Ishangtxl's picture
Sync from GitHub c4e4dad
a537615 verified
"""Submission-root inference entrypoint for OpenEnv gates."""
from __future__ import annotations
import argparse
import json
import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from statistics import mean
from typing import Any, Mapping, Sequence, TextIO
from urllib.parse import urlparse
ENV_ROOT = Path(__file__).resolve().parent
PARENT = ENV_ROOT.parent
if str(PARENT) not in sys.path:
sys.path.insert(0, str(PARENT))
from openai import OpenAI
from legacy_cobol_env.eval.model_rollout import run_model_repair_rollout, run_model_rollout
from legacy_cobol_env.eval.providers import StaticResponseProvider, TextProvider
from legacy_cobol_env.server.task_bank import TaskInstance, all_tasks, load_task
VALID_MARKERS = {"START", "STEP", "END"}
STATIC_RESPONSE = '{"code": "def migrate(input_record: str) -> str:\\n return input_record\\n"}'
@dataclass(frozen=True)
class RuntimeConfig:
api_base_url: str
model_name: str
hf_token: str
mode: str
api_version: str = "2024-12-01-preview"
def load_runtime_config(env: Mapping[str, str] | None = None, mode: str | None = None) -> RuntimeConfig:
values = os.environ if env is None else env
selected_mode = mode or values.get("INFERENCE_MODE") or values.get("MODE") or "live"
if selected_mode in {"static", "mock"}:
return RuntimeConfig(
api_base_url=values.get("API_BASE_URL", ""),
model_name=values.get("MODEL_NAME", "static"),
hf_token=values.get("HF_TOKEN", ""),
mode=selected_mode,
api_version=values.get(
"API_VERSION",
values.get(
"OPENAI_API_VERSION",
values.get("AZURE_OPENAI_API_VERSION", "2024-12-01-preview"),
),
),
)
required = ["API_BASE_URL", "MODEL_NAME", "HF_TOKEN"]
missing = [key for key in required if not values.get(key)]
if missing:
raise ValueError(f"missing inference configuration: {', '.join(missing)}")
return RuntimeConfig(
api_base_url=values["API_BASE_URL"],
model_name=values["MODEL_NAME"],
hf_token=values["HF_TOKEN"],
mode=selected_mode,
api_version=values.get(
"API_VERSION",
values.get("OPENAI_API_VERSION", values.get("AZURE_OPENAI_API_VERSION", "2024-12-01-preview")),
),
)
def format_event(marker: str, payload: Mapping[str, object]) -> str:
if marker not in VALID_MARKERS:
raise ValueError(f"invalid log marker: {marker}")
data = json.dumps(dict(payload), sort_keys=True, separators=(",", ":"))
return f"[{marker}] {data}"
def build_openai_client(config: RuntimeConfig) -> OpenAI:
base_url = config.api_base_url.rstrip("/")
if _is_azure_endpoint(base_url):
deployment_base = (
base_url
if "/openai/deployments/" in base_url
else f"{base_url}/openai/deployments/{config.model_name}"
)
return OpenAI(
base_url=deployment_base,
api_key=config.hf_token,
default_headers={"api-key": config.hf_token},
default_query={"api-version": config.api_version},
timeout=60.0,
)
return OpenAI(base_url=base_url, api_key=config.hf_token, timeout=60.0)
class OpenAITextProvider:
name = "openai-client"
def __init__(self, client: OpenAI, model_name: str) -> None:
self._client = client
self._model_name = model_name
def generate(self, prompt: str) -> str:
response = self._client.chat.completions.create(
model=self._model_name,
messages=[
{
"role": "system",
"content": "Return only JSON with a single code field containing a Python migrate function.",
},
{"role": "user", "content": prompt},
],
temperature=0,
)
return response.choices[0].message.content or ""
def build_provider(config: RuntimeConfig) -> TextProvider:
if config.mode in {"static", "mock"}:
return StaticResponseProvider(config.mode, STATIC_RESPONSE)
return OpenAITextProvider(build_openai_client(config), config.model_name)
def run_inference(task_id: str | None, max_repairs: int, config: RuntimeConfig) -> dict[str, object]:
tasks = [load_task(task_id=task_id)] if task_id else all_tasks()
provider = build_provider(config)
results = [_run_task(task, provider, max_repairs) for task in tasks]
return {
"mode": config.mode,
"model_name": config.model_name,
"max_repairs": max_repairs,
"task_count": len(results),
"mean_public_score": mean(item["score"] for item in results) if results else 0.0,
"accepted_count": sum(1 for item in results if item["accepted"]),
"results": results,
}
def _run_task(task: TaskInstance, provider: TextProvider, max_repairs: int) -> dict[str, Any]:
started = time.perf_counter()
try:
trajectory = (
run_model_repair_rollout(task=task, provider=provider, max_repairs=max_repairs)
if max_repairs > 0
else run_model_rollout(task=task, provider=provider)
)
except Exception as exc:
return {
"task_id": task.task_id,
"family_id": task.family_id,
"difficulty": task.metadata["difficulty"],
"score": 0.0,
"accepted": False,
"duration_s": round(time.perf_counter() - started, 3),
"error": f"{type(exc).__name__}: {exc}",
}
final = trajectory["final"]
return {
"task_id": task.task_id,
"family_id": task.family_id,
"difficulty": task.metadata["difficulty"],
"score": final["public_score"],
"accepted": final["accepted"],
"visible_pass_rate": trajectory["visible"]["pass_rate"],
"duration_s": round(time.perf_counter() - started, 3),
"trajectory": trajectory,
}
def write_output(path: str | None, payload: Mapping[str, object]) -> None:
if not path:
return
output_path = Path(path)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(json.dumps(dict(payload), sort_keys=True, indent=2) + "\n", encoding="utf-8")
def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run OpenEnv submission inference gate.")
parser.add_argument("--task-id")
parser.add_argument("--max-repairs", type=int, default=0)
parser.add_argument("--output")
parser.add_argument("--mode", choices=["live", "static", "mock"])
return parser.parse_args(argv)
def main(
argv: Sequence[str] | None = None,
env: Mapping[str, str] | None = None,
stdout: TextIO | None = None,
) -> int:
args = parse_args(argv)
stream = sys.stdout if stdout is None else stdout
config = load_runtime_config(env, mode=args.mode)
print(
format_event(
"START",
{
"mode": config.mode,
"model_name": config.model_name,
"task_id": args.task_id or "all",
"max_repairs": args.max_repairs,
},
),
file=stream,
)
result = run_inference(args.task_id, args.max_repairs, config)
for index, task_result in enumerate(result["results"], start=1):
step_payload = {
"index": index,
"task_id": task_result["task_id"],
"family_id": task_result["family_id"],
"difficulty": task_result["difficulty"],
"score": task_result["score"],
"accepted": task_result["accepted"],
}
if "error" in task_result:
step_payload["error"] = task_result["error"]
print(format_event("STEP", step_payload), file=stream)
write_output(args.output, result)
print(
format_event(
"END",
{
"mode": config.mode,
"task_count": result["task_count"],
"mean_public_score": result["mean_public_score"],
"accepted_count": result["accepted_count"],
},
),
file=stream,
)
return 0
def _is_azure_endpoint(base_url: str) -> bool:
host = urlparse(base_url).netloc.lower()
return "openai.azure.com" in host or "cognitiveservices.azure.com" in host
if __name__ == "__main__":
raise SystemExit(main())