Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: galv <[email protected]>
  • Loading branch information
galv committed May 17, 2024
1 parent 366ed7d commit bd8c5d4
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 44 deletions.
4 changes: 3 additions & 1 deletion nemo/collections/asr/parts/submodules/ctc_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,9 @@ class CTCDecoding(AbstractCTCDecoding):
"""

def __init__(
self, decoding_cfg, vocabulary,
self,
decoding_cfg,
vocabulary,
):
blank_id = len(vocabulary)
self.vocabulary = vocabulary
Expand Down
25 changes: 14 additions & 11 deletions nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
from nemo.utils import logging, logging_mode


def pack_hypotheses(hypotheses: List[rnnt_utils.Hypothesis], logitlen: torch.Tensor,) -> List[rnnt_utils.Hypothesis]:
def pack_hypotheses(
hypotheses: List[rnnt_utils.Hypothesis],
logitlen: torch.Tensor,
) -> List[rnnt_utils.Hypothesis]:

if logitlen is not None:
if hasattr(logitlen, 'cpu'):
Expand Down Expand Up @@ -108,8 +111,7 @@ class GreedyCTCInfer(Typing, ConfidenceMethodMixin):

@property
def input_types(self):
"""Returns definitions of module input ports.
"""
"""Returns definitions of module input ports."""
# Input can be of dimension -
# ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels]

Expand All @@ -120,8 +122,7 @@ def input_types(self):

@property
def output_types(self):
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {"predictions": [NeuralType(elements_type=HypothesisType())]}

def __init__(
Expand All @@ -145,7 +146,9 @@ def __init__(

@typecheck()
def forward(
self, decoder_output: torch.Tensor, decoder_lengths: torch.Tensor,
self,
decoder_output: torch.Tensor,
decoder_lengths: torch.Tensor,
):
"""Returns a list of hypotheses given an input batch of the encoder hidden embedding.
Output token is generated auto-repressively.
Expand Down Expand Up @@ -330,8 +333,7 @@ class GreedyBatchedCTCInfer(Typing, ConfidenceMethodMixin):

@property
def input_types(self):
"""Returns definitions of module input ports.
"""
"""Returns definitions of module input ports."""
# Input can be of dimension -
# ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels]

Expand All @@ -342,8 +344,7 @@ def input_types(self):

@property
def output_types(self):
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {"predictions": [NeuralType(elements_type=HypothesisType())]}

def __init__(
Expand All @@ -367,7 +368,9 @@ def __init__(

@typecheck()
def forward(
self, decoder_output: torch.Tensor, decoder_lengths: torch.Tensor,
self,
decoder_output: torch.Tensor,
decoder_lengths: torch.Tensor,
):
"""Returns a list of hypotheses given an input batch of the encoder hidden embedding.
Output token is generated auto-repressively.
Expand Down
4 changes: 3 additions & 1 deletion tests/collections/asr/decoding/test_ctc_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def test_constructor_subword(self, tmp_tokenizer):
assert decoding is not None

@pytest.mark.unit
def test_char_decoding_greedy_forward(self,):
def test_char_decoding_greedy_forward(
self,
):
cfg = CTCDecodingConfig(strategy='greedy')
vocab = char_vocabulary()
decoding = CTCDecoding(decoding_cfg=cfg, vocabulary=vocab)
Expand Down
5 changes: 4 additions & 1 deletion tests/collections/asr/test_asr_ctc_encoder_model_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,10 @@ def test_ASRDatasetConfig_for_AudioToBPEDataset(self):
REMAP_ARGS = {'trim_silence': 'trim', 'labels': 'tokenizer'}

result = assert_dataclass_signature_match(
audio_to_text.AudioToBPEDataset, configs.ASRDatasetConfig, ignore_args=IGNORE_ARGS, remap_args=REMAP_ARGS,
audio_to_text.AudioToBPEDataset,
configs.ASRDatasetConfig,
ignore_args=IGNORE_ARGS,
remap_args=REMAP_ARGS,
)
signatures_match, cls_subset, dataclass_subset = result

Expand Down
5 changes: 4 additions & 1 deletion tests/collections/asr/test_asr_ctcencdec_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,10 @@ def test_ASRDatasetConfig_for_AudioToCharDataset(self):
REMAP_ARGS = {'trim_silence': 'trim'}

result = assert_dataclass_signature_match(
audio_to_text.AudioToCharDataset, configs.ASRDatasetConfig, ignore_args=IGNORE_ARGS, remap_args=REMAP_ARGS,
audio_to_text.AudioToCharDataset,
configs.ASRDatasetConfig,
ignore_args=IGNORE_ARGS,
remap_args=REMAP_ARGS,
)
signatures_match, cls_subset, dataclass_subset = result

Expand Down
31 changes: 22 additions & 9 deletions tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,18 @@ def hybrid_asr_model(test_data_dir):

decoder = {
'_target_': 'nemo.collections.asr.modules.RNNTDecoder',
'prednet': {'pred_hidden': model_defaults['pred_hidden'], 'pred_rnn_layers': 1,},
'prednet': {
'pred_hidden': model_defaults['pred_hidden'],
'pred_rnn_layers': 1,
},
}

joint = {
'_target_': 'nemo.collections.asr.modules.RNNTJoint',
'jointnet': {'joint_hidden': 32, 'activation': 'relu',},
'jointnet': {
'joint_hidden': 32,
'activation': 'relu',
},
}

decoding = {'strategy': 'greedy_batch', 'greedy': {'max_symbols': 30}}
Expand Down Expand Up @@ -111,7 +117,8 @@ def hybrid_asr_model(test_data_dir):

class TestEncDecHybridRNNTCTCBPEModel:
@pytest.mark.skipif(
not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.',
not NUMBA_RNNT_LOSS_AVAILABLE,
reason='RNNTLoss has not been compiled with appropriate numba version.',
)
@pytest.mark.with_downloads()
@pytest.mark.unit
Expand All @@ -125,7 +132,8 @@ def test_constructor(self, hybrid_asr_model):

@pytest.mark.with_downloads()
@pytest.mark.skipif(
not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.',
not NUMBA_RNNT_LOSS_AVAILABLE,
reason='RNNTLoss has not been compiled with appropriate numba version.',
)
@pytest.mark.unit
def test_forward(self, hybrid_asr_model):
Expand Down Expand Up @@ -160,7 +168,8 @@ def test_forward(self, hybrid_asr_model):

@pytest.mark.with_downloads()
@pytest.mark.skipif(
not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.',
not NUMBA_RNNT_LOSS_AVAILABLE,
reason='RNNTLoss has not been compiled with appropriate numba version.',
)
@pytest.mark.unit
def test_save_restore_artifact(self, hybrid_asr_model):
Expand All @@ -178,7 +187,8 @@ def test_save_restore_artifact(self, hybrid_asr_model):

@pytest.mark.with_downloads()
@pytest.mark.skipif(
not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.',
not NUMBA_RNNT_LOSS_AVAILABLE,
reason='RNNTLoss has not been compiled with appropriate numba version.',
)
@pytest.mark.unit
def test_save_restore_artifact_spe(self, hybrid_asr_model, test_data_dir):
Expand Down Expand Up @@ -224,7 +234,8 @@ def test_save_restore_artifact_agg(self, hybrid_asr_model, test_data_dir):

@pytest.mark.with_downloads()
@pytest.mark.skipif(
not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.',
not NUMBA_RNNT_LOSS_AVAILABLE,
reason='RNNTLoss has not been compiled with appropriate numba version.',
)
@pytest.mark.unit
def test_vocab_change(self, test_data_dir, hybrid_asr_model):
Expand Down Expand Up @@ -255,7 +266,8 @@ def test_vocab_change(self, test_data_dir, hybrid_asr_model):

@pytest.mark.with_downloads()
@pytest.mark.skipif(
not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.',
not NUMBA_RNNT_LOSS_AVAILABLE,
reason='RNNTLoss has not been compiled with appropriate numba version.',
)
@pytest.mark.unit
def test_decoding_change(self, hybrid_asr_model):
Expand Down Expand Up @@ -309,7 +321,8 @@ def test_decoding_change(self, hybrid_asr_model):
assert hybrid_asr_model.cur_decoder == "ctc"

@pytest.mark.skipif(
not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.',
not NUMBA_RNNT_LOSS_AVAILABLE,
reason='RNNTLoss has not been compiled with appropriate numba version.',
)
@pytest.mark.unit
def test_decoding_type_change(self, hybrid_asr_model):
Expand Down
Loading

0 comments on commit bd8c5d4

Please sign in to comment.