SmolLM2-135M / app.py
Rahul2020's picture
Upload 7 files
2f9108b verified
import gradio as gr
import torch
from transformers import LlamaForCausalLM, GPT2TokenizerFast
import os
# Optional import for quantization (GPU only)
try:
from transformers import BitsAndBytesConfig
HAS_BITSANDBYTES = True
except ImportError:
HAS_BITSANDBYTES = False
# Global variables for model and tokenizer
model = None
tokenizer = None
device = "cuda" if torch.cuda.is_available() else "cpu"
# Configuration:
# Set HF_MODEL_ID to load from your HuggingFace Hub repository instead of local checkpoint
HF_MODEL_ID = os.getenv("HF_MODEL_ID", None) # e.g., "your-username/smollm2-135m-coriolanus"
# Note: Quantization (8-bit/4-bit) requires GPU and is disabled for CPU-only environments
def load_model():
"""Load the model and tokenizer with optional quantization"""
global model, tokenizer
if model is None:
print("Loading model...")
# Determine model source
model_path = None
if HF_MODEL_ID:
# Load from HuggingFace Hub repository
print(f"Loading from HuggingFace Hub: {HF_MODEL_ID}")
model_path = HF_MODEL_ID
elif os.path.exists("checkpoint_5000"):
# Load from local checkpoint
print("Loading from local checkpoint: checkpoint_5000")
model_path = "checkpoint_5000"
else:
# Fallback: load pretrained model
print("Loading pretrained model: HuggingFaceTB/SmolLM2-135M")
model_path = "HuggingFaceTB/SmolLM2-135M"
# Load model (quantization only available on GPU)
if device == "cuda" and HAS_BITSANDBYTES:
# GPU: Can use quantization if requested
use_quantization = os.getenv("USE_QUANTIZATION", "").lower()
if use_quantization == "4bit":
try:
print("Using 4-bit quantization (QLoRA)")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
model = LlamaForCausalLM.from_pretrained(
model_path,
quantization_config=quantization_config,
device_map="auto"
)
except Exception as e:
print(f"Quantization failed, loading normally: {e}")
model = LlamaForCausalLM.from_pretrained(model_path)
model.to(device)
elif use_quantization in ["8bit", "true", "1"]:
try:
print("Using 8-bit quantization")
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model = LlamaForCausalLM.from_pretrained(
model_path,
quantization_config=quantization_config,
device_map="auto"
)
except Exception as e:
print(f"Quantization failed, loading normally: {e}")
model = LlamaForCausalLM.from_pretrained(model_path)
model.to(device)
else:
# GPU without quantization
model = LlamaForCausalLM.from_pretrained(model_path)
model.to(device)
else:
# CPU: Load without quantization (bitsandbytes doesn't work on CPU)
print(f"Loading on {device.upper()} (quantization not available)")
# Use float32 for CPU (more compatible, though slower)
model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float32)
model.to(device)
model.eval()
tokenizer = GPT2TokenizerFast.from_pretrained("HuggingFaceTB/SmolLM2-135M")
tokenizer.pad_token = tokenizer.eos_token
print("Model loaded successfully!")
return model, tokenizer
def count_parameters(model):
"""Count model parameters"""
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
def format_number(num):
if num >= 1_000_000_000:
return f"{num / 1_000_000_000:.2f}B"
elif num >= 1_000_000:
return f"{num / 1_000_000:.2f}M"
elif num >= 1_000:
return f"{num / 1_000:.2f}K"
else:
return str(num)
return f"Total: {total_params:,} ({format_number(total_params)}) | Trainable: {trainable_params:,} ({format_number(trainable_params)})"
def generate_text(
prompt,
max_new_tokens,
temperature,
top_p,
top_k,
do_sample,
repetition_penalty
):
"""Generate text from prompt"""
global model, tokenizer
if model is None or tokenizer is None:
model, tokenizer = load_model()
try:
# Tokenize input
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
# Generate
with torch.no_grad():
outputs = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature if do_sample else 1.0,
top_p=top_p if do_sample else 1.0,
top_k=top_k if do_sample else 50,
do_sample=do_sample,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.eos_token_id,
)
# Decode
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove the prompt from output if it's included
if generated_text.startswith(prompt):
generated_text = generated_text[len(prompt):].strip()
return generated_text
except Exception as e:
return f"Error: {str(e)}"
def get_model_info():
"""Get model information"""
global model, tokenizer
if model is None or tokenizer is None:
model, tokenizer = load_model()
info = f"""
## Model Information
**Model Type:** SmolLM2-135M (LLaMA Architecture)
**Fine-tuned on:** Shakespeare's Coriolanus
**Device:** {device}
**Parameters:** {count_parameters(model)}
### Architecture Details
- **Hidden Size:** 576
- **Intermediate Size:** 1536
- **Number of Layers:** 30
- **Attention Heads:** 9
- **Key-Value Heads:** 3 (GQA)
- **Vocabulary Size:** 49,152
- **Max Position Embeddings:** 8,192
- **RoPE Theta:** 100,000
### Features
- βœ… Flash Attention (SDPA)
- βœ… Grouped Query Attention (GQA)
- βœ… RMSNorm
- βœ… Rotary Position Embeddings (RoPE)
- βœ… Tied Word Embeddings
- βœ… Fine-tuned on Coriolanus (writes like a dramatic play)
"""
return info
# Load model on startup
model, tokenizer = load_model()
# Create Gradio interface
with gr.Blocks(title="SmolLM2-135M Demo") as demo:
gr.Markdown(
"""
# πŸš€ SmolLM2-135M Text Generation Demo
A lightweight language model (135M parameters) fine-tuned exclusively on Shakespeare's **Coriolanus**.
The model writes in the style of a dramatic play, complete with character names, stage directions, and Shakespearean dialogue.
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Model Information")
model_info = gr.Markdown(get_model_info())
refresh_btn = gr.Button("πŸ”„ Refresh Info")
with gr.Column(scale=2):
gr.Markdown("### Text Generation")
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here... (e.g., 'CORIOLANUS:' or 'Enter CORIOLANUS and MENENIUS')",
lines=3,
value="CORIOLANUS:"
)
with gr.Row():
max_tokens = gr.Slider(
label="Max New Tokens",
minimum=10,
maximum=512,
value=100,
step=10
)
temperature = gr.Slider(
label="Temperature",
minimum=0.1,
maximum=2.0,
value=0.8,
step=0.1
)
with gr.Row():
top_p = gr.Slider(
label="Top-p (Nucleus Sampling)",
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05
)
top_k = gr.Slider(
label="Top-k",
minimum=1,
maximum=100,
value=50,
step=1
)
with gr.Row():
repetition_penalty = gr.Slider(
label="Repetition Penalty",
minimum=1.0,
maximum=2.0,
value=1.1,
step=0.05
)
do_sample = gr.Checkbox(
label="Enable Sampling",
value=True
)
generate_btn = gr.Button("✨ Generate")
output = gr.Textbox(
label="Generated Text",
lines=10,
interactive=False
)
# Event handlers
generate_btn.click(
fn=generate_text,
inputs=[
prompt_input,
max_tokens,
temperature,
top_p,
top_k,
do_sample,
repetition_penalty
],
outputs=output
)
refresh_btn.click(
fn=get_model_info,
outputs=model_info
)
gr.Markdown(
"""
### Usage Tips
- **Model Style**: This model is fine-tuned on Coriolanus and generates text in dramatic play format with character names and dialogue
- **Prompt Examples**: Try prompts like "CORIOLANUS:", "Enter CORIOLANUS and MENENIUS", or "ACT I, SCENE I"
- **Temperature**: Lower values (0.1-0.5) for more focused outputs, higher (0.8-1.5) for more creative text
- **Top-p**: Controls diversity via nucleus sampling (0.9 is a good default)
- **Top-k**: Limits sampling to top k tokens (50 is a good default)
- **Repetition Penalty**: Higher values (1.1-1.3) reduce repetition
- **Max New Tokens**: Maximum length of generated text
"""
)
if __name__ == "__main__":
demo.launch(share=False, server_name="0.0.0.0", server_port=7860)