Skip to content

Commit

Permalink
Clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
galv committed May 2, 2024
1 parent b02af5f commit f44aa43
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 11 deletions.
8 changes: 0 additions & 8 deletions nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,10 @@ def forward(
"""

if self.batched_inference:
torch.cuda.nvtx.range_push("batched hypotheses")
if decoder_output.ndim == 2:
hypotheses = self._greedy_decode_labels_batched(decoder_output, decoder_lengths)
else:
hypotheses = self._greedy_decode_logprobs_batched(decoder_output, decoder_lengths)
torch.cuda.nvtx.range_pop()
packed_result = pack_hypotheses(hypotheses, decoder_lengths)
return (packed_result,)

Expand Down Expand Up @@ -208,20 +206,16 @@ def forward(
greedy_decode = self._greedy_decode_labels
else:
greedy_decode = self._greedy_decode_logprobs
torch.cuda.nvtx.range_push("hypotheses")

for ind in range(prediction_cpu_tensor.shape[0]):
out_len = decoder_lengths[ind] if decoder_lengths is not None else None
# Gross, why are we doing this on CPU one at a time?
hypothesis = greedy_decode(prediction_cpu_tensor[ind], out_len)
hypotheses.append(hypothesis)
torch.cuda.nvtx.range_pop()

torch.cuda.nvtx.range_push("pack hypotheses")
# Pack results into Hypotheses
packed_result = pack_hypotheses(hypotheses, decoder_lengths)

torch.cuda.nvtx.range_pop()
return (packed_result,)

@torch.no_grad()
Expand Down Expand Up @@ -324,9 +318,7 @@ def _greedy_decode_logprobs(self, x: torch.Tensor, out_len: torch.Tensor):
if out_len is not None:
prediction = prediction[:out_len]

torch.cuda.nvtx.range_push("max")
prediction_logprobs, prediction_labels = prediction.max(dim=-1)
torch.cuda.nvtx.range_pop()

non_blank_ids = prediction_labels != self.blank_id
hypothesis.y_sequence = prediction_labels.tolist()
Expand Down
5 changes: 2 additions & 3 deletions tests/collections/asr/decoding/test_ctc_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ def test_subword_decoding_greedy_forward_hypotheses(self, tmp_tokenizer, alignme
@pytest.mark.parametrize('timestamps', [False, True])
@pytest.mark.parametrize('preserve_frame_confidence', [False, True])
def test_batched_decoding_logprobs(self, tmp_tokenizer, alignments, timestamps, preserve_frame_confidence):
# timestamps not working...
cfg = CTCBPEDecodingConfig(
strategy='greedy',
preserve_alignments=alignments,
Expand Down Expand Up @@ -254,8 +253,8 @@ def test_batched_decoding_labels(self, tmp_tokenizer, timestamps):
B, T = 4, 20
V = unbatched_decoding.tokenizer.tokenizer.vocab_size + 1
input_labels = torch.randint(V, size=(B, T))
# Set the blank index to a very high probability to make sure
# that we always handle at least a few blanks.
# Set some indices to blank to make sure that we always handle
# at least a few blanks.
input_labels[:, 0] = unbatched_decoding.tokenizer.tokenizer.vocab_size
input_labels[:, 1] = unbatched_decoding.tokenizer.tokenizer.vocab_size
length = torch.randint(low=1, high=T, size=[B])
Expand Down

0 comments on commit f44aa43

Please sign in to comment.