vr180-converter / backend /model_manager.py
Prithwis's picture
Force update backend\model_manager.py - 1757359681
807c166 verified
"""
Model Manager for preloading and managing AI models in memory
"""
import torch
import numpy as np
from transformers import pipeline
import time
from typing import Optional, Dict, Any
import gc
class ModelManager:
"""Singleton class to manage preloaded models"""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(ModelManager, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self.depth_estimator = None
self.upscaler = None
self.models_loaded = False
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model_config = {
"depth_model": "Intel/dpt-hybrid-midas", # Lighter model for faster loading
"upscale_model": None, # Placeholder for future upscaling models
}
self._initialized = True
def load_models_sync(self):
"""Load models synchronously to avoid threading issues"""
if self.models_loaded:
return True
try:
print("πŸ”„ Loading AI models...")
start_time = time.time()
# Load depth estimation model
print("Loading depth estimation model...")
self.depth_estimator = pipeline(
"depth-estimation",
model=self.model_config["depth_model"],
device=self.device,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
)
# Load upscaler model (placeholder for now)
self.upscaler = None
self.models_loaded = True
load_time = time.time() - start_time
print(f"βœ… Models loaded successfully in {load_time:.2f}s")
return True
except Exception as e:
print(f"❌ Error loading models: {e}")
self.models_loaded = False
return False
def wait_for_models(self, timeout: Optional[float] = None) -> bool:
"""Wait for models to be loaded"""
if self.models_loaded:
return True
return self.load_models_sync()
def get_depth_estimator(self):
"""Get the depth estimation model"""
if not self.models_loaded:
self.load_models_sync()
return self.depth_estimator
def get_upscaler(self):
"""Get the upscaler model"""
if not self.models_loaded:
self.load_models_sync()
return self.upscaler
def is_ready(self) -> bool:
"""Check if models are ready"""
return self.models_loaded
def get_model_info(self) -> Dict[str, Any]:
"""Get information about loaded models"""
return {
"models_loaded": self.models_loaded,
"device": self.device,
"depth_model": self.model_config["depth_model"],
"upscale_model": self.model_config["upscale_model"]
}
def cleanup(self):
"""Clean up models and free memory"""
if self.depth_estimator is not None:
del self.depth_estimator
self.depth_estimator = None
if self.upscaler is not None:
del self.upscaler
self.upscaler = None
self.models_loaded = False
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Global instance
model_manager = ModelManager()