Skip to content

Commit

Permalink
Remove spurious print, spurious nvtx ranges
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Galvez <[email protected]>
  • Loading branch information
galv committed Feb 22, 2024
1 parent c2d9248 commit 118c01a
Showing 1 changed file with 0 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@ def __init__(self, max_symbols: int, caller):
self.caller = caller

def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_length):
torch.cuda.nvtx.range_push("Init")
if self.first_call:
# We need to call the original _greedy_decode_blank_as_pad
# implementation at least once beforehand in order to make
Expand Down Expand Up @@ -283,10 +282,6 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len

self.graph = torch.cuda.CUDAGraph()

torch.cuda.nvtx.range_pop()

torch.cuda.nvtx.range_push("build graph")

# Always create a new stream, because the per-thread default stream disallows stream capture to a graph.
stream_for_graph = torch.cuda.Stream(self.device)
with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph(
Expand Down Expand Up @@ -402,8 +397,6 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len
self.last_label.fill_(self.caller._SOS)
self.time_idx_t.fill_(0)

torch.cuda.nvtx.range_pop()

def __call__(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -440,17 +433,14 @@ def __call__(
# set self.first_call to True to make sure that all
# possibly blocking initializers are initialized properly
# again on the new device.
print("GALVEZ: reinit!")
if self.device != x.device:
self.first_call = True
self._reinitialize(max_time, batch_size, x, out_len)

torch.cuda.nvtx.range_push("Graph")
self.encoder_output[: x.shape[0], : x.shape[1], ...].copy_(x)
self.encoder_output_length[: out_len.shape[0]].copy_(out_len)
self.graph.replay()
torch.cuda.current_stream(device=self.device).synchronize()
torch.cuda.nvtx.range_pop()

self.scores_cpu[self.labels_cpu == self.caller._blank_index] = 0.0
total_scores = self.scores_cpu.sum(dtype=torch.float32, axis=(1, 2))
Expand Down

0 comments on commit 118c01a

Please sign in to comment.