Skip to content

Commit

Permalink
[INTERPRETER] Implement implicit tensor conversion for assignment ope…
Browse files Browse the repository at this point in the history
…rators (#4214)
  • Loading branch information
Jokeren authored Jun 28, 2024
1 parent 938e388 commit 1b35f11
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 40 deletions.
2 changes: 0 additions & 2 deletions docs/programming-guide/chapter-3/debugging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ The interpreter has several known limitations:
ptr = tl.load(ptr)
x = tl.load(ptr)
- Unlike the compilation mode, a scalar in interpreter mode is treated as a simple float or integer but not as a 0-d tensor. This means it lacks tensor attributes such as :code:`x.dtype`. A workaround is to explicitly convert the scalar to a tensor using :code:`tl.to_tensor(x)`, where :code:`x` is the scalar.

----------------------------
Using Third-party Tools
----------------------------
Expand Down
33 changes: 27 additions & 6 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1702,6 +1702,27 @@ def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
assert torch.all(output == ref)


@pytest.mark.interpreter
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_store_constant_default_dtype(num_ctas, device):
"""Tests that boolean True is stored as 1"""

@triton.jit
def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
value = 1
output = tl.full([BLOCK_SIZE], value=value, dtype=value.dtype)
tl.store(output_ptr + offsets, output, mask=mask)

block_size = 128
ref = torch.ones([block_size], dtype=getattr(torch, 'int32'), device=device)
output = torch.zeros([block_size], dtype=getattr(torch, 'int32'), device=device)
kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas)

assert torch.all(output == ref)


def test_load_store_same_ptr(device):

@triton.jit()
Expand Down Expand Up @@ -5342,12 +5363,12 @@ def test_tl_range(device):
torch.testing.assert_close(ref_out, c, rtol=1e-2, atol=1e-1)
else:
torch.testing.assert_close(ref_out, c, rtol=1e-3, atol=1e-3)
if device in ['cuda']:
capability = torch.cuda.get_device_capability()
if capability[0] >= 8:
ptx = pgm.asm['ptx']
# check that the loop got pipelined with the right number of stages.
assert 'cp.async.wait_group 0x6' in ptx
if device in ['cuda']:
capability = torch.cuda.get_device_capability()
if capability[0] >= 8:
ptx = pgm.asm['ptx']
# check that the loop got pipelined with the right number of stages.
assert 'cp.async.wait_group 0x6' in ptx


