File size: 4,652 Bytes
2ab6e0a
a8d3381
2ab6e0a
 
 
a8d3381
 
 
34b56b2
2ab6e0a
 
 
 
 
 
 
 
 
 
 
 
 
a8d3381
 
 
 
 
c57c49d
a8d3381
 
 
c57c49d
 
2ab6e0a
c57c49d
 
 
 
 
a8d3381
 
 
 
 
 
 
 
 
 
c57c49d
 
 
 
 
 
a8d3381
 
 
c57c49d
 
 
2ab6e0a
a8d3381
2ab6e0a
a8d3381
2ab6e0a
 
 
 
51780f2
a8d3381
2ab6e0a
c57c49d
51780f2
 
5e3ba22
51780f2
2ab6e0a
 
 
 
 
 
5e3ba22
 
 
 
 
 
34b56b2
 
 
 
 
 
 
 
 
 
 
 
a8d3381
2ab6e0a
a8d3381
2ab6e0a
 
 
 
34b56b2
a8d3381
2ab6e0a
 
51780f2
 
 
 
 
 
 
 
5e3ba22
51780f2
5e3ba22
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import logging
from typing import Any, Dict, List, Optional, Sequence, Tuple

import cv2
import numpy as np
from models.model_loader import load_detector
from mission_planner import MissionPlan, get_mission_plan
from mission_summarizer import summarize_results
from pipeline_registry import get_pipeline_spec
from utils.video import extract_frames, write_video


def draw_boxes(frame: np.ndarray, boxes: np.ndarray) -> np.ndarray:
    output = frame.copy()
    if boxes is None:
        return output
    for box in boxes:
        x1, y1, x2, y2 = [int(coord) for coord in box]
        cv2.rectangle(output, (x1, y1), (x2, y2), (0, 255, 0), thickness=2)
    return output


def _build_detection_records(
    boxes: np.ndarray,
    scores: Sequence[float],
    labels: Sequence[int],
    queries: Sequence[str],
    label_names: Optional[Sequence[str]] = None,
) -> List[Dict[str, Any]]:
    detections: List[Dict[str, Any]] = []
    for idx, box in enumerate(boxes):
        if label_names is not None and idx < len(label_names):
            label = label_names[idx]
        else:
            label_idx = int(labels[idx]) if idx < len(labels) else -1
            if 0 <= label_idx < len(queries):
                label = queries[label_idx]
            else:
                label = f"label_{label_idx}"
        detections.append(
            {
                "label": label,
                "score": float(scores[idx]) if idx < len(scores) else 0.0,
                "bbox": [int(coord) for coord in box.tolist()],
            }
        )
    return detections


def infer_frame(
    frame: np.ndarray,
    queries: Sequence[str],
    detector_name: Optional[str] = None,
) -> Tuple[np.ndarray, List[Dict[str, Any]]]:
    detector = load_detector(detector_name)
    text_queries = list(queries) or ["object"]
    try:
        result = detector.predict(frame, text_queries)
        detections = _build_detection_records(
            result.boxes, result.scores, result.labels, text_queries, result.label_names
        )
    except Exception:
        logging.exception("Inference failed for queries %s", text_queries)
        raise
    return draw_boxes(frame, result.boxes), detections


def run_inference(
    input_video_path: str,
    output_video_path: Optional[str],
    mission_prompt: str,
    max_frames: Optional[int] = None,
    detector_name: Optional[str] = None,
    write_output_video: bool = True,
    generate_summary: bool = True,
    mission_plan: Optional[MissionPlan] = None,
) -> Tuple[Optional[str], MissionPlan, Optional[str]]:
    try:
        frames, fps, width, height = extract_frames(input_video_path)
    except ValueError as exc:
        logging.exception("Failed to decode video at %s", input_video_path)
        raise

    mission_prompt_clean = (mission_prompt or "").strip()
    if not mission_prompt_clean:
        raise ValueError("Mission prompt is required.")
    resolved_plan = mission_plan or get_mission_plan(mission_prompt_clean)
    logging.info("Mission plan: %s", resolved_plan.to_json())
    queries = resolved_plan.queries()
    plan_detector = None
    if resolved_plan.pipeline and resolved_plan.pipeline.primary_id:
        spec = get_pipeline_spec(resolved_plan.pipeline.primary_id)
        if spec:
            hf_bindings = spec.get("huggingface") or {}
            detections = hf_bindings.get("detection") or []
            for entry in detections:
                detector_key = entry.get("detector_key")
                if detector_key:
                    plan_detector = detector_key
                    break
    active_detector = detector_name or plan_detector

    processed_frames: List[np.ndarray] = []
    detection_log: List[Dict[str, Any]] = []
    for idx, frame in enumerate(frames):
        if max_frames is not None and idx >= max_frames:
            break
        logging.debug("Processing frame %d", idx)
        processed_frame, detections = infer_frame(frame, queries, detector_name=active_detector)
        detection_log.append({"frame_index": idx, "detections": detections})
        processed_frames.append(processed_frame)

    if write_output_video:
        if not output_video_path:
            raise ValueError("output_video_path is required when write_output_video=True.")
        write_video(processed_frames, output_video_path, fps=fps, width=width, height=height)
        video_path_result: Optional[str] = output_video_path
    else:
        video_path_result = None
    mission_summary = (
        summarize_results(mission_prompt_clean, resolved_plan, detection_log) if generate_summary else None
    )
    return video_path_result, resolved_plan, mission_summary