Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New extended prompt format for Canary, short utterances inference fix, and training micro-optimizations #11058

Merged
merged 19 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/asr/conf/speech_multitask/fast-conformer_aed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ spl_tokens:
model:
sample_rate: 16000
label_smoothing: 0.0
context_len_for_AR_decoding: 5 # Length of input prompt tokens. For example, in Canary models, we use [BOS,src_lang,task,tgt_lang,pnc] and thus the length is 5
use_loss_mask_for_prompt: false
log_prediction: true # enables logging sample predictions in the output during training

# Important ! Set the prompt format to the class you need
prompt_format: ??? # Options supported: ["canary"]
prompt_defaults: null

model_defaults:
asr_enc_hidden: 1024
Expand Down
24 changes: 23 additions & 1 deletion nemo/collections/asr/data/audio_to_text_lhotse_prompted.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch.utils.data
from lhotse import CutSet
from lhotse.cut import MixedCut
from lhotse.dataset import AudioSamples
from lhotse.dataset.collation import collate_vectors

Expand Down Expand Up @@ -99,7 +100,7 @@ def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch:
prompt_lens=prompt_lens,
prompted_transcript=prompts_with_answers,
prompted_transcript_lens=prompts_with_answers_lens,
cuts=cuts.drop_in_memory_data(),
cuts=_drop_in_memory_data(cuts),
)

def _collate_tokens(self, tokens: list[Union[list[int], torch.Tensor]]) -> tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -111,3 +112,24 @@ def _collate_tokens(self, tokens: list[Union[list[int], torch.Tensor]]) -> tuple

class ProbablyIncorrectLanguageKeyError(RuntimeError):
pass


def _drop_in_memory_data(
cuts: CutSet,
_fields=frozenset(MixedCut.__dataclass_fields__.keys()),
) -> CutSet:
"""Workaround for an edge case in cuts.drop_in_memory_data() on MixedCut with Lhotse<1.29.0"""
ans = []
for c in cuts:
# Not a mixed cut or a mixed cut that wasn't assigned any extra attributes.
if not isinstance(c, MixedCut) or _fields.issuperset(c.__dict__.keys()):
ans.append(c.drop_in_memory_data())
else:
extra_attrs = {k: v for k, v in c.__dict__.items() if k not in _fields}
for k in extra_attrs:
delattr(c, k)
ans.append(c.drop_in_memory_data())
for k, v in extra_attrs.items():
setattr(ans[-1], k, v)
setattr(c, k, v)
return CutSet(ans)
46 changes: 32 additions & 14 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def lens_to_mask(lens, max_length):
Create a mask from a tensor of lengths.
"""
batch_size = lens.shape[0]
mask = torch.arange(max_length).repeat(batch_size, 1).to(lens.device) < lens[:, None]
arange = torch.arange(max_length, device=lens.device)
mask = arange.expand(batch_size, max_length) < lens.unsqueeze(1)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

micro-optimization, removes some copies and memory movement

return mask


Expand Down Expand Up @@ -697,24 +698,33 @@ def training_step(self, batch: PromptedAudioToTextMiniBatch, batch_nb):
return torch.tensor([0.0])

input_ids, labels = batch.get_decoder_inputs_outputs()
input_ids_lens = batch.prompted_transcript_lens - 1
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixing off-by-one issue that included an extra padding frame in decoder masks


num_frames = batch.audio_lens.sum().float()
num_tokens = batch.prompted_transcript_lens.sum().float()
tot_frames = torch.as_tensor(batch.audio.numel(), device=num_frames.device, dtype=torch.float)
tot_tokens = torch.as_tensor(batch.prompted_transcript.numel(), device=num_frames.device, dtype=torch.float)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

micro-optimizations


transf_log_probs, encoded_len, enc_states, enc_mask = self.forward(
input_signal=batch.audio,
input_signal_length=batch.audio_lens,
transcript=input_ids,
transcript_length=batch.prompted_transcript_lens,
transcript_length=input_ids_lens,
)

audio_loss = self.loss(log_probs=transf_log_probs, labels=labels)
# Mask components: 1) discard padding & 2) discard prompt (notice the negation)
# For a full decoder sequence O with len M, the loss mask skips the first element,
# covering the remaining M-1 elements - hence we subtract 1 from prompt lens to account BOS.
loss_mask = None
if self.cfg.get("use_loss_mask_for_prompt", False):
maxlen = batch.prompted_transcript.shape[1] - 1
loss_mask = lens_to_mask(input_ids_lens, maxlen) & ~lens_to_mask(batch.prompt_lens - 1, maxlen)
audio_loss = self.loss(log_probs=transf_log_probs, labels=labels, output_mask=loss_mask)

num_frames = batch.audio_lens.sum()
num_tokens = batch.prompted_transcript_lens.sum()
tot_frames = batch.audio.numel()
tot_tokens = batch.prompted_transcript.numel()
tensorboard_logs = {
'train_loss': audio_loss,
'learning_rate': self._optimizer.param_groups[0]['lr'],
'batch_size': batch.audio.shape[0],
'learning_rate': torch.as_tensor(self._optimizer.param_groups[0]['lr']),
Fixed Show fixed Hide fixed
'batch_size': torch.as_tensor(batch.audio.shape[0]),
Fixed Show fixed Hide fixed
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

micro-optimizations (the PTL logger turned out to have an inefficient way of converting scalars to tensors)

'num_frames': num_frames,
'num_tokens': num_tokens,
'input_to_padding_ratio': num_frames / tot_frames,
Expand All @@ -725,6 +735,7 @@ def training_step(self, batch: PromptedAudioToTextMiniBatch, batch_nb):

def validation_pass(self, batch: PromptedAudioToTextMiniBatch, batch_idx, dataloader_idx=0, eval_mode="val"):
input_ids, labels = batch.get_decoder_inputs_outputs()
input_ids_lens = batch.prompted_transcript_lens - 1

transf_log_probs, encoded_len, enc_states, enc_mask = self.forward(
input_signal=batch.audio,
Expand All @@ -733,11 +744,16 @@ def validation_pass(self, batch: PromptedAudioToTextMiniBatch, batch_idx, datalo
transcript_length=batch.prompted_transcript_lens,
)

transf_loss = self.loss(log_probs=transf_log_probs, labels=labels)
self.val_loss(loss=transf_loss, num_measurements=transf_log_probs.shape[0] * transf_log_probs.shape[1])
output_dict = {
f'{eval_mode}_loss': transf_loss,
}
# Mask components: 1) discard padding & 2) discard prompt (notice the negation)
# For a full decoder sequence O with len M, the loss mask skips the first element,
# covering the remaining M-1 elements - hence we subtract 1 from prompt lens to account BOS.
loss_mask = None
if self.cfg.get("use_loss_mask_for_prompt", False):
maxlen = batch.prompted_transcript.shape[1] - 1
loss_mask = lens_to_mask(input_ids_lens, maxlen) & ~lens_to_mask(batch.prompt_lens - 1, maxlen)
transf_loss = self.loss(log_probs=transf_log_probs, labels=labels, output_mask=loss_mask)
self.val_loss(loss=transf_loss, num_measurements=loss_mask.long().sum())
output_dict = {f'{eval_mode}_loss': transf_loss}

self.wer.update(
predictions=enc_states,
Expand Down Expand Up @@ -983,6 +999,8 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
'text_field': config.get('text_field', 'answer'),
'lang_field': config.get('lang_field', 'target_lang'),
'channel_selector': config.get('channel_selector', None),
'pad_min_duration': config.get('pad_min_duration', 1.0),
'pad_direction': config.get('pad_direction', 'both'),
}

temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config))
Expand Down
20 changes: 0 additions & 20 deletions nemo/collections/asr/modules/transformer/transformer_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,27 +58,7 @@ def _build_pos_enc(self, hidden_size, max_sequence_length, device=None):
self.register_buffer('pos_enc', pos_enc)

def forward(self, position_ids):
max_pos_id = position_ids.max()
# update positional encoding if needed
if max_pos_id >= self._max_sequence_length:
Copy link
Collaborator Author

@pzelasko pzelasko Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is super costly as it triggers a DtoH transfer and CUDA sync on every call to transformer decoder forward, and the proposed solution doesn't work anyway (bad results instead of a crash).

logging.warning(
f'Max position id {max_pos_id} is greater than max sequence length {self._max_sequence_length}. Expanding position embeddings just for this batch. This is not expected to work very well. Consider chunking your input into smaller sequences.'
)
self._build_pos_enc(
hidden_size=self._hidden_size,
max_sequence_length=max_pos_id + 1,
device=position_ids.device,
)

embeddings = torch.embedding(self.pos_enc, position_ids)

# Revert expansion of position embeddings since this wall checkpoint size mismatches.
if max_pos_id >= self._max_sequence_length:
self._build_pos_enc(
hidden_size=self._hidden_size,
max_sequence_length=self._max_sequence_length,
device=position_ids.device,
)
return embeddings


Expand Down
6 changes: 6 additions & 0 deletions nemo/collections/common/data/lhotse/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ class LhotseDataLoadingConfig:
rir_enabled: bool = False
rir_path: str | None = None # str, must point to a lhotse RecordingSet manifest
rir_prob: float = 0.5
# f. Padding to a minimum duration. Examples shorter than this will be padded, others are unaffected.
pad_min_duration: Optional[float] = None
pad_direction: str = "right" # "right" | "left" | "both" | "random"

# 5. Other Lhotse options.
text_field: str = "text" # key to read the transcript from
Expand Down Expand Up @@ -278,6 +281,9 @@ def get_lhotse_dataloader_from_config(
keep_excessive_supervisions=config.keep_excessive_supervisions,
)

if config.pad_min_duration is not None:
cuts = cuts.pad(duration=config.pad_min_duration, direction=config.pad_direction, preserve_id=True)

# Duration filtering, same as native NeMo dataloaders.
# We can filter after the augmentations because they are applied only when calling load_audio().
cuts = cuts.filter(DurationFilter(config.min_duration, config.max_duration))
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/common/prompts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from nemo.collections.common.prompts.canary import CanaryPromptFormatter
from nemo.collections.common.prompts.canary2 import Canary2PromptFormatter
from nemo.collections.common.prompts.fn import get_prompt_format_fn, registered_prompt_format_fn
from nemo.collections.common.prompts.formatter import PromptFormatter
from nemo.collections.common.prompts.gemma import GemmaPromptFormatter
Expand Down
Loading
Loading