diff --git a/nemo/collections/asr/parts/submodules/ctc_decoding.py b/nemo/collections/asr/parts/submodules/ctc_decoding.py index 065e55692821..d2bfb629293e 100644 --- a/nemo/collections/asr/parts/submodules/ctc_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_decoding.py @@ -1018,7 +1018,9 @@ class CTCDecoding(AbstractCTCDecoding): """ def __init__( - self, decoding_cfg, vocabulary, + self, + decoding_cfg, + vocabulary, ): blank_id = len(vocabulary) self.vocabulary = vocabulary diff --git a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py index 035c5376e24a..d0063ee81150 100644 --- a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py @@ -25,7 +25,10 @@ from nemo.utils import logging, logging_mode -def pack_hypotheses(hypotheses: List[rnnt_utils.Hypothesis], logitlen: torch.Tensor,) -> List[rnnt_utils.Hypothesis]: +def pack_hypotheses( + hypotheses: List[rnnt_utils.Hypothesis], + logitlen: torch.Tensor, +) -> List[rnnt_utils.Hypothesis]: if logitlen is not None: if hasattr(logitlen, 'cpu'): @@ -108,8 +111,7 @@ class GreedyCTCInfer(Typing, ConfidenceMethodMixin): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" # Input can be of dimension - # ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels] @@ -120,8 +122,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return {"predictions": [NeuralType(elements_type=HypothesisType())]} def __init__( @@ -145,7 +146,9 @@ def __init__( @typecheck() def forward( - self, decoder_output: torch.Tensor, decoder_lengths: torch.Tensor, + self, + decoder_output: torch.Tensor, + decoder_lengths: torch.Tensor, ): """Returns a list of hypotheses given an input batch of the encoder hidden embedding. Output token is generated auto-repressively. @@ -330,8 +333,7 @@ class GreedyBatchedCTCInfer(Typing, ConfidenceMethodMixin): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" # Input can be of dimension - # ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels] @@ -342,8 +344,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return {"predictions": [NeuralType(elements_type=HypothesisType())]} def __init__( @@ -367,7 +368,9 @@ def __init__( @typecheck() def forward( - self, decoder_output: torch.Tensor, decoder_lengths: torch.Tensor, + self, + decoder_output: torch.Tensor, + decoder_lengths: torch.Tensor, ): """Returns a list of hypotheses given an input batch of the encoder hidden embedding. Output token is generated auto-repressively. diff --git a/tests/collections/asr/decoding/test_ctc_decoding.py b/tests/collections/asr/decoding/test_ctc_decoding.py index 6c2f38bf5e5e..ea2cdea58119 100644 --- a/tests/collections/asr/decoding/test_ctc_decoding.py +++ b/tests/collections/asr/decoding/test_ctc_decoding.py @@ -90,7 +90,9 @@ def test_constructor_subword(self, tmp_tokenizer): assert decoding is not None @pytest.mark.unit - def test_char_decoding_greedy_forward(self,): + def test_char_decoding_greedy_forward( + self, + ): cfg = CTCDecodingConfig(strategy='greedy') vocab = char_vocabulary() decoding = CTCDecoding(decoding_cfg=cfg, vocabulary=vocab) diff --git a/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py b/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py index 049abb9c28c0..0d7c555ee778 100644 --- a/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py +++ b/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py @@ -309,7 +309,10 @@ def test_ASRDatasetConfig_for_AudioToBPEDataset(self): REMAP_ARGS = {'trim_silence': 'trim', 'labels': 'tokenizer'} result = assert_dataclass_signature_match( - audio_to_text.AudioToBPEDataset, configs.ASRDatasetConfig, ignore_args=IGNORE_ARGS, remap_args=REMAP_ARGS, + audio_to_text.AudioToBPEDataset, + configs.ASRDatasetConfig, + ignore_args=IGNORE_ARGS, + remap_args=REMAP_ARGS, ) signatures_match, cls_subset, dataclass_subset = result diff --git a/tests/collections/asr/test_asr_ctcencdec_model.py b/tests/collections/asr/test_asr_ctcencdec_model.py index b6e51e29f7b0..28a07fd54663 100644 --- a/tests/collections/asr/test_asr_ctcencdec_model.py +++ b/tests/collections/asr/test_asr_ctcencdec_model.py @@ -279,7 +279,10 @@ def test_ASRDatasetConfig_for_AudioToCharDataset(self): REMAP_ARGS = {'trim_silence': 'trim'} result = assert_dataclass_signature_match( - audio_to_text.AudioToCharDataset, configs.ASRDatasetConfig, ignore_args=IGNORE_ARGS, remap_args=REMAP_ARGS, + audio_to_text.AudioToCharDataset, + configs.ASRDatasetConfig, + ignore_args=IGNORE_ARGS, + remap_args=REMAP_ARGS, ) signatures_match, cls_subset, dataclass_subset = result diff --git a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py index d6a78dad96e4..1743acc6878c 100644 --- a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py +++ b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py @@ -64,12 +64,18 @@ def hybrid_asr_model(test_data_dir): decoder = { '_target_': 'nemo.collections.asr.modules.RNNTDecoder', - 'prednet': {'pred_hidden': model_defaults['pred_hidden'], 'pred_rnn_layers': 1,}, + 'prednet': { + 'pred_hidden': model_defaults['pred_hidden'], + 'pred_rnn_layers': 1, + }, } joint = { '_target_': 'nemo.collections.asr.modules.RNNTJoint', - 'jointnet': {'joint_hidden': 32, 'activation': 'relu',}, + 'jointnet': { + 'joint_hidden': 32, + 'activation': 'relu', + }, } decoding = {'strategy': 'greedy_batch', 'greedy': {'max_symbols': 30}} @@ -111,7 +117,8 @@ def hybrid_asr_model(test_data_dir): class TestEncDecHybridRNNTCTCBPEModel: @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.with_downloads() @pytest.mark.unit @@ -125,7 +132,8 @@ def test_constructor(self, hybrid_asr_model): @pytest.mark.with_downloads() @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_forward(self, hybrid_asr_model): @@ -160,7 +168,8 @@ def test_forward(self, hybrid_asr_model): @pytest.mark.with_downloads() @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_save_restore_artifact(self, hybrid_asr_model): @@ -178,7 +187,8 @@ def test_save_restore_artifact(self, hybrid_asr_model): @pytest.mark.with_downloads() @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_save_restore_artifact_spe(self, hybrid_asr_model, test_data_dir): @@ -224,7 +234,8 @@ def test_save_restore_artifact_agg(self, hybrid_asr_model, test_data_dir): @pytest.mark.with_downloads() @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_vocab_change(self, test_data_dir, hybrid_asr_model): @@ -255,7 +266,8 @@ def test_vocab_change(self, test_data_dir, hybrid_asr_model): @pytest.mark.with_downloads() @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_decoding_change(self, hybrid_asr_model): @@ -309,7 +321,8 @@ def test_decoding_change(self, hybrid_asr_model): assert hybrid_asr_model.cur_decoder == "ctc" @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_decoding_type_change(self, hybrid_asr_model): diff --git a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py index a1193b7698e8..a0d5627f1a65 100644 --- a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py +++ b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py @@ -117,7 +117,8 @@ def hybrid_asr_model(): class TestEncDecHybridRNNTCTCModel: @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_constructor(self, hybrid_asr_model): @@ -129,7 +130,8 @@ def test_constructor(self, hybrid_asr_model): assert isinstance(instance2, EncDecHybridRNNTCTCModel) @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_forward(self, hybrid_asr_model): @@ -163,7 +165,8 @@ def test_forward(self, hybrid_asr_model): assert diff <= 1e-6 @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_vocab_change(self, hybrid_asr_model): @@ -186,10 +189,12 @@ def test_vocab_change(self, hybrid_asr_model): assert hybrid_asr_model.ctc_decoder.vocabulary == hybrid_asr_model.joint.vocabulary @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_decoding_change(self, hybrid_asr_model): @@ -242,7 +247,8 @@ def test_decoding_change(self, hybrid_asr_model): assert hybrid_asr_model.ctc_decoding.compute_timestamps is True @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_decoding_type_change(self, hybrid_asr_model): @@ -306,7 +312,8 @@ def test_BeamRNNTInferConfig(self): assert dataclass_subset is None @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit @pytest.mark.parametrize( @@ -349,11 +356,13 @@ def test_greedy_decoding(self, greedy_class, loop_labels: Optional[bool]): _ = greedy(encoder_output=enc_out, encoded_lengths=enc_len) @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit @pytest.mark.parametrize( - "greedy_class", [greedy_decode.GreedyRNNTInfer], + "greedy_class", + [greedy_decode.GreedyRNNTInfer], ) def test_greedy_multi_decoding(self, greedy_class): token_list = [" ", "a", "b", "c"] @@ -386,7 +395,8 @@ def test_greedy_multi_decoding(self, greedy_class): _ = greedy(encoder_output=enc_out, encoded_lengths=enc_len, partial_hypotheses=partial_hyp) @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit @pytest.mark.parametrize( @@ -430,11 +440,13 @@ def test_greedy_decoding_stateless_decoder(self, greedy_class, loop_labels: Opti _ = greedy(encoder_output=enc_out, encoded_lengths=enc_len) @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit @pytest.mark.parametrize( - "greedy_class", [greedy_decode.GreedyRNNTInfer], + "greedy_class", + [greedy_decode.GreedyRNNTInfer], ) def test_greedy_multi_decoding_stateless_decoder(self, greedy_class): token_list = [" ", "a", "b", "c"] @@ -467,7 +479,8 @@ def test_greedy_multi_decoding_stateless_decoder(self, greedy_class): _ = greedy(encoder_output=enc_out, encoded_lengths=enc_len, partial_hypotheses=partial_hyp) @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit @pytest.mark.parametrize( @@ -522,7 +535,8 @@ def test_greedy_decoding_preserve_alignment(self, greedy_class, loop_labels: Opt assert torch.is_tensor(label) @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit @pytest.mark.parametrize( @@ -556,7 +570,12 @@ def test_beam_decoding(self, beam_config): decoder = RNNTDecoder(prednet_cfg, vocab_size) joint_net = RNNTJoint(jointnet_cfg, vocab_size, vocabulary=token_list) - beam = beam_decode.BeamRNNTInfer(decoder, joint_net, beam_size=beam_size, **beam_config,) + beam = beam_decode.BeamRNNTInfer( + decoder, + joint_net, + beam_size=beam_size, + **beam_config, + ) # (B, D, T) enc_out = torch.randn(1, encoder_output_size, 30) @@ -566,12 +585,16 @@ def test_beam_decoding(self, beam_config): _ = beam(encoder_output=enc_out, encoded_lengths=enc_len) @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit @pytest.mark.parametrize( "beam_config", - [{"search_type": "greedy"}, {"search_type": "default", "score_norm": False, "return_best_hypothesis": False},], + [ + {"search_type": "greedy"}, + {"search_type": "default", "score_norm": False, "return_best_hypothesis": False}, + ], ) def test_beam_decoding_preserve_alignments(self, beam_config): token_list = [" ", "a", "b", "c"] @@ -616,7 +639,8 @@ def test_beam_decoding_preserve_alignments(self, beam_config): assert torch.is_tensor(label) @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit @pytest.mark.parametrize( @@ -659,7 +683,8 @@ def test_greedy_decoding_SampledRNNTJoint(self, greedy_class, loop_labels: Optio _ = greedy(encoder_output=enc_out, encoded_lengths=enc_len) @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit @pytest.mark.parametrize( @@ -693,7 +718,12 @@ def test_beam_decoding_SampledRNNTJoint(self, beam_config): decoder = RNNTDecoder(prednet_cfg, vocab_size) joint_net = SampledRNNTJoint(jointnet_cfg, vocab_size, n_samples=2, vocabulary=token_list) - beam = beam_decode.BeamRNNTInfer(decoder, joint_net, beam_size=beam_size, **beam_config,) + beam = beam_decode.BeamRNNTInfer( + decoder, + joint_net, + beam_size=beam_size, + **beam_config, + ) # (B, D, T) enc_out = torch.randn(1, encoder_output_size, 30)