sh0416/ag_news
Viewer β’ Updated β’ 128k β’ 2.61k β’ 13
A fine-tuned DistilBERT model for classifying news articles into 4 categories: World, Sports, Business, and Sci/Tech.
| Metric | Value |
|---|---|
| Test Accuracy | 94.33% |
| Dataset | AG News (120,000 train / 7,600 test) |
| Baseline (random) | 25% |
For reference, state of the art on AG News is ~95-96%, making this model competitive with top published results.
import torch
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
# 1. Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
# 2. Define the model class
import torch.nn as nn
from transformers import AutoModel
class AGNewsClassifier(nn.Module):
def __init__(self, num_classes=4, dropout=0.3):
super().__init__()
self.bert = AutoModel.from_pretrained("distilbert-base-uncased")
self.dropout = nn.Dropout(dropout)
self.classifier = nn.Linear(768, num_classes)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
cls_output = outputs.last_hidden_state[:, 0, :]
dropped = self.dropout(cls_output)
return self.classifier(dropped)
# 3. Load model weights
model = AGNewsClassifier()
weights_path = hf_hub_download("your_username/ag-news-classifier", "best_model.pt")
model.load_state_dict(torch.load(weights_path, map_location="cpu"))
model.eval()
# 4. Classify new text
LABELS = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
def classify(text):
tokens = tokenizer(
text,
max_length=128,
padding="max_length",
truncation=True,
return_tensors="pt"
)
with torch.no_grad():
logits = model(tokens["input_ids"], tokens["attention_mask"])
pred = logits.argmax(dim=1).item()
return LABELS[pred]
# Example
print(classify("Apple stock surges after record quarterly earnings"))
# β Business
print(classify("Scientists discover potential signs of water on Mars"))
# β Sci/Tech
Title and description were combined using a separator token:
[TITLE]{title}[SEP]{description}
Max token length in the dataset after combining: 383 tokens. Since 99.9% of samples fit within 128 tokens, max_length=128 was chosen for efficiency with negligible information loss.
Input (title + description)
β
DistilBERT (distilbert-base-uncased, 66M parameters)
β
[CLS] token hidden state (768-dim)
β
Dropout (p=0.3)
β
Linear layer (768 β 4)
β
Logits (4 classes)
| Parameter | Value |
|---|---|
| Base model | distilbert-base-uncased |
| Max sequence length | 128 |
| Batch size | 64 |
| Optimizer | AdamW |
| Learning rate | 2e-5 |
| Epochs | 3 |
| Dropout | 0.3 |
This model was built as part of a data science portfolio project to demonstrate fine-tuning transformer models for NLP classification tasks using PyTorch and HuggingFace Transformers.