diff --git a/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py index 5746850eb93b..4d6943abad51 100644 --- a/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py @@ -228,7 +228,6 @@ def __init__(self, max_symbols: int, caller): self.caller = caller def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_length): - torch.cuda.nvtx.range_push("Init") if self.first_call: # We need to call the original _greedy_decode_blank_as_pad # implementation at least once beforehand in order to make @@ -283,10 +282,6 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len self.graph = torch.cuda.CUDAGraph() - torch.cuda.nvtx.range_pop() - - torch.cuda.nvtx.range_push("build graph") - # Always create a new stream, because the per-thread default stream disallows stream capture to a graph. stream_for_graph = torch.cuda.Stream(self.device) with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( @@ -402,8 +397,6 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len self.last_label.fill_(self.caller._SOS) self.time_idx_t.fill_(0) - torch.cuda.nvtx.range_pop() - def __call__( self, x: torch.Tensor, @@ -440,17 +433,14 @@ def __call__( # set self.first_call to True to make sure that all # possibly blocking initializers are initialized properly # again on the new device. - print("GALVEZ: reinit!") if self.device != x.device: self.first_call = True self._reinitialize(max_time, batch_size, x, out_len) - torch.cuda.nvtx.range_push("Graph") self.encoder_output[: x.shape[0], : x.shape[1], ...].copy_(x) self.encoder_output_length[: out_len.shape[0]].copy_(out_len) self.graph.replay() torch.cuda.current_stream(device=self.device).synchronize() - torch.cuda.nvtx.range_pop() self.scores_cpu[self.labels_cpu == self.caller._blank_index] = 0.0 total_scores = self.scores_cpu.sum(dtype=torch.float32, axis=(1, 2))