Skip to content

Commit

Permalink
WIP on improving streaming inference.
Browse files Browse the repository at this point in the history
Impotant changes are in nemo/collections/asr/parts/mixins/mixins.py

We should never do a serial loop over the elements of a batch.
  • Loading branch information
galv committed May 6, 2024
1 parent bb566a8 commit bfefe72
Show file tree
Hide file tree
Showing 10 changed files with 221 additions and 108 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""
This script can be used to simulate cache-aware streaming for ASR models. The ASR model to be used with this script need to get trained in streaming mode. Currently only Conformer models supports this streaming mode.
You may find examples of streaming models under 'NeMo/example/asr/conf/conformer/streaming/'.
You may find examples of streaming models under 'NeMo/examples/asr/conf/conformer/streaming/'.
It works both on a manifest of audio files or a single audio file. It can perform streaming for a single stream (audio) or perform the evalution in multi-stream model (batch_size>1).
The manifest file must conform to standard ASR definition - containing `audio_filepath` and `text` as the ground truth.
Expand All @@ -23,7 +23,7 @@
## To evaluate a model in cache-aware streaming mode on a single audio file:
python speech_to_text_streaming_infer.py \
python speech_to_text_cache_aware_streaming_infer.py \
--asr_model=asr_model.nemo \
--audio_file=audio_file.wav \
--compare_vs_offline \
Expand All @@ -32,7 +32,7 @@
## To evaluate a model in cache-aware streaming mode on a manifest file:
python speech_to_text_streaming_infer.py \
python speech_to_text_cache_aware_streaming_infer.py \
--asr_model=asr_model.nemo \
--manifest_file=manifest_file.json \
--batch_size=16 \
Expand Down Expand Up @@ -97,6 +97,7 @@
from nemo.utils import logging


