From ca92032be1ac08deb44474ebf447307c094d1bbf Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 24 Feb 2022 21:55:44 -0800 Subject: [PATCH] add u8s8s32 back --- src/target/source/ptx_mma.cc | 93 +++-- tests/python/unittest/test_tir_ptx_mma.py | 391 ++++++++++++++++++++++ 2 files changed, 450 insertions(+), 34 deletions(-) diff --git a/src/target/source/ptx_mma.cc b/src/target/source/ptx_mma.cc index cebb6ebad95f5..45eb8303183ae 100644 --- a/src/target/source/ptx_mma.cc +++ b/src/target/source/ptx_mma.cc @@ -216,35 +216,61 @@ const MMAConfig valid_mma_configs[] = { /*! * \brief Check whether the multiplicand data type and accumulator data type is valid for MMA * computation. - * \param mul The multiplicand data type. - * \param acc The accumulator data type. + * \param dtype_a The data type of multiplicand a. + * \param dtype_b The data type of multiplicand b. + * \param dtype_c The data type of accumulator c. */ -void CheckMMADTypeCompatible(DataType mul, DataType acc) { - switch (mul) { +void CheckMMADTypeCompatible(DataType dtype_a, DataType dtype_b, DataType dtype_c) { + std::string ab_not_match_err_str = "The multiplicands' data type " + DTypeToString(dtype_a) + + DTypeToString(dtype_b) + " do not match."; + // check a and b + switch (dtype_a) { case DataType::kBit1: + case DataType::kFloat16: + case DataType::kBFloat16: + case DataType::kTensorFloat32: + case DataType::kFloat64: + CHECK(dtype_a == dtype_b) << ab_not_match_err_str; + break; case DataType::kInt4: case DataType::kUInt4: + CHECK(dtype_b == DataType::kInt4 || dtype_b == DataType::kUInt4) << ab_not_match_err_str; + break; case DataType::kInt8: case DataType::kUInt8: - CHECK(acc == DataType::kInt32) << "For multiplicand data type " << DTypeToString(mul) - << ", accumulator data type should be s32."; + CHECK(dtype_b == DataType::kInt8 || dtype_b == DataType::kUInt8) << ab_not_match_err_str; + break; + default: + CHECK(false) << "Invalid multiplicand data types: " << DTypeToString(dtype_a) + << DTypeToString(dtype_b); + } + // check a,b and c + switch (dtype_a) { + case DataType::kBit1: + case DataType::kInt4: + case DataType::kUInt4: + case DataType::kInt8: + case DataType::kUInt8: + CHECK(dtype_c == DataType::kInt32) + << "For multiplicand data type " << DTypeToString(dtype_a) << DTypeToString(dtype_b) + << ", accumulator data type should be s32."; break; case DataType::kFloat16: - CHECK(acc == DataType::kFloat16 || acc == DataType::kFloat32) + CHECK(dtype_c == DataType::kFloat16 || dtype_c == DataType::kFloat32) << "For multiplicand data type f16, accumulator data type should be f16/f32."; break; case DataType::kBFloat16: case DataType::kTensorFloat32: - CHECK(acc == DataType::kFloat32) - << "For multiplicand data type bf16/tf32, accumulator data type can only be f32"; + CHECK(dtype_c == DataType::kFloat32) + << "For multiplicand data type bf16/tf32, accumulator data type can only be f32."; break; case DataType::kFloat64: - CHECK(acc == DataType::kFloat64) - << "For multiplicand data type f64, accumulator data type can only be f64"; + CHECK(dtype_c == DataType::kFloat64) + << "For multiplicand data type f64, accumulator data type can only be f64."; break; default: - CHECK(false) << "Invalid multiplicand/accumulator data type pair: " << DTypeToString(mul) - << ", " << DTypeToString(acc) << "."; + CHECK(false) << "Invalid multiplicand/accumulator data types: " << DTypeToString(dtype_a) + << DTypeToString(dtype_b) << DTypeToString(dtype_c) << "."; } } @@ -272,10 +298,7 @@ void CheckMMAConfigValidity(int m, int n, int k, LayoutType layout_a, LayoutType if (use_bit_op) { CHECK(dtype_a == DataType::kBit1) << "Bit operator is only compatible with 1bit multiplicand."; } - CHECK(dtype_a == dtype_b) << "The multiplicand data type must be equal, found " - << DTypeToString(dtype_a) << " and " << ptx::DTypeToString(dtype_b) - << "."; - CheckMMADTypeCompatible(dtype_a, dtype_c); + CheckMMADTypeCompatible(dtype_a, dtype_b, dtype_c); if (saturate) { CHECK(dtype_a == DataType::kInt4 || dtype_a == DataType::kUInt4 || dtype_a == DataType::kInt8 || dtype_a == DataType::kUInt8) @@ -389,23 +412,26 @@ inline uint32_t GetNumMMAComputations(int m, int n, int k, ptx::DataType dtype) * \param m The M in mMnNkK of MMA instructions. * \param n The N in mMnNkK of MMA instructions. * \param k The K in mMnNkK of MMA instructions. - * \param dtype_mul The data type of multiplicand. - * \param dtype_acc The data type of accumulator. + * \param dtype_a The data type of multiplicand a. + * \param dtype_b The data type of multiplicand b. + * \param dtype_c The data type of accumulator c. * \param sparse Whether it's Sparse MMA or not. */ inline std::tuple GetMMAOperands(int m, int n, int k, - ptx::DataType dtype_mul, - ptx::DataType dtype_acc, + ptx::DataType dtype_a, + ptx::DataType dtype_b, + ptx::DataType dtype_c, bool sparse) { std::stringstream templates, inputs, outputs; - const ptx::FragAttrs frag_attr_mul = ptx::GetFragAttrs(dtype_mul), - frag_attr_acc = ptx::GetFragAttrs(dtype_acc); + const ptx::FragAttrs frag_attr_a = ptx::GetFragAttrs(dtype_a), + frag_attr_b = ptx::GetFragAttrs(dtype_b), + frag_attr_c = ptx::GetFragAttrs(dtype_c); constexpr uint32_t warp_size = 32; - const uint32_t threads = warp_size / GetNumMMAComputations(m, n, k, dtype_mul); - const int num_operands_a = (m * k) * ptx::DTypeBits(dtype_mul) / frag_attr_acc.size / threads / - (sparse ? 2 : 1), - num_operands_b = (k * n) * ptx::DTypeBits(dtype_mul) / frag_attr_mul.size / threads, - num_operands_c = (m * n) * ptx::DTypeBits(dtype_acc) / frag_attr_acc.size / threads; + const uint32_t threads = warp_size / GetNumMMAComputations(m, n, k, dtype_a); + const int num_operands_a = + (m * k) * ptx::DTypeBits(dtype_a) / frag_attr_a.size / threads / (sparse ? 2 : 1), + num_operands_b = (k * n) * ptx::DTypeBits(dtype_b) / frag_attr_b.size / threads, + num_operands_c = (m * n) * ptx::DTypeBits(dtype_c) / frag_attr_c.size / threads; // generate templates; int arg_counter = 0; @@ -440,15 +466,14 @@ inline std::tuple GetMMAOperands(int m, i if (i != 0) { inputs << ", "; } - inputs << "\"" << frag_attr_mul.reg_type << "\"((" << frag_attr_mul.ptr_sig << "(A))[" << i - << "])"; + inputs << "\"" << frag_attr_a.reg_type << "\"((" << frag_attr_a.ptr_sig << "(A))[" << i << "])"; } for (int i = 0; i < num_operands_b; ++i) { - inputs << ", \"" << frag_attr_mul.reg_type << "\"((" << frag_attr_mul.ptr_sig << "(B))[" << i + inputs << ", \"" << frag_attr_b.reg_type << "\"((" << frag_attr_b.ptr_sig << "(B))[" << i << "])"; } for (int i = 0; i < num_operands_c; ++i) { - inputs << ", \"" << frag_attr_acc.reg_type << "\"((" << frag_attr_acc.ptr_sig << "(C))[" << i + inputs << ", \"" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_sig << "(C))[" << i << "])"; } // input of metadata for sparse mma. @@ -461,7 +486,7 @@ inline std::tuple GetMMAOperands(int m, i if (i != 0) { outputs << ","; } - outputs << " \"=" << frag_attr_acc.reg_type << "\"((" << frag_attr_acc.ptr_sig << "(D))[" << i + outputs << " \"=" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_sig << "(D))[" << i << "])"; } return std::make_tuple(templates.str(), inputs.str(), outputs.str()); @@ -495,7 +520,7 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo )"; std::string templates_str, inputs_str, outputs_str; std::tie(templates_str, inputs_str, outputs_str) = - GetMMAOperands(m, n, k, dtype_a, dtype_c, sparse); + GetMMAOperands(m, n, k, dtype_a, dtype_b, dtype_c, sparse); // replace patterns Replacer replacer; diff --git a/tests/python/unittest/test_tir_ptx_mma.py b/tests/python/unittest/test_tir_ptx_mma.py index 55547dfcc7e84..2af922b5e6946 100644 --- a/tests/python/unittest/test_tir_ptx_mma.py +++ b/tests/python/unittest/test_tir_ptx_mma.py @@ -342,6 +342,84 @@ def test_gemm_mma_m8n8k16_row_col_s8s8s32(): tvm.testing.assert_allclose(golden, C_numpy, atol=1e-3, rtol=1e-3) +@T.prim_func +def gemm_mma_m8n8k16_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [8, 16], dtype="int8") + B = T.match_buffer(b, [8, 16], dtype="uint8") + C = T.match_buffer(c, [8, 8], dtype="int32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + MultiA = T.allocate([4], "int8", scope="local") + MultiB = T.allocate([4], "uint8", scope="local") + Accum = T.allocate([2], "int32", scope="local") + for i in range(2): + Accum[i] = T.int32(0) + + for mma_multi_a_col in T.vectorized(4): + MultiA[mma_multi_a_col] = A[(tx % 32) // 4, mma_multi_a_col + (tx % 32) % 4 * 4] + for mma_multi_b_col in T.vectorized(4): + MultiB[mma_multi_b_col] = B[(tx % 32) // 4, mma_multi_b_col + (tx % 32) % 4 * 4] + T.evaluate( + T.ptx_mma( + "m8n8k16", + "row", + "col", + "int8", + "uint8", + "int32", + MultiA, + 0, + MultiB, + 0, + Accum, + 0, + False, + dtype="int32", + ) + ) + for mma_accum_c_id in range(2): + C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load( + "int32", Accum, mma_accum_c_id + ) + + +# This test uses mma instructions that are not available on NVCC 10.1. +# Failure occurs during the external call to nvcc, when attempting to +# generate the .fatbin file. +@tvm.testing.requires_nvcc_version(11) +@tvm.testing.requires_cuda +def test_gemm_mma_m8n8k16_row_col_s8u8s32(): + sch = tvm.tir.Schedule(gemm_mma_m8n8k16_row_col_s8u8s32) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + if major * 10 + minor < 75: + # Require at least SM75 + return + cuda_mod = tvm.build(sch.mod, target="cuda") + + A_np = np.random.uniform(-10, 10, [8, 16]).astype("int8") + B_np = np.random.uniform(-10, 10, [8, 16]).astype("uint8") + C_np = np.zeros([8, 8]).astype("int32") + + ctx = tvm.cuda() + A_tvm = tvm.nd.array(A_np, ctx) + B_tvm = tvm.nd.array(B_np, ctx) + C_tvm = tvm.nd.array(C_np, ctx) + + cuda_mod(A_tvm, B_tvm, C_tvm) + + golden = np.matmul(A_np.astype("int32"), B_np.astype("int32").T) + + C_numpy = C_tvm.numpy() + + tvm.testing.assert_allclose(golden, C_numpy, atol=1e-3, rtol=1e-3) + + @T.prim_func def gemm_mma_m8n8k32_row_col_s4s4s32(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) @@ -412,6 +490,76 @@ def test_gemm_mma_m8n8k32_row_col_s4s4s32(): # TODO: add correctness checking here. +@T.prim_func +def gemm_mma_m8n8k32_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [8, 32], dtype="int4") + B = T.match_buffer(b, [8, 32], dtype="uint4") + C = T.match_buffer(c, [8, 8], dtype="int32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + MultiA = T.allocate([8], "int4", scope="local") + MultiB = T.allocate([8], "uint4", scope="local") + Accum = T.allocate([2], "int32", scope="local") + for i in range(2): + Accum[i] = T.int32(0) + + for mma_multi_a_col in T.vectorized(8): + MultiA[mma_multi_a_col] = A[(tx % 32) // 4, mma_multi_a_col + (tx % 32) % 4 * 8] + for mma_multi_b_col in T.vectorized(8): + MultiB[mma_multi_b_col] = B[(tx % 32) // 4, mma_multi_b_col + (tx % 32) % 4 * 8] + T.evaluate( + T.ptx_mma( + "m8n8k32", + "row", + "col", + "int4", + "uint4", + "int32", + MultiA, + 0, + MultiB, + 0, + Accum, + 0, + False, + dtype="int32", + ) + ) + for mma_accum_c_id in range(2): + C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load( + "int32", Accum, mma_accum_c_id + ) + + +# This test uses mma instructions that are not available on NVCC 10.1. +# Failure occurs during the external call to nvcc, when attempting to +# generate the .fatbin file. +@tvm.testing.requires_nvcc_version(11) +@tvm.testing.requires_cuda +def test_gemm_mma_m8n8k32_row_col_s4u4s32(): + sch = tvm.tir.Schedule(gemm_mma_m8n8k32_row_col_s4u4s32) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + if major * 10 + minor < 75: + # Require at least SM75 + return + cuda_mod = tvm.build(sch.mod, target="cuda") + + ctx = tvm.cuda() + A_tvm = tvm.nd.empty([8, 32], "int4", ctx) + B_tvm = tvm.nd.empty([8, 32], "uint4", ctx) + C_tvm = tvm.nd.empty([8, 8], "int32", ctx) + + cuda_mod(A_tvm, B_tvm, C_tvm) + # Currently the correctness is not checked. + # TODO: add correctness checking here. + + @T.prim_func def gemm_mma_m16n8k8_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) @@ -736,6 +884,88 @@ def test_gemm_mma_m16n8k16_row_col_s8s8s32(): tvm.testing.assert_allclose(golden, C_numpy, atol=1e-3, rtol=1e-3) +@T.prim_func +def gemm_mma_m16n8k16_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [16, 16], dtype="int8") + B = T.match_buffer(b, [8, 16], dtype="uint8") + C = T.match_buffer(c, [16, 8], dtype="int32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + MultiA = T.allocate([8], "int8", scope="local") + MultiB = T.allocate([4], "uint8", scope="local") + Accum = T.allocate([4], "int32", scope="local") + for i in range(4): + Accum[i] = T.int32(0) + + for mma_multi_a_col in range(8): + MultiA[mma_multi_a_col] = A[ + (tx % 32) // 4 + mma_multi_a_col // 4 * 8, + (tx % 32) % 4 * 4 + mma_multi_a_col % 4, + ] + for mma_multi_b_col in T.vectorized(4): + MultiB[mma_multi_b_col] = B[ + (tx % 32) // 4, + (tx % 32) % 4 * 4 + mma_multi_b_col, + ] + T.evaluate( + T.ptx_mma( + "m16n8k16", + "row", + "col", + "int8", + "uint8", + "int32", + MultiA, + 0, + MultiB, + 0, + Accum, + 0, + False, + dtype="int32", + ) + ) + for mma_accum_c_id in range(4): + C[ + (tx % 32) // 4 + mma_accum_c_id // 2 * 8, + (tx % 32) % 4 * 2 + mma_accum_c_id % 2, + ] = T.load("int32", Accum, mma_accum_c_id) + + +@tvm.testing.requires_cuda +def test_gemm_mma_m16n8k16_row_col_s8u8s32(): + sch = tvm.tir.Schedule(gemm_mma_m16n8k16_row_col_s8u8s32) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # Require at least SM80 + return + cuda_mod = tvm.build(sch.mod, target="cuda") + cuda_mod = tvm.build(sch.mod, target="cuda") + + A_np = np.random.uniform(-10, 10, [16, 16]).astype("int8") + B_np = np.random.uniform(-10, 10, [8, 16]).astype("uint8") + C_np = np.zeros([16, 8]).astype("int32") + + ctx = tvm.cuda() + A_tvm = tvm.nd.array(A_np, ctx) + B_tvm = tvm.nd.array(B_np, ctx) + C_tvm = tvm.nd.array(C_np, ctx) + + cuda_mod(A_tvm, B_tvm, C_tvm) + + golden = np.matmul(A_np.astype("int32"), B_np.astype("int32").T) + + C_numpy = C_tvm.numpy() + + tvm.testing.assert_allclose(golden, C_numpy, atol=1e-3, rtol=1e-3) + + @T.prim_func def gemm_mma_m16n8k32_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) @@ -818,6 +1048,88 @@ def test_gemm_mma_m16n8k32_row_col_s8s8s32(): tvm.testing.assert_allclose(golden, C_numpy, atol=1e-3, rtol=1e-3) +@T.prim_func +def gemm_mma_m16n8k32_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [16, 32], dtype="int8") + B = T.match_buffer(b, [8, 32], dtype="uint8") + C = T.match_buffer(c, [16, 8], dtype="int32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + MultiA = T.allocate([16], "int8", scope="local") + MultiB = T.allocate([8], "uint8", scope="local") + Accum = T.allocate([4], "int32", scope="local") + for i in range(4): + Accum[i] = T.int32(0) + + for mma_multi_a_col in range(16): + MultiA[mma_multi_a_col] = A[ + (tx % 32) // 4 + mma_multi_a_col % 8 // 4 * 8, + (tx % 32) % 4 * 4 + mma_multi_a_col % 4 + mma_multi_a_col // 8 * 16, + ] + for mma_multi_b_col in range(8): + MultiB[mma_multi_b_col] = B[ + (tx % 32) // 4, + (tx % 32) % 4 * 4 + mma_multi_b_col % 4 + mma_multi_b_col // 4 * 16, + ] + T.evaluate( + T.ptx_mma( + "m16n8k32", + "row", + "col", + "int8", + "uint8", + "int32", + MultiA, + 0, + MultiB, + 0, + Accum, + 0, + False, + dtype="int32", + ) + ) + for mma_accum_c_id in range(4): + C[ + (tx % 32) // 4 + mma_accum_c_id // 2 * 8, + (tx % 32) % 4 * 2 + mma_accum_c_id % 2, + ] = T.load("int32", Accum, mma_accum_c_id) + + +@tvm.testing.requires_cuda +def test_gemm_mma_m16n8k32_row_col_s8u8s32(): + sch = tvm.tir.Schedule(gemm_mma_m16n8k32_row_col_s8u8s32) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # Require at least SM80 + return + cuda_mod = tvm.build(sch.mod, target="cuda") + cuda_mod = tvm.build(sch.mod, target="cuda") + + A_np = np.random.uniform(-10, 10, [16, 32]).astype("int8") + B_np = np.random.uniform(-10, 10, [8, 32]).astype("uint8") + C_np = np.zeros([16, 8]).astype("int32") + + ctx = tvm.cuda() + A_tvm = tvm.nd.array(A_np, ctx) + B_tvm = tvm.nd.array(B_np, ctx) + C_tvm = tvm.nd.array(C_np, ctx) + + cuda_mod(A_tvm, B_tvm, C_tvm) + + golden = np.matmul(A_np.astype("int32"), B_np.astype("int32").T) + + C_numpy = C_tvm.numpy() + + tvm.testing.assert_allclose(golden, C_numpy, atol=1e-3, rtol=1e-3) + + @T.prim_func def gemm_mma_m16n8k64_row_col_s4s4s32(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) @@ -892,6 +1204,80 @@ def test_gemm_mma_m16n8k64_row_col_s4s4s32(): # TODO: add correctness checking here. +@T.prim_func +def gemm_mma_m16n8k64_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [16, 64], dtype="int4") + B = T.match_buffer(b, [8, 64], dtype="uint4") + C = T.match_buffer(c, [16, 8], dtype="int32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + MultiA = T.allocate([32], "int4", scope="local") + MultiB = T.allocate([16], "uint4", scope="local") + Accum = T.allocate([4], "int32", scope="local") + for i in range(4): + Accum[i] = T.int32(0) + + for mma_multi_a_col in range(32): + MultiA[mma_multi_a_col] = A[ + (tx % 32) // 4 + mma_multi_a_col % 16 // 8 * 8, + (tx % 32) % 4 * 8 + mma_multi_a_col % 8 + mma_multi_a_col // 16 * 32, + ] + for mma_multi_b_col in range(16): + MultiB[mma_multi_b_col] = B[ + (tx % 32) // 4, + (tx % 32) % 4 * 8 + mma_multi_b_col % 8 + mma_multi_b_col // 8 * 32, + ] + T.evaluate( + T.ptx_mma( + "m8n8k32", + "row", + "col", + "int4", + "uint4", + "int32", + MultiA, + 0, + MultiB, + 0, + Accum, + 0, + False, + dtype="int32", + ) + ) + for mma_accum_c_id in range(4): + C[ + (tx % 32) // 4 + mma_accum_c_id // 2 * 8, + (tx % 32) % 4 * 2 + mma_accum_c_id % 2, + ] = T.load("int32", Accum, mma_accum_c_id) + + +@tvm.testing.requires_cuda +def test_gemm_mma_m16n8k64_row_col_s4u4s32(): + sch = tvm.tir.Schedule(gemm_mma_m16n8k64_row_col_s4u4s32) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # Require at least SM80 + return + cuda_mod = tvm.build(sch.mod, target="cuda") + cuda_mod = tvm.build(sch.mod, target="cuda") + + ctx = tvm.cuda() + A_tvm = tvm.nd.empty([16, 64], "int4", ctx) + B_tvm = tvm.nd.empty([8, 64], "uint4", ctx) + C_tvm = tvm.nd.empty([16, 8], "int32", ctx) + + cuda_mod(A_tvm, B_tvm, C_tvm) + # Currently the correctness is not checked. + # TODO: add correctness checking here. + + @T.prim_func def gemm_mma_m16n8k256_row_col_b1b1s32(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) @@ -972,11 +1358,16 @@ def test_gemm_mma_m16n8k256_row_col_b1b1s32(): test_gemm_mma_m8n8k4_row_row_fp16fp16fp16() test_gemm_mma_m8n8k4_row_row_fp16fp16fp32() test_gemm_mma_m8n8k16_row_col_s8s8s32() + test_gemm_mma_m8n8k16_row_col_s8u8s32() test_gemm_mma_m8n8k32_row_col_s4s4s32() + test_gemm_mma_m8n8k32_row_col_s4u4s32() test_gemm_mma_m16n8k8_row_col_fp16fp16fp32() test_gemm_mma_m16n8k16_row_col_fp16fp16fp16() test_gemm_mma_m16n8k16_row_col_fp16fp16fp32() test_gemm_mma_m16n8k16_row_col_s8s8s32() + test_gemm_mma_m16n8k16_row_col_s8u8s32() test_gemm_mma_m16n8k32_row_col_s8s8s32() + test_gemm_mma_m16n8k32_row_col_s8u8s32() test_gemm_mma_m16n8k64_row_col_s4s4s32() + test_gemm_mma_m16n8k64_row_col_s4u4s32() test_gemm_mma_m16n8k256_row_col_b1b1s32()