Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import LlamaForCausalLM, GPT2TokenizerFast | |
| import os | |
| # Optional import for quantization (GPU only) | |
| try: | |
| from transformers import BitsAndBytesConfig | |
| HAS_BITSANDBYTES = True | |
| except ImportError: | |
| HAS_BITSANDBYTES = False | |
| # Global variables for model and tokenizer | |
| model = None | |
| tokenizer = None | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Configuration: | |
| # Set HF_MODEL_ID to load from your HuggingFace Hub repository instead of local checkpoint | |
| HF_MODEL_ID = os.getenv("HF_MODEL_ID", None) # e.g., "your-username/smollm2-135m-coriolanus" | |
| # Note: Quantization (8-bit/4-bit) requires GPU and is disabled for CPU-only environments | |
| def load_model(): | |
| """Load the model and tokenizer with optional quantization""" | |
| global model, tokenizer | |
| if model is None: | |
| print("Loading model...") | |
| # Determine model source | |
| model_path = None | |
| if HF_MODEL_ID: | |
| # Load from HuggingFace Hub repository | |
| print(f"Loading from HuggingFace Hub: {HF_MODEL_ID}") | |
| model_path = HF_MODEL_ID | |
| elif os.path.exists("checkpoint_5000"): | |
| # Load from local checkpoint | |
| print("Loading from local checkpoint: checkpoint_5000") | |
| model_path = "checkpoint_5000" | |
| else: | |
| # Fallback: load pretrained model | |
| print("Loading pretrained model: HuggingFaceTB/SmolLM2-135M") | |
| model_path = "HuggingFaceTB/SmolLM2-135M" | |
| # Load model (quantization only available on GPU) | |
| if device == "cuda" and HAS_BITSANDBYTES: | |
| # GPU: Can use quantization if requested | |
| use_quantization = os.getenv("USE_QUANTIZATION", "").lower() | |
| if use_quantization == "4bit": | |
| try: | |
| print("Using 4-bit quantization (QLoRA)") | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4" | |
| ) | |
| model = LlamaForCausalLM.from_pretrained( | |
| model_path, | |
| quantization_config=quantization_config, | |
| device_map="auto" | |
| ) | |
| except Exception as e: | |
| print(f"Quantization failed, loading normally: {e}") | |
| model = LlamaForCausalLM.from_pretrained(model_path) | |
| model.to(device) | |
| elif use_quantization in ["8bit", "true", "1"]: | |
| try: | |
| print("Using 8-bit quantization") | |
| quantization_config = BitsAndBytesConfig(load_in_8bit=True) | |
| model = LlamaForCausalLM.from_pretrained( | |
| model_path, | |
| quantization_config=quantization_config, | |
| device_map="auto" | |
| ) | |
| except Exception as e: | |
| print(f"Quantization failed, loading normally: {e}") | |
| model = LlamaForCausalLM.from_pretrained(model_path) | |
| model.to(device) | |
| else: | |
| # GPU without quantization | |
| model = LlamaForCausalLM.from_pretrained(model_path) | |
| model.to(device) | |
| else: | |
| # CPU: Load without quantization (bitsandbytes doesn't work on CPU) | |
| print(f"Loading on {device.upper()} (quantization not available)") | |
| # Use float32 for CPU (more compatible, though slower) | |
| model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float32) | |
| model.to(device) | |
| model.eval() | |
| tokenizer = GPT2TokenizerFast.from_pretrained("HuggingFaceTB/SmolLM2-135M") | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("Model loaded successfully!") | |
| return model, tokenizer | |
| def count_parameters(model): | |
| """Count model parameters""" | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| def format_number(num): | |
| if num >= 1_000_000_000: | |
| return f"{num / 1_000_000_000:.2f}B" | |
| elif num >= 1_000_000: | |
| return f"{num / 1_000_000:.2f}M" | |
| elif num >= 1_000: | |
| return f"{num / 1_000:.2f}K" | |
| else: | |
| return str(num) | |
| return f"Total: {total_params:,} ({format_number(total_params)}) | Trainable: {trainable_params:,} ({format_number(trainable_params)})" | |
| def generate_text( | |
| prompt, | |
| max_new_tokens, | |
| temperature, | |
| top_p, | |
| top_k, | |
| do_sample, | |
| repetition_penalty | |
| ): | |
| """Generate text from prompt""" | |
| global model, tokenizer | |
| if model is None or tokenizer is None: | |
| model, tokenizer = load_model() | |
| try: | |
| # Tokenize input | |
| input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
| # Generate | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature if do_sample else 1.0, | |
| top_p=top_p if do_sample else 1.0, | |
| top_k=top_k if do_sample else 50, | |
| do_sample=do_sample, | |
| repetition_penalty=repetition_penalty, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| # Decode | |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Remove the prompt from output if it's included | |
| if generated_text.startswith(prompt): | |
| generated_text = generated_text[len(prompt):].strip() | |
| return generated_text | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def get_model_info(): | |
| """Get model information""" | |
| global model, tokenizer | |
| if model is None or tokenizer is None: | |
| model, tokenizer = load_model() | |
| info = f""" | |
| ## Model Information | |
| **Model Type:** SmolLM2-135M (LLaMA Architecture) | |
| **Fine-tuned on:** Shakespeare's Coriolanus | |
| **Device:** {device} | |
| **Parameters:** {count_parameters(model)} | |
| ### Architecture Details | |
| - **Hidden Size:** 576 | |
| - **Intermediate Size:** 1536 | |
| - **Number of Layers:** 30 | |
| - **Attention Heads:** 9 | |
| - **Key-Value Heads:** 3 (GQA) | |
| - **Vocabulary Size:** 49,152 | |
| - **Max Position Embeddings:** 8,192 | |
| - **RoPE Theta:** 100,000 | |
| ### Features | |
| - β Flash Attention (SDPA) | |
| - β Grouped Query Attention (GQA) | |
| - β RMSNorm | |
| - β Rotary Position Embeddings (RoPE) | |
| - β Tied Word Embeddings | |
| - β Fine-tuned on Coriolanus (writes like a dramatic play) | |
| """ | |
| return info | |
| # Load model on startup | |
| model, tokenizer = load_model() | |
| # Create Gradio interface | |
| with gr.Blocks(title="SmolLM2-135M Demo") as demo: | |
| gr.Markdown( | |
| """ | |
| # π SmolLM2-135M Text Generation Demo | |
| A lightweight language model (135M parameters) fine-tuned exclusively on Shakespeare's **Coriolanus**. | |
| The model writes in the style of a dramatic play, complete with character names, stage directions, and Shakespearean dialogue. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Model Information") | |
| model_info = gr.Markdown(get_model_info()) | |
| refresh_btn = gr.Button("π Refresh Info") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Text Generation") | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Enter your prompt here... (e.g., 'CORIOLANUS:' or 'Enter CORIOLANUS and MENENIUS')", | |
| lines=3, | |
| value="CORIOLANUS:" | |
| ) | |
| with gr.Row(): | |
| max_tokens = gr.Slider( | |
| label="Max New Tokens", | |
| minimum=10, | |
| maximum=512, | |
| value=100, | |
| step=10 | |
| ) | |
| temperature = gr.Slider( | |
| label="Temperature", | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.8, | |
| step=0.1 | |
| ) | |
| with gr.Row(): | |
| top_p = gr.Slider( | |
| label="Top-p (Nucleus Sampling)", | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.05 | |
| ) | |
| top_k = gr.Slider( | |
| label="Top-k", | |
| minimum=1, | |
| maximum=100, | |
| value=50, | |
| step=1 | |
| ) | |
| with gr.Row(): | |
| repetition_penalty = gr.Slider( | |
| label="Repetition Penalty", | |
| minimum=1.0, | |
| maximum=2.0, | |
| value=1.1, | |
| step=0.05 | |
| ) | |
| do_sample = gr.Checkbox( | |
| label="Enable Sampling", | |
| value=True | |
| ) | |
| generate_btn = gr.Button("β¨ Generate") | |
| output = gr.Textbox( | |
| label="Generated Text", | |
| lines=10, | |
| interactive=False | |
| ) | |
| # Event handlers | |
| generate_btn.click( | |
| fn=generate_text, | |
| inputs=[ | |
| prompt_input, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| top_k, | |
| do_sample, | |
| repetition_penalty | |
| ], | |
| outputs=output | |
| ) | |
| refresh_btn.click( | |
| fn=get_model_info, | |
| outputs=model_info | |
| ) | |
| gr.Markdown( | |
| """ | |
| ### Usage Tips | |
| - **Model Style**: This model is fine-tuned on Coriolanus and generates text in dramatic play format with character names and dialogue | |
| - **Prompt Examples**: Try prompts like "CORIOLANUS:", "Enter CORIOLANUS and MENENIUS", or "ACT I, SCENE I" | |
| - **Temperature**: Lower values (0.1-0.5) for more focused outputs, higher (0.8-1.5) for more creative text | |
| - **Top-p**: Controls diversity via nucleus sampling (0.9 is a good default) | |
| - **Top-k**: Limits sampling to top k tokens (50 is a good default) | |
| - **Repetition Penalty**: Higher values (1.1-1.3) reduce repetition | |
| - **Max New Tokens**: Maximum length of generated text | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=False, server_name="0.0.0.0", server_port=7860) | |