Skip to content

Commit

Permalink
Update the test checking for cooperative kernels in conditional nodes.
Browse files Browse the repository at this point in the history
Now we conditionally xfail only when a cuda driver version less than
12.6 is installed. CUDA 12.6 fixes this issue. Before it, cooperative
kernels could not be used within the body of a conditional node.

We also provide a better error message for users to know that the fix
is to upgrade to CUDA 12.6.

Signed-off-by: Daniel Galvez <[email protected]>
  • Loading branch information
galv committed Jul 24, 2024
1 parent cd0d2c2 commit 05b6672
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from nemo.collections.asr.parts.utils import rnnt_utils
from nemo.core.utils.cuda_python_utils import (
check_cuda_python_cuda_graphs_conditional_nodes_supported,
checked_graph,
cu_call,
run_nvrtc,
with_conditional_node,
Expand Down Expand Up @@ -174,7 +175,7 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.graph, stream=stream_for_graph, capture_error_mode="thread_local"),
checked_graph(self.graph, stream=stream_for_graph, capture_error_mode="thread_local"),
):
# This is failing...
self.f = torch.zeros(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin
from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs
from nemo.core.utils.cuda_python_utils import (
checked_graph,
check_cuda_python_cuda_graphs_conditional_nodes_supported,
cu_call,
run_nvrtc,
Expand Down Expand Up @@ -630,7 +631,7 @@ def _partial_graphs_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(
checked_graph(
self.separate_graphs.before_outer_loop, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
Expand All @@ -639,7 +640,7 @@ def _partial_graphs_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(
checked_graph(
self.separate_graphs.before_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
Expand All @@ -649,7 +650,7 @@ def _partial_graphs_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(
checked_graph(
self.separate_graphs.inner_loop_code, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
Expand All @@ -658,7 +659,7 @@ def _partial_graphs_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(
checked_graph(
self.separate_graphs.after_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
Expand All @@ -672,7 +673,7 @@ def _full_graph_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"),
checked_graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"),
):
self._before_outer_loop()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin
from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs
from nemo.core.utils.cuda_python_utils import (
checked_graph,
check_cuda_python_cuda_graphs_conditional_nodes_supported,
cu_call,
run_nvrtc,
Expand Down Expand Up @@ -691,7 +692,7 @@ def _partial_graphs_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(
checked_graph(
self.separate_graphs.before_outer_loop, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
Expand All @@ -700,7 +701,7 @@ def _partial_graphs_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(
checked_graph(
self.separate_graphs.before_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
Expand All @@ -710,7 +711,7 @@ def _partial_graphs_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(
checked_graph(
self.separate_graphs.inner_loop_code, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
Expand All @@ -719,7 +720,7 @@ def _partial_graphs_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(
checked_graph(
self.separate_graphs.after_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
Expand All @@ -734,7 +735,7 @@ def _full_graph_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"),
checked_graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"),
):
self._before_outer_loop()

Expand Down
34 changes: 34 additions & 0 deletions nemo/core/utils/cuda_python_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,27 @@ def cu_call(f_call_out):
return tuple(others)


def cuda_python_conditional_node_cooperative_kernels_supported():
"""
Returns true if cuda-python is installed and CUDA driver 12.6 or newer is
installed. Before this CUDA driver version, cooperative nodes could not run
within cuda graph conditional nodes.
"""
try:
check_cuda_python_cuda_graphs_conditional_nodes_supported()
except:
return False
else:
from cuda import cuda

error, driver_version = cuda.cuDriverGetVersion()
if error != cuda.CUresult.CUDA_SUCCESS:
raise ImportError(f"cuDriverGetVersion() returned {cuda.cuGetErrorString(error)}")
driver_version_major = driver_version // 1000
driver_version_minor = (driver_version % 1000) // 10
driver_version = (driver_version_major, driver_version_minor)
return driver_version >= (12,6)

@contextlib.contextmanager
def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditional_handle, device):
"""
Expand Down Expand Up @@ -219,3 +240,16 @@ def run_nvrtc(kernel_string: str, kernel_name: bytes, program_name: bytes):
assert_drv(err)

return kernel

@contextlib.contextmanager
def checked_graph(*args, **kwargs):
"""
Wrapper around torch.cuda.graph that checks for common errors that are too vague for an end user to diagnose based on the error message.
"""
try:
with torch.cuda.graph(*args, **kwargs):
yield
except RuntimeError as err:
if "CUDA error: invalid argument" in str(err):
raise RuntimeError("CUDA Graph capture failed. It is likely that you are calling a cooperative kernel in your RNN-T or TDT prediction network. Cooperative kernels are not compatible with CUDA Graphs until CUDA 12.6. Please update to CUDA 12.6. File an issue if that still does not work.") from err
raise
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from omegaconf import open_dict

from nemo.collections.asr.models import ASRModel
from nemo.core.utils.cuda_python_utils import skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported
from nemo.core.utils.cuda_python_utils import (
skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported,
cuda_python_conditional_node_cooperative_kernels_supported
)


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -53,8 +56,9 @@ def stt_en_fastconformer_transducer_large():
8,
True,
marks=pytest.mark.xfail(
reason="""Cannot instantiate the
body cuda graph of a conditional node with a persistent kernel (in this case,
not cuda_python_conditional_node_cooperative_kernels_supported(),
reason="""Cannot instantiate the
body cuda graph of a conditional node with a persistent kernel (in this case,
a persistent LSTM), which is triggered in cudnn by using a batch size of 8."""
),
),
Expand Down

0 comments on commit 05b6672

Please sign in to comment.