@triton.jit(noinline=True)
Expand Down
37 changes: 34 additions & 3 deletions python/test/unit/language/test_line_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,8 @@ def check_file_lines(file_lines, file_name, lineno, should_contain=True):
should_contain: whether the file name and line number should be in the file_lines
"""
for file, line in file_lines:
if lineno == -1:
if file_name in file:
return True
if lineno == -1 and file_name in file:
return True
if file_name in file and str(lineno) in line:
return should_contain
return not should_contain
Expand Down Expand Up @@ -169,3 +168,35 @@ def test_line_info(func: str):
elif func == "dot_combine":
assert (check_file_lines(file_lines, "test_line_info.py", 65))
assert (check_file_lines(file_lines, "test_line_info.py", 66, should_contain=False))


def is_interpreter():
import os
return os.environ.get('TRITON_INTERPRET', '0') == '1'


@pytest.mark.interpreter
@pytest.mark.parametrize("func", func_types)
def test_line_info_interpreter(func: str):
if not is_interpreter():
pytest.skip("interpreter is not enabled")

kernel = None
expected_offset = 0
if func == "single":
kernel = kernel_single
expected_offset = 12
elif func == "call":
kernel = kernel_call
expected_offset = 25
elif func == "call_noinline":
kernel = kernel_call_noinline
expected_offset = 41
elif func == "autotune":
kernel = kernel_autotune.fn
expected_offset = 52
elif func == "dot_combine":
kernel = kernel_dot_combine
expected_offset = 62
kernel._rewrite_ast()
assert kernel.ast_transformer.offset == expected_offset
24 changes: 3 additions & 21 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .._C.libtriton import ir
from ..language import constexpr, tensor, str_to_ty
from ..language.core import _unwrap_if_constexpr
from ..runtime.jit import _normalize_ty
from ..runtime.jit import _normalize_ty, get_jit_fn_file_line
# ideally we wouldn't need any runtime component
from ..runtime import JITFunction
from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
Expand Down Expand Up @@ -73,24 +73,6 @@ def _check_fn_args(node, fn, args):
)


def _get_fn_file_line(fn):
base_fn = fn
while not isinstance(base_fn, JITFunction):
base_fn = base_fn.fn
file_name = base_fn.fn.__code__.co_filename
lines, begin_line = inspect.getsourcelines(base_fn.fn)
# Match the following pattern:
# @triton.autotune(...) <- foo.__code__.co_firstlineno
# @triton.heuristics(...)
# @triton.jit
# def foo(...): <- this line is the first line
for idx, line in enumerate(lines):
if line.strip().startswith("def "):
begin_line += idx
break
return file_name, begin_line


_condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels


Expand Down Expand Up @@ -1059,7 +1041,7 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs):
prototype = language.function_type([], arg_types)
gscope = fn.__globals__
# If the callee is not set, we use the same debug setting as the caller
file_name, begin_line = _get_fn_file_line(fn)
file_name, begin_line = get_jit_fn_file_line(fn)
debug = self.debug if fn.debug is None else fn.debug
generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module,
jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types,
Expand Down Expand Up @@ -1282,7 +1264,7 @@ def ast_to_ttir(fn, specialization, context, options, codegen_fns):
all_constants = constants.copy()
all_constants.update(new_constants)
arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants]
file_name, begin_line = _get_fn_file_line(fn)
file_name, begin_line = get_jit_fn_file_line(fn)

prototype = language.function_type([], arg_types)
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name,
Expand Down
8 changes: 5 additions & 3 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def to_tensor(x, _builder=None):
return _to_tensor(x, _builder)


def _to_tensor(x, builder):
def _to_tensor(x, builder, check_type: bool = True):
if isinstance(x, bool):
return tensor(builder.get_int1(x), int1)
# Note: compile-time const integers are represented by unsigned values
Expand All @@ -129,7 +129,7 @@ def _to_tensor(x, builder):
elif 2**63 <= x < 2**64:
return tensor(builder.get_uint64(x), uint64)
else:
raise RuntimeError(f'Nonrepresentable integer {x}.')
raise ValueError(f'Nonrepresentable integer {x}.')
elif isinstance(x, float):
min_float32 = 2**-126
max_float32 = (2 - 2**-23) * 2**127
Expand All @@ -146,7 +146,9 @@ def _to_tensor(x, builder):
return _to_tensor(x.value, builder)
elif isinstance(x, tensor):
return x
assert False, f"cannot convert {x} of type {type(x)} to tensor"
if check_type:
raise TypeError(f"cannot convert {x} of type {type(x)} to tensor")
return x


# -----------------------
Expand Down
74 changes: 70 additions & 4 deletions python/triton/runtime/interpreter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import ast
import textwrap
import inspect
from typing import Tuple

Expand Down Expand Up @@ -1094,30 +1096,94 @@ def __call__(self, *args_dev, **kwargs):
self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst)


class ASTTransformer(ast.NodeTransformer):

def __init__(self) -> None:
self.offset = 0

def visit_Assign(self, node):
names = []
for target in node.targets:
names += [self.visit(target)]
if len(names) > 1:
raise ValueError("Multiple assignments are not supported")
# Modify the assignment x = value to
# triton.core.language._to_tensor(value, interpreter_builder, False)
node.value = ast.Call(
func=ast.Attribute(
value=ast.Attribute(
value=ast.Attribute(value=ast.Name(id='triton', ctx=ast.Load()), attr='language', ctx=ast.Load()),
attr='core', ctx=ast.Load()), attr='_to_tensor', ctx=ast.Load()),
args=[node.value, ast.Name(id='interpreter_builder', ctx=ast.Load()),
ast.Constant(value=False)], keywords=[])
return node

def generic_visit(self, node):
# Adjust the begin line number of the node
if hasattr(node, 'lineno') and node.lineno is not None:
node.lineno += self.offset
if hasattr(node, 'end_lineno') and node.end_lineno is not None:
node.end_lineno += self.offset
return super().generic_visit(node)


class InterpretedFunction:
rewritted_fn = {}
ast_transformer = ASTTransformer()

def __init__(self, fn) -> None:
def __init__(self, fn, **kwargs) -> None:
self.fn = fn

def run(*args, **kwargs):
grid = kwargs["grid"]
return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs)
fn = self._rewrite_ast()
return GridExecutor(fn, self.arg_names, grid)(*args, **kwargs)

self.run = run
self.kwargs = kwargs
signature = inspect.signature(fn)
self.arg_names = [v.name for v in signature.parameters.values()]

def _rewrite_ast(self):
if self.fn in self.rewritted_fn:
return self.rewritted_fn[self.fn]
# If exception is raise, it means the function does not have source code available,
# e.g., dynamically generated functions, we cannot rewrite it so just return the original function
try:
lines, lineno = inspect.getsourcelines(self.fn)
except Exception:
self.rewritted_fn[self.fn] = self.fn
return self.fn
from .jit import get_jit_fn_file_line, JITFunction
filename, lineno = get_jit_fn_file_line(JITFunction(self.fn))
src = ''.join(lines)
src = textwrap.dedent(src)
parsed_ast = ast.parse(src)
self.ast_transformer.offset = lineno
transformed_ast = self.ast_transformer.visit(parsed_ast)
transformed_ast = ast.fix_missing_locations(transformed_ast)
compiled_code = compile(transformed_ast, filename=filename, mode='exec')
local_namespace = {**self.kwargs}
if self.fn.__name__ in local_namespace:
raise ValueError(f"Function name {self.fn.__name__} is reserved")
exec(compiled_code, globals(), local_namespace)
fn = local_namespace[self.fn.__name__].fn
self.rewritted_fn[self.fn] = fn
return fn

@property
def __name__(self):
return self.fn.__name__

def __getitem__(self, grid):
return GridExecutor(self.fn, self.arg_names, grid)
fn = self._rewrite_ast()
return GridExecutor(fn, self.arg_names, grid)

def __call__(self, *args, **kwargs):
# This is a device function call
_patch_lang(self.fn)
fn = self._rewrite_ast()
try:
return self.fn(*args, **kwargs)
return fn(*args, **kwargs)
except Exception as e:
raise InterpreterError(repr(e)) from e
21 changes: 20 additions & 1 deletion python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,8 @@ def decorator(fn: T) -> JITFunction[T]:
assert callable(fn)
if os.getenv("TRITON_INTERPRET", "0") == "1":
from .interpreter import InterpretedFunction
return InterpretedFunction(fn)
return InterpretedFunction(fn, version=version, do_not_specialize=do_not_specialize, debug=debug,
noinline=noinline, repr=repr, launch_metadata=launch_metadata)
else:
return JITFunction(
fn,
Expand Down Expand Up @@ -935,3 +936,21 @@ def reinterpret(tensor, dtype):
return TensorWrapper(tensor, dtype)
else:
raise TypeError(f"Cannot reinterpret a {type(tensor)}.")


def get_jit_fn_file_line(fn):
base_fn = fn
while not isinstance(base_fn, JITFunction):
base_fn = base_fn.fn
file_name = base_fn.fn.__code__.co_filename
lines, begin_line = inspect.getsourcelines(base_fn.fn)
# Match the following pattern:
# @triton.autotune(...) <- foo.__code__.co_firstlineno
# @triton.heuristics(...)
# @triton.jit
# def foo(...): <- this line is the first line
for idx, line in enumerate(lines):
if line.strip().startswith("def "):
begin_line += idx
break
return file_name, begin_line

0 comments on commit 1b35f11

Please sign in to comment.