Skip to content

Commit

Permalink
Do infernece by casting model to bfloat16, not by using AMP.
Browse files Browse the repository at this point in the history
Do feature preprocessing in float32 for accuracy.

Warn if someone tries to input a non-float32 tensor.

Always create the output in the type the rest of the model expects.

Sort manifests by duration.
  • Loading branch information
galv committed May 17, 2024
1 parent f7626e0 commit a97048e
Show file tree
Hide file tree
Showing 11 changed files with 112 additions and 179 deletions.
57 changes: 32 additions & 25 deletions examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
from dataclasses import dataclass, is_dataclass
from tempfile import NamedTemporaryFile
import time
from typing import List, Optional, Union

import pytorch_lightning as pl
Expand Down Expand Up @@ -83,6 +84,8 @@
langid: Str used for convert_num_to_words during groundtruth cleaning
use_cer: Bool to use Character Error Rate (CER) or Word Error Rate (WER)
calculate_rtfx: Bool to calculate the RTFx throughput to transcribe the input dataset.
# Usage
ASR model can be specified by either "model_path" or "pretrained_name".
Data for transcription can be defined with either "audio_dir" or "dataset_manifest".
Expand Down Expand Up @@ -152,6 +155,7 @@ class TranscriptionConfig:
allow_mps: bool = False # allow to select MPS device (Apple Silicon M-series GPU)
amp: bool = False
amp_dtype: str = "float16" # can be set to "float16" or "bfloat16" when using amp
compute_dtype: str = "float32"
matmul_precision: str = "highest" # Literal["highest", "high", "medium"]
audio_type: str = "wav"

Expand Down Expand Up @@ -201,6 +205,8 @@ class TranscriptionConfig:
allow_partial_transcribe: bool = False
extract_nbest: bool = False # Extract n-best hypotheses from the model

calculate_rtfx: bool = False


@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis]]:
Expand Down Expand Up @@ -257,10 +263,15 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis

trainer = pl.Trainer(devices=device, accelerator=accelerator)
asr_model.set_trainer(trainer)

# assert all(param.requires_grad for param in asr_model.parameters())
asr_model = asr_model.eval()
# assert all(param.requires_grad for param in asr_model.parameters())

if cfg.compute_dtype != "float32" and cfg.amp:
raise ValueError("amp=true is mutually exclusive with a compute_dtype other than float32")

amp_dtype = torch.float16 if cfg.amp_dtype == "float16" else torch.bfloat16

if cfg.compute_dtype != "float32":
asr_model.to(getattr(torch, cfg.compute_dtype))

# we will adjust this flag if the model does not support it
compute_timestamps = cfg.compute_timestamps
Expand Down Expand Up @@ -374,7 +385,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
else:

@contextlib.contextmanager
def autocast(dtype=None):
def autocast(dtype=None, enabled=True):
yield

# Compute output filename
Expand All @@ -390,9 +401,17 @@ def autocast(dtype=None):

# transcribe audio

amp_dtype = torch.float16 if cfg.amp_dtype == "float16" else torch.bfloat16
if cfg.calculate_rtfx:
total_duration = 0.0

with autocast(dtype=amp_dtype):
with open(cfg.dataset_manifest, "rt") as fh:
for line in fh:
item = json.loads(line)
if "duration" not in item:
raise ValueError(f"Requested calculate_rtfx=True, but line {line} in manifest {cfg.dataset_manifest} lacks a 'duration' field.")
total_duration += item["duration"]

