Skip to content

Commit

Permalink
Attention encoder-decoder models for multiple speech-to-text tasks (#…
Browse files Browse the repository at this point in the history
…8242)

* Rebasing canary changes at current main

Signed-off-by: Piotr Żelasko <[email protected]>

* Move the changes from asr transformer to nlp transformer as originally intended

Signed-off-by: Piotr Żelasko <[email protected]>

* update eval to strip spaces before punctuations

Signed-off-by: stevehuang52 <[email protected]>

* update pc strip

Signed-off-by: stevehuang52 <[email protected]>

* [canary] Refactor: `PromptedAudioToTextLhotseDataset` and `EncDecMultiTaskModel` (#8247)

* Create a separate CanaryDataset and use it inside `transformer_bpe_models.py`. Ditches `token_sequence_format`.

Signed-off-by: Piotr Żelasko <[email protected]>

* [canary] Refactor: move changes in transformer_bpe_models.py to Canar… (#8252)

* [canary] Refactor: move changes in transformer_bpe_models.py to CanaryModel

Signed-off-by: Piotr Żelasko <[email protected]>

* Rename `CanaryModel` to `EncDecMultiTaskModel` and remove inheritance from `EncDecTransfModelBPE`; add a separate config for this model

Signed-off-by: Piotr Żelasko <[email protected]>

---------

Signed-off-by: Piotr Żelasko <[email protected]>

* Rename `CanaryDataset` to `PromptedAudioToTextLhotseDataset`; add `prompt_format_fn` argument; clean-up the `_canary_prompt_format` function a bit

Signed-off-by: Piotr Żelasko <[email protected]>

* Move tokenization into `prompt_format_fn`, fix usage, add docs

Signed-off-by: Piotr Żelasko <[email protected]>

* Backward-compatible utterance validation

Signed-off-by: Piotr Żelasko <[email protected]>

* Improve type annotations

Signed-off-by: Piotr Żelasko <[email protected]>

* config and prompt_fn registration changes from review

Signed-off-by: Piotr Żelasko <[email protected]>

---------

Signed-off-by: Piotr Żelasko <[email protected]>

* fix transcribe config

Signed-off-by: stevehuang52 <[email protected]>

* Refactor Canary to follow schema of remaining ASR models (#8260)

* Initial draft of multi task beam decoding strategy

Signed-off-by: smajumdar <[email protected]>

* Stabilize inference

Signed-off-by: smajumdar <[email protected]>

* Update AED Multi Task model to mostly conform to Archetype-Type format. Update config

Signed-off-by: smajumdar <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add change decoding strategy

Signed-off-by: smajumdar <[email protected]>

* Remove redundant imports

Signed-off-by: smajumdar <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Cleanup

Signed-off-by: smajumdar <[email protected]>

* Cleanup

Signed-off-by: smajumdar <[email protected]>

* remove asr transformer dependency on nlp

Signed-off-by: stevehuang52 <[email protected]>

* clean up

Signed-off-by: stevehuang52 <[email protected]>

* copy token_classifier from nlp to asr

Signed-off-by: stevehuang52 <[email protected]>

* Address comments

Signed-off-by: smajumdar <[email protected]>

* Add typing to beam decoding

Signed-off-by: smajumdar <[email protected]>

* Make prompt format configurable

Signed-off-by: smajumdar <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* drop asr dependency on nlp

Signed-off-by: stevehuang52 <[email protected]>

---------

Signed-off-by: smajumdar <[email protected]>
Signed-off-by: stevehuang52 <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: stevehuang52 <[email protected]>

* fix transcribe, update asr evaluator

Signed-off-by: stevehuang52 <[email protected]>

* Extend the docs for the canary prompt_fn

Signed-off-by: Piotr Żelasko <[email protected]>

* Incorporate changes from Nithin's code review

Signed-off-by: Piotr Żelasko <[email protected]>

* training bug fix and adding launch script for speech_multitask (#8270)

* bug fix and adding launch script for speech_multitask

Signed-off-by: Krishna Puvvada <[email protected]>

* update launch script example in speech_to_text_aed.py

Signed-off-by: Krishna Puvvada <[email protected]>

---------

Signed-off-by: Krishna Puvvada <[email protected]>
Co-authored-by: Krishna Puvvada <[email protected]>

* Fix: drop_last must be true in validation/test otherwise the training will hang

Signed-off-by: Piotr Żelasko <[email protected]>

* revert to current transcribe API

Signed-off-by: stevehuang52 <[email protected]>

* revert changes to NLP, update docs

Signed-off-by: stevehuang52 <[email protected]>

* update eval utils

Signed-off-by: stevehuang52 <[email protected]>

* update docs

Signed-off-by: stevehuang52 <[email protected]>

* Remove DALI; rename compute_audio_loss to compute_loss

Signed-off-by: Piotr Żelasko <[email protected]>

* set default use_model_transcribe=False

Signed-off-by: stevehuang52 <[email protected]>

* change os.path.dirname to pathlib

Signed-off-by: stevehuang52 <[email protected]>

* [canary] Test for CanaryTokenizer + refactoring (#8285)

* Test for CanaryTokenizer

Signed-off-by: Piotr Żelasko <[email protected]>

* Attempt at refactor...

Signed-off-by: Piotr Żelasko <[email protected]>

---------

Signed-off-by: Piotr Żelasko <[email protected]>

* Update config for AED models (#8294)

Signed-off-by: smajumdar <[email protected]>

* set default calculate_wer=False in transcribe_speech.py

Signed-off-by: stevehuang52 <[email protected]>

* Attention encoder-decoder models for multiple speech-to-text tasks

Signed-off-by: Piotr Żelasko <[email protected]>

* Apply suggestions from code review, part 1

Co-authored-by: Nithin Rao <[email protected]>
Signed-off-by: Piotr Żelasko <[email protected]>

* Apply suggestions from code review, part 2

Signed-off-by: Piotr Żelasko <[email protected]>

* Document compute_loss

Signed-off-by: Piotr Żelasko <[email protected]>

* update transcribe_speech.py

Signed-off-by: stevehuang52 <[email protected]>

* add docstring

Signed-off-by: stevehuang52 <[email protected]>

* Attention encoder-decoder models for multiple speech-to-text tasks

Signed-off-by: Piotr Żelasko <[email protected]>

---------

Signed-off-by: Piotr Żelasko <[email protected]>
Signed-off-by: stevehuang52 <[email protected]>
Signed-off-by: smajumdar <[email protected]>
Signed-off-by: Krishna Puvvada <[email protected]>
Signed-off-by: Piotr Żelasko <[email protected]>
Co-authored-by: stevehuang52 <[email protected]>
Co-authored-by: Somshubra Majumdar <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Krishna Puvvada <[email protected]>
Co-authored-by: Krishna Puvvada <[email protected]>
Co-authored-by: He Huang (Steve) <[email protected]>
Co-authored-by: Nithin Rao <[email protected]>
Signed-off-by: Pablo Garay <[email protected]>
  • Loading branch information
8 people authored and pablo-garay committed Mar 19, 2024
1 parent 07d25a8 commit ee0cc6f
Show file tree
Hide file tree
Showing 30 changed files with 2,643 additions and 119 deletions.
277 changes: 277 additions & 0 deletions examples/asr/conf/speech_multitask/fast-conformer_aed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
# It contains the default values for training an autoregressive FastConformer-Transformer AED model with sub-word encoding.

# Architecture and training config:
# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective
# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches.
# Here are the recommended configs for different variants of FastConformer-Transformer, other parameters are the same as in this config file.
# One extra (linear projection) layer is added between FastConformer encoder and Transformer decoder if they have different hidden sizes
# It is recommended to initialize FastConformer with ASR/SSL pre-trained encoder for better accuracy and faster convergence

name: "FastConformer-Transformer-MultiTask"

# Note: for larger models (1B+ params) initializing from a pretrained encoder
# may help (or even be required to) stabilize the training.
init_from_nemo_model: null

model:
_target_: nemo.collections.asr.models.EncDecMultiTaskModel
sample_rate: 16000
label_smoothing: 0.0
context_len_for_AR_decoding: 5 # Length of input prompt tokens. For example, in Canary models, we use [BOS,src_lang,task,tgt_lang,pnc] and thus the length is 5
log_prediction: true # enables logging sample predictions in the output during training

# Important ! Set the prompt format to the class you need
prompt_format: ??? # Options supported: ["canary"]

model_defaults:
asr_enc_hidden: 1024
lm_enc_hidden: 512
lm_dec_hidden: 1024

train_ds:
use_lhotse: true
tarred_audio_filepaths: null
manifest_filepath: ???
sample_rate: ${model.sample_rate}
shuffle: true
num_workers: 8
# To understand the settings below, please refer to Lhotse Dataloading documentation:
# https://github.com/NVIDIA/NeMo/blob/main/docs/source/asr/datasets.rst#lhotse-dataloading
# You can also check the following configuration dataclass:
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/data/lhotse/dataloader.py#L36
batch_size: None
batch_duration: 360
quadratic_duration: 15
use_bucketing: True
num_buckets: 20
bucket_buffer_size: 20000
shuffle_buffer_size: 10000

validation_ds:
use_lhotse: true
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 8 # you may increase batch_size if your memory allows
shuffle: false
num_workers: 4
pin_memory: true
use_start_end_token: true
use_bucketing: false

test_ds:
use_lhotse: true
manifest_filepath: ???
sample_rate: ${model.sample_rate}
batch_size: 8 # you may increase batch_size if your memory allows
shuffle: false
num_workers: 4
pin_memory: true
use_start_end_token: true
use_bucketing: false

# recommend small vocab size of 128 or 256 when using 4x sub-sampling
# you may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py
tokenizer:
dir: null # Null for aggregate tokenizers
type: agg # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer) or `agg` for aggregate tokenizers
langs:
spl_tokens: # special tokens model
dir: ???
type: bpe
en: # English tokenizer (example, replace with whichever language you would like or add tokenizers to add tokenizer for additional languages)
dir: ???
type: bpe

custom_tokenizer:
_target_: nemo.collections.common.tokenizers.canary_tokenizer.CanaryTokenizer # Can be replaced with other tokenizer for different prompt formats
tokenizers: null # Filled at runtime by all the tokenizers inside the aggregate tokenizer

# Audio Preprocessor
preprocessor:
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
sample_rate: ${model.sample_rate}
normalize: "per_feature"
window_size: 0.025
window_stride: 0.01
window: "hann"
features: 80
n_fft: 512
log: true
frame_splicing: 1
dither: 0.00001
pad_to: 0
pad_value: 0.0

# SpecAugment is applied either in the model or in the data layer
spec_augment:
_target_: nemo.collections.asr.modules.SpectrogramAugmentation
freq_masks: 2 # set to zero to disable it
# you may use lower time_masks for smaller models to have a faster convergence
time_masks: 10 # set to zero to disable it
freq_width: 27
time_width: 0.05

# FastConformer Encoder
encoder:
_target_: nemo.collections.asr.modules.ConformerEncoder
feat_in: ${model.preprocessor.features}
feat_out: -1 # you may set it if you need different output size other than the default d_model
n_layers: 24
d_model: ${model.model_defaults.asr_enc_hidden}

# Sub-sampling params
subsampling: dw_striding # vggnet or striding, vggnet may give better results but needs more memory
subsampling_factor: 8 # must be power of 2
subsampling_conv_channels: 256 # -1 sets it to d_model
causal_downsampling: false
reduction: null
reduction_position: null
reduction_factor: 1

# Feed forward module's params
ff_expansion_factor: 4

# Multi-headed Attention Module's params
self_attention_model: rel_pos # rel_pos or abs_pos
n_heads: 8 # may need to be lower for smaller d_models
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
att_context_size: [-1, -1] # -1 means unlimited context
xscaling: false # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000

# Convolution module's params
conv_kernel_size: 9
conv_norm_type: batch_norm
conv_context_size: null

### regularization
dropout: 0.1 # The dropout used in most of the Conformer Modules
dropout_pre_encoder: 0.1
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules

# Optional Transformer Encoder sandwitched between ASR Encoder and Transformer Ddcoder.
# Only used if num_layers > 0
transf_encoder:
_target_: nemo.collections.asr.modules.transformer.transformer_encoders.TransformerEncoder
num_layers: 0
hidden_size: ${model.model_defaults.lm_enc_hidden}
inner_size: ${multiply:${model.model_defaults.lm_enc_hidden}, 4}
num_attention_heads: 8
ffn_dropout: 0.1
attn_score_dropout: 0.1
attn_layer_dropout: 0.1
mask_future: False
pre_ln: True
pre_ln_final_layer_norm: True

transf_decoder:
_target_: nemo.collections.asr.modules.transformer.get_nemo_transformer
model_name: null
pretrained: false
encoder: null
pre_ln_final_layer_norm: true

config_dict:
max_sequence_length: 512
num_token_types: 0
embedding_dropout: 0.1
learn_positional_encodings: false
hidden_size: ${model.model_defaults.lm_dec_hidden}
inner_size: ${multiply:${model.model_defaults.lm_dec_hidden}, 4}
num_layers: 24
num_attention_heads: 8
ffn_dropout: 0.1
attn_score_dropout: 0.1
attn_layer_dropout: 0.1
hidden_act: relu
pre_ln: true
vocab_size: None # Will be set by the model at runtime

# Label Prediction Head (Token Classifier)
head:
_target_: nemo.collections.asr.parts.submodules.token_classifier.TokenClassifier
num_layers: 1
activation: relu
log_softmax: true
hidden_size: ${model.transf_decoder.config_dict.hidden_size}
num_classes: None # Will be set by the model at runtime
dropout: 0.0
use_transformer_init: true

# Decoding Strategy
decoding:
strategy: beam
return_best_hypothesis: true # Returns the most probably hypothesis after beam search

beam:
beam_size: 1
len_pen: 0.0
max_generation_delta: 50

# Loss Config
loss:
_target_: nemo.collections.common.losses.smoothed_cross_entropy.SmoothedCrossEntropyLoss
label_smoothing: ${model.label_smoothing}
pad_id: null

optim:
name: adamw
lr: 3e-4
# optimizer arguments
betas: [0.9, 0.98]
# less necessity for weight_decay as we already have large augmentations with SpecAug
# you may need weight_decay for large models, stable AMP training, small datasets, or when lower augmentations are used
# weight decay of 0.0 with lr of 2.0 also works fine
weight_decay: 1e-3

# scheduler setup
sched:
name: InverseSquareRootAnnealing
# scheduler config override
warmup_steps: 2500
warmup_ratio: null
min_lr: 1e-6

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: -1
max_steps: 100000 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: 0.0
precision: 16 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 100 # Interval of logging.
enable_progress_bar: True
num_sanity_val_steps: 2 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager

exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: "val_sacreBLEU"
mode: "max"
save_top_k: 3
always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints

resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
# you need to set these two to True to continue the training
resume_if_exists: true
resume_ignore_no_checkpoint: false

# You may use this section to create a W&B logger
create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ model:
min_lr: 1e-6

trainer:
gpus: -1 # number of GPUs, -1 would use all available GPUs
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: 100
max_steps: -1 # computed at runtime if not set
Expand Down
81 changes: 81 additions & 0 deletions examples/asr/speech_multitask/speech_to_text_aed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
# Training the model
```sh
python speech_to_text_aed.py \
# (Optional: --config-path=<path to dir of configs> --config-name=<name of config without .yaml>) \
model.train_ds.tarred_audio_filepaths=<path to tar files with audio> \
model.train_ds.manifest_filepath=<path to audio data manifest> \
model.train_ds.batch_duration=360 \
model.train_ds.num_buckets=30 \
model.train_ds.bucket_duration_bins=<optional list of precomputed float bins for bucket durations, speeds up init> \
model.validation_ds.manifest_filepath=<path to validation manifest> \
model.test_ds.manifest_filepath=<path to test manifest> \
model.model_defaults.asr_enc_hidden=1024 \
model.model_defaults.lm_enc_hidden=512 \
model.model_defaults.lm_dec_hidden=1024 \
model.tokenizer.langs.spl_tokens.dir=<path to the directory of prompt special tokens tokenizer> \
model.tokenizer.langs.spl_tokens.type=bpe \
model.tokenizer.langs.en.dir=<path to the directory of en language tokenizer (add new langs the same way)> \
model.tokenizer.langs.en.type=bpe \
model.prompt_format="canary" \
trainer.devices=-1 \
trainer.accelerator="ddp" \
trainer.max_steps=100000 \
+trainer.limit_train_batches=20000 \
trainer.val_check_interval=5000 \
+trainer.use_distributed_sampler=false \
model.optim.name="adamw" \
model.optim.lr=0.001 \
model.optim.betas=[0.9,0.999] \
model.optim.weight_decay=0.0001 \
model.optim.sched.warmup_steps=2000 \
exp_manager.create_wandb_logger=True \
exp_manager.wandb_logger_kwargs.name="<Name of experiment>" \
exp_manager.wandb_logger_kwargs.project="<Name of project>"
```
"""

import pytorch_lightning as pl
from omegaconf import OmegaConf

from nemo.collections.asr.models import EncDecMultiTaskModel
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager


@hydra_runner(config_path="../conf/speech_multitask/", config_name="fast-conformer_aed")
def main(cfg):
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
aed_model = EncDecMultiTaskModel(cfg=cfg.model, trainer=trainer)

# Initialize the weights of the model from another model, if provided via config
aed_model.maybe_init_from_pretrained_checkpoint(cfg)
trainer.fit(aed_model)

if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
if aed_model.prepare_test(trainer):
trainer.test(aed_model)


if __name__ == '__main__':
main()
Loading

0 comments on commit ee0cc6f

Please sign in to comment.