Spaces:
Runtime error
Runtime error
| from typing import Optional | |
| from loguru import logger | |
| from mlflow.tracking import MlflowClient | |
| def get_best_model_by_tag( | |
| language: str, | |
| tag_key: str = "best_model", | |
| metric: str = "f1_score" | |
| ) -> Optional[dict]: | |
| """ | |
| Retrieve the best model for a specific language using MLflow tags. | |
| Args: | |
| language: Programming language (java, python, pharo) | |
| tag_key: Tag key to search for (default: "best_model") | |
| metric: Metric to use for ordering (default: "f1_score") | |
| Returns: | |
| Dict with run_id and artifact_name of the best model or None if not found | |
| """ | |
| client = MlflowClient() | |
| experiments = client.search_experiments() | |
| if not experiments: | |
| logger.error("No experiments found in MLflow") | |
| return None | |
| try: | |
| runs = client.search_runs( | |
| experiment_ids=[exp.experiment_id for exp in experiments], | |
| filter_string=f"tags.{tag_key} = 'true' and tags.Language = '{language}'", | |
| order_by=[f"metrics.{metric} DESC"], | |
| max_results=1 | |
| ) | |
| if not runs: | |
| logger.warning(f"No runs found with tag '{tag_key}' for language '{language}'") | |
| return None | |
| best_run = runs[0] | |
| run_id = best_run.info.run_id | |
| exp_name = client.get_experiment(best_run.info.experiment_id).name | |
| run_name = best_run.info.run_name | |
| artifact_name = best_run.data.tags.get("model_name") | |
| model_id = best_run.data.tags.get("model_id") | |
| logger.info(f"Found best model for {language}: {exp_name}/{run_name} ({run_id}), artifact={artifact_name}") | |
| return { | |
| "run_id": run_id, | |
| "artifact": artifact_name, | |
| "model_id": model_id | |
| } | |
| except Exception as e: | |
| logger.error(f"Error searching for best model: {e}") | |
| return None | |
| def get_best_model_info( | |
| language: str, | |
| fallback_registry: dict = None | |
| ) -> dict: | |
| """ | |
| Retrieve the best model information for a language. | |
| First searches by tag, then falls back to hardcoded registry. | |
| Args: | |
| language: Programming language | |
| fallback_registry: Fallback registry with run_id and artifact | |
| Returns: | |
| Dict with run_id and artifact of the model | |
| """ | |
| model_info = get_best_model_by_tag(language, "best_model") | |
| if model_info: | |
| logger.info(f"Using tagged best model for {language}") | |
| return model_info | |
| if fallback_registry and language in fallback_registry: | |
| logger.warning(f"No tagged model found for {language}, using fallback registry") | |
| return fallback_registry[language] | |
| model_info = get_best_model_by_metric(language) | |
| if model_info: | |
| logger.warning(f"Using best model by metric for {language}") | |
| return model_info | |
| raise ValueError(f"No model found for language {language}") | |
| def get_best_model_by_metric( | |
| language: str, | |
| metric: str = "f1_score" | |
| ) -> Optional[dict]: | |
| """ | |
| Find the model with the best metric for a language. | |
| Args: | |
| language: Programming language | |
| metric: Metric to use for ordering | |
| Returns: | |
| Dict with run_id and artifact of the model or None | |
| """ | |
| client = MlflowClient() | |
| experiments = client.search_experiments() | |
| if not experiments: | |
| logger.error("No experiments found in MLflow") | |
| return None | |
| try: | |
| runs = client.search_runs( | |
| experiment_ids=[exp.experiment_id for exp in experiments], | |
| filter_string=f"tags.Language = '{language}'", | |
| order_by=[f"metrics.{metric} DESC"], | |
| max_results=1 | |
| ) | |
| if not runs: | |
| logger.warning(f"No runs found for language '{language}'") | |
| return None | |
| best_run = runs[0] | |
| run_id = best_run.info.run_id | |
| exp_name = client.get_experiment(best_run.info.experiment_id).name | |
| run_name = best_run.info.run_name | |
| artifact_name = best_run.data.tags.get("model_name") | |
| model_id = best_run.data.tags.get("model_id") | |
| logger.info(f"Found best model for {language}: {exp_name}/{run_name} ({run_id}), artifact={artifact_name}") | |
| return { | |
| "run_id": run_id, | |
| "artifact": artifact_name, | |
| "model_id": model_id | |
| } | |
| except Exception as e: | |
| logger.error(f"Error finding best model by metric: {e}") | |
| return None | |