-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
models : add conversion scripts from HuggingFace models to CoreML (#1304
- Loading branch information
1 parent
2f668c3
commit 91c0b23
Showing
2 changed files
with
134 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import argparse | ||
import importlib.util | ||
|
||
spec = importlib.util.spec_from_file_location('whisper_to_coreml', 'models/convert-whisper-to-coreml.py') | ||
whisper_to_coreml = importlib.util.module_from_spec(spec) | ||
spec.loader.exec_module(whisper_to_coreml) | ||
|
||
from whisper import load_model | ||
|
||
from copy import deepcopy | ||
import torch | ||
from transformers import WhisperForConditionalGeneration | ||
from huggingface_hub import metadata_update | ||
|
||
# https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets/blob/main/src/multiple_datasets/hub_default_utils.py | ||
WHISPER_MAPPING = { | ||
"layers": "blocks", | ||
"fc1": "mlp.0", | ||
"fc2": "mlp.2", | ||
"final_layer_norm": "mlp_ln", | ||
"layers": "blocks", | ||
".self_attn.q_proj": ".attn.query", | ||
".self_attn.k_proj": ".attn.key", | ||
".self_attn.v_proj": ".attn.value", | ||
".self_attn_layer_norm": ".attn_ln", | ||
".self_attn.out_proj": ".attn.out", | ||
".encoder_attn.q_proj": ".cross_attn.query", | ||
".encoder_attn.k_proj": ".cross_attn.key", | ||
".encoder_attn.v_proj": ".cross_attn.value", | ||
".encoder_attn_layer_norm": ".cross_attn_ln", | ||
".encoder_attn.out_proj": ".cross_attn.out", | ||
"decoder.layer_norm.": "decoder.ln.", | ||
"encoder.layer_norm.": "encoder.ln_post.", | ||
"embed_tokens": "token_embedding", | ||
"encoder.embed_positions.weight": "encoder.positional_embedding", | ||
"decoder.embed_positions.weight": "decoder.positional_embedding", | ||
"layer_norm": "ln_post", | ||
} | ||
|
||
# https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets/blob/main/src/multiple_datasets/hub_default_utils.py | ||
def rename_keys(s_dict): | ||
keys = list(s_dict.keys()) | ||
for key in keys: | ||
new_key = key | ||
for k, v in WHISPER_MAPPING.items(): | ||
if k in key: | ||
new_key = new_key.replace(k, v) | ||
|
||
print(f"{key} -> {new_key}") | ||
|
||
s_dict[new_key] = s_dict.pop(key) | ||
return s_dict | ||
|
||
# https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets/blob/main/src/multiple_datasets/hub_default_utils.py | ||
def convert_hf_whisper(hf_model_name_or_path: str, whisper_state_path: str): | ||
transformer_model = WhisperForConditionalGeneration.from_pretrained(hf_model_name_or_path) | ||
config = transformer_model.config | ||
|
||
# first build dims | ||
dims = { | ||
'n_mels': config.num_mel_bins, | ||
'n_vocab': config.vocab_size, | ||
'n_audio_ctx': config.max_source_positions, | ||
'n_audio_state': config.d_model, | ||
'n_audio_head': config.encoder_attention_heads, | ||
'n_audio_layer': config.encoder_layers, | ||
'n_text_ctx': config.max_target_positions, | ||
'n_text_state': config.d_model, | ||
'n_text_head': config.decoder_attention_heads, | ||
'n_text_layer': config.decoder_layers | ||
} | ||
|
||
state_dict = deepcopy(transformer_model.model.state_dict()) | ||
state_dict = rename_keys(state_dict) | ||
|
||
torch.save({"dims": dims, "model_state_dict": state_dict}, whisper_state_path) | ||
|
||
# Ported from models/convert-whisper-to-coreml.py | ||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model-name", type=str, help="name of model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large, large-v1)", required=True) | ||
parser.add_argument("--model-path", type=str, help="path to the model (e.g. if published on HuggingFace: Oblivion208/whisper-tiny-cantonese)", required=True) | ||
parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False) | ||
parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False) | ||
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False) | ||
args = parser.parse_args() | ||
|
||
if args.model_name not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1"]: | ||
raise ValueError("Invalid model name") | ||
|
||
pt_target_path = f"models/hf-{args.model_name}.pt" | ||
convert_hf_whisper(args.model_path, pt_target_path) | ||
|
||
whisper = load_model(pt_target_path).cpu() | ||
hparams = whisper.dims | ||
print(hparams) | ||
|
||
if args.optimize_ane: | ||
whisperANE = whisper_to_coreml.WhisperANE(hparams).eval() | ||
whisperANE.load_state_dict(whisper.state_dict()) | ||
|
||
encoder = whisperANE.encoder | ||
decoder = whisperANE.decoder | ||
else: | ||
encoder = whisper.encoder | ||
decoder = whisper.decoder | ||
|
||
# Convert encoder | ||
encoder = whisper_to_coreml.convert_encoder(hparams, encoder, quantize=args.quantize) | ||
encoder.save(f"models/coreml-encoder-{args.model_name}.mlpackage") | ||
|
||
if args.encoder_only is False: | ||
# Convert decoder | ||
decoder = whisper_to_coreml.convert_decoder(hparams, decoder, quantize=args.quantize) | ||
decoder.save(f"models/coreml-decoder-{args.model_name}.mlpackage") | ||
|
||
print("done converting") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters