Skip to content

Commit

Permalink
Add loop_labels algorithm for TDT greedy decoding (#8215)
Browse files Browse the repository at this point in the history
* Add `loop_labels` algorithm for TDT greedy decoding

Signed-off-by: Vladimir Bataev <[email protected]>

* Use `loop_labels` by default

Signed-off-by: Vladimir Bataev <[email protected]>

* Loop labels greedy decoding v2

Signed-off-by: Vladimir Bataev <[email protected]>

* Add comments. Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Add comments

Signed-off-by: Vladimir Bataev <[email protected]>

* Add comments

Signed-off-by: Vladimir Bataev <[email protected]>

* Add tests for batched hypotheses

Signed-off-by: Vladimir Bataev <[email protected]>

* Add tests for batched alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Add comments

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix comment

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix test

Signed-off-by: Vladimir Bataev <[email protected]>

* Add computer for TDT

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix TDT decoding algorithm

Signed-off-by: Vladimir Bataev <[email protected]>

* Use loop frames by default for TDT

Signed-off-by: Vladimir Bataev <[email protected]>

* Remove "loop frames" implementation for TDT

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Add comments

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix confidence. Use tensor for durations.

Signed-off-by: Vladimir Bataev <[email protected]>

---------

Signed-off-by: Vladimir Bataev <[email protected]>
  • Loading branch information
artbataev authored and yaoyu-33 committed Feb 26, 2024
1 parent 49b54c3 commit 9a183d6
Show file tree
Hide file tree
Showing 2 changed files with 301 additions and 166 deletions.
199 changes: 33 additions & 166 deletions nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from nemo.collections.asr.modules import rnnt_abstract
from nemo.collections.asr.parts.submodules.rnnt_loop_labels_computer import GreedyBatchedRNNTLoopLabelsComputer
from nemo.collections.asr.parts.submodules.tdt_loop_labels_computer import GreedyBatchedTDTLoopLabelsComputer
from nemo.collections.asr.parts.utils import rnnt_utils
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodConfig, ConfidenceMethodMixin
from nemo.collections.common.parts.rnn import label_collate
Expand Down Expand Up @@ -2638,8 +2639,20 @@ def __init__(

# Depending on availability of `blank_as_pad` support
# switch between more efficient batch decoding technique
self._decoding_computer = None
if self.decoder.blank_as_pad:
self._greedy_decode = self._greedy_decode_blank_as_pad
# batched "loop frames" is not implemented for TDT
self._decoding_computer = GreedyBatchedTDTLoopLabelsComputer(
decoder=self.decoder,
joint=self.joint,
blank_index=self._blank_index,
durations=self.durations,
max_symbols_per_step=self.max_symbols,
preserve_alignments=preserve_alignments,
preserve_frame_confidence=preserve_frame_confidence,
confidence_method_cfg=confidence_method_cfg,
)
self._greedy_decode = self._greedy_decode_blank_as_pad_loop_labels
else:
self._greedy_decode = self._greedy_decode_masked

Expand Down Expand Up @@ -2685,179 +2698,33 @@ def forward(

return (packed_result,)

def _greedy_decode_blank_as_pad(
def _greedy_decode_masked(
self,
x: torch.Tensor,
out_len: torch.Tensor,
device: torch.device,
partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None,
):
if partial_hypotheses is not None:
raise NotImplementedError("`partial_hypotheses` support is not supported")

with torch.inference_mode():
# x: [B, T, D]
# out_len: [B]
# device: torch.device

# Initialize list of Hypothesis
batchsize = x.shape[0]
hypotheses = [
rnnt_utils.Hypothesis(score=0.0, y_sequence=[], timestep=[], dec_state=None) for _ in range(batchsize)
]

# Initialize Hidden state matrix (shared by entire batch)
hidden = None

# If alignments need to be preserved, register a danling list to hold the values
if self.preserve_alignments:
# alignments is a 3-dimensional dangling list representing B x T x U
for hyp in hypotheses:
hyp.alignments = [[]]

# If confidence scores need to be preserved, register a danling list to hold the values
if self.preserve_frame_confidence:
# frame_confidence is a 3-dimensional dangling list representing B x T x U
for hyp in hypotheses:
hyp.frame_confidence = [[]]

# Last Label buffer + Last Label without blank buffer
# batch level equivalent of the last_label
last_label = torch.full([batchsize, 1], fill_value=self._blank_index, dtype=torch.long, device=device)

# Mask buffers
blank_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device)

# Get max sequence length
max_out_len = out_len.max()

# skip means the number of frames the next decoding step should "jump" to. When skip == 1
# it means the next decoding step will just use the next input frame.
skip = 1
for time_idx in range(max_out_len):
if skip > 1: # if skip > 1 at the current step, we decrement it and skip the current frame.
skip -= 1
continue
f = x.narrow(dim=1, start=time_idx, length=1) # [B, 1, D]

# need_to_stay is a boolean indicates whether the next decoding step should remain in the same frame.
need_to_stay = True
symbols_added = 0

# Reset blank mask
blank_mask.mul_(False)

# Update blank mask with time mask
# Batch: [B, T, D], but Bi may have seq len < max(seq_lens_in_batch)
# Forcibly mask with "blank" tokens, for all sample where current time step T > seq_len
blank_mask = time_idx >= out_len

# Start inner loop
while need_to_stay and (self.max_symbols is None or symbols_added < self.max_symbols):
# Batch prediction and joint network steps
# If very first prediction step, submit SOS tag (blank) to pred_step.
# This feeds a zero tensor as input to AbstractRNNTDecoder to prime the state
if time_idx == 0 and symbols_added == 0 and hidden is None:
g, hidden_prime = self._pred_step(self._SOS, hidden, batch_size=batchsize)
else:
# Perform batch step prediction of decoder, getting new states and scores ("g")
g, hidden_prime = self._pred_step(last_label, hidden, batch_size=batchsize)

# Batched joint step - Output = [B, V + 1 + num-big-blanks]
# Note: log_normalize must not be True here since the joiner output is contanetation of both token logits and duration logits,
# and they need to be normalized independently.
joined = self._joint_step(f, g, log_normalize=None)
logp = joined[:, 0, 0, : -len(self.durations)]
duration_logp = joined[:, 0, 0, -len(self.durations) :]

if logp.dtype != torch.float32:
logp = logp.float()
duration_logp = duration_logp.float()

# get the max for both token and duration predictions.
v, k = logp.max(1)
dv, dk = duration_logp.max(1)

# here we set the skip value to be the minimum of all predicted durations, hense the "torch.min(dk)" call there.
# Please refer to Section 5.2 of our paper https://arxiv.org/pdf/2304.06795.pdf for explanation of this.
skip = self.durations[int(torch.min(dk))]

# this is a special case: if all batches emit blanks, we require that skip be at least 1
# so we don't loop forever at the current frame.
if blank_mask.all():
if skip == 0:
skip = 1

need_to_stay = skip == 0
del g

# Update blank mask with current predicted blanks
# This is accumulating blanks over all time steps T and all target steps min(max_symbols, U)
k_is_blank = k == self._blank_index
blank_mask.bitwise_or_(k_is_blank)

del k_is_blank
del logp, duration_logp

# If all samples predict / have predicted prior blanks, exit loop early
# This is equivalent to if single sample predicted k
if not blank_mask.all():
# Collect batch indices where blanks occurred now/past
blank_indices = (blank_mask == 1).nonzero(as_tuple=False)

# Recover prior state for all samples which predicted blank now/past
if hidden is not None:
hidden_prime = self.decoder.batch_copy_states(hidden_prime, hidden, blank_indices)

elif len(blank_indices) > 0 and hidden is None:
# Reset state if there were some blank and other non-blank predictions in batch
# Original state is filled with zeros so we just multiply
# LSTM has 2 states
hidden_prime = self.decoder.batch_copy_states(hidden_prime, None, blank_indices, value=0.0)

# Recover prior predicted label for all samples which predicted blank now/past
k[blank_indices] = last_label[blank_indices, 0]

# Update new label and hidden state for next iteration
last_label = k.clone().view(-1, 1)
hidden = hidden_prime

# Update predicted labels, accounting for time mask
# If blank was predicted even once, now or in the past,
# Force the current predicted label to also be blank
# This ensures that blanks propogate across all timesteps
# once they have occured (normally stopping condition of sample level loop).
for kidx, ki in enumerate(k):
if blank_mask[kidx] == 0:
hypotheses[kidx].y_sequence.append(ki)
hypotheses[kidx].timestep.append(time_idx)
hypotheses[kidx].score += float(v[kidx])

symbols_added += 1

# Remove trailing empty list of alignments at T_{am-len} x Uj
if self.preserve_alignments:
for batch_idx in range(batchsize):
if len(hypotheses[batch_idx].alignments[-1]) == 0:
del hypotheses[batch_idx].alignments[-1]

# Remove trailing empty list of confidence scores at T_{am-len} x Uj
if self.preserve_frame_confidence:
for batch_idx in range(batchsize):
if len(hypotheses[batch_idx].frame_confidence[-1]) == 0:
del hypotheses[batch_idx].frame_confidence[-1]

# Preserve states
for batch_idx in range(batchsize):
hypotheses[batch_idx].dec_state = self.decoder.batch_select_state(hidden, batch_idx)

return hypotheses
raise NotImplementedError("masked greedy-batched decode is not supported for TDT models.")

def _greedy_decode_masked(
@torch.inference_mode()
def _greedy_decode_blank_as_pad_loop_labels(
self,
x: torch.Tensor,
out_len: torch.Tensor,
device: torch.device,
partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None,
):
raise NotImplementedError("masked greedy-batched decode is not supported for TDT models.")
partial_hypotheses: Optional[list[rnnt_utils.Hypothesis]] = None,
) -> list[rnnt_utils.Hypothesis]:
"""
Optimized batched greedy decoding.
The main idea: search for next labels for the whole batch (evaluating Joint)
and thus always evaluate prediction network with maximum possible batch size
"""
if partial_hypotheses is not None:
raise NotImplementedError("`partial_hypotheses` support is not implemented")

batched_hyps, alignments, last_decoder_state = self._decoding_computer(x=x, out_len=out_len)
hyps = rnnt_utils.batched_hyps_to_hypotheses(batched_hyps, alignments)
for hyp, state in zip(hyps, self.decoder.batch_split_states(last_decoder_state)):
hyp.dec_state = state
return hyps
Loading

0 comments on commit 9a183d6

Please sign in to comment.