From 0a8b1a1593700bf9d0203764448a1c5445ea4a28 Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 14 Jan 2025 22:41:09 -0800 Subject: [PATCH 1/4] [Frontend] Validate minimum TMA block size in the frontend The current TMA implementation requires the TMA block shape to have a minimum size, which is the swizzle tile size in smem. Add an error message for this in the frontend so users stop running into codegen crashes. --- python/triton/language/semantic.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 7939d7f5af63..6377457c1a91 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1147,9 +1147,21 @@ def reinterpret_tensor_descriptor(desc_ptr: tl.tensor, block_ty: tl.block_type, return tl._experimental_tensor_descriptor_base(handle, block_ty) +def validate_descriptor_block(shape, dtype): + if len(shape) != 2: + return + # Due to limitations of the shared memory encoding, the TMA bounding box has + # to be at least as big as the swizzle tile. + assert shape[0] >= 8, f"tensor descriptor block shape must have at least 8 rows, but got {shape[0]}" + min_cols = 32 / dtype.primitive_bitwidth * 8 + assert shape[1] >= min_cols, f"{dtype} tensor descriptor block shape must have at least {min_cols} columns, but got {shape[1]}" + + def descriptor_load(desc: tl._experimental_tensor_desciptor_base, offsets, cache_modifier: str, eviction_policy: str, builder: ir.builder) -> tl.tensor: assert isinstance(desc, tl._experimental_tensor_descriptor_base) + validate_descriptor_block(desc.block_shape, desc.type.element_ty) + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) x = builder.create_descriptor_load(desc.handle, offsets, _str_to_load_cache_modifier(cache_modifier), _str_to_eviction_policy(eviction_policy)) @@ -1159,6 +1171,8 @@ def descriptor_load(desc: tl._experimental_tensor_desciptor_base, offsets, cache def descriptor_store(desc: tl._experimental_tensor_descriptor_base, value: tl.tensor, offsets, builder: ir.builder) -> tl.tensor: assert isinstance(desc, tl._experimental_tensor_descriptor_base) + validate_descriptor_block(desc.block_shape, desc.type.element_ty) + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) return tl.tensor(builder.create_descriptor_store(desc.handle, value.handle, offsets), tl.void) From 5fc5052016d3a1cb4701382fb6b82315d4de3a20 Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 14 Jan 2025 22:43:17 -0800 Subject: [PATCH 2/4] fmt --- python/triton/language/semantic.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 6377457c1a91..f156f4257242 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1154,7 +1154,8 @@ def validate_descriptor_block(shape, dtype): # to be at least as big as the swizzle tile. assert shape[0] >= 8, f"tensor descriptor block shape must have at least 8 rows, but got {shape[0]}" min_cols = 32 / dtype.primitive_bitwidth * 8 - assert shape[1] >= min_cols, f"{dtype} tensor descriptor block shape must have at least {min_cols} columns, but got {shape[1]}" + assert shape[ + 1] >= min_cols, f"{dtype} tensor descriptor block shape must have at least {min_cols} columns, but got {shape[1]}" def descriptor_load(desc: tl._experimental_tensor_desciptor_base, offsets, cache_modifier: str, eviction_policy: str, From ffc1b6643322edfec85c431238e5677c9060965f Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 15 Jan 2025 12:48:47 -0800 Subject: [PATCH 3/4] fix div --- python/triton/language/semantic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index f156f4257242..94cf3eaf0729 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1153,7 +1153,7 @@ def validate_descriptor_block(shape, dtype): # Due to limitations of the shared memory encoding, the TMA bounding box has # to be at least as big as the swizzle tile. assert shape[0] >= 8, f"tensor descriptor block shape must have at least 8 rows, but got {shape[0]}" - min_cols = 32 / dtype.primitive_bitwidth * 8 + min_cols = 32 // dtype.primitive_bitwidth * 8 assert shape[ 1] >= min_cols, f"{dtype} tensor descriptor block shape must have at least {min_cols} columns, but got {shape[1]}" From 90efa29313fcf25cef2ae8ba9743b702b1ea20cb Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 15 Jan 2025 17:11:46 -0800 Subject: [PATCH 4/4] add unit tests that check for CompilationError --- python/test/unit/language/test_core.py | 28 ++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index bcbf00e5c8dd..78827725a43d 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -4732,6 +4732,34 @@ def kernel(): assert "reshape" in str(exc_info.value) +def test_tma_load_block_shape_err(): + + @triton.jit + def kernel(ptr): + desc = tl._experimental_make_tensor_descriptor(ptr, [128, 128], [128, 1], [1, 32]) + desc.load([0, 0]) + + input = torch.empty((128, 128), dtype=torch.int32, device='cuda') + with pytest.raises(triton.CompilationError) as e: + kernel[(1, )](input) + + assert "tensor descriptor block shape must have at least 8 rows" in str(e.value.__cause__) + + +def test_tma_store_block_shape_err(): + + @triton.jit + def kernel(ptr): + desc = tl._experimental_make_tensor_descriptor(ptr, [128, 128], [128, 1], [8, 8]) + desc.store([0, 0], tl.zeros((1, 32), dtype=tl.int16)) + + input = torch.empty((128, 128), dtype=torch.int16, device='cuda') + with pytest.raises(triton.CompilationError) as e: + kernel[(1, )](input) + + assert "int16 tensor descriptor block shape must have at least 16 columns" in str(e.value.__cause__) + + def test_trans_reshape(device): @triton.jit