Skip to content

Commit

Permalink
Small fixups
Browse files Browse the repository at this point in the history
  • Loading branch information
galv committed May 6, 2024
1 parent d374dde commit e193293
Showing 1 changed file with 11 additions and 15 deletions.
26 changes: 11 additions & 15 deletions nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class GreedyCTCInfer(Typing, ConfidenceMethodMixin):
def input_types(self):
"""Returns definitions of module input ports.
"""
# Input can be of dimention -
# Input can be of dimension -
# ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels]

return {
Expand All @@ -131,7 +131,6 @@ def __init__(
compute_timestamps: bool = False,
preserve_frame_confidence: bool = False,
confidence_method_cfg: Optional[DictConfig] = None,
batched_inference: bool = False,
):
super().__init__()

Expand All @@ -144,8 +143,6 @@ def __init__(
# set confidence calculation method
self._init_confidence_method(confidence_method_cfg)

self.batched_inference = batched_inference

@typecheck()
def forward(
self, decoder_output: torch.Tensor, decoder_lengths: torch.Tensor,
Expand All @@ -161,7 +158,6 @@ def forward(
Returns:
packed list containing batch number of sentences (Hypotheses).
"""

with torch.inference_mode():
hypotheses = []
# Process each sequence independently
Expand All @@ -183,9 +179,7 @@ def forward(
# each scalar from GPU to CPU one at a time, in the line:
# prediction = prediction[:out_len]
# Doing one GPU to CPU copy ahead of time amortizes that overhead.
decoder_lengths = (
decoder_lengths.cpu()
) # synchronizes. But this synchronizations is necessary and appropriate.
decoder_lengths = decoder_lengths.cpu()

if prediction_cpu_tensor.ndim < 2 or prediction_cpu_tensor.ndim > 3:
raise ValueError(
Expand All @@ -201,7 +195,6 @@ def forward(

for ind in range(prediction_cpu_tensor.shape[0]):
out_len = decoder_lengths[ind] if decoder_lengths is not None else None
# Gross, why are we doing this on CPU one at a time?
hypothesis = greedy_decode(prediction_cpu_tensor[ind], out_len)
hypotheses.append(hypothesis)

Expand Down Expand Up @@ -279,9 +272,9 @@ class GreedyVectorizedCTCInfer(Typing, ConfidenceMethodMixin):
This is basically always faster than GreedyCTCInfer, and supports
the same interface. See issue #8891 on github for what is wrong
with GreedyCTCInfer. GreedyCTCInfer loops over each element in the
batch, running kernels one at a time. This implementation does
appropriate masking to appropriately do the same operation in a
vectorized manner.
batch, running kernels at batch size one. CPU overheads end up
dominating. This implementation does appropriate masking to
appropriately do the same operation in a vectorized manner.
Args:
blank_index: int index of the blank token. Can be 0 or len(vocabulary).
Expand Down Expand Up @@ -333,7 +326,7 @@ class GreedyVectorizedCTCInfer(Typing, ConfidenceMethodMixin):
def input_types(self):
"""Returns definitions of module input ports.
"""
# Input can be of dimention -
# Input can be of dimension -
# ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels]

return {
Expand Down Expand Up @@ -420,7 +413,7 @@ def _greedy_decode_logprobs_batched(self, x: torch.Tensor, out_len: torch.Tensor

hypotheses = []

# This mimics the for loop in GreedyCTCInfer::_greedy_decode_logprobs
# This mimics the for loop in GreedyCTCInfer::forward.
for i in range(batch_size):
hypothesis = rnnt_utils.Hypothesis(score=0.0, y_sequence=[], dec_state=None, timestep=[], last_token=None)
hypothesis.score = scores[i]
Expand Down Expand Up @@ -449,6 +442,10 @@ def _greedy_decode_logprobs_batched(self, x: torch.Tensor, out_len: torch.Tensor

@torch.no_grad()
def _greedy_decode_labels_batched(self, x: torch.Tensor, out_len: torch.Tensor):
"""
This does greedy decoding in the case where you have already found the
most likely token at each timestep.
"""
# x: [B, T]
# out_len: [B]

Expand Down Expand Up @@ -496,7 +493,6 @@ class GreedyCTCInferConfig:
compute_timestamps: bool = False
preserve_frame_confidence: bool = False
confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig())
batched_inference: bool = False

def __post_init__(self):
# OmegaConf.structured ensures that post_init check is always executed
Expand Down

0 comments on commit e193293

Please sign in to comment.