Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: titu1994 <[email protected]>
  • Loading branch information
titu1994 committed May 18, 2024
1 parent e47f141 commit 2066dc3
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
18 changes: 9 additions & 9 deletions nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
from nemo.utils import logging


def pack_hypotheses(hypotheses: List[rnnt_utils.Hypothesis], logitlen: torch.Tensor,) -> List[rnnt_utils.Hypothesis]:
def pack_hypotheses(
hypotheses: List[rnnt_utils.Hypothesis],
logitlen: torch.Tensor,
) -> List[rnnt_utils.Hypothesis]:

if logitlen is not None:
if hasattr(logitlen, 'cpu'):
Expand Down Expand Up @@ -54,6 +57,7 @@ def _states_to_device(dec_state, device='cpu'):

return dec_state


_DECODER_LENGTHS_NONE_WARNING = "Passing in decoder_lengths=None for CTC decoding is likely to be an error, since it is unlikely that each element of your batch has exactly the same length. decoder_lengths will default to decoder_output.shape[0]."


Expand Down Expand Up @@ -110,8 +114,7 @@ class GreedyCTCInfer(Typing, ConfidenceMethodMixin):

@property
def input_types(self):
"""Returns definitions of module input ports.
"""
"""Returns definitions of module input ports."""
# Input can be of dimension -
# ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels]

Expand All @@ -122,8 +125,7 @@ def input_types(self):

@property
def output_types(self):
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {"predictions": [NeuralType(elements_type=HypothesisType())]}

def __init__(
Expand Down Expand Up @@ -331,8 +333,7 @@ class GreedyBatchedCTCInfer(Typing, ConfidenceMethodMixin):

@property
def input_types(self):
"""Returns definitions of module input ports.
"""
"""Returns definitions of module input ports."""
# Input can be of dimension -
# ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels]

Expand All @@ -343,8 +344,7 @@ def input_types(self):

@property
def output_types(self):
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {"predictions": [NeuralType(elements_type=HypothesisType())]}

def __init__(
Expand Down
8 changes: 6 additions & 2 deletions tests/collections/asr/decoding/test_ctc_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def test_constructor_subword(self, tmp_tokenizer):
assert decoding is not None

@pytest.mark.unit
def test_char_decoding_greedy_forward(self,):
def test_char_decoding_greedy_forward(
self,
):
cfg = CTCDecodingConfig(strategy='greedy')
vocab = char_vocabulary()
decoding = CTCDecoding(decoding_cfg=cfg, vocabulary=vocab)
Expand Down Expand Up @@ -198,7 +200,9 @@ def test_subword_decoding_greedy_forward_hypotheses(self, tmp_tokenizer, alignme
@pytest.mark.parametrize('timestamps', [False, True])
@pytest.mark.parametrize('preserve_frame_confidence', [False, True])
@pytest.mark.parametrize('length_is_none', [False, True])
def test_batched_decoding_logprobs(self, tmp_tokenizer, alignments, timestamps, preserve_frame_confidence, length_is_none):
def test_batched_decoding_logprobs(
self, tmp_tokenizer, alignments, timestamps, preserve_frame_confidence, length_is_none
):
cfg = CTCBPEDecodingConfig(
strategy='greedy',
preserve_alignments=alignments,
Expand Down

0 comments on commit 2066dc3

Please sign in to comment.