Skip to content

Commit

Permalink
Add relevant changs to TDT cuda graphs decoding as well.
Browse files Browse the repository at this point in the history
I didn't test this because I'm not sure how. But it seems low risk.
  • Loading branch information
galv committed May 29, 2024
1 parent 69355c0 commit abc61e6
Showing 1 changed file with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -691,29 +691,29 @@ def _partial_graphs_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.separate_graphs.before_outer_loop, stream=stream_for_graph),
torch.cuda.graph(self.separate_graphs.before_outer_loop, stream=stream_for_graph, capture_error_mode="thread_local"),
):
self._before_outer_loop()

with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.separate_graphs.before_inner_loop, stream=stream_for_graph),
torch.cuda.graph(self.separate_graphs.before_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"),
):
self._before_inner_loop_get_decoder_output()
self._before_inner_loop_get_joint_output()

with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.separate_graphs.inner_loop_code, stream=stream_for_graph),
torch.cuda.graph(self.separate_graphs.inner_loop_code, stream=stream_for_graph, capture_error_mode="thread_local"),
):
self._inner_loop_code()

with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.separate_graphs.after_inner_loop, stream=stream_for_graph),
torch.cuda.graph(self.separate_graphs.after_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"),
):
self._after_inner_loop()

Expand All @@ -726,7 +726,7 @@ def _full_graph_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.full_graph, stream=stream_for_graph),
torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"),
):
self._before_outer_loop()

Expand Down

0 comments on commit abc61e6

Please sign in to comment.