Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setting the environment variable TRITON_INTERPRET causes the kernel function to not be able to receive reserved keyword arguments. #5164

Closed
Asuka0630 opened this issue Nov 15, 2024 · 0 comments · Fixed by #5169
Assignees
Labels

Comments

@Asuka0630
Copy link

Describe the bug

Thanks for triton team's excellent work
👍
Describe the bug
When I tried to put the reserved keyword arguments in the kernel function's parameters and set TRITON_INTERPRET=1, I encountered an error of not finding these arguments (this will not affect you if you do not set this environment variable)
I'm not sure if this is a bug or a trivial problem(or an operation not allowed).
minimal compelete example

import os
os.environ["TRITON_INTERPRET"] = "1"
import torch
import triton
import triton.language as tl

@triton.autotune(
    configs=[
        triton.Config({}, num_stages=2, num_warps=1),
    ],
    key=["BLOCK_SIZE"],
)
@triton.jit
def simple_kernel(
    a_ptr,
    out_ptr,
    n_cols,
    BLOCK_SIZE: tl.constexpr,
    num_stages: tl.constexpr,
):
    _sum = 0.0
    for idx in tl.range(0, n_cols, BLOCK_SIZE, num_stages=num_stages):
        off = idx + tl.arange(0, BLOCK_SIZE)
        a_ptrs = a_ptr + off
        a = tl.load(a_ptrs, mask=off < n_cols, other=0.0)
        _sum += tl.sum(a)
    tl.store(out_ptr, _sum)


N = 12
a = torch.randn((N,), device="cuda", dtype=torch.float16)
triton_out = torch.zeros((1,), device="cuda", dtype=torch.float16)
BLOCK_SIZE = 4
simple_kernel[(1, 1, 1)](
    a_ptr=a,
    out_ptr=triton_out,
    n_cols=N,
    BLOCK_SIZE=BLOCK_SIZE,
)
torch_out = torch.sum(a)
print(torch.allclose(triton_out, torch_out, atol=1e-2, rtol=1e-2))

Error message

Traceback (most recent call last):
File "/user/test/patch.py", line 39, in
simple_kernel[(1, 1, 1)](
File "/user/miniconda3/lib/python3.12/site-packages/triton/runtime/jit.py", line 345, in
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/user/miniconda3/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 171, in run
ret = self.fn.run(
^^^^^^^^^^^^
File "/user/miniconda3/lib/python3.12/site-packages/triton/runtime/interpreter.py", line 1108, in run
return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/user/miniconda3/lib/python3.12/site-packages/triton/runtime/interpreter.py", line 1082, in call
args = inspect.getcallargs(self.fn, *args_hst, **kwargs_hst)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/user/miniconda3/lib/python3.12/inspect.py", line 1583, in getcallargs
_missing_arguments(f_name, req, True, arg2value)
File "/user/miniconda3/lib/python3.12/inspect.py", line 1512, in _missing_arguments
raise TypeError("%s() missing %i required %s argument%s: %s" %
TypeError: simple_kernel() missing 1 required positional argument: 'num_stages'

Possible direct cause

kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS}

removed all reserved keywords arguments

Possible solution

        req_args = inspect.getfullargspec(self.fn)[0]
        kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS or k in req_args}

Environment details

Triton: 3.0.0
GPU: Tesla V100-PCIE-32GB
Python: 3.12

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
2 participants