diff --git a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py index d0063ee81150..a7f57c82279a 100644 --- a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py @@ -58,6 +58,9 @@ def _states_to_device(dec_state, device='cpu'): return dec_state +_DECODER_LENGTHS_NONE_WARNING = "Passing in decoder_lengths=None for CTC decoding is likely to be an error, since it is unlikely that each element of your batch has exactly the same length. decoder_lengths will default to decoder_output.shape[0]." + + class GreedyCTCInfer(Typing, ConfidenceMethodMixin): """A greedy CTC decoder. @@ -148,7 +151,7 @@ def __init__( def forward( self, decoder_output: torch.Tensor, - decoder_lengths: torch.Tensor, + decoder_lengths: Optional[torch.Tensor], ): """Returns a list of hypotheses given an input batch of the encoder hidden embedding. Output token is generated auto-repressively. @@ -167,6 +170,9 @@ def forward( mode=logging_mode.ONCE, ) + if decoder_lengths is None: + logging.warning(_DECODER_LENGTHS_NONE_WARNING, mode=logging_mode.ONCE) + with torch.inference_mode(): hypotheses = [] # Process each sequence independently @@ -213,7 +219,7 @@ def forward( return (packed_result,) @torch.no_grad() - def _greedy_decode_logprobs(self, x: torch.Tensor, out_len: torch.Tensor): + def _greedy_decode_logprobs(self, x: torch.Tensor, out_len: Optional[torch.Tensor]): # x: [T, D] # out_len: [seq_len] @@ -243,7 +249,7 @@ def _greedy_decode_logprobs(self, x: torch.Tensor, out_len: torch.Tensor): return hypothesis @torch.no_grad() - def _greedy_decode_labels(self, x: torch.Tensor, out_len: torch.Tensor): + def _greedy_decode_labels(self, x: torch.Tensor, out_len: Optional[torch.Tensor]): # x: [T] # out_len: [seq_len] @@ -370,7 +376,7 @@ def __init__( def forward( self, decoder_output: torch.Tensor, - decoder_lengths: torch.Tensor, + decoder_lengths: Optional[torch.Tensor], ): """Returns a list of hypotheses given an input batch of the encoder hidden embedding. Output token is generated auto-repressively. @@ -383,11 +389,18 @@ def forward( Returns: packed list containing batch number of sentences (Hypotheses). """ + + input_decoder_lengths = decoder_lengths + + if decoder_lengths is None: + logging.warning(_DECODER_LENGTHS_NONE_WARNING, mode=logging_mode.ONCE) + decoder_lengths = torch.tensor([decoder_output.shape[1]], dtype=torch.long).expand(decoder_output.shape[0]) + if decoder_output.ndim == 2: hypotheses = self._greedy_decode_labels_batched(decoder_output, decoder_lengths) else: hypotheses = self._greedy_decode_logprobs_batched(decoder_output, decoder_lengths) - packed_result = pack_hypotheses(hypotheses, decoder_lengths) + packed_result = pack_hypotheses(hypotheses, input_decoder_lengths) return (packed_result,) @torch.no_grad() diff --git a/tests/collections/asr/decoding/test_ctc_decoding.py b/tests/collections/asr/decoding/test_ctc_decoding.py index ea2cdea58119..dd8871c329fc 100644 --- a/tests/collections/asr/decoding/test_ctc_decoding.py +++ b/tests/collections/asr/decoding/test_ctc_decoding.py @@ -199,7 +199,10 @@ def test_subword_decoding_greedy_forward_hypotheses(self, tmp_tokenizer, alignme @pytest.mark.parametrize('alignments', [False, True]) @pytest.mark.parametrize('timestamps', [False, True]) @pytest.mark.parametrize('preserve_frame_confidence', [False, True]) - def test_batched_decoding_logprobs(self, tmp_tokenizer, alignments, timestamps, preserve_frame_confidence): + @pytest.mark.parametrize('length_is_none', [False, True]) + def test_batched_decoding_logprobs( + self, tmp_tokenizer, alignments, timestamps, preserve_frame_confidence, length_is_none + ): cfg = CTCBPEDecodingConfig( strategy='greedy', preserve_alignments=alignments, @@ -219,7 +222,10 @@ def test_batched_decoding_logprobs(self, tmp_tokenizer, alignments, timestamps, # that we always handle at least a few blanks. input_signal[:, 0, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000 input_signal[:, 1, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000 - length = torch.randint(low=1, high=T, size=[B]) + if length_is_none: + length = None + else: + length = torch.randint(low=1, high=T, size=[B]) with torch.inference_mode(): hyps, _ = unbatched_decoding.ctc_decoder_predictions_tensor( @@ -242,7 +248,8 @@ def test_batched_decoding_logprobs(self, tmp_tokenizer, alignments, timestamps, @pytest.mark.unit @pytest.mark.parametrize('timestamps', [False, True]) - def test_batched_decoding_labels(self, tmp_tokenizer, timestamps): + @pytest.mark.parametrize('length_is_none', [False, True]) + def test_batched_decoding_labels(self, tmp_tokenizer, timestamps, length_is_none): cfg = CTCBPEDecodingConfig(strategy='greedy', compute_timestamps=timestamps) unbatched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer) cfg.strategy = 'greedy_batch' @@ -256,7 +263,10 @@ def test_batched_decoding_labels(self, tmp_tokenizer, timestamps): # at least a few blanks. input_labels[:, 0] = unbatched_decoding.tokenizer.tokenizer.vocab_size input_labels[:, 1] = unbatched_decoding.tokenizer.tokenizer.vocab_size - length = torch.randint(low=1, high=T, size=[B]) + if length_is_none: + length = None + else: + length = torch.randint(low=1, high=T, size=[B]) with torch.inference_mode(): hyps, _ = unbatched_decoding.ctc_decoder_predictions_tensor(