Skip to content

Commit

Permalink
[NVIDIA][Launcher] Ensure device context is valid before calling getP…
Browse files Browse the repository at this point in the history
…ointer (#5276)
  • Loading branch information
peterbell10 authored and bertmaher committed Dec 20, 2024
1 parent 0d4682f commit f6b96ae
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 9 deletions.
27 changes: 27 additions & 0 deletions python/test/unit/runtime/test_driver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import sys
from concurrent.futures import ThreadPoolExecutor
import torch

import triton
import triton.language as tl


def test_is_lazy():
Expand All @@ -12,3 +15,27 @@ def test_is_lazy():
assert triton.runtime.driver.active._obj is None
utils = triton.runtime.driver.active.utils # noqa: F841
assert issubclass(triton.runtime.driver.active._obj.__class__, getattr(triton.backends.driver, "DriverBase"))


def test_kernel_in_thread(device):
# Test calling in a new thread sets a valid device context
buf = torch.zeros((38016 * 1024, ), dtype=torch.float32, device=device)

@triton.jit
def _kernel(P, BLOCK: tl.constexpr):
pid = tl.program_id(0).to(tl.int64)
offset = pid * BLOCK + tl.arange(0, BLOCK)

p = tl.load(P + offset)
tl.store(P + offset, p)

def call_triton():
N = buf.numel()
grid = lambda meta: (triton.cdiv(N, meta["BLOCK"]), )
_kernel[grid](buf, BLOCK=1024)
getattr(torch, device).synchronize()

call_triton()
with ThreadPoolExecutor(1) as pool:
future = pool.submit(call_triton)
future.result()
27 changes: 18 additions & 9 deletions third_party/nvidia/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,15 +208,6 @@ def format_of(ty):
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }};
if (gridX*gridY*gridZ > 0) {{
CUcontext pctx;
CUDA_CHECK(cuCtxGetCurrent(&pctx));
if (!pctx) {{
// Ensure device context.
CUdevice device;
CUDA_CHECK(cuDeviceGet(&device, 0));
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
CUDA_CHECK(cuCtxSetCurrent(pctx));
}}
if (num_ctas == 1) {{
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
}} else {{
Expand Down Expand Up @@ -284,6 +275,9 @@ def format_of(ty):
PyErr_Format(PyExc_ValueError,
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
ptr_info.valid = false;
}} else if (status != CUDA_SUCCESS) {{
CUDA_CHECK(status); // Catch any other cuda API errors
ptr_info.valid = false;
}}
ptr_info.dev_ptr = dev_ptr;
Py_DECREF(ret); // Thanks ChatGPT!
Expand Down Expand Up @@ -340,7 +334,22 @@ def format_of(ty):
return (CUtensorMap*)(ptr_as_uint);
}}
static void ensureCudaContext() {{
CUcontext pctx;
CUDA_CHECK(cuCtxGetCurrent(&pctx));
if (!pctx) {{
// Ensure device context.
CUdevice device;
CUDA_CHECK(cuDeviceGet(&device, 0));
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
CUDA_CHECK(cuCtxSetCurrent(pctx));
}}
}}
static PyObject* launch(PyObject* self, PyObject* args) {{
// ensure cuda context is valid before calling any CUDA APIs, e.g. before getPointer calls cuPointerGetAttributes
ensureCudaContext();
int gridX, gridY, gridZ;
uint64_t _stream;
uint64_t _function;
Expand Down

0 comments on commit f6b96ae

Please sign in to comment.