From 6779d39c17d1d82913c75d9fa0800b7ba123b5f1 Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Thu, 2 May 2024 14:42:32 -0700 Subject: [PATCH] Add support for batched inference for label inputs as well. --- .../parts/submodules/ctc_greedy_decoding.py | 57 ++++++++++++++++--- .../asr/decoding/test_ctc_decoding.py | 48 +++++++++++++++- 2 files changed, 95 insertions(+), 10 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py index e68dde59b1f35..bcef828d8f8e4 100644 --- a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py @@ -164,10 +164,14 @@ def forward( if self.batched_inference: torch.cuda.nvtx.range_push("batched hypotheses") - hypotheses = self._greedy_decode_logprobs_batched(decoder_output, decoder_lengths) + if decoder_output.ndim == 2: + hypotheses = self._greedy_decode_labels_batched(decoder_output, decoder_lengths) + else: + hypotheses = self._greedy_decode_logprobs_batched(decoder_output, decoder_lengths) torch.cuda.nvtx.range_pop() packed_result = pack_hypotheses(hypotheses, decoder_lengths) return (packed_result,) + with torch.inference_mode(): hypotheses = [] # Process each sequence independently @@ -218,7 +222,6 @@ def forward( torch.cuda.nvtx.range_pop() return (packed_result,) - # Cannot naively call vmap because this does not return tensors... @torch.no_grad() def _greedy_decode_logprobs_batched(self, x: torch.Tensor, out_len: torch.Tensor): # x: [B, T, D] @@ -248,19 +251,59 @@ def _greedy_decode_logprobs_batched(self, x: torch.Tensor, out_len: torch.Tensor hypothesis = rnnt_utils.Hypothesis(score=0.0, y_sequence=[], dec_state=None, timestep=[], last_token=None) hypothesis.score = scores[i] - # prediction_labels_no_padding = predictions_labels[i, :labels_segments[i]] - prediction_labels_no_padding = predictions_labels[i, :out_len[i]] + prediction_labels_no_padding = predictions_labels[i, :out_len[i]].tolist() assert predictions_labels.dtype == torch.int64 hypothesis.y_sequence = prediction_labels_no_padding if self.preserve_alignments: - hypothesis.alignments = (predictions[i].clone(), predictions_labels[i].clone()) + hypothesis.alignments = (predictions[i, :out_len[i], :].clone(), + predictions_labels[i, :out_len[i]].clone()) if self.compute_timestamps: - # Could do this in a vectorized manner... + # TOOD: Could do this in a vectorized manner... Would + # prefer to have nonzero_static, though, for sanity. hypothesis.timestep = torch.nonzero(non_blank_ids_mask[i], as_tuple=False)[:, 0].cpu().tolist() if self.preserve_frame_confidence: - hypothesis.frame_confidence = self._get_confidence(predictions[i]) + hypothesis.frame_confidence = self._get_confidence(predictions[i, :out_len[i], :]) + + hypotheses.append(hypothesis) + + return hypotheses + + @torch.no_grad() + def _greedy_decode_labels_batched(self, x: torch.Tensor, out_len: torch.Tensor): + # x: [B, T] + # out_len: [B] + + batch_size = x.shape[0] + max_time = x.shape[1] + + predictions_labels = x + time_steps = torch.arange(max_time, device=x.device).unsqueeze(0).expand(batch_size, max_time) + non_blank_ids_mask = torch.logical_and(predictions_labels != self.blank_id, + time_steps < out_len.unsqueeze(1)) + predictions_labels = predictions_labels.cpu() + out_len = out_len.cpu() + + hypotheses = [] + + for i in range(batch_size): + hypothesis = rnnt_utils.Hypothesis(score=0.0, y_sequence=[], dec_state=None, timestep=[], last_token=None) + hypothesis.y_sequence = predictions_labels[i, :out_len[i]].tolist() + hypothesis.score = -1.0 + + if self.preserve_alignments: + raise ValueError("Requested for alignments, but predictions provided were labels, not log probabilities.") + if self.compute_timestamps: + # TOOD: Could do this in a vectorized manner... Would + # prefer to have nonzero_static, though, for sanity. + # Or do a prefix sum on out_len + hypothesis.timestep = torch.nonzero(non_blank_ids_mask[i], as_tuple=False)[:, 0].cpu().tolist() + if self.preserve_frame_confidence: + raise ValueError( + "Requested for per-frame confidence, but predictions provided were labels, not log probabilities." + ) + hypotheses.append(hypothesis) diff --git a/tests/collections/asr/decoding/test_ctc_decoding.py b/tests/collections/asr/decoding/test_ctc_decoding.py index 9afea44a5ecca..79fd1ac2a1bc8 100644 --- a/tests/collections/asr/decoding/test_ctc_decoding.py +++ b/tests/collections/asr/decoding/test_ctc_decoding.py @@ -26,6 +26,7 @@ CTCDecoding, CTCDecodingConfig, ) +from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceConfig from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis @@ -195,8 +196,10 @@ def test_subword_decoding_greedy_forward_hypotheses(self, tmp_tokenizer, alignme @pytest.mark.unit @pytest.mark.parametrize('alignments', [False, True]) @pytest.mark.parametrize('timestamps', [False, True]) - def test_batched_decoding(self, tmp_tokenizer, alignments, timestamps): - cfg = CTCBPEDecodingConfig(strategy='greedy', preserve_alignments=alignments, compute_timestamps=timestamps) + @pytest.mark.parametrize('preserve_frame_confidence', [False, True]) + def test_batched_decoding_logprobs(self, tmp_tokenizer, alignments, timestamps, preserve_frame_confidence): + # timestamps not working... + cfg = CTCBPEDecodingConfig(strategy='greedy', preserve_alignments=alignments, compute_timestamps=timestamps, confidence_cfg=ConfidenceConfig(preserve_frame_confidence=preserve_frame_confidence)) cfg.greedy.batched_inference = False unbatched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer) @@ -207,11 +210,13 @@ def test_batched_decoding(self, tmp_tokenizer, alignments, timestamps): B, T = 4, 20 V = unbatched_decoding.tokenizer.tokenizer.vocab_size + 1 input_signal = torch.randn(size=(B, T, V)) + # Set the blank index to a very high probability to make sure + # that we always handle at least a few blanks. input_signal[:, 0, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000 input_signal[:, 1, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000 length = torch.randint(low=1, high=T, size=[B]) - with torch.no_grad(): + with torch.inference_mode(): hyps, _ = unbatched_decoding.ctc_decoder_predictions_tensor( input_signal, length, fold_consecutive=True, return_hypotheses=True ) @@ -229,3 +234,40 @@ def test_batched_decoding(self, tmp_tokenizer, alignments, timestamps): if alignments: assert torch.all(hyp.alignments[0] == batched_hyp.alignments[0]) assert torch.all(hyp.alignments[1] == batched_hyp.alignments[1]) + + + @pytest.mark.unit + @pytest.mark.parametrize('timestamps', [False, True]) + def test_batched_decoding_labels(self, tmp_tokenizer, timestamps): + cfg = CTCBPEDecodingConfig(strategy='greedy', compute_timestamps=timestamps) + cfg.greedy.batched_inference = False + unbatched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer) + + cfg.greedy.batched_inference = True + batched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer) + + torch.manual_seed(1) + B, T = 4, 20 + V = unbatched_decoding.tokenizer.tokenizer.vocab_size + 1 + input_labels = torch.randint(V, size=(B, T)) + # Set the blank index to a very high probability to make sure + # that we always handle at least a few blanks. + input_labels[:, 0] = unbatched_decoding.tokenizer.tokenizer.vocab_size + input_labels[:, 1] = unbatched_decoding.tokenizer.tokenizer.vocab_size + length = torch.randint(low=1, high=T, size=[B]) + + with torch.inference_mode(): + hyps, _ = unbatched_decoding.ctc_decoder_predictions_tensor( + input_labels, length, fold_consecutive=True, return_hypotheses=True + ) + + batched_hyps, _ = batched_decoding.ctc_decoder_predictions_tensor( + input_labels, length, fold_consecutive=True, return_hypotheses=True + ) + + assert len(hyps) == len(batched_hyps) == B + for hyp, batched_hyp in zip(hyps, batched_hyps): + assert abs(hyp.score - batched_hyp.score) <= 1e-5 + assert torch.all(hyp.y_sequence == batched_hyp.y_sequence) + if timestamps: + assert hyp.timestep == batched_hyp.timestep