From 69355c0073850d26e4f62f8fedc5306582141804 Mon Sep 17 00:00:00 2001 From: galv Date: Wed, 29 May 2024 15:50:06 +0000 Subject: [PATCH] Apply isort and black reformatting Signed-off-by: galv --- .../cuda_graph_rnnt_greedy_decoding.py | 8 +++++--- .../submodules/rnnt_loop_labels_computer.py | 16 ++++++++++++---- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py index b18dfa4384237..aa49435ded163 100644 --- a/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py @@ -37,7 +37,7 @@ def create_outer_for_loop_kernel(): """ - Creates a kernel that evaluates whether or not to enter the for loop body. + Creates a kernel that evaluates whether or not to enter the for loop body. Effectively substitutes for `for time_idx in range(trip_count)` such that that for loop can run on a GPU. """ @@ -171,8 +171,10 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len # Always create a new stream, because the per-thread default stream disallows stream capture to a graph. stream_for_graph = torch.cuda.Stream(self.device) - with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph( - self.graph, stream=stream_for_graph, capture_error_mode="thread_local" + with ( + torch.cuda.stream(stream_for_graph), + torch.inference_mode(), + torch.cuda.graph(self.graph, stream=stream_for_graph, capture_error_mode="thread_local"), ): # This is failing... self.f = torch.zeros( diff --git a/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py b/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py index c53b2c04160ec..c0783c301c440 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py @@ -630,14 +630,18 @@ def _partial_graphs_compile(self): with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), - torch.cuda.graph(self.separate_graphs.before_outer_loop, stream=stream_for_graph, capture_error_mode="thread_local"), + torch.cuda.graph( + self.separate_graphs.before_outer_loop, stream=stream_for_graph, capture_error_mode="thread_local" + ), ): self._before_outer_loop() with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), - torch.cuda.graph(self.separate_graphs.before_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"), + torch.cuda.graph( + self.separate_graphs.before_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local" + ), ): self._before_inner_loop_get_decoder_output() self._before_inner_loop_get_joint_output() @@ -645,14 +649,18 @@ def _partial_graphs_compile(self): with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), - torch.cuda.graph(self.separate_graphs.inner_loop_code, stream=stream_for_graph, capture_error_mode="thread_local"), + torch.cuda.graph( + self.separate_graphs.inner_loop_code, stream=stream_for_graph, capture_error_mode="thread_local" + ), ): self._inner_loop_code() with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), - torch.cuda.graph(self.separate_graphs.after_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"), + torch.cuda.graph( + self.separate_graphs.after_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local" + ), ): self._after_inner_loop()