YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

Transformer-based Neural Machine Translation (NMT)

A Chinese-English neural machine translation system based on the Transformer architecture, supporting both training from scratch and fine-tuning of pre-trained models.

πŸš€πŸš€πŸš€ Due to the model checkpoint is too big to put on github, so the checkpoint is upload to the Huggingface: https://huggingface.co/DarcyCheng/Transformer-based-NMT

Introduction

This project implements a complete Transformer-based NMT system, with core tasks including:

1. From Scratch Training

Construct and train a Chinese-English translation model based on the Transformer architecture, including:

  • Adoption of the Encoder-Decoder structure
  • Model training from scratch
  • Support for distributed training (multi-GPU)

2. Architectural Ablation

Implement and compare different architectural variants:

  • Positional Encoding Schemes: Absolute Positional Encoding vs Relative Positional Encoding
  • Normalization Methods: LayerNorm vs RMSNorm

3. Hyperparameter Sensitivity

Evaluate the impact of hyperparameter adjustments on translation performance:

  • Batch Size
  • Learning Rate
  • Model Scale

4. From Pretrained Language Model

Support fine-tuning based on pre-trained language models (e.g., T5) and compare performance with models trained from scratch.

Data Preparation

The dataset consists of four JSONL files, corresponding to:

  • Small Training Set: 100k samples
  • Large Training Set: 10k samples
  • Validation Set: 500 samples
  • Test Set: 200 samples

Each line in the JSONL files contains a parallel Chinese-English sentence pair. The final model performance will be evaluated based on the results from the test set.

Data Format Example:

{"zh": "δ½ ε₯½οΌŒδΈ–η•Œγ€‚", "en": "Hello, world."}

Environment

System Requirements

  • Python: 3.9.25
  • PyTorch: 2.0.1+cu118
  • CUDA: 11.8 (recommended)

Install Dependencies

pip install -r requirements.txt

Key dependencies include:

  • torch>=2.0.1
  • torchvision
  • numpy
  • matplotlib
  • tqdm
  • hydra-core
  • omegaconf
  • sentencepiece
  • nltk

Project Structure

Transformer_NMT/
β”œβ”€β”€ src/                    # Core source code
β”‚   β”œβ”€β”€ model.py           # Transformer model definition
β”‚   β”œβ”€β”€ dataset.py         # Data processing and dataset classes
β”‚   β”œβ”€β”€ utils.py           # Utility functions
β”‚   └── visualize_training.py  # Training visualization
β”œβ”€β”€ configs/                # Configuration files
β”‚   β”œβ”€β”€ train.yaml         # Training configuration
β”‚   └── inference.yaml     # Inference configuration
β”œβ”€β”€ checkpoints/            # Model checkpoint directory
β”‚   └── exp_*/             # Experiment-specific checkpoint directories
β”œβ”€β”€ logs/                   # Training logs
β”œβ”€β”€ outputs/                # Output results
β”œβ”€β”€ train.py               # Training script
β”œβ”€β”€ evaluation.py          # Evaluation script
└── inference.py           # Inference script

Training, Evaluation, and Inference

Train the Model

Single-GPU Training

python train.py

Multi-GPU Distributed Training

torchrun --nproc_per_node=<num_gpus> train.py

Configure Training Parameters

Edit the configs/train.yaml file to adjust training parameters, including:

  • Model architecture parameters (D_MODEL, NHEAD, NUM_ENCODER_LAYERS, etc.)
  • Training hyperparameters (BATCH_SIZE, LEARNING_RATE, NUM_EPOCHS, etc.)
  • Ablation experiment configurations (POS_ENCODING_TYPE, NORM_TYPE, etc.)

Ablation Experiment Configuration Examples

Absolute Positional Encoding + LayerNorm:

POS_ENCODING_TYPE: absolute
NORM_TYPE: layernorm

Relative Positional Encoding + RMSNorm:

POS_ENCODING_TYPE: relative
NORM_TYPE: rmsnorm

During training, model checkpoints will be automatically saved to the checkpoints/exp_<experiment_name>/ directory, where the experiment name is generated automatically based on the configuration.

Evaluate the Model

Evaluate the model's performance on the test set, outputting BLEU-1, BLEU-2, BLEU-3, BLEU-4, and Perplexity scores:

python evaluation.py model_path=<checkpoint_path>

Example:

python evaluation.py model_path=checkpoints/exp_abs_pos_ln_bs32_lre4/checkpoint_best.pth

Inference for Translation

Use the trained model for single-sentence or batch translation:

python inference.py

Or specify the model path:

python inference.py model_path=<checkpoint_path>

Configuration

Training Configuration (configs/train.yaml)

Key configuration items:

  • Model Parameters:

    • D_MODEL: Model dimension (default: 256)
    • NHEAD: Number of attention heads (default: 8)
    • NUM_ENCODER_LAYERS: Number of encoder layers (default: 4)
    • NUM_DECODER_LAYERS: Number of decoder layers (default: 4)
    • DIM_FEEDFORWARD: Feed-forward network dimension (default: 1024)
    • DROPOUT: Dropout rate (default: 0.1)
    • MAX_LEN: Maximum sequence length (default: 128)
  • Training Parameters:

    • BATCH_SIZE: Batch size (default: 64)
    • LEARNING_RATE: Learning rate (default: 1e-5)
    • NUM_EPOCHS: Number of training epochs (default: 30)
    • CLIP_GRAD: Gradient clipping threshold (default: 5.0)
    • LABEL_SMOOTHING: Label smoothing coefficient (default: 0.1)
  • Ablation Experiment Parameters:

    • POS_ENCODING_TYPE: Positional encoding type (absolute or relative)
    • NORM_TYPE: Normalization type (layernorm or rmsnorm)

Features

  • βœ… Complete Transformer architecture implementation
  • βœ… Support for absolute and relative positional encoding
  • βœ… Support for LayerNorm and RMSNorm
  • βœ… Distributed training support (DDP)
  • βœ… Mixed precision training (AMP)
  • βœ… Automatic experiment management and checkpoint saving
  • βœ… Comprehensive evaluation metrics (BLEU-1/2/3/4, Perplexity)
  • βœ… Training process visualization
  • βœ… Numerical stability optimization (NaN detection and handling)

Acknowledgement

Thanks to the following repositories and projects:

License

This project is for educational and research purposes only.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support