import logging from typing import Sequence import numpy as np import torch from transformers import DetrForObjectDetection, DetrImageProcessor from models.detectors.base import DetectionResult, ObjectDetector class DetrDetector(ObjectDetector): """Wrapper around facebook/detr-resnet-50 for mission-aligned detection.""" MODEL_NAME = "facebook/detr-resnet-50" def __init__(self, score_threshold: float = 0.3) -> None: self.name = "detr_resnet50" self.score_threshold = score_threshold self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logging.info("Loading %s onto %s", self.MODEL_NAME, self.device) self.processor = DetrImageProcessor.from_pretrained(self.MODEL_NAME) self.model = DetrForObjectDetection.from_pretrained(self.MODEL_NAME) self.model.to(self.device) self.model.eval() def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult: inputs = self.processor(images=frame, return_tensors="pt") inputs = {key: value.to(self.device) for key, value in inputs.items()} with torch.no_grad(): outputs = self.model(**inputs) target_sizes = torch.tensor([frame.shape[:2]], device=self.device) processed = self.processor.post_process_object_detection( outputs, threshold=self.score_threshold, target_sizes=target_sizes, )[0] boxes = processed["boxes"].cpu().numpy() scores = processed["scores"].cpu().tolist() labels = processed["labels"].cpu().tolist() label_names = [ self.model.config.id2label.get(int(idx), f"class_{idx}") for idx in labels ] return DetectionResult( boxes=boxes, scores=scores, labels=labels, label_names=label_names, )