# This is how it removes the text
def extract_transcriptions(hyps):
"""
The transcribed_texts returned by CTC and RNNT models are different.
Expand All @@ -123,68 +124,63 @@ def perform_streaming(
asr_model, streaming_buffer, compare_vs_offline=False, debug_mode=False, pad_and_drop_preencoded=False
):
batch_size = len(streaming_buffer.streams_length)
if compare_vs_offline:
# would pass the whole audio at once through the model like offline mode in order to compare the results with the stremaing mode
# the output of the model in the offline and streaming mode should be exactly the same
with torch.inference_mode():
with autocast():
processed_signal, processed_signal_length = streaming_buffer.get_all_audios()
with torch.no_grad():
(
pred_out_offline,
transcribed_texts,
cache_last_channel_next,
cache_last_time_next,
cache_last_channel_len,
best_hyp,
) = asr_model.conformer_stream_step(
processed_signal=processed_signal,
processed_signal_length=processed_signal_length,
return_transcription=True,
)
final_offline_tran = extract_transcriptions(transcribed_texts)
logging.info(f" Final offline transcriptions: {final_offline_tran}")
else:
final_offline_tran = None
# Seems like this is just assuming fp16 as the amp dtype then...
with torch.inference_mode(), autocast(), torch.no_grad():
if compare_vs_offline:
# would pass the whole audio at once through the model like offline mode in order to compare the results with the stremaing mode
# the output of the model in the offline and streaming mode should be exactly the same
processed_signal, processed_signal_length = streaming_buffer.get_all_audios()
(
pred_out_offline,
transcribed_texts,
cache_last_channel_next,
cache_last_time_next,
cache_last_channel_len,
best_hyp,
) = asr_model.conformer_stream_step(
processed_signal=processed_signal,
processed_signal_length=processed_signal_length,
return_transcription=True,
)
final_offline_tran = extract_transcriptions(transcribed_texts)
# logging.info(f" Final offline transcriptions: {final_offline_tran}")
else:
final_offline_tran = None

cache_last_channel, cache_last_time, cache_last_channel_len = asr_model.encoder.get_initial_cache_state(
batch_size=batch_size
)
cache_last_channel, cache_last_time, cache_last_channel_len = asr_model.encoder.get_initial_cache_state(
batch_size=batch_size,
dtype=torch.float16 # TODO, make it depend upon amp dtype
)

previous_hypotheses = None
streaming_buffer_iter = iter(streaming_buffer)
pred_out_stream = None
with torch.inference_mode(), autocast(), torch.no_grad():
for step_num, (chunk_audio, chunk_lengths) in enumerate(streaming_buffer_iter):
(
pred_out_stream,
transcribed_texts,
cache_last_channel,
cache_last_time,
cache_last_channel_len,
previous_hypotheses,
) = asr_model.conformer_stream_step(
processed_signal=chunk_audio,
processed_signal_length=chunk_lengths,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_len=cache_last_channel_len,
keep_all_outputs=streaming_buffer.is_buffer_empty(),
previous_hypotheses=previous_hypotheses,
previous_pred_out=pred_out_stream,
drop_extra_pre_encoded=calc_drop_extra_pre_encoded(
asr_model, step_num, pad_and_drop_preencoded
),
return_transcription=True,
)

previous_hypotheses = None
streaming_buffer_iter = iter(streaming_buffer)
pred_out_stream = None
for step_num, (chunk_audio, chunk_lengths) in enumerate(streaming_buffer_iter):
with torch.inference_mode():
with autocast():
# keep_all_outputs needs to be True for the last step of streaming when model is trained with att_context_style=regular
# otherwise the last outputs would get dropped

with torch.no_grad():
(
pred_out_stream,
transcribed_texts,
cache_last_channel,
cache_last_time,
cache_last_channel_len,
previous_hypotheses,
) = asr_model.conformer_stream_step(
processed_signal=chunk_audio,
processed_signal_length=chunk_lengths,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_len=cache_last_channel_len,
keep_all_outputs=streaming_buffer.is_buffer_empty(),
previous_hypotheses=previous_hypotheses,
previous_pred_out=pred_out_stream,
drop_extra_pre_encoded=calc_drop_extra_pre_encoded(
asr_model, step_num, pad_and_drop_preencoded
),
return_transcription=True,
)

if debug_mode:
logging.info(f"Streaming transcriptions: {extract_transcriptions(transcribed_texts)}")
if debug_mode:
logging.info(f"Streaming transcriptions: {extract_transcriptions(transcribed_texts)}")

final_streaming_tran = extract_transcriptions(transcribed_texts)
logging.info(f"Final streaming transcriptions: {final_streaming_tran}")
Expand Down Expand Up @@ -297,11 +293,6 @@ def main():
asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=args.asr_model)

logging.info(asr_model.encoder.streaming_cfg)
if args.set_decoder is not None:
if hasattr(asr_model, "cur_decoder"):
asr_model.change_decoding_strategy(decoder_type=args.set_decoder)
else:
raise ValueError("Decoder cannot get changed for non-Hybrid ASR models.")

if args.att_context_size is not None:
if hasattr(asr_model.encoder, "set_default_att_context_size"):
Expand All @@ -325,14 +316,33 @@ def autocast():
yield

# configure the decoding config

if args.set_decoder is not None:
if hasattr(asr_model, "cur_decoder"):
decoder_type = args.set_decoder
else:
raise ValueError("Decoder cannot get changed for non-Hybrid ASR models.")
else:
decoder_type = None


decoding_cfg = asr_model.cfg.decoding
with open_dict(decoding_cfg):
decoding_cfg.strategy = "greedy"
decoding_cfg.preserve_alignments = False
if hasattr(asr_model, 'joint'): # if an RNNT model
decoding_cfg.greedy.max_symbols = 10
if decoder_type == "rnnt": # if an RNNT model
# We need partial hypothesis support here...
decoding_cfg.strategy = "greedy_batch"
decoding_cfg.fused_batch_size = -1
asr_model.change_decoding_strategy(decoding_cfg)
decoding_cfg.greedy.max_symbols_per_step = 10
decoding_cfg.greedy.loop_labels = True
# TODO: Why isn't this working???
decoding_cfg.greedy.use_cuda_graph_decoder = True
# import ipdb; ipdb.set_trace()
elif decoder_type == "ctc":
decoding_cfg.greedy.batched_inference = True

asr_model.change_decoding_strategy(decoding_cfg, decoder_type=decoder_type)

asr_model = asr_model.to(args.device)
asr_model.eval()
Expand Down Expand Up @@ -393,15 +403,31 @@ def autocast():
logging.info(f"Loaded {len(samples)} from the manifest at {args.manifest_file}.")

start_time = time.time()
# Need to do autocast here...

input_audio_time = 0.0

first_time = True

for sample_idx, sample in enumerate(samples):
# print("GALVEZ:", sample_idx)
processed_signal, processed_signal_length, stream_id = streaming_buffer.append_audio_file(
sample['audio_filepath'], stream_id=-1
)
if "text" in sample:
all_refs_text.append(sample["text"])
logging.info(f'Added this sample to the buffer: {sample["audio_filepath"]}')

input_audio_time += sample["duration"]

if (sample_idx + 1) % args.batch_size == 0 or sample_idx == len(samples) - 1:
if not first_time:
# import nvtx
# pr = nvtx.Profile()
# pr.enable() # begin annotating function calls
# ctx = torch.autograd.profiler.emit_nvtx()
# ctx.__enter__()
torch.cuda.cudart().cudaProfilerStart()
logging.info(f"Starting to stream samples {sample_idx - len(streaming_buffer) + 1} to {sample_idx}...")
streaming_tran, offline_tran = perform_streaming(
asr_model=asr_model,
Expand All @@ -414,6 +440,11 @@ def autocast():
if args.compare_vs_offline:
all_offline_tran.extend(offline_tran)
streaming_buffer.reset_buffer()
if not first_time:
# pr.disable()
# ctx.__exit__(None, None, None)
torch.cuda.cudart().cudaProfilerStop()
first_time = False

if args.compare_vs_offline and len(all_refs_text) == len(all_offline_tran):
offline_wer = word_error_rate(hypotheses=all_offline_tran, references=all_refs_text)
Expand All @@ -424,6 +455,7 @@ def autocast():

end_time = time.time()
logging.info(f"The whole streaming process took: {round(end_time - start_time, 2)}s")
logging.info(f"RTFx={input_audio_time/(end_time - start_time)}")

# stores the results including the transcriptions of the streaming inference in a json file
if args.output_path is not None and len(all_refs_text) == len(all_streaming_tran):
Expand Down
4 changes: 2 additions & 2 deletions examples/asr/triton-inference-server/SETUP.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ python client.py --audio_file=xxx.wav
# verify manifest accuracy
pip install nemo_toolkit['asr']
python client.py --manifest=<path to manifest>.json --do_wer_cer=1 # 1 for wer, 2 for cer
/home/dgalvez/scratch/data/test_other_sorted_downward.json
python client.py --manifest=/home/dgalvez/scratch/data/test_other_sorted_downward.json --do_wer_cer=1 # 1 for wer, 2 for cer
```
4 changes: 2 additions & 2 deletions examples/asr/triton-inference-server/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def recognize(self, wav_file):
result = b" ".join(decoding_results).decode("utf-8")
else:
result = decoding_results.decode("utf-8")
print("Recognized: ", wav_file, result)
# print("Recognized: ", wav_file, result)
return (wav_file, result, latency)


Expand Down Expand Up @@ -189,4 +189,4 @@ async def main(args):

args = parser.parse_args()
asyncio.run(main(args))


50 changes: 23 additions & 27 deletions nemo/collections/asr/parts/mixins/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,7 @@ def change_subsampling_conv_chunking_factor(
with open_dict(self.cfg):
self.cfg.encoder.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor

# So I should run this, right? Either that or cache_aware_stream_step
def conformer_stream_step(
self,
processed_signal: torch.Tensor,
Expand Down Expand Up @@ -646,45 +647,40 @@ def conformer_stream_step(
decoding = self.decoding
decoder = self.decoder

# I don't think this is a very useful thing to do.
# TODO: Consider rewriting this.
log_probs = decoder(encoder_output=encoded)
predictions_tensor = log_probs.argmax(dim=-1, keepdim=False)
# predictions_tensor = log_probs.argmax(dim=-1, keepdim=False)

assert return_transcription

# Concatenate the previous predictions with the current one to have the full predictions.
# We drop the extra predictions for each sample by using the lengths returned by the encoder (encoded_len)
# Then create a list of the predictions for the batch. The predictions can have different lengths because of the paddings.
greedy_predictions = []
if return_transcription:
all_hyp_or_transcribed_texts = []
else:
all_hyp_or_transcribed_texts = None
for preds_idx, preds in enumerate(predictions_tensor):
if encoded_len is None:
preds_cur = predictions_tensor[preds_idx]
else:
preds_cur = predictions_tensor[preds_idx, : encoded_len[preds_idx]]
if previous_pred_out is not None:
greedy_predictions_concat = torch.cat((previous_pred_out[preds_idx], preds_cur), dim=-1)
encoded_len[preds_idx] += len(previous_pred_out[preds_idx])
else:
greedy_predictions_concat = preds_cur
greedy_predictions.append(greedy_predictions_concat)

# TODO: make decoding more efficient by avoiding the decoding process from the beginning
if return_transcription:
decoded_out = decoding.ctc_decoder_predictions_tensor(
decoder_outputs=greedy_predictions_concat.unsqueeze(0),
decoder_lengths=encoded_len[preds_idx : preds_idx + 1],
return_hypotheses=False,
)
all_hyp_or_transcribed_texts.append(decoded_out[0][0])
best_hyp = None

decoded_out = decoding.ctc_decoder_predictions_tensor(
decoder_outputs=log_probs,
decoder_lengths=encoded_len,
return_hypotheses=True,
previous_hypotheses=previous_hypotheses,
)

batch_size = encoded_len.shape[0]

# TODO: We need to merge the previous hypothesis output with this one...

all_hyp_or_transcribed_texts: List[Hypothesis] = [decoded_out[0][i] for i in range(batch_size)]

best_hyp = all_hyp_or_transcribed_texts
else:
best_hyp, all_hyp_or_transcribed_texts = self.decoding.rnnt_decoder_predictions_tensor(
encoder_output=encoded,
encoded_lengths=encoded_len,
return_hypotheses=True,
partial_hypotheses=previous_hypotheses,
partial_hypotheses=previous_hypotheses, # Here is my problem. Could change to CTC for now.
)
# Just return hypotheses for now, I think? When does it get turned into text? Unclear.
greedy_predictions = [hyp.y_sequence for hyp in best_hyp]

if all_hyp_or_transcribed_texts is None:
Expand Down
5 changes: 5 additions & 0 deletions nemo/collections/asr/parts/preprocessing/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def __init__(
self.win_length = n_window_size
self.hop_length = n_window_stride
self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
# import ipdb; ipdb.set_trace()
self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None

if exact_pad:
Expand Down Expand Up @@ -398,6 +399,7 @@ def log_zero_guard_value_fn(self, x):
def get_seq_len(self, seq_len):
# Assuming that center is True is stft_pad_amount = 0
pad_amount = self.stft_pad_amount * 2 if self.stft_pad_amount is not None else self.n_fft // 2 * 2
# pad_amount = 512
seq_len = torch.floor_divide((seq_len + pad_amount - self.n_fft), self.hop_length) + 1
return seq_len.to(dtype=torch.long)

Expand All @@ -414,6 +416,7 @@ def forward(self, x, seq_len, linear_spec=False):
).squeeze(1)

# dither (only in training mode for eval determinism)
# TODO: We probably want to apply "deterministic dithering" to unbias this.
if self.training and self.dither > 0:
x += self.dither * torch.randn_like(x)

Expand Down Expand Up @@ -476,6 +479,8 @@ def forward(self, x, seq_len, linear_spec=False):
pad_amt = x.size(-1) % pad_to
if pad_amt != 0:
x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value)

# import ipdb; ipdb.set_trace()
return x, seq_len


Expand Down
Loading

0 comments on commit bfefe72

Please sign in to comment.