Skip to content

Commit

Permalink
[Frontend] Validate minimum TMA block size in the frontend (#5611)
Browse files Browse the repository at this point in the history
The current TMA implementation requires the TMA block shape to have a
minimum size, which is the swizzle tile size in smem. Add an error
message for this in the frontend so users stop running into codegen
crashes.
  • Loading branch information
Mogball authored Jan 16, 2025
1 parent b4f89c9 commit 94f80f4
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
28 changes: 28 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4733,6 +4733,34 @@ def kernel():
assert "reshape" in str(exc_info.value)


def test_tma_load_block_shape_err():

@triton.jit
def kernel(ptr):
desc = tl._experimental_make_tensor_descriptor(ptr, [128, 128], [128, 1], [1, 32])
desc.load([0, 0])

input = torch.empty((128, 128), dtype=torch.int32, device='cuda')
with pytest.raises(triton.CompilationError) as e:
kernel[(1, )](input)

assert "tensor descriptor block shape must have at least 8 rows" in str(e.value.__cause__)


def test_tma_store_block_shape_err():

@triton.jit
def kernel(ptr):
desc = tl._experimental_make_tensor_descriptor(ptr, [128, 128], [128, 1], [8, 8])
desc.store([0, 0], tl.zeros((1, 32), dtype=tl.int16))

input = torch.empty((128, 128), dtype=torch.int16, device='cuda')
with pytest.raises(triton.CompilationError) as e:
kernel[(1, )](input)

assert "int16 tensor descriptor block shape must have at least 16 columns" in str(e.value.__cause__)


def test_trans_reshape(device):

@triton.jit
Expand Down
15 changes: 15 additions & 0 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,9 +1147,22 @@ def reinterpret_tensor_descriptor(desc_ptr: tl.tensor, block_ty: tl.block_type,
return tl._experimental_tensor_descriptor_base(handle, block_ty)


def validate_descriptor_block(shape, dtype):
if len(shape) != 2:
return
# Due to limitations of the shared memory encoding, the TMA bounding box has
# to be at least as big as the swizzle tile.
assert shape[0] >= 8, f"tensor descriptor block shape must have at least 8 rows, but got {shape[0]}"
min_cols = 32 // dtype.primitive_bitwidth * 8
assert shape[
1] >= min_cols, f"{dtype} tensor descriptor block shape must have at least {min_cols} columns, but got {shape[1]}"


def descriptor_load(desc: tl._experimental_tensor_desciptor_base, offsets, cache_modifier: str, eviction_policy: str,
builder: ir.builder) -> tl.tensor:
assert isinstance(desc, tl._experimental_tensor_descriptor_base)
validate_descriptor_block(desc.block_shape, desc.type.element_ty)

offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
x = builder.create_descriptor_load(desc.handle, offsets, _str_to_load_cache_modifier(cache_modifier),
_str_to_eviction_policy(eviction_policy))
Expand All @@ -1159,6 +1172,8 @@ def descriptor_load(desc: tl._experimental_tensor_desciptor_base, offsets, cache
def descriptor_store(desc: tl._experimental_tensor_descriptor_base, value: tl.tensor, offsets,
builder: ir.builder) -> tl.tensor:
assert isinstance(desc, tl._experimental_tensor_descriptor_base)
validate_descriptor_block(desc.block_shape, desc.type.element_ty)

offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
return tl.tensor(builder.create_descriptor_store(desc.handle, value.handle, offsets), tl.void)

Expand Down

0 comments on commit 94f80f4

Please sign in to comment.