Spaces:
Sleeping
Sleeping
| """ | |
| Fine-tuning Script for Medical AI Models | |
| Trains models on real medical datasets for production use | |
| """ | |
| import os | |
| import torch | |
| import pandas as pd | |
| import numpy as np | |
| from PIL import Image | |
| from torch.utils.data import Dataset, DataLoader | |
| from transformers import ( | |
| ViTImageProcessor, | |
| ViTForImageClassification, | |
| Trainer, | |
| TrainingArguments, | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification | |
| ) | |
| from datasets import load_dataset | |
| from sklearn.model_selection import train_test_split | |
| import json | |
| class SkinLesionDataset(Dataset): | |
| """Dataset for skin lesion images (HAM10000 format)""" | |
| def __init__(self, image_paths, labels, processor): | |
| self.image_paths = image_paths | |
| self.labels = labels | |
| self.processor = processor | |
| def __len__(self): | |
| return len(self.image_paths) | |
| def __getitem__(self, idx): | |
| image = Image.open(self.image_paths[idx]).convert('RGB') | |
| encoding = self.processor(images=image, return_tensors="pt") | |
| encoding = {key: val.squeeze() for key, val in encoding.items()} | |
| encoding['labels'] = torch.tensor(self.labels[idx]) | |
| return encoding | |
| class SymptomDataset(Dataset): | |
| """Dataset for symptom-to-disease classification""" | |
| def __init__(self, texts, labels, tokenizer, max_length=128): | |
| self.texts = texts | |
| self.labels = labels | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| def __len__(self): | |
| return len(self.texts) | |
| def __getitem__(self, idx): | |
| encoding = self.tokenizer( | |
| self.texts[idx], | |
| truncation=True, | |
| padding='max_length', | |
| max_length=self.max_length, | |
| return_tensors='pt' | |
| ) | |
| encoding = {key: val.squeeze() for key, val in encoding.items()} | |
| encoding['labels'] = torch.tensor(self.labels[idx]) | |
| return encoding | |
| class MedicalModelTrainer: | |
| """Fine-tune models on medical datasets""" | |
| def __init__(self, output_dir="./trained_models"): | |
| self.output_dir = output_dir | |
| os.makedirs(output_dir, exist_ok=True) | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {self.device}") | |
| def finetune_skin_model(self, data_dir, num_epochs=10): | |
| """ | |
| Fine-tune Vision Transformer on HAM10000 skin lesion dataset | |
| Dataset structure: | |
| data_dir/ | |
| βββ images/ | |
| β βββ image1.jpg | |
| β βββ image2.jpg | |
| βββ labels.csv (columns: image_id, diagnosis) | |
| Download from: https://www.kaggle.com/datasets/kmader/skin-cancer-mnist-ham10000 | |
| """ | |
| print("π¬ Fine-tuning Skin Condition Model...") | |
| # Load dataset | |
| try: | |
| labels_df = pd.read_csv(os.path.join(data_dir, "HAM10000_metadata.csv")) | |
| except FileNotFoundError: | |
| print("β Dataset not found. Download HAM10000 from Kaggle:") | |
| print(" kaggle datasets download -d kmader/skin-cancer-mnist-ham10000") | |
| return None | |
| # Map diagnoses to indices | |
| diagnosis_map = { | |
| 'akiec': 0, # Actinic keratoses | |
| 'bcc': 1, # Basal cell carcinoma | |
| 'bkl': 2, # Benign keratosis | |
| 'df': 3, # Dermatofibroma | |
| 'mel': 4, # Melanoma | |
| 'nv': 5, # Melanocytic nevi | |
| 'vasc': 6 # Vascular lesions | |
| } | |
| labels_df['label'] = labels_df['dx'].map(diagnosis_map) | |
| # Prepare image paths | |
| image_dir = os.path.join(data_dir, "images") | |
| labels_df['image_path'] = labels_df['image_id'].apply( | |
| lambda x: os.path.join(image_dir, f"{x}.jpg") | |
| ) | |
| # Filter existing images | |
| labels_df = labels_df[labels_df['image_path'].apply(os.path.exists)] | |
| print(f"π Loaded {len(labels_df)} images") | |
| # Split dataset | |
| train_df, val_df = train_test_split( | |
| labels_df, | |
| test_size=0.2, | |
| stratify=labels_df['label'], | |
| random_state=42 | |
| ) | |
| # Load processor and model | |
| processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') | |
| model = ViTForImageClassification.from_pretrained( | |
| 'google/vit-base-patch16-224', | |
| num_labels=len(diagnosis_map), | |
| ignore_mismatched_sizes=True | |
| ) | |
| # Create datasets | |
| train_dataset = SkinLesionDataset( | |
| train_df['image_path'].tolist(), | |
| train_df['label'].tolist(), | |
| processor | |
| ) | |
| val_dataset = SkinLesionDataset( | |
| val_df['image_path'].tolist(), | |
| val_df['label'].tolist(), | |
| processor | |
| ) | |
| # Training arguments | |
| training_args = TrainingArguments( | |
| output_dir=os.path.join(self.output_dir, "skin-condition-vit"), | |
| evaluation_strategy="epoch", | |
| save_strategy="epoch", | |
| learning_rate=2e-5, | |
| per_device_train_batch_size=16, | |
| per_device_eval_batch_size=16, | |
| num_train_epochs=num_epochs, | |
| weight_decay=0.01, | |
| load_best_model_at_end=True, | |
| metric_for_best_model="accuracy", | |
| logging_dir='./logs', | |
| logging_steps=100, | |
| save_total_limit=2 | |
| ) | |
| # Define metrics | |
| def compute_metrics(eval_pred): | |
| predictions, labels = eval_pred | |
| predictions = np.argmax(predictions, axis=1) | |
| accuracy = (predictions == labels).mean() | |
| return {"accuracy": accuracy} | |
| # Create trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=val_dataset, | |
| compute_metrics=compute_metrics | |
| ) | |
| # Train | |
| print("ποΈ Training started...") | |
| trainer.train() | |
| # Save model | |
| model_path = os.path.join(self.output_dir, "skin-condition-vit-final") | |
| trainer.save_model(model_path) | |
| processor.save_pretrained(model_path) | |
| # Save label mapping | |
| with open(os.path.join(model_path, "label_map.json"), "w") as f: | |
| reverse_map = {v: k for k, v in diagnosis_map.items()} | |
| json.dump(reverse_map, f) | |
| print(f"β Model saved to {model_path}") | |
| return model_path | |
| def finetune_symptom_model(self, data_file, num_epochs=5): | |
| """ | |
| Fine-tune BERT on symptom-to-disease dataset | |
| Dataset format (CSV): | |
| symptoms,disease | |
| "headache fever cough","Influenza" | |
| "chest pain shortness of breath","Heart Condition" | |
| Download from Kaggle: Disease Symptom Prediction Dataset | |
| """ | |
| print("π¬ Fine-tuning Symptom Analysis Model...") | |
| try: | |
| # Load dataset | |
| df = pd.read_csv(data_file) | |
| # Create disease label mapping | |
| diseases = df['disease'].unique() | |
| disease_map = {disease: idx for idx, disease in enumerate(diseases)} | |
| df['label'] = df['disease'].map(disease_map) | |
| print(f"π Loaded {len(df)} examples with {len(diseases)} diseases") | |
| # Split dataset | |
| train_df, val_df = train_test_split( | |
| df, | |
| test_size=0.2, | |
| stratify=df['label'], | |
| random_state=42 | |
| ) | |
| # Load tokenizer and model | |
| model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| model_name, | |
| num_labels=len(diseases) | |
| ) | |
| # Create datasets | |
| train_dataset = SymptomDataset( | |
| train_df['symptoms'].tolist(), | |
| train_df['label'].tolist(), | |
| tokenizer | |
| ) | |
| val_dataset = SymptomDataset( | |
| val_df['symptoms'].tolist(), | |
| val_df['label'].tolist(), | |
| tokenizer | |
| ) | |
| # Training arguments | |
| training_args = TrainingArguments( | |
| output_dir=os.path.join(self.output_dir, "symptom-bert"), | |
| evaluation_strategy="epoch", | |
| save_strategy="epoch", | |
| learning_rate=2e-5, | |
| per_device_train_batch_size=16, | |
| per_device_eval_batch_size=16, | |
| num_train_epochs=num_epochs, | |
| weight_decay=0.01, | |
| load_best_model_at_end=True, | |
| metric_for_best_model="accuracy", | |
| logging_steps=50 | |
| ) | |
| # Define metrics | |
| def compute_metrics(eval_pred): | |
| predictions, labels = eval_pred | |
| predictions = np.argmax(predictions, axis=1) | |
| accuracy = (predictions == labels).mean() | |
| return {"accuracy": accuracy} | |
| # Create trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=val_dataset, | |
| compute_metrics=compute_metrics | |
| ) | |
| # Train | |
| print("ποΈ Training started...") | |
| trainer.train() | |
| # Save model | |
| model_path = os.path.join(self.output_dir, "symptom-bert-final") | |
| trainer.save_model(model_path) | |
| tokenizer.save_pretrained(model_path) | |
| # Save label mapping | |
| with open(os.path.join(model_path, "disease_map.json"), "w") as f: | |
| reverse_map = {v: k for k, v in disease_map.items()} | |
| json.dump(reverse_map, f) | |
| print(f"β Model saved to {model_path}") | |
| return model_path | |
| except FileNotFoundError: | |
| print("β Dataset not found. Create or download symptom-disease dataset") | |
| print(" Format: CSV with columns 'symptoms' and 'disease'") | |
| return None | |
| def create_sample_symptom_dataset(self, output_file="symptom_dataset.csv"): | |
| """Create a sample symptom dataset for testing""" | |
| print("π Creating sample symptom dataset...") | |
| sample_data = [ | |
| ("headache fever fatigue", "Influenza"), | |
| ("cough shortness of breath chest pain", "Pneumonia"), | |
| ("nausea vomiting diarrhea", "Gastroenteritis"), | |
| ("rash itching redness", "Allergic Reaction"), | |
| ("sore throat fever headache", "Strep Throat"), | |
| ("fatigue weakness pale skin", "Anemia"), | |
| ("headache sensitivity to light nausea", "Migraine"), | |
| ("chest pain shortness of breath", "Heart Condition"), | |
| ("fever cough body aches", "Common Cold"), | |
| ("abdominal pain nausea fever", "Appendicitis") | |
| ] * 50 # Duplicate for larger dataset | |
| df = pd.DataFrame(sample_data, columns=['symptoms', 'disease']) | |
| df.to_csv(output_file, index=False) | |
| print(f"β Sample dataset saved to {output_file}") | |
| return output_file | |
| def main(): | |
| """Main training pipeline""" | |
| trainer = MedicalModelTrainer() | |
| print("=" * 60) | |
| print("π₯ Medical AI Model Fine-tuning Pipeline") | |
| print("=" * 60) | |
| # Option 1: Fine-tune skin condition model | |
| print("\n1οΈβ£ Skin Condition Model") | |
| print(" Dataset: HAM10000 (download from Kaggle)") | |
| print(" Command: kaggle datasets download -d kmader/skin-cancer-mnist-ham10000") | |
| skin_data_dir = "./HAM10000" | |
| if os.path.exists(skin_data_dir): | |
| trainer.finetune_skin_model(skin_data_dir, num_epochs=3) | |
| else: | |
| print(" βοΈ Skipping (dataset not found)") | |
| # Option 2: Fine-tune symptom model | |
| print("\n2οΈβ£ Symptom Analysis Model") | |
| symptom_dataset = "./symptom_dataset.csv" | |
| if not os.path.exists(symptom_dataset): | |
| symptom_dataset = trainer.create_sample_symptom_dataset() | |
| trainer.finetune_symptom_model(symptom_dataset, num_epochs=3) | |
| print("\n" + "=" * 60) | |
| print("β Training complete!") | |
| print("=" * 60) | |
| print("\nπ¦ Trained models saved in ./trained_models/") | |
| print("\nπ To use in production:") | |
| print(" 1. Update ai_models.py to load from ./trained_models/") | |
| print(" 2. Replace model_name with local path") | |
| print(" 3. Test with test_api.py") | |
| if __name__ == "__main__": | |
| main() |