with autocast(dtype=amp_dtype, enabled=cfg.amp):
with torch.no_grad():
if partial_audio:
transcriptions = transcribe_partial_audio(
Expand All @@ -414,27 +433,12 @@ def autocast(dtype=None):
override_cfg.augmentor = augmentor
override_cfg.text_field = cfg.gt_text_attr_name
override_cfg.lang_field = cfg.gt_lang_attr_name
# assert all(param.requires_grad for param in asr_model.parameters())
for i in range(5):
if i == 1:
# import nvtx
# pr = nvtx.Profile()
# pr.enable() # begin annotating function calls
# ctx = torch.autograd.profiler.emit_nvtx()
# ctx.__enter__()
torch.cuda.cudart().cudaProfilerStart()
import time

if cfg.calculate_rtfx:
start_time = time.time()
# assert all(param.requires_grad for param in asr_model.parameters())
transcriptions = asr_model.transcribe(audio=filepaths, override_config=override_cfg,)
# assert all(param.requires_grad for param in asr_model.parameters())
end_time = time.time()
print(5.1 * 60 * 60 / (end_time - start_time))
if i == 1:
# pr.disable()
# ctx.__exit__(None, None, None)
torch.cuda.cudart().cudaProfilerStop()
transcriptions = asr_model.transcribe(audio=filepaths, override_config=override_cfg,)
if cfg.calculate_rtfx:
transcribe_time = time.time() - start_time

if cfg.dataset_manifest is not None:
logging.info(f"Finished transcribing from manifest file: {cfg.dataset_manifest}")
Expand Down Expand Up @@ -486,6 +490,9 @@ def autocast(dtype=None):
logging.info(f"Writing prediction and error rate of each sample to {output_manifest_w_wer}!")
logging.info(f"{total_res}")

if cfg.calculate_rtfx:
logging.info(f"Dataset RTFx {(transcribe_time/total_duration):.2}")

return cfg


Expand Down
21 changes: 9 additions & 12 deletions nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,6 @@ def forward(
if self.spec_augmentation is not None and self.training:
processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length)

# assert all(param.requires_grad for param in self.parameters())
encoder_output = self.encoder(audio_signal=processed_signal, length=processed_signal_length)
encoded = encoder_output[0]
encoded_len = encoder_output[1]
Expand Down Expand Up @@ -670,18 +669,17 @@ def _transcribe_on_begin(self, audio, trcfg: TranscribeConfig):
super()._transcribe_on_begin(audio, trcfg)

# Freeze the encoder and decoder modules
# self.encoder.freeze()
# self.decoder.freeze()
self.encoder.freeze()
self.decoder.freeze()

def _transcribe_on_end(self, trcfg: TranscribeConfig):
super()._transcribe_on_end(trcfg)

# Unfreeze the encoder and decoder modules
# self.encoder.unfreeze()
# self.decoder.unfreeze()
self.encoder.unfreeze()
self.decoder.unfreeze()

def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig):
# assert all(param.requires_grad for param in self.parameters())
logits, logits_len, greedy_predictions = self.forward(input_signal=batch[0], input_signal_length=batch[1])
output = dict(logits=logits, logits_len=logits_len)
del greedy_predictions
Expand All @@ -701,19 +699,18 @@ def _transcribe_output_processing(self, outputs, trcfg: TranscribeConfig) -> Gen
# See comment in
# ctc_greedy_decoding.py::GreedyCTCInfer::forward() to
# understand this idiom.

# This is way way wayyyyy too slow. A single
# cudaHostAlloc takes an average of 10ms if the
# caching allocator fails to return an
# existing allocation. Consider reverting this change.
logits_cpu = torch.empty(logits.shape, dtype=logits.dtype, device=torch.device("cpu"), pin_memory=True)
logits_cpu.copy_(logits, non_blocking=True)
else:
logits_cpu = logits
logits_len = logits_len.cpu()
# dump log probs per file
for idx in range(logits_cpu.shape[0]):
current_hypotheses[idx].y_sequence = logits_cpu[idx][: logits_len[idx]]
# We clone because we don't want references to the
# cudaMallocHost()-allocated tensor to be floating
# around. Were that to be the case, then the pinned
# memory cache would always miss.
current_hypotheses[idx].y_sequence = logits_cpu[idx, :logits_len[idx]].clone()
if current_hypotheses[idx].alignments is None:
current_hypotheses[idx].alignments = current_hypotheses[idx].y_sequence
del logits_cpu
Expand Down
23 changes: 20 additions & 3 deletions nemo/collections/asr/modules/audio_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
)
from nemo.core.utils import numba_utils
from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__
from nemo.utils import logging
from nemo.utils import logging, logging_mode

try:
import torchaudio
Expand Down Expand Up @@ -85,11 +85,28 @@ def __init__(self, win_length, hop_length):
None: torch.ones,
}

# Normally, when you call to(dtype) on a torch.nn.Module, all
# floating point parameters and buffers will change to that
# dtype, rather than being float32. The AudioPreprocessor
# classes, uniquely, don't actually have any parameters or
# buffers from what I see. In addition, we want the input to
# the preprocessor to be float32, but need to create the
# output in appropriate precision. We have this empty tensor
# here just to detect which dtype tensor this module should
# output at the end of execution.
self.register_buffer("dtype_sentinel_tensor",
torch.tensor((), dtype=torch.float32),
persistent=False)

@typecheck()
@torch.no_grad()
def forward(self, input_signal, length):
processed_signal, processed_length = self.get_features(input_signal, length)

if input_signal.dtype != torch.float32:
logging.warn(
f"AudioPreprocessor received an input signal of dtype {input_signal.dtype}, rather than torch.float32. In sweeps across multiple datasets, we have found that the preprocessor is not robust to low precision mathematics. As such, it runs in float32. Your input will be cast to float32, but this is not necessarily enough to recovery full accuracy. For example, simply casting input_signal from torch.float32 to torch.bfloat16, then back to torch.float32 before running AudioPreprocessor causes drops in absolute WER of up to 0.1%. torch.bfloat16 simply does not have enough mantissa bits to represent enough values in the range [-1.0,+1.0] correctly.",
mode=logging_mode.ONCE)
processed_signal, processed_length = self.get_features(input_signal.to(torch.float32), length)
processed_signal = processed_signal.to(self.dtype_sentinel_tensor.dtype)
return processed_signal, processed_length

@abstractmethod
Expand Down
9 changes: 2 additions & 7 deletions nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,6 @@ def __init__(
conv_kernel_size=conv_kernel_size,
)

# import ipdb; ipdb.set_trace()

if xscaling:
self.xscale = math.sqrt(d_model)
else:
Expand Down Expand Up @@ -552,7 +550,6 @@ def forward_internal(
cache_len = 0
offset = None

# Need to cast pos_emb to float16... or something.
audio_signal, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len)

# Create the self-attention and padding masks
Expand Down Expand Up @@ -678,9 +675,9 @@ def set_max_audio_length(self, max_audio_length):
"""
self.max_audio_length = max_audio_length
device = next(self.parameters()).device
self.pos_enc.extend_pe(max_audio_length, device)
dtype = next(self.parameters()).dtype
self.pos_enc.extend_pe(max_audio_length, device, dtype)

# Why is this rerun every time that forward() is called? Seems like it needs to run only once
def _create_masks(self, att_context_size, padding_length, max_audio_length, offset, device):
if self.self_attention_model != "rel_pos_local_attn":
att_mask = torch.ones(1, max_audio_length, max_audio_length, dtype=torch.bool, device=device)
Expand Down Expand Up @@ -792,7 +789,6 @@ def _calc_context_sizes(
return att_context_size_all, att_context_size_all[0], att_context_probs, conv_context_size

def set_default_att_context_size(self, att_context_size):
print("GALVEZ:", self.att_context_size_all)
if att_context_size not in self.att_context_size_all:
logging.warning(
f"att_context_size={att_context_size} is not among the list of the supported look-aheads: {self.att_context_size_all}"
Expand Down Expand Up @@ -831,7 +827,6 @@ def setup_streaming_params(
if chunk_size < 1:
raise ValueError("chunk_size needs to be a number larger or equal to one.")
lookahead_steps = chunk_size - 1
# So it looks like it retains its own cache on its own?
streaming_cfg.cache_drop_size = chunk_size - shift_size
elif self.att_context_style == "chunked_limited":
lookahead_steps = att_context_size[1]
Expand Down
5 changes: 3 additions & 2 deletions nemo/collections/asr/modules/squeezeformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,15 +280,16 @@ def set_max_audio_length(self, max_audio_length):
"""
self.max_audio_length = max_audio_length
device = next(self.parameters()).device
dtype = next(self.parameters()).dtype
seq_range = torch.arange(0, self.max_audio_length, device=device)
if hasattr(self, 'seq_range'):
self.seq_range = seq_range
else:
self.register_buffer('seq_range', seq_range, persistent=False)
self.pos_enc.extend_pe(max_audio_length, device)
self.pos_enc.extend_pe(max_audio_length, device, dtype)

