diff --git a/examples/asr/transcribe_speech.py b/examples/asr/transcribe_speech.py index ae82ba2de16cb..390e1ae2e226b 100644 --- a/examples/asr/transcribe_speech.py +++ b/examples/asr/transcribe_speech.py @@ -18,6 +18,7 @@ import os from dataclasses import dataclass, is_dataclass from tempfile import NamedTemporaryFile +import time from typing import List, Optional, Union import pytorch_lightning as pl @@ -83,6 +84,8 @@ langid: Str used for convert_num_to_words during groundtruth cleaning use_cer: Bool to use Character Error Rate (CER) or Word Error Rate (WER) + calculate_rtfx: Bool to calculate the RTFx throughput to transcribe the input dataset. + # Usage ASR model can be specified by either "model_path" or "pretrained_name". Data for transcription can be defined with either "audio_dir" or "dataset_manifest". @@ -152,6 +155,7 @@ class TranscriptionConfig: allow_mps: bool = False # allow to select MPS device (Apple Silicon M-series GPU) amp: bool = False amp_dtype: str = "float16" # can be set to "float16" or "bfloat16" when using amp + compute_dtype: str = "float32" matmul_precision: str = "highest" # Literal["highest", "high", "medium"] audio_type: str = "wav" @@ -201,6 +205,8 @@ class TranscriptionConfig: allow_partial_transcribe: bool = False extract_nbest: bool = False # Extract n-best hypotheses from the model + calculate_rtfx: bool = False + @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis]]: @@ -257,10 +263,15 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis trainer = pl.Trainer(devices=device, accelerator=accelerator) asr_model.set_trainer(trainer) - - # assert all(param.requires_grad for param in asr_model.parameters()) asr_model = asr_model.eval() - # assert all(param.requires_grad for param in asr_model.parameters()) + + if cfg.compute_dtype != "float32" and cfg.amp: + raise ValueError("amp=true is mutually exclusive with a compute_dtype other than float32") + + amp_dtype = torch.float16 if cfg.amp_dtype == "float16" else torch.bfloat16 + + if cfg.compute_dtype != "float32": + asr_model.to(getattr(torch, cfg.compute_dtype)) # we will adjust this flag if the model does not support it compute_timestamps = cfg.compute_timestamps @@ -374,7 +385,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis else: @contextlib.contextmanager - def autocast(dtype=None): + def autocast(dtype=None, enabled=True): yield # Compute output filename @@ -390,9 +401,17 @@ def autocast(dtype=None): # transcribe audio - amp_dtype = torch.float16 if cfg.amp_dtype == "float16" else torch.bfloat16 + if cfg.calculate_rtfx: + total_duration = 0.0 - with autocast(dtype=amp_dtype): + with open(cfg.dataset_manifest, "rt") as fh: + for line in fh: + item = json.loads(line) + if "duration" not in item: + raise ValueError(f"Requested calculate_rtfx=True, but line {line} in manifest {cfg.dataset_manifest} lacks a 'duration' field.") + total_duration += item["duration"] + + with autocast(dtype=amp_dtype, enabled=cfg.amp): with torch.no_grad(): if partial_audio: transcriptions = transcribe_partial_audio( @@ -414,27 +433,12 @@ def autocast(dtype=None): override_cfg.augmentor = augmentor override_cfg.text_field = cfg.gt_text_attr_name override_cfg.lang_field = cfg.gt_lang_attr_name - # assert all(param.requires_grad for param in asr_model.parameters()) - for i in range(5): - if i == 1: - # import nvtx - # pr = nvtx.Profile() - # pr.enable() # begin annotating function calls - # ctx = torch.autograd.profiler.emit_nvtx() - # ctx.__enter__() - torch.cuda.cudart().cudaProfilerStart() - import time + if cfg.calculate_rtfx: start_time = time.time() - # assert all(param.requires_grad for param in asr_model.parameters()) - transcriptions = asr_model.transcribe(audio=filepaths, override_config=override_cfg,) - # assert all(param.requires_grad for param in asr_model.parameters()) - end_time = time.time() - print(5.1 * 60 * 60 / (end_time - start_time)) - if i == 1: - # pr.disable() - # ctx.__exit__(None, None, None) - torch.cuda.cudart().cudaProfilerStop() + transcriptions = asr_model.transcribe(audio=filepaths, override_config=override_cfg,) + if cfg.calculate_rtfx: + transcribe_time = time.time() - start_time if cfg.dataset_manifest is not None: logging.info(f"Finished transcribing from manifest file: {cfg.dataset_manifest}") @@ -486,6 +490,9 @@ def autocast(dtype=None): logging.info(f"Writing prediction and error rate of each sample to {output_manifest_w_wer}!") logging.info(f"{total_res}") + if cfg.calculate_rtfx: + logging.info(f"Dataset RTFx {(transcribe_time/total_duration):.2}") + return cfg diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index 0c510c72cc06c..24ffe9b435031 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -502,7 +502,6 @@ def forward( if self.spec_augmentation is not None and self.training: processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length) - # assert all(param.requires_grad for param in self.parameters()) encoder_output = self.encoder(audio_signal=processed_signal, length=processed_signal_length) encoded = encoder_output[0] encoded_len = encoder_output[1] @@ -670,18 +669,17 @@ def _transcribe_on_begin(self, audio, trcfg: TranscribeConfig): super()._transcribe_on_begin(audio, trcfg) # Freeze the encoder and decoder modules - # self.encoder.freeze() - # self.decoder.freeze() + self.encoder.freeze() + self.decoder.freeze() def _transcribe_on_end(self, trcfg: TranscribeConfig): super()._transcribe_on_end(trcfg) # Unfreeze the encoder and decoder modules - # self.encoder.unfreeze() - # self.decoder.unfreeze() + self.encoder.unfreeze() + self.decoder.unfreeze() def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig): - # assert all(param.requires_grad for param in self.parameters()) logits, logits_len, greedy_predictions = self.forward(input_signal=batch[0], input_signal_length=batch[1]) output = dict(logits=logits, logits_len=logits_len) del greedy_predictions @@ -701,11 +699,6 @@ def _transcribe_output_processing(self, outputs, trcfg: TranscribeConfig) -> Gen # See comment in # ctc_greedy_decoding.py::GreedyCTCInfer::forward() to # understand this idiom. - - # This is way way wayyyyy too slow. A single - # cudaHostAlloc takes an average of 10ms if the - # caching allocator fails to return an - # existing allocation. Consider reverting this change. logits_cpu = torch.empty(logits.shape, dtype=logits.dtype, device=torch.device("cpu"), pin_memory=True) logits_cpu.copy_(logits, non_blocking=True) else: @@ -713,7 +706,11 @@ def _transcribe_output_processing(self, outputs, trcfg: TranscribeConfig) -> Gen logits_len = logits_len.cpu() # dump log probs per file for idx in range(logits_cpu.shape[0]): - current_hypotheses[idx].y_sequence = logits_cpu[idx][: logits_len[idx]] + # We clone because we don't want references to the + # cudaMallocHost()-allocated tensor to be floating + # around. Were that to be the case, then the pinned + # memory cache would always miss. + current_hypotheses[idx].y_sequence = logits_cpu[idx, :logits_len[idx]].clone() if current_hypotheses[idx].alignments is None: current_hypotheses[idx].alignments = current_hypotheses[idx].y_sequence del logits_cpu diff --git a/nemo/collections/asr/modules/audio_preprocessing.py b/nemo/collections/asr/modules/audio_preprocessing.py index d45c0acf314fb..7ddf24680d35c 100644 --- a/nemo/collections/asr/modules/audio_preprocessing.py +++ b/nemo/collections/asr/modules/audio_preprocessing.py @@ -39,7 +39,7 @@ ) from nemo.core.utils import numba_utils from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__ -from nemo.utils import logging +from nemo.utils import logging, logging_mode try: import torchaudio @@ -85,11 +85,28 @@ def __init__(self, win_length, hop_length): None: torch.ones, } + # Normally, when you call to(dtype) on a torch.nn.Module, all + # floating point parameters and buffers will change to that + # dtype, rather than being float32. The AudioPreprocessor + # classes, uniquely, don't actually have any parameters or + # buffers from what I see. In addition, we want the input to + # the preprocessor to be float32, but need to create the + # output in appropriate precision. We have this empty tensor + # here just to detect which dtype tensor this module should + # output at the end of execution. + self.register_buffer("dtype_sentinel_tensor", + torch.tensor((), dtype=torch.float32), + persistent=False) + @typecheck() @torch.no_grad() def forward(self, input_signal, length): - processed_signal, processed_length = self.get_features(input_signal, length) - + if input_signal.dtype != torch.float32: + logging.warn( + f"AudioPreprocessor received an input signal of dtype {input_signal.dtype}, rather than torch.float32. In sweeps across multiple datasets, we have found that the preprocessor is not robust to low precision mathematics. As such, it runs in float32. Your input will be cast to float32, but this is not necessarily enough to recovery full accuracy. For example, simply casting input_signal from torch.float32 to torch.bfloat16, then back to torch.float32 before running AudioPreprocessor causes drops in absolute WER of up to 0.1%. torch.bfloat16 simply does not have enough mantissa bits to represent enough values in the range [-1.0,+1.0] correctly.", + mode=logging_mode.ONCE) + processed_signal, processed_length = self.get_features(input_signal.to(torch.float32), length) + processed_signal = processed_signal.to(self.dtype_sentinel_tensor.dtype) return processed_signal, processed_length @abstractmethod diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index cfcdbf2bd8b0d..f63f7beba3f3f 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -321,8 +321,6 @@ def __init__( conv_kernel_size=conv_kernel_size, ) - # import ipdb; ipdb.set_trace() - if xscaling: self.xscale = math.sqrt(d_model) else: @@ -552,7 +550,6 @@ def forward_internal( cache_len = 0 offset = None - # Need to cast pos_emb to float16... or something. audio_signal, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len) # Create the self-attention and padding masks @@ -678,9 +675,9 @@ def set_max_audio_length(self, max_audio_length): """ self.max_audio_length = max_audio_length device = next(self.parameters()).device - self.pos_enc.extend_pe(max_audio_length, device) + dtype = next(self.parameters()).dtype + self.pos_enc.extend_pe(max_audio_length, device, dtype) - # Why is this rerun every time that forward() is called? Seems like it needs to run only once def _create_masks(self, att_context_size, padding_length, max_audio_length, offset, device): if self.self_attention_model != "rel_pos_local_attn": att_mask = torch.ones(1, max_audio_length, max_audio_length, dtype=torch.bool, device=device) @@ -792,7 +789,6 @@ def _calc_context_sizes( return att_context_size_all, att_context_size_all[0], att_context_probs, conv_context_size def set_default_att_context_size(self, att_context_size): - print("GALVEZ:", self.att_context_size_all) if att_context_size not in self.att_context_size_all: logging.warning( f"att_context_size={att_context_size} is not among the list of the supported look-aheads: {self.att_context_size_all}" @@ -831,7 +827,6 @@ def setup_streaming_params( if chunk_size < 1: raise ValueError("chunk_size needs to be a number larger or equal to one.") lookahead_steps = chunk_size - 1 - # So it looks like it retains its own cache on its own? streaming_cfg.cache_drop_size = chunk_size - shift_size elif self.att_context_style == "chunked_limited": lookahead_steps = att_context_size[1] diff --git a/nemo/collections/asr/modules/squeezeformer_encoder.py b/nemo/collections/asr/modules/squeezeformer_encoder.py index ce0d49843d4f8..b00dd53c6b4a4 100644 --- a/nemo/collections/asr/modules/squeezeformer_encoder.py +++ b/nemo/collections/asr/modules/squeezeformer_encoder.py @@ -280,15 +280,16 @@ def set_max_audio_length(self, max_audio_length): """ self.max_audio_length = max_audio_length device = next(self.parameters()).device + dtype = next(self.parameters()).dtype seq_range = torch.arange(0, self.max_audio_length, device=device) if hasattr(self, 'seq_range'): self.seq_range = seq_range else: self.register_buffer('seq_range', seq_range, persistent=False) - self.pos_enc.extend_pe(max_audio_length, device) + self.pos_enc.extend_pe(max_audio_length, device, dtype) if self.time_reduce_pos_enc is not None: - self.time_reduce_pos_enc.extend_pe(max_audio_length, device) + self.time_reduce_pos_enc.extend_pe(max_audio_length, device, dtype) @typecheck() def forward(self, audio_signal, length=None): diff --git a/nemo/collections/asr/parts/mixins/transcription.py b/nemo/collections/asr/parts/mixins/transcription.py index b8ca4cb44e072..af5f7ae444f48 100644 --- a/nemo/collections/asr/parts/mixins/transcription.py +++ b/nemo/collections/asr/parts/mixins/transcription.py @@ -235,8 +235,6 @@ def transcribe( - Dict[str, List[str/Hypothesis]] """ - # assert all(param.requires_grad for param in self.parameters()) - if override_config is None: transcribe_cfg = TranscribeConfig( batch_size=batch_size, @@ -273,10 +271,9 @@ def transcribe( # Hold the results here results = None # type: GenericTranscriptionType - # assert all(param.requires_grad for param in self.parameters()) try: generator = self.transcribe_generator(audio, override_config=transcribe_cfg) - # assert all(param.requires_grad for param in self.parameters()) + for processed_outputs in generator: # Store results if isinstance(processed_outputs, list): @@ -366,9 +363,7 @@ def transcribe_generator(self, audio, override_config: Optional[TranscribeConfig try: # Initialize and assert the transcription environment - # assert all(param.requires_grad for param in self.parameters()) self._transcribe_on_begin(audio, transcribe_cfg) - # assert all(param.requires_grad for param in self.parameters()) # Work in tmp directory - will store manifest file there with tempfile.TemporaryDirectory() as tmpdir: @@ -387,9 +382,7 @@ def transcribe_generator(self, audio, override_config: Optional[TranscribeConfig for test_batch in tqdm(dataloader, desc="Transcribing", disable=not verbose): # Move batch to device - # assert all(param.requires_grad for param in self.parameters()) test_batch = move_to_device(test_batch, transcribe_cfg._internal.device) - # assert all(param.requires_grad for param in self.parameters()) # Run forward pass model_outputs = self._transcribe_forward(test_batch, transcribe_cfg) processed_outputs = self._transcribe_output_processing(model_outputs, transcribe_cfg) @@ -459,9 +452,7 @@ def _transcribe_on_begin(self, audio, trcfg: TranscribeConfig): self.preprocessor.featurizer.pad_to = 0 # Switch model to evaluation mode - # assert all(param.requires_grad for param in self.parameters()) self.eval() - # assert all(param.requires_grad for param in self.parameters()) # Disable logging trcfg._internal.logging_level = logging.get_verbosity() @@ -761,19 +752,15 @@ def _transcribe_on_begin(self, audio, trcfg: TranscribeConfig): """ super()._transcribe_on_begin(audio, trcfg) - # assert all(param.requires_grad for param in self.parameters()) - # Freeze the encoder and decoder modules - # if hasattr(self, 'encoder'): - # self.encoder.freeze() + if hasattr(self, 'encoder'): + self.encoder.freeze() - # if hasattr(self, 'decoder'): - # self.decoder.freeze() + if hasattr(self, 'decoder'): + self.decoder.freeze() - # if hasattr(self, 'joint'): - # self.joint.freeze() - # assert all(param.requires_grad for param in self.parameters()) - # import ipdb; ipdb.set_trace() + if hasattr(self, 'joint'): + self.joint.freeze() def _transcribe_on_end(self, trcfg: TranscribeConfig): """ diff --git a/nemo/collections/asr/parts/submodules/conformer_modules.py b/nemo/collections/asr/parts/submodules/conformer_modules.py index 838fc54f183dd..22d98c349933a 100644 --- a/nemo/collections/asr/parts/submodules/conformer_modules.py +++ b/nemo/collections/asr/parts/submodules/conformer_modules.py @@ -83,11 +83,11 @@ def __init__( self.fc_factor = 0.5 # first feed forward module - self.norm_feed_forward1 = LayerNorm(d_model).half() + self.norm_feed_forward1 = LayerNorm(d_model) self.feed_forward1 = ConformerFeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout) # convolution module - self.norm_conv = LayerNorm(d_model).half() + self.norm_conv = LayerNorm(d_model) self.conv = ConformerConvolution( d_model=d_model, kernel_size=conv_kernel_size, @@ -96,7 +96,7 @@ def __init__( ) # multi-headed self-attention module - self.norm_self_att = LayerNorm(d_model).half() + self.norm_self_att = LayerNorm(d_model) MHA_max_cache_len = att_context_size[0] if self_attention_model == 'rel_pos': @@ -132,11 +132,11 @@ def __init__( ) # second feed forward module - self.norm_feed_forward2 = LayerNorm(d_model).half() + self.norm_feed_forward2 = LayerNorm(d_model) self.feed_forward2 = ConformerFeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout) self.dropout = nn.Dropout(dropout) - self.norm_out = LayerNorm(d_model).half() + self.norm_out = LayerNorm(d_model) def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None, cache_last_channel=None, cache_last_time=None): """ @@ -153,16 +153,11 @@ def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None, cache_last_chan cache_last_time (torch.tensor) : next cache for convolutional layers (B, d_model, T_cache) """ residual = x - with torch.cuda.amp.autocast(enabled=False): - x = self.norm_feed_forward1(x) - assert x.dtype == torch.float16 + x = self.norm_feed_forward1(x) x = self.feed_forward1(x) residual = residual + self.dropout(x) * self.fc_factor - assert residual.dtype == torch.float16 - - with torch.cuda.amp.autocast(enabled=False): - x = self.norm_self_att(residual) + x = self.norm_self_att(residual) if self.self_attention_model == 'rel_pos': x = self.self_attn(query=x, key=x, value=x, mask=att_mask, pos_emb=pos_emb, cache=cache_last_channel) elif self.self_attention_model == 'rel_pos_local_attn': @@ -188,20 +183,17 @@ def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None, cache_last_chan pack_ip = self.forward_enabled_adapters(pack_ip) residual = pack_ip['x'] - with torch.cuda.amp.autocast(enabled=False): - x = self.norm_conv(residual) + x = self.norm_conv(residual) x = self.conv(x, pad_mask=pad_mask, cache=cache_last_time) if cache_last_time is not None: (x, cache_last_time) = x residual = residual + self.dropout(x) - with torch.cuda.amp.autocast(enabled=False): - x = self.norm_feed_forward2(residual) + x = self.norm_feed_forward2(residual) x = self.feed_forward2(x) residual = residual + self.dropout(x) * self.fc_factor - with torch.cuda.amp.autocast(enabled=False): - x = self.norm_out(residual) + x = self.norm_out(residual) if self.is_adapter_available(): # Call the adapters @@ -356,8 +348,6 @@ def forward(self, x, pad_mask=None, cache=None): x = self.pointwise_activation(x) if pad_mask is not None: - # TODO: Get rid of this? - # assert False x = x.masked_fill(pad_mask.unsqueeze(1), 0.0) x = self.depthwise_conv(x, cache=cache) @@ -407,28 +397,10 @@ def __init__(self, d_model, d_ff, dropout, activation=Swish()): self.linear2 = nn.Linear(d_ff, d_model) def forward(self, x): - assert torch.is_inference_mode_enabled() - assert self.linear1.weight.dtype == torch.float32 - assert self.linear1.bias.dtype == torch.float32 - assert self.linear1.weight.is_leaf - assert self.linear1.bias.is_leaf - assert self.linear1.weight.requires_grad - assert self.linear1.bias.requires_grad - # TODO: is_view - assert torch.is_autocast_cache_enabled() - - assert x.dtype == torch.float16 - # It seems as though it is constantly casting the weight to - # fp16 here for some reason... x = self.linear1(x) - # import ipdb; ipdb.set_trace() - assert x.dtype == torch.float16 x = self.activation(x) - assert x.dtype == torch.float16 x = self.dropout(x) - assert x.dtype == torch.float16 x = self.linear2(x) - assert x.dtype == torch.float16 return x def reset_parameters_ff(self): diff --git a/nemo/collections/asr/parts/submodules/multi_head_attention.py b/nemo/collections/asr/parts/submodules/multi_head_attention.py index 52b81cf9461c3..8a0406994c62c 100644 --- a/nemo/collections/asr/parts/submodules/multi_head_attention.py +++ b/nemo/collections/asr/parts/submodules/multi_head_attention.py @@ -107,12 +107,10 @@ def forward_attention(self, value, scores, mask): n_batch = value.size(0) if mask is not None: mask = mask.unsqueeze(1) # (batch, 1, time1, time2) - with torch.cuda.amp.autocast(enabled=False): - scores = scores.masked_fill(mask, -10000.0) - attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2) + scores = scores.masked_fill(mask, -10000.0) + attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2) else: - with torch.cuda.amp.autocast(enabled=False): - attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) p_attn = self.dropout(attn) x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) @@ -173,11 +171,8 @@ def __init__(self, n_head, n_feat, dropout_rate, pos_bias_u, pos_bias_v, max_cac # these two learnable biases are used in matrix c and matrix d # as described in https://arxiv.org/abs/1901.02860 Section 3.3 if pos_bias_u is None or pos_bias_v is None: - # AMP does not really handle this situation that well. It - # would be better just to call half() on the model rather - # than do these hacks. - self.pos_bias_u = nn.Parameter(torch.HalfTensor(self.h, self.d_k)) - self.pos_bias_v = nn.Parameter(torch.HalfTensor(self.h, self.d_k)) + self.pos_bias_u = nn.Parameter(torch.FloatTensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.FloatTensor(self.h, self.d_k)) # nn.init.normal_(self.pos_bias_u, 0.0, 0.02) # nn.init.normal_(self.pos_bias_v, 0.0, 0.02) nn.init.zeros_(self.pos_bias_u) @@ -215,28 +210,19 @@ def forward(self, query, key, value, mask, pos_emb, cache=None): """ key, value, query, cache = self.update_cache(key=key, value=value, query=query, cache=cache) - # if torch.is_autocast_enabled(): - # query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32) + if torch.is_autocast_enabled(): + query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32) # temporary until we solve this more gracefully - from contextlib import nullcontext - - with nullcontext(): # avoid_float16_autocast_context(): + with avoid_float16_autocast_context(): q, k, v = self.forward_qkv(query, key, value) q = q.transpose(1, 2) # (batch, time1, head, d_k) n_batch_pos = pos_emb.size(0) - # embedding is not affected by autocast. Ignore for now... - - # Could make a custom torch.nn.Module that checks if we're - # in inference mode, and then caches its own weights' - # casted versions that way... - # assert pos_emb.dtype == torch.float16 p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) p = p.transpose(1, 2) # (batch, head, time1, d_k) # (batch, head, time1, d_k) - # We are stuck casting this up to float32... ugh. q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) # (batch, head, time1, d_k) q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) @@ -245,17 +231,10 @@ def forward(self, query, key, value, mask, pos_emb, cache=None): # first compute matrix a and matrix c # as described in https://arxiv.org/abs/1901.02860 Section 3.3 # (batch, head, time1, time2) - assert q_with_bias_u.dtype == torch.float16 - assert k.dtype == torch.float16 - # This has some reshapes associated with it. For some - # reason the transpose needs to be done explicitly... It's - # a bmm call as well... - # TODO: Consider whether we can do better matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) # compute matrix b and matrix d # (batch, head, time1, time2) - # TODO: Here too... matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) matrix_bd = self.rel_shift(matrix_bd) # drops extra elements in the matrix_bd to match the matrix_ac's size @@ -912,7 +891,7 @@ def __init__(self, d_model, dropout_rate, max_len=5000, xscale=None, dropout_rat else: self.dropout_emb = None - def create_pe(self, positions): + def create_pe(self, positions, dtype): pos_length = positions.size(0) pe = torch.zeros(pos_length, self.d_model, device=positions.device) div_term = torch.exp( @@ -921,18 +900,18 @@ def create_pe(self, positions): ) pe[:, 0::2] = torch.sin(positions * div_term) pe[:, 1::2] = torch.cos(positions * div_term) - pe = pe.unsqueeze(0) + pe = pe.unsqueeze(0).to(dtype) if hasattr(self, 'pe'): self.pe = pe else: self.register_buffer('pe', pe, persistent=False) - def extend_pe(self, length, device): + def extend_pe(self, length, device, dtype): """Reset and extend the positional encodings if needed.""" if hasattr(self, 'pe') and self.pe.size(1) >= length: return positions = torch.arange(0, length, dtype=torch.float32, device=device).unsqueeze(1) - self.create_pe(positions=positions) + self.create_pe(positions=positions, dtype=dtype) def forward(self, x: torch.Tensor, cache_len=0): """Adds positional encoding. @@ -964,7 +943,7 @@ class RelPositionalEncoding(PositionalEncoding): dropout_rate_emb (float): dropout rate for the positional embeddings """ - def extend_pe(self, length, device): + def extend_pe(self, length, device, dtype): """Reset and extend the positional encodings if needed.""" needed_size = 2 * length - 1 if hasattr(self, 'pe') and self.pe.size(1) >= needed_size: @@ -972,7 +951,7 @@ def extend_pe(self, length, device): # positions would be from negative numbers to positive # positive positions would be used for left positions and negative for right positions positions = torch.arange(length - 1, -length, -1, dtype=torch.float32, device=device).unsqueeze(1) - self.create_pe(positions=positions) + self.create_pe(positions=positions, dtype=dtype) def forward(self, x, cache_len=0): """Compute positional encoding. @@ -1018,7 +997,7 @@ def __init__(self, att_context_size, **kwargs): self.left_context = att_context_size[0] self.right_context = att_context_size[1] - def extend_pe(self, length, device): + def extend_pe(self, length, device, dtype): """Reset and extend the positional encodings only at the beginning""" if hasattr(self, 'pe'): return @@ -1026,7 +1005,7 @@ def extend_pe(self, length, device): positions = torch.arange( self.left_context, -self.right_context - 1, -1, dtype=torch.float32, device=device ).unsqueeze(1) - self.create_pe(positions=positions) + self.create_pe(positions=positions, dtype=dtype) def forward(self, x, cache_len=0): """Compute positional encoding. diff --git a/nemo/collections/asr/parts/utils/transcribe_utils.py b/nemo/collections/asr/parts/utils/transcribe_utils.py index 8465406224e7d..e8165b4e6d971 100644 --- a/nemo/collections/asr/parts/utils/transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/transcribe_utils.py @@ -15,12 +15,14 @@ import json import os import re +from tempfile import NamedTemporaryFile from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Tuple, Union import torch from omegaconf import DictConfig +import soundfile as sf from tqdm.auto import tqdm import nemo.collections.asr as nemo_asr @@ -282,11 +284,18 @@ def prepare_audio_data(cfg: DictConfig) -> Tuple[List[str], bool]: logging.error(f"The input dataset_manifest {cfg.dataset_manifest} is empty. Exiting!") return None + audio_key = cfg.get('audio_key', 'audio_filepath') + + with open(cfg.dataset_manifest, "rt") as fh: + for line in fh: + item = json.loads(line) + item["audio_filepath"] = get_full_path(item["audio_filepath"], cfg.dataset_manifest) + if item.get("duration") is None and cfg.presort_manifest: + raise ValueError(f"Requested presort_manifest=True, but line {line} in manifest {cfg.dataset_manifest} lacks a 'duration' field.") all_entries_have_offset_and_duration = True for item in read_and_maybe_sort_manifest(cfg.dataset_manifest, try_sort=cfg.presort_manifest): if not ("offset" in item and "duration" in item): all_entries_have_offset_and_duration = False - audio_key = cfg.get('audio_key', 'audio_filepath') audio_file = get_full_path(audio_file=item[audio_key], manifest_file=cfg.dataset_manifest) filepaths.append(audio_file) partial_audio = all_entries_have_offset_and_duration diff --git a/nemo/core/classes/module.py b/nemo/core/classes/module.py index 531bcd4fdc65b..2d7bd0179447f 100644 --- a/nemo/core/classes/module.py +++ b/nemo/core/classes/module.py @@ -55,39 +55,8 @@ def freeze(self) -> None: r""" Freeze all params for inference. """ - # Freezing your parameters when you are running in mixed - # precision prevents the automatic mixed precision "cast - # cache" from working. That is to say, your fp32 weights will - # be repeatedly cast to fp16, even though this work could be - # cached. This is a serious problem, as it can make running in - # mixed precision slower than running in fp32. - - # See "arg.requires_grad()" here: https://github.com/pytorch/pytorch/blob/6f5f405b057c7de0f5fce0b1432cb74468f96f95/aten/src/ATen/autocast_mode.cpp#L121C61-L121C80 - - # It won't use the cache unless the parameter requires a - # gradient. This honestly seems like a mistake in pytorch's - # AMP implementation. - - # TODO: This silly attempt does not really fix things. It will fix: - - # with torch.cuda.amp.autocast(): - # my_model.freeze() - # output = my_model(input) - - # But it doesn't fix the following situation: - - # my_model.freeze() - # with torch.cuda.amp.autocast(): - # output = my_model(input) - - # A better way to handle this might be to call - # register_forward_pre_hook on every thing that ineherits from - # NeuralModule, and have the register hook check if - # `torch.is_autocast_enabled() == True and any(not param.requires_grad for param in self.parameters)` - # If so, it should warn (or crash), because that is probably not what the user wants to be doing. - if not torch.is_autocast_enabled(): - for param in self.parameters(): - param.requires_grad = False + for param in self.parameters(): + param.requires_grad = False self.eval() diff --git a/nemo/utils/cast_utils.py b/nemo/utils/cast_utils.py index 0968e505220de..21e977ec494d8 100644 --- a/nemo/utils/cast_utils.py +++ b/nemo/utils/cast_utils.py @@ -37,7 +37,7 @@ def avoid_float16_autocast_context(): if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.float16: if torch.jit.is_scripting() or torch.jit.is_tracing(): - return torch.cuda.amp.autocast(enabled=False) + return torch.cuda.amp.autocast(dtype=torch.float32) if torch.cuda.is_bf16_supported(): return torch.cuda.amp.autocast(dtype=torch.bfloat16)