diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index 6ee00ee6342c..409a1ff10a78 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -14,18 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name,missing-function-docstring +# pylint: disable=invalid-name,missing-function-docstring,unused-variable """Intrinsics for tensorization on NVIDIA GPU.""" -from typing import Dict, Tuple - -from typing_extensions import Literal +from typing import Dict, Optional, Tuple, Literal +from tvm._ffi import register_func +from tvm.runtime import convert from tvm.script import tir as T from tvm.tir.function import PrimFunc - -from ..._ffi import register_func -from ...runtime import convert -from .. import Cast, IntImm, TensorIntrin +from tvm.tir import Cast, IntImm, TensorIntrin def shared_16x16_to_ldmatrix_32x8_layout(i, j): @@ -43,6 +40,12 @@ def shared_32x16_to_ldmatrix_32x16_layout(i, j): return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4 +def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id): + row = 8 * (local_id % 4 // 2) + (thread_id // 4) + col = 8 * (local_id // 4) + (thread_id % 4) * 2 + (local_id % 2) + return row, col + + @register_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout") def index_map_shared_16x16_to_ldmatrix_32x8_layout(ind): i, j = ind[0], ind[1] @@ -59,70 +62,94 @@ def index_map_shared_16x16_to_ldmatrix_32x8_layout(ind): HALF_WARP_expr = lift(HALF_WARP) -def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed, shared_scope="shared"): +def get_ldmatrix_intrin( + k_dim: int, + dtype: str, + matrix_name: Literal["A", "B"], + transposed: bool, + shared_scope: str = "shared", +): local_size = (M_DIM * k_dim) // WARP_SIZE - shared_offset = None + smem_offset = None index_map = None - if transposed: - assert is_b, "Transposed A matrix not supported" - - ldmatrix_col_major = is_b and not transposed + if matrix_name == "A": + transpose_in_ldmatrix = transposed + # transpose_layout_for_ldmatrix_input: Every thread loads 8 bytes data. This determines + # which 8 bytes every thread loads. + # If transpose_layout_for_ldmatrix_input is False, the load pattern is + # T0 T0 T0 T0 T16 T16 T16 T16 + # T1 T1 T1 T1 T17 T17 T17 T17 + # ... + # T8 T8 T8 T8 T24 T24 T24 T24 + # T9 T9 T9 T9 T25 T25 T25 T25 + # ... + # T15 T15 T15 T15 T31 T31 T31 T31 + # Otherwise, the load pattern is + # T0 T0 T0 T0 T8 T8 T8 T8 + # T1 T1 T1 T1 T9 T9 T9 T9 + # ... + # T7 T7 T7 T7 T15 T15 T15 T15 + # T16 T16 T16 T16 T24 T24 T24 T24 + # T17 T17 T17 T17 T25 T25 T25 T25 + # ... + # T23 T23 T23 T23 T31 T31 T31 T31 + transpose_layout_for_ldmatrix_input = transposed + smem_tile_row, smem_tile_col = (M_DIM, k_dim) if not transposed else (k_dim, M_DIM) + else: + assert matrix_name == "B" + transpose_in_ldmatrix = not transposed + transpose_layout_for_ldmatrix_input = transposed + smem_tile_row, smem_tile_col = (k_dim, N_DIM) if not transposed else (N_DIM, k_dim) if k_dim == 16: assert dtype == "float16" index_map = shared_16x16_to_ldmatrix_32x8_layout - if transposed: - shared_offset = ( + if transpose_layout_for_ldmatrix_input: + smem_offset = ( lambda tx, stride: stride * 8 * (tx // HALF_WARP_expr) + stride * (tx % 8) + 8 * ((tx % HALF_WARP_expr) // 8) ) else: - shared_offset = lambda tx, stride: stride * (tx % HALF_WARP_expr) + 8 * ( + smem_offset = lambda tx, stride: stride * (tx % HALF_WARP_expr) + 8 * ( tx // HALF_WARP_expr ) else: + # TODO(yixin): Support TN and TT matmul for int8 + assert ( + matrix_name == "B" or not transposed + ), "Now only B matrix can be transposed for int8 matmul" assert ( k_dim == 32 and dtype == "int8" ), "Only k_dim == 16 (float16) or k_dim == 32 (int8) supported for now" - if ldmatrix_col_major: + if matrix_name == "B" and not transposed: index_map = shared_32x16_to_ldmatrix_32x16_layout # A dummy offset, ldmatrix cannot be used for int8 + trans case. # We still use the ldmatrix intrinsic, but lower it to a manual loop in the codegen. # Only the stride information is required. - shared_offset = lambda _, stride: stride - elif is_b and transposed: + smem_offset = lambda _, stride: stride + elif matrix_name == "B" and transposed: index_map = shared_16x32_to_ldmatrix_32x16_layout - shared_offset = ( + smem_offset = ( lambda tx, stride: stride * 8 * (tx // HALF_WARP_expr) + (tx % 8) * stride + 16 * ((tx % HALF_WARP_expr) // 8) ) - else: + else: # A, not transposed index_map = shared_16x32_to_ldmatrix_32x16_layout - shared_offset = lambda tx, stride: stride * (tx % 16) + 16 * (tx // 16) - - assert index_map and shared_offset - - if is_b and not transposed: - row_dim = k_dim - col_dim = M_DIM - else: - row_dim = M_DIM - col_dim = k_dim + smem_offset = lambda tx, stride: stride * (tx % 16) + 16 * (tx // 16) - shmem_shape = (row_dim, col_dim) - offset_factor = col_dim + offset_factor = smem_tile_col @T.prim_func def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None: shared = T.match_buffer( shared_handle, - shmem_shape, + (smem_tile_row, smem_tile_col), dtype, align=64, offset_factor=offset_factor, @@ -138,10 +165,10 @@ def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None: ) with T.block("root"): - T.reads(shared[0:row_dim, 0:col_dim]) + T.reads(shared[0:smem_tile_row, 0:smem_tile_col]) T.writes(warp[0:WARP_SIZE, 0:local_size]) - for ax0, ax1 in T.grid(row_dim, col_dim): + for ax0, ax1 in T.grid(smem_tile_row, smem_tile_col): with T.block("shared_warp"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads(shared[v0, v1]) @@ -156,7 +183,7 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None: s1 = T.int32() shared = T.match_buffer( shared_handle, - shmem_shape, + (smem_tile_row, smem_tile_col), dtype, align=64, offset_factor=offset_factor, @@ -173,28 +200,68 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None: ) with T.block("root"): - T.reads(shared[0:row_dim, 0:col_dim]) + T.reads(shared[0:smem_tile_row, 0:smem_tile_col]) T.writes(warp[0:WARP_SIZE, 0:local_size]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, WARP_SIZE) - - T.evaluate( - T.ptx_ldmatrix( - ldmatrix_col_major, - 4, # Always load 4 matrices - ".b16", - warp.data, - warp.elem_offset + lift(local_size) * tx, - shared.access_ptr("r"), - shared_offset(tx, s0), - dtype=dtype, + for tx in T.thread_binding(0, WARP_SIZE, "threadIdx.x"): + T.evaluate( + T.ptx_ldmatrix( + transpose_in_ldmatrix, + 4, # Always load 4 matrices + ".b16", + warp.data, + warp.elem_offset + lift(local_size) * tx, + shared.access_ptr("r"), + smem_offset(tx, s0), + dtype=dtype, + ) ) - ) return ldmatrix_desc, ldmatrix_impl -def get_mma_intrin(k_dim, out_dtype, b_transposed): +LDMATRIX_f16_A_INTRIN = "mma_ldmatrix_f16_a" +TensorIntrin.register(LDMATRIX_f16_A_INTRIN, *get_ldmatrix_intrin(16, "float16", "A", False)) + +LDMATRIX_f16_B_INTRIN = "mma_ldmatrix_f16_b" +TensorIntrin.register(LDMATRIX_f16_B_INTRIN, *get_ldmatrix_intrin(16, "float16", "B", False)) + +LDMATRIX_f16_A_TRANS_INTRIN = "mma_ldmatrix_f16_a_trans" +TensorIntrin.register(LDMATRIX_f16_A_TRANS_INTRIN, *get_ldmatrix_intrin(16, "float16", "A", True)) + +LDMATRIX_f16_B_TRANS_INTRIN = "mma_ldmatrix_f16_b_trans" +TensorIntrin.register(LDMATRIX_f16_B_TRANS_INTRIN, *get_ldmatrix_intrin(16, "float16", "B", True)) + +LDMATRIX_f16_A_DYN_INTRIN = "mma_ldmatrix_f16_a_dyn" +TensorIntrin.register( + LDMATRIX_f16_A_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", "A", False, "shared.dyn") +) + +LDMATRIX_f16_B_DYN_INTRIN = "mma_ldmatrix_f16_b_dyn" +TensorIntrin.register( + LDMATRIX_f16_B_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", "B", False, "shared.dyn") +) + +LDMATRIX_f16_A_TRANS_DYN_INTRIN = "mma_ldmatrix_f16_a_trans_dyn" +TensorIntrin.register( + LDMATRIX_f16_A_TRANS_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", "A", True, "shared.dyn") +) + +LDMATRIX_f16_B_TRANS_DYN_INTRIN = "mma_ldmatrix_f16_b_trans_dyn" +TensorIntrin.register( + LDMATRIX_f16_B_TRANS_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", "B", True, "shared.dyn") +) + +LDMATRIX_i8_A_INTRIN = "mma_ldmatrix_i8_a" +TensorIntrin.register(LDMATRIX_i8_A_INTRIN, *get_ldmatrix_intrin(32, "int8", "A", False)) + +LDMATRIX_i8_B_INTRIN = "mma_ldmatrix_i8_b" +TensorIntrin.register(LDMATRIX_i8_B_INTRIN, *get_ldmatrix_intrin(32, "int8", "B", False)) + +LDMATRIX_i8_B_TRANS_INTRIN = "mma_ldmatrix_i8_b_trans" +TensorIntrin.register(LDMATRIX_i8_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "int8", "B", True)) + + +def get_mma_intrin(k_dim, out_dtype, a_transposed, b_transposed): local_size = (M_DIM * k_dim) // WARP_SIZE local_size_out = (M_DIM * N_DIM) // 32 @@ -223,18 +290,16 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed): in_dtype = "int8" in_dtype_abbrv = "int8" - def maybe_cast(v): + def cast_to_out_dtype(v): if out_dtype in ["float32", "int32"]: return Cast(out_dtype, v) return v - def maybe_swap(i, j): - if b_transposed: - return j, i - return i, j + def swap_if_flag(i, j, flag): + return (j, i) if flag else (i, j) - A_offset_factor = k_dim - B_offset_factor = maybe_swap(k_dim, N_DIM)[-1] + A_offset_factor = M_DIM if a_transposed else k_dim + B_offset_factor = k_dim if b_transposed else N_DIM out_offset_factor = N_DIM @T.prim_func @@ -275,10 +340,11 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: for i, j, k in T.grid(M_DIM, N_DIM, k_dim): with T.block("C"): i, j, k = T.axis.remap("SSR", [i, j, k]) - b_row_ind, b_col_ind = T.meta_var(maybe_swap(k, j)) + a_row_ind, a_col_ind = T.meta_var(swap_if_flag(i, k, a_transposed)) + b_row_ind, b_col_ind = T.meta_var(swap_if_flag(k, j, b_transposed)) thread_id_C, local_id_C = T.meta_var(index_map_C(i, j)) - thread_id_A, local_id_A = T.meta_var(index_map_A(i, k)) + thread_id_A, local_id_A = T.meta_var(index_map_A(a_row_ind, a_col_ind)) thread_id_B, local_id_B = T.meta_var(index_map_B(b_row_ind, b_col_ind)) T.reads( @@ -288,9 +354,9 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: ) T.writes(C[thread_id_C, local_id_C]) - C[thread_id_C, local_id_C] += maybe_cast( + C[thread_id_C, local_id_C] += cast_to_out_dtype( A[thread_id_A, local_id_A] - ) * maybe_cast(B[thread_id_B, local_id_B]) + ) * cast_to_out_dtype(B[thread_id_B, local_id_B]) @T.prim_func def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: @@ -326,50 +392,84 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: B[0:WARP_SIZE, 0:local_size], ) T.writes(C[0:WARP_SIZE, 0:local_size_out]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, WARP_SIZE) - T.evaluate( - T.ptx_mma( - mma_prefix, - "row", - "col", - in_dtype_abbrv, - in_dtype_abbrv, - out_dtype_abbrv, - A.data, - A.elem_offset + tx * lift(local_size), - B.data, - B.elem_offset + tx * lift(local_size), - C.data, - C.elem_offset + tx * lift(local_size_out), - False, - dtype=out_dtype, + for tx in T.thread_binding(0, WARP_SIZE, "threadIdx.x"): + T.evaluate( + T.ptx_mma( + mma_prefix, + "row", + "col", + in_dtype_abbrv, + in_dtype_abbrv, + out_dtype_abbrv, + A.data, + A.elem_offset + tx * lift(local_size), + B.data, + B.elem_offset + tx * lift(local_size), + C.data, + C.elem_offset + tx * lift(local_size_out), + False, + dtype=out_dtype, + ) ) - ) - T.evaluate( - T.ptx_mma( - mma_prefix, - "row", - "col", - in_dtype_abbrv, - in_dtype_abbrv, - out_dtype_abbrv, - A.data, - A.elem_offset + tx * lift(local_size), - B.data, - B.elem_offset + tx * lift(local_size) + lift(local_size) // 2, - C.data, - C.elem_offset + tx * lift(local_size_out) + lift(local_size_out) // 2, - False, - dtype=out_dtype, + T.evaluate( + T.ptx_mma( + mma_prefix, + "row", + "col", + in_dtype_abbrv, + in_dtype_abbrv, + out_dtype_abbrv, + A.data, + A.elem_offset + tx * lift(local_size), + B.data, + B.elem_offset + tx * lift(local_size) + lift(local_size) // 2, + C.data, + C.elem_offset + tx * lift(local_size_out) + lift(local_size_out) // 2, + False, + dtype=out_dtype, + ) ) - ) return mma_sync_desc, mma_sync_impl +MMA_f16f16f32_INTRIN = "mma_f16f16f32" +TensorIntrin.register(MMA_f16f16f32_INTRIN, *get_mma_intrin(16, "float32", False, False)) + +MMA_f16f16f32_TRANS_B_INTRIN = "mma_f16f16f32_trans_b" +TensorIntrin.register(MMA_f16f16f32_TRANS_B_INTRIN, *get_mma_intrin(16, "float32", False, True)) + +MMA_f16f16f32_TRANS_A_INTRIN = "mma_f16f16f32_trans_a" +TensorIntrin.register(MMA_f16f16f32_TRANS_A_INTRIN, *get_mma_intrin(16, "float32", True, False)) + +MMA_f16f16f32_TRANS_A_TRANS_B_INTRIN = "mma_f16f16f32_trans_a_trans_b" +TensorIntrin.register( + MMA_f16f16f32_TRANS_A_TRANS_B_INTRIN, *get_mma_intrin(16, "float32", True, True) +) + +MMA_f16f16f16_INTRIN = "mma_f16f16f16" +TensorIntrin.register(MMA_f16f16f16_INTRIN, *get_mma_intrin(16, "float16", False, False)) + +MMA_f16f16f16_TRANS_B_INTRIN = "mma_f16f16f16_trans_b" +TensorIntrin.register(MMA_f16f16f16_TRANS_B_INTRIN, *get_mma_intrin(16, "float16", False, True)) + +MMA_f16f16f16_TRANS_A_INTRIN = "mma_f16f16f16_trans_a" +TensorIntrin.register(MMA_f16f16f16_TRANS_A_INTRIN, *get_mma_intrin(16, "float16", True, False)) + +MMA_f16f16f16_TRANS_A_TRANS_B_INTRIN = "mma_f16f16f16_trans_a_trans_b" +TensorIntrin.register( + MMA_f16f16f16_TRANS_A_TRANS_B_INTRIN, *get_mma_intrin(16, "float16", True, True) +) + +MMA_i8i8i32_INTRIN = "mma_i8i8i32" +TensorIntrin.register(MMA_i8i8i32_INTRIN, *get_mma_intrin(32, "int32", False, False)) + +MMA_i8i8i32_TRANS_B_INTRIN = "mma_i8i8i32_trans_b" +TensorIntrin.register(MMA_i8i8i32_TRANS_B_INTRIN, *get_mma_intrin(32, "int32", False, True)) + + def get_mma_fill_intrin(dtype, local_size): zero = IntImm("int32", 0).astype(dtype) @@ -400,17 +500,27 @@ def mma_fill_impl(a: T.handle) -> None: with T.block("root"): T.reads() T.writes(C_warp[0:WARP_SIZE, 0:local_size]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, WARP_SIZE) - T.evaluate(T.mma_fill(local_size, C_warp.data, C_warp.elem_offset, dtype=dtype)) + for tx in T.thread_binding(0, WARP_SIZE, "threadIdx.x"): + T.evaluate(T.mma_fill(local_size, C_warp.data, C_warp.elem_offset, dtype=dtype)) return mma_fill_desc, mma_fill_impl -def get_mma_store_intrin(dtype, local_size, scope="global"): +MMA_fill_16x16_f32_INTRIN = "mma_fill_16x16_f32" +TensorIntrin.register(MMA_fill_16x16_f32_INTRIN, *get_mma_fill_intrin("float32", 8)) + +MMA_fill_16x16_f16_INTRIN = "mma_fill_16x16_f16" +TensorIntrin.register(MMA_fill_16x16_f16_INTRIN, *get_mma_fill_intrin("float16", 8)) + +MMA_fill_16x16_i32_INTRIN = "mma_fill_16x16_i32" +TensorIntrin.register(MMA_fill_16x16_i32_INTRIN, *get_mma_fill_intrin("int32", 8)) + + +def get_mma_store_intrin(dtype, local_size, scope="global", use_mma_store_intrinic=True): # Assume M = N = 16 index_map = shared_16x16_to_ldmatrix_32x8_layout + index_map_rev = ldmatrix_32x8_to_shared_16x16_layout @T.prim_func def mma_store_desc(a: T.handle, c: T.handle) -> None: @@ -428,110 +538,183 @@ def mma_store_desc(a: T.handle, c: T.handle) -> None: T.writes(C[v0, v1]) C[v0, v1] = C_warp[thread_id, local_id] - @T.prim_func - def mma_store_impl(a: T.handle, c: T.handle) -> None: - s0 = T.int32() - s1 = T.int32() + if use_mma_store_intrinic: - C_warp = T.match_buffer( - a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=1 - ) - C = T.match_buffer( - c, [M_DIM, N_DIM], dtype=dtype, scope=scope, offset_factor=1, strides=[s0, s1] - ) + @T.prim_func + def mma_store_impl(a: T.handle, c: T.handle) -> None: + s0 = T.int32() + s1 = T.int32() - with T.block("root"): - T.reads(C_warp[0:WARP_SIZE, 0:local_size]) - T.writes(C[0:M_DIM, 0:N_DIM]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, WARP_SIZE) + C_warp = T.match_buffer( + a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=1 + ) + C = T.match_buffer( + c, [M_DIM, N_DIM], dtype=dtype, scope=scope, offset_factor=1, strides=[s0, s1] + ) - T.evaluate( - T.mma_store( - M_DIM, - N_DIM, - C.access_ptr("w"), - C_warp.data, - C_warp.elem_offset, - s0, - dtype=dtype, - ) + with T.block("root"): + T.reads(C_warp[0:WARP_SIZE, 0:local_size]) + T.writes(C[0:M_DIM, 0:N_DIM]) + + for tx in T.thread_binding(0, WARP_SIZE, "threadIdx.x"): + T.evaluate( + T.mma_store( + M_DIM, + N_DIM, + C.access_ptr("w"), + C_warp.data, + C_warp.elem_offset, + s0, + dtype=dtype, + ) + ) + + else: + + @T.prim_func + def mma_store_impl(a: T.handle, c: T.handle) -> None: + s0 = T.int32() + s1 = T.int32() + + C_warp = T.match_buffer( + a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=1 + ) + C = T.match_buffer( + c, [M_DIM, N_DIM], dtype=dtype, scope=scope, offset_factor=1, strides=[s0, s1] ) + with T.block("root"): + T.reads(C_warp[0:WARP_SIZE, 0:local_size]) + T.writes(C[0:M_DIM, 0:N_DIM]) + + for tx in T.thread_binding(0, WARP_SIZE, "threadIdx.x"): + for local_id in T.serial(local_size): + row, col = T.meta_var(index_map_rev(tx, local_id)) + C[row, col] = C_warp[tx, local_id] + return mma_store_desc, mma_store_impl -LDMATRIX_16x16_A_INTRIN = "mma.ldmatrix_16x16_a" -TensorIntrin.register(LDMATRIX_16x16_A_INTRIN, *get_ldmatrix_intrin(16, "float16", False, False)) +MMA_store_16x16_f32_global_INTRIN = "mma_store_16x16_f32_global_" +TensorIntrin.register( + MMA_store_16x16_f32_global_INTRIN, *get_mma_store_intrin("float32", 8, "global", True) +) -LDMATRIX_16x16_B_INTRIN = "mma.ldmatrix_16x16_b" -TensorIntrin.register(LDMATRIX_16x16_B_INTRIN, *get_ldmatrix_intrin(16, "float16", True, False)) +MMA_store_16x16_f32_shared_dyn_INTRIN = "mma_store_16x16_f32_shared_dyn_" +TensorIntrin.register( + MMA_store_16x16_f32_shared_dyn_INTRIN, *get_mma_store_intrin("float32", 8, "shared.dyn", True) +) -LDMATRIX_16x16_A_DYN_INTRIN = "mma.ldmatrix_16x16_a_dyn" +MMA_store_16x16_f32_shared_dyn_INTRIN_SIMPLE = "mma_store_16x16_f32_shared_dyn_simple_" TensorIntrin.register( - LDMATRIX_16x16_A_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", False, False, "shared.dyn") + MMA_store_16x16_f32_shared_dyn_INTRIN_SIMPLE, + *get_mma_store_intrin("float32", 8, "shared.dyn", False), ) -LDMATRIX_16x16_B_DYN_INTRIN = "mma.ldmatrix_16x16_b_dyn" +MMA_store_16x16_f16_shared_dyn_INTRIN_SIMPLE = "mma_store_16x16_f16_shared_dyn_simple_" TensorIntrin.register( - LDMATRIX_16x16_B_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", True, False, "shared.dyn") + MMA_store_16x16_f16_shared_dyn_INTRIN_SIMPLE, + *get_mma_store_intrin("float16", 8, "shared.dyn", False), ) -LDMATRIX_16x16_B_TRANS_INTRIN = "mma.ldmatrix_16x16_b_trans" +MMA_store_16x16_f16_global_INTRIN = "mma_store_16x16_f16_global_" TensorIntrin.register( - LDMATRIX_16x16_B_TRANS_INTRIN, *get_ldmatrix_intrin(16, "float16", True, True) + MMA_store_16x16_f16_global_INTRIN, *get_mma_store_intrin("float16", 8, "global", True) ) -LDMATRIX_16x32_A_INTRIN = "mma.ldmatrix_16x32_a" -TensorIntrin.register(LDMATRIX_16x32_A_INTRIN, *get_ldmatrix_intrin(32, "int8", False, False)) +MMA_store_16x16_i32_global_INTRIN = "mma_store_16x16_i32_global_" +TensorIntrin.register( + MMA_store_16x16_i32_global_INTRIN, *get_mma_store_intrin("int32", 8, "global", True) +) -LDMATRIX_32x16_B_INTRIN = "mma.ldmatrix_32x16_b" -TensorIntrin.register(LDMATRIX_32x16_B_INTRIN, *get_ldmatrix_intrin(32, "int8", True, False)) -LDMATRIX_16x32_B_TRANS_INTRIN = "mma.ldmatrix_16x32_b_trans" -TensorIntrin.register(LDMATRIX_16x32_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "int8", True, True)) +def get_mma_intrin_group( + load_scope: Literal["shared", "shared.dyn"], + store_scope: Literal["global", "shared", "shared.dyn"], + in_dtype: Literal["float16", "int8"], + out_dtype: Literal["float16", "float32", "int32"], + trans_a: bool, + trans_b: bool, + not_use_mma_store_intrinic: bool = True, + store_to_smem_dtype: Optional[Literal["float16", "float32", "int32"]] = None, +) -> Dict[str, str]: + """Get a group of intrinsics for mma tensor core with the given configurations -MMA_f16f16f32_INTRIN = "mma_f16f16f32" -TensorIntrin.register(MMA_f16f16f32_INTRIN, *get_mma_intrin(16, "float32", False)) + Parameters + ---------- + load_scope : Literal["shared", "shared.dyn"] + The memory scope of the input buffer. -MMA_f16f16f32_TRANS_INTRIN = "mma_f16f16f32_trans" -TensorIntrin.register(MMA_f16f16f32_TRANS_INTRIN, *get_mma_intrin(16, "float32", True)) + store_scope : Literal["global", "shared", "shared.dyn"] + The memory scope of the result buffer. -MMA_f16f16f16_INTRIN = "mma_f16f16f16" -TensorIntrin.register(MMA_f16f16f16_INTRIN, *get_mma_intrin(16, "float16", False)) + in_dtype : str + The input data type. + + out_dtype : str + The output data dtype. -MMA_f16f16f16_TRANS_INTRIN = "mma_f16f16f16_trans" -TensorIntrin.register(MMA_f16f16f16_TRANS_INTRIN, *get_mma_intrin(16, "float16", True)) + trans_a : bool + Whether the input matrix A is transposed. -MMA_i8i8i32_INTRIN = "mma_i8i8i32" -TensorIntrin.register(MMA_i8i8i32_INTRIN, *get_mma_intrin(32, "int32", False)) + trans_b : bool + Whether the input matrix B is transposed. -MMA_i8i8i32_TRANS_INTRIN = "mma_i8i8i32_trans" -TensorIntrin.register(MMA_i8i8i32_TRANS_INTRIN, *get_mma_intrin(32, "int32", True)) + not_use_mma_store_intrinic : bool + Whether to not use the mma_store intrinsic. If True, use BufferStore stmts to store the + result of mma. Otherwise, use mma_store intrinsic. -MMA_fill_16x16_f32_INTRIN = "mma_fill_16x16_f32" -TensorIntrin.register(MMA_fill_16x16_f32_INTRIN, *get_mma_fill_intrin("float32", 8)) + This is because if we use mma_store intrinsic, during swizzling shared memory visits, our + rearrangement scheme will involve areas accessed by different mma_store calls. This makes + swizzling quite complex. But BufferStore will not face this problem. -MMA_fill_16x16_f16_INTRIN = "mma_fill_16x16_f16" -TensorIntrin.register(MMA_fill_16x16_f16_INTRIN, *get_mma_fill_intrin("float16", 8)) + store_to_smem_dtype : Optional[Literal["float16", "float32", "int32"]] + The dtype that we use to store from register to shared memory. By default it is out_dtype. -MMA_fill_16x16_i32_INTRIN = "mma_fill_16x16_i32" -TensorIntrin.register(MMA_fill_16x16_i32_INTRIN, *get_mma_fill_intrin("int32", 8)) + Returns + ------- + ret : Dict[str, str] + A group of tensor intrinsics. + """ + assert load_scope in ["shared", "shared.dyn"] + assert store_scope in ["global", "shared", "shared.dyn"] + assert in_dtype in ["float16", "int8"] + assert out_dtype in ["float16", "float32", "int32"] -MMA_store_16x16_f32_global_INTRIN = "mma_store_16x16_f32_global_" -TensorIntrin.register( - MMA_store_16x16_f32_global_INTRIN, *get_mma_store_intrin("float32", 8, "global") -) + shape = "16x16" -MMA_store_16x16_f16_global_INTRIN = "mma_store_16x16_f16_global_" -TensorIntrin.register( - MMA_store_16x16_f16_global_INTRIN, *get_mma_store_intrin("float16", 8, "global") -) + dtype_mapping = {"float16": "f16", "float32": "f32", "int8": "i8", "int32": "i32"} + in_dtype = dtype_mapping[in_dtype] + out_dtype = dtype_mapping[out_dtype] -MMA_store_16x16_i32_global_INTRIN = "mma_store_16x16_i32_global_" -TensorIntrin.register( - MMA_store_16x16_i32_global_INTRIN, *get_mma_store_intrin("int32", 8, "global") -) + # e.g. mma_fill_16x16_f32 + init_intrin = f"mma_fill_{shape}_{out_dtype}" + + # e.g. mma_ldmatrix_f16_a_trans_dyn, mma_ldmatrix_f16_b_trans_dyn + trans_a = "_trans" if trans_a else "" + trans_b = "_trans" if trans_b else "" + load_scope = "_dyn" if load_scope == "shared.dyn" else "" + load_a_intrin = f"mma_ldmatrix_{in_dtype}_a{trans_a}{load_scope}" + load_b_intrin = f"mma_ldmatrix_{in_dtype}_b{trans_b}{load_scope}" + + # e.g. mma_f16f16f32_trans_a_trans_b + trans_a_str = trans_a + "_a" if trans_a != "" else "" + trans_b_str = trans_b + "_b" if trans_b != "" else "" + compute_intrin = f"mma_{in_dtype}{in_dtype}{out_dtype}{trans_a_str}{trans_b_str}" + + # e.g. mma_store_16x16_f32_shared_dyn_simple_ + store_scope = store_scope.replace(".", "_") + store_to_smem_dtype = dtype_mapping[store_to_smem_dtype] if store_to_smem_dtype else out_dtype + suffix = "simple_" if not_use_mma_store_intrinic else "" + store_intrin = f"mma_store_{shape}_{store_to_smem_dtype}_{store_scope}_{suffix}" + + return { + "init": init_intrin, + "load_a": load_a_intrin, + "load_b": load_b_intrin, + "compute": compute_intrin, + "store": store_intrin, + } ######## WMMA intrinsics ######## @@ -1235,11 +1418,11 @@ def mma_init_impl(c: T.handle) -> None: with T.block("root"): T.reads() T.writes(dst[0:m_dim, 0:n_dim]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - for b in range(m_dim // 8): - for v in T.vectorized(n_dim // 4): - dst[b * 8 + tx // 4, (tx % 4) * (n_dim // 4) + v] = zero + + for tx in T.thread_binding(0, WARP_SIZE, "threadIdx.x"): + for b in range(m_dim // 8): + for v in T.vectorized(n_dim // 4): + dst[b * 8 + tx // 4, (tx % 4) * (n_dim // 4) + v] = zero return mma_init_desc, mma_init_impl @@ -1310,21 +1493,19 @@ def mma_load_impl(a: T.handle, c: T.handle) -> None: T.reads(src[0:frag_m, 0:frag_n]) T.writes(dst[0:frag_m, 0:frag_n]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate( - T.ptx_ldmatrix( - trans, - 4, # Always load 4 matrices - ".b16", - dst.data, - get_index(dst.elem_offset, d0), - src.access_ptr("r"), - get_tx_index(tx, s0), - dtype=dtype, + for tx in T.thread_binding(0, WARP_SIZE, "threadIdx.x"): + T.evaluate( + T.ptx_ldmatrix( + trans, + 4, # Always load 4 matrices + ".b16", + dst.data, + get_index(dst.elem_offset, d0), + src.access_ptr("r"), + get_tx_index(tx, s0), + dtype=dtype, + ) ) - ) return mma_load_desc, mma_load_impl diff --git a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py similarity index 88% rename from tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py rename to tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py index 2a853a24318c..d704dc243891 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py +++ b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py @@ -22,21 +22,21 @@ from tvm import te from tvm.testing.tir import mma_schedule from tvm.tir.tensor_intrin.cuda import ( - LDMATRIX_16x16_A_INTRIN, - LDMATRIX_16x16_B_INTRIN, - LDMATRIX_16x16_B_TRANS_INTRIN, - LDMATRIX_16x32_A_INTRIN, - LDMATRIX_16x32_B_TRANS_INTRIN, - LDMATRIX_32x16_B_INTRIN, + LDMATRIX_f16_A_INTRIN, + LDMATRIX_f16_B_INTRIN, + LDMATRIX_f16_B_TRANS_INTRIN, + LDMATRIX_i8_A_INTRIN, + LDMATRIX_i8_B_TRANS_INTRIN, + LDMATRIX_i8_B_INTRIN, MMA_f16f16f16_INTRIN, - MMA_f16f16f16_TRANS_INTRIN, + MMA_f16f16f16_TRANS_B_INTRIN, MMA_f16f16f32_INTRIN, - MMA_f16f16f32_TRANS_INTRIN, + MMA_f16f16f32_TRANS_B_INTRIN, MMA_fill_16x16_f16_INTRIN, MMA_fill_16x16_f32_INTRIN, MMA_fill_16x16_i32_INTRIN, MMA_i8i8i32_INTRIN, - MMA_i8i8i32_TRANS_INTRIN, + MMA_i8i8i32_TRANS_B_INTRIN, MMA_store_16x16_f16_global_INTRIN, MMA_store_16x16_f32_global_INTRIN, MMA_store_16x16_i32_global_INTRIN, @@ -116,15 +116,15 @@ def run_test( dev = tvm.device("cuda", 0) if in_dtype == "float16": - a_np = np.random.uniform(size=(M, K)).astype("float16") + a_np = np.random.normal(size=(M, K)).astype("float16") if b_transposed: - b_np = np.random.uniform(size=(N, K)).astype("float16") + b_np = np.random.normal(size=(N, K)).astype("float16") c_np = np.dot(a_np.astype("float32"), b_np.astype("float32").transpose()).astype( out_dtype ) else: - b_np = np.random.uniform(size=(K, N)).astype("float16") + b_np = np.random.normal(size=(K, N)).astype("float16") c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype(out_dtype) else: a_np = np.random.randint(-128, 128, (M, K)).astype("int8") @@ -147,7 +147,7 @@ def run_test( if out_dtype != "float16": # The numpy reference is computed with fp32 precision (otherwise too slow). # So there is non-trivial accuracy difference if TVM result is computed with fp16 accumulation. - tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-2, atol=1e-2) return lambda: f.time_evaluator(f.entry_name, dev, number=500)(a, b, c) @@ -177,8 +177,8 @@ def index_map(i, j): index_map, index_map, index_map, - LDMATRIX_16x16_A_INTRIN, - LDMATRIX_16x16_B_INTRIN, + LDMATRIX_f16_A_INTRIN, + LDMATRIX_f16_B_INTRIN, MMA_f16f16f32_INTRIN, MMA_fill_16x16_f32_INTRIN, MMA_store_16x16_f32_global_INTRIN, @@ -198,9 +198,9 @@ def index_map(i, j): index_map, index_map, index_map, - LDMATRIX_16x16_A_INTRIN, - LDMATRIX_16x16_B_TRANS_INTRIN, - MMA_f16f16f32_TRANS_INTRIN, + LDMATRIX_f16_A_INTRIN, + LDMATRIX_f16_B_TRANS_INTRIN, + MMA_f16f16f32_TRANS_B_INTRIN, MMA_fill_16x16_f32_INTRIN, MMA_store_16x16_f32_global_INTRIN, ) @@ -234,8 +234,8 @@ def index_map(i, j): index_map, index_map, index_map, - LDMATRIX_16x16_A_INTRIN, - LDMATRIX_16x16_B_INTRIN, + LDMATRIX_f16_A_INTRIN, + LDMATRIX_f16_B_INTRIN, MMA_f16f16f16_INTRIN, MMA_fill_16x16_f16_INTRIN, MMA_store_16x16_f16_global_INTRIN, @@ -255,9 +255,9 @@ def index_map(i, j): index_map, index_map, index_map, - LDMATRIX_16x16_A_INTRIN, - LDMATRIX_16x16_B_TRANS_INTRIN, - MMA_f16f16f16_TRANS_INTRIN, + LDMATRIX_f16_A_INTRIN, + LDMATRIX_f16_B_TRANS_INTRIN, + MMA_f16f16f16_TRANS_B_INTRIN, MMA_fill_16x16_f16_INTRIN, MMA_store_16x16_f16_global_INTRIN, ) @@ -305,8 +305,8 @@ def index_map_C(i, j): index_map_A, index_map_B, index_map_C, - LDMATRIX_16x32_A_INTRIN, - LDMATRIX_32x16_B_INTRIN, + LDMATRIX_i8_A_INTRIN, + LDMATRIX_i8_B_INTRIN, MMA_i8i8i32_INTRIN, MMA_fill_16x16_i32_INTRIN, MMA_store_16x16_i32_global_INTRIN, @@ -326,9 +326,9 @@ def index_map_C(i, j): index_map_A, index_map_A, index_map_C, - LDMATRIX_16x32_A_INTRIN, - LDMATRIX_16x32_B_TRANS_INTRIN, - MMA_i8i8i32_TRANS_INTRIN, + LDMATRIX_i8_A_INTRIN, + LDMATRIX_i8_B_TRANS_INTRIN, + MMA_i8i8i32_TRANS_B_INTRIN, MMA_fill_16x16_i32_INTRIN, MMA_store_16x16_i32_global_INTRIN, ) diff --git a/tests/python/unittest/test_tir_schedule_tensorize_mfma.py b/tests/python/unittest/test_tir_schedule_tensorize_mfma_numeric.py similarity index 100% rename from tests/python/unittest/test_tir_schedule_tensorize_mfma.py rename to tests/python/unittest/test_tir_schedule_tensorize_mfma_numeric.py diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index a013cf0f65b8..bc3e979f94cd 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -26,8 +26,8 @@ from tvm.script import tir as T from tvm.testing.tir import mma_schedule from tvm.tir.tensor_intrin.cuda import ( - LDMATRIX_16x16_A_DYN_INTRIN, - LDMATRIX_16x16_B_DYN_INTRIN, + LDMATRIX_f16_A_DYN_INTRIN, + LDMATRIX_f16_B_DYN_INTRIN, MMA_f16f16f32_INTRIN, MMA_fill_16x16_f32_INTRIN, MMA_store_16x16_f32_global_INTRIN, @@ -1520,8 +1520,8 @@ def index_map(i, j): index_map, index_map, index_map, - LDMATRIX_16x16_A_DYN_INTRIN, - LDMATRIX_16x16_B_DYN_INTRIN, + LDMATRIX_f16_A_DYN_INTRIN, + LDMATRIX_f16_B_DYN_INTRIN, MMA_f16f16f32_INTRIN, MMA_fill_16x16_f32_INTRIN, MMA_store_16x16_f32_global_INTRIN,