AG News Text Classifier

A fine-tuned DistilBERT model for classifying news articles into 4 categories: World, Sports, Business, and Sci/Tech.

Model Performance

Metric Value
Test Accuracy 94.33%
Dataset AG News (120,000 train / 7,600 test)
Baseline (random) 25%

For reference, state of the art on AG News is ~95-96%, making this model competitive with top published results.


How to Use

import torch
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download

# 1. Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

# 2. Define the model class
import torch.nn as nn
from transformers import AutoModel

class AGNewsClassifier(nn.Module):
    def __init__(self, num_classes=4, dropout=0.3):
        super().__init__()
        self.bert = AutoModel.from_pretrained("distilbert-base-uncased")
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(768, num_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]
        dropped = self.dropout(cls_output)
        return self.classifier(dropped)

# 3. Load model weights
model = AGNewsClassifier()
weights_path = hf_hub_download("your_username/ag-news-classifier", "best_model.pt")
model.load_state_dict(torch.load(weights_path, map_location="cpu"))
model.eval()

# 4. Classify new text
LABELS = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}

def classify(text):
    tokens = tokenizer(
        text,
        max_length=128,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )
    with torch.no_grad():
        logits = model(tokens["input_ids"], tokens["attention_mask"])
    pred = logits.argmax(dim=1).item()
    return LABELS[pred]

# Example
print(classify("Apple stock surges after record quarterly earnings"))
# β†’ Business

print(classify("Scientists discover potential signs of water on Mars"))
# β†’ Sci/Tech

Training Details

Dataset

  • Source: AG News
  • Train size: 120,000 samples (30,000 per class)
  • Test size: 7,600 samples (1,900 per class)
  • Classes: Perfectly balanced across all 4 categories

Input Format

Title and description were combined using a separator token:

[TITLE]{title}[SEP]{description}

Max token length in the dataset after combining: 383 tokens. Since 99.9% of samples fit within 128 tokens, max_length=128 was chosen for efficiency with negligible information loss.

Model Architecture

Input (title + description)
        ↓
DistilBERT (distilbert-base-uncased, 66M parameters)
        ↓
[CLS] token hidden state (768-dim)
        ↓
Dropout (p=0.3)
        ↓
Linear layer (768 β†’ 4)
        ↓
Logits (4 classes)

Hyperparameters

Parameter Value
Base model distilbert-base-uncased
Max sequence length 128
Batch size 64
Optimizer AdamW
Learning rate 2e-5
Epochs 3
Dropout 0.3

Training Environment

  • Platform: Kaggle Notebooks
  • GPU: Tesla T4 x2
  • Framework: PyTorch + HuggingFace Transformers

Limitations

  • Trained on English news articles only β€” performance on other languages or domains (e.g. social media, academic text) may degrade
  • Model was fine-tuned for 3 epochs; further tuning with a validation-based early stopping strategy could improve results

About

This model was built as part of a data science portfolio project to demonstrate fine-tuning transformer models for NLP classification tasks using PyTorch and HuggingFace Transformers.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train sushamarangarajan/ag-news-classifier

Space using sushamarangarajan/ag-news-classifier 1

Evaluation results