diff --git a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py index edb598c83819d..c2b9b123cd6e5 100644 --- a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py @@ -25,7 +25,10 @@ from nemo.utils import logging -def pack_hypotheses(hypotheses: List[rnnt_utils.Hypothesis], logitlen: torch.Tensor,) -> List[rnnt_utils.Hypothesis]: +def pack_hypotheses( + hypotheses: List[rnnt_utils.Hypothesis], + logitlen: torch.Tensor, +) -> List[rnnt_utils.Hypothesis]: if logitlen is not None: if hasattr(logitlen, 'cpu'): @@ -54,6 +57,7 @@ def _states_to_device(dec_state, device='cpu'): return dec_state + _DECODER_LENGTHS_NONE_WARNING = "Passing in decoder_lengths=None for CTC decoding is likely to be an error, since it is unlikely that each element of your batch has exactly the same length. decoder_lengths will default to decoder_output.shape[0]." @@ -110,8 +114,7 @@ class GreedyCTCInfer(Typing, ConfidenceMethodMixin): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" # Input can be of dimension - # ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels] @@ -122,8 +125,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return {"predictions": [NeuralType(elements_type=HypothesisType())]} def __init__( @@ -331,8 +333,7 @@ class GreedyBatchedCTCInfer(Typing, ConfidenceMethodMixin): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" # Input can be of dimension - # ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels] @@ -343,8 +344,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return {"predictions": [NeuralType(elements_type=HypothesisType())]} def __init__( diff --git a/tests/collections/asr/decoding/test_ctc_decoding.py b/tests/collections/asr/decoding/test_ctc_decoding.py index d74c24f45e4bc..a42d61f051adc 100644 --- a/tests/collections/asr/decoding/test_ctc_decoding.py +++ b/tests/collections/asr/decoding/test_ctc_decoding.py @@ -90,7 +90,9 @@ def test_constructor_subword(self, tmp_tokenizer): assert decoding is not None @pytest.mark.unit - def test_char_decoding_greedy_forward(self,): + def test_char_decoding_greedy_forward( + self, + ): cfg = CTCDecodingConfig(strategy='greedy') vocab = char_vocabulary() decoding = CTCDecoding(decoding_cfg=cfg, vocabulary=vocab) @@ -198,7 +200,9 @@ def test_subword_decoding_greedy_forward_hypotheses(self, tmp_tokenizer, alignme @pytest.mark.parametrize('timestamps', [False, True]) @pytest.mark.parametrize('preserve_frame_confidence', [False, True]) @pytest.mark.parametrize('length_is_none', [False, True]) - def test_batched_decoding_logprobs(self, tmp_tokenizer, alignments, timestamps, preserve_frame_confidence, length_is_none): + def test_batched_decoding_logprobs( + self, tmp_tokenizer, alignments, timestamps, preserve_frame_confidence, length_is_none + ): cfg = CTCBPEDecodingConfig( strategy='greedy', preserve_alignments=alignments,