Spaces:
Runtime error
Runtime error
File size: 7,498 Bytes
5fc6e5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import importlib
import warnings
import dagshub
from loguru import logger
import mlflow
import numpy as np
import pandas as pd
from turing.config import INPUT_COLUMN, LABELS_MAP, LANGS, MODEL_CONFIG, MODELS_DIR
from turing.dataset import DatasetManager
from turing.modeling.model_selector import get_best_model_info
from turing.modeling.models.codeBerta import CodeBERTa
class ModelInference:
# Model Configuration (Fallback Registry)
FALLBACK_MODEL_REGISTRY = {
"java": {
"run_id": "446f4459780347da8c796e619129be37",
"artifact": "fine-tuned-CodeBERTa_java",
"model_id": "codeberta",
},
"python": {
"run_id": "ef5fd8ebf33a412087dcf02afd9e3147",
"artifact": "fine-tuned-CodeBERTa_python",
"model_id": "codeberta",
},
"pharo": {
"run_id": "97822c6d84fc40c5b2363c9201a39997",
"artifact": "fine-tuned-CodeBERTa_pharo",
"model_id": "codeberta",
},
}
def __init__(self, repo_owner="se4ai2526-uniba", repo_name="Turing", use_best_model_tags=True):
dagshub.init(repo_owner=repo_owner, repo_name=repo_name, mlflow=True)
warnings.filterwarnings("ignore")
self.dataset_manager = DatasetManager()
self.use_best_model_tags = use_best_model_tags
# Initialize model registry based on configuration
if use_best_model_tags:
logger.info("Using MLflow tags to find best models")
self.model_registry = {}
for lang in LANGS:
try:
model_info = get_best_model_info(
lang, fallback_registry=self.FALLBACK_MODEL_REGISTRY
)
self.model_registry[lang] = model_info
logger.info(f"Loaded model info for {lang}: {model_info}")
# raise error if any required info is missing
if not all(k in model_info for k in ("run_id", "artifact", "model_id")):
raise ValueError(f"Incomplete model info for {lang}: {model_info}")
except Exception as e:
logger.warning(f"Could not load model info for {lang}: {e}")
if lang in self.FALLBACK_MODEL_REGISTRY:
self.model_registry[lang] = self.FALLBACK_MODEL_REGISTRY[lang]
# Pre-cache models locally
run_id = self.model_registry[lang]["run_id"]
artifact = self.model_registry[lang]["artifact"]
self._get_cached_model_path(run_id, artifact, lang)
else:
logger.info("Using hardcoded model registry")
self.model_registry = self.FALLBACK_MODEL_REGISTRY
def _decode_predictions(self, raw_predictions, language: str):
"""
Converts the binary matrix from the model into human-readable labels.
Args:
raw_predictions: Numpy array or similar with binary predictions
language: Programming language for label mapping
"""
labels_map = LABELS_MAP.get(language, [])
decoded_results = []
# Ensure input is a numpy array for processing
if isinstance(raw_predictions, list):
raw_array = np.array(raw_predictions)
elif isinstance(raw_predictions, pd.DataFrame):
raw_array = raw_predictions.values
else:
raw_array = raw_predictions
# Iterate over rows
for row in raw_array:
indices = np.where(row == 1)[0]
# Map indices to labels safely
row_labels = [labels_map[i] for i in indices if i < len(labels_map)]
decoded_results.append(row_labels)
return decoded_results
def _get_cached_model_path(self, run_id: str, artifact_name: str, language: str) -> str:
"""Checks if model exists locally; if not, downloads it from MLflow."""
# Define local path: models/mlflow_temp_models/language/artifact_name
local_path = MODELS_DIR / "mlflow_temp_models" / language / artifact_name
if local_path.exists():
logger.info(f"Loading {language} model from local cache: {local_path}")
return str(local_path)
logger.info(
f"Model not found locally. Downloading {language} model from MLflow (Run ID: {run_id})..."
)
# Ensure parent directory exists
local_path.parent.mkdir(parents=True, exist_ok=True)
# Download artifacts to the parent directory (artifact_name folder will be created inside)
mlflow.artifacts.download_artifacts(
run_id=run_id, artifact_path=artifact_name, dst_path=str(local_path.parent)
)
logger.success(f"Model downloaded and cached at: {local_path}")
return str(local_path)
def predict_payload(self, texts: list[str], language: str):
"""
API Prediction: Automatically fetches the correct model from the registry based on language.
Args:
texts: List of code comments to classify
language: Programming language
"""
# 1. Validate Language and Fetch Config
if language not in self.model_registry:
raise ValueError(
f"Language '{language}' is not supported or the model is not configured."
)
model_config = self.model_registry[language]
run_id = model_config["run_id"]
artifact_name = model_config["artifact"]
model_id = model_config["model_id"]
# Dynamically import model class
config_entry = MODEL_CONFIG[model_id]
module_name = config_entry["model_class_module"]
class_name = config_entry["model_class_name"]
module = importlib.import_module(module_name)
model_class = getattr(module, class_name)
# 2. Get Model Path (Local Cache or Download)
model_path = self._get_cached_model_path(run_id, artifact_name, language)
# Load Model
model = model_class(language=language, path=model_path)
# 3. Predict
raw_predictions = model.predict(texts)
# 4. Decode Labels
decoded_labels = self._decode_predictions(raw_predictions, language)
return raw_predictions, decoded_labels, run_id, artifact_name
def predict_from_mlflow(
self, mlflow_run_id: str, artifact_name: str, language: str, model_class=CodeBERTa
):
"""
Legacy method for CML/CLI: Predicts on the test dataset stored on disk.
"""
# Load Dataset
try:
full_dataset = self.dataset_manager.get_dataset()
dataset_key = f"{language}_test"
if dataset_key not in full_dataset:
raise ValueError(f"Dataset key '{dataset_key}' not found.")
test_ds = full_dataset[dataset_key]
X_test = test_ds[INPUT_COLUMN]
except Exception as e:
logger.error(f"Error loading dataset: {e}")
raise e
# Load Model (Local Cache or Download)
model_path = self._get_cached_model_path(mlflow_run_id, artifact_name, language)
model = model_class(language=language, path=model_path)
raw_predictions = model.predict(X_test)
# Decode output
readable_predictions = self._decode_predictions(raw_predictions, language)
logger.info("Dataset prediction completed.")
return readable_predictions
|