TechJamM2I / app.py
nimatov's picture
Update app.py
179e049 verified
raw
history blame
1.75 kB
import streamlit as st
from flask.Emotion_spotting_service import _Emotion_spotting_service
from flask.Genre_spotting_service import _Genre_spotting_service
from flask.Beat_tracking_service import _Beat_tracking_service
from diffusers import StableDiffusionPipeline
import torch
import os
import logging
import psutil
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def print_memory_info():
# Get free CPU memory
virtual_mem = psutil.virtual_memory()
free_cpu_mem = virtual_mem.available / (1024 ** 3) # Convert bytes to GB
# Get free GPU memory
free_gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) - torch.cuda.memory_reserved(0) / (1024 ** 3) # Convert bytes to GB
logger.info(f"Free CPU Memory: {free_cpu_mem:.2f} GB")
logger.info(f"Free GPU Memory: {free_gpu_mem:.2f} GB")
emo_list = []
gen_list = []
tempo_list = []
@st.cache_resource
def load_emo_model():
emo_service = _Emotion_spotting_service("flask/emotion_model.h5")
return emo_service
@st.cache_resource
def load_genre_model():
gen_service = _Genre_spotting_service("flask/Genre_classifier_model.h5")
return gen_service
@st.cache_resource
def load_beat_model():
beat_service = _Beat_tracking_service()
return beat_service
@st.cache_resource
def load_image_model():
torch.cuda.empty_cache()
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", variant='fp16')
pipeline.to("cuda")
pipeline.load_lora_weights("Weights/pytorch_lora_weights.safetensors", weight_name="pytorch_lora_weights.safetensors")
return pipeline
print_memory_info()
image_service = load_image_model()
print_memory_info()