Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: galv <[email protected]>
  • Loading branch information
galv committed May 29, 2024
1 parent cec4f53 commit 69355c0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

def create_outer_for_loop_kernel():
"""
Creates a kernel that evaluates whether or not to enter the for loop body.
Creates a kernel that evaluates whether or not to enter the for loop body.
Effectively substitutes for `for time_idx in range(trip_count)`
such that that for loop can run on a GPU.
"""
Expand Down Expand Up @@ -171,8 +171,10 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len

# Always create a new stream, because the per-thread default stream disallows stream capture to a graph.
stream_for_graph = torch.cuda.Stream(self.device)
with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph(
self.graph, stream=stream_for_graph, capture_error_mode="thread_local"
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.graph, stream=stream_for_graph, capture_error_mode="thread_local"),
):
# This is failing...
self.f = torch.zeros(
Expand Down
16 changes: 12 additions & 4 deletions nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,29 +630,37 @@ def _partial_graphs_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.separate_graphs.before_outer_loop, stream=stream_for_graph, capture_error_mode="thread_local"),
torch.cuda.graph(
self.separate_graphs.before_outer_loop, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
self._before_outer_loop()

with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.separate_graphs.before_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"),
torch.cuda.graph(
self.separate_graphs.before_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
self._before_inner_loop_get_decoder_output()
self._before_inner_loop_get_joint_output()

with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.separate_graphs.inner_loop_code, stream=stream_for_graph, capture_error_mode="thread_local"),
torch.cuda.graph(
self.separate_graphs.inner_loop_code, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
self._inner_loop_code()

with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.separate_graphs.after_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"),
torch.cuda.graph(
self.separate_graphs.after_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
self._after_inner_loop()

Expand Down

0 comments on commit 69355c0

Please sign in to comment.