import importlib import warnings import dagshub from loguru import logger import mlflow import numpy as np import pandas as pd from turing.config import INPUT_COLUMN, LABELS_MAP, LANGS, MODEL_CONFIG, MODELS_DIR from turing.dataset import DatasetManager from turing.modeling.model_selector import get_best_model_info from turing.modeling.models.codeBerta import CodeBERTa class ModelInference: # Model Configuration (Fallback Registry) FALLBACK_MODEL_REGISTRY = { "java": { "run_id": "446f4459780347da8c796e619129be37", "artifact": "fine-tuned-CodeBERTa_java", "model_id": "codeberta", }, "python": { "run_id": "ef5fd8ebf33a412087dcf02afd9e3147", "artifact": "fine-tuned-CodeBERTa_python", "model_id": "codeberta", }, "pharo": { "run_id": "97822c6d84fc40c5b2363c9201a39997", "artifact": "fine-tuned-CodeBERTa_pharo", "model_id": "codeberta", }, } def __init__(self, repo_owner="se4ai2526-uniba", repo_name="Turing", use_best_model_tags=True): dagshub.init(repo_owner=repo_owner, repo_name=repo_name, mlflow=True) warnings.filterwarnings("ignore") self.dataset_manager = DatasetManager() self.use_best_model_tags = use_best_model_tags # Initialize model registry based on configuration if use_best_model_tags: logger.info("Using MLflow tags to find best models") self.model_registry = {} for lang in LANGS: try: model_info = get_best_model_info( lang, fallback_registry=self.FALLBACK_MODEL_REGISTRY ) self.model_registry[lang] = model_info logger.info(f"Loaded model info for {lang}: {model_info}") # raise error if any required info is missing if not all(k in model_info for k in ("run_id", "artifact", "model_id")): raise ValueError(f"Incomplete model info for {lang}: {model_info}") except Exception as e: logger.warning(f"Could not load model info for {lang}: {e}") if lang in self.FALLBACK_MODEL_REGISTRY: self.model_registry[lang] = self.FALLBACK_MODEL_REGISTRY[lang] # Pre-cache models locally run_id = self.model_registry[lang]["run_id"] artifact = self.model_registry[lang]["artifact"] self._get_cached_model_path(run_id, artifact, lang) else: logger.info("Using hardcoded model registry") self.model_registry = self.FALLBACK_MODEL_REGISTRY def _decode_predictions(self, raw_predictions, language: str): """ Converts the binary matrix from the model into human-readable labels. Args: raw_predictions: Numpy array or similar with binary predictions language: Programming language for label mapping """ labels_map = LABELS_MAP.get(language, []) decoded_results = [] # Ensure input is a numpy array for processing if isinstance(raw_predictions, list): raw_array = np.array(raw_predictions) elif isinstance(raw_predictions, pd.DataFrame): raw_array = raw_predictions.values else: raw_array = raw_predictions # Iterate over rows for row in raw_array: indices = np.where(row == 1)[0] # Map indices to labels safely row_labels = [labels_map[i] for i in indices if i < len(labels_map)] decoded_results.append(row_labels) return decoded_results def _get_cached_model_path(self, run_id: str, artifact_name: str, language: str) -> str: """Checks if model exists locally; if not, downloads it from MLflow.""" # Define local path: models/mlflow_temp_models/language/artifact_name local_path = MODELS_DIR / "mlflow_temp_models" / language / artifact_name if local_path.exists(): logger.info(f"Loading {language} model from local cache: {local_path}") return str(local_path) logger.info( f"Model not found locally. Downloading {language} model from MLflow (Run ID: {run_id})..." ) # Ensure parent directory exists local_path.parent.mkdir(parents=True, exist_ok=True) # Download artifacts to the parent directory (artifact_name folder will be created inside) mlflow.artifacts.download_artifacts( run_id=run_id, artifact_path=artifact_name, dst_path=str(local_path.parent) ) logger.success(f"Model downloaded and cached at: {local_path}") return str(local_path) def predict_payload(self, texts: list[str], language: str): """ API Prediction: Automatically fetches the correct model from the registry based on language. Args: texts: List of code comments to classify language: Programming language """ # 1. Validate Language and Fetch Config if language not in self.model_registry: raise ValueError( f"Language '{language}' is not supported or the model is not configured." ) model_config = self.model_registry[language] run_id = model_config["run_id"] artifact_name = model_config["artifact"] model_id = model_config["model_id"] # Dynamically import model class config_entry = MODEL_CONFIG[model_id] module_name = config_entry["model_class_module"] class_name = config_entry["model_class_name"] module = importlib.import_module(module_name) model_class = getattr(module, class_name) # 2. Get Model Path (Local Cache or Download) model_path = self._get_cached_model_path(run_id, artifact_name, language) # Load Model model = model_class(language=language, path=model_path) # 3. Predict raw_predictions = model.predict(texts) # 4. Decode Labels decoded_labels = self._decode_predictions(raw_predictions, language) return raw_predictions, decoded_labels, run_id, artifact_name def predict_from_mlflow( self, mlflow_run_id: str, artifact_name: str, language: str, model_class=CodeBERTa ): """ Legacy method for CML/CLI: Predicts on the test dataset stored on disk. """ # Load Dataset try: full_dataset = self.dataset_manager.get_dataset() dataset_key = f"{language}_test" if dataset_key not in full_dataset: raise ValueError(f"Dataset key '{dataset_key}' not found.") test_ds = full_dataset[dataset_key] X_test = test_ds[INPUT_COLUMN] except Exception as e: logger.error(f"Error loading dataset: {e}") raise e # Load Model (Local Cache or Download) model_path = self._get_cached_model_path(mlflow_run_id, artifact_name, language) model = model_class(language=language, path=model_path) raw_predictions = model.predict(X_test) # Decode output readable_predictions = self._decode_predictions(raw_predictions, language) logger.info("Dataset prediction completed.") return readable_predictions