Skip to content

Commit

Permalink
Add support for batched inference for label inputs as well.
Browse files Browse the repository at this point in the history
  • Loading branch information
galv committed May 2, 2024
1 parent 7853c4f commit 6779d39
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 10 deletions.
57 changes: 50 additions & 7 deletions nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,14 @@ def forward(

if self.batched_inference:
torch.cuda.nvtx.range_push("batched hypotheses")
hypotheses = self._greedy_decode_logprobs_batched(decoder_output, decoder_lengths)
if decoder_output.ndim == 2:
hypotheses = self._greedy_decode_labels_batched(decoder_output, decoder_lengths)
else:
hypotheses = self._greedy_decode_logprobs_batched(decoder_output, decoder_lengths)
torch.cuda.nvtx.range_pop()
packed_result = pack_hypotheses(hypotheses, decoder_lengths)
return (packed_result,)

with torch.inference_mode():
hypotheses = []
# Process each sequence independently
Expand Down Expand Up @@ -218,7 +222,6 @@ def forward(
torch.cuda.nvtx.range_pop()
return (packed_result,)

# Cannot naively call vmap because this does not return tensors...
@torch.no_grad()
def _greedy_decode_logprobs_batched(self, x: torch.Tensor, out_len: torch.Tensor):
# x: [B, T, D]
Expand Down Expand Up @@ -248,19 +251,59 @@ def _greedy_decode_logprobs_batched(self, x: torch.Tensor, out_len: torch.Tensor
hypothesis = rnnt_utils.Hypothesis(score=0.0, y_sequence=[], dec_state=None, timestep=[], last_token=None)
hypothesis.score = scores[i]

# prediction_labels_no_padding = predictions_labels[i, :labels_segments[i]]
prediction_labels_no_padding = predictions_labels[i, :out_len[i]]
prediction_labels_no_padding = predictions_labels[i, :out_len[i]].tolist()

assert predictions_labels.dtype == torch.int64
hypothesis.y_sequence = prediction_labels_no_padding

if self.preserve_alignments:
hypothesis.alignments = (predictions[i].clone(), predictions_labels[i].clone())
hypothesis.alignments = (predictions[i, :out_len[i], :].clone(),
predictions_labels[i, :out_len[i]].clone())
if self.compute_timestamps:
# Could do this in a vectorized manner...
# TOOD: Could do this in a vectorized manner... Would
# prefer to have nonzero_static, though, for sanity.
hypothesis.timestep = torch.nonzero(non_blank_ids_mask[i], as_tuple=False)[:, 0].cpu().tolist()
if self.preserve_frame_confidence:
hypothesis.frame_confidence = self._get_confidence(predictions[i])
hypothesis.frame_confidence = self._get_confidence(predictions[i, :out_len[i], :])

hypotheses.append(hypothesis)

return hypotheses

@torch.no_grad()
def _greedy_decode_labels_batched(self, x: torch.Tensor, out_len: torch.Tensor):
# x: [B, T]
# out_len: [B]

batch_size = x.shape[0]
max_time = x.shape[1]

predictions_labels = x
time_steps = torch.arange(max_time, device=x.device).unsqueeze(0).expand(batch_size, max_time)
non_blank_ids_mask = torch.logical_and(predictions_labels != self.blank_id,
time_steps < out_len.unsqueeze(1))
predictions_labels = predictions_labels.cpu()
out_len = out_len.cpu()

hypotheses = []

for i in range(batch_size):
hypothesis = rnnt_utils.Hypothesis(score=0.0, y_sequence=[], dec_state=None, timestep=[], last_token=None)
hypothesis.y_sequence = predictions_labels[i, :out_len[i]].tolist()
hypothesis.score = -1.0

if self.preserve_alignments:
raise ValueError("Requested for alignments, but predictions provided were labels, not log probabilities.")
if self.compute_timestamps:
# TOOD: Could do this in a vectorized manner... Would
# prefer to have nonzero_static, though, for sanity.
# Or do a prefix sum on out_len
hypothesis.timestep = torch.nonzero(non_blank_ids_mask[i], as_tuple=False)[:, 0].cpu().tolist()
if self.preserve_frame_confidence:
raise ValueError(
"Requested for per-frame confidence, but predictions provided were labels, not log probabilities."
)


hypotheses.append(hypothesis)

Expand Down
48 changes: 45 additions & 3 deletions tests/collections/asr/decoding/test_ctc_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
CTCDecoding,
CTCDecodingConfig,
)
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceConfig
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis


Expand Down Expand Up @@ -195,8 +196,10 @@ def test_subword_decoding_greedy_forward_hypotheses(self, tmp_tokenizer, alignme
@pytest.mark.unit
@pytest.mark.parametrize('alignments', [False, True])
@pytest.mark.parametrize('timestamps', [False, True])
def test_batched_decoding(self, tmp_tokenizer, alignments, timestamps):
cfg = CTCBPEDecodingConfig(strategy='greedy', preserve_alignments=alignments, compute_timestamps=timestamps)
@pytest.mark.parametrize('preserve_frame_confidence', [False, True])
def test_batched_decoding_logprobs(self, tmp_tokenizer, alignments, timestamps, preserve_frame_confidence):
# timestamps not working...
cfg = CTCBPEDecodingConfig(strategy='greedy', preserve_alignments=alignments, compute_timestamps=timestamps, confidence_cfg=ConfidenceConfig(preserve_frame_confidence=preserve_frame_confidence))
cfg.greedy.batched_inference = False
unbatched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer)

Expand All @@ -207,11 +210,13 @@ def test_batched_decoding(self, tmp_tokenizer, alignments, timestamps):
B, T = 4, 20
V = unbatched_decoding.tokenizer.tokenizer.vocab_size + 1
input_signal = torch.randn(size=(B, T, V))
# Set the blank index to a very high probability to make sure
# that we always handle at least a few blanks.
input_signal[:, 0, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000
input_signal[:, 1, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000
length = torch.randint(low=1, high=T, size=[B])

with torch.no_grad():
with torch.inference_mode():
hyps, _ = unbatched_decoding.ctc_decoder_predictions_tensor(
input_signal, length, fold_consecutive=True, return_hypotheses=True
)
Expand All @@ -229,3 +234,40 @@ def test_batched_decoding(self, tmp_tokenizer, alignments, timestamps):
if alignments:
assert torch.all(hyp.alignments[0] == batched_hyp.alignments[0])
assert torch.all(hyp.alignments[1] == batched_hyp.alignments[1])


@pytest.mark.unit
@pytest.mark.parametrize('timestamps', [False, True])
def test_batched_decoding_labels(self, tmp_tokenizer, timestamps):
cfg = CTCBPEDecodingConfig(strategy='greedy', compute_timestamps=timestamps)
cfg.greedy.batched_inference = False
unbatched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer)

cfg.greedy.batched_inference = True
batched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer)

torch.manual_seed(1)
B, T = 4, 20
V = unbatched_decoding.tokenizer.tokenizer.vocab_size + 1
input_labels = torch.randint(V, size=(B, T))
# Set the blank index to a very high probability to make sure
# that we always handle at least a few blanks.
input_labels[:, 0] = unbatched_decoding.tokenizer.tokenizer.vocab_size
input_labels[:, 1] = unbatched_decoding.tokenizer.tokenizer.vocab_size
length = torch.randint(low=1, high=T, size=[B])

with torch.inference_mode():
hyps, _ = unbatched_decoding.ctc_decoder_predictions_tensor(
input_labels, length, fold_consecutive=True, return_hypotheses=True
)

batched_hyps, _ = batched_decoding.ctc_decoder_predictions_tensor(
input_labels, length, fold_consecutive=True, return_hypotheses=True
)

assert len(hyps) == len(batched_hyps) == B
for hyp, batched_hyp in zip(hyps, batched_hyps):
assert abs(hyp.score - batched_hyp.score) <= 1e-5
assert torch.all(hyp.y_sequence == batched_hyp.y_sequence)
if timestamps:
assert hyp.timestep == batched_hyp.timestep

0 comments on commit 6779d39

Please sign in to comment.