Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add loop_labels algorithm for TDT greedy decoding #8215

Merged
merged 24 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
4a849ee
Add `loop_labels` algorithm for TDT greedy decoding
artbataev Jan 22, 2024
882ad62
Use `loop_labels` by default
artbataev Jan 22, 2024
abb998c
Loop labels greedy decoding v2
artbataev Jan 31, 2024
c4a5672
Add comments. Clean up
artbataev Jan 31, 2024
1ce53d3
Add comments
artbataev Jan 31, 2024
eb1b425
Add comments
artbataev Jan 31, 2024
5982b4d
Add tests for batched hypotheses
artbataev Jan 31, 2024
c2a8b1f
Add tests for batched alignments
artbataev Jan 31, 2024
e6b6bb7
Add comments
artbataev Jan 31, 2024
fb6000f
Fix comment
artbataev Feb 1, 2024
3445357
Merge branch 'main' into rnnt_greedy_loop_lables_v2
artbataev Feb 1, 2024
00ffb75
Fix test
artbataev Feb 1, 2024
712a78c
Merge remote-tracking branch 'origin/rnnt_greedy_loop_lables_v2' into…
artbataev Feb 1, 2024
cc8e7d5
Merge branch 'rnnt_greedy_loop_lables_v2' into tdt_decode_loop_labels
artbataev Feb 1, 2024
645303b
Add computer for TDT
artbataev Feb 1, 2024
85e5591
Fix TDT decoding algorithm
artbataev Feb 1, 2024
28416f0
Merge branch 'main' into tdt_decode_loop_labels
artbataev Feb 12, 2024
33c685c
Use loop frames by default for TDT
artbataev Feb 12, 2024
938a97b
Remove "loop frames" implementation for TDT
artbataev Feb 12, 2024
aa63e96
Clean up
artbataev Feb 12, 2024
8c98235
Add comments
artbataev Feb 12, 2024
db72408
Fix confidence. Use tensor for durations.
artbataev Feb 14, 2024
f10442a
Merge branch 'main' into tdt_decode_loop_labels
artbataev Feb 14, 2024
4fe22f6
Merge branch 'main' into tdt_decode_loop_labels
artbataev Feb 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 33 additions & 166 deletions nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from nemo.collections.asr.modules import rnnt_abstract
from nemo.collections.asr.parts.submodules.rnnt_loop_labels_computer import GreedyBatchedRNNTLoopLabelsComputer
from nemo.collections.asr.parts.submodules.tdt_loop_labels_computer import GreedyBatchedTDTLoopLabelsComputer
from nemo.collections.asr.parts.utils import rnnt_utils
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodConfig, ConfidenceMethodMixin
from nemo.collections.common.parts.rnn import label_collate
Expand Down Expand Up @@ -2638,8 +2639,20 @@ def __init__(

# Depending on availability of `blank_as_pad` support
# switch between more efficient batch decoding technique
self._decoding_computer = None
if self.decoder.blank_as_pad:
self._greedy_decode = self._greedy_decode_blank_as_pad
# batched "loop frames" is not implemented for TDT
self._decoding_computer = GreedyBatchedTDTLoopLabelsComputer(
decoder=self.decoder,
joint=self.joint,
blank_index=self._blank_index,
durations=self.durations,
max_symbols_per_step=self.max_symbols,
preserve_alignments=preserve_alignments,
preserve_frame_confidence=preserve_frame_confidence,
confidence_method_cfg=confidence_method_cfg,
)
self._greedy_decode = self._greedy_decode_blank_as_pad_loop_labels
else:
self._greedy_decode = self._greedy_decode_masked

Expand Down Expand Up @@ -2685,179 +2698,33 @@ def forward(

return (packed_result,)

def _greedy_decode_blank_as_pad(
def _greedy_decode_masked(
self,
x: torch.Tensor,
out_len: torch.Tensor,
device: torch.device,
partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None,
):
if partial_hypotheses is not None:
raise NotImplementedError("`partial_hypotheses` support is not supported")

with torch.inference_mode():
# x: [B, T, D]
# out_len: [B]
# device: torch.device

# Initialize list of Hypothesis
batchsize = x.shape[0]
hypotheses = [
rnnt_utils.Hypothesis(score=0.0, y_sequence=[], timestep=[], dec_state=None) for _ in range(batchsize)
]

# Initialize Hidden state matrix (shared by entire batch)
hidden = None

# If alignments need to be preserved, register a danling list to hold the values
if self.preserve_alignments:
# alignments is a 3-dimensional dangling list representing B x T x U
for hyp in hypotheses:
hyp.alignments = [[]]

# If confidence scores need to be preserved, register a danling list to hold the values
if self.preserve_frame_confidence:
# frame_confidence is a 3-dimensional dangling list representing B x T x U
for hyp in hypotheses:
hyp.frame_confidence = [[]]

# Last Label buffer + Last Label without blank buffer
# batch level equivalent of the last_label
last_label = torch.full([batchsize, 1], fill_value=self._blank_index, dtype=torch.long, device=device)

# Mask buffers
blank_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device)

# Get max sequence length
max_out_len = out_len.max()

# skip means the number of frames the next decoding step should "jump" to. When skip == 1
# it means the next decoding step will just use the next input frame.
skip = 1
for time_idx in range(max_out_len):
if skip > 1: # if skip > 1 at the current step, we decrement it and skip the current frame.
skip -= 1
continue
f = x.narrow(dim=1, start=time_idx, length=1) # [B, 1, D]

# need_to_stay is a boolean indicates whether the next decoding step should remain in the same frame.
need_to_stay = True
symbols_added = 0

# Reset blank mask
blank_mask.mul_(False)

# Update blank mask with time mask
# Batch: [B, T, D], but Bi may have seq len < max(seq_lens_in_batch)
# Forcibly mask with "blank" tokens, for all sample where current time step T > seq_len
blank_mask = time_idx >= out_len

# Start inner loop
while need_to_stay and (self.max_symbols is None or symbols_added < self.max_symbols):
# Batch prediction and joint network steps
# If very first prediction step, submit SOS tag (blank) to pred_step.
# This feeds a zero tensor as input to AbstractRNNTDecoder to prime the state
if time_idx == 0 and symbols_added == 0 and hidden is None:
g, hidden_prime = self._pred_step(self._SOS, hidden, batch_size=batchsize)
else:
# Perform batch step prediction of decoder, getting new states and scores ("g")
g, hidden_prime = self._pred_step(last_label, hidden, batch_size=batchsize)

# Batched joint step - Output = [B, V + 1 + num-big-blanks]
# Note: log_normalize must not be True here since the joiner output is contanetation of both token logits and duration logits,
# and they need to be normalized independently.
joined = self._joint_step(f, g, log_normalize=None)
logp = joined[:, 0, 0, : -len(self.durations)]
duration_logp = joined[:, 0, 0, -len(self.durations) :]

if logp.dtype != torch.float32:
logp = logp.float()
duration_logp = duration_logp.float()

# get the max for both token and duration predictions.
v, k = logp.max(1)
dv, dk = duration_logp.max(1)

# here we set the skip value to be the minimum of all predicted durations, hense the "torch.min(dk)" call there.
# Please refer to Section 5.2 of our paper https://arxiv.org/pdf/2304.06795.pdf for explanation of this.
skip = self.durations[int(torch.min(dk))]

# this is a special case: if all batches emit blanks, we require that skip be at least 1
# so we don't loop forever at the current frame.
if blank_mask.all():
if skip == 0:
skip = 1

need_to_stay = skip == 0
del g

# Update blank mask with current predicted blanks
# This is accumulating blanks over all time steps T and all target steps min(max_symbols, U)
k_is_blank = k == self._blank_index
blank_mask.bitwise_or_(k_is_blank)

del k_is_blank
del logp, duration_logp

# If all samples predict / have predicted prior blanks, exit loop early
# This is equivalent to if single sample predicted k
if not blank_mask.all():
# Collect batch indices where blanks occurred now/past
blank_indices = (blank_mask == 1).nonzero(as_tuple=False)

# Recover prior state for all samples which predicted blank now/past
if hidden is not None:
hidden_prime = self.decoder.batch_copy_states(hidden_prime, hidden, blank_indices)

elif len(blank_indices) > 0 and hidden is None:
# Reset state if there were some blank and other non-blank predictions in batch
# Original state is filled with zeros so we just multiply
# LSTM has 2 states
hidden_prime = self.decoder.batch_copy_states(hidden_prime, None, blank_indices, value=0.0)

# Recover prior predicted label for all samples which predicted blank now/past
k[blank_indices] = last_label[blank_indices, 0]

# Update new label and hidden state for next iteration
last_label = k.clone().view(-1, 1)
hidden = hidden_prime

# Update predicted labels, accounting for time mask
# If blank was predicted even once, now or in the past,
# Force the current predicted label to also be blank
# This ensures that blanks propogate across all timesteps
# once they have occured (normally stopping condition of sample level loop).
for kidx, ki in enumerate(k):
if blank_mask[kidx] == 0:
hypotheses[kidx].y_sequence.append(ki)
hypotheses[kidx].timestep.append(time_idx)
hypotheses[kidx].score += float(v[kidx])

symbols_added += 1

# Remove trailing empty list of alignments at T_{am-len} x Uj
if self.preserve_alignments:
for batch_idx in range(batchsize):
if len(hypotheses[batch_idx].alignments[-1]) == 0:
del hypotheses[batch_idx].alignments[-1]

# Remove trailing empty list of confidence scores at T_{am-len} x Uj
if self.preserve_frame_confidence:
for batch_idx in range(batchsize):
if len(hypotheses[batch_idx].frame_confidence[-1]) == 0:
del hypotheses[batch_idx].frame_confidence[-1]

# Preserve states
for batch_idx in range(batchsize):
hypotheses[batch_idx].dec_state = self.decoder.batch_select_state(hidden, batch_idx)

return hypotheses
raise NotImplementedError("masked greedy-batched decode is not supported for TDT models.")

def _greedy_decode_masked(
@torch.inference_mode()
def _greedy_decode_blank_as_pad_loop_labels(
self,
x: torch.Tensor,
out_len: torch.Tensor,
device: torch.device,
partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None,
):
raise NotImplementedError("masked greedy-batched decode is not supported for TDT models.")
partial_hypotheses: Optional[list[rnnt_utils.Hypothesis]] = None,
) -> list[rnnt_utils.Hypothesis]:
"""
Optimized batched greedy decoding.
The main idea: search for next labels for the whole batch (evaluating Joint)
and thus always evaluate prediction network with maximum possible batch size
"""
if partial_hypotheses is not None:
raise NotImplementedError("`partial_hypotheses` support is not implemented")

batched_hyps, alignments, last_decoder_state = self._decoding_computer(x=x, out_len=out_len)
hyps = rnnt_utils.batched_hyps_to_hypotheses(batched_hyps, alignments)
for hyp, state in zip(hyps, self.decoder.batch_split_states(last_decoder_state)):
hyp.dec_state = state
return hyps
Loading
Loading