File size: 7,498 Bytes
5fc6e5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import importlib
import warnings

import dagshub
from loguru import logger
import mlflow
import numpy as np
import pandas as pd

from turing.config import INPUT_COLUMN, LABELS_MAP, LANGS, MODEL_CONFIG, MODELS_DIR
from turing.dataset import DatasetManager
from turing.modeling.model_selector import get_best_model_info
from turing.modeling.models.codeBerta import CodeBERTa


class ModelInference:
    # Model Configuration (Fallback Registry)
    FALLBACK_MODEL_REGISTRY = {
        "java": {
            "run_id": "446f4459780347da8c796e619129be37",
            "artifact": "fine-tuned-CodeBERTa_java",
            "model_id": "codeberta",
        },
        "python": {
            "run_id": "ef5fd8ebf33a412087dcf02afd9e3147",
            "artifact": "fine-tuned-CodeBERTa_python",
            "model_id": "codeberta",
        },
        "pharo": {
            "run_id": "97822c6d84fc40c5b2363c9201a39997",
            "artifact": "fine-tuned-CodeBERTa_pharo",
            "model_id": "codeberta",
        },
    }


    def __init__(self, repo_owner="se4ai2526-uniba", repo_name="Turing", use_best_model_tags=True):
        dagshub.init(repo_owner=repo_owner, repo_name=repo_name, mlflow=True)
        warnings.filterwarnings("ignore")
        self.dataset_manager = DatasetManager()
        self.use_best_model_tags = use_best_model_tags

        # Initialize model registry based on configuration
        if use_best_model_tags:
            logger.info("Using MLflow tags to find best models")

            self.model_registry = {}
            for lang in LANGS:
                try:
                    model_info = get_best_model_info(
                        lang, fallback_registry=self.FALLBACK_MODEL_REGISTRY
                    )
                    self.model_registry[lang] = model_info
                    logger.info(f"Loaded model info for {lang}: {model_info}")

                    # raise error if any required info is missing
                    if not all(k in model_info for k in ("run_id", "artifact", "model_id")):
                        raise ValueError(f"Incomplete model info for {lang}: {model_info}")

                except Exception as e:
                    logger.warning(f"Could not load model info for {lang}: {e}")
                    if lang in self.FALLBACK_MODEL_REGISTRY:
                        self.model_registry[lang] = self.FALLBACK_MODEL_REGISTRY[lang]

                # Pre-cache models locally
                run_id = self.model_registry[lang]["run_id"]
                artifact = self.model_registry[lang]["artifact"]
                self._get_cached_model_path(run_id, artifact, lang)
        else:
            logger.info("Using hardcoded model registry")
            self.model_registry = self.FALLBACK_MODEL_REGISTRY

    def _decode_predictions(self, raw_predictions, language: str):
        """
        Converts the binary matrix from the model into human-readable labels.

        Args:
            raw_predictions: Numpy array or similar with binary predictions
            language: Programming language for label mapping
        """

        labels_map = LABELS_MAP.get(language, [])
        decoded_results = []

        # Ensure input is a numpy array for processing
        if isinstance(raw_predictions, list):
            raw_array = np.array(raw_predictions)
        elif isinstance(raw_predictions, pd.DataFrame):
            raw_array = raw_predictions.values
        else:
            raw_array = raw_predictions

        # Iterate over rows
        for row in raw_array:
            indices = np.where(row == 1)[0]
            # Map indices to labels safely
            row_labels = [labels_map[i] for i in indices if i < len(labels_map)]
            decoded_results.append(row_labels)

        return decoded_results

    def _get_cached_model_path(self, run_id: str, artifact_name: str, language: str) -> str:
        """Checks if model exists locally; if not, downloads it from MLflow."""
        # Define local path: models/mlflow_temp_models/language/artifact_name
        local_path = MODELS_DIR / "mlflow_temp_models" / language / artifact_name

        if local_path.exists():
            logger.info(f"Loading {language} model from local cache: {local_path}")
            return str(local_path)

        logger.info(
            f"Model not found locally. Downloading {language} model from MLflow (Run ID: {run_id})..."
        )

        # Ensure parent directory exists
        local_path.parent.mkdir(parents=True, exist_ok=True)

        # Download artifacts to the parent directory (artifact_name folder will be created inside)
        mlflow.artifacts.download_artifacts(
            run_id=run_id, artifact_path=artifact_name, dst_path=str(local_path.parent)
        )
        logger.success(f"Model downloaded and cached at: {local_path}")

        return str(local_path)

    def predict_payload(self, texts: list[str], language: str):
        """
        API Prediction: Automatically fetches the correct model from the registry based on language.

        Args:
            texts: List of code comments to classify
            language: Programming language
        """

        # 1. Validate Language and Fetch Config
        if language not in self.model_registry:
            raise ValueError(
                f"Language '{language}' is not supported or the model is not configured."
            )

        model_config = self.model_registry[language]
        run_id = model_config["run_id"]
        artifact_name = model_config["artifact"]
        model_id = model_config["model_id"]

        # Dynamically import model class
        config_entry = MODEL_CONFIG[model_id]
        module_name = config_entry["model_class_module"]
        class_name = config_entry["model_class_name"]
        module = importlib.import_module(module_name)
        model_class = getattr(module, class_name)

        # 2. Get Model Path (Local Cache or Download)
        model_path = self._get_cached_model_path(run_id, artifact_name, language)

        # Load Model
        model = model_class(language=language, path=model_path)

        # 3. Predict
        raw_predictions = model.predict(texts)

        # 4. Decode Labels
        decoded_labels = self._decode_predictions(raw_predictions, language)

        return raw_predictions, decoded_labels, run_id, artifact_name

    def predict_from_mlflow(
        self, mlflow_run_id: str, artifact_name: str, language: str, model_class=CodeBERTa
    ):
        """
        Legacy method for CML/CLI: Predicts on the test dataset stored on disk.
        """
        # Load Dataset
        try:
            full_dataset = self.dataset_manager.get_dataset()
            dataset_key = f"{language}_test"
            if dataset_key not in full_dataset:
                raise ValueError(f"Dataset key '{dataset_key}' not found.")
            test_ds = full_dataset[dataset_key]
            X_test = test_ds[INPUT_COLUMN]
        except Exception as e:
            logger.error(f"Error loading dataset: {e}")
            raise e

        # Load Model (Local Cache or Download)
        model_path = self._get_cached_model_path(mlflow_run_id, artifact_name, language)
        model = model_class(language=language, path=model_path)

        raw_predictions = model.predict(X_test)

        # Decode output
        readable_predictions = self._decode_predictions(raw_predictions, language)

        logger.info("Dataset prediction completed.")
        return readable_predictions