diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index f7e1cfbc3e6d..0d9f82305352 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -596,6 +596,19 @@ TVM_DLL const Op& tvm_store_matrix_sync(); */ TVM_DLL const Op& ptx_mma(); +/*! + * \brief tvm intrinsic for sparse tensor core ptx instructions. + * + * void ptx_mma_sp(StringImm shape, StringImm A_layout, StringImm B_layout, + * StringImm A_dtype, StringImm B_dtype, StringImm C_dtype, + * Var multiplicand_a, Expr a_index, + * Var multiplicand_b, Expr b_index, + * Var accumulator, Expr c_index, + * Var metadata, Expr meta_index, + * Var sparse_selector, bool saturate); + */ +TVM_DLL const Op& ptx_mma_sp(); + // TODO(tvm-team) replace the usage of the vector operations by Shuffle. /*! * \brief Get the high level half of the vector diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 0dda079066d9..f74d5cf484b9 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -744,7 +744,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { // arg 10: C accumulator // arg 11: C accumulator index // arg 12: saturate - ICHECK_EQ(op->args.size(), 13U); + // arg 13: (optional) 1-bit operator (xor or and) + ICHECK(op->args.size() == 13U || op->args.size() == 14U); std::string shape = Downcast(op->args[0])->value; std::string A_layout = Downcast(op->args[1])->value; std::string B_layout = Downcast(op->args[2])->value; @@ -757,11 +758,51 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string b_bias = this->PrintExpr(op->args[9]); std::string c_ref = this->PrintExpr(op->args[10]); std::string c_bias = this->PrintExpr(op->args[11]); - bool saturate = (Downcast(op->args[12])->value != 0); - std::string asm_code = PrintMMAAssembly(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, - a_ref, a_bias, b_ref, b_bias, c_ref, c_bias, saturate); + bool saturate = Downcast(op->args[12])->value; + std::string bit_op = op->args.size() > 13 ? Downcast(op->args[13])->value : ""; + std::string asm_code = + PrintMMAAssembly(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, b_ref, + b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate); this->stream << asm_code; + } else if (op->op.same_as(builtin::ptx_mma_sp())) { + // arg 0: shape: mXnXkX + // arg 1: A layout: row/col + // arg 2: B layout: row/col + // arg 3: A precision: fp16, fp32, ... + // arg 4: B precision: fp16, fp32, ... + // arg 5: C precision: fp16, fp32, ... + // arg 6: A multiplicand + // arg 7: A multiplicand index + // arg 8: B multiplicand + // arg 9: B multiplicand index + // arg 10: C accumulator + // arg 11: C accumulator index + // arg 12: metadata + // arg 13: metadata index + // arg 14: sparse_selector + // arg 15: saturate + ICHECK_EQ(op->args.size(), 16U); + std::string shape = Downcast(op->args[0])->value; + std::string A_layout = Downcast(op->args[1])->value; + std::string B_layout = Downcast(op->args[2])->value; + std::string A_dtype = Downcast(op->args[3])->value; + std::string B_dtype = Downcast(op->args[4])->value; + std::string C_dtype = Downcast(op->args[5])->value; + std::string a_ref = this->PrintExpr(op->args[6]); + std::string a_offset = this->PrintExpr(op->args[7]); + std::string b_ref = this->PrintExpr(op->args[8]); + std::string b_offset = this->PrintExpr(op->args[9]); + std::string c_ref = this->PrintExpr(op->args[10]); + std::string c_offset = this->PrintExpr(op->args[11]); + std::string metadata = this->PrintExpr(op->args[12]); + std::string metadata_offset = this->PrintExpr(op->args[13]); + std::string sparse_selector = this->PrintExpr(op->args[14]); + bool saturate = Downcast(op->args[15])->value; + std::string asm_code = PrintMMAAssembly( + shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset, b_ref, b_offset, + c_ref, c_offset, metadata, metadata_offset, sparse_selector, "", true, saturate); + this->stream << asm_code; } else { CodeGenC::VisitExpr_(op, os); } diff --git a/src/target/source/ptx_mma.cc b/src/target/source/ptx_mma.cc index b6182720416c..d04c01896ed7 100644 --- a/src/target/source/ptx_mma.cc +++ b/src/target/source/ptx_mma.cc @@ -23,1351 +23,543 @@ #include "ptx_mma.h" +#include +#include +#include +#include +#include + namespace tvm { namespace codegen { -std::string ReplaceMMAArgument(std::string asm_code, const std::string& original, - const std::string& new_arg) { - size_t len = original.size(); - size_t new_len = new_arg.size(); - size_t pos = asm_code.find(original); - while (pos != std::string::npos) { - asm_code = asm_code.replace(pos, len, new_arg); - pos = asm_code.find(original, pos + new_len); - } - return asm_code; -} +// PTX related data structures and functions. +namespace ptx { -std::string PrintMMAm8n8k4Assembly(const std::string& A_layout, const std::string& B_layout, - const std::string& A_dtype, const std::string& B_dtype, - const std::string& C_dtype, const std::string& a_ref, - const std::string& a_bias, const std::string& b_ref, - const std::string& b_bias, const std::string& c_ref, - const std::string& c_bias, bool saturate) { - std::string asm_code = ""; - std::string new_a_ref = ""; - std::string new_b_ref = ""; - std::string new_c_ref = ""; - ICHECK(((A_dtype == "fp16") && (B_dtype == "fp16")) || - ((A_dtype == "fp64") && (B_dtype == "fp64"))); - ICHECK(saturate == false) << "Saturate is not allowed for m8n8k4 mma."; - if ((A_dtype == "fp16") && (B_dtype == "fp16")) { - // A/B multiplicand is fp16, SM 70 Tensor Core instructions - ICHECK((C_dtype == "fp16") || (C_dtype == "fp32")); - if (C_dtype == "fp16") { - // C accumulator is fp16 - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((unsigned *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k4.left_layout.right_layout.f16.f16.f16.f16 " - "{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, " - "{%8,%9,%10,%11};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // C accumulator is fp32 - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k4.left_layout.right_layout.f32.f16.f16.f32 " - "{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " - "{%12,%13,%14,%15,%16,%17,%18,%19};\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]), - "=f"(D[4]), "=f"(D[5]), "=f"(D[6]), "=f"(D[7]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), - "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); - } - )"; - } +/*! + * \brief PTX data type. + * \note + * PTX fundamental data types: + * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types + * PTX matrix data types: + * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types + */ +enum class DataType : int { + kInt4 = 0, + kUInt4 = 1, + kInt8 = 2, + kUInt8 = 3, + kInt16 = 4, + kUInt16 = 5, + kInt32 = 6, + kUInt32 = 7, + kInt64 = 8, + kUInt64 = 9, + kFloat16 = 10, + kBFloat16 = 11, + kFloat16x2 = 12, + kFloat32 = 13, + kTensorFloat32 = 14, + kFloat64 = 15, + kBit1 = 16 +}; + +static const char* dtype_str[] = {".s4", ".u4", ".s8", ".u8", ".s16", ".u16", + ".s32", ".u32", ".s64", ".u64", ".f16", ".bf16", + ".f16x2", ".f32", ".tf32", ".f64", ".b1"}; +static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 16, 16, 32, 32, 32, 64, 1}; + +/*! + * \brief Create PTX data type from string. + */ +inline DataType DTypeFromString(const std::string str) { + if (str == "int4" || str == ".s4") { + return DataType::kInt4; + } else if (str == "uint4" || str == ".u4") { + return DataType::kUInt4; + } else if (str == "int8" || str == ".s8") { + return DataType::kInt8; + } else if (str == "uint8" || str == ".u8") { + return DataType::kUInt8; + } else if (str == "int16" || str == ".s16") { + return DataType::kInt16; + } else if (str == "uint16" || str == ".u16") { + return DataType::kUInt16; + } else if (str == "int32" || str == ".s32") { + return DataType::kInt32; + } else if (str == "uint32" || str == ".u32") { + return DataType::kUInt32; + } else if (str == "int64" || str == ".s64") { + return DataType::kInt64; + } else if (str == "uint64" || str == ".u64") { + return DataType::kUInt64; + } else if (str == "float16" || str == "fp16" || str == ".f16") { + return DataType::kFloat16; + } else if (str == "bfloat16" || str == "bf16") { + return DataType::kBFloat16; + } else if (str == ".f16x2") { + return DataType::kFloat16x2; + } else if (str == "float32" || str == "fp32" || str == ".f32") { + return DataType::kFloat32; + } else if (str == "tf32") { + return DataType::kTensorFloat32; + } else if (str == "float64" || str == "fp64" || str == ".f64") { + return DataType::kFloat64; + } else if (str == "int1" || str == ".b1") { + return DataType::kBit1; } else { - // A/B multiplicand is fp64, SM 80 Tensor Core instructions - ICHECK(C_dtype == "fp64"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Fp64 Tensor Core instructions " - << "with shape m8n8k4 expect A layout is row major and B layout is col major."; - // C accumulator is fp64 - new_a_ref = "((double *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((double *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((double *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=d"(D[0]), "=d"(D[1]) - : "d"(A[0]), "d"(B[0]), - "d"(C[0]), "d"(C[1])); - } - )"; + LOG(FATAL) << "Unrecognized PTX data type " << str; + return DataType(0); } - asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); - asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); - asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); - asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); - asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); - asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); - return asm_code; } -std::string PrintMMAm16n8k8Assembly(const std::string& A_layout, const std::string& B_layout, - const std::string& A_dtype, const std::string& B_dtype, - const std::string& C_dtype, const std::string& a_ref, - const std::string& a_bias, const std::string& b_ref, - const std::string& b_bias, const std::string& c_ref, - const std::string& c_bias, bool saturate) { - std::string asm_code = ""; - std::string new_a_ref = ""; - std::string new_b_ref = ""; - std::string new_c_ref = ""; - ICHECK(((A_dtype == "fp16") && (B_dtype == "fp16")) || - ((A_dtype == "bf16") && (B_dtype == "bf16"))); - ICHECK(saturate == false) << "Saturate is not allowed for m16n8k8 mma."; - if ((A_dtype == "fp16") && (B_dtype == "fp16")) { - // A/B multiplicand is fp16, SM 75 Tensor Core instructions - ICHECK((C_dtype == "fp16") || (C_dtype == "fp32")); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM75 Tensor Core instructions " - << "with shape m16n8k8 expect A layout is row major and B layout is col major."; - if (C_dtype == "fp16") { - // C accumulator is fp16 - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((unsigned *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " - "{%0,%1}, {%2,%3}, {%5}, " - "{%5,%6};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } else { - // C accumulator is fp32 - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); - } - )"; - } +/*! + * \brief Get the string representation of given PTX data type. + */ +inline std::string DTypeToString(DataType dtype) { return dtype_str[static_cast(dtype)]; } + +/*! + * \brief Get the number of bits of given PTX data type. + */ +inline uint32_t DTypeBits(DataType dtype) { return num_bits[static_cast(dtype)]; } + +/*! + * \brief Extract the value m, n, k from string m*n*k* + */ +inline std::tuple ParseMMAShape(const std::string& str) { + size_t pos_m = str.find("m"), pos_n = str.find("n"), pos_k = str.find("k"); + CHECK(pos_m != str.npos && pos_n != str.npos && pos_k != str.npos) + << "Cannot parse MMA shape " << str; + int m = std::stoi(str.substr(pos_m + 1, pos_n - pos_m - 1)), + n = std::stoi(str.substr(pos_n + 1, pos_k - pos_n - 1)), k = std::stoi(str.substr(pos_k + 1)); + return std::make_tuple(m, n, k); +} + +/*! + * \brief Layout Type + */ +enum class LayoutType : int { kRowMajor = 0, kColumnMajor = 1 }; + +/*! + * \brief Parse layout type + */ +LayoutType LayoutTypeFromString(const std::string& str) { + if (str == "row") { + return LayoutType::kRowMajor; + } else if (str == "col") { + return LayoutType::kColumnMajor; } else { - // A/B multiplicand is bf16, SM 80 Tensor Core instructions - ICHECK(C_dtype == "fp32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k8 expect A layout is row major and B layout is col major."; - // C accumulator is fp32 - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); - } - )"; + LOG(FATAL) << "Unrecognized layout type " << str; + return LayoutType::kRowMajor; } - asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); - asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); - asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); - asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); - asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); - asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); - return asm_code; } -std::string PrintMMAm8n8k16Assembly(const std::string& A_layout, const std::string& B_layout, - const std::string& A_dtype, const std::string& B_dtype, - const std::string& C_dtype, const std::string& a_ref, - const std::string& a_bias, const std::string& b_ref, - const std::string& b_bias, const std::string& c_ref, - const std::string& c_bias, bool saturate) { - std::string asm_code = ""; - std::string new_a_ref = ""; - std::string new_b_ref = ""; - std::string new_c_ref = ""; - ICHECK(((A_dtype == "int8") && (B_dtype == "int8")) || - ((A_dtype == "uint8") && (B_dtype == "int8")) || - ((A_dtype == "int8") && (B_dtype == "uint8")) || - ((A_dtype == "uint8") && (B_dtype == "uint8"))); - if ((A_dtype == "int8") && (B_dtype == "int8")) { - // A/B multiplicand is int8, SM 75 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM75 Tensor Core instructions " - << "with shape m8n8k16 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } - } else if ((A_dtype == "uint8") && (B_dtype == "int8")) { - // A multiplicand is uint8, B multiplicand is int8 - // SM 75 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM75 Tensor Core instructions " - << "with shape m8n8k16 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.s8.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } - } else if ((A_dtype == "int8") && (B_dtype == "uint8")) { - // A multiplicand is int8, B multiplicand is uint8 - // SM 75 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM75 Tensor Core instructions " - << "with shape m8n8k16 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k16.row.col.s32.s8.u8.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.u8.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } - } else { - // A/B multiplicand is uint8, SM 75 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM75 Tensor Core instructions " - << "with shape m8n8k16 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.u8.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } +static const char* layout_type_str[] = {"row", "col"}; + +/*! + * \brief Convert layout type to string. + */ +inline std::string LayoutTypeToString(LayoutType layout) { + return layout_type_str[static_cast(layout)]; +} + +/*! + * \brief MMA Configurations, used to determine validity. + */ +struct MMAConfig { + explicit MMAConfig(int m, int n, int k, DataType dtype_mul, bool use_bit_op, bool sparse) + : m(m), n(n), k(k), dtype_mul(dtype_mul), use_bit_op(use_bit_op), sparse(sparse) {} + int m, n, k; + DataType dtype_mul; + bool use_bit_op; + bool sparse; + inline bool operator==(const MMAConfig& other) { + return m == other.m && n == other.n && k == other.k && dtype_mul == other.dtype_mul && + use_bit_op == other.use_bit_op && sparse == other.sparse; + } +}; + +/*! + * \brief Valid MMA configurations + * \note Reference: + * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-shape + */ +const MMAConfig valid_mma_configs[] = { + MMAConfig(8, 8, 4, DataType::kFloat64, false, false), + MMAConfig(8, 8, 4, DataType::kFloat16, false, false), + MMAConfig(16, 8, 8, DataType::kFloat16, false, false), + MMAConfig(16, 8, 16, DataType::kFloat16, false, false), + MMAConfig(16, 8, 8, DataType::kBFloat16, false, false), + MMAConfig(16, 8, 16, DataType::kBFloat16, false, false), + MMAConfig(16, 8, 4, DataType::kTensorFloat32, false, false), + MMAConfig(16, 8, 8, DataType::kTensorFloat32, false, false), + MMAConfig(8, 8, 16, DataType::kInt8, false, false), + MMAConfig(16, 8, 16, DataType::kInt8, false, false), + MMAConfig(16, 8, 32, DataType::kInt8, false, false), + MMAConfig(8, 8, 16, DataType::kUInt8, false, false), + MMAConfig(16, 8, 16, DataType::kUInt8, false, false), + MMAConfig(16, 8, 32, DataType::kUInt8, false, false), + MMAConfig(8, 8, 32, DataType::kInt4, false, false), + MMAConfig(16, 8, 32, DataType::kInt4, false, false), + MMAConfig(16, 8, 64, DataType::kInt4, false, false), + MMAConfig(8, 8, 32, DataType::kUInt4, false, false), + MMAConfig(16, 8, 32, DataType::kUInt4, false, false), + MMAConfig(16, 8, 64, DataType::kUInt4, false, false), + MMAConfig(8, 8, 128, DataType::kBit1, true, false), + MMAConfig(16, 8, 128, DataType::kBit1, true, false), + MMAConfig(16, 8, 256, DataType::kBit1, true, false), + MMAConfig(16, 8, 16, DataType::kFloat16, false, true), + MMAConfig(16, 8, 32, DataType::kFloat16, false, true), + MMAConfig(16, 8, 16, DataType::kBFloat16, false, true), + MMAConfig(16, 8, 32, DataType::kBFloat16, false, true), + MMAConfig(16, 8, 8, DataType::kTensorFloat32, false, true), + MMAConfig(16, 8, 16, DataType::kTensorFloat32, false, true), + MMAConfig(16, 8, 32, DataType::kInt8, false, true), + MMAConfig(16, 8, 64, DataType::kInt8, false, true), + MMAConfig(16, 8, 32, DataType::kUInt8, false, true), + MMAConfig(16, 8, 64, DataType::kUInt8, false, true), + MMAConfig(16, 8, 64, DataType::kInt4, false, true), + MMAConfig(16, 8, 128, DataType::kInt4, false, true), + MMAConfig(16, 8, 64, DataType::kUInt4, false, true), + MMAConfig(16, 8, 128, DataType::kUInt4, false, true), +}; + +/*! + * \brief Check whether the multiplicand data type and accumulator data type is valid for MMA + * computation. + * \param dtype_a The data type of multiplicand a. + * \param dtype_b The data type of multiplicand b. + * \param dtype_c The data type of accumulator c. + * \note Reference: + * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types + */ +void CheckMMADTypeCompatible(DataType dtype_a, DataType dtype_b, DataType dtype_c) { + std::string ab_not_match_err_str = "The multiplicands' data type " + DTypeToString(dtype_a) + + DTypeToString(dtype_b) + " do not match."; + // check a and b + switch (dtype_a) { + case DataType::kBit1: + case DataType::kFloat16: + case DataType::kBFloat16: + case DataType::kTensorFloat32: + case DataType::kFloat64: + CHECK(dtype_a == dtype_b) << ab_not_match_err_str; + break; + case DataType::kInt4: + case DataType::kUInt4: + CHECK(dtype_b == DataType::kInt4 || dtype_b == DataType::kUInt4) << ab_not_match_err_str; + break; + case DataType::kInt8: + case DataType::kUInt8: + CHECK(dtype_b == DataType::kInt8 || dtype_b == DataType::kUInt8) << ab_not_match_err_str; + break; + default: + CHECK(false) << "Invalid multiplicand data types: " << DTypeToString(dtype_a) + << DTypeToString(dtype_b); + } + // check a,b and c + switch (dtype_a) { + case DataType::kBit1: + case DataType::kInt4: + case DataType::kUInt4: + case DataType::kInt8: + case DataType::kUInt8: + CHECK(dtype_c == DataType::kInt32) + << "For multiplicand data type " << DTypeToString(dtype_a) << DTypeToString(dtype_b) + << ", accumulator data type should be s32."; + break; + case DataType::kFloat16: + CHECK(dtype_c == DataType::kFloat16 || dtype_c == DataType::kFloat32) + << "For multiplicand data type f16, accumulator data type should be f16/f32."; + break; + case DataType::kBFloat16: + case DataType::kTensorFloat32: + CHECK(dtype_c == DataType::kFloat32) + << "For multiplicand data type bf16/tf32, accumulator data type can only be f32."; + break; + case DataType::kFloat64: + CHECK(dtype_c == DataType::kFloat64) + << "For multiplicand data type f64, accumulator data type can only be f64."; + break; + default: + CHECK(false) << "Invalid multiplicand/accumulator data types: " << DTypeToString(dtype_a) + << DTypeToString(dtype_b) << DTypeToString(dtype_c) << "."; } - asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); - asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); - asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); - asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); - asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); - asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); - return asm_code; } -std::string PrintMMAm8n8k32Assembly(const std::string& A_layout, const std::string& B_layout, - const std::string& A_dtype, const std::string& B_dtype, - const std::string& C_dtype, const std::string& a_ref, - const std::string& a_bias, const std::string& b_ref, - const std::string& b_bias, const std::string& c_ref, - const std::string& c_bias, bool saturate) { - std::string asm_code = ""; - std::string new_a_ref = ""; - std::string new_b_ref = ""; - std::string new_c_ref = ""; - ICHECK(((A_dtype == "int4") && (B_dtype == "int4")) || - ((A_dtype == "uint4") && (B_dtype == "int4")) || - ((A_dtype == "int4") && (B_dtype == "uint4")) || - ((A_dtype == "uint4") && (B_dtype == "uint4"))); - if ((A_dtype == "int4") && (B_dtype == "int4")) { - // A/B multiplicand is int4, SM 75 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM75 Tensor Core instructions " - << "with shape m8n8k32 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.s4.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } - } else if ((A_dtype == "uint4") && (B_dtype == "int4")) { - // A multiplicand is uint4, B multiplicand is int4 - // SM 75 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM75 Tensor Core instructions " - << "with shape m8n8k32 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } - } else if ((A_dtype == "int4") && (B_dtype == "uint4")) { - // A multiplicand is int4, B multiplicand is uint4 - // SM 75 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM75 Tensor Core instructions " - << "with shape m8n8k32 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.u4.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } - } else { - // A/B multiplicand is uint4, SM 75 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM75 Tensor Core instructions " - << "with shape m8n8k32 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.u4.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; +/*! + * \brief Check whether the given configuration is valid for MMA computation. + * \param m The M in mMnNkK of MMA instructions. + * \param n The N in mMnNkK of MMA instructions. + * \param k The K in mMnNkK of MMA instructions. + * \param layout_a The layout of multiplicand A (row/col). + * \param layout_b The layout of multiplicand B (row/col). + * \param dtype_a The data type of multiplicand A. + * \param dtype_b The data type of multiplicand B. + * \param dtype_c The data type of accumulator C. + * \param bit_op The bit operator for 1-bit MMA computation, can be "xor"/"and" or ""(if it's not + * 1-bit MMA). + * \param sparse Whether it's Sparse MMA or not. + * \param saturate Whether saturate output or not. + */ +void CheckMMAConfigValidity(int m, int n, int k, LayoutType layout_a, LayoutType layout_b, + DataType dtype_a, DataType dtype_b, DataType dtype_c, + const std::string& bit_op, bool sparse, bool saturate) { + CHECK(bit_op == "xor" || bit_op == "and" || bit_op == "") + << "Unrecognized 1-bit operation " << bit_op << " , can only be xor/and."; + bool use_bit_op = !bit_op.empty(); + if (use_bit_op) { + CHECK(dtype_a == DataType::kBit1) << "Bit operator is only compatible with 1-bit multiplicand."; + } + CheckMMADTypeCompatible(dtype_a, dtype_b, dtype_c); + if (saturate) { + CHECK(dtype_a == DataType::kInt4 || dtype_a == DataType::kUInt4 || dtype_a == DataType::kInt8 || + dtype_a == DataType::kUInt8) + << "Output saturation only applicable to multiplicand type s4/u4/s8/u8."; + } + + if (!(m == 8 && n == 8 && k == 4 && dtype_a == ptx::DataType::kFloat16)) { + // Only MMA on m8n8k4 for fp16 supports customized layouts. + CHECK(layout_a == LayoutType::kRowMajor && layout_b == LayoutType::kColumnMajor) + << "Invalid layout combination " << LayoutTypeToString(layout_a) << "," + << LayoutTypeToString(layout_b) << "."; + } + + MMAConfig config(m, n, k, dtype_a, use_bit_op, sparse); + bool match = false; + for (const MMAConfig& valid_config : valid_mma_configs) { + if (config == valid_config) { + match = true; + break; } } - asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); - asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); - asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); - asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); - asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); - asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); - return asm_code; + CHECK(match) << "Cannot find matched MMA configurations."; } -std::string PrintMMAm16n8k4Assembly(const std::string& A_layout, const std::string& B_layout, - const std::string& A_dtype, const std::string& B_dtype, - const std::string& C_dtype, const std::string& a_ref, - const std::string& a_bias, const std::string& b_ref, - const std::string& b_bias, const std::string& c_ref, - const std::string& c_bias, bool saturate) { - std::string asm_code = ""; - std::string new_a_ref = ""; - std::string new_b_ref = ""; - std::string new_c_ref = ""; - ICHECK((A_dtype == "tf32") && (B_dtype == "tf32")); - ICHECK(saturate == false) << "Saturate is not allowed for m16n8k4 mma."; - // A/B multiplicand is tf32, SM 80 Tensor Core instructions - ICHECK(C_dtype == "fp32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k4 expect A layout is row major and B layout is col major."; - // C accumulator is fp32 - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "f"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); - } - )"; - asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); - asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); - asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); - asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); - asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); - asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); - return asm_code; -} +/*! + * \brief Fragment attributes + */ +class FragAttrs { + public: + explicit FragAttrs(char reg_type, uint32_t size, std::string ptr_type) + : reg_type(reg_type), size(size), ptr_type(ptr_type) {} + /*! \brief PTX register type */ + char reg_type; + /*! \brief Fragment size */ + uint32_t size; + /*! \brief Fragment pointer type */ + std::string ptr_type; +}; -std::string PrintMMAm16n8k16Assembly(const std::string& A_layout, const std::string& B_layout, - const std::string& A_dtype, const std::string& B_dtype, - const std::string& C_dtype, const std::string& a_ref, - const std::string& a_bias, const std::string& b_ref, - const std::string& b_bias, const std::string& c_ref, - const std::string& c_bias, bool saturate) { - std::string asm_code = ""; - std::string new_a_ref = ""; - std::string new_b_ref = ""; - std::string new_c_ref = ""; - ICHECK(((A_dtype == "fp16") && (B_dtype == "fp16")) || - ((A_dtype == "bf16") && (B_dtype == "bf16")) || - ((A_dtype == "int8") && (B_dtype == "int8")) || - ((A_dtype == "uint8") && (B_dtype == "int8")) || - ((A_dtype == "int8") && (B_dtype == "uint8")) || - ((A_dtype == "uint8") && (B_dtype == "uint8"))); - if ((A_dtype == "fp16") && (B_dtype == "fp16")) { - ICHECK(saturate == false) << "Saturate is not allowed for m16n8k8 fp16 mma."; - // A/B multiplicand is fp16, SM 80 Tensor Core instructions - ICHECK((C_dtype == "fp16") || (C_dtype == "fp32")); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k16 expect A layout is row major and B layout is col major."; - if (C_dtype == "fp16") { - // C accumulator is fp16 - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((unsigned *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " - "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, " - "{%8,%9};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1])); - } - )"; - } else { - // C accumulator is fp32 - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); - } - )"; - } - } else if ((A_dtype == "bf16") && (B_dtype == "bf16")) { - // A/B multiplicand is bf16, SM 80 Tensor Core instructions - ICHECK(saturate == false) << "Saturate is not allowed for m16n8k8 bf16 mma."; - ICHECK(C_dtype == "fp32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k16 expect A layout is row major and B layout is col major."; - // C accumulator is fp32 - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); - } - )"; - } else if ((A_dtype == "int8") && (B_dtype == "int8")) { - // A/B multiplicand is int8, SM 80 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k16 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } - } else if ((A_dtype == "uint8") && (B_dtype == "int8")) { - // A multiplicand is uint8, B multiplicand is int8 - // SM 80 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k16 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32 " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } - } else if ((A_dtype == "int8") && (B_dtype == "uint8")) { - // A multiplicand is int8, B multiplicand is uint8 - // SM 80 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k16 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32 " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } - } else { - // A/B multiplicand is uint8, SM 80 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k16 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32 " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } +/*! + * \brief Fragment attributes of given data type. + */ +inline FragAttrs GetFragAttrs(DataType dtype) { + switch (dtype) { + case DataType::kBit1: + case DataType::kInt4: + case DataType::kUInt4: + case DataType::kInt8: + case DataType::kUInt8: + case DataType::kFloat16: // .f16x2 register + case DataType::kBFloat16: + case DataType::kTensorFloat32: + return FragAttrs('r', 32, "(unsigned *)"); + case DataType::kInt32: + return FragAttrs('r', 32, "(int *)"); + case DataType::kFloat32: + return FragAttrs('f', 32, "(float *)"); + case DataType::kFloat64: + return FragAttrs('d', 64, "(double *)"); + default: + ICHECK(false) << DTypeToString(dtype) << " is not matrix data type in MMA."; + return FragAttrs('\0', 0, ""); } - asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); - asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); - asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); - asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); - asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); - asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); - return asm_code; } -std::string PrintMMAm16n8k32Assembly(const std::string& A_layout, const std::string& B_layout, - const std::string& A_dtype, const std::string& B_dtype, - const std::string& C_dtype, const std::string& a_ref, - const std::string& a_bias, const std::string& b_ref, - const std::string& b_bias, const std::string& c_ref, - const std::string& c_bias, bool saturate) { - std::string asm_code = ""; - std::string new_a_ref = ""; - std::string new_b_ref = ""; - std::string new_c_ref = ""; - ICHECK(((A_dtype == "int8") && (B_dtype == "int8")) || - ((A_dtype == "uint8") && (B_dtype == "int8")) || - ((A_dtype == "int8") && (B_dtype == "uint8")) || - ((A_dtype == "uint8") && (B_dtype == "uint8"))); - if ((A_dtype == "int8") && (B_dtype == "int8")) { - // A/B multiplicand is int8, SM 80 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k32 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } - } else if ((A_dtype == "uint8") && (B_dtype == "int8")) { - // A multiplicand is uint8, B multiplicand is int8 - // SM 80 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k32 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } - } else if ((A_dtype == "int8") && (B_dtype == "uint8")) { - // A multiplicand is int8, B multiplicand is uint8 - // SM 80 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k32 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; +}; // namespace ptx + +/*! + * \brief Replace patterns with replacement strings. + * \note should use std::format instead when codebase is ported to C++20. + */ +class Replacer { + public: + void register_rule(const std::string& pattern, const std::string& replacement) { + _rules.emplace_back(pattern, replacement); + } + std::string rewrite(std::string str) { + for (auto&& rule : _rules) { + std::string pattern, replacement; + std::tie(pattern, replacement) = rule; + size_t len = pattern.size(); + size_t new_len = replacement.size(); + size_t pos = str.find(pattern); + while (pos != std::string::npos) { + str = str.replace(pos, len, replacement); + pos = str.find(pattern, pos + new_len); + } } + return str; + } + void empty_rules() { _rules.clear(); } + + private: + std::vector> _rules; +}; + +/*! + * \brief Get the number of MMA computations for given shape and datatype. + */ +inline uint32_t GetNumMMAComputations(int m, int n, int k, ptx::DataType dtype) { + if (m == 8 && n == 8 && k == 4 && dtype == ptx::DataType::kFloat16) { + // MMA for m8n8k4 on fp16 would launch 4 MMA computations instead of one. + return 4; } else { - // A/B multiplicand is uint8, SM 80 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k32 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } + return 1; } - asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); - asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); - asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); - asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); - asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); - asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); - return asm_code; } -std::string PrintMMAm16n8k64Assembly(const std::string& A_layout, const std::string& B_layout, - const std::string& A_dtype, const std::string& B_dtype, - const std::string& C_dtype, const std::string& a_ref, - const std::string& a_bias, const std::string& b_ref, - const std::string& b_bias, const std::string& c_ref, - const std::string& c_bias, bool saturate) { - std::string asm_code = ""; - std::string new_a_ref = ""; - std::string new_b_ref = ""; - std::string new_c_ref = ""; - ICHECK(((A_dtype == "int4") && (B_dtype == "int4")) || - ((A_dtype == "uint4") && (B_dtype == "int4")) || - ((A_dtype == "int4") && (B_dtype == "uint4")) || - ((A_dtype == "uint4") && (B_dtype == "uint4"))); - if ((A_dtype == "int4") && (B_dtype == "int4")) { - // A/B multiplicand is int4, SM 80 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k64 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } - } else if ((A_dtype == "uint4") && (B_dtype == "int4")) { - // A multiplicand is uint4, B multiplicand is int4 - // SM 80 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k64 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } - } else if ((A_dtype == "int4") && (B_dtype == "uint4")) { - // A multiplicand is int4, B multiplicand is uint4 - // SM 80 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k64 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } - } else { - // A/B multiplicand is uint4, SM 75 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k64 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; +/*! + * \brief Return template string, input operands string and output operands string. + * \param m The M in mMnNkK of MMA instructions. + * \param n The N in mMnNkK of MMA instructions. + * \param k The K in mMnNkK of MMA instructions. + * \param dtype_a The data type of multiplicand a. + * \param dtype_b The data type of multiplicand b. + * \param dtype_c The data type of accumulator c. + * \param sparse Whether it's Sparse MMA or not. + */ +inline std::tuple GetMMAOperands(int m, int n, int k, + ptx::DataType dtype_a, + ptx::DataType dtype_b, + ptx::DataType dtype_c, + bool sparse) { + std::stringstream templates, inputs, outputs; + const ptx::FragAttrs frag_attr_a = ptx::GetFragAttrs(dtype_a), + frag_attr_b = ptx::GetFragAttrs(dtype_b), + frag_attr_c = ptx::GetFragAttrs(dtype_c); + constexpr uint32_t warp_size = 32; + const uint32_t threads = warp_size / GetNumMMAComputations(m, n, k, dtype_a); + const int num_operands_a = + (m * k) * ptx::DTypeBits(dtype_a) / frag_attr_a.size / threads / (sparse ? 2 : 1), + num_operands_b = (k * n) * ptx::DTypeBits(dtype_b) / frag_attr_b.size / threads, + num_operands_c = (m * n) * ptx::DTypeBits(dtype_c) / frag_attr_c.size / threads; + + // generate templates; + int arg_counter = 0; + templates << "{" + << "%" << arg_counter++; + for (int i = 1; i < num_operands_c; ++i) { + templates << ", %" << arg_counter++; + } + templates << "}, {" + << "%" << arg_counter++; + for (int i = 1; i < num_operands_a; ++i) { + templates << ", %" << arg_counter++; + } + templates << "}, {" + << "%" << arg_counter++; + for (int i = 1; i < num_operands_b; ++i) { + templates << ", %" << arg_counter++; + } + templates << "}, {" + << "%" << arg_counter++; + for (int i = 1; i < num_operands_c; ++i) { + templates << ", %" << arg_counter++; + } + templates << "}"; + // templates of metadata and sparse selector for sparse mma. + if (sparse) { + templates << ", %" << (arg_counter++) << ", F"; + } + + // generate inputs + for (int i = 0; i < num_operands_a; ++i) { + if (i != 0) { + inputs << ", "; } + inputs << "\"" << frag_attr_a.reg_type << "\"((" << frag_attr_a.ptr_type << "(A))[" << i + << "])"; + } + for (int i = 0; i < num_operands_b; ++i) { + inputs << ", \"" << frag_attr_b.reg_type << "\"((" << frag_attr_b.ptr_type << "(B))[" << i + << "])"; + } + for (int i = 0; i < num_operands_c; ++i) { + inputs << ", \"" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_type << "(C))[" << i + << "])"; + } + // input of metadata for sparse mma. + if (sparse) { + inputs << ", \"r\"(((unsigned *)(E))[0])"; } - asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); - asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); - asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); - asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); - asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); - asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); - return asm_code; -} -std::string PrintMMAm16n8k256Assembly(const std::string& A_layout, const std::string& B_layout, - const std::string& A_dtype, const std::string& B_dtype, - const std::string& C_dtype, const std::string& a_ref, - const std::string& a_bias, const std::string& b_ref, - const std::string& b_bias, const std::string& c_ref, - const std::string& c_bias, bool saturate) { - std::string asm_code = ""; - std::string new_a_ref = ""; - std::string new_b_ref = ""; - std::string new_c_ref = ""; - ICHECK(((A_dtype == "uint1") && (B_dtype == "uint1")) || - ((A_dtype == "int1") && (B_dtype == "int1"))); - if ((A_dtype == "uint1") && (B_dtype == "uint1")) { - // A/B multiplicand is uint1, SM 80 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k256 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // A/B multiplicand is int1, SM 80 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k256 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; + // generate outputs + for (int i = 0; i < num_operands_c; ++i) { + if (i != 0) { + outputs << ","; + } + outputs << " \"=" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_type << "(D))[" << i + << "])"; } - asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); - asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); - asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); - asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); - asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); - asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); - return asm_code; + return std::make_tuple(templates.str(), inputs.str(), outputs.str()); } std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout, const std::string& B_layout, const std::string& A_dtype, const std::string& B_dtype, const std::string& C_dtype, - const std::string& a_ref, const std::string& a_bias, - const std::string& b_ref, const std::string& b_bias, - const std::string& c_ref, const std::string& c_bias, bool saturate) { - ICHECK((shape == "m8n8k4") || (shape == "m16n8k8") || (shape == "m8n8k16") || - (shape == "m8n8k32") || (shape == "m16n8k4") || (shape == "m16n8k16") || - (shape == "m16n8k32") || (shape == "m16n8k64") || (shape == "m16n8k256")); - ICHECK((A_layout == "row") || (A_layout == "col")) << "Unknown A layout: " << A_layout; - ICHECK((B_layout == "row") || (B_layout == "col")) << "Unknown B layout: " << B_layout; - - if (shape == "m8n8k4") { - return PrintMMAm8n8k4Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, - b_ref, b_bias, c_ref, c_bias, saturate); - } else if (shape == "m16n8k8") { - return PrintMMAm16n8k8Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, - b_ref, b_bias, c_ref, c_bias, saturate); - } else if (shape == "m8n8k16") { - return PrintMMAm8n8k16Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, - b_ref, b_bias, c_ref, c_bias, saturate); - } else if (shape == "m8n8k32") { - return PrintMMAm8n8k32Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, - b_ref, b_bias, c_ref, c_bias, saturate); - } else if (shape == "m16n8k4") { - return PrintMMAm16n8k4Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, - b_ref, b_bias, c_ref, c_bias, saturate); - } else if (shape == "m16n8k16") { - return PrintMMAm16n8k16Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, - b_ref, b_bias, c_ref, c_bias, saturate); - } else if (shape == "m16n8k32") { - return PrintMMAm16n8k32Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, - b_ref, b_bias, c_ref, c_bias, saturate); - } else if (shape == "m16n8k64") { - return PrintMMAm16n8k64Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, - b_ref, b_bias, c_ref, c_bias, saturate); - } else if (shape == "m16n8k256") { - return PrintMMAm16n8k256Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, - b_ref, b_bias, c_ref, c_bias, saturate); + const std::string& a_ref, const std::string& a_offset, + const std::string& b_ref, const std::string& b_offset, + const std::string& c_ref, const std::string& c_offset, + const std::string& metadata, const std::string& metadata_offset, + const std::string& sparsity_selector, const std::string& bit_op, + bool sparse, bool saturate) { + ptx::DataType dtype_a = ptx::DTypeFromString(A_dtype), dtype_b = ptx::DTypeFromString(B_dtype), + dtype_c = ptx::DTypeFromString(C_dtype); + ptx::LayoutType layout_a = ptx::LayoutTypeFromString(A_layout), + layout_b = ptx::LayoutTypeFromString(B_layout); + int m, n, k; + std::tie(m, n, k) = ptx::ParseMMAShape(shape); + CheckMMAConfigValidity(m, n, k, layout_a, layout_b, dtype_a, dtype_b, dtype_c, bit_op, sparse, + saturate); + std::string asm_code = R"( + { + __asm__ __volatile__( + "mma{sparse}.sync.aligned.{shape}.{alayout}.{blayout}{saturate}{dtype}{atype}{btype}{ctype}{bitop}" + "{templates};\n" + : {outputs} + : {inputs}); } - /* - * TODO: add mma.m16n8k128 - */ - throw Error("Unknown PTX mma instructions."); +)"; + std::string templates_str, inputs_str, outputs_str; + std::tie(templates_str, inputs_str, outputs_str) = + GetMMAOperands(m, n, k, dtype_a, dtype_b, dtype_c, sparse); + + // replace patterns + Replacer replacer; + replacer.register_rule("{sparse}", sparse ? ".sp" : ""); + replacer.register_rule("{shape}", shape); + replacer.register_rule("{saturate}", saturate ? ".satfinite" : ""); + replacer.register_rule("{alayout}", A_layout); + replacer.register_rule("{blayout}", B_layout); + replacer.register_rule("{atype}", ptx::DTypeToString(dtype_a)); + replacer.register_rule("{btype}", ptx::DTypeToString(dtype_b)); + replacer.register_rule("{ctype}", ptx::DTypeToString(dtype_c)); + replacer.register_rule("{dtype}", ptx::DTypeToString(dtype_c)); + replacer.register_rule("{bitop}", bit_op.empty() ? "" : "." + bit_op + ".popc"); + replacer.register_rule("{templates}", templates_str); + replacer.register_rule("{outputs}", outputs_str); + replacer.register_rule("{inputs}", inputs_str); + asm_code = replacer.rewrite(asm_code); + replacer.empty_rules(); + replacer.register_rule("A", a_ref + " + " + a_offset); + replacer.register_rule("B", b_ref + " + " + b_offset); + replacer.register_rule("C", c_ref + " + " + c_offset); + replacer.register_rule("D", c_ref + " + " + c_offset); + replacer.register_rule("E", metadata + " + " + metadata_offset); + replacer.register_rule("F", sparsity_selector); + asm_code = replacer.rewrite(asm_code); + return asm_code; } } // namespace codegen diff --git a/src/target/source/ptx_mma.h b/src/target/source/ptx_mma.h index d2a7a6705d6d..728478cdf5fb 100644 --- a/src/target/source/ptx_mma.h +++ b/src/target/source/ptx_mma.h @@ -32,12 +32,36 @@ namespace tvm { namespace codegen { +/*! + * \brief Print MMA assembly string given parameters. + * \param shape The shape string mMnNkK + * \param A_layout The layout of multiplicand A, can be either "row" or "col". + * \param B_layout The layout of multiplicand B, can be either "row" or "col". + * \param A_dtype The data type of multiplicand A. + * \param B_dtype The data type of multiplicand B. + * \param C_dtype The data type of multiplicand C. + * \param a_ref Pointer to buffer A. + * \param a_offset The offset of element in A. + * \param b_ref Pointer to buffer B. + * \param b_offset The offset of element in B. + * \param c_ref Pointer to buffer C. + * \param c_offset The offset of element in C. + * \param metadata Pointer to metadata buffer (only used for sparse mma). + * \param metadata_offset The offset of element in metadata. + * \param sparsity_selector The sparsity selector in sparse mma. + * \param bit_op The bit operator used in 1-bit mma, can be either "xor" or "and". + * \param sparse Whether it's sparse mma or not. + * \param saturate Whether saturate output or not. + */ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout, const std::string& B_layout, const std::string& A_dtype, const std::string& B_dtype, const std::string& C_dtype, - const std::string& a_ref, const std::string& a_bias, - const std::string& b_ref, const std::string& b_bias, - const std::string& c_ref, const std::string& c_bias, bool saturate); + const std::string& a_ref, const std::string& a_offset, + const std::string& b_ref, const std::string& b_offset, + const std::string& c_ref, const std::string& c_offset, + const std::string& metadata, const std::string& metadata_offset, + const std::string& sparsity_selector, const std::string& bit_op, + bool sparse, bool saturate); } // namespace codegen } // namespace tvm diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 0e767ead4e6b..977050a2d2ce 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -237,6 +237,9 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_store_matrix_sync) TIR_DEFINE_BUILTIN_FUNC(ptx_mma).set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(ptx_mma_sp) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(vectorhigh) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); diff --git a/tests/python/unittest/test_tir_ptx_mma.py b/tests/python/unittest/test_tir_ptx_mma.py index 8f653c614d42..23405fdee98a 100644 --- a/tests/python/unittest/test_tir_ptx_mma.py +++ b/tests/python/unittest/test_tir_ptx_mma.py @@ -1311,6 +1311,7 @@ def gemm_mma_m16n8k256_row_col_b1b1s32(a: T.handle, b: T.handle, c: T.handle): Accum.data, 0, False, + "xor", dtype="int32", ) ) diff --git a/tests/python/unittest/test_tir_ptx_mma_sp.py b/tests/python/unittest/test_tir_ptx_mma_sp.py new file mode 100644 index 000000000000..321cd28ff6f7 --- /dev/null +++ b/tests/python/unittest/test_tir_ptx_mma_sp.py @@ -0,0 +1,346 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm.script import tir as T +import numpy as np +import tvm.testing + + +def gen_2in4_mask(m: int, n: int): + assert n % 4 == 0 + return np.array( + [[np.sort(np.random.choice(4, 2, replace=False)) for _ in range(n // 4)] for _ in range(m)] + ).astype("uint8") + + +def get_dense_mat_by_mask(val, mask): + m, n_chunks, _ = mask.shape + val = val.reshape(m, n_chunks, 2) + ret = np.zeros((m, n_chunks, 4)).astype(val.dtype) + for i in range(m): + for j in range(n_chunks): + for k in range(2): + ret[i, j, mask[i, j, k]] = val[i, j, k] + return ret.reshape(m, n_chunks * 4) + + +@T.prim_func +def mma_sp_m16n8k16_f16f16f16(a: T.handle, b: T.handle, c: T.handle, _metadata: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [16, 8], dtype="float16") + B = T.match_buffer(b, [16, 8], dtype="float16") + C = T.match_buffer(c, [16, 8], dtype="float16") + metadata = T.match_buffer(_metadata, [8], dtype="uint32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + multi_a = T.allocate([4], "float16", scope="local") + multi_b = T.allocate([4], "float16", scope="local") + accum = T.allocate([4], "float16", scope="local") + meta_local = T.allocate([1], "uint32", scope="local") + for i in range(4): + accum[i] = T.float16(0) + + for i in range(4): + multi_a[i] = A[tx // 4 + i // 2 * 8, tx % 4 * 2 + i % 2] + + for i in range(4): + multi_b[i] = B[tx % 4 * 2 + i % 2 + i // 2 * 8, tx // 4] + + meta_local[0] = metadata[tx // 4] + + T.evaluate( + T.ptx_mma_sp( + "m16n8k16", + "row", + "col", + "fp16", + "fp16", + "fp16", + multi_a.data, + 0, + multi_b.data, + 0, + accum.data, + 0, + meta_local.data, + 0, + 0, + False, + dtype="float16", + ) + ) + + for i in range(4): + C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = accum[i] + + +@T.prim_func +def mma_sp_m16n8k16_f16f16f32(a: T.handle, b: T.handle, c: T.handle, _metadata: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [16, 8], dtype="float16") + B = T.match_buffer(b, [16, 8], dtype="float16") + C = T.match_buffer(c, [16, 8], dtype="float32") + metadata = T.match_buffer(_metadata, [8], dtype="uint32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + multi_a = T.allocate([4], "float16", scope="local") + multi_b = T.allocate([4], "float16", scope="local") + accum = T.allocate([4], "float32", scope="local") + meta_local = T.allocate([1], "uint32", scope="local") + for i in range(4): + accum[i] = T.float16(0) + + for i in range(4): + multi_a[i] = A[tx // 4 + i // 2 * 8, tx % 4 * 2 + i % 2] + + for i in range(4): + multi_b[i] = B[tx % 4 * 2 + i % 2 + i // 2 * 8, tx // 4] + + meta_local[0] = metadata[tx // 4] + + T.evaluate( + T.ptx_mma_sp( + "m16n8k16", + "row", + "col", + "fp16", + "fp16", + "fp32", + multi_a.data, + 0, + multi_b.data, + 0, + accum.data, + 0, + meta_local.data, + 0, + 0, + False, + dtype="float32", + ) + ) + + for i in range(4): + C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = accum[i] + + +@T.prim_func +def mma_sp_m16n8k32_f16f16f16(a: T.handle, b: T.handle, c: T.handle, _metadata: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [16, 16], dtype="float16") + B = T.match_buffer(b, [32, 8], dtype="float16") + C = T.match_buffer(c, [16, 8], dtype="float16") + metadata = T.match_buffer(_metadata, [16], dtype="uint32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + multi_a = T.allocate([8], "float16", scope="local") + multi_b = T.allocate([8], "float16", scope="local") + accum = T.allocate([4], "float16", scope="local") + meta_local = T.allocate([1], "uint32", scope="local") + for i in range(4): + accum[i] = T.float16(0) + + for i in range(8): + multi_a[i] = A[(i % 4) // 2 * 8 + tx // 4, i // 4 * 8 + tx % 4 * 2 + i % 2] + + for i in range(8): + multi_b[i] = B[i // 2 * 8 + tx % 4 * 2 + i % 2, tx // 4] + + meta_local[0] = metadata[tx // 4 * 2 + tx % 2] + + T.evaluate( + T.ptx_mma_sp( + "m16n8k32", + "row", + "col", + "fp16", + "fp16", + "fp16", + multi_a.data, + 0, + multi_b.data, + 0, + accum.data, + 0, + meta_local.data, + 0, + 0, + False, + dtype="float16", + ) + ) + + for i in range(4): + C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = accum[i] + + +@T.prim_func +def mma_sp_m16n8k32_f16f16f32(a: T.handle, b: T.handle, c: T.handle, _metadata: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [16, 16], dtype="float16") + B = T.match_buffer(b, [32, 8], dtype="float16") + C = T.match_buffer(c, [16, 8], dtype="float32") + metadata = T.match_buffer(_metadata, [16], dtype="uint32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + multi_a = T.allocate([8], "float16", scope="local") + multi_b = T.allocate([8], "float16", scope="local") + accum = T.allocate([4], "float32", scope="local") + meta_local = T.allocate([1], "uint32", scope="local") + for i in range(4): + accum[i] = T.float16(0) + + for i in range(8): + multi_a[i] = A[(i % 4) // 2 * 8 + tx // 4, i // 4 * 8 + tx % 4 * 2 + i % 2] + + for i in range(8): + multi_b[i] = B[i // 2 * 8 + tx % 4 * 2 + i % 2, tx // 4] + + meta_local[0] = metadata[tx // 4 * 2 + tx % 2] + + T.evaluate( + T.ptx_mma_sp( + "m16n8k32", + "row", + "col", + "fp16", + "fp16", + "fp32", + multi_a.data, + 0, + multi_b.data, + 0, + accum.data, + 0, + meta_local.data, + 0, + 0, + False, + dtype="float32", + ) + ) + + for i in range(4): + C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = accum[i] + + +@tvm.testing.requires_cuda +def test_mma_sp_m16n8k16_f16(): + def get_meta_m16n8k16_half(mask): + assert mask.shape == (16, 4, 2) + mask = mask.reshape(16, 8) + ret = np.zeros((8,)).astype("uint32") + + for i in range(8): + base = 1 + for blk in range(2): + for j in range(8): + ret[i] |= int(mask[blk * 8 + i, j]) * base + base = base << 2 + return ret + + for out_dtype in ["float16", "float32"]: + func = mma_sp_m16n8k16_f16f16f16 if out_dtype == "float16" else mma_sp_m16n8k16_f16f16f32 + sch = tvm.tir.Schedule(func) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, _ = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # Requires SM80+ + return + cuda_mod = tvm.build(sch.mod, target="cuda") + + A_np = np.random.uniform(-1, 1, [16, 8]).astype("float16") + B_np = np.random.uniform(-1, 1, [16, 8]).astype("float16") + mask = gen_2in4_mask(16, 16) + A_dense_np = get_dense_mat_by_mask(A_np, mask) + C_np = np.matmul(A_dense_np, B_np).astype(out_dtype) + meta = get_meta_m16n8k16_half(mask) + + ctx = tvm.cuda() + A_tvm = tvm.nd.array(A_np, ctx) + B_tvm = tvm.nd.array(B_np, ctx) + C_tvm = tvm.nd.array(np.zeros_like(C_np), ctx) + meta_tvm = tvm.nd.array(meta, ctx) + cuda_mod(A_tvm, B_tvm, C_tvm, meta_tvm) + + tvm.testing.assert_allclose(C_tvm.numpy(), C_np, atol=1e-3, rtol=1e-3) + + +@tvm.testing.requires_cuda +def test_mma_sp_m16n8k32_f16(): + def get_meta_m16n8k32_half(mask): + assert mask.shape == (16, 8, 2) + mask = mask.reshape(16, 2, 8) + ret = np.zeros((8, 2)).astype("uint32") + + for i in range(8): + for k in range(2): + base = 1 + for blk in range(2): + for j in range(8): + ret[i, k] |= int(mask[blk * 8 + i, k, j]) * base + base = base << 2 + + return ret.reshape(16) + + for out_dtype in ["float16", "float32"]: + func = mma_sp_m16n8k32_f16f16f16 if out_dtype == "float16" else mma_sp_m16n8k32_f16f16f32 + sch = tvm.tir.Schedule(func) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, _ = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # Requires SM80+ + return + cuda_mod = tvm.build(sch.mod, target="cuda") + + A_np = np.random.uniform(-1, 1, [16, 16]).astype("float16") + B_np = np.random.uniform(-1, 1, [32, 8]).astype("float16") + mask = gen_2in4_mask(16, 32) + A_dense_np = get_dense_mat_by_mask(A_np, mask) + C_np = np.matmul(A_dense_np, B_np).astype(out_dtype) + meta = get_meta_m16n8k32_half(mask) + + ctx = tvm.cuda() + A_tvm = tvm.nd.array(A_np, ctx) + B_tvm = tvm.nd.array(B_np, ctx) + C_tvm = tvm.nd.array(np.zeros_like(C_np), ctx) + meta_tvm = tvm.nd.array(meta, ctx) + cuda_mod(A_tvm, B_tvm, C_tvm, meta_tvm) + + tvm.testing.assert_allclose(C_tvm.numpy(), C_np, atol=1e-3, rtol=1e-3) + + +if __name__ == "__main__": + test_mma_sp_m16n8k16_f16() + test_mma_sp_m16n8k32_f16()