diff --git a/nemo/collections/asr/modules/hybrid_autoregressive_transducer.py b/nemo/collections/asr/modules/hybrid_autoregressive_transducer.py index b806692d44f40..8fac6a00e9b66 100644 --- a/nemo/collections/asr/modules/hybrid_autoregressive_transducer.py +++ b/nemo/collections/asr/modules/hybrid_autoregressive_transducer.py @@ -138,9 +138,9 @@ def return_hat_ilm(self): def return_hat_ilm(self, hat_subtract_ilm): self._return_hat_ilm = hat_subtract_ilm - def joint(self, f: torch.Tensor, g: torch.Tensor) -> Union[torch.Tensor, HATJointOutput]: + def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor) -> Union[torch.Tensor, HATJointOutput]: """ - Compute the joint step of the network. + Compute the joint step of the network after Encoder/Decoder projection. Here, B = Batch size @@ -169,14 +169,8 @@ def joint(self, f: torch.Tensor, g: torch.Tensor) -> Union[torch.Tensor, HATJoin Log softmaxed tensor of shape (B, T, U, V + 1). Internal LM probability (B, 1, U, V) -- in case of return_ilm==True. """ - # f = [B, T, H1] - f = self.enc(f) - f.unsqueeze_(dim=2) # (B, T, 1, H) - - # g = [B, U, H2] - g = self.pred(g) - g.unsqueeze_(dim=1) # (B, 1, U, H) - + f = f.unsqueeze(dim=2) # (B, T, 1, H) + g = g.unsqueeze(dim=1) # (B, 1, U, H) inp = f + g # [B, T, U, H] del f diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index fb49524cdf328..c36e13e7c9d9a 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -398,6 +398,22 @@ def batch_copy_states( return old_states + def mask_select_states( + self, states: Optional[List[torch.Tensor]], mask: torch.Tensor + ) -> Optional[List[torch.Tensor]]: + """ + Return states by mask selection + Args: + states: states for the batch + mask: boolean mask for selecting states; batch dimension should be the same as for states + + Returns: + states filtered by mask + """ + if states is None: + return None + return [states[0][mask]] + def batch_score_hypothesis( self, hypotheses: List[rnnt_utils.Hypothesis], cache: Dict[Tuple[int], Any], batch_states: List[torch.Tensor] ) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]: @@ -1047,6 +1063,21 @@ def batch_copy_states( return old_states + def mask_select_states( + self, states: Tuple[torch.Tensor, torch.Tensor], mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Return states by mask selection + Args: + states: states for the batch + mask: boolean mask for selecting states; batch dimension should be the same as for states + + Returns: + states filtered by mask + """ + # LSTM in PyTorch returns a tuple of 2 tensors as a state + return states[0][:, mask], states[1][:, mask] + # Adapter method overrides def add_adapter(self, name: str, cfg: DictConfig): # Update the config with correct input dim @@ -1382,9 +1413,33 @@ def forward( return losses, wer, wer_num, wer_denom - def joint(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor: + def project_encoder(self, encoder_output: torch.Tensor) -> torch.Tensor: + """ + Project the encoder output to the joint hidden dimension. + + Args: + encoder_output: A torch.Tensor of shape [B, T, D] + + Returns: + A torch.Tensor of shape [B, T, H] + """ + return self.enc(encoder_output) + + def project_prednet(self, prednet_output: torch.Tensor) -> torch.Tensor: + """ + Project the Prediction Network (Decoder) output to the joint hidden dimension. + + Args: + prednet_output: A torch.Tensor of shape [B, U, D] + + Returns: + A torch.Tensor of shape [B, U, H] + """ + return self.pred(prednet_output) + + def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor: """ - Compute the joint step of the network. + Compute the joint step of the network after projection. Here, B = Batch size @@ -1412,14 +1467,8 @@ def joint(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor: Returns: Logits / log softmaxed tensor of shape (B, T, U, V + 1). """ - # f = [B, T, H1] - f = self.enc(f) - f.unsqueeze_(dim=2) # (B, T, 1, H) - - # g = [B, U, H2] - g = self.pred(g) - g.unsqueeze_(dim=1) # (B, 1, U, H) - + f = f.unsqueeze(dim=2) # (B, T, 1, H) + g = g.unsqueeze(dim=1) # (B, 1, U, H) inp = f + g # [B, T, U, H] del f, g @@ -1536,7 +1585,7 @@ def set_fuse_loss_wer(self, fuse_loss_wer, loss=None, metric=None): @property def fused_batch_size(self): - return self._fuse_loss_wer + return self._fused_batch_size def set_fused_batch_size(self, fused_batch_size): self._fused_batch_size = fused_batch_size diff --git a/nemo/collections/asr/modules/rnnt_abstract.py b/nemo/collections/asr/modules/rnnt_abstract.py index e473f64e5716c..ae5ff384e1fed 100644 --- a/nemo/collections/asr/modules/rnnt_abstract.py +++ b/nemo/collections/asr/modules/rnnt_abstract.py @@ -28,6 +28,45 @@ class AbstractRNNTJoint(NeuralModule, ABC): """ @abstractmethod + def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor) -> Any: + """ + Compute the joint step of the network after the projection step. + Args: + f: Output of the Encoder model after projection. A torch.Tensor of shape [B, T, H] + g: Output of the Decoder model (Prediction Network) after projection. A torch.Tensor of shape [B, U, H] + + Returns: + Logits / log softmaxed tensor of shape (B, T, U, V + 1). + Arbitrary return type, preferably torch.Tensor, but not limited to (e.g., see HatJoint) + """ + raise NotImplementedError() + + @abstractmethod + def project_encoder(self, encoder_output: torch.Tensor) -> torch.Tensor: + """ + Project the encoder output to the joint hidden dimension. + + Args: + encoder_output: A torch.Tensor of shape [B, T, D] + + Returns: + A torch.Tensor of shape [B, T, H] + """ + raise NotImplementedError() + + @abstractmethod + def project_prednet(self, prednet_output: torch.Tensor) -> torch.Tensor: + """ + Project the Prediction Network (Decoder) output to the joint hidden dimension. + + Args: + prednet_output: A torch.Tensor of shape [B, U, D] + + Returns: + A torch.Tensor of shape [B, U, H] + """ + raise NotImplementedError() + def joint(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor: """ Compute the joint step of the network. @@ -58,7 +97,7 @@ def joint(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor: Returns: Logits / log softmaxed tensor of shape (B, T, U, V + 1). """ - raise NotImplementedError() + return self.joint_after_projection(self.project_encoder(f), self.project_prednet(g)) @property def num_classes_with_blank(self): @@ -277,3 +316,15 @@ def batch_copy_states( (L x B x H, L x B x H) """ raise NotImplementedError() + + def mask_select_states(self, states: Any, mask: torch.Tensor) -> Any: + """ + Return states by mask selection + Args: + states: states for the batch (preferably a list of tensors, but not limited to) + mask: boolean mask for selecting states; batch dimension should be the same as for states + + Returns: + states filtered by mask (same type as `states`) + """ + raise NotImplementedError() diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index 30d7c81785328..3f4e0bc6eac05 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -316,6 +316,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): preserve_alignments=self.preserve_alignments, preserve_frame_confidence=self.preserve_frame_confidence, confidence_method_cfg=self.confidence_method_cfg, + loop_labels=self.cfg.greedy.get('loop_labels', True), ) else: self.decoding = rnnt_greedy_decoding.GreedyBatchedTDTInfer( @@ -1495,8 +1496,8 @@ class RNNTDecodingConfig: rnnt_timestamp_type: str = "all" # can be char, word or all for both # greedy decoding config - greedy: rnnt_greedy_decoding.GreedyRNNTInferConfig = field( - default_factory=lambda: rnnt_greedy_decoding.GreedyRNNTInferConfig() + greedy: rnnt_greedy_decoding.GreedyBatchedRNNTInferConfig = field( + default_factory=rnnt_greedy_decoding.GreedyBatchedRNNTInferConfig ) # beam decoding config diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index a0aea07f7bc8f..83fdad35f6de9 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -38,7 +38,7 @@ from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodConfig, ConfidenceMethodMixin from nemo.collections.common.parts.rnn import label_collate from nemo.core.classes import Typing, typecheck -from nemo.core.neural_types import AcousticEncodedRepresentation, ElementType, HypothesisType, LengthsType, NeuralType +from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType from nemo.utils import logging @@ -50,7 +50,11 @@ def pack_hypotheses(hypotheses: List[rnnt_utils.Hypothesis], logitlen: torch.Ten logitlen_cpu = logitlen for idx, hyp in enumerate(hypotheses): # type: rnnt_utils.Hypothesis - hyp.y_sequence = torch.tensor(hyp.y_sequence, dtype=torch.long) + hyp.y_sequence = ( + hyp.y_sequence.to(torch.long) + if isinstance(hyp.y_sequence, torch.Tensor) + else torch.tensor(hyp.y_sequence, dtype=torch.long) + ) hyp.length = logitlen_cpu[idx] if hyp.dec_state is not None: @@ -162,6 +166,9 @@ def __init__( self._blank_index = blank_index self._SOS = blank_index # Start of single index + + if max_symbols_per_step is not None and max_symbols_per_step <= 0: + raise ValueError(f"Expected max_symbols_per_step > 0 (or None), got {max_symbols_per_step}") self.max_symbols = max_symbols_per_step self.preserve_alignments = preserve_alignments self.preserve_frame_confidence = preserve_frame_confidence @@ -235,6 +242,30 @@ def _joint_step(self, enc, pred, log_normalize: Optional[bool] = None): return logits + def _joint_step_after_projection(self, enc, pred, log_normalize: Optional[bool] = None) -> torch.Tensor: + """ + Common joint step based on AbstractRNNTJoint implementation. + + Args: + enc: Output of the Encoder model after projection. A torch.Tensor of shape [B, 1, H] + pred: Output of the Decoder model after projection. A torch.Tensor of shape [B, 1, H] + log_normalize: Whether to log normalize or not. None will log normalize only for CPU. + + Returns: + logits of shape (B, T=1, U=1, V + 1) + """ + with torch.no_grad(): + logits = self.joint.joint_after_projection(enc, pred) + + if log_normalize is None: + if not logits.is_cuda: # Use log softmax only if on CPU + logits = logits.log_softmax(dim=len(logits.shape) - 1) + else: + if log_normalize: + logits = logits.log_softmax(dim=len(logits.shape) - 1) + + return logits + class GreedyRNNTInfer(_GreedyRNNTInfer): """A greedy transducer decoder. @@ -534,6 +565,16 @@ class GreedyBatchedRNNTInfer(_GreedyRNNTInfer): Supported values: - 'lin' for using the linear mapping. - 'exp' for using exponential mapping with linear shift. + loop_labels: Switching between decoding algorithms. Both algorithms produce equivalent results. + loop_labels=True (default) algorithm is faster (especially for large batches) but can use a bit more memory + (negligible overhead compared to the amount of memory used by the encoder). + loop_labels=False is an implementation of a traditional decoding algorithm, which iterates over frames + (encoder output vectors), and in the inner loop, decodes labels for the current frame one by one, + stopping when is found. + loop_labels=True iterates over labels, on each step finding the next non-blank label + (evaluating Joint multiple times in inner loop); It uses a minimal possible amount of calls + to prediction network (with maximum possible batch size), + which makes it especially useful for scaling the prediction network. """ def __init__( @@ -545,6 +586,7 @@ def __init__( preserve_alignments: bool = False, preserve_frame_confidence: bool = False, confidence_method_cfg: Optional[DictConfig] = None, + loop_labels: bool = True, ): super().__init__( decoder_model=decoder_model, @@ -559,7 +601,12 @@ def __init__( # Depending on availability of `blank_as_pad` support # switch between more efficient batch decoding technique if self.decoder.blank_as_pad: - self._greedy_decode = self._greedy_decode_blank_as_pad + if loop_labels: + # default (faster) algo: loop over labels + self._greedy_decode = self._greedy_decode_blank_as_pad_loop_labels + else: + # previous algo: loop over frames + self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames else: self._greedy_decode = self._greedy_decode_masked @@ -607,7 +654,174 @@ def forward( return (packed_result,) - def _greedy_decode_blank_as_pad( + @torch.inference_mode() + def _greedy_decode_blank_as_pad_loop_labels( + self, + x: torch.Tensor, + out_len: torch.Tensor, + device: torch.device, + partial_hypotheses: Optional[list[rnnt_utils.Hypothesis]] = None, + ) -> list[rnnt_utils.Hypothesis]: + """ + Optimized batched greedy decoding. + The main idea: search for next labels for the whole batch (evaluating Joint) + and thus always evaluate prediction network with maximum possible batch size + """ + if partial_hypotheses is not None: + raise NotImplementedError("`partial_hypotheses` support is not implemented") + + batch_size, max_time, _ = x.shape + + x = self.joint.project_encoder(x) # do not recalculate joint projection, project only once + + # Initialize empty hypotheses and all necessary tensors + batched_hyps = rnnt_utils.BatchedHyps( + batch_size=batch_size, init_length=max_time, device=x.device, float_dtype=x.dtype + ) + time_indices = torch.zeros([batch_size], dtype=torch.long, device=device) # always of batch_size + active_indices = torch.arange(batch_size, dtype=torch.long, device=device) # initial: all indices + labels = torch.full([batch_size], fill_value=self._blank_index, dtype=torch.long, device=device) + state = None + + # init additional structs for hypotheses: last decoder state, alignments, frame_confidence + last_decoder_state = [None for _ in range(batch_size)] + + alignments: Optional[rnnt_utils.BatchedAlignments] + if self.preserve_alignments or self.preserve_frame_confidence: + alignments = rnnt_utils.BatchedAlignments( + batch_size=batch_size, + logits_dim=self.joint.num_classes_with_blank, + init_length=max_time * 2, # blank for each timestep + text tokens + device=x.device, + float_dtype=x.dtype, + store_alignments=self.preserve_alignments, + store_frame_confidence=self.preserve_frame_confidence, + ) + else: + alignments = None + + # loop while there are active indices + while (current_batch_size := active_indices.shape[0]) > 0: + # stage 1: get decoder (prediction network) output + if state is None: + # start of the loop, SOS symbol is passed into prediction network + decoder_output, state, *_ = self._pred_step(self._SOS, state, batch_size=current_batch_size) + else: + # pass the labels (found in the inner loop) to the prediction network + decoder_output, state, *_ = self._pred_step(labels.unsqueeze(1), state, batch_size=current_batch_size) + decoder_output = self.joint.project_prednet(decoder_output) # do not recalculate joint projection + + # stage 2: get joint output, iteratively seeking for non-blank labels + # blank label in `labels` tensor means "end of hypothesis" (for this index) + logits = ( + self._joint_step_after_projection( + x[active_indices, time_indices[active_indices]].unsqueeze(1), + decoder_output, + log_normalize=True if self.preserve_frame_confidence else None, + ) + .squeeze(1) + .squeeze(1) + ) + scores, labels = logits.max(-1) + + # search for non-blank labels using joint, advancing time indices for blank labels + # checking max_symbols is not needed, since we already forced advancing time indices for such cases + blank_mask = labels == self._blank_index + if alignments is not None: + alignments.add_results_( + active_indices=active_indices, + time_indices=time_indices[active_indices], + logits=logits if self.preserve_alignments else None, + labels=labels if self.preserve_alignments else None, + confidence=torch.tensor(self._get_confidence(logits), device=device) + if self.preserve_frame_confidence + else None, + ) + # advance_mask is a mask for current batch for searching non-blank labels; + # each element is True if non-blank symbol is not yet found AND we can increase the time index + advance_mask = torch.logical_and(blank_mask, (time_indices[active_indices] + 1 < out_len[active_indices])) + while advance_mask.any(): + advance_indices = active_indices[advance_mask] + time_indices[advance_indices] += 1 + logits = ( + self._joint_step_after_projection( + x[advance_indices, time_indices[advance_indices]].unsqueeze(1), + decoder_output[advance_mask], + log_normalize=True if self.preserve_frame_confidence else None, + ) + .squeeze(1) + .squeeze(1) + ) + # get labels (greedy) and scores from current logits, replace labels/scores with new + # labels[advance_mask] are blank, and we are looking for non-blank labels + more_scores, more_labels = logits.max(-1) + labels[advance_mask] = more_labels + scores[advance_mask] = more_scores + if alignments is not None: + alignments.add_results_( + active_indices=advance_indices, + time_indices=time_indices[advance_indices], + logits=logits if self.preserve_alignments else None, + labels=more_labels if self.preserve_alignments else None, + confidence=torch.tensor(self._get_confidence(logits), device=device) + if self.preserve_frame_confidence + else None, + ) + blank_mask = labels == self._blank_index + advance_mask = torch.logical_and( + blank_mask, (time_indices[active_indices] + 1 < out_len[active_indices]) + ) + + # stage 3: filter labels and state, store hypotheses + # the only case, when there are blank labels in predictions - when we found the end for some utterances + if blank_mask.any(): + non_blank_mask = ~blank_mask + labels = labels[non_blank_mask] + scores = scores[non_blank_mask] + + # select states for hyps that became inactive (is it necessary?) + # this seems to be redundant, but used in the `loop_frames` output + inactive_global_indices = active_indices[blank_mask] + inactive_inner_indices = torch.arange(current_batch_size, device=device, dtype=torch.long)[blank_mask] + for idx, batch_idx in zip(inactive_global_indices.cpu().numpy(), inactive_inner_indices.cpu().numpy()): + last_decoder_state[idx] = self.decoder.batch_select_state(state, batch_idx) + + # update active indices and state + active_indices = active_indices[non_blank_mask] + state = self.decoder.mask_select_states(state, non_blank_mask) + # store hypotheses + batched_hyps.add_results_( + active_indices, labels, time_indices[active_indices].clone(), scores, + ) + + # stage 4: to avoid looping, go to next frame after max_symbols emission + if self.max_symbols is not None: + # if labels are non-blank (not end-of-utterance), check that last observed timestep with label: + # if it is equal to the current time index, and number of observations is >= max_symbols, force blank + force_blank_mask = torch.logical_and( + torch.logical_and( + labels != self._blank_index, + batched_hyps.last_timestep_lasts[active_indices] >= self.max_symbols, + ), + batched_hyps.last_timestep[active_indices] == time_indices[active_indices], + ) + if force_blank_mask.any(): + # forced blank is not stored in the alignments following the original implementation + time_indices[active_indices[force_blank_mask]] += 1 # emit blank => advance time indices + # elements with time indices >= out_len become inactive, remove them from batch + still_active_mask = time_indices[active_indices] < out_len[active_indices] + active_indices = active_indices[still_active_mask] + labels = labels[still_active_mask] + state = self.decoder.mask_select_states(state, still_active_mask) + + hyps = rnnt_utils.batched_hyps_to_hypotheses(batched_hyps, alignments) + # preserve last decoder state (is it necessary?) + for i, last_state in enumerate(last_decoder_state): + # assert last_state is not None + hyps[i].dec_state = last_state + return hyps + + def _greedy_decode_blank_as_pad_loop_frames( self, x: torch.Tensor, out_len: torch.Tensor, @@ -2207,6 +2421,7 @@ class GreedyBatchedRNNTInferConfig: preserve_alignments: bool = False preserve_frame_confidence: bool = False confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig()) + loop_labels: bool = True def __post_init__(self): # OmegaConf.structured ensures that post_init check is always executed diff --git a/nemo/collections/asr/parts/utils/rnnt_utils.py b/nemo/collections/asr/parts/utils/rnnt_utils.py index 18fc8305cac19..13b8a5b5fa7ed 100644 --- a/nemo/collections/asr/parts/utils/rnnt_utils.py +++ b/nemo/collections/asr/parts/utils/rnnt_utils.py @@ -220,3 +220,246 @@ def select_k_expansions( k_expansions.append([(k_best_exp_idx, k_best_exp)]) return k_expansions + + +class BatchedHyps: + """Class to store batched hypotheses (labels, time_indices, scores) for efficient RNNT decoding""" + + def __init__( + self, + batch_size: int, + init_length: int, + device: Optional[torch.device] = None, + float_dtype: Optional[torch.dtype] = None, + ): + """ + + Args: + batch_size: batch size for hypotheses + init_length: initial estimate for the length of hypotheses (if the real length is higher, tensors will be reallocated) + device: device for storing hypotheses + float_dtype: float type for scores + """ + if init_length <= 0: + raise ValueError(f"init_length must be > 0, got {init_length}") + if batch_size <= 0: + raise ValueError(f"batch_size must be > 0, got {batch_size}") + self._max_length = init_length + + # batch of current lengths of hypotheses and correspoinding timesteps + self.current_lengths = torch.zeros(batch_size, device=device, dtype=torch.long) + # tensor for storing transcripts + self.transcript = torch.zeros((batch_size, self._max_length), device=device, dtype=torch.long) + # tensor for storing timesteps corresponding to transcripts + self.timesteps = torch.zeros((batch_size, self._max_length), device=device, dtype=torch.long) + # accumulated scores for hypotheses + self.scores = torch.zeros(batch_size, device=device, dtype=float_dtype) + + # tracking last timestep of each hyp to avoid infinite looping (when max symbols per frame is restricted) + # last observed timestep (with label) for each hypothesis + self.last_timestep = torch.full((batch_size,), -1, device=device, dtype=torch.long) + # number of labels for the last timestep + self.last_timestep_lasts = torch.zeros(batch_size, device=device, dtype=torch.long) + + def _allocate_more(self): + """ + Allocate 2x space for tensors, similar to common C++ std::vector implementations + to maintain O(1) insertion time complexity + """ + self.transcript = torch.cat((self.transcript, torch.zeros_like(self.transcript)), dim=-1) + self.timesteps = torch.cat((self.timesteps, torch.zeros_like(self.timesteps)), dim=-1) + self._max_length *= 2 + + def add_results_( + self, active_indices: torch.Tensor, labels: torch.Tensor, time_indices: torch.Tensor, scores: torch.Tensor + ): + """ + Add results (inplace) from a decoding step to the batched hypotheses + Args: + active_indices: tensor with indices of active hypotheses (indices should be within the original batch_size) + labels: non-blank labels to add + time_indices: tensor of time index for each label + scores: label scores + """ + # we assume that all tensors have the same first dimension, and labels are non-blanks + if active_indices.shape[0] == 0: + return # nothing to add + # if needed - increase storage + if self.current_lengths.max().item() >= self._max_length: + self._allocate_more() + + # accumulate scores + self.scores[active_indices] += scores + + # store transcript and timesteps + active_lengths = self.current_lengths[active_indices] + self.transcript[active_indices, active_lengths] = labels + self.timesteps[active_indices, active_lengths] = time_indices + # store last observed timestep + number of observation for the current timestep + self.last_timestep_lasts[active_indices] = torch.where( + self.last_timestep[active_indices] == time_indices, self.last_timestep_lasts[active_indices] + 1, 1 + ) + self.last_timestep[active_indices] = time_indices + # increase lengths + self.current_lengths[active_indices] += 1 + + +class BatchedAlignments: + """ + Class to store batched alignments (logits, labels, frame_confidence). + Size is different from hypotheses, since blank outputs are preserved + """ + + def __init__( + self, + batch_size: int, + logits_dim: int, + init_length: int, + device: Optional[torch.device] = None, + float_dtype: Optional[torch.dtype] = None, + store_alignments: bool = True, + store_frame_confidence: bool = False, + ): + """ + + Args: + batch_size: batch size for hypotheses + logits_dim: dimension for logits + init_length: initial estimate for the lengths of flatten alignments + device: device for storing data + float_dtype: expected logits/confidence data type + store_alignments: if alignments should be stored + store_frame_confidence: if frame confidence should be stored + """ + if init_length <= 0: + raise ValueError(f"init_length must be > 0, got {init_length}") + if batch_size <= 0: + raise ValueError(f"batch_size must be > 0, got {batch_size}") + self.with_frame_confidence = store_frame_confidence + self.with_alignments = store_alignments + self._max_length = init_length + + # tensor to store observed timesteps (for alignments / confidence scores) + self.timesteps = torch.zeros((batch_size, self._max_length), device=device, dtype=torch.long) + # current lengths of the utterances (alignments) + self.current_lengths = torch.zeros(batch_size, device=device, dtype=torch.long) + + if self.with_alignments: + # logits and labels; labels can contain , different from BatchedHyps + self.logits = torch.zeros((batch_size, self._max_length, logits_dim), device=device, dtype=float_dtype) + self.labels = torch.zeros((batch_size, self._max_length), device=device, dtype=torch.long) + else: + self.logits = None + self.labels = None + + if self.with_frame_confidence: + # tensor to store frame confidence + self.frame_confidence = torch.zeros((batch_size, self._max_length), device=device, dtype=float_dtype) + else: + self.frame_confidence = None + + def _allocate_more(self): + """ + Allocate 2x space for tensors, similar to common C++ std::vector implementations + to maintain O(1) insertion time complexity + """ + self.timesteps = torch.cat((self.timesteps, torch.zeros_like(self.timesteps)), dim=-1) + if self.with_alignments: + self.logits = torch.cat((self.logits, torch.zeros_like(self.logits)), dim=1) + self.labels = torch.cat((self.labels, torch.zeros_like(self.labels)), dim=-1) + if self.with_frame_confidence: + self.frame_confidence = torch.cat((self.frame_confidence, torch.zeros_like(self.frame_confidence)), dim=-1) + self._max_length *= 2 + + def add_results_( + self, + active_indices: torch.Tensor, + time_indices: torch.Tensor, + logits: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + confidence: Optional[torch.Tensor] = None, + ): + """ + Add results (inplace) from a decoding step to the batched hypotheses + Args: + active_indices: tensor with indices of active hypotheses (indices should be within the original batch_size) + logits: tensor with raw network outputs + labels: tensor with decoded labels (can contain blank) + time_indices: tensor of time index for each label + confidence: optional tensor with confidence for each item in batch + """ + # we assume that all tensors have the same first dimension + if active_indices.shape[0] == 0: + return # nothing to add + + # if needed - increase storage + if self.current_lengths.max().item() >= self._max_length: + self._allocate_more() + + active_lengths = self.current_lengths[active_indices] + # store timesteps - same for alignments / confidence + self.timesteps[active_indices, active_lengths] = time_indices + + if self.with_alignments and logits is not None and labels is not None: + self.logits[active_indices, active_lengths] = logits + self.labels[active_indices, active_lengths] = labels + + if self.with_frame_confidence and confidence is not None: + self.frame_confidence[active_indices, active_lengths] = confidence + # increase lengths + self.current_lengths[active_indices] += 1 + + +def batched_hyps_to_hypotheses( + batched_hyps: BatchedHyps, alignments: Optional[BatchedAlignments] = None +) -> List[Hypothesis]: + """ + Convert batched hypotheses to a list of Hypothesis objects. + Keep this function separate to allow for jit compilation for BatchedHyps class (see tests) + + Args: + batched_hyps: BatchedHyps object + alignments: BatchedAlignments object, optional; must correspond to BatchedHyps if present + + Returns: + list of Hypothesis objects + """ + hypotheses = [ + Hypothesis( + score=batched_hyps.scores[i].item(), + y_sequence=batched_hyps.transcript[i, : batched_hyps.current_lengths[i]], + timestep=batched_hyps.timesteps[i, : batched_hyps.current_lengths[i]], + alignments=None, + dec_state=None, + ) + for i in range(batched_hyps.scores.shape[0]) + ] + if alignments is not None: + # move all data to cpu to avoid overhead with moving data by chunks + alignment_lengths = alignments.current_lengths.cpu().tolist() + if alignments.with_alignments: + alignment_logits = alignments.logits.cpu() + alignment_labels = alignments.labels.cpu() + if alignments.with_frame_confidence: + frame_confidence = alignments.frame_confidence.cpu() + + # for each hypothesis - aggregate alignment using unique_consecutive for time indices (~itertools.groupby) + for i in range(len(hypotheses)): + hypotheses[i].alignments = [] + if alignments.with_frame_confidence: + hypotheses[i].frame_confidence = [] + _, grouped_counts = torch.unique_consecutive( + alignments.timesteps[i, : alignment_lengths[i]], return_counts=True + ) + start = 0 + for timestep_cnt in grouped_counts.tolist(): + if alignments.with_alignments: + hypotheses[i].alignments.append( + [(alignment_logits[i, start + j], alignment_labels[i, start + j]) for j in range(timestep_cnt)] + ) + if alignments.with_frame_confidence: + hypotheses[i].frame_confidence.append( + [frame_confidence[i, start + j] for j in range(timestep_cnt)] + ) + start += timestep_cnt + return hypotheses diff --git a/tests/collections/asr/confidence/test_asr_confidence.py b/tests/collections/asr/confidence/test_asr_confidence.py index 0d91abf6d81bf..edf35bb17b0be 100644 --- a/tests/collections/asr/confidence/test_asr_confidence.py +++ b/tests/collections/asr/confidence/test_asr_confidence.py @@ -26,7 +26,7 @@ from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig from nemo.collections.asr.parts.submodules.ctc_greedy_decoding import GreedyCTCInferConfig from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig -from nemo.collections.asr.parts.submodules.rnnt_greedy_decoding import GreedyRNNTInferConfig +from nemo.collections.asr.parts.submodules.rnnt_greedy_decoding import GreedyBatchedRNNTInferConfig from nemo.collections.asr.parts.utils.asr_confidence_benchmarking_utils import run_confidence_benchmark from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceConfig @@ -125,7 +125,7 @@ def test_deprecated_config_args(self, model_name, conformer_ctc_bpe_model, confo RNNTDecodingConfig( fused_batch_size=-1, strategy="greedy", - greedy=GreedyRNNTInferConfig(preserve_frame_confidence=True, **test_args_greedy), + greedy=GreedyBatchedRNNTInferConfig(preserve_frame_confidence=True, **test_args_greedy), ) if model_name == "rnnt" else CTCDecodingConfig(greedy=GreedyCTCInferConfig(preserve_frame_confidence=True, **test_args_greedy)) diff --git a/tests/collections/asr/decoding/rnnt_alignments_check.py b/tests/collections/asr/decoding/rnnt_alignments_check.py index 2f19f78c61229..b384062abe42c 100644 --- a/tests/collections/asr/decoding/rnnt_alignments_check.py +++ b/tests/collections/asr/decoding/rnnt_alignments_check.py @@ -28,15 +28,17 @@ PRETRAINED_MODEL_NAME = "stt_en_conformer_transducer_small" -def get_rnnt_alignments(strategy: str): +def get_rnnt_alignments(strategy: str, loop_labels: bool = True, location="cuda"): cfg = OmegaConf.structured(TranscriptionConfig(pretrained_name=PRETRAINED_MODEL_NAME)) cfg.rnnt_decoding.confidence_cfg.preserve_frame_confidence = True cfg.rnnt_decoding.preserve_alignments = True cfg.rnnt_decoding.strategy = strategy + if cfg.rnnt_decoding.strategy == "greedy_batch": + cfg.rnnt_decoding.greedy.loop_labels = loop_labels cfg.dataset_manifest = TEST_DATA_PATH filepaths = prepare_audio_data(cfg)[0][:10] # selecting 10 files only - model = setup_model(cfg, map_location="cuda")[0] + model = setup_model(cfg, map_location=location)[0] model.change_decoding_strategy(cfg.rnnt_decoding) transcriptions = model.transcribe( @@ -70,10 +72,11 @@ def cleanup_local_folder(): # TODO: add the same tests for multi-blank RNNT decoding @pytest.mark.skipif(not os.path.exists('/home/TestData'), reason='Not a Jenkins machine') -def test_rnnt_alignments(): +@pytest.mark.parametrize("loop_labels", [True, False]) +def test_rnnt_alignments(loop_labels: bool): # using greedy as baseline and comparing all other configurations to it ref_transcriptions = get_rnnt_alignments("greedy") - transcriptions = get_rnnt_alignments("greedy_batch") + transcriptions = get_rnnt_alignments("greedy_batch", loop_labels=loop_labels) # comparing that label sequence in alignments is exactly the same # we can't compare logits as well, because they are expected to be # slightly different in batched and single-sample mode diff --git a/tests/collections/asr/decoding/test_batched_hyps_and_alignments.py b/tests/collections/asr/decoding/test_batched_hyps_and_alignments.py new file mode 100644 index 0000000000000..21d958e9eeb83 --- /dev/null +++ b/tests/collections/asr/decoding/test_batched_hyps_and_alignments.py @@ -0,0 +1,288 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import pytest +import torch + +from nemo.collections.asr.parts.utils.rnnt_utils import BatchedAlignments, BatchedHyps, batched_hyps_to_hypotheses + +DEVICES: List[torch.device] = [torch.device("cpu")] + +if torch.cuda.is_available(): + DEVICES.append(torch.device("cuda")) + +if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + DEVICES.append(torch.device("mps")) + + +class TestBatchedHyps: + @pytest.mark.unit + @pytest.mark.parametrize("device", DEVICES) + def test_instantiate(self, device: torch.device): + hyps = BatchedHyps(batch_size=2, init_length=3, device=device) + assert torch.is_tensor(hyps.timesteps) + # device: for mps device we need to use `type`, not directly compare + assert hyps.timesteps.device.type == device.type + assert hyps.timesteps.shape == (2, 3) + + @pytest.mark.unit + @pytest.mark.parametrize("batch_size", [-1, 0]) + def test_instantiate_incorrect_batch_size(self, batch_size): + with pytest.raises(ValueError): + _ = BatchedHyps(batch_size=batch_size, init_length=3) + + @pytest.mark.unit + @pytest.mark.parametrize("init_length", [-1, 0]) + def test_instantiate_incorrect_init_length(self, init_length): + with pytest.raises(ValueError): + _ = BatchedHyps(batch_size=1, init_length=init_length) + + @pytest.mark.unit + @pytest.mark.parametrize("device", DEVICES) + def test_add_results(self, device: torch.device): + # batch of size 2, add label for first utterance + hyps = BatchedHyps(batch_size=2, init_length=1, device=device) + hyps.add_results_( + active_indices=torch.tensor([0], device=device), + labels=torch.tensor([5], device=device), + time_indices=torch.tensor([1], device=device), + scores=torch.tensor([0.5], device=device), + ) + assert hyps.current_lengths.tolist() == [1, 0] + assert hyps.transcript.tolist()[0][:1] == [5] + assert hyps.timesteps.tolist()[0][:1] == [1] + assert hyps.scores.tolist() == pytest.approx([0.5, 0.0]) + assert hyps.last_timestep.tolist() == [1, -1] + assert hyps.last_timestep_lasts.tolist() == [1, 0] + + @pytest.mark.unit + @pytest.mark.parametrize("device", DEVICES) + def test_add_multiple_results(self, device: torch.device): + # batch of size 2, add label for first utterance, then add labels for both utterances + hyps = BatchedHyps(batch_size=2, init_length=1, device=device) + hyps.add_results_( + active_indices=torch.tensor([0], device=device), + labels=torch.tensor([5], device=device), + time_indices=torch.tensor([1], device=device), + scores=torch.tensor([0.5], device=device), + ) + hyps.add_results_( + active_indices=torch.tensor([0, 1], device=device), + labels=torch.tensor([2, 4], device=device), + time_indices=torch.tensor([1, 2], device=device), + scores=torch.tensor([1.0, 1.0], device=device), + ) + assert hyps.current_lengths.tolist() == [2, 1] + assert hyps.transcript.tolist()[0][:2] == [5, 2] + assert hyps.transcript.tolist()[1][:1] == [4] + assert hyps.timesteps.tolist()[0][:2] == [1, 1] + assert hyps.timesteps.tolist()[1][:1] == [2] + assert hyps.scores.tolist() == pytest.approx([1.5, 1.0]) + assert hyps.last_timestep.tolist() == [1, 2] + assert hyps.last_timestep_lasts.tolist() == [2, 1] + + @pytest.mark.unit + @pytest.mark.parametrize("device", DEVICES) + def test_torch_jit_compatibility(self, device: torch.device): + @torch.jit.script + def hyps_add_wrapper( + active_indices: torch.Tensor, labels: torch.Tensor, time_indices: torch.Tensor, scores: torch.Tensor + ): + hyps = BatchedHyps(batch_size=2, init_length=3, device=active_indices.device) + hyps.add_results_(active_indices=active_indices, labels=labels, time_indices=time_indices, scores=scores) + return hyps + + scores = torch.tensor([0.1, 0.1], device=device) + hyps = hyps_add_wrapper( + torch.tensor([0, 1], device=device), + torch.tensor([2, 4], device=device), + torch.tensor([0, 0], device=device), + scores, + ) + assert torch.allclose(hyps.scores, scores) + + +class TestBatchedAlignments: + @pytest.mark.unit + @pytest.mark.parametrize("device", DEVICES) + def test_instantiate(self, device: torch.device): + alignments = BatchedAlignments(batch_size=2, logits_dim=7, init_length=3, device=device) + assert torch.is_tensor(alignments.logits) + # device: for mps device we need to use `type`, not directly compare + assert alignments.logits.device.type == device.type + assert alignments.logits.shape == (2, 3, 7) + + @pytest.mark.unit + @pytest.mark.parametrize("batch_size", [-1, 0]) + def test_instantiate_incorrect_batch_size(self, batch_size): + with pytest.raises(ValueError): + _ = BatchedAlignments(batch_size=batch_size, logits_dim=7, init_length=3) + + @pytest.mark.unit + @pytest.mark.parametrize("init_length", [-1, 0]) + def test_instantiate_incorrect_init_length(self, init_length): + with pytest.raises(ValueError): + _ = BatchedAlignments(batch_size=1, logits_dim=7, init_length=init_length) + + @pytest.mark.unit + @pytest.mark.parametrize("device", DEVICES) + def test_add_results(self, device: torch.device): + # batch of size 2, add label for first utterance + batch_size = 2 + logits_dim = 7 + sample_logits = torch.rand((batch_size, 1, logits_dim), device=device) + alignments = BatchedAlignments(batch_size=batch_size, logits_dim=logits_dim, init_length=1, device=device) + alignments.add_results_( + active_indices=torch.arange(batch_size, device=device), + logits=sample_logits[:, 0], + labels=torch.argmax(sample_logits[:, 0], dim=-1), + time_indices=torch.tensor([0, 0], device=device), + ) + assert alignments.current_lengths.tolist() == [1, 1] + assert torch.allclose(alignments.logits[:, 0], sample_logits[:, 0]) + assert alignments.timesteps[:, 0].tolist() == [0, 0] + + @pytest.mark.unit + @pytest.mark.parametrize("device", DEVICES) + def test_add_multiple_results(self, device: torch.device): + # batch of size 2, add label for first utterance + batch_size = 2 + seq_length = 5 + logits_dim = 7 + alignments = BatchedAlignments(batch_size=batch_size, logits_dim=logits_dim, init_length=1, device=device) + sample_logits = torch.rand((batch_size, seq_length, logits_dim), device=device) + add_logits_mask = torch.rand((batch_size, seq_length), device=device) < 0.6 + for t in range(seq_length): + alignments.add_results_( + active_indices=torch.arange(batch_size, device=device)[add_logits_mask[:, t]], + logits=sample_logits[add_logits_mask[:, t], t], + labels=torch.argmax(sample_logits[add_logits_mask[:, t], t], dim=-1), + time_indices=torch.tensor([0, 0], device=device)[add_logits_mask[:, t]], + ) + + assert (alignments.current_lengths == add_logits_mask.sum(dim=-1)).all() + for i in range(batch_size): + assert ( + alignments.logits[i, : alignments.current_lengths[i]] == sample_logits[i, add_logits_mask[i]] + ).all() + + +class TestConvertToHypotheses: + @pytest.mark.unit + @pytest.mark.parametrize("device", DEVICES) + def test_convert_to_hypotheses(self, device: torch.device): + hyps = BatchedHyps(batch_size=2, init_length=1, device=device) + hyps.add_results_( + active_indices=torch.tensor([0], device=device), + labels=torch.tensor([5], device=device), + time_indices=torch.tensor([1], device=device), + scores=torch.tensor([0.5], device=device), + ) + hyps.add_results_( + active_indices=torch.tensor([0, 1], device=device), + labels=torch.tensor([2, 4], device=device), + time_indices=torch.tensor([1, 2], device=device), + scores=torch.tensor([1.0, 1.0], device=device), + ) + hypotheses = batched_hyps_to_hypotheses(hyps) + assert (hypotheses[0].y_sequence == torch.tensor([5, 2], device=device)).all() + assert (hypotheses[1].y_sequence == torch.tensor([4], device=device)).all() + assert hypotheses[0].score == pytest.approx(1.5) + assert hypotheses[1].score == pytest.approx(1.0) + assert (hypotheses[0].timestep == torch.tensor([1, 1], device=device)).all() + assert (hypotheses[1].timestep == torch.tensor([2], device=device)).all() + + @pytest.mark.unit + @pytest.mark.parametrize("device", DEVICES) + def test_convert_to_hypotheses_with_alignments(self, device: torch.device): + batch_size = 2 + logits_dim = 7 + blank_index = 6 + hyps = BatchedHyps(batch_size=batch_size, init_length=1, device=device) + alignments = BatchedAlignments(batch_size=batch_size, init_length=1, logits_dim=logits_dim, device=device) + sample_logits = torch.rand((batch_size, 4, logits_dim), device=device) + # sequence 0: [[5, blank], [2, blank]] -> [5, 2] + # sequence 1: [[blank ], [4, blank]] -> [4] + + # frame 0 + hyps.add_results_( + active_indices=torch.tensor([0], device=device), + labels=torch.tensor([5], device=device), + time_indices=torch.tensor([0], device=device), + scores=torch.tensor([0.5], device=device), + ) + alignments.add_results_( + active_indices=torch.arange(batch_size, device=device), + logits=sample_logits[:, 0], + labels=torch.tensor([5, blank_index], device=device), + time_indices=torch.tensor([0, 0], device=device), + ) + alignments.add_results_( + active_indices=torch.tensor([0], device=device), + logits=sample_logits[:1, 1], + labels=torch.tensor([blank_index], device=device), + time_indices=torch.tensor([0], device=device), + ) + + # frame 1 + hyps.add_results_( + active_indices=torch.arange(batch_size, device=device), + labels=torch.tensor([2, 4], device=device), + time_indices=torch.tensor([1, 1], device=device), + scores=torch.tensor([1.0, 1.0], device=device), + ) + alignments.add_results_( + active_indices=torch.arange(batch_size, device=device), + logits=sample_logits[:, 2], + labels=torch.tensor([2, 4], device=device), + time_indices=torch.tensor([1, 1], device=device), + ) + alignments.add_results_( + active_indices=torch.arange(batch_size, device=device), + logits=sample_logits[:, 3], + labels=torch.tensor([blank_index, blank_index], device=device), + time_indices=torch.tensor([1, 1], device=device), + ) + + hypotheses = batched_hyps_to_hypotheses(hyps, alignments) + assert (hypotheses[0].y_sequence == torch.tensor([5, 2], device=device)).all() + assert (hypotheses[1].y_sequence == torch.tensor([4], device=device)).all() + assert hypotheses[0].score == pytest.approx(1.5) + assert hypotheses[1].score == pytest.approx(1.0) + assert (hypotheses[0].timestep == torch.tensor([0, 1], device=device)).all() + assert (hypotheses[1].timestep == torch.tensor([1], device=device)).all() + + etalon = [ + [ + [ + (torch.tensor(5), sample_logits[0, 0].cpu()), + (torch.tensor(blank_index), sample_logits[0, 1].cpu()), + ], + [ + (torch.tensor(2), sample_logits[0, 2].cpu()), + (torch.tensor(blank_index), sample_logits[0, 3].cpu()), + ], + ], + [ + [(torch.tensor(blank_index), sample_logits[1, 0].cpu())], + [(torch.tensor(4), sample_logits[1, 2].cpu()), (torch.tensor(blank_index), sample_logits[1, 3].cpu())], + ], + ] + for batch_i in range(batch_size): + for t, group_for_timestep in enumerate(etalon[batch_i]): + for step, (label, current_logits) in enumerate(group_for_timestep): + assert torch.allclose(hypotheses[batch_i].alignments[t][step][0], current_logits) + assert hypotheses[batch_i].alignments[t][step][1] == label diff --git a/tests/collections/asr/decoding/test_rnnt_decoding.py b/tests/collections/asr/decoding/test_rnnt_decoding.py index 3c839387c956e..b898f1bc70a96 100644 --- a/tests/collections/asr/decoding/test_rnnt_decoding.py +++ b/tests/collections/asr/decoding/test_rnnt_decoding.py @@ -177,6 +177,62 @@ def test_greedy_decoding_preserve_alignments(self, test_data_dir): print(f"Tokens at timestep {t} = {t_u}") print() + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.with_downloads + @pytest.mark.unit + @pytest.mark.parametrize("loop_labels", [True, False]) + def test_batched_greedy_decoding_preserve_alignments(self, test_data_dir, loop_labels: bool): + """Test batched greedy decoding using non-batched decoding as a reference""" + model, encoded, encoded_len = get_model_encoder_output(test_data_dir, 'stt_en_conformer_transducer_small') + + search_algo = greedy_decode.GreedyBatchedRNNTInfer( + model.decoder, + model.joint, + blank_index=model.joint.num_classes_with_blank - 1, + max_symbols_per_step=5, + preserve_alignments=True, + loop_labels=loop_labels, + ) + + etalon_search_algo = greedy_decode.GreedyRNNTInfer( + model.decoder, + model.joint, + blank_index=model.joint.num_classes_with_blank - 1, + max_symbols_per_step=5, + preserve_alignments=True, + ) + + enc_out = encoded + enc_len = encoded_len + + with torch.no_grad(): + hyps: list[rnnt_utils.Hypothesis] = search_algo(encoder_output=enc_out, encoded_lengths=enc_len)[0] + hyp = decode_text_from_greedy_hypotheses(hyps, model.decoding)[0] + etalon_hyps: list[rnnt_utils.Hypothesis] = etalon_search_algo( + encoder_output=enc_out, encoded_lengths=enc_len + )[0] + etalon_hyp = decode_text_from_greedy_hypotheses(etalon_hyps, model.decoding)[0] + + assert hyp.alignments is not None + assert etalon_hyp.alignments is not None + + assert hyp.text == etalon_hyp.text + assert len(hyp.alignments) == len(etalon_hyp.alignments) + + for t in range(len(hyp.alignments)): + t_u = [] + for u in range(len(hyp.alignments[t])): + logp, label = hyp.alignments[t][u] + assert torch.is_tensor(logp) + assert torch.is_tensor(label) + etalon_logp, etalon_label = etalon_hyp.alignments[t][u] + assert label == etalon_label + assert torch.allclose(logp, etalon_logp, atol=1e-4, rtol=1e-4) + + t_u.append(int(label)) + @pytest.mark.skipif( not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', ) diff --git a/tests/collections/asr/test_asr_rnnt_encdec_model.py b/tests/collections/asr/test_asr_rnnt_encdec_model.py index 8622ab9b53ce3..d7c47adce1ad2 100644 --- a/tests/collections/asr/test_asr_rnnt_encdec_model.py +++ b/tests/collections/asr/test_asr_rnnt_encdec_model.py @@ -16,6 +16,7 @@ import pytest import torch +import torch.nn.functional as F from omegaconf import DictConfig, ListConfig from nemo.collections.asr.models import EncDecRNNTModel @@ -41,40 +42,67 @@ class DummyRNNTDecoder(AbstractRNNTDecoder): def predict( self, y: Optional[torch.Tensor] = None, - state: Optional[torch.Tensor] = None, + state: Optional[List[torch.Tensor]] = None, add_sos: bool = False, batch_size: Optional[int] = None, ) -> Tuple[torch.Tensor, List[torch.Tensor]]: if batch_size is None: batch_size = 1 if y is not None: - y = y + torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32).repeat(y.size()) + assert len(y.shape) == 2 + assert list(y.shape) == [batch_size, 1] + if state is not None: + assert len(state) == 1 + assert len(state[0].shape) == 3 + assert list(state[0].shape) == [1, batch_size, self.vocab_size + 1] + if y is not None: + # boost blank + output = F.one_hot(y, num_classes=self.vocab_size + 1) + torch.tensor( + [0] * self.vocab_size + [1], dtype=torch.float32 + )[None, None, :].expand([batch_size, 1, -1]) if y is not None and state is not None: - return (y + state) / 2, y * state + return (output + state[0].transpose(0, 1)) / 2, [output.transpose(0, 1) * state[0]] elif state is not None: - return torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32).repeat(state.size()), state + return ( + torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32)[None, None, :].expand( + [batch_size, 1, -1] + ), + state, + ) elif y is not None: - return y, torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32).repeat(y.size()) + return ( + output, + [ + torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32)[None, None, :].exand( + [1, batch_size, -1] + ) + ], + ) + # y, state - None (initial call) return ( - torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32).repeat([1, batch_size, 1]), - torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32).repeat([1, batch_size, 1]), + torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32)[None, None, :].expand( + [batch_size, 1, -1] + ), + [ + torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32)[None, None, :].expand( + [1, batch_size, -1] + ) + ], ) - def initialize_state(self, y: torch.Tensor) -> List[torch.Tensor]: - return [torch.tensor()] + def initialize_state(self, y: torch.Tensor) -> Optional[List[torch.Tensor]]: + return None def score_hypothesis( self, hypothesis: Hypothesis, cache: Dict[Tuple[int], Any] ) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]: return torch.tensor(), [torch.tensor()], torch.tensor() - def batch_select_state(self, batch_states: List[torch.Tensor], idx: int) -> List[List[torch.Tensor]]: + def batch_select_state( + self, batch_states: Optional[List[torch.Tensor]], idx: int + ) -> Optional[List[List[torch.Tensor]]]: if batch_states is not None: - try: - states = batch_states[0][idx] - states = states.long() - except Exception as e: - raise Exception(batch_states, idx) + states = [batch_states[0][:, idx]] return [states] else: return None @@ -87,18 +115,39 @@ def batch_copy_states( value: Optional[float] = None, ) -> List[torch.Tensor]: if value is None: - old_states[0][ids, :] = new_states[0][ids, :] + old_states[0][:, ids] = new_states[0][:, ids] return old_states + def mask_select_states( + self, states: Optional[torch.Tensor], mask: torch.Tensor + ) -> Optional[List[torch.Tensor]]: + if states is None: + return None + return [states[0][:, mask]] + class DummyRNNTJoint(AbstractRNNTJoint): - def joint(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor: + def __init__(self, num_outputs: int): + super().__init__() + self.num_outputs = num_outputs + + @property + def num_classes_with_blank(self): + return self.num_outputs + + def project_encoder(self, encoder_output: torch.Tensor) -> torch.Tensor: + return encoder_output + + def project_prednet(self, prednet_output: torch.Tensor) -> torch.Tensor: + return prednet_output + + def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor: return f.unsqueeze(dim=2) + g.unsqueeze(dim=1) setup = {} setup["decoder"] = DummyRNNTDecoder(vocab_size=2, blank_idx=2, blank_as_pad=True) setup["decoder_masked"] = DummyRNNTDecoder(vocab_size=2, blank_idx=2, blank_as_pad=False) - setup["joint"] = DummyRNNTJoint() + setup["joint"] = DummyRNNTJoint(num_outputs=3) # expected timesteps for max_symbols_per_step=5 are [[0, 0, 0, 0, 0, 1, 1], [1, 1, 1, 1, 1]], # so we have both looped and regular iteration on the second frame setup["encoder_output"] = torch.tensor( @@ -760,8 +809,11 @@ def test_greedy_decoding_preserve_frame_confidence(self, greedy_class): confidence_len = len(hyp.frame_confidence[t]) assert confidence_len <= max_symbols_per_step if t in timestep_count: # non-blank + # if timestep_count[t] less than max_symbols_per_step, + # blank emission and corresponding confidence expected + # if timestep_count[t] == max_symbols_per_step, "forced blank" is not added => no confidence assert confidence_len == timestep_count[t] + ( - 1 if confidence_len < max_symbols_per_step else 0 + 1 if timestep_count[t] < max_symbols_per_step else 0 ) else: # blank assert confidence_len == 1 @@ -777,7 +829,7 @@ def test_greedy_decoding_preserve_frame_confidence(self, greedy_class): @pytest.mark.parametrize( "greedy_class", [greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyBatchedRNNTInfer], ) - @pytest.mark.parametrize("max_symbols_per_step", [0, 1, 5]) + @pytest.mark.parametrize("max_symbols_per_step", [1, 5]) def test_greedy_decoding_max_symbols_alignment(self, max_symbols_setup, greedy_class, max_symbols_per_step): decoders = [max_symbols_setup["decoder"]] if greedy_class is greedy_decode.GreedyBatchedRNNTInfer: @@ -819,7 +871,32 @@ def test_greedy_decoding_max_symbols_alignment(self, max_symbols_setup, greedy_c @pytest.mark.parametrize( "greedy_class", [greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyBatchedRNNTInfer], ) - @pytest.mark.parametrize("max_symbols_per_step", [0, 1, 5]) + @pytest.mark.parametrize("max_symbols_per_step", [-1, 0]) + def test_greedy_decoding_max_symbols_confidence(self, max_symbols_setup, greedy_class, max_symbols_per_step): + """Test ValueError for max_symbols_per_step <= 0""" + decoders = [max_symbols_setup["decoder"]] + if greedy_class is greedy_decode.GreedyBatchedRNNTInfer: + decoders.append(max_symbols_setup["decoder_masked"]) + joint = max_symbols_setup["joint"] + + for decoder in decoders: + with pytest.raises(ValueError): + _ = greedy_class( + decoder_model=decoder, + joint_model=joint, + blank_index=decoder.blank_idx, + max_symbols_per_step=max_symbols_per_step, + preserve_frame_confidence=True, + ) + + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.unit + @pytest.mark.parametrize( + "greedy_class", [greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyBatchedRNNTInfer], + ) + @pytest.mark.parametrize("max_symbols_per_step", [1, 5]) def test_greedy_decoding_max_symbols_confidence(self, max_symbols_setup, greedy_class, max_symbols_per_step): decoders = [max_symbols_setup["decoder"]] if greedy_class is greedy_decode.GreedyBatchedRNNTInfer: