Skip to content

Commit

Permalink
Fix the "cast ping pong" problem when we run AMP inference.
Browse files Browse the repository at this point in the history
This has been tested only for Parakeet-CTC-1.1B right now. This
problem certainly exists elsewhere.

Automatic mixed precision and inference do not play well together.

First, automatic mixed precision was created back when neural networks
were much simpler. In particular, they did not have softmax and layer
norm as frequent operations. In the era of transformers, softmax and
layer norm are very common. AMP will uncoditionally output fp32
outputs from these operations, even if their inputs are fp16. See
here: https://pytorch.org/docs/stable/amp.html#cuda-ops-that-can-autocast-to-float32

This is no longer necessary, now that layer norm does accumulation in
fp32 in pytorch, even if the input is fp16:
pytorch/pytorch#66707

Do infernece by casting model to bfloat16, not by using AMP.

Do feature preprocessing in float32 for accuracy. Warn if someone
tries to input a non-float32 tensor.

Always create the output in the type the rest of the model expects.

Sort manifests by duration.

Signed-off-by: Daniel Galvez <[email protected]>
  • Loading branch information
galv committed Jun 5, 2024
1 parent 55a9738 commit a6ffea9
Show file tree
Hide file tree
Showing 10 changed files with 117 additions and 46 deletions.
38 changes: 35 additions & 3 deletions examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import glob
import json
import os
import time
from dataclasses import dataclass, field, is_dataclass
from tempfile import NamedTemporaryFile
from typing import List, Optional, Union
Expand Down Expand Up @@ -84,6 +85,8 @@
langid: Str used for convert_num_to_words during groundtruth cleaning
use_cer: Bool to use Character Error Rate (CER) or Word Error Rate (WER)
calculate_rtfx: Bool to calculate the RTFx throughput to transcribe the input dataset.
# Usage
ASR model can be specified by either "model_path" or "pretrained_name".
Data for transcription can be defined with either "audio_dir" or "dataset_manifest".
Expand Down Expand Up @@ -153,6 +156,7 @@ class TranscriptionConfig:
allow_mps: bool = False # allow to select MPS device (Apple Silicon M-series GPU)
amp: bool = False
amp_dtype: str = "float16" # can be set to "float16" or "bfloat16" when using amp
compute_dtype: str = "float32"
matmul_precision: str = "highest" # Literal["highest", "high", "medium"]
audio_type: str = "wav"

Expand Down Expand Up @@ -208,6 +212,8 @@ class TranscriptionConfig:
allow_partial_transcribe: bool = False
extract_nbest: bool = False # Extract n-best hypotheses from the model

calculate_rtfx: bool = False


@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis]]:
Expand Down Expand Up @@ -266,6 +272,14 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
asr_model.set_trainer(trainer)
asr_model = asr_model.eval()

if cfg.compute_dtype != "float32" and cfg.amp:
raise ValueError("amp=true is mutually exclusive with a compute_dtype other than float32")

amp_dtype = torch.float16 if cfg.amp_dtype == "float16" else torch.bfloat16

if cfg.compute_dtype != "float32":
asr_model.to(getattr(torch, cfg.compute_dtype))

# we will adjust this flag if the model does not support it
compute_timestamps = cfg.compute_timestamps
compute_langs = cfg.compute_langs
Expand Down Expand Up @@ -378,7 +392,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
else:

@contextlib.contextmanager
def autocast(dtype=None):
def autocast(dtype=None, enabled=True):
yield

# Compute output filename
Expand All @@ -394,10 +408,22 @@ def autocast(dtype=None):

# transcribe audio

amp_dtype = torch.float16 if cfg.amp_dtype == "float16" else torch.bfloat16
if cfg.calculate_rtfx:
total_duration = 0.0

with open(cfg.dataset_manifest, "rt") as fh:
for line in fh:
item = json.loads(line)
if "duration" not in item:
raise ValueError(
f"Requested calculate_rtfx=True, but line {line} in manifest {cfg.dataset_manifest} lacks a 'duration' field."
)
total_duration += item["duration"]

