turing-space / turing /modeling /model_selector.py
papri-ka's picture
Deploy FastAPI ML service to Hugging Face Spaces
5fc6e5d
from typing import Optional
from loguru import logger
from mlflow.tracking import MlflowClient
def get_best_model_by_tag(
language: str,
tag_key: str = "best_model",
metric: str = "f1_score"
) -> Optional[dict]:
"""
Retrieve the best model for a specific language using MLflow tags.
Args:
language: Programming language (java, python, pharo)
tag_key: Tag key to search for (default: "best_model")
metric: Metric to use for ordering (default: "f1_score")
Returns:
Dict with run_id and artifact_name of the best model or None if not found
"""
client = MlflowClient()
experiments = client.search_experiments()
if not experiments:
logger.error("No experiments found in MLflow")
return None
try:
runs = client.search_runs(
experiment_ids=[exp.experiment_id for exp in experiments],
filter_string=f"tags.{tag_key} = 'true' and tags.Language = '{language}'",
order_by=[f"metrics.{metric} DESC"],
max_results=1
)
if not runs:
logger.warning(f"No runs found with tag '{tag_key}' for language '{language}'")
return None
best_run = runs[0]
run_id = best_run.info.run_id
exp_name = client.get_experiment(best_run.info.experiment_id).name
run_name = best_run.info.run_name
artifact_name = best_run.data.tags.get("model_name")
model_id = best_run.data.tags.get("model_id")
logger.info(f"Found best model for {language}: {exp_name}/{run_name} ({run_id}), artifact={artifact_name}")
return {
"run_id": run_id,
"artifact": artifact_name,
"model_id": model_id
}
except Exception as e:
logger.error(f"Error searching for best model: {e}")
return None
def get_best_model_info(
language: str,
fallback_registry: dict = None
) -> dict:
"""
Retrieve the best model information for a language.
First searches by tag, then falls back to hardcoded registry.
Args:
language: Programming language
fallback_registry: Fallback registry with run_id and artifact
Returns:
Dict with run_id and artifact of the model
"""
model_info = get_best_model_by_tag(language, "best_model")
if model_info:
logger.info(f"Using tagged best model for {language}")
return model_info
if fallback_registry and language in fallback_registry:
logger.warning(f"No tagged model found for {language}, using fallback registry")
return fallback_registry[language]
model_info = get_best_model_by_metric(language)
if model_info:
logger.warning(f"Using best model by metric for {language}")
return model_info
raise ValueError(f"No model found for language {language}")
def get_best_model_by_metric(
language: str,
metric: str = "f1_score"
) -> Optional[dict]:
"""
Find the model with the best metric for a language.
Args:
language: Programming language
metric: Metric to use for ordering
Returns:
Dict with run_id and artifact of the model or None
"""
client = MlflowClient()
experiments = client.search_experiments()
if not experiments:
logger.error("No experiments found in MLflow")
return None
try:
runs = client.search_runs(
experiment_ids=[exp.experiment_id for exp in experiments],
filter_string=f"tags.Language = '{language}'",
order_by=[f"metrics.{metric} DESC"],
max_results=1
)
if not runs:
logger.warning(f"No runs found for language '{language}'")
return None
best_run = runs[0]
run_id = best_run.info.run_id
exp_name = client.get_experiment(best_run.info.experiment_id).name
run_name = best_run.info.run_name
artifact_name = best_run.data.tags.get("model_name")
model_id = best_run.data.tags.get("model_id")
logger.info(f"Found best model for {language}: {exp_name}/{run_name} ({run_id}), artifact={artifact_name}")
return {
"run_id": run_id,
"artifact": artifact_name,
"model_id": model_id
}
except Exception as e:
logger.error(f"Error finding best model by metric: {e}")
return None