from __future__ import annotations import json import logging from dataclasses import asdict, dataclass, replace from datetime import datetime from typing import Any, Dict, List, Mapping, Tuple from coco_classes import canonicalize_coco_name, coco_class_catalog from mission_context import ( MissionClass, MissionContext, MissionPlan, MISSION_TYPE_OPTIONS, LOCATION_TYPE_OPTIONS, TIME_OF_DAY_OPTIONS, PRIORITY_LEVEL_OPTIONS, PipelineRecommendation, build_prompt_hints, ) from pipeline_registry import ( PIPELINE_SPECS, fallback_pipeline_for_context, filter_pipelines_for_context, get_pipeline_spec, ) from prompt import mission_planner_system_prompt, mission_planner_user_prompt from utils.openai_client import get_openai_client DEFAULT_OPENAI_MODEL = "gpt-4o-mini" class MissionReasoner: def __init__( self, *, model_name: str = DEFAULT_OPENAI_MODEL, top_k: int = 10, ) -> None: self._model_name = model_name self._top_k = top_k self._coco_catalog = coco_class_catalog() def plan( self, mission: str, *, context: MissionContext, cues: Mapping[str, Any] | None = None, ) -> MissionPlan: mission = (mission or "").strip() if not mission: raise ValueError("Mission prompt cannot be empty.") available_pipelines = self._candidate_pipelines(mission, context, cues) candidate_ids = [spec["id"] for spec in available_pipelines] or [PIPELINE_SPECS[0]["id"]] lock_pipeline_id = candidate_ids[0] if len(candidate_ids) == 1 else None response_payload = self._query_llm( mission, context=context, cues=None, pipeline_ids=candidate_ids, ) relevant = self._parse_plan(response_payload, fallback_mission=mission) enriched_context = self._merge_context(context, response_payload.get("context")) if lock_pipeline_id: pipeline_rec = PipelineRecommendation( primary_id=lock_pipeline_id, primary_reason="Only pipeline compatible with mission context.", ) else: pipeline_rec = self._parse_pipeline_recommendation( response_payload.get("pipelines") or response_payload.get("pipeline"), available_pipelines, context, ) return MissionPlan( mission=response_payload.get("mission", mission), relevant_classes=relevant[: self._top_k], context=enriched_context, pipeline=pipeline_rec, ) def _render_pipeline_catalog(self, specs: List[Dict[str, object]]) -> str: if not specs: return "No compatible pipelines available." sections: List[str] = [] for spec in specs: reason = spec.get("availability_reason") or "Compatible with mission context." hf_bindings = spec.get("huggingface") or {} def _format_models(models: List[Dict[str, object]]) -> str: if not models: return "none" labels = [] for entry in models: model_id = entry.get("model_id") or entry.get("name") or "unknown" label = entry.get("label") or model_id suffix = " (optional)" if entry.get("optional") else "" labels.append(f"{label}{suffix}") return ", ".join(labels) detection_models = _format_models(hf_bindings.get("detection", [])) segmentation_models = _format_models(hf_bindings.get("segmentation", [])) tracking_models = _format_models(hf_bindings.get("tracking", [])) hf_notes = hf_bindings.get("notes") or "" sections.append( "\n".join( [ f"{spec['id']} pipeline", f" Modalities: {', '.join(spec.get('modalities', ())) or 'unspecified'}", f" Locations: {', '.join(spec.get('location_types', ())) or 'any'}", f" Time of day: {', '.join(spec.get('time_of_day', ())) or 'any'}", f" Availability: {reason}", f" HF detection: {detection_models}", f" HF segmentation: {segmentation_models}", f" Tracking: {tracking_models}", f" Notes: {hf_notes or 'n/a'}", ] ) ) return "\n\n".join(sections) def _candidate_pipelines( self, mission: str, context: MissionContext, cues: Mapping[str, Any] | None, ) -> List[Dict[str, object]]: filtered = filter_pipelines_for_context(context) if filtered: return filtered fallback_spec = fallback_pipeline_for_context(context, []) if fallback_spec is None: logging.error("No fallback pipeline available; mission context=%s", context) return [dict(spec) for spec in PIPELINE_SPECS] logging.warning( "No compatible pipelines for context %s; selecting fallback %s.", context, fallback_spec["id"], ) fallback_copy = dict(fallback_spec) fallback_copy["availability_reason"] = ( "Fallback engaged because no specialized pipeline matched this mission context." ) return [fallback_copy] def _query_llm( self, mission: str, *, context: MissionContext, cues: Mapping[str, Any] | None = None, pipeline_ids: List[str] | None, ) -> Dict[str, object]: client = get_openai_client() system_prompt = mission_planner_system_prompt() context_payload = context.to_prompt_payload() user_prompt = mission_planner_user_prompt( mission, self._top_k, context=context_payload, cues=cues, pipeline_candidates=pipeline_ids, coco_catalog=self._coco_catalog, ) completion = client.chat.completions.create( model=self._model_name, temperature=0.1, response_format={"type": "json_object"}, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], ) content = completion.choices[0].message.content or "{}" try: return json.loads(content) except json.JSONDecodeError: logging.exception("LLM returned non-JSON content: %s", content) return {"mission": mission, "classes": []} def _parse_plan(self, payload: Dict[str, object], fallback_mission: str) -> List[MissionClass]: entries = ( payload.get("entities") or payload.get("classes") or payload.get("relevant_classes") or [] ) mission = payload.get("mission") or fallback_mission parsed: List[MissionClass] = [] seen = set() for entry in entries: if not isinstance(entry, dict): continue name = str(entry.get("name") or "").strip() if not name: continue canonical_name = canonicalize_coco_name(name) if not canonical_name: logging.warning("Skipping non-COCO entity '%s'.", name) continue if canonical_name in seen: continue seen.add(canonical_name) score_raw = entry.get("score") try: score = float(score_raw) except (TypeError, ValueError): score = 0.5 rationale = str(entry.get("rationale") or f"Track '{name}' for mission '{mission}'.") parsed.append( MissionClass( name=canonical_name, score=max(0.0, min(1.0, score)), rationale=rationale, ) ) if not parsed: raise RuntimeError("LLM returned no semantic entities; aborting instead of fabricating outputs.") return parsed def _merge_context( self, base_context: MissionContext, context_payload: Dict[str, object] | None, ) -> MissionContext: payload = context_payload or {} if not isinstance(payload, dict): return base_context def _coerce_choice(value: object | None, allowed: Tuple[str, ...]) -> str | None: if value is None: return None candidate = str(value).strip().lower() return candidate if candidate in allowed else None updates: Dict[str, Any] = {} new_mission_type = _coerce_choice(payload.get("mission_type"), MISSION_TYPE_OPTIONS) new_location_type = _coerce_choice(payload.get("location_type"), LOCATION_TYPE_OPTIONS) new_time_of_day = _coerce_choice(payload.get("time_of_day"), TIME_OF_DAY_OPTIONS) new_priority = _coerce_choice(payload.get("priority_level"), PRIORITY_LEVEL_OPTIONS) if new_mission_type: updates["mission_type"] = new_mission_type if new_location_type: updates["location_type"] = new_location_type if new_time_of_day: updates["time_of_day"] = new_time_of_day if new_priority: updates["priority_level"] = new_priority if not updates: return base_context return replace(base_context, **updates) def _parse_pipeline_recommendation( self, payload: object, available_specs: List[Dict[str, object]], context: MissionContext, ) -> PipelineRecommendation | None: if not isinstance(payload, dict): return self._validate_pipeline_selection(None, available_specs, context) if "id" in payload or "pipeline_id" in payload or "pipeline" in payload: pipeline_id_raw = payload.get("id") or payload.get("pipeline_id") or payload.get("pipeline") pipeline_id = str(pipeline_id_raw or "").strip() reason = str(payload.get("reason") or "").strip() or None candidate = PipelineRecommendation(primary_id=pipeline_id or None, primary_reason=reason) return self._validate_pipeline_selection(candidate, available_specs, context) def _extract_entry(entry_key: str) -> tuple[str | None, str | None]: value = payload.get(entry_key) if not isinstance(value, dict): return None, None pipeline_id_raw = value.get("id") or value.get("pipeline_id") or value.get("pipeline") pipeline_id = str(pipeline_id_raw).strip() if not pipeline_id: return None, None if not get_pipeline_spec(pipeline_id): return None, None reason = str(value.get("reason") or "").strip() or None return pipeline_id, reason primary_id, primary_reason = _extract_entry("primary") fallback_id, fallback_reason = _extract_entry("fallback") rec = PipelineRecommendation( primary_id=primary_id, primary_reason=primary_reason, fallback_id=fallback_id, fallback_reason=fallback_reason, ) return self._validate_pipeline_selection(rec, available_specs, context) def _validate_pipeline_selection( self, candidate: PipelineRecommendation | None, available_specs: List[Dict[str, object]], context: MissionContext, ) -> PipelineRecommendation | None: if not available_specs: return None available_ids = {spec["id"] for spec in available_specs} def _normalize_reason(reason: str | None, default: str) -> str: text = (reason or "").strip() return text or default primary_id = candidate.primary_id if candidate and candidate.primary_id in available_ids else None if not primary_id: fallback_spec = fallback_pipeline_for_context(context, available_specs) if fallback_spec is None: logging.warning("No pipelines available even after fallback.") return None logging.warning( "Pipeline recommendation invalid or missing. Defaulting to %s.", fallback_spec["id"] ) return PipelineRecommendation( primary_id=fallback_spec["id"], primary_reason=_normalize_reason( candidate.primary_reason if candidate else None, "Auto-selected based on available sensors and context.", ), fallback_id=None, fallback_reason=None, ) primary_reason = _normalize_reason(candidate.primary_reason if candidate else None, "LLM-selected.") fallback_allowed = context.priority_level in {"elevated", "high"} fallback_id = candidate.fallback_id if candidate else None fallback_reason = candidate.fallback_reason if candidate else None if not fallback_allowed or fallback_id not in available_ids or fallback_id == primary_id: if fallback_id: logging.info("Dropping fallback pipeline %s due to priority/context constraints.", fallback_id) fallback_id_valid = None fallback_reason_valid = None else: fallback_id_valid = fallback_id fallback_reason_valid = _normalize_reason(fallback_reason, "Fallback allowed due to priority level.") return PipelineRecommendation( primary_id=primary_id, primary_reason=primary_reason, fallback_id=fallback_id_valid, fallback_reason=fallback_reason_valid, ) _REASONER: MissionReasoner | None = None def get_mission_plan( mission: str, *, latitude: float | None = None, longitude: float | None = None, context_overrides: MissionContext | None = None, ) -> MissionPlan: global _REASONER if _REASONER is None: _REASONER = MissionReasoner() context = context_overrides or MissionContext() cues = build_prompt_hints(mission, latitude, longitude) if latitude is not None and longitude is not None: logging.info("Mission location coordinates: lat=%s, lon=%s", latitude, longitude) local_time_hint = cues.get("local_time") if isinstance(cues, Mapping) else None if local_time_hint: logging.info("Derived local mission time: %s", local_time_hint) timezone_hint = cues.get("timezone") if isinstance(cues, Mapping) else None if timezone_hint: logging.info("Derived local timezone: %s", timezone_hint) locality_hint = cues.get("nearest_locality") if isinstance(cues, Mapping) else None if locality_hint: logging.info("Reverse geocoded locality: %s", locality_hint) inferred_time = _infer_time_of_day_from_cues(context, cues) if inferred_time and context.time_of_day != inferred_time: context = replace(context, time_of_day=inferred_time) return _REASONER.plan(mission, context=context, cues=cues) def _infer_time_of_day_from_cues(context: MissionContext, cues: Mapping[str, Any] | None) -> str | None: if context.time_of_day or not cues: return context.time_of_day local_time_raw = cues.get("local_time") if isinstance(cues, Mapping) else None if not local_time_raw: return None try: local_dt = datetime.fromisoformat(str(local_time_raw)) except (ValueError, TypeError): return None hour = local_dt.hour return "day" if 6 <= hour < 18 else "night"