Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Validate minimum TMA block size in the frontend #5611

Merged
merged 6 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4732,6 +4732,34 @@ def kernel():
assert "reshape" in str(exc_info.value)


def test_tma_load_block_shape_err():

@triton.jit
def kernel(ptr):
desc = tl._experimental_make_tensor_descriptor(ptr, [128, 128], [128, 1], [1, 32])
desc.load([0, 0])

input = torch.empty((128, 128), dtype=torch.int32, device='cuda')
with pytest.raises(triton.CompilationError) as e:
kernel[(1, )](input)

assert "tensor descriptor block shape must have at least 8 rows" in str(e.value.__cause__)
peterbell10 marked this conversation as resolved.
Show resolved Hide resolved


def test_tma_store_block_shape_err():

@triton.jit
def kernel(ptr):
desc = tl._experimental_make_tensor_descriptor(ptr, [128, 128], [128, 1], [8, 8])
desc.store([0, 0], tl.zeros((1, 32), dtype=tl.int16))

input = torch.empty((128, 128), dtype=torch.int16, device='cuda')
with pytest.raises(triton.CompilationError) as e:
kernel[(1, )](input)

assert "int16 tensor descriptor block shape must have at least 16 columns" in str(e.value.__cause__)


def test_trans_reshape(device):

@triton.jit
Expand Down
15 changes: 15 additions & 0 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,9 +1147,22 @@ def reinterpret_tensor_descriptor(desc_ptr: tl.tensor, block_ty: tl.block_type,
return tl._experimental_tensor_descriptor_base(handle, block_ty)


def validate_descriptor_block(shape, dtype):
if len(shape) != 2:
return
# Due to limitations of the shared memory encoding, the TMA bounding box has
# to be at least as big as the swizzle tile.
assert shape[0] >= 8, f"tensor descriptor block shape must have at least 8 rows, but got {shape[0]}"
min_cols = 32 // dtype.primitive_bitwidth * 8
assert shape[
1] >= min_cols, f"{dtype} tensor descriptor block shape must have at least {min_cols} columns, but got {shape[1]}"


def descriptor_load(desc: tl._experimental_tensor_desciptor_base, offsets, cache_modifier: str, eviction_policy: str,
builder: ir.builder) -> tl.tensor:
assert isinstance(desc, tl._experimental_tensor_descriptor_base)
validate_descriptor_block(desc.block_shape, desc.type.element_ty)

offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
x = builder.create_descriptor_load(desc.handle, offsets, _str_to_load_cache_modifier(cache_modifier),
_str_to_eviction_policy(eviction_policy))
Expand All @@ -1159,6 +1172,8 @@ def descriptor_load(desc: tl._experimental_tensor_desciptor_base, offsets, cache
def descriptor_store(desc: tl._experimental_tensor_descriptor_base, value: tl.tensor, offsets,
builder: ir.builder) -> tl.tensor:
assert isinstance(desc, tl._experimental_tensor_descriptor_base)
validate_descriptor_block(desc.block_shape, desc.type.element_ty)

offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
return tl.tensor(builder.create_descriptor_store(desc.handle, value.handle, offsets), tl.void)

Expand Down
Loading