with autocast(dtype=amp_dtype):
with autocast(dtype=amp_dtype, enabled=cfg.amp):
with torch.no_grad():
if cfg.calculate_rtfx:
start_time = time.time()
if partial_audio:
transcriptions = transcribe_partial_audio(
asr_model=asr_model,
Expand All @@ -420,10 +446,13 @@ def autocast(dtype=None):
override_cfg.lang_field = cfg.gt_lang_attr_name
if hasattr(override_cfg, "prompt"):
override_cfg.prompt = parse_multitask_prompt(OmegaConf.to_container(cfg.prompt))

transcriptions = asr_model.transcribe(
audio=filepaths,
override_config=override_cfg,
)
if cfg.calculate_rtfx:
transcribe_time = time.time() - start_time

if cfg.dataset_manifest is not None:
logging.info(f"Finished transcribing from manifest file: {cfg.dataset_manifest}")
Expand Down Expand Up @@ -475,6 +504,9 @@ def autocast(dtype=None):
logging.info(f"Writing prediction and error rate of each sample to {output_manifest_w_wer}!")
logging.info(f"{total_res}")

if cfg.calculate_rtfx:
logging.info(f"Dataset RTFx {(total_duration/transcribe_time)}")

return cfg


Expand Down
8 changes: 6 additions & 2 deletions nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ def test_dataloader(self):
def _transcribe_on_begin(self, audio, trcfg: TranscribeConfig):
super()._transcribe_on_begin(audio, trcfg)

# Freeze the encoder and decoure_exder modules
# Freeze the encoder and decoder modules
self.encoder.freeze()
self.decoder.freeze()

Expand Down Expand Up @@ -706,7 +706,11 @@ def _transcribe_output_processing(self, outputs, trcfg: TranscribeConfig) -> Gen
logits_len = logits_len.cpu()
# dump log probs per file
for idx in range(logits_cpu.shape[0]):
current_hypotheses[idx].y_sequence = logits_cpu[idx][: logits_len[idx]]
# We clone because we don't want references to the
# cudaMallocHost()-allocated tensor to be floating
# around. Were that to be the case, then the pinned
# memory cache would always miss.
current_hypotheses[idx].y_sequence = logits_cpu[idx, : logits_len[idx]].clone()
if current_hypotheses[idx].alignments is None:
current_hypotheses[idx].alignments = current_hypotheses[idx].y_sequence
del logits_cpu
Expand Down
22 changes: 19 additions & 3 deletions nemo/collections/asr/modules/audio_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
)
from nemo.core.utils import numba_utils
from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__
from nemo.utils import logging
from nemo.utils import logging, logging_mode

try:
import torchaudio
Expand Down Expand Up @@ -85,11 +85,27 @@ def __init__(self, win_length, hop_length):
None: torch.ones,
}

# Normally, when you call to(dtype) on a torch.nn.Module, all
# floating point parameters and buffers will change to that
# dtype, rather than being float32. The AudioPreprocessor
# classes, uniquely, don't actually have any parameters or
# buffers from what I see. In addition, we want the input to
# the preprocessor to be float32, but need to create the
# output in appropriate precision. We have this empty tensor
# here just to detect which dtype tensor this module should
# output at the end of execution.
self.register_buffer("dtype_sentinel_tensor", torch.tensor((), dtype=torch.float32), persistent=False)

@typecheck()
@torch.no_grad()
def forward(self, input_signal, length):
processed_signal, processed_length = self.get_features(input_signal, length)

if input_signal.dtype != torch.float32:
logging.warn(
f"AudioPreprocessor received an input signal of dtype {input_signal.dtype}, rather than torch.float32. In sweeps across multiple datasets, we have found that the preprocessor is not robust to low precision mathematics. As such, it runs in float32. Your input will be cast to float32, but this is not necessarily enough to recovery full accuracy. For example, simply casting input_signal from torch.float32 to torch.bfloat16, then back to torch.float32 before running AudioPreprocessor causes drops in absolute WER of up to 0.1%. torch.bfloat16 simply does not have enough mantissa bits to represent enough values in the range [-1.0,+1.0] correctly.",
mode=logging_mode.ONCE,
)
processed_signal, processed_length = self.get_features(input_signal.to(torch.float32), length)
processed_signal = processed_signal.to(self.dtype_sentinel_tensor.dtype)
return processed_signal, processed_length

@abstractmethod
Expand Down
3 changes: 2 additions & 1 deletion nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,8 @@ def set_max_audio_length(self, max_audio_length):
"""
self.max_audio_length = max_audio_length
device = next(self.parameters()).device
self.pos_enc.extend_pe(max_audio_length, device)
dtype = next(self.parameters()).dtype
self.pos_enc.extend_pe(max_audio_length, device, dtype)

def _create_masks(self, att_context_size, padding_length, max_audio_length, offset, device):
if self.self_attention_model != "rel_pos_local_attn":
Expand Down
25 changes: 15 additions & 10 deletions nemo/collections/asr/modules/squeezeformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ def input_example(self, max_batch=1, max_dim=256):

@property
def input_types(self):
"""Returns definitions of module input ports.
"""
"""Returns definitions of module input ports."""
return OrderedDict(
{
"audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()),
Expand All @@ -110,8 +109,7 @@ def input_types(self):

@property
def output_types(self):
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return OrderedDict(
{
"outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
Expand Down Expand Up @@ -253,7 +251,11 @@ def __init__(
# Chose same type of positional encoding as the originally determined above
if self_attention_model == "rel_pos":
self.time_reduce_pos_enc = RelPositionalEncoding(
d_model=d_model, dropout_rate=0.0, max_len=pos_emb_max_len, xscale=None, dropout_rate_emb=0.0,
d_model=d_model,
dropout_rate=0.0,
max_len=pos_emb_max_len,
xscale=None,
dropout_rate_emb=0.0,
)
else:
self.time_reduce_pos_enc = PositionalEncoding(
Expand All @@ -275,20 +277,21 @@ def __init__(
self.interctc_capture_at_layers = None

def set_max_audio_length(self, max_audio_length):
""" Sets maximum input length.
Pre-calculates internal seq_range mask.
"""Sets maximum input length.
Pre-calculates internal seq_range mask.
"""
self.max_audio_length = max_audio_length
device = next(self.parameters()).device
dtype = next(self.parameters()).dtype
seq_range = torch.arange(0, self.max_audio_length, device=device)
if hasattr(self, 'seq_range'):
self.seq_range = seq_range
else:
self.register_buffer('seq_range', seq_range, persistent=False)
self.pos_enc.extend_pe(max_audio_length, device)
self.pos_enc.extend_pe(max_audio_length, device, dtype)

if self.time_reduce_pos_enc is not None:
self.time_reduce_pos_enc.extend_pe(max_audio_length, device)
self.time_reduce_pos_enc.extend_pe(max_audio_length, device, dtype)

@typecheck()
def forward(self, audio_signal, length=None):
Expand Down Expand Up @@ -434,7 +437,9 @@ def _update_adapter_cfg_input_dim(self, cfg: DictConfig):
cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self.d_model)
return cfg

def get_accepted_adapter_types(self,) -> Set[type]:
def get_accepted_adapter_types(
self,
) -> Set[type]:
types = super().get_accepted_adapter_types()

