tse-conv-tasnet-48k
Causal, streaming, speaker-conditioned target speaker extraction at native 48 kHz.
Given (a) a noisy mixture of several speakers and (b) a 192-dimensional ECAPA speaker embedding of the target, the model extracts only the target speaker's voice. The graph is exported as a per-chunk streaming ONNX with explicit causal-conv state tensors — designed to run in real time, chunk-by-chunk, on CPU or any ONNX Runtime backend.
Performance (v3 — EMA-fixed, speaker-disjoint test)
| Split | # Speakers | SI-SDR (dB) |
|---|---|---|
| TRAIN | 87 (seen) | +9.96 |
| VAL | 11 (unseen) | +8.61 |
| TEST | 11 (unseen) | +8.55 |
Generalisation gap (train − test) = +1.40 dB — the model holds up well on speakers it has never seen.
Numbers measured on 500 VCTK utterance mixtures (target + interferer from a different speaker + DEMAND noise) at randomly-sampled crops within each utterance (avoiding the silent intro of VCTK recordings, which would otherwise underestimate SI-SDR — see mellonella#151).
Highlights
- 48 kHz native, end-to-end. Most public TSE / source-separation releases (SpeakerBeam, SpEx+, much of WSJ0-mix etc.) are trained at 8 or 16 kHz and resampled at inference. This model trains, exports, and runs at the same 48 kHz the consumer audio path actually uses.
- Clean licensing. Training data is VCTK (English multi-speaker studio recordings, CC BY 4.0) and DEMAND (real-world noise field recordings, CC BY 4.0). The released weights inherit CC BY 4.0 — no copyleft, no research-only restriction.
- Causal & streaming. Encoder is left-padded, separator uses cumulative (causal) layer norm, all depthwise dilated convs run with explicit per-chunk state buffers. Look-ahead is zero — the model never sees future samples.
- Small. 1.45 M parameters, 624 KB graph + 5.3 MB external weights.
- Verified parity. PyTorch ↔ ONNX Runtime per-chunk max|Δ| = 5.2 × 10⁻⁸ (well under the 1 × 10⁻⁴ tolerance).
Architecture
Causal Conv-TasNet (Luo & Mesgarani, 2019) with SpeakerBeam-style FiLM conditioning (Žmolíková et al., 2019).
- Encoder. 1-D conv, kernel 96 / stride 48 / 256 basis channels. Causal (left-pad). The encoder kernel/stride triples vs. the 16 kHz PoC so the latent frame rate stays at 1 kHz across both regimes — the separator is byte-identical between 16 kHz and 48 kHz.
- Conditioning. A frozen 192-dim ECAPA-TDNN enrollment embedding
feeds a trainable 2-layer MLP (
192 → 256 → 2·128) producing FiLM(γ, β). The ECAPA model itself is not part of this release — compute the embedding once per enrolled speaker and pass it as a plain input. - Separator. Temporal Convolutional Network with
R = 2repeats ×X = 6depthwise-separable conv blocks, dilations1, 2, 4, 8, 16, 32. Each block: 1×1 conv → causal dilated depthwise conv (left-pad only) → PReLU → cumulative (causal) layer norm → FiLM → 1×1 residual + skip. BottleneckB = 128, hiddenH = 256, depthwise kernelP = 3. - Mask + decoder. Skip-sum → 1×1 conv → sigmoid mask over the 256 basis channels → multiply with the encoder output → 1-D transposed-conv decoder.
Files
| File | Size | Purpose |
|---|---|---|
tse_prod_48k.onnx |
624 KB | Streaming ONNX graph |
tse_prod_48k.onnx.data |
5.3 MB | External weights |
tse_prod_48k.onnx.weights.pt |
5.6 MB | PyTorch sidecar (verification / Python inference) |
metrics.json |
~10 KB | Per-epoch SI-SDR (train + val on held-out) + config |
ckpt_source.txt |
< 1 KB | Provenance |
I/O contract
The ONNX graph has a fixed chunk length of 480 samples (10 ms at
48 kHz) and threads 89 state tensors across process_chunk calls.
| Input | Shape | dtype | Notes |
|---|---|---|---|
audio_chunk |
[1, 480] |
f32 |
Mixture chunk |
cond |
[1, 192] |
f32 |
ECAPA enrollment embedding (compute once) |
state_in_0 … state_in_88 |
various | f32 |
Causal-conv state (zero-init at session start) |
| Output | Shape | dtype |
|---|---|---|
extracted_chunk |
[1, 480] |
f32 |
state_out_0 … state_out_88 |
matches state_in_k |
f32 |
Quickstart — Python (onnxruntime)
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
onnx = hf_hub_download("penta2himajin/tse-conv-tasnet-48k", "tse_prod_48k.onnx")
hf_hub_download("penta2himajin/tse-conv-tasnet-48k", "tse_prod_48k.onnx.data")
session = ort.InferenceSession(onnx, providers=["CPUExecutionProvider"])
state_inputs = [i for i in session.get_inputs() if i.name.startswith("state_in_")]
state = {i.name: np.zeros(i.shape, dtype=np.float32) for i in state_inputs}
cond = np.zeros((1, 192), dtype=np.float32) # real ECAPA embedding here
mixture = np.zeros((1, 480), dtype=np.float32) # your audio chunk
while True:
out = session.run(None, {"audio_chunk": mixture, "cond": cond, **state})
extracted = out[0]
state = {f"state_in_{i}": out[i + 1] for i in range(len(state_inputs))}
# ... process the next mixture chunk
Quickstart — Rust (mellonella-core)
This is the inference target of the
mellonella real-time
voice-filter project — see rust/mellonella-core for a streaming
wrapper (TseSession) and TseConfig::prod_48k() for the matching
config. Rust↔Python ONNX parity is enforced by
tests/tse_parity_prod_48k.rs.
Training recipe
- Data. VCTK (44.1 kHz, resampled to 48 kHz FLAC int16) + DEMAND ch01 of each noise category. Speaker-level split: 87 train / 11 val / 11 test (SHA-1-ranked, fully disjoint).
- Loss. Negative SI-SDR.
- Optimiser. AdamW (lr 1e-3, weight_decay 1e-2), cosine schedule with 2 warmup epochs, EMA decay 0.999, gradient clip 5.0.
- AMP. Off. AMP fp16 NaN'd at peak LR on 48 kHz sequence lengths (cumulative LN's intermediate activations exceed fp16 range). Fp32 training is ~2× the per-step cost but stable.
- Trained 50 epochs in ~2 h on a Kaggle T4.
- Released checkpoint. Epoch 49 EMA weights.
History
- v1 (2026-05-18): initial release, but the EMA shadow was silently
not updated due to a
torch.compilename-prefix mismatch (mellonella#151), so the export was effectively random init. - v2 (2026-05-19 hotfix): re-export from the LIVE weights of the same 50-epoch run.
- v3 (current): re-train with the EMA fix + a speaker-disjoint train/val/test split, ship the EMA epoch-49 weights. Numbers above.
Licence & citation
Released under CC BY 4.0. Training data carries the same licence:
- VCTK Corpus 0.92 — Junichi Yamagishi, Christophe Veaux, Kirsten MacDonald, CSTR VCTK Corpus, CSTR, University of Edinburgh, 2019.
- DEMAND — Joachim Thiemann, Nobutaka Ito, Emmanuel Vincent, The Diverse Environments Multi-channel Acoustic Noise Database, ICA 2013.
Architecture references:
- Yi Luo, Nima Mesgarani, Conv-TasNet: Surpassing Ideal Time-Frequency Magnitude Masking for Speech Separation, TASLP 2019.
- Kateřina Žmolíková et al., SpeakerBeam: Speaker Aware Neural Network for Target Speaker Extraction in Speech Mixtures, IEEE JSTSP 2019.
Speaker embedding model used at training time (not redistributed
here): ECAPA-TDNN (Desplanques et al., 2020),
weights from
speechbrain/spkrec-ecapa-voxceleb.