Skip to content

Commit

Permalink
[TTS] Modify codec training tutorial to allow HF download (NVIDIA#11518)
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan <[email protected]>
  • Loading branch information
rlangman authored Dec 10, 2024
1 parent f1616c1 commit a426f38
Showing 1 changed file with 19 additions and 26 deletions.
45 changes: 19 additions & 26 deletions tutorials/tts/Audio_Codec_Training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
"BRANCH = 'main'\n",
"# Install NeMo library. If you are running locally (rather than on Google Colab), comment out the below line\n",
"# and instead follow the instructions at https://github.com/NVIDIA/NeMo#Installation\n",
"!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]"
"!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[tts]"
]
},
{
Expand Down Expand Up @@ -157,7 +157,7 @@
},
"outputs": [],
"source": [
"CONFIG_FILENAME = \"audio_codec_16000.yaml\"\n",
"CONFIG_FILENAME = \"audio_codec_22050.yaml\"\n",
"CONFIG_DIR = NEMO_CONFIG_DIR / \"audio_codec\"\n",
"\n",
"config_filepath = CONFIG_DIR / CONFIG_FILENAME\n",
Expand Down Expand Up @@ -187,44 +187,29 @@
"id": "W7F--_0maLh5"
},
"source": [
"We provide pretrained model checkpoints for fine-tuning. The list of available models can be found [here](https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/models/audio_codec.py#L645)."
"We provide pretrained model checkpoints for fine-tuning.\n",
"\n",
"A list of models available on NGC can be found [here](https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/models/audio_codec.py#L645).\n",
"\n",
"A list of models available on Hugging Face can be found [here](https://huggingface.co/collections/nvidia/nemo-audio-codecs-674f57ab6cb1324f997b5d5b). To use a checkpoint from hugging face, add \"nvidia/\" before the model name."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XqAYWR65aKTx"
"id": "cADIAIDUcGWd"
},
"outputs": [],
"source": [
"import wget\n",
"from nemo.collections.tts.models.audio_codec import AudioCodecModel\n",
"\n",
"# Optionally specify a pretrained model to fine-tune from. To train from scratch, set this to 'None'.\n",
"pretrained_model_name = \"audio_codec_16khz_small\"\n",
"pretrained_model_name = \"nvidia/audio-codec-22khz\"\n",
"\n",
"if pretrained_model_name is None:\n",
" MODEL_CHECKPOINT_PATH = None\n",
"else:\n",
" model_list = AudioCodecModel.list_available_models()\n",
"\n",
" pretrained_model_url = None\n",
" for model in model_list:\n",
" if model.pretrained_model_name == pretrained_model_name:\n",
" pretrained_model_url = model.location\n",
" break\n",
"\n",
" if pretrained_model_url is None:\n",
" raise ValueError(f\"Could not find pretrained model {pretrained_model_name}. Models available {model_list}\")\n",
"\n",
" # Optionally load pretrained checkpoint\n",
" MODEL_CHECKPOINT_PATH = ROOT_DIR / \"models\" / f\"{pretrained_model_name}.nemo\"\n",
"\n",
" if not MODEL_CHECKPOINT_PATH.exists():\n",
" print(f\"Downloading {pretrained_model_url} to {MODEL_CHECKPOINT_PATH}\")\n",
" MODEL_CHECKPOINT_PATH.parent.mkdir(exist_ok=True)\n",
" wget.download(pretrained_model_url, out=str(MODEL_CHECKPOINT_PATH))"
" MODEL_CHECKPOINT_PATH = AudioCodecModel.from_pretrained(model_name=pretrained_model_name, return_model_file=True)"
]
},
{
Expand Down Expand Up @@ -254,6 +239,7 @@
"outputs": [],
"source": [
"import tarfile\n",
"import wget\n",
"\n",
"from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest"
]
Expand Down Expand Up @@ -636,9 +622,12 @@
"if torch.cuda.is_available():\n",
" accelerator=\"gpu\"\n",
" batch_size = 4\n",
" devices = -1\n",
"else:\n",
" import multiprocessing\n",
" accelerator=\"cpu\"\n",
" batch_size = 2\n",
" devices = multiprocessing.cpu_count()\n",
"\n",
"args = [\n",
" f\"--config-path={CONFIG_DIR}\",\n",
Expand All @@ -653,6 +642,7 @@
" f\"model.log_config.log_tensorboard={log_to_tensorboard}\",\n",
" f\"model.log_config.generators.0.log_dequantized={log_dequantized}\",\n",
" f\"trainer.accelerator={accelerator}\",\n",
" f\"trainer.devices={devices}\",\n",
" f\"+train_ds_meta.{dataset_name}.manifest_path={train_manifest_filepath}\",\n",
" f\"+train_ds_meta.{dataset_name}.audio_dir={audio_dir}\",\n",
" f\"+val_ds_meta.{dataset_name}.manifest_path={dev_manifest_filepath}\",\n",
Expand Down Expand Up @@ -790,6 +780,9 @@
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
Expand All @@ -800,4 +793,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}

0 comments on commit a426f38

Please sign in to comment.