zye0616's picture
Update: Add mission pipeline registry
34b56b2
import logging
from typing import Sequence
import numpy as np
import torch
from transformers import DetrForObjectDetection, DetrImageProcessor
from models.detectors.base import DetectionResult, ObjectDetector
class DetrDetector(ObjectDetector):
"""Wrapper around facebook/detr-resnet-50 for mission-aligned detection."""
MODEL_NAME = "facebook/detr-resnet-50"
def __init__(self, score_threshold: float = 0.3) -> None:
self.name = "detr_resnet50"
self.score_threshold = score_threshold
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info("Loading %s onto %s", self.MODEL_NAME, self.device)
self.processor = DetrImageProcessor.from_pretrained(self.MODEL_NAME)
self.model = DetrForObjectDetection.from_pretrained(self.MODEL_NAME)
self.model.to(self.device)
self.model.eval()
def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
inputs = self.processor(images=frame, return_tensors="pt")
inputs = {key: value.to(self.device) for key, value in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
target_sizes = torch.tensor([frame.shape[:2]], device=self.device)
processed = self.processor.post_process_object_detection(
outputs,
threshold=self.score_threshold,
target_sizes=target_sizes,
)[0]
boxes = processed["boxes"].cpu().numpy()
scores = processed["scores"].cpu().tolist()
labels = processed["labels"].cpu().tolist()
label_names = [
self.model.config.id2label.get(int(idx), f"class_{idx}") for idx in labels
]
return DetectionResult(
boxes=boxes,
scores=scores,
labels=labels,
label_names=label_names,
)