if self.time_reduce_pos_enc is not None:
self.time_reduce_pos_enc.extend_pe(max_audio_length, device)
self.time_reduce_pos_enc.extend_pe(max_audio_length, device, dtype)

@typecheck()
def forward(self, audio_signal, length=None):
Expand Down
27 changes: 7 additions & 20 deletions nemo/collections/asr/parts/mixins/transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,6 @@ def transcribe(
- Dict[str, List[str/Hypothesis]]
"""

# assert all(param.requires_grad for param in self.parameters())

if override_config is None:
transcribe_cfg = TranscribeConfig(
batch_size=batch_size,
Expand Down Expand Up @@ -273,10 +271,9 @@ def transcribe(
# Hold the results here
results = None # type: GenericTranscriptionType

# assert all(param.requires_grad for param in self.parameters())
try:
generator = self.transcribe_generator(audio, override_config=transcribe_cfg)
# assert all(param.requires_grad for param in self.parameters())

for processed_outputs in generator:
# Store results
if isinstance(processed_outputs, list):
Expand Down Expand Up @@ -366,9 +363,7 @@ def transcribe_generator(self, audio, override_config: Optional[TranscribeConfig

try:
# Initialize and assert the transcription environment
# assert all(param.requires_grad for param in self.parameters())
self._transcribe_on_begin(audio, transcribe_cfg)
# assert all(param.requires_grad for param in self.parameters())

# Work in tmp directory - will store manifest file there
with tempfile.TemporaryDirectory() as tmpdir:
Expand All @@ -387,9 +382,7 @@ def transcribe_generator(self, audio, override_config: Optional[TranscribeConfig

for test_batch in tqdm(dataloader, desc="Transcribing", disable=not verbose):
# Move batch to device
# assert all(param.requires_grad for param in self.parameters())
test_batch = move_to_device(test_batch, transcribe_cfg._internal.device)
# assert all(param.requires_grad for param in self.parameters())
# Run forward pass
model_outputs = self._transcribe_forward(test_batch, transcribe_cfg)
processed_outputs = self._transcribe_output_processing(model_outputs, transcribe_cfg)
Expand Down Expand Up @@ -459,9 +452,7 @@ def _transcribe_on_begin(self, audio, trcfg: TranscribeConfig):
self.preprocessor.featurizer.pad_to = 0

# Switch model to evaluation mode
# assert all(param.requires_grad for param in self.parameters())
self.eval()
# assert all(param.requires_grad for param in self.parameters())

# Disable logging
trcfg._internal.logging_level = logging.get_verbosity()
Expand Down Expand Up @@ -761,19 +752,15 @@ def _transcribe_on_begin(self, audio, trcfg: TranscribeConfig):
"""
super()._transcribe_on_begin(audio, trcfg)

# assert all(param.requires_grad for param in self.parameters())

# Freeze the encoder and decoder modules
# if hasattr(self, 'encoder'):
# self.encoder.freeze()
if hasattr(self, 'encoder'):
self.encoder.freeze()

# if hasattr(self, 'decoder'):
# self.decoder.freeze()
if hasattr(self, 'decoder'):
self.decoder.freeze()

# if hasattr(self, 'joint'):
# self.joint.freeze()
# assert all(param.requires_grad for param in self.parameters())
# import ipdb; ipdb.set_trace()
if hasattr(self, 'joint'):
self.joint.freeze()

def _transcribe_on_end(self, trcfg: TranscribeConfig):
"""
Expand Down
Loading

0 comments on commit a97048e

Please sign in to comment.