Audio-to-Audio
MLX
audio
music-source-separation
source-separation
demucs
htdemucs
hdemucs
apple-silicon
float16
Instructions to use iky1e/demucs-mlx-fp16 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use iky1e/demucs-mlx-fp16 with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir demucs-mlx-fp16 iky1e/demucs-mlx-fp16
- Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- LM Studio
| #!/usr/bin/env python3 | |
| """ | |
| Export Demucs PyTorch models directly to safetensors + JSON config for Swift MLX. | |
| Converts all 8 pretrained models directly from the original PyTorch demucs package. | |
| No dependency on demucs-mlx or any other re-implementation. | |
| Usage: | |
| # Export all models | |
| python scripts/export_from_pytorch.py --out-dir ~/.cache/demucs-mlx-swift-models | |
| # Export specific models | |
| python scripts/export_from_pytorch.py --models htdemucs htdemucs_ft --out-dir ./Models | |
| Requirements: | |
| pip install demucs safetensors numpy | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import inspect | |
| import json | |
| import re | |
| import sys | |
| from fractions import Fraction | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| ALL_MODELS = [ | |
| "htdemucs", | |
| "htdemucs_ft", | |
| "htdemucs_6s", | |
| "hdemucs_mmi", | |
| "mdx", | |
| "mdx_extra", | |
| "mdx_q", | |
| "mdx_extra_q", | |
| ] | |
| # Map PyTorch class names to MLX class names used by Swift loader | |
| CLASS_MAP = { | |
| "Demucs": "DemucsMLX", | |
| "HDemucs": "HDemucsMLX", | |
| "HTDemucs": "HTDemucsMLX", | |
| } | |
| # Conv-like layer names that get .conv. wrapper in MLX | |
| CONV_LAYER_NAMES = { | |
| "conv", "conv_tr", "rewrite", | |
| "channel_upsampler", "channel_downsampler", | |
| "channel_upsampler_t", "channel_downsampler_t", | |
| } | |
| # DConv attention sub-module names (LocalState) | |
| DCONV_ATTN_NAMES = {"content", "key", "query", "proj", "query_decay", "query_freqs"} | |
| def to_json_serializable(obj): | |
| """Convert Python objects to JSON-serializable types.""" | |
| if isinstance(obj, Fraction): | |
| return f"{obj.numerator}/{obj.denominator}" | |
| if isinstance(obj, torch.Tensor): | |
| return obj.item() if obj.numel() == 1 else obj.tolist() | |
| if isinstance(obj, np.ndarray): | |
| return obj.tolist() | |
| if isinstance(obj, (list, tuple)): | |
| return [to_json_serializable(x) for x in obj] | |
| if isinstance(obj, dict): | |
| return {str(k): to_json_serializable(v) for k, v in obj.items()} | |
| return obj | |
| def transpose_conv_weights(key: str, value: np.ndarray, is_conv_transpose: bool = False) -> np.ndarray: | |
| """Transpose PyTorch conv weights to MLX layout. | |
| Conv1d: (out, in, k) → MLX: (out, k, in) transpose (0,2,1) | |
| Conv2d: (out, in, h, w) → MLX: (out, h, w, in) transpose (0,2,3,1) | |
| ConvTranspose1d: (in, out, k) → MLX: (out, k, in) transpose (1,2,0) | |
| ConvTranspose2d: (in, out, h, w) → MLX: (out, h, w, in) transpose (1,2,3,0) | |
| """ | |
| if not key.endswith(".weight"): | |
| return value | |
| if len(value.shape) == 3: | |
| return np.transpose(value, (1, 2, 0) if is_conv_transpose else (0, 2, 1)) | |
| if len(value.shape) == 4: | |
| return np.transpose(value, (1, 2, 3, 0) if is_conv_transpose else (0, 2, 3, 1)) | |
| return value | |
| def remap_key( | |
| key: str, | |
| value: np.ndarray, | |
| model_type: str = "HTDemucs", | |
| dconv_conv_slots: set | None = None, | |
| seq_conv_slots: set | None = None, | |
| ) -> list[tuple[str, np.ndarray]]: | |
| """Remap a PyTorch state dict key to MLX key convention. | |
| Returns a list of (key, value) pairs (multiple for attention in_proj splits). | |
| Duplicate target keys (e.g. LSTM bias_ih + bias_hh) are merged by the caller. | |
| Args: | |
| key: PyTorch state dict key | |
| value: numpy array (already transposed for conv weights) | |
| model_type: PyTorch class name ("Demucs", "HDemucs", "HTDemucs") | |
| dconv_conv_slots: set of (block_prefix, slot_str) for DConv slots with 3D weights | |
| seq_conv_slots: set of (enc_dec, layer, slot) for Demucs v1/v2 Sequential Conv slots | |
| """ | |
| dconv_conv_slots = dconv_conv_slots or set() | |
| seq_conv_slots = seq_conv_slots or set() | |
| # ========================================================================= | |
| # Step 1: Demucs v1/v2 Sequential insertion | |
| # encoder.{i}.{j}.rest → encoder.{i}.layers.{j}.rest | |
| # decoder.{i}.{j}.rest → decoder.{i}.layers.{j}.rest | |
| # ========================================================================= | |
| if model_type == "Demucs": | |
| m = re.match(r"(encoder|decoder)\.(\d+)\.(\d+)(\..*)?$", key) | |
| if m: | |
| enc_dec, layer, slot, rest = m.groups() | |
| rest = rest or "" | |
| key = f"{enc_dec}.{layer}.layers.{slot}{rest}" | |
| # ========================================================================= | |
| # Step 1.5: Demucs v1/v2 Sequential Conv/Norm slot wrapping | |
| # encoder.{i}.layers.{j}.weight → encoder.{i}.layers.{j}.conv.weight (if Conv slot) | |
| # ========================================================================= | |
| if model_type == "Demucs": | |
| m = re.match(r"(encoder|decoder)\.(\d+)\.layers\.(\d+)\.(weight|bias)$", key) | |
| if m: | |
| enc_dec, layer, slot, param = m.groups() | |
| if (enc_dec, layer, slot) in seq_conv_slots: | |
| return [(f"{enc_dec}.{layer}.layers.{slot}.conv.{param}", value)] | |
| else: | |
| return [(f"{enc_dec}.{layer}.layers.{slot}.{param}", value)] | |
| # ========================================================================= | |
| # Step 2: DConv internal slot handling | |
| # Matches: *.layers.{block_idx}.{slot_idx}.{rest} | |
| # Both HDemucs (.dconv.layers.) and Demucs v1/v2 (.layers.{N}.layers.) end | |
| # with this pattern after Step 1. | |
| # ========================================================================= | |
| m = re.match(r"(.+\.layers\.\d+)\.(\d+)\.(.+)$", key) | |
| if m: | |
| block_prefix = m.group(1) | |
| slot = m.group(2) | |
| rest = m.group(3) | |
| # --- 2a. Simple weight/bias/scale --- | |
| if rest in ("weight", "bias", "scale"): | |
| if rest == "weight" and len(value.shape) >= 2: | |
| # 3D weight = Conv1d → add .conv. | |
| return [(f"{block_prefix}.layers.{slot}.conv.{rest}", value)] | |
| elif rest == "weight": | |
| # 1D weight = GroupNorm → no wrapper | |
| return [(f"{block_prefix}.layers.{slot}.{rest}", value)] | |
| elif rest == "bias": | |
| if (block_prefix, slot) in dconv_conv_slots: | |
| return [(f"{block_prefix}.layers.{slot}.conv.{rest}", value)] | |
| else: | |
| return [(f"{block_prefix}.layers.{slot}.{rest}", value)] | |
| else: # scale | |
| return [(f"{block_prefix}.layers.{slot}.{rest}", value)] | |
| # --- 2b. LSTM weights/biases --- | |
| m_lstm = re.match(r"lstm\.(weight|bias)_(ih|hh)_l(\d+)(_reverse)?$", rest) | |
| if m_lstm: | |
| wb, ih_hh, layer_idx, reverse = m_lstm.groups() | |
| direction = "backward_lstms" if reverse else "forward_lstms" | |
| if wb == "weight": | |
| param = "Wx" if ih_hh == "ih" else "Wh" | |
| return [(f"{block_prefix}.layers.{slot}.{direction}.{layer_idx}.{param}", value)] | |
| else: # bias — both bias_ih and bias_hh map to same key; caller merges | |
| return [(f"{block_prefix}.layers.{slot}.{direction}.{layer_idx}.bias", value)] | |
| # --- 2c. LSTM linear --- | |
| m_linear = re.match(r"linear\.(weight|bias)$", rest) | |
| if m_linear: | |
| param = m_linear.group(1) | |
| return [(f"{block_prefix}.layers.{slot}.linear.{param}", value)] | |
| # --- 2d. Attention sub-modules (LocalState) --- | |
| m_attn = re.match(r"(content|key|query|proj|query_decay|query_freqs)\.(weight|bias)$", rest) | |
| if m_attn: | |
| attn_name, param = m_attn.groups() | |
| # These are all Conv1d modules → add .conv. wrapper | |
| return [(f"{block_prefix}.layers.{slot}.{attn_name}.conv.{param}", value)] | |
| # --- 2e. Fallback for unknown compound keys --- | |
| return [(f"{block_prefix}.layers.{slot}.{rest}", value)] | |
| # ========================================================================= | |
| # Step 3: MultiheadAttention in_proj split (HTDemucs transformer) | |
| # ========================================================================= | |
| m = re.match(r"(.+)\.(self_attn|cross_attn)\.in_proj_(weight|bias)$", key) | |
| if m: | |
| prefix, attn_type, param = m.group(1), m.group(2), m.group(3) | |
| mlx_attn = "attn" if attn_type == "self_attn" else "cross_attn" | |
| dim = value.shape[0] // 3 | |
| q, k_val, v = value[:dim], value[dim : 2 * dim], value[2 * dim :] | |
| return [ | |
| (f"{prefix}.{mlx_attn}.query_proj.{param}", q), | |
| (f"{prefix}.{mlx_attn}.key_proj.{param}", k_val), | |
| (f"{prefix}.{mlx_attn}.value_proj.{param}", v), | |
| ] | |
| # self_attn.out_proj → attn.out_proj | |
| m = re.match(r"(.+)\.self_attn\.out_proj\.(weight|bias)$", key) | |
| if m: | |
| prefix, param = m.group(1), m.group(2) | |
| return [(f"{prefix}.attn.out_proj.{param}", value)] | |
| # ========================================================================= | |
| # Step 4: norm_out wrapping → norm_out.gn | |
| # ========================================================================= | |
| m = re.match(r"(.+)\.norm_out\.(weight|bias)$", key) | |
| if m: | |
| prefix, param = m.group(1), m.group(2) | |
| return [(f"{prefix}.norm_out.gn.{param}", value)] | |
| # ========================================================================= | |
| # Step 5: Bottleneck LSTM (Demucs v1/v2 and HDemucs) | |
| # lstm.lstm.weight_ih_l0 → lstm.forward_lstms.0.Wx | |
| # ========================================================================= | |
| m = re.match(r"(.+)\.lstm\.(weight|bias)_(ih|hh)_l(\d+)(_reverse)?$", key) | |
| if m: | |
| prefix = m.group(1) | |
| wb = m.group(2) | |
| ih_hh = m.group(3) | |
| layer_idx = m.group(4) | |
| reverse = m.group(5) | |
| direction = "backward_lstms" if reverse else "forward_lstms" | |
| if wb == "weight": | |
| param = "Wx" if ih_hh == "ih" else "Wh" | |
| return [(f"{prefix}.{direction}.{layer_idx}.{param}", value)] | |
| else: # bias — merge handled by caller | |
| return [(f"{prefix}.{direction}.{layer_idx}.bias", value)] | |
| # ========================================================================= | |
| # Step 6: Conv/ConvTranspose/Rewrite named layers → add .conv. wrapper | |
| # ========================================================================= | |
| parts = key.rsplit(".", 1) | |
| if len(parts) == 2: | |
| path, param = parts | |
| path_parts = path.split(".") | |
| last_name = path_parts[-1] | |
| if last_name in CONV_LAYER_NAMES and param in ("weight", "bias"): | |
| return [(f"{path}.conv.{param}", value)] | |
| # ========================================================================= | |
| # Default: no change | |
| # ========================================================================= | |
| return [(key, value)] | |
| def convert_sub_model(model, prefix: str) -> dict[str, np.ndarray]: | |
| """Convert a single sub-model's state dict to MLX-compatible numpy arrays.""" | |
| cls_name = type(model).__name__ | |
| # --- Pre-scan: identify ConvTranspose modules by type --- | |
| conv_tr_paths = set() | |
| for name, module in model.named_modules(): | |
| if isinstance(module, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d)): | |
| conv_tr_paths.add(name) | |
| # --- Collect state dict as numpy --- | |
| state_items = [] | |
| for key, tensor in model.state_dict().items(): | |
| arr = tensor.detach().cpu().float().numpy() | |
| state_items.append((key, arr)) | |
| # --- Pre-scan: identify DConv Conv slots (3D weights) --- | |
| # Pattern: *.layers.{block}.{slot}.weight where value is 3D | |
| # For Demucs v1/v2, apply Sequential insertion first so lookups match remap_key | |
| dconv_conv_slots: set[tuple[str, str]] = set() | |
| for key, arr in state_items: | |
| scan_key = key | |
| if cls_name == "Demucs": | |
| m = re.match(r"(encoder|decoder)\.(\d+)\.(\d+)(\..*)?$", scan_key) | |
| if m: | |
| enc_dec, layer, slot, rest = m.groups() | |
| rest = rest or "" | |
| scan_key = f"{enc_dec}.{layer}.layers.{slot}{rest}" | |
| m = re.match(r"(.+\.layers\.\d+)\.(\d+)\.weight$", scan_key) | |
| if m and len(arr.shape) >= 2: | |
| dconv_conv_slots.add((m.group(1), m.group(2))) | |
| # --- Pre-scan: Demucs v1/v2 Sequential Conv slots --- | |
| seq_conv_slots: set[tuple[str, str, str]] = set() | |
| if cls_name == "Demucs": | |
| for key, arr in state_items: | |
| m = re.match(r"(encoder|decoder)\.(\d+)\.(\d+)\.weight$", key) | |
| if m and len(arr.shape) >= 2: | |
| seq_conv_slots.add((m.group(1), m.group(2), m.group(3))) | |
| # --- Convert --- | |
| weights: dict[str, np.ndarray] = {} | |
| for key, arr in state_items: | |
| # Determine if this belongs to a ConvTranspose module | |
| is_conv_tr = any(key.startswith(p + ".") for p in conv_tr_paths) | |
| # Transpose conv weights | |
| arr = transpose_conv_weights(key, arr, is_conv_transpose=is_conv_tr) | |
| # Remap key | |
| remapped = remap_key(key, arr, cls_name, dconv_conv_slots, seq_conv_slots) | |
| for new_key, new_val in remapped: | |
| full_key = f"{prefix}{new_key}" | |
| if full_key in weights: | |
| # LSTM bias merge: bias_ih + bias_hh → bias (additive) | |
| weights[full_key] = weights[full_key] + new_val | |
| else: | |
| weights[full_key] = new_val | |
| return weights | |
| def extract_kwargs(model) -> dict: | |
| """Extract constructor kwargs from a model using _init_args_kwargs or inspection.""" | |
| if hasattr(model, "_init_args_kwargs"): | |
| _, kwargs = model._init_args_kwargs | |
| return {k: to_json_serializable(v) for k, v in kwargs.items() | |
| if isinstance(v, (int, float, str, bool, list, tuple, type(None), Fraction))} | |
| # Fallback: inspect __init__ signature and read matching attributes | |
| sig = inspect.signature(type(model).__init__) | |
| kwargs = {} | |
| for name in sig.parameters: | |
| if name == "self": | |
| continue | |
| if hasattr(model, name): | |
| val = getattr(model, name) | |
| kwargs[name] = to_json_serializable(val) | |
| return kwargs | |
| def export_model(model_name: str, out_dir: Path) -> bool: | |
| """Export a single model (or bag) to safetensors + config JSON.""" | |
| from demucs.pretrained import get_model | |
| from demucs.apply import BagOfModels | |
| print(f"\n--- Exporting {model_name} ---") | |
| try: | |
| model = get_model(model_name) | |
| except Exception as e: | |
| print(f" Failed to load model: {e}") | |
| return False | |
| is_bag = isinstance(model, BagOfModels) | |
| if is_bag: | |
| sub_models = list(model.models) | |
| num_models = len(sub_models) | |
| bag_weights = model.weights.tolist() if hasattr(model.weights, "tolist") else list(model.weights) | |
| else: | |
| sub_models = [model] | |
| num_models = 1 | |
| bag_weights = None | |
| print(f" {'Bag of ' + str(num_models) + ' models' if is_bag else 'Single model'}") | |
| # Collect all weights and metadata | |
| all_weights: dict[str, np.ndarray] = {} | |
| model_classes: list[str] = [] | |
| model_configs: list[dict] = [] | |
| for i, sub in enumerate(sub_models): | |
| cls_name = type(sub).__name__ | |
| mlx_cls = CLASS_MAP.get(cls_name, cls_name) | |
| model_classes.append(mlx_cls) | |
| print(f" Model {i}: {cls_name} → {mlx_cls}") | |
| prefix = f"model_{i}." if is_bag else "" | |
| sub_weights = convert_sub_model(sub, prefix) | |
| all_weights.update(sub_weights) | |
| kwargs = extract_kwargs(sub) | |
| model_configs.append({ | |
| "model_class": mlx_cls, | |
| "kwargs": kwargs, | |
| }) | |
| # Build config JSON | |
| config: dict = { | |
| "model_name": model_name, | |
| "tensor_count": len(all_weights), | |
| } | |
| if is_bag: | |
| config["model_class"] = "BagOfModelsMLX" | |
| config["num_models"] = num_models | |
| config["weights"] = bag_weights | |
| config["sub_model_classes"] = model_classes | |
| # If all sub-models are the same class, set sub_model_class for compat | |
| unique = set(model_classes) | |
| if len(unique) == 1: | |
| config["sub_model_class"] = unique.pop() | |
| config["model_configs"] = model_configs | |
| # Also put kwargs at top level for single-model bags (common case) | |
| if num_models == 1: | |
| config["kwargs"] = model_configs[0]["kwargs"] | |
| else: | |
| config["model_class"] = model_classes[0] | |
| config["kwargs"] = model_configs[0]["kwargs"] | |
| # Save files | |
| model_dir = out_dir / model_name | |
| model_dir.mkdir(parents=True, exist_ok=True) | |
| safetensors_path = model_dir / f"{model_name}.safetensors" | |
| config_path = model_dir / f"{model_name}_config.json" | |
| # Save safetensors (prefer safetensors library, fallback to mlx) | |
| try: | |
| from safetensors.numpy import save_file | |
| save_file(all_weights, str(safetensors_path)) | |
| except ImportError: | |
| import mlx.core as mx | |
| mlx_weights = {k: mx.array(v) for k, v in all_weights.items()} | |
| mx.save_safetensors(str(safetensors_path), mlx_weights) | |
| with config_path.open("w") as f: | |
| json.dump(config, f, indent=2, default=str) | |
| size_mb = safetensors_path.stat().st_size / (1024 * 1024) | |
| print(f" Wrote {safetensors_path} ({len(all_weights)} tensors, {size_mb:.0f} MB)") | |
| print(f" Wrote {config_path}") | |
| return True | |
| def main(): | |
| ap = argparse.ArgumentParser( | |
| description="Export Demucs PyTorch models to safetensors for Swift MLX" | |
| ) | |
| ap.add_argument( | |
| "--models", | |
| nargs="*", | |
| default=None, | |
| help=f"Models to export (default: all). Choices: {', '.join(ALL_MODELS)}", | |
| ) | |
| ap.add_argument( | |
| "--out-dir", | |
| default="./Models", | |
| help="Output root directory (files go into <out-dir>/<model_name>/)", | |
| ) | |
| args = ap.parse_args() | |
| models = args.models or ALL_MODELS | |
| out_dir = Path(args.out_dir).resolve() | |
| exported = 0 | |
| failed = 0 | |
| for name in models: | |
| if export_model(name, out_dir): | |
| exported += 1 | |
| else: | |
| failed += 1 | |
| print(f"\n=== Done: {exported} exported, {failed} failed ===") | |
| if failed: | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() | |