backbone_base_kmeans β€” Beat-aligned Piano LM with Hierarchical k-means + Velocity VQ

Autoregressive LlaMA-based LM over a beat-aligned piano roll tokenizer. Supports two modes:

  1. Base β€” pattern tokens only (sustain+onset), no velocity.
  2. +Velocity β€” unified hierarchical tokenizer: per (beat, pitch) cell gets a cluster-index pattern token and one learned velocity token from a single-codebook VQ-VAE. Captures expressive dynamics end-to-end in one LM.

Vocab layout (4789 total)

Range Size Meaning
0..2047 2048 L1 k-means codes (48-frame full beat)
2048..3071 1024 L2 k-means codes (24-frame half beat)
3072..3583 512 L3 k-means codes (12-frame quarter beat)
3584..3711 128 MIDI pitch tokens
3712..4223 512 Velocity VQ tokens (K=512, learned)
4224..4735 512 (reserved velocity slots, currently unused)
4736..4785 50 <beat_bpm> separators (BPM-bucketed, 10 BPM/bin)
4786/87/88 - BOS / EOS / PAD

Sequence format

[BOS] <beat_bpm_X> <pitch_P1> <L1/L2/L3 codes ...> <vel_token>
                    <pitch_P2> <codes ...>       <vel_token>
      <beat_bpm_Y> ...
[EOS]

Every active (beat, pitch) cell produces pitch + 1+ pattern_codes + 1 vel_token. When velocity is unavailable or uninformative (constant-velocity score files), the vel_token is still emitted for structural consistency but its CE loss is masked (labels[vel_pos] = -100).

Directory layout

backbone_base_kmeans/
β”œβ”€β”€ config.py             # ModelConfig, TrainingConfig
β”œβ”€β”€ model.py              # PianoLLaMA (LlamaForCausalLM wrapper)
β”œβ”€β”€ hierarchical_tokenizer.py  # HierCodebook (k-means) + VelVQEncoder
β”œβ”€β”€ ismir_dataset.py      # ISMIRPianoDataset + StreamingDataset (+ flat-vel mask)
β”œβ”€β”€ train.py              # base (no-vel) training
β”œβ”€β”€ train_finetune.py     # fine-tune warm-started from a base ckpt, adding velocity
β”œβ”€β”€ inference_ismir.py    # generate tokens β†’ decode β†’ MIDI (BPM-aware timing)
β”œβ”€β”€ codebook/
β”‚   └── hierarchical3_codebook.npz   # pre-trained 3-level k-means codebook (unchanged)
β”œβ”€β”€ vel_vq/                            # Velocity VQ-VAE (single-codebook, K=512)
β”‚   β”œβ”€β”€ model.py          # BeatVelVQModel
β”‚   β”œβ”€β”€ config.py         # BeatVelVQConfig
β”‚   β”œβ”€β”€ dataset.py        # VelCellCachedDataset / VelRowDataset
β”‚   β”œβ”€β”€ train.py          # VelVQ training loop
β”‚   β”œβ”€β”€ build_cache.py    # extract (vel, mask) cells to single npz (filters flat-vel score files)
β”‚   β”œβ”€β”€ vel_vq_best.pt    # trained checkpoint (val MIDI MAE β‰ˆ 0.54)
β”‚   └── code_freq.npz     # codebook usage histogram over full cache
└── checkpoints/
    β”œβ”€β”€ best_model/        # base LM (no-velocity, old checkpoint)
    β”œβ”€β”€ epoch_4_0422_0406/
    └── epoch_9_0422_0834/

Quick start

Base model inference (no velocity)

python inference_ismir.py \
  --model checkpoints/best_model/model.safetensors \
  --codebook codebook/hierarchical3_codebook.npz \
  --out_dir generated_base

Fine-tune to add velocity

# Expects the matching dataset at config.data_root:
#   https://huggingface.co/datasets/marisa0v0/triad-perf-midi-processed
#   (use processed_v2_velocity_fixed.tar.gz β€” v1 has score velocity=204 bug)
python train_finetune.py \
  --resume_ckpt checkpoints/best_model/model.safetensors \
  --vel_vq_ckpt ./vel_vq/vel_vq_best.pt

Train VelVQ from scratch

cd vel_vq
python build_cache.py   # ~2 min, 20M cells to ./cache/vel_cells.npz
python train.py          # ~15 min single 3090, val MIDI MAE β‰ˆ 0.54

Key design decisions

  • Hierarchical k-means codebook (non-learned) for binary patterns β€” lossless, interpretable, no training drift.
  • Single VQ-VAE (learned) for continuous velocity β€” num_quantizers=1 using vector_quantize_pytorch.ResidualVQ with EMA + dead-code revival.
  • Mask injection in VelVQ decoder (both at init projection + output side-channel) β€” tells decoder which frames are "silent" so it can focus capacity on predicting real velocity values.
  • Masked MSE loss β€” only active frames contribute; silent frames ignored.
  • Flat-velocity score filter at two places:
    • VelVQ training: build_cache.py drops score files whose velocity has ≀1 unique non-zero value (MuseScore defaults, no dynamics info).
    • Backbone fine-tune: sequences from those files still get vel_tokens for structure, but labels[vel_pos] = -100 so they contribute no loss signal.
  • Resample alignment: identical nearest-neighbor downsampling in VelVQ training and backbone tokenization (F.interpolate(mode='nearest') or its integer-index equivalent) β€” verified round-trip identical on real files.

Training data

See companion dataset:

Dependencies

torch >= 2.0
transformers >= 4.40   # LlamaForCausalLM
accelerate             # training
vector_quantize_pytorch
numpy
tqdm
pretty_midi            # inference MIDI output
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support