backbone_base_kmeans β Beat-aligned Piano LM with Hierarchical k-means + Velocity VQ
Autoregressive LlaMA-based LM over a beat-aligned piano roll tokenizer. Supports two modes:
- Base β pattern tokens only (sustain+onset), no velocity.
- +Velocity β unified hierarchical tokenizer: per
(beat, pitch)cell gets a cluster-indexpatterntoken and one learnedvelocitytoken from a single-codebook VQ-VAE. Captures expressive dynamics end-to-end in one LM.
Vocab layout (4789 total)
| Range | Size | Meaning |
|---|---|---|
0..2047 |
2048 | L1 k-means codes (48-frame full beat) |
2048..3071 |
1024 | L2 k-means codes (24-frame half beat) |
3072..3583 |
512 | L3 k-means codes (12-frame quarter beat) |
3584..3711 |
128 | MIDI pitch tokens |
3712..4223 |
512 | Velocity VQ tokens (K=512, learned) |
4224..4735 |
512 | (reserved velocity slots, currently unused) |
4736..4785 |
50 | <beat_bpm> separators (BPM-bucketed, 10 BPM/bin) |
4786/87/88 |
- | BOS / EOS / PAD |
Sequence format
[BOS] <beat_bpm_X> <pitch_P1> <L1/L2/L3 codes ...> <vel_token>
<pitch_P2> <codes ...> <vel_token>
<beat_bpm_Y> ...
[EOS]
Every active (beat, pitch) cell produces pitch + 1+ pattern_codes + 1 vel_token.
When velocity is unavailable or uninformative (constant-velocity score files), the
vel_token is still emitted for structural consistency but its CE loss is masked
(labels[vel_pos] = -100).
Directory layout
backbone_base_kmeans/
βββ config.py # ModelConfig, TrainingConfig
βββ model.py # PianoLLaMA (LlamaForCausalLM wrapper)
βββ hierarchical_tokenizer.py # HierCodebook (k-means) + VelVQEncoder
βββ ismir_dataset.py # ISMIRPianoDataset + StreamingDataset (+ flat-vel mask)
βββ train.py # base (no-vel) training
βββ train_finetune.py # fine-tune warm-started from a base ckpt, adding velocity
βββ inference_ismir.py # generate tokens β decode β MIDI (BPM-aware timing)
βββ codebook/
β βββ hierarchical3_codebook.npz # pre-trained 3-level k-means codebook (unchanged)
βββ vel_vq/ # Velocity VQ-VAE (single-codebook, K=512)
β βββ model.py # BeatVelVQModel
β βββ config.py # BeatVelVQConfig
β βββ dataset.py # VelCellCachedDataset / VelRowDataset
β βββ train.py # VelVQ training loop
β βββ build_cache.py # extract (vel, mask) cells to single npz (filters flat-vel score files)
β βββ vel_vq_best.pt # trained checkpoint (val MIDI MAE β 0.54)
β βββ code_freq.npz # codebook usage histogram over full cache
βββ checkpoints/
βββ best_model/ # base LM (no-velocity, old checkpoint)
βββ epoch_4_0422_0406/
βββ epoch_9_0422_0834/
Quick start
Base model inference (no velocity)
python inference_ismir.py \
--model checkpoints/best_model/model.safetensors \
--codebook codebook/hierarchical3_codebook.npz \
--out_dir generated_base
Fine-tune to add velocity
# Expects the matching dataset at config.data_root:
# https://huggingface.co/datasets/marisa0v0/triad-perf-midi-processed
# (use processed_v2_velocity_fixed.tar.gz β v1 has score velocity=204 bug)
python train_finetune.py \
--resume_ckpt checkpoints/best_model/model.safetensors \
--vel_vq_ckpt ./vel_vq/vel_vq_best.pt
Train VelVQ from scratch
cd vel_vq
python build_cache.py # ~2 min, 20M cells to ./cache/vel_cells.npz
python train.py # ~15 min single 3090, val MIDI MAE β 0.54
Key design decisions
- Hierarchical k-means codebook (non-learned) for binary patterns β lossless, interpretable, no training drift.
- Single VQ-VAE (learned) for continuous velocity β
num_quantizers=1usingvector_quantize_pytorch.ResidualVQwith EMA + dead-code revival. - Mask injection in VelVQ decoder (both at init projection + output side-channel) β tells decoder which frames are "silent" so it can focus capacity on predicting real velocity values.
- Masked MSE loss β only active frames contribute; silent frames ignored.
- Flat-velocity score filter at two places:
- VelVQ training:
build_cache.pydrops score files whose velocity has β€1 unique non-zero value (MuseScore defaults, no dynamics info). - Backbone fine-tune: sequences from those files still get
vel_tokensfor structure, butlabels[vel_pos] = -100so they contribute no loss signal.
- VelVQ training:
- Resample alignment: identical nearest-neighbor downsampling in VelVQ training
and backbone tokenization (
F.interpolate(mode='nearest')or its integer-index equivalent) β verified round-trip identical on real files.
Training data
See companion dataset:
- https://huggingface.co/datasets/marisa0v0/triad-perf-midi-processed
(
processed_v2_velocity_fixed.tar.gz, 602 MB) - 46k beat-aligned performance+score files across 6 sources (maestro, asap, aria, cpm, amaps + musescore-rendered score)
Dependencies
torch >= 2.0
transformers >= 4.40 # LlamaForCausalLM
accelerate # training
vector_quantize_pytorch
numpy
tqdm
pretty_midi # inference MIDI output
Inference Providers NEW
This model isn't deployed by any Inference Provider. π Ask for provider support