if len(types) == 0:
Expand Down
3 changes: 1 addition & 2 deletions nemo/collections/asr/parts/mixins/transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class TranscriptionMixin(ABC):
"""

@torch.no_grad()
@torch.inference_mode()
def transcribe(
self,
audio: Union[str, List[str], np.ndarray, DataLoader],
Expand Down Expand Up @@ -381,7 +381,6 @@ def transcribe_generator(self, audio, override_config: Optional[TranscribeConfig
for test_batch in tqdm(dataloader, desc="Transcribing", disable=not verbose):
# Move batch to device
test_batch = move_to_device(test_batch, transcribe_cfg._internal.device)

# Run forward pass
model_outputs = self._transcribe_forward(test_batch, transcribe_cfg)
processed_outputs = self._transcribe_output_processing(model_outputs, transcribe_cfg)
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/parts/submodules/conformer_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def forward(self, x, pad_mask=None, cache=None):
x = self.pointwise_activation(x)

if pad_mask is not None:
x = x.float().masked_fill(pad_mask.unsqueeze(1), 0.0)
x = x.masked_fill(pad_mask.unsqueeze(1), 0.0)

x = self.depthwise_conv(x, cache=cache)
if cache is not None:
Expand Down
18 changes: 9 additions & 9 deletions nemo/collections/asr/parts/submodules/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ def _compute_out_global_to_all(
global_attn_scores = global_attn_scores.view(batch_size * self.h, max_num_global_attn_indices, seq_len)

# compute global attn probs
global_attn_probs_float = nn.functional.softmax(global_attn_scores, dim=-1, dtype=torch.float32)
global_attn_probs_float = nn.functional.softmax(global_attn_scores, dim=-1)

global_attn_probs = self.dropout(global_attn_probs_float)

Expand Down Expand Up @@ -906,7 +906,7 @@ def __init__(self, d_model, dropout_rate, max_len=5000, xscale=None, dropout_rat
else:
self.dropout_emb = None

def create_pe(self, positions):
def create_pe(self, positions, dtype):
pos_length = positions.size(0)
pe = torch.zeros(pos_length, self.d_model, device=positions.device)
div_term = torch.exp(
Expand All @@ -915,18 +915,18 @@ def create_pe(self, positions):
)
pe[:, 0::2] = torch.sin(positions * div_term)
pe[:, 1::2] = torch.cos(positions * div_term)
pe = pe.unsqueeze(0)
pe = pe.unsqueeze(0).to(dtype)
if hasattr(self, 'pe'):
self.pe = pe
else:
self.register_buffer('pe', pe, persistent=False)

def extend_pe(self, length, device):
def extend_pe(self, length, device, dtype):
"""Reset and extend the positional encodings if needed."""
if hasattr(self, 'pe') and self.pe.size(1) >= length:
return
positions = torch.arange(0, length, dtype=torch.float32, device=device).unsqueeze(1)
self.create_pe(positions=positions)
self.create_pe(positions=positions, dtype=dtype)

def forward(self, x: torch.Tensor, cache_len=0):
"""Adds positional encoding.
Expand Down Expand Up @@ -958,15 +958,15 @@ class RelPositionalEncoding(PositionalEncoding):
dropout_rate_emb (float): dropout rate for the positional embeddings
"""

def extend_pe(self, length, device):
def extend_pe(self, length, device, dtype):
"""Reset and extend the positional encodings if needed."""
needed_size = 2 * length - 1
if hasattr(self, 'pe') and self.pe.size(1) >= needed_size:
return
# positions would be from negative numbers to positive
# positive positions would be used for left positions and negative for right positions
positions = torch.arange(length - 1, -length, -1, dtype=torch.float32, device=device).unsqueeze(1)
self.create_pe(positions=positions)
self.create_pe(positions=positions, dtype=dtype)

def forward(self, x, cache_len=0):
"""Compute positional encoding.
Expand Down Expand Up @@ -1012,15 +1012,15 @@ def __init__(self, att_context_size, **kwargs):
self.left_context = att_context_size[0]
self.right_context = att_context_size[1]

def extend_pe(self, length, device):
def extend_pe(self, length, device, dtype):
"""Reset and extend the positional encodings only at the beginning"""
if hasattr(self, 'pe'):
return

positions = torch.arange(
self.left_context, -self.right_context - 1, -1, dtype=torch.float32, device=device
).unsqueeze(1)
self.create_pe(positions=positions)
self.create_pe(positions=positions, dtype=dtype)

def forward(self, x, cache_len=0):
"""Compute positional encoding.
Expand Down
Loading

0 comments on commit a6ffea9

Please sign in to comment.