Skip to content

Commit

Permalink
Speed up RNN-T greedy decoding with cuda graphs
Browse files Browse the repository at this point in the history
This uses CUDA 12.3's conditional node support.

Signed-off-by: Daniel Galvez <[email protected]>
  • Loading branch information
galv committed Feb 14, 2024
1 parent b100cd1 commit 5b31417
Show file tree
Hide file tree
Showing 6 changed files with 605 additions and 2 deletions.
2 changes: 1 addition & 1 deletion nemo/collections/asr/modules/rnnt_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def batch_select_state(self, batch_states: List[torch.Tensor], idx: int) -> List
def batch_replace_states_mask(
cls, src_states: list[torch.Tensor], dst_states: list[torch.Tensor], mask: torch.Tensor,
):
"""Replace states in dst_states with states from src_states using the mask"""
"""Replace states in dst_states with states from src_states using the mask, in a way that does not synchronize with the CPU"""
raise NotImplementedError()

def batch_split_states(self, batch_states: list[torch.Tensor]) -> list[list[torch.Tensor]]:
Expand Down
Loading

0 comments on commit 5b31417

Please sign in to comment.