Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for finetuning with huggingface datasets #7834

Merged
merged 18 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 172 additions & 0 deletions examples/asr/conf/speech_to_text_hf_finetune.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
name: "Speech_To_Text_HF_Finetuning"
titu1994 marked this conversation as resolved.
Show resolved Hide resolved

# use `init_from_nemo_model` or `init_from_pretrained_model` to initialize the model
# We do not currently support `init_from_ptl_ckpt` to create a single script for all types of models.
init_from_nemo_model: null # path to nemo model
init_from_pretrained_model: null # name of pretrained NeMo model, e.g., `stt_en_fastconformer_transducer_large`

model:
sample_rate: 16000
compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag.
log_prediction: true # enables logging sample predictions in the output during training
rnnt_reduction: 'mean_volume'
skip_nan_grad: false

# configs for huggingface load_dataset function
data_path: "librispeech_asr"
data_name: null
titu1994 marked this conversation as resolved.
Show resolved Hide resolved
num_proc: 8 # Number of processes when downloading and generating the dataset locally.
streaming: false # set True to use streaming mode, which doesn't wait for data downloading but each training step takes longer in the first epoch. If True, you'll need to specify trainer.max_steps instead of trainer.max_epochs.

# keys for audio, sample_rate and text in the huggingface dataset, keys seperated by `.` for nested fields.
# An example of data item in the `librispeech_asr` dataset:
# {'chapter_id': 141231,
# 'file': '/home/patrick/.cache/huggingface/datasets/downloads/extracted/b7ded9969e09942ab65313e691e6fc2e12066192ee8527e21d634aca128afbe2/dev_clean/1272/141231/1272-141231-0000.flac',
# 'audio': {'path': '/home/patrick/.cache/huggingface/datasets/downloads/extracted/b7ded9969e09942ab65313e691e6fc2e12066192ee8527e21d634aca128afbe2/dev_clean/1272/141231/1272-141231-0000.flac',
# 'array': array([-0.00048828, -0.00018311, -0.00137329, ..., 0.00079346,
# 0.00091553, 0.00085449], dtype=float32),
# 'sampling_rate': 16000},
# 'id': '1272-141231-0000',
# 'speaker_id': 1272,
# 'text': 'A MAN SAID TO THE UNIVERSE SIR I EXIST'}
nithinraok marked this conversation as resolved.
Show resolved Hide resolved
audio_key: "audio.array"
sample_rate_key: "audio.sampling_rate"
text_key: "text"
titu1994 marked this conversation as resolved.
Show resolved Hide resolved

# simple text cleaning, by default converts all chars to lower-case and only keeps alpha-numeric chars.
normalize_text: true
symbols_to_keep: "-'"
nithinraok marked this conversation as resolved.
Show resolved Hide resolved

train_ds:
normalize_text: ${model.normalize_text}
symbols_to_keep: ${model.symbols_to_keep}
manifest_filepath: "hugginface" # set to a not None value to avoid breaking existing code
audio_key: ${model.audio_key}
sample_rate_key: ${model.sample_rate_key}
text_key: ${model.text_key}
hf_data_cfg: # configs for huggingface load_dataset, add more if needed
path: ${model.data_path}
name: ${model.data_name}
split: 'validation.clean'
titu1994 marked this conversation as resolved.
Show resolved Hide resolved
streaming: ${model.streaming}
num_proc: ${model.num_proc}

sample_rate: ${model.sample_rate}
batch_size: 16 # you may increase batch_size if your memory allows
shuffle: true
num_workers: 0
pin_memory: true
use_start_end_token: false

validation_ds:
normalize_text: ${model.normalize_text}
symbols_to_keep: ${model.symbols_to_keep}
manifest_filepath: "hugginface" # set to a not None value to avoid breaking existing code
audio_key: ${model.audio_key}
sample_rate_key: ${model.sample_rate_key}
text_key: ${model.text_key}
hf_data_cfg: # configs for huggingface load_dataset, add more if needed
path: ${model.data_path}
name: ${model.data_name}
split: 'validation.clean'
nithinraok marked this conversation as resolved.
Show resolved Hide resolved
streaming: ${model.streaming}
num_proc: ${model.num_proc}

sample_rate: ${model.sample_rate}
batch_size: 16
shuffle: false
num_workers: 0
pin_memory: true
use_start_end_token: false

test_ds:
normalize_text: ${model.normalize_text}
symbols_to_keep: ${model.symbols_to_keep}
manifest_filepath: "hugginface" # set to a not None value to avoid breaking existing code
audio_key: ${model.audio_key}
sample_rate_key: ${model.sample_rate_key}
text_key: ${model.text_key}
hf_data_cfg: # configs for huggingface load_dataset
path: ${model.data_path}
name: ${model.data_name}
split: 'test.other'
streaming: ${model.streaming}
num_proc: ${model.num_proc}

sample_rate: ${model.sample_rate}
batch_size: 16
shuffle: false
titu1994 marked this conversation as resolved.
Show resolved Hide resolved
num_workers: 0
pin_memory: true
use_start_end_token: false

char_labels: # use for char based models
update_labels: false
labels: null # example list config: \[' ', 'a', 'b', 'c'\]

tokenizer: # use for spe/bpe based tokenizer models
update_tokenizer: false
dir: null # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe)
type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer)

spec_augment:
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
freq_masks: 2 # set to zero to disable it
time_masks: 10 # set to zero to disable it
freq_width: 27
time_width: 0.05

optim:
name: adamw
lr: 1e-4
# optimizer arguments
betas: [0.9, 0.98]
weight_decay: 1e-3

# scheduler setup
sched:
name: CosineAnnealing
# scheduler config override
warmup_steps: 5000
warmup_ratio: null
min_lr: 5e-6

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: 100
max_steps: -1 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: 0.0
precision: 32 # 16, 32, or bf16
log_every_n_steps: 10 # Interval of logging.
enable_progress_bar: True
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager
benchmark: false # needs to be false for models with variable-length speech input as it slows down training


exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: "val_wer"
mode: "min"
save_top_k: 5
always_save_nemo: True # saves the checkpoints as nemo files along with PTL checkpoints
resume_if_exists: false
resume_ignore_no_checkpoint: false

create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
14 changes: 14 additions & 0 deletions nemo/collections/asr/data/audio_to_text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
from torch.utils.data import ChainDataset

from nemo.collections.asr.data import audio_to_text, audio_to_text_dali
from nemo.collections.asr.data.huggingface.hf_audio_to_text_dataset import (
get_hf_audio_to_text_bpe_dataset,
get_hf_audio_to_text_char_dataset,
)
from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
from nemo.collections.common.data.dataset import CodeSwitchedDataset, ConcatDataset
from nemo.utils import logging
Expand Down Expand Up @@ -598,6 +602,11 @@ def get_audio_to_text_char_dataset_from_config(
else:
augmentor = None

if 'hf_data_cfg' in config:
return get_hf_audio_to_text_char_dataset(
config=config, global_rank=global_rank, world_size=world_size, augmentor=augmentor
)

is_concat = config.get('is_concat', False)
if is_concat:
if 'concat_sampling_technique' in config and config['concat_sampling_technique'] is None:
Expand Down Expand Up @@ -722,6 +731,11 @@ def get_audio_to_text_bpe_dataset_from_config(
else:
augmentor = None

if 'hf_data_cfg' in config:
return get_hf_audio_to_text_bpe_dataset(
config=config, global_rank=global_rank, world_size=world_size, tokenizer=tokenizer, augmentor=augmentor
)

is_concat = config.get('is_concat', False)
if is_concat:
if 'concat_sampling_technique' in config and config['concat_sampling_technique'] is None:
Expand Down
13 changes: 13 additions & 0 deletions nemo/collections/asr/data/huggingface/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
nithinraok marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading
Loading