-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for finetuning with huggingface datasets #7834
Merged
Merged
Changes from 14 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
6367e2e
add finetune with huggingface dataset
stevehuang52 d132f1d
Merge remote-tracking branch 'origin' into add_hf_finetune
stevehuang52 22c8c21
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 763766d
update yaml
stevehuang52 141417b
Merge branch 'add_hf_finetune' of https://github.com/NVIDIA/NeMo into…
stevehuang52 3e36a85
update
stevehuang52 8704df3
update and refactor
stevehuang52 baea154
add extrac hf text and update
stevehuang52 31740ef
update and refactor
stevehuang52 1d87db1
move dataset dependency to common
stevehuang52 40c643b
add docstring
stevehuang52 c9994a1
Add to Dics
d80ea6c
Merge branch 'main' into add_hf_finetune
nithinraok 8353abb
add ci test
ab75759
add max steps in jenkins
db5a2ce
reduce max steps
ec75267
jenkins test
aeeec74
add bs=2
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
189 changes: 189 additions & 0 deletions
189
examples/asr/conf/asr_finetune/speech_to_text_hf_finetune.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
name: "Speech_To_Text_HF_Finetuning_using_HF_Datasets" | ||
|
||
# use `init_from_nemo_model` or `init_from_pretrained_model` to initialize the model | ||
# We do not currently support `init_from_ptl_ckpt` to create a single script for all types of models. | ||
init_from_nemo_model: null # path to nemo model | ||
init_from_pretrained_model: null # name of pretrained NeMo model, e.g., `stt_en_fastconformer_transducer_large` | ||
|
||
model: | ||
sample_rate: 16000 | ||
compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag. | ||
log_prediction: true # enables logging sample predictions in the output during training | ||
rnnt_reduction: 'mean_volume' | ||
skip_nan_grad: false | ||
|
||
# configs for huggingface load_dataset function | ||
data_path: "librispeech_asr" | ||
data_name: null # name for the specific dataset to load, e.g., 'en' for MCV datasets, but some datasets don't require this field. | ||
streaming: false # set True to use streaming mode, which doesn't wait for data downloading but each training step takes longer in the first epoch. If True, you'll need to specify trainer.max_steps instead of trainer.max_epochs. | ||
|
||
# keys for audio, sample_rate and transcription in the huggingface dataset, keys seperated by `.` for nested fields. See example at the bottom of this file. | ||
audio_key: "audio.array" | ||
sample_rate_key: "audio.sampling_rate" | ||
text_key: "text" # the key for groundtruth transcription, e.g., MCV usually uses "sentence" while some others use "text" | ||
|
||
# simple text cleaning, by default converts all chars to lower-case and only keeps alpha-numeric chars. | ||
normalize_text: true | ||
symbols_to_keep: ["'"] # a list of symbols to keep during text cleaning. | ||
|
||
train_ds: | ||
manifest_filepath: "hugginface" # set to a not None value to avoid breaking existing code | ||
streaming: ${model.streaming} | ||
normalize_text: ${model.normalize_text} | ||
symbols_to_keep: ${model.symbols_to_keep} | ||
audio_key: ${model.audio_key} | ||
sample_rate_key: ${model.sample_rate_key} | ||
text_key: ${model.text_key} | ||
hf_data_cfg: # hf_data_cfg can be a ListConfig or DictConfig. Params for each data are passed into huggingface load_dataset(). Add more params if needed | ||
- path: ${model.data_path} | ||
name: ${model.data_name} | ||
split: 'train.clean.360' | ||
streaming: ${model.streaming} | ||
- path: ${model.data_path} | ||
name: ${model.data_name} | ||
split: 'train.clean.100' | ||
streaming: ${model.streaming} | ||
- path: ${model.data_path} | ||
name: ${model.data_name} | ||
split: 'train.other.500' | ||
streaming: ${model.streaming} | ||
|
||
sample_rate: ${model.sample_rate} | ||
batch_size: 16 # you may increase batch_size if your memory allows | ||
shuffle: true | ||
shuffle_n: 2048 | ||
num_workers: 8 | ||
pin_memory: true | ||
use_start_end_token: false | ||
|
||
validation_ds: | ||
manifest_filepath: "hugginface" # set to a not None value to avoid breaking existing code | ||
streaming: ${model.streaming} | ||
normalize_text: ${model.normalize_text} | ||
symbols_to_keep: ${model.symbols_to_keep} | ||
audio_key: ${model.audio_key} | ||
sample_rate_key: ${model.sample_rate_key} | ||
text_key: ${model.text_key} | ||
hf_data_cfg: # An example of using only one dataset | ||
path: ${model.data_path} | ||
name: ${model.data_name} | ||
split: 'validation.other' | ||
streaming: ${model.streaming} | ||
|
||
sample_rate: ${model.sample_rate} | ||
batch_size: 8 | ||
shuffle: false | ||
shuffle_n: 2048 | ||
num_workers: 8 | ||
pin_memory: true | ||
use_start_end_token: false | ||
|
||
test_ds: | ||
manifest_filepath: "hugginface" # set to a not None value to avoid breaking existing code | ||
streaming: ${model.streaming} | ||
normalize_text: ${model.normalize_text} | ||
symbols_to_keep: ${model.symbols_to_keep} | ||
audio_key: ${model.audio_key} | ||
sample_rate_key: ${model.sample_rate_key} | ||
text_key: ${model.text_key} | ||
hf_data_cfg: # hf_data_cfg can be a ListConfig or DictConfig. Params for each data are passed into huggingface load_dataset(). Add more params if needed | ||
- path: ${model.data_path} | ||
name: ${model.data_name} | ||
split: 'test.other' | ||
streaming: ${model.streaming} | ||
- path: ${model.data_path} | ||
name: ${model.data_name} | ||
split: 'test.clean' | ||
streaming: ${model.streaming} | ||
|
||
sample_rate: ${model.sample_rate} | ||
batch_size: 8 | ||
shuffle: false | ||
shuffle_n: 2048 | ||
num_workers: 8 | ||
pin_memory: true | ||
use_start_end_token: false | ||
|
||
char_labels: # use for char based models | ||
update_labels: false | ||
labels: null # example list config: \[' ', 'a', 'b', 'c'\] | ||
|
||
tokenizer: # use for spe/bpe based tokenizer models | ||
update_tokenizer: false | ||
dir: null # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe) | ||
type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer) | ||
|
||
spec_augment: | ||
_target_: nemo.collections.asr.modules.SpectrogramAugmentation | ||
freq_masks: 2 # set to zero to disable it | ||
time_masks: 10 # set to zero to disable it | ||
freq_width: 27 | ||
time_width: 0.05 | ||
|
||
optim: | ||
name: adamw | ||
lr: 1e-4 | ||
# optimizer arguments | ||
betas: [0.9, 0.98] | ||
weight_decay: 1e-3 | ||
|
||
# scheduler setup | ||
sched: | ||
name: CosineAnnealing | ||
# scheduler config override | ||
warmup_steps: 5000 | ||
warmup_ratio: null | ||
min_lr: 5e-6 | ||
|
||
trainer: | ||
devices: -1 # number of GPUs, -1 would use all available GPUs | ||
num_nodes: 1 | ||
max_epochs: 100 | ||
max_steps: -1 # computed at runtime if not set | ||
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations | ||
accelerator: auto | ||
strategy: ddp | ||
accumulate_grad_batches: 1 | ||
gradient_clip_val: 0.0 | ||
precision: 32 # 16, 32, or bf16 | ||
log_every_n_steps: 10 # Interval of logging. | ||
enable_progress_bar: True | ||
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it | ||
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs | ||
sync_batchnorm: true | ||
enable_checkpointing: False # Provided by exp_manager | ||
logger: false # Provided by exp_manager | ||
benchmark: false # needs to be false for models with variable-length speech input as it slows down training | ||
|
||
|
||
exp_manager: | ||
exp_dir: null | ||
name: ${name} | ||
create_tensorboard_logger: true | ||
create_checkpoint_callback: true | ||
checkpoint_callback_params: | ||
# in case of multiple validation sets, first one is used | ||
monitor: "val_wer" | ||
mode: "min" | ||
save_top_k: 5 | ||
always_save_nemo: True # saves the checkpoints as nemo files along with PTL checkpoints | ||
resume_if_exists: false | ||
resume_ignore_no_checkpoint: false | ||
|
||
create_wandb_logger: false | ||
wandb_logger_kwargs: | ||
name: null | ||
project: null | ||
|
||
|
||
# An example item in the HuggingFace `librispeech_asr` dataset: | ||
# {'chapter_id': 141231, | ||
# 'file': '/home/patrick/.cache/huggingface/datasets/downloads/extracted/b7ded9969e09942ab65313e691e6fc2e12066192ee8527e21d634aca128afbe2/dev_clean/1272/141231/1272-141231-0000.flac', | ||
# 'audio': { | ||
# 'path': '/home/patrick/.cache/huggingface/datasets/downloads/extracted/b7ded9969e09942ab65313e691e6fc2e12066192ee8527e21d634aca128afbe2/dev_clean/1272/141231/1272-141231-0000.flac', | ||
# 'array': array([-0.00048828, -0.00018311, -0.00137329, ..., 0.00079346, 0.00091553, 0.00085449], dtype=float32), | ||
# 'sampling_rate': 16000 | ||
# }, | ||
# 'id': '1272-141231-0000', | ||
# 'speaker_id': 1272, | ||
# 'text': 'A MAN SAID TO THE UNIVERSE SIR I EXIST'} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets set all batch sizes to 1 for train and val