Spaces:
Running
Running
models : add conversion scripts from HuggingFace models to CoreML (#1304)
Browse files- models/convert-h5-to-coreml.py +117 -0
- models/generate-coreml-model.sh +17 -6
models/convert-h5-to-coreml.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import importlib.util
|
| 3 |
+
|
| 4 |
+
spec = importlib.util.spec_from_file_location('whisper_to_coreml', 'models/convert-whisper-to-coreml.py')
|
| 5 |
+
whisper_to_coreml = importlib.util.module_from_spec(spec)
|
| 6 |
+
spec.loader.exec_module(whisper_to_coreml)
|
| 7 |
+
|
| 8 |
+
from whisper import load_model
|
| 9 |
+
|
| 10 |
+
from copy import deepcopy
|
| 11 |
+
import torch
|
| 12 |
+
from transformers import WhisperForConditionalGeneration
|
| 13 |
+
from huggingface_hub import metadata_update
|
| 14 |
+
|
| 15 |
+
# https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets/blob/main/src/multiple_datasets/hub_default_utils.py
|
| 16 |
+
WHISPER_MAPPING = {
|
| 17 |
+
"layers": "blocks",
|
| 18 |
+
"fc1": "mlp.0",
|
| 19 |
+
"fc2": "mlp.2",
|
| 20 |
+
"final_layer_norm": "mlp_ln",
|
| 21 |
+
"layers": "blocks",
|
| 22 |
+
".self_attn.q_proj": ".attn.query",
|
| 23 |
+
".self_attn.k_proj": ".attn.key",
|
| 24 |
+
".self_attn.v_proj": ".attn.value",
|
| 25 |
+
".self_attn_layer_norm": ".attn_ln",
|
| 26 |
+
".self_attn.out_proj": ".attn.out",
|
| 27 |
+
".encoder_attn.q_proj": ".cross_attn.query",
|
| 28 |
+
".encoder_attn.k_proj": ".cross_attn.key",
|
| 29 |
+
".encoder_attn.v_proj": ".cross_attn.value",
|
| 30 |
+
".encoder_attn_layer_norm": ".cross_attn_ln",
|
| 31 |
+
".encoder_attn.out_proj": ".cross_attn.out",
|
| 32 |
+
"decoder.layer_norm.": "decoder.ln.",
|
| 33 |
+
"encoder.layer_norm.": "encoder.ln_post.",
|
| 34 |
+
"embed_tokens": "token_embedding",
|
| 35 |
+
"encoder.embed_positions.weight": "encoder.positional_embedding",
|
| 36 |
+
"decoder.embed_positions.weight": "decoder.positional_embedding",
|
| 37 |
+
"layer_norm": "ln_post",
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
# https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets/blob/main/src/multiple_datasets/hub_default_utils.py
|
| 41 |
+
def rename_keys(s_dict):
|
| 42 |
+
keys = list(s_dict.keys())
|
| 43 |
+
for key in keys:
|
| 44 |
+
new_key = key
|
| 45 |
+
for k, v in WHISPER_MAPPING.items():
|
| 46 |
+
if k in key:
|
| 47 |
+
new_key = new_key.replace(k, v)
|
| 48 |
+
|
| 49 |
+
print(f"{key} -> {new_key}")
|
| 50 |
+
|
| 51 |
+
s_dict[new_key] = s_dict.pop(key)
|
| 52 |
+
return s_dict
|
| 53 |
+
|
| 54 |
+
# https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets/blob/main/src/multiple_datasets/hub_default_utils.py
|
| 55 |
+
def convert_hf_whisper(hf_model_name_or_path: str, whisper_state_path: str):
|
| 56 |
+
transformer_model = WhisperForConditionalGeneration.from_pretrained(hf_model_name_or_path)
|
| 57 |
+
config = transformer_model.config
|
| 58 |
+
|
| 59 |
+
# first build dims
|
| 60 |
+
dims = {
|
| 61 |
+
'n_mels': config.num_mel_bins,
|
| 62 |
+
'n_vocab': config.vocab_size,
|
| 63 |
+
'n_audio_ctx': config.max_source_positions,
|
| 64 |
+
'n_audio_state': config.d_model,
|
| 65 |
+
'n_audio_head': config.encoder_attention_heads,
|
| 66 |
+
'n_audio_layer': config.encoder_layers,
|
| 67 |
+
'n_text_ctx': config.max_target_positions,
|
| 68 |
+
'n_text_state': config.d_model,
|
| 69 |
+
'n_text_head': config.decoder_attention_heads,
|
| 70 |
+
'n_text_layer': config.decoder_layers
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
state_dict = deepcopy(transformer_model.model.state_dict())
|
| 74 |
+
state_dict = rename_keys(state_dict)
|
| 75 |
+
|
| 76 |
+
torch.save({"dims": dims, "model_state_dict": state_dict}, whisper_state_path)
|
| 77 |
+
|
| 78 |
+
# Ported from models/convert-whisper-to-coreml.py
|
| 79 |
+
if __name__ == "__main__":
|
| 80 |
+
parser = argparse.ArgumentParser()
|
| 81 |
+
parser.add_argument("--model-name", type=str, help="name of model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large, large-v1)", required=True)
|
| 82 |
+
parser.add_argument("--model-path", type=str, help="path to the model (e.g. if published on HuggingFace: Oblivion208/whisper-tiny-cantonese)", required=True)
|
| 83 |
+
parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False)
|
| 84 |
+
parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False)
|
| 85 |
+
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
|
| 86 |
+
args = parser.parse_args()
|
| 87 |
+
|
| 88 |
+
if args.model_name not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1"]:
|
| 89 |
+
raise ValueError("Invalid model name")
|
| 90 |
+
|
| 91 |
+
pt_target_path = f"models/hf-{args.model_name}.pt"
|
| 92 |
+
convert_hf_whisper(args.model_path, pt_target_path)
|
| 93 |
+
|
| 94 |
+
whisper = load_model(pt_target_path).cpu()
|
| 95 |
+
hparams = whisper.dims
|
| 96 |
+
print(hparams)
|
| 97 |
+
|
| 98 |
+
if args.optimize_ane:
|
| 99 |
+
whisperANE = whisper_to_coreml.WhisperANE(hparams).eval()
|
| 100 |
+
whisperANE.load_state_dict(whisper.state_dict())
|
| 101 |
+
|
| 102 |
+
encoder = whisperANE.encoder
|
| 103 |
+
decoder = whisperANE.decoder
|
| 104 |
+
else:
|
| 105 |
+
encoder = whisper.encoder
|
| 106 |
+
decoder = whisper.decoder
|
| 107 |
+
|
| 108 |
+
# Convert encoder
|
| 109 |
+
encoder = whisper_to_coreml.convert_encoder(hparams, encoder, quantize=args.quantize)
|
| 110 |
+
encoder.save(f"models/coreml-encoder-{args.model_name}.mlpackage")
|
| 111 |
+
|
| 112 |
+
if args.encoder_only is False:
|
| 113 |
+
# Convert decoder
|
| 114 |
+
decoder = whisper_to_coreml.convert_decoder(hparams, decoder, quantize=args.quantize)
|
| 115 |
+
decoder.save(f"models/coreml-decoder-{args.model_name}.mlpackage")
|
| 116 |
+
|
| 117 |
+
print("done converting")
|
models/generate-coreml-model.sh
CHANGED
|
@@ -1,11 +1,15 @@
|
|
| 1 |
#!/bin/bash
|
| 2 |
|
| 3 |
# Usage: ./generate-coreml-model.sh <model-name>
|
| 4 |
-
if [ $# -eq 0 ]
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
fi
|
| 10 |
|
| 11 |
mname="$1"
|
|
@@ -13,7 +17,14 @@ mname="$1"
|
|
| 13 |
wd=$(dirname "$0")
|
| 14 |
cd "$wd/../"
|
| 15 |
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
xcrun coremlc compile models/coreml-encoder-${mname}.mlpackage models/
|
| 19 |
rm -rf models/ggml-${mname}-encoder.mlmodelc
|
|
|
|
| 1 |
#!/bin/bash
|
| 2 |
|
| 3 |
# Usage: ./generate-coreml-model.sh <model-name>
|
| 4 |
+
if [ $# -eq 0 ]; then
|
| 5 |
+
echo "No model name supplied"
|
| 6 |
+
echo "Usage for Whisper models: ./generate-coreml-model.sh <model-name>"
|
| 7 |
+
echo "Usage for HuggingFace models: ./generate-coreml-model.sh -h5 <model-name> <model-path>"
|
| 8 |
+
exit 1
|
| 9 |
+
elif [[ "$1" == "-h5" && $# != 3 ]]; then
|
| 10 |
+
echo "No model name and model path supplied for a HuggingFace model"
|
| 11 |
+
echo "Usage for HuggingFace models: ./generate-coreml-model.sh -h5 <model-name> <model-path>"
|
| 12 |
+
exit 1
|
| 13 |
fi
|
| 14 |
|
| 15 |
mname="$1"
|
|
|
|
| 17 |
wd=$(dirname "$0")
|
| 18 |
cd "$wd/../"
|
| 19 |
|
| 20 |
+
if [[ $mname == "-h5" ]]; then
|
| 21 |
+
mname="$2"
|
| 22 |
+
mpath="$3"
|
| 23 |
+
echo $mpath
|
| 24 |
+
python3 models/convert-h5-to-coreml.py --model-name $mname --model-path $mpath --encoder-only True
|
| 25 |
+
else
|
| 26 |
+
python3 models/convert-whisper-to-coreml.py --model $mname --encoder-only True
|
| 27 |
+
fi
|
| 28 |
|
| 29 |
xcrun coremlc compile models/coreml-encoder-${mname}.mlpackage models/
|
| 30 |
rm -rf models/ggml-${mname}-encoder.mlmodelc
|