-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New extended prompt format for Canary, short utterances inference fix, and training micro-optimizations #11058
Changes from 16 commits
301d454
cc69cf4
936784e
aa79570
ffffd23
92e6375
9d17592
6c39efc
2fa9bed
f8f4964
8464f2e
aabcb76
4b1bfce
2b75b09
a84c255
f59d320
bcfb775
3418770
f405086
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,7 +70,8 @@ def lens_to_mask(lens, max_length): | |
Create a mask from a tensor of lengths. | ||
""" | ||
batch_size = lens.shape[0] | ||
mask = torch.arange(max_length).repeat(batch_size, 1).to(lens.device) < lens[:, None] | ||
arange = torch.arange(max_length, device=lens.device) | ||
mask = arange.expand(batch_size, max_length) < lens.unsqueeze(1) | ||
return mask | ||
|
||
|
||
|
@@ -697,24 +698,33 @@ def training_step(self, batch: PromptedAudioToTextMiniBatch, batch_nb): | |
return torch.tensor([0.0]) | ||
|
||
input_ids, labels = batch.get_decoder_inputs_outputs() | ||
input_ids_lens = batch.prompted_transcript_lens - 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixing off-by-one issue that included an extra padding frame in decoder masks |
||
|
||
num_frames = batch.audio_lens.sum().float() | ||
num_tokens = batch.prompted_transcript_lens.sum().float() | ||
tot_frames = torch.as_tensor(batch.audio.numel(), device=num_frames.device, dtype=torch.float) | ||
tot_tokens = torch.as_tensor(batch.prompted_transcript.numel(), device=num_frames.device, dtype=torch.float) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. micro-optimizations |
||
|
||
transf_log_probs, encoded_len, enc_states, enc_mask = self.forward( | ||
input_signal=batch.audio, | ||
input_signal_length=batch.audio_lens, | ||
transcript=input_ids, | ||
transcript_length=batch.prompted_transcript_lens, | ||
transcript_length=input_ids_lens, | ||
) | ||
|
||
audio_loss = self.loss(log_probs=transf_log_probs, labels=labels) | ||
# Mask components: 1) discard padding & 2) discard prompt (notice the negation) | ||
# For a full decoder sequence O with len M, the loss mask skips the first element, | ||
# covering the remaining M-1 elements - hence we subtract 1 from prompt lens to account BOS. | ||
loss_mask = None | ||
if self.cfg.get("use_loss_mask_for_prompt", False): | ||
maxlen = batch.prompted_transcript.shape[1] - 1 | ||
loss_mask = lens_to_mask(input_ids_lens, maxlen) & ~lens_to_mask(batch.prompt_lens - 1, maxlen) | ||
audio_loss = self.loss(log_probs=transf_log_probs, labels=labels, output_mask=loss_mask) | ||
|
||
num_frames = batch.audio_lens.sum() | ||
num_tokens = batch.prompted_transcript_lens.sum() | ||
tot_frames = batch.audio.numel() | ||
tot_tokens = batch.prompted_transcript.numel() | ||
tensorboard_logs = { | ||
'train_loss': audio_loss, | ||
'learning_rate': self._optimizer.param_groups[0]['lr'], | ||
'batch_size': batch.audio.shape[0], | ||
'learning_rate': torch.as_tensor(self._optimizer.param_groups[0]['lr']), | ||
|
||
'batch_size': torch.as_tensor(batch.audio.shape[0]), | ||
|
||
'num_frames': num_frames, | ||
'num_tokens': num_tokens, | ||
'input_to_padding_ratio': num_frames / tot_frames, | ||
|
@@ -725,6 +735,7 @@ def training_step(self, batch: PromptedAudioToTextMiniBatch, batch_nb): | |
|
||
def validation_pass(self, batch: PromptedAudioToTextMiniBatch, batch_idx, dataloader_idx=0, eval_mode="val"): | ||
input_ids, labels = batch.get_decoder_inputs_outputs() | ||
input_ids_lens = batch.prompted_transcript_lens - 1 | ||
|
||
transf_log_probs, encoded_len, enc_states, enc_mask = self.forward( | ||
input_signal=batch.audio, | ||
|
@@ -733,11 +744,16 @@ def validation_pass(self, batch: PromptedAudioToTextMiniBatch, batch_idx, datalo | |
transcript_length=batch.prompted_transcript_lens, | ||
) | ||
|
||
transf_loss = self.loss(log_probs=transf_log_probs, labels=labels) | ||
self.val_loss(loss=transf_loss, num_measurements=transf_log_probs.shape[0] * transf_log_probs.shape[1]) | ||
output_dict = { | ||
f'{eval_mode}_loss': transf_loss, | ||
} | ||
# Mask components: 1) discard padding & 2) discard prompt (notice the negation) | ||
# For a full decoder sequence O with len M, the loss mask skips the first element, | ||
# covering the remaining M-1 elements - hence we subtract 1 from prompt lens to account BOS. | ||
loss_mask = None | ||
if self.cfg.get("use_loss_mask_for_prompt", False): | ||
maxlen = batch.prompted_transcript.shape[1] - 1 | ||
loss_mask = lens_to_mask(input_ids_lens, maxlen) & ~lens_to_mask(batch.prompt_lens - 1, maxlen) | ||
transf_loss = self.loss(log_probs=transf_log_probs, labels=labels, output_mask=loss_mask) | ||
self.val_loss(loss=transf_loss, num_measurements=loss_mask.long().sum()) | ||
output_dict = {f'{eval_mode}_loss': transf_loss} | ||
|
||
self.wer.update( | ||
predictions=enc_states, | ||
|
@@ -983,6 +999,8 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo | |
'text_field': config.get('text_field', 'answer'), | ||
'lang_field': config.get('lang_field', 'target_lang'), | ||
'channel_selector': config.get('channel_selector', None), | ||
'pad_min_duration': config.get('pad_min_duration', 1.0), | ||
'pad_direction': config.get('pad_direction', 'both'), | ||
} | ||
|
||
temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config)) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -58,27 +58,7 @@ def _build_pos_enc(self, hidden_size, max_sequence_length, device=None): | |
self.register_buffer('pos_enc', pos_enc) | ||
|
||
def forward(self, position_ids): | ||
max_pos_id = position_ids.max() | ||
# update positional encoding if needed | ||
if max_pos_id >= self._max_sequence_length: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This check is super costly as it triggers a DtoH transfer and CUDA sync on every call to transformer decoder forward, and the proposed solution doesn't work anyway (bad results instead of a crash). |
||
logging.warning( | ||
f'Max position id {max_pos_id} is greater than max sequence length {self._max_sequence_length}. Expanding position embeddings just for this batch. This is not expected to work very well. Consider chunking your input into smaller sequences.' | ||
) | ||
self._build_pos_enc( | ||
hidden_size=self._hidden_size, | ||
max_sequence_length=max_pos_id + 1, | ||
device=position_ids.device, | ||
) | ||
|
||
embeddings = torch.embedding(self.pos_enc, position_ids) | ||
|
||
# Revert expansion of position embeddings since this wall checkpoint size mismatches. | ||
if max_pos_id >= self._max_sequence_length: | ||
self._build_pos_enc( | ||
hidden_size=self._hidden_size, | ||
max_sequence_length=self._max_sequence_length, | ||
device=position_ids.device, | ||
) | ||
return embeddings | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
micro-optimization, removes some copies and memory movement