| """Submission-root inference entrypoint for OpenEnv gates.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| import time |
| from dataclasses import dataclass |
| from pathlib import Path |
| from statistics import mean |
| from typing import Any, Mapping, Sequence, TextIO |
| from urllib.parse import urlparse |
|
|
|
|
| ENV_ROOT = Path(__file__).resolve().parent |
| PARENT = ENV_ROOT.parent |
| if str(PARENT) not in sys.path: |
| sys.path.insert(0, str(PARENT)) |
|
|
| from openai import OpenAI |
|
|
| from legacy_cobol_env.eval.model_rollout import run_model_repair_rollout, run_model_rollout |
| from legacy_cobol_env.eval.providers import StaticResponseProvider, TextProvider |
| from legacy_cobol_env.server.task_bank import TaskInstance, all_tasks, load_task |
|
|
|
|
| VALID_MARKERS = {"START", "STEP", "END"} |
| STATIC_RESPONSE = '{"code": "def migrate(input_record: str) -> str:\\n return input_record\\n"}' |
|
|
|
|
| @dataclass(frozen=True) |
| class RuntimeConfig: |
| api_base_url: str |
| model_name: str |
| hf_token: str |
| mode: str |
| api_version: str = "2024-12-01-preview" |
|
|
|
|
| def load_runtime_config(env: Mapping[str, str] | None = None, mode: str | None = None) -> RuntimeConfig: |
| values = os.environ if env is None else env |
| selected_mode = mode or values.get("INFERENCE_MODE") or values.get("MODE") or "live" |
|
|
| if selected_mode in {"static", "mock"}: |
| return RuntimeConfig( |
| api_base_url=values.get("API_BASE_URL", ""), |
| model_name=values.get("MODEL_NAME", "static"), |
| hf_token=values.get("HF_TOKEN", ""), |
| mode=selected_mode, |
| api_version=values.get( |
| "API_VERSION", |
| values.get( |
| "OPENAI_API_VERSION", |
| values.get("AZURE_OPENAI_API_VERSION", "2024-12-01-preview"), |
| ), |
| ), |
| ) |
|
|
| required = ["API_BASE_URL", "MODEL_NAME", "HF_TOKEN"] |
| missing = [key for key in required if not values.get(key)] |
| if missing: |
| raise ValueError(f"missing inference configuration: {', '.join(missing)}") |
|
|
| return RuntimeConfig( |
| api_base_url=values["API_BASE_URL"], |
| model_name=values["MODEL_NAME"], |
| hf_token=values["HF_TOKEN"], |
| mode=selected_mode, |
| api_version=values.get( |
| "API_VERSION", |
| values.get("OPENAI_API_VERSION", values.get("AZURE_OPENAI_API_VERSION", "2024-12-01-preview")), |
| ), |
| ) |
|
|
|
|
| def format_event(marker: str, payload: Mapping[str, object]) -> str: |
| if marker not in VALID_MARKERS: |
| raise ValueError(f"invalid log marker: {marker}") |
| data = json.dumps(dict(payload), sort_keys=True, separators=(",", ":")) |
| return f"[{marker}] {data}" |
|
|
|
|
| def build_openai_client(config: RuntimeConfig) -> OpenAI: |
| base_url = config.api_base_url.rstrip("/") |
| if _is_azure_endpoint(base_url): |
| deployment_base = ( |
| base_url |
| if "/openai/deployments/" in base_url |
| else f"{base_url}/openai/deployments/{config.model_name}" |
| ) |
| return OpenAI( |
| base_url=deployment_base, |
| api_key=config.hf_token, |
| default_headers={"api-key": config.hf_token}, |
| default_query={"api-version": config.api_version}, |
| timeout=60.0, |
| ) |
| return OpenAI(base_url=base_url, api_key=config.hf_token, timeout=60.0) |
|
|
|
|
| class OpenAITextProvider: |
| name = "openai-client" |
|
|
| def __init__(self, client: OpenAI, model_name: str) -> None: |
| self._client = client |
| self._model_name = model_name |
|
|
| def generate(self, prompt: str) -> str: |
| response = self._client.chat.completions.create( |
| model=self._model_name, |
| messages=[ |
| { |
| "role": "system", |
| "content": "Return only JSON with a single code field containing a Python migrate function.", |
| }, |
| {"role": "user", "content": prompt}, |
| ], |
| temperature=0, |
| ) |
| return response.choices[0].message.content or "" |
|
|
|
|
| def build_provider(config: RuntimeConfig) -> TextProvider: |
| if config.mode in {"static", "mock"}: |
| return StaticResponseProvider(config.mode, STATIC_RESPONSE) |
| return OpenAITextProvider(build_openai_client(config), config.model_name) |
|
|
|
|
| def run_inference(task_id: str | None, max_repairs: int, config: RuntimeConfig) -> dict[str, object]: |
| tasks = [load_task(task_id=task_id)] if task_id else all_tasks() |
| provider = build_provider(config) |
| results = [_run_task(task, provider, max_repairs) for task in tasks] |
|
|
| return { |
| "mode": config.mode, |
| "model_name": config.model_name, |
| "max_repairs": max_repairs, |
| "task_count": len(results), |
| "mean_public_score": mean(item["score"] for item in results) if results else 0.0, |
| "accepted_count": sum(1 for item in results if item["accepted"]), |
| "results": results, |
| } |
|
|
|
|
| def _run_task(task: TaskInstance, provider: TextProvider, max_repairs: int) -> dict[str, Any]: |
| started = time.perf_counter() |
| try: |
| trajectory = ( |
| run_model_repair_rollout(task=task, provider=provider, max_repairs=max_repairs) |
| if max_repairs > 0 |
| else run_model_rollout(task=task, provider=provider) |
| ) |
| except Exception as exc: |
| return { |
| "task_id": task.task_id, |
| "family_id": task.family_id, |
| "difficulty": task.metadata["difficulty"], |
| "score": 0.0, |
| "accepted": False, |
| "duration_s": round(time.perf_counter() - started, 3), |
| "error": f"{type(exc).__name__}: {exc}", |
| } |
|
|
| final = trajectory["final"] |
| return { |
| "task_id": task.task_id, |
| "family_id": task.family_id, |
| "difficulty": task.metadata["difficulty"], |
| "score": final["public_score"], |
| "accepted": final["accepted"], |
| "visible_pass_rate": trajectory["visible"]["pass_rate"], |
| "duration_s": round(time.perf_counter() - started, 3), |
| "trajectory": trajectory, |
| } |
|
|
|
|
| def write_output(path: str | None, payload: Mapping[str, object]) -> None: |
| if not path: |
| return |
| output_path = Path(path) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| output_path.write_text(json.dumps(dict(payload), sort_keys=True, indent=2) + "\n", encoding="utf-8") |
|
|
|
|
| def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Run OpenEnv submission inference gate.") |
| parser.add_argument("--task-id") |
| parser.add_argument("--max-repairs", type=int, default=0) |
| parser.add_argument("--output") |
| parser.add_argument("--mode", choices=["live", "static", "mock"]) |
| return parser.parse_args(argv) |
|
|
|
|
| def main( |
| argv: Sequence[str] | None = None, |
| env: Mapping[str, str] | None = None, |
| stdout: TextIO | None = None, |
| ) -> int: |
| args = parse_args(argv) |
| stream = sys.stdout if stdout is None else stdout |
| config = load_runtime_config(env, mode=args.mode) |
|
|
| print( |
| format_event( |
| "START", |
| { |
| "mode": config.mode, |
| "model_name": config.model_name, |
| "task_id": args.task_id or "all", |
| "max_repairs": args.max_repairs, |
| }, |
| ), |
| file=stream, |
| ) |
| result = run_inference(args.task_id, args.max_repairs, config) |
| for index, task_result in enumerate(result["results"], start=1): |
| step_payload = { |
| "index": index, |
| "task_id": task_result["task_id"], |
| "family_id": task_result["family_id"], |
| "difficulty": task_result["difficulty"], |
| "score": task_result["score"], |
| "accepted": task_result["accepted"], |
| } |
| if "error" in task_result: |
| step_payload["error"] = task_result["error"] |
| print(format_event("STEP", step_payload), file=stream) |
| write_output(args.output, result) |
| print( |
| format_event( |
| "END", |
| { |
| "mode": config.mode, |
| "task_count": result["task_count"], |
| "mean_public_score": result["mean_public_score"], |
| "accepted_count": result["accepted_count"], |
| }, |
| ), |
| file=stream, |
| ) |
| return 0 |
|
|
|
|
| def _is_azure_endpoint(base_url: str) -> bool: |
| host = urlparse(base_url).netloc.lower() |
| return "openai.azure.com" in host or "cognitiveservices.azure.com" in host |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|