From 4067af44c1bdbc7a14fdbdfd9c019f4f7ef70ea1 Mon Sep 17 00:00:00 2001 From: Zihao Date: Wed, 19 Jan 2022 00:50:54 -0800 Subject: [PATCH 01/15] init --- include/tvm/tir/builtin.h | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index f7e1cfbc3e6d..d85de9d10d3e 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -596,6 +596,14 @@ TVM_DLL const Op& tvm_store_matrix_sync(); */ TVM_DLL const Op& ptx_mma(); +/*! + * \brief tvm intrinsic for sparse tensor core ptx instructions. + * + * void tvm_ptx_mma_sp(TODO) + */ +TVM_DLL const Op& tvm_ptx_mma_sp(); + + // TODO(tvm-team) replace the usage of the vector operations by Shuffle. /*! * \brief Get the high level half of the vector From 7471f299fad2a94e928a66191c398af9639f815a Mon Sep 17 00:00:00 2001 From: Zihao Date: Wed, 9 Feb 2022 21:31:49 -0800 Subject: [PATCH 02/15] upd --- include/tvm/tir/builtin.h | 4 ++-- src/target/source/codegen_cuda.cc | 5 +++++ src/target/source/ptx_mma.h | 2 ++ src/tir/op/builtin.cc | 3 +++ 4 files changed, 12 insertions(+), 2 deletions(-) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index d85de9d10d3e..fe588d69ae36 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -599,9 +599,9 @@ TVM_DLL const Op& ptx_mma(); /*! * \brief tvm intrinsic for sparse tensor core ptx instructions. * - * void tvm_ptx_mma_sp(TODO) + * void ptx_mma_sp(TODO) */ -TVM_DLL const Op& tvm_ptx_mma_sp(); +TVM_DLL const Op& ptx_mma_sp(); // TODO(tvm-team) replace the usage of the vector operations by Shuffle. diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 0dda079066d9..ba93cfde948c 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -761,6 +761,11 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string asm_code = PrintMMAAssembly(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, b_ref, b_bias, c_ref, c_bias, saturate); + this->stream << asm_code; + } else if (op->op.same_as(builtin::ptx_mma_sp())) { + // arg 0: shape: mXnXkX + // arg 1: + std::string asm_code = PrintMMASparseAssembly(); this->stream << asm_code; } else { CodeGenC::VisitExpr_(op, os); diff --git a/src/target/source/ptx_mma.h b/src/target/source/ptx_mma.h index d2a7a6705d6d..43a37d51a08b 100644 --- a/src/target/source/ptx_mma.h +++ b/src/target/source/ptx_mma.h @@ -39,6 +39,8 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo const std::string& b_ref, const std::string& b_bias, const std::string& c_ref, const std::string& c_bias, bool saturate); +std::string PrintMMASparseAssembly(); + } // namespace codegen } // namespace tvm diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 0e767ead4e6b..833139c5f201 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -236,6 +236,9 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_store_matrix_sync) TIR_DEFINE_BUILTIN_FUNC(ptx_mma).set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(ptx_mma_sp).set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); TIR_DEFINE_BUILTIN_FUNC(vectorhigh) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); From 21d293d362e906a0dcda92116317966eb32aa815 Mon Sep 17 00:00:00 2001 From: Zihao Date: Mon, 21 Feb 2022 15:50:56 -0800 Subject: [PATCH 03/15] upd --- include/tvm/tir/builtin.h | 7 +- src/target/source/codegen_cuda.cc | 36 ++- src/target/source/ptx_mma.h | 2 - src/target/source/ptx_mma_sp.cc | 282 +++++++++++++++++++ src/target/source/ptx_mma_sp.h | 47 ++++ tests/python/unittest/test_tir_ptx_mma_sp.py | 145 ++++++++++ 6 files changed, 514 insertions(+), 5 deletions(-) create mode 100644 src/target/source/ptx_mma_sp.cc create mode 100644 src/target/source/ptx_mma_sp.h create mode 100644 tests/python/unittest/test_tir_ptx_mma_sp.py diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index fe588d69ae36..5b62ce1f8d44 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -599,7 +599,12 @@ TVM_DLL const Op& ptx_mma(); /*! * \brief tvm intrinsic for sparse tensor core ptx instructions. * - * void ptx_mma_sp(TODO) + * void ptx_mma_sp(StringImm shape, StringImm A_layout, StringImm B_layout, + * StringImm A_dtype, StringImm B_dtype, StringImm C_dtype, + * Var multiplicand_a, Expr a_index, + * Var multiplicand_b, Expr b_index, + * Var accumulator, Expr c_index, + * Var metadata, Var sparse_selector, bool saturate); */ TVM_DLL const Op& ptx_mma_sp(); diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index ba93cfde948c..9fda1eb82695 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -34,6 +34,7 @@ #include "literal/cuda_half_t.h" #include "ptx_mma.h" +#include "ptx_mma_sp.h" namespace tvm { namespace codegen { @@ -764,8 +765,39 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { this->stream << asm_code; } else if (op->op.same_as(builtin::ptx_mma_sp())) { // arg 0: shape: mXnXkX - // arg 1: - std::string asm_code = PrintMMASparseAssembly(); + // arg 1: A layout: row/col + // arg 2: B layout: row/col + // arg 3: A precision: fp16, fp32, ... + // arg 4: B precision: fp16, fp32, ... + // arg 5: C precision: fp16, fp32, ... + // arg 6: A multiplicand + // arg 7: A multiplicand index + // arg 8: B multiplicand + // arg 9: B multiplicand index + // arg 10: C accumulator + // arg 11: C accumulator index + // arg 12: metadata + // arg 13: sparse_selector + // arg 14: saturate + ICHECK_EQ(op->args.size(), 15U); + std::string shape = Downcast(op->args[0])->value; + std::string A_layout = Downcast(op->args[1])->value; + std::string B_layout = Downcast(op->args[2])->value; + std::string A_dtype = Downcast(op->args[3])->value; + std::string B_dtype = Downcast(op->args[4])->value; + std::string C_dtype = Downcast(op->args[5])->value; + std::string a_ref = this->PrintExpr(op->args[6]); + std::string a_offset = this->PrintExpr(op->args[7]); + std::string b_ref = this->PrintExpr(op->args[8]); + std::string b_offset = this->PrintExpr(op->args[9]); + std::string c_ref = this->PrintExpr(op->args[10]); + std::string c_offset = this->PrintExpr(op->args[11]); + std::string metadata = this->PrintExpr(op->args[12]); + std::string sparse_selector = this->PrintExpr(op->args[13]); + bool saturate = (Downcast(op->args[14])->value != 0); + std::string asm_code = PrintMMASparseAssembly(shape, A_layout, B_layout, A_dtype, B_dtype, + C_dtype, a_ref, a_offset, b_ref, b_offset, c_ref, + c_offset, metadata, sparse_selector, saturate); this->stream << asm_code; } else { CodeGenC::VisitExpr_(op, os); diff --git a/src/target/source/ptx_mma.h b/src/target/source/ptx_mma.h index 43a37d51a08b..d2a7a6705d6d 100644 --- a/src/target/source/ptx_mma.h +++ b/src/target/source/ptx_mma.h @@ -39,8 +39,6 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo const std::string& b_ref, const std::string& b_bias, const std::string& c_ref, const std::string& c_bias, bool saturate); -std::string PrintMMASparseAssembly(); - } // namespace codegen } // namespace tvm diff --git a/src/target/source/ptx_mma_sp.cc b/src/target/source/ptx_mma_sp.cc new file mode 100644 index 000000000000..b59cf1371b68 --- /dev/null +++ b/src/target/source/ptx_mma_sp.cc @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ptx_mma_sp.cc + */ + +#include "ptx_mma_sp.h" + +#include +#include +#include +#include +#include + +namespace tvm { +namespace codegen { + +namespace ptx { + +enum class DataType : int { + kInt4 = 0, + kUInt4 = 1, + kInt8 = 2, + kUInt8 = 3, + kInt16 = 4, + kUInt16 = 5, + kInt32 = 6, + kUInt32 = 7, + kInt64 = 8, + kUInt64 = 9, + kFloat16 = 10, + kBFloat16 = 11, + kFloat16x2 = 12, + kFloat32 = 13, + kTensorFloat32 = 14, + kFloat64 = 15, + kBit1 = 16 +}; + +static const char* dtype_str[] = {".s4", ".u4", ".s8", ".u8", ".s16", ".u16", + ".s32", ".u32", ".s64", ".u64", ".f16", ".bf16", + ".f16x2", ".f32", ".tf32", ".f64", ".b1"}; +static uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 16, 16, 32, 32, 32, 64, 1}; + +inline DataType DTypeFromString(const std::string str) { + if (str == "int4" || str == ".s4") { + return DataType::kInt4; + } else if (str == "uint4" || str == ".u4") { + return DataType::kUInt4; + } else if (str == "int8" || str == ".s8") { + return DataType::kUInt8; + } else if (str == "uint8" || str == ".u8") { + return DataType::kUInt8; + } else if (str == "int16" || str == ".s16") { + return DataType::kInt16; + } else if (str == "uint16" || str == ".u16") { + return DataType::kUInt16; + } else if (str == "int32" || str == ".s32") { + return DataType::kInt32; + } else if (str == "uint32" || str == ".u32") { + return DataType::kUInt32; + } else if (str == "int64" || str == ".s64") { + return DataType::kInt64; + } else if (str == "uint64" || str == ".u64") { + return DataType::kUInt64; + } else if (str == "float16" || str == "fp16" || str == ".f16") { + return DataType::kFloat16; + } else if (str == "bfloat16" || str == "bf16") { + return DataType::kBFloat16; + } else if (str == ".f16x2") { + return DataType::kFloat16x2; + } else if (str == "float32" || str == "fp32" || str == ".f32") { + return DataType::kFloat32; + } else if (str == "tf32") { + return DataType::kTensorFloat32; + } else if (str == "float64" || str == "fp64" || str == ".f64") { + return DataType::kFloat64; + } else if (str == ".b1") { + return DataType::kBit1; + } else { + LOG(FATAL) << "Unrecognized data type " << str << " for PTX."; + return DataType(0); + } +} + +inline std::string DTypeToString(DataType dtype) { return dtype_str[int(dtype)]; } + +inline uint32_t DTypeToBits(DataType dtype) { return num_bits[int(dtype)]; } + +std::tuple ParseMMAShape(const std::string& str) { + size_t pos_m = str.find("m"), pos_n = str.find("n"), pos_k = str.find("k"); + CHECK(pos_m != str.npos && pos_n != str.npos && pos_k != str.npos) + << "Cannot parse MMA shape " << str; + int m = std::stoi(str.substr(pos_m + 1, pos_n - pos_m - 1)), + n = std::stoi(str.substr(pos_n + 1, pos_k - pos_n - 1)), k = std::stoi(str.substr(pos_k + 1)); + return {m, n, k}; +} + +inline std::tuple FragmentAttrs(DataType dtype) { + switch (dtype) { + case DataType::kBit1: + case DataType::kInt4: + case DataType::kUInt4: + case DataType::kInt8: + case DataType::kUInt8: + case DataType::kFloat16: // .f16x2 register + case DataType::kBFloat16: + case DataType::kTensorFloat32: + return {'r', 32, "(unsigned *)"}; + case DataType::kInt32: + return {'r', 32, "(int *)"}; + case DataType::kFloat32: + return {'f', 32, "(float *)"}; + case DataType::kFloat64: + return {'d', 64, "(double *)"}; + default: + LOG(FATAL) << DTypeToString(dtype) << " is not matrix data type in MMA."; + return {'\0', 0, ""}; + } +} + +}; // namespace ptx + +class Replacer { + public: + explicit Replacer() {} + void register_rule(const std::string& pattern, const std::string& replacement) { + _rules.emplace_back(pattern, replacement); + } + std::string rewrite(std::string str) { + for (auto&& rule : _rules) { + std::string pattern, replacement; + std::tie(pattern, replacement) = rule; + size_t len = pattern.size(); + size_t new_len = replacement.size(); + size_t pos = str.find(pattern); + while (pos != std::string::npos) { + str = str.replace(pos, len, replacement); + pos = str.find(pattern, pos + new_len); + } + } + return str; + } + void empty_rules() { _rules.clear(); } + + private: + std::vector> _rules; +}; + +inline std::tuple get_mma_sp_operands( + int m, int n, int k, ptx::DataType dtype_a, ptx::DataType dtype_b, ptx::DataType dtype_c) { + std::stringstream templates, inputs, outputs; + auto frag_attr_a = ptx::FragmentAttrs(dtype_a), frag_attr_b = ptx::FragmentAttrs(dtype_b), + frag_attr_c = ptx::FragmentAttrs(dtype_c); + constexpr int warp_size = 32; + int num_operands_a, num_operands_b, num_operands_c; + num_operands_a = (m * k / 2) * ptx::DTypeToBits(dtype_a) / std::get<1>(frag_attr_a) / warp_size; + num_operands_b = (k * n) * ptx::DTypeToBits(dtype_b) / std::get<1>(frag_attr_b) / warp_size; + num_operands_c = (m * n) * ptx::DTypeToBits(dtype_c) / std::get<1>(frag_attr_c) / warp_size; + + // generate templates; + int arg_counter = 0; + templates << "{" + << "%" << arg_counter++; + for (int i = 1; i < num_operands_c; ++i) { + templates << ", %" << arg_counter++; + } + templates << "}, {" + << "%" << arg_counter++; + for (int i = 1; i < num_operands_a; ++i) { + templates << ", %" << arg_counter++; + } + templates << "}, {" + << "%" << arg_counter++; + for (int i = 1; i < num_operands_b; ++i) { + templates << ", %" << arg_counter++; + } + templates << "}, {" + << "%" << arg_counter++; + for (int i = 1; i < num_operands_c; ++i) { + templates << ", %" << arg_counter++; + } + templates << "}, %" << (arg_counter++) << ", F"; + + // generate inputs + for (int i = 0; i < num_operands_a; ++i) { + if (i != 0) { + inputs << ", "; + } + inputs << "\"" << std::get<0>(frag_attr_a) << "\"((" << std::get<2>(frag_attr_a) << "(A))[" << i + << "])"; + } + for (int i = 0; i < num_operands_b; ++i) { + inputs << ", \"" << std::get<0>(frag_attr_b) << "\"((" << std::get<2>(frag_attr_b) << "(B))[" + << i << "])"; + } + for (int i = 0; i < num_operands_c; ++i) { + inputs << ", \"" << std::get<0>(frag_attr_c) << "\"((" << std::get<2>(frag_attr_c) << "(C))[" + << i << "])"; + } + inputs << ", \"r\"(E[0])"; + + // generate outputs + for (int i = 0; i < num_operands_c; ++i) { + if (i != 0) { + outputs << ","; + } + outputs << " \"=" << std::get<0>(frag_attr_c) << "\"((" << std::get<2>(frag_attr_c) << "(D))[" + << i << "])"; + } + return {templates.str(), inputs.str(), outputs.str()}; +} + +std::string PrintMMASparseAssembly(const std::string& shape, const std::string& A_layout, + const std::string& B_layout, const std::string& A_dtype, + const std::string& B_dtype, const std::string& C_dtype, + const std::string& a_ref, const std::string& a_offset, + const std::string& b_ref, const std::string& b_offset, + const std::string& c_ref, const std::string& c_offset, + const std::string& metadata, + const std::string& sparsity_selector, bool saturate) { + ptx::DataType dtype_a = ptx::DTypeFromString(A_dtype), dtype_b = ptx::DTypeFromString(B_dtype), + dtype_c = ptx::DTypeFromString(C_dtype); + int m, n, k; + std::tie(m, n, k) = ptx::ParseMMAShape(shape); + std::string asm_code = R"( + { + __asm__ __volatile__( + "mma.sp.sync.aligned.{shape}.{alayout}.{blayout}{satinite}{dtype}{atype}{btype}{ctype}" + "{templates};\n" + : {outputs} + : {inputs}); + } +)"; + std::string templates_str, inputs_str, outputs_str; + std::tie(templates_str, inputs_str, outputs_str) = + get_mma_sp_operands(m, n, k, dtype_a, dtype_b, dtype_c); + + // replace patterns + Replacer replacer; + replacer.register_rule("{shape}", shape); + replacer.register_rule("{satinite}", saturate ? ".satinite" : ""); + replacer.register_rule("{alayout}", A_layout); + replacer.register_rule("{blayout}", B_layout); + replacer.register_rule("{atype}", ptx::DTypeToString(dtype_a)); + replacer.register_rule("{btype}", ptx::DTypeToString(dtype_b)); + replacer.register_rule("{ctype}", ptx::DTypeToString(dtype_c)); + replacer.register_rule("{dtype}", ptx::DTypeToString(dtype_c)); + replacer.register_rule("{templates}", templates_str); + replacer.register_rule("{outputs}", outputs_str); + replacer.register_rule("{inputs}", inputs_str); + asm_code = replacer.rewrite(asm_code); + replacer.empty_rules(); + replacer.register_rule("A", a_ref + " + " + a_offset); + replacer.register_rule("B", b_ref + " + " + b_offset); + replacer.register_rule("C", c_ref + " + " + c_offset); + replacer.register_rule("D", c_ref + " + " + c_offset); + replacer.register_rule("E", metadata); + replacer.register_rule("F", sparsity_selector); + asm_code = replacer.rewrite(asm_code); + return asm_code; +} + +} // namespace codegen +} // namespace tvm \ No newline at end of file diff --git a/src/target/source/ptx_mma_sp.h b/src/target/source/ptx_mma_sp.h new file mode 100644 index 000000000000..54bbceb51b31 --- /dev/null +++ b/src/target/source/ptx_mma_sp.h @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ptx_mma_sp.h + * \brief Sparse MMA code generation with inlined PTX code. + */ +#ifndef TVM_TARGET_SOURCE_PTX_MMA_SP_H_ +#define TVM_TARGET_SOURCE_PTX_MMA_SP_H_ + +#include + +#include +#include + +namespace tvm { +namespace codegen { + +std::string PrintMMASparseAssembly(const std::string& shape, const std::string& A_layout, + const std::string& B_layout, const std::string& A_dtype, + const std::string& B_dtype, const std::string& C_dtype, + const std::string& a_ref, const std::string& a_offset, + const std::string& b_ref, const std::string& b_offset, + const std::string& c_ref, const std::string& c_offset, + const std::string& metadata, + const std::string& sparsity_selector, bool saturate); + +} // namespace codegen +} // namespace tvm + +#endif // TVM_TARGET_SOURCE_PTX_MMA_H_ diff --git a/tests/python/unittest/test_tir_ptx_mma_sp.py b/tests/python/unittest/test_tir_ptx_mma_sp.py new file mode 100644 index 000000000000..98a844cfdef4 --- /dev/null +++ b/tests/python/unittest/test_tir_ptx_mma_sp.py @@ -0,0 +1,145 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys +import pytest + +import tvm +from tvm.script import tir as T +import numpy as np +import tvm.testing + + +def gen_2in4_mask(m: int, n: int): + assert n % 4 == 0 + return np.array([[np.sort(np.random.choice(4, 2, replace=False)) for _ in range(n // 4)] for _ in range(m)]).astype("uint8") + + +def get_dense_mat_by_mask(val, mask): + m, n_chunks, _ = mask.shape + val = val.reshape(m, n_chunks, 2) + ret = np.zeros((m, n_chunks, 4)).astype(val.dtype) + for i in range(m): + for j in range(n_chunks): + for k in range(2): + ret[i, j, mask[i, j, k]] = val[i, j, k] + return ret.reshape(m, n_chunks * 4) + + +def get_meta_m16n8k16_half(mask): + assert mask.shape == (16, 4, 2) + mask = mask.reshape(16, 8) + ret = np.zeros((8,)).astype("uint32") + + for i in range(8): + base = 1 + for k in range(2): + for j in range(8): + ret[i] |= int(mask[k * 8 + i, j]) * base + base = base << 2 + return ret + + +@T.prim_func +def mma_sp_m16n8k16_fp16(a: T.handle, b: T.handle, c: T.handle, _metadata: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [16, 8], dtype="float16") + B = T.match_buffer(b, [16, 8], dtype="float16") + C = T.match_buffer(c, [16, 8], dtype="float16") + metadata = T.match_buffer(_metadata, [8], dtype="uint32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + multi_a = T.allocate([4], "float16", scope="local") + multi_b = T.allocate([4], "float16", scope="local") + accum = T.allocate([4], "float16", scope="local") + meta_local = T.allocate([1], "uint32", scope="local") + for i in range(4): + accum[i] = T.float16(0) + + for i in range(4): + multi_a[i] = A[ + tx // 4 + i // 2 * 8, + tx % 4 * 2 + i % 2 + ] + + for i in range(4): + multi_b[i] = B[ + tx % 4 * 2 + i % 2 + i // 2 * 8, + tx // 4 + ] + + meta_local[0] = metadata[tx // 4] + + T.evaluate( + T.ptx_mma_sp( + "m16n8k16", + "row", + "col", + "fp16", + "fp16", + "fp16", + multi_a, + 0, + multi_b, + 0, + accum, + 0, + meta_local, + 0, + False, + dtype="float16" + ) + ) + + for i in range(4): + C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = T.load("float16", accum, i) + + +@tvm.testing.requires_cuda +def test_mma_sp_m16n8k16_fp16(): + sch = tvm.tir.Schedule(mma_sp_m16n8k16_fp16) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, _ = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # Requires SM80+ + return + cuda_mod = tvm.build(sch.mod, target="cuda") + print(cuda_mod.imported_modules[0].get_source()) + + A_np = np.random.uniform(-1, 1, [16, 8]).astype("float16") + B_np = np.random.uniform(-1, 1, [16, 8]).astype("float16") + mask = gen_2in4_mask(16, 16) + A_dense_np = get_dense_mat_by_mask(A_np, mask) + C_np = np.matmul(A_dense_np, B_np) + meta = get_meta_m16n8k16_half(mask) + + ctx = tvm.cuda() + A_tvm = tvm.nd.array(A_np, ctx) + B_tvm = tvm.nd.array(B_np, ctx) + C_tvm = tvm.nd.array(np.zeros_like(C_np), ctx) + meta_tvm = tvm.nd.array(meta, ctx) + cuda_mod(A_tvm, B_tvm, C_tvm, meta_tvm) + + tvm.testing.assert_allclose(C_tvm.numpy(), C_np, atol=1e-3, rtol=1e-3) + + +if __name__ == "__main__": + test_mma_sp_m16n8k16_fp16() From 708480e082177550ed35a1950841a2be484ea379 Mon Sep 17 00:00:00 2001 From: Zihao Date: Mon, 21 Feb 2022 16:14:22 -0800 Subject: [PATCH 04/15] lint --- src/target/source/ptx_mma_sp.cc | 7 +++---- src/target/source/ptx_mma_sp.h | 2 +- src/tir/op/builtin.cc | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/target/source/ptx_mma_sp.cc b/src/target/source/ptx_mma_sp.cc index b59cf1371b68..01cc71c8d571 100644 --- a/src/target/source/ptx_mma_sp.cc +++ b/src/target/source/ptx_mma_sp.cc @@ -100,9 +100,9 @@ inline DataType DTypeFromString(const std::string str) { } } -inline std::string DTypeToString(DataType dtype) { return dtype_str[int(dtype)]; } +inline std::string DTypeToString(DataType dtype) { return dtype_str[static_cast(dtype)]; } -inline uint32_t DTypeToBits(DataType dtype) { return num_bits[int(dtype)]; } +inline uint32_t DTypeToBits(DataType dtype) { return num_bits[static_cast(dtype)]; } std::tuple ParseMMAShape(const std::string& str) { size_t pos_m = str.find("m"), pos_n = str.find("n"), pos_k = str.find("k"); @@ -140,7 +140,6 @@ inline std::tuple FragmentAttrs(DataType dtype) { class Replacer { public: - explicit Replacer() {} void register_rule(const std::string& pattern, const std::string& replacement) { _rules.emplace_back(pattern, replacement); } @@ -279,4 +278,4 @@ std::string PrintMMASparseAssembly(const std::string& shape, const std::string& } } // namespace codegen -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/src/target/source/ptx_mma_sp.h b/src/target/source/ptx_mma_sp.h index 54bbceb51b31..8200e60a943f 100644 --- a/src/target/source/ptx_mma_sp.h +++ b/src/target/source/ptx_mma_sp.h @@ -44,4 +44,4 @@ std::string PrintMMASparseAssembly(const std::string& shape, const std::string& } // namespace codegen } // namespace tvm -#endif // TVM_TARGET_SOURCE_PTX_MMA_H_ +#endif // TVM_TARGET_SOURCE_PTX_MMA_SP_H_ diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 833139c5f201..e47cd139d07f 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -236,7 +236,7 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_store_matrix_sync) TIR_DEFINE_BUILTIN_FUNC(ptx_mma).set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); - + TIR_DEFINE_BUILTIN_FUNC(ptx_mma_sp).set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); From f15632d796f41704329ff97a6b5f35c0abf34a7a Mon Sep 17 00:00:00 2001 From: Zihao Date: Mon, 21 Feb 2022 16:25:16 -0800 Subject: [PATCH 05/15] lint again --- include/tvm/tir/builtin.h | 1 - src/tir/op/builtin.cc | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 5b62ce1f8d44..8e5f5838d1d6 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -608,7 +608,6 @@ TVM_DLL const Op& ptx_mma(); */ TVM_DLL const Op& ptx_mma_sp(); - // TODO(tvm-team) replace the usage of the vector operations by Shuffle. /*! * \brief Get the high level half of the vector diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index e47cd139d07f..977050a2d2ce 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -237,8 +237,8 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_store_matrix_sync) TIR_DEFINE_BUILTIN_FUNC(ptx_mma).set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(ptx_mma_sp).set_attr("TCallEffectKind", - Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(ptx_mma_sp) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); TIR_DEFINE_BUILTIN_FUNC(vectorhigh) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); From 295de969dcd7194130a06aa23f219fc5e8d06832 Mon Sep 17 00:00:00 2001 From: Zihao Date: Mon, 21 Feb 2022 16:50:00 -0800 Subject: [PATCH 06/15] upd --- src/target/source/ptx_mma_sp.cc | 34 +++++++++++++++++--- tests/python/unittest/test_tir_ptx_mma_sp.py | 18 ++++------- 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/src/target/source/ptx_mma_sp.cc b/src/target/source/ptx_mma_sp.cc index 01cc71c8d571..9b6a53768687 100644 --- a/src/target/source/ptx_mma_sp.cc +++ b/src/target/source/ptx_mma_sp.cc @@ -34,6 +34,9 @@ namespace codegen { namespace ptx { +/*! + * \brief PTX data type. + */ enum class DataType : int { kInt4 = 0, kUInt4 = 1, @@ -59,6 +62,9 @@ static const char* dtype_str[] = {".s4", ".u4", ".s8", ".u8", ".s16", ".u ".f16x2", ".f32", ".tf32", ".f64", ".b1"}; static uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 16, 16, 32, 32, 32, 64, 1}; +/*! + * \brief Create PTX data type from string. + */ inline DataType DTypeFromString(const std::string str) { if (str == "int4" || str == ".s4") { return DataType::kInt4; @@ -100,10 +106,19 @@ inline DataType DTypeFromString(const std::string str) { } } +/*! + * \brief Get the string representation of given PTX data type. + */ inline std::string DTypeToString(DataType dtype) { return dtype_str[static_cast(dtype)]; } -inline uint32_t DTypeToBits(DataType dtype) { return num_bits[static_cast(dtype)]; } +/*! + * \brief Get the number of bits of given PTX data type. + */ +inline uint32_t DTypeBits(DataType dtype) { return num_bits[static_cast(dtype)]; } +/*! + * \brief Extract the value m, n, k from string m*n*k* + */ std::tuple ParseMMAShape(const std::string& str) { size_t pos_m = str.find("m"), pos_n = str.find("n"), pos_k = str.find("k"); CHECK(pos_m != str.npos && pos_n != str.npos && pos_k != str.npos) @@ -113,6 +128,10 @@ std::tuple ParseMMAShape(const std::string& str) { return {m, n, k}; } +/*! + * \brief Fragment attributes of given data type. + * \return the register type in ptx, fragment size, fragment pointer string. + */ inline std::tuple FragmentAttrs(DataType dtype) { switch (dtype) { case DataType::kBit1: @@ -138,6 +157,10 @@ inline std::tuple FragmentAttrs(DataType dtype) { }; // namespace ptx +/*! + * \brief Replace patterns with replacement strings. + * \note should use std::format instead when codebase is ported to C++20. + */ class Replacer { public: void register_rule(const std::string& pattern, const std::string& replacement) { @@ -163,6 +186,9 @@ class Replacer { std::vector> _rules; }; +/*! + * \brief Return template string, input operands string and output operands string. + */ inline std::tuple get_mma_sp_operands( int m, int n, int k, ptx::DataType dtype_a, ptx::DataType dtype_b, ptx::DataType dtype_c) { std::stringstream templates, inputs, outputs; @@ -170,9 +196,9 @@ inline std::tuple get_mma_sp_operands( frag_attr_c = ptx::FragmentAttrs(dtype_c); constexpr int warp_size = 32; int num_operands_a, num_operands_b, num_operands_c; - num_operands_a = (m * k / 2) * ptx::DTypeToBits(dtype_a) / std::get<1>(frag_attr_a) / warp_size; - num_operands_b = (k * n) * ptx::DTypeToBits(dtype_b) / std::get<1>(frag_attr_b) / warp_size; - num_operands_c = (m * n) * ptx::DTypeToBits(dtype_c) / std::get<1>(frag_attr_c) / warp_size; + num_operands_a = (m * k / 2) * ptx::DTypeBits(dtype_a) / std::get<1>(frag_attr_a) / warp_size; + num_operands_b = (k * n) * ptx::DTypeBits(dtype_b) / std::get<1>(frag_attr_b) / warp_size; + num_operands_c = (m * n) * ptx::DTypeBits(dtype_c) / std::get<1>(frag_attr_c) / warp_size; // generate templates; int arg_counter = 0; diff --git a/tests/python/unittest/test_tir_ptx_mma_sp.py b/tests/python/unittest/test_tir_ptx_mma_sp.py index 98a844cfdef4..3d8a861bc792 100644 --- a/tests/python/unittest/test_tir_ptx_mma_sp.py +++ b/tests/python/unittest/test_tir_ptx_mma_sp.py @@ -26,7 +26,9 @@ def gen_2in4_mask(m: int, n: int): assert n % 4 == 0 - return np.array([[np.sort(np.random.choice(4, 2, replace=False)) for _ in range(n // 4)] for _ in range(m)]).astype("uint8") + return np.array( + [[np.sort(np.random.choice(4, 2, replace=False)) for _ in range(n // 4)] for _ in range(m)] + ).astype("uint8") def get_dense_mat_by_mask(val, mask): @@ -44,7 +46,7 @@ def get_meta_m16n8k16_half(mask): assert mask.shape == (16, 4, 2) mask = mask.reshape(16, 8) ret = np.zeros((8,)).astype("uint32") - + for i in range(8): base = 1 for k in range(2): @@ -75,16 +77,10 @@ def mma_sp_m16n8k16_fp16(a: T.handle, b: T.handle, c: T.handle, _metadata: T.han accum[i] = T.float16(0) for i in range(4): - multi_a[i] = A[ - tx // 4 + i // 2 * 8, - tx % 4 * 2 + i % 2 - ] + multi_a[i] = A[tx // 4 + i // 2 * 8, tx % 4 * 2 + i % 2] for i in range(4): - multi_b[i] = B[ - tx % 4 * 2 + i % 2 + i // 2 * 8, - tx // 4 - ] + multi_b[i] = B[tx % 4 * 2 + i % 2 + i // 2 * 8, tx // 4] meta_local[0] = metadata[tx // 4] @@ -105,7 +101,7 @@ def mma_sp_m16n8k16_fp16(a: T.handle, b: T.handle, c: T.handle, _metadata: T.han meta_local, 0, False, - dtype="float16" + dtype="float16", ) ) From 4872e327631a1bb60343de8fd5acbeee80e30c6f Mon Sep 17 00:00:00 2001 From: Zihao Date: Tue, 22 Feb 2022 10:44:29 -0800 Subject: [PATCH 07/15] add m16n8k32 testcase --- tests/python/unittest/test_tir_ptx_mma_sp.py | 102 ++++++++++++++++++- 1 file changed, 99 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_tir_ptx_mma_sp.py b/tests/python/unittest/test_tir_ptx_mma_sp.py index 3d8a861bc792..e722680bece8 100644 --- a/tests/python/unittest/test_tir_ptx_mma_sp.py +++ b/tests/python/unittest/test_tir_ptx_mma_sp.py @@ -49,13 +49,29 @@ def get_meta_m16n8k16_half(mask): for i in range(8): base = 1 - for k in range(2): + for blk in range(2): for j in range(8): - ret[i] |= int(mask[k * 8 + i, j]) * base + ret[i] |= int(mask[blk * 8 + i, j]) * base base = base << 2 return ret +def get_meta_m16n8k32_half(mask): + assert mask.shape == (16, 8, 2) + mask = mask.reshape(16, 2, 8) + ret = np.zeros((8, 2)).astype("uint32") + + for i in range(8): + for k in range(2): + base = 1 + for blk in range(2): + for j in range(8): + ret[i, k] |= int(mask[blk * 8 + i, k, j]) * base + base = base << 2 + + return ret.reshape(16) + + @T.prim_func def mma_sp_m16n8k16_fp16(a: T.handle, b: T.handle, c: T.handle, _metadata: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) @@ -109,6 +125,59 @@ def mma_sp_m16n8k16_fp16(a: T.handle, b: T.handle, c: T.handle, _metadata: T.han C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = T.load("float16", accum, i) +@T.prim_func +def mma_sp_m16n8k32_fp16(a: T.handle, b: T.handle, c: T.handle, _metadata: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [16, 16], dtype="float16") + B = T.match_buffer(b, [32, 8], dtype="float16") + C = T.match_buffer(c, [16, 8], dtype="float16") + metadata = T.match_buffer(_metadata, [16], dtype="uint32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + multi_a = T.allocate([8], "float16", scope="local") + multi_b = T.allocate([8], "float16", scope="local") + accum = T.allocate([4], "float16", scope="local") + meta_local = T.allocate([1], "uint32", scope="local") + for i in range(4): + accum[i] = T.float16(0) + + for i in range(8): + multi_a[i] = A[(i % 4) // 2 * 8 + tx // 4, i // 4 * 8 + tx % 4 * 2 + i % 2] + + for i in range(8): + multi_b[i] = B[i // 2 * 8 + tx % 4 * 2 + i % 2, tx // 4] + + meta_local[0] = metadata[tx // 4 * 2 + tx % 2] + + T.evaluate( + T.ptx_mma_sp( + "m16n8k32", + "row", + "col", + "fp16", + "fp16", + "fp16", + multi_a, + 0, + multi_b, + 0, + accum, + 0, + meta_local, + 0, + False, + dtype="float16", + ) + ) + + for i in range(4): + C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = T.load("float16", accum, i) + + @tvm.testing.requires_cuda def test_mma_sp_m16n8k16_fp16(): sch = tvm.tir.Schedule(mma_sp_m16n8k16_fp16) @@ -118,7 +187,6 @@ def test_mma_sp_m16n8k16_fp16(): # Requires SM80+ return cuda_mod = tvm.build(sch.mod, target="cuda") - print(cuda_mod.imported_modules[0].get_source()) A_np = np.random.uniform(-1, 1, [16, 8]).astype("float16") B_np = np.random.uniform(-1, 1, [16, 8]).astype("float16") @@ -137,5 +205,33 @@ def test_mma_sp_m16n8k16_fp16(): tvm.testing.assert_allclose(C_tvm.numpy(), C_np, atol=1e-3, rtol=1e-3) +@tvm.testing.requires_cuda +def test_mma_sp_m16n8k32_fp16(): + sch = tvm.tir.Schedule(mma_sp_m16n8k32_fp16) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, _ = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # Requires SM80+ + return + cuda_mod = tvm.build(sch.mod, target="cuda") + + A_np = np.random.uniform(-1, 1, [16, 16]).astype("float16") + B_np = np.random.uniform(-1, 1, [32, 8]).astype("float16") + mask = gen_2in4_mask(16, 32) + A_dense_np = get_dense_mat_by_mask(A_np, mask) + C_np = np.matmul(A_dense_np, B_np) + meta = get_meta_m16n8k32_half(mask) + + ctx = tvm.cuda() + A_tvm = tvm.nd.array(A_np, ctx) + B_tvm = tvm.nd.array(B_np, ctx) + C_tvm = tvm.nd.array(np.zeros_like(C_np), ctx) + meta_tvm = tvm.nd.array(meta, ctx) + cuda_mod(A_tvm, B_tvm, C_tvm, meta_tvm) + + tvm.testing.assert_allclose(C_tvm.numpy(), C_np, atol=1e-3, rtol=1e-3) + + if __name__ == "__main__": test_mma_sp_m16n8k16_fp16() + test_mma_sp_m16n8k32_fp16() From a1afc4808406f6ab977bae20d29ed2595da183a8 Mon Sep 17 00:00:00 2001 From: Zihao Date: Tue, 22 Feb 2022 10:45:08 -0800 Subject: [PATCH 08/15] format --- tests/python/unittest/test_tir_ptx_mma_sp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_ptx_mma_sp.py b/tests/python/unittest/test_tir_ptx_mma_sp.py index e722680bece8..b94478f2da07 100644 --- a/tests/python/unittest/test_tir_ptx_mma_sp.py +++ b/tests/python/unittest/test_tir_ptx_mma_sp.py @@ -68,7 +68,7 @@ def get_meta_m16n8k32_half(mask): for j in range(8): ret[i, k] |= int(mask[blk * 8 + i, k, j]) * base base = base << 2 - + return ret.reshape(16) From bbf7cec3d14b7bf9d5ee3e3c091650a7b1537a87 Mon Sep 17 00:00:00 2001 From: Zihao Date: Tue, 22 Feb 2022 12:11:52 -0800 Subject: [PATCH 09/15] use make_tuple instead of initializer list --- src/target/source/ptx_mma_sp.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/target/source/ptx_mma_sp.cc b/src/target/source/ptx_mma_sp.cc index 9b6a53768687..a1dd5e50b4a1 100644 --- a/src/target/source/ptx_mma_sp.cc +++ b/src/target/source/ptx_mma_sp.cc @@ -125,7 +125,7 @@ std::tuple ParseMMAShape(const std::string& str) { << "Cannot parse MMA shape " << str; int m = std::stoi(str.substr(pos_m + 1, pos_n - pos_m - 1)), n = std::stoi(str.substr(pos_n + 1, pos_k - pos_n - 1)), k = std::stoi(str.substr(pos_k + 1)); - return {m, n, k}; + return std::make_tuple(m, n, k); } /*! @@ -142,16 +142,16 @@ inline std::tuple FragmentAttrs(DataType dtype) { case DataType::kFloat16: // .f16x2 register case DataType::kBFloat16: case DataType::kTensorFloat32: - return {'r', 32, "(unsigned *)"}; + return std::make_tuple('r', 32, "(unsigned *)"); case DataType::kInt32: - return {'r', 32, "(int *)"}; + return std::make_tuple('r', 32, "(int *)"); case DataType::kFloat32: - return {'f', 32, "(float *)"}; + return std::make_tuple('f', 32, "(float *)"); case DataType::kFloat64: - return {'d', 64, "(double *)"}; + return std::make_tuple('d', 64, "(double *)"); default: LOG(FATAL) << DTypeToString(dtype) << " is not matrix data type in MMA."; - return {'\0', 0, ""}; + return std::make_tuple('\0', 0, ""); } } @@ -250,7 +250,7 @@ inline std::tuple get_mma_sp_operands( outputs << " \"=" << std::get<0>(frag_attr_c) << "\"((" << std::get<2>(frag_attr_c) << "(D))[" << i << "])"; } - return {templates.str(), inputs.str(), outputs.str()}; + return std::make_tuple(templates.str(), inputs.str(), outputs.str()); } std::string PrintMMASparseAssembly(const std::string& shape, const std::string& A_layout, From c4e716addc61ed0d73a8db00ef925d2336c03670 Mon Sep 17 00:00:00 2001 From: Zihao Date: Wed, 23 Feb 2022 14:28:48 -0800 Subject: [PATCH 10/15] add metadata offset --- include/tvm/tir/builtin.h | 3 +- src/target/source/codegen_cuda.cc | 18 +- src/target/source/ptx_mma_sp.cc | 6 +- src/target/source/ptx_mma_sp.h | 2 +- tests/python/unittest/test_tir_ptx_mma_sp.py | 266 +++++++++++++------ 5 files changed, 205 insertions(+), 90 deletions(-) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 8e5f5838d1d6..0d9f82305352 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -604,7 +604,8 @@ TVM_DLL const Op& ptx_mma(); * Var multiplicand_a, Expr a_index, * Var multiplicand_b, Expr b_index, * Var accumulator, Expr c_index, - * Var metadata, Var sparse_selector, bool saturate); + * Var metadata, Expr meta_index, + * Var sparse_selector, bool saturate); */ TVM_DLL const Op& ptx_mma_sp(); diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 9fda1eb82695..4562a8702e9b 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -777,9 +777,10 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { // arg 10: C accumulator // arg 11: C accumulator index // arg 12: metadata - // arg 13: sparse_selector - // arg 14: saturate - ICHECK_EQ(op->args.size(), 15U); + // arg 13: metadata index + // arg 14: sparse_selector + // arg 15: saturate + ICHECK_EQ(op->args.size(), 16U); std::string shape = Downcast(op->args[0])->value; std::string A_layout = Downcast(op->args[1])->value; std::string B_layout = Downcast(op->args[2])->value; @@ -793,11 +794,12 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string c_ref = this->PrintExpr(op->args[10]); std::string c_offset = this->PrintExpr(op->args[11]); std::string metadata = this->PrintExpr(op->args[12]); - std::string sparse_selector = this->PrintExpr(op->args[13]); - bool saturate = (Downcast(op->args[14])->value != 0); - std::string asm_code = PrintMMASparseAssembly(shape, A_layout, B_layout, A_dtype, B_dtype, - C_dtype, a_ref, a_offset, b_ref, b_offset, c_ref, - c_offset, metadata, sparse_selector, saturate); + std::string metadata_offset = this->PrintExpr(op->args[13]); + std::string sparse_selector = this->PrintExpr(op->args[14]); + bool saturate = Downcast(op->args[15])->value; + std::string asm_code = PrintMMASparseAssembly( + shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset, b_ref, b_offset, + c_ref, c_offset, metadata, metadata_offset, sparse_selector, saturate); this->stream << asm_code; } else { CodeGenC::VisitExpr_(op, os); diff --git a/src/target/source/ptx_mma_sp.cc b/src/target/source/ptx_mma_sp.cc index a1dd5e50b4a1..cc1c8b7bd012 100644 --- a/src/target/source/ptx_mma_sp.cc +++ b/src/target/source/ptx_mma_sp.cc @@ -240,7 +240,7 @@ inline std::tuple get_mma_sp_operands( inputs << ", \"" << std::get<0>(frag_attr_c) << "\"((" << std::get<2>(frag_attr_c) << "(C))[" << i << "])"; } - inputs << ", \"r\"(E[0])"; + inputs << ", \"r\"(((unsigned *)(E))[0])"; // generate outputs for (int i = 0; i < num_operands_c; ++i) { @@ -259,7 +259,7 @@ std::string PrintMMASparseAssembly(const std::string& shape, const std::string& const std::string& a_ref, const std::string& a_offset, const std::string& b_ref, const std::string& b_offset, const std::string& c_ref, const std::string& c_offset, - const std::string& metadata, + const std::string& metadata, const std::string& metadata_offset, const std::string& sparsity_selector, bool saturate) { ptx::DataType dtype_a = ptx::DTypeFromString(A_dtype), dtype_b = ptx::DTypeFromString(B_dtype), dtype_c = ptx::DTypeFromString(C_dtype); @@ -297,7 +297,7 @@ std::string PrintMMASparseAssembly(const std::string& shape, const std::string& replacer.register_rule("B", b_ref + " + " + b_offset); replacer.register_rule("C", c_ref + " + " + c_offset); replacer.register_rule("D", c_ref + " + " + c_offset); - replacer.register_rule("E", metadata); + replacer.register_rule("E", metadata + " + " + metadata_offset); replacer.register_rule("F", sparsity_selector); asm_code = replacer.rewrite(asm_code); return asm_code; diff --git a/src/target/source/ptx_mma_sp.h b/src/target/source/ptx_mma_sp.h index 8200e60a943f..8ee7099db2c4 100644 --- a/src/target/source/ptx_mma_sp.h +++ b/src/target/source/ptx_mma_sp.h @@ -38,7 +38,7 @@ std::string PrintMMASparseAssembly(const std::string& shape, const std::string& const std::string& a_ref, const std::string& a_offset, const std::string& b_ref, const std::string& b_offset, const std::string& c_ref, const std::string& c_offset, - const std::string& metadata, + const std::string& metadata, const std::string& metadata_offset, const std::string& sparsity_selector, bool saturate); } // namespace codegen diff --git a/tests/python/unittest/test_tir_ptx_mma_sp.py b/tests/python/unittest/test_tir_ptx_mma_sp.py index b94478f2da07..93c7915c2ab5 100644 --- a/tests/python/unittest/test_tir_ptx_mma_sp.py +++ b/tests/python/unittest/test_tir_ptx_mma_sp.py @@ -42,42 +42,66 @@ def get_dense_mat_by_mask(val, mask): return ret.reshape(m, n_chunks * 4) -def get_meta_m16n8k16_half(mask): - assert mask.shape == (16, 4, 2) - mask = mask.reshape(16, 8) - ret = np.zeros((8,)).astype("uint32") +@T.prim_func +def mma_sp_m16n8k16_f16f16f16(a: T.handle, b: T.handle, c: T.handle, _metadata: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [16, 8], dtype="float16") + B = T.match_buffer(b, [16, 8], dtype="float16") + C = T.match_buffer(c, [16, 8], dtype="float16") + metadata = T.match_buffer(_metadata, [8], dtype="uint32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + multi_a = T.allocate([4], "float16", scope="local") + multi_b = T.allocate([4], "float16", scope="local") + accum = T.allocate([4], "float16", scope="local") + meta_local = T.allocate([1], "uint32", scope="local") + for i in range(4): + accum[i] = T.float16(0) - for i in range(8): - base = 1 - for blk in range(2): - for j in range(8): - ret[i] |= int(mask[blk * 8 + i, j]) * base - base = base << 2 - return ret + for i in range(4): + multi_a[i] = A[tx // 4 + i // 2 * 8, tx % 4 * 2 + i % 2] + for i in range(4): + multi_b[i] = B[tx % 4 * 2 + i % 2 + i // 2 * 8, tx // 4] -def get_meta_m16n8k32_half(mask): - assert mask.shape == (16, 8, 2) - mask = mask.reshape(16, 2, 8) - ret = np.zeros((8, 2)).astype("uint32") + meta_local[0] = metadata[tx // 4] - for i in range(8): - for k in range(2): - base = 1 - for blk in range(2): - for j in range(8): - ret[i, k] |= int(mask[blk * 8 + i, k, j]) * base - base = base << 2 + T.evaluate( + T.ptx_mma_sp( + "m16n8k16", + "row", + "col", + "fp16", + "fp16", + "fp16", + multi_a, + 0, + multi_b, + 0, + accum, + 0, + meta_local, + 0, + 0, + False, + dtype="float16", + ) + ) - return ret.reshape(16) + for i in range(4): + C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = T.load("float16", accum, i) @T.prim_func -def mma_sp_m16n8k16_fp16(a: T.handle, b: T.handle, c: T.handle, _metadata: T.handle): +def mma_sp_m16n8k16_f16f16f32(a: T.handle, b: T.handle, c: T.handle, _metadata: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) A = T.match_buffer(a, [16, 8], dtype="float16") B = T.match_buffer(b, [16, 8], dtype="float16") - C = T.match_buffer(c, [16, 8], dtype="float16") + C = T.match_buffer(c, [16, 8], dtype="float32") metadata = T.match_buffer(_metadata, [8], dtype="uint32") brow = T.env_thread("blockIdx.y") bcol = T.env_thread("blockIdx.x") @@ -87,7 +111,7 @@ def mma_sp_m16n8k16_fp16(a: T.handle, b: T.handle, c: T.handle, _metadata: T.han T.launch_thread(tx, 32) multi_a = T.allocate([4], "float16", scope="local") multi_b = T.allocate([4], "float16", scope="local") - accum = T.allocate([4], "float16", scope="local") + accum = T.allocate([4], "float32", scope="local") meta_local = T.allocate([1], "uint32", scope="local") for i in range(4): accum[i] = T.float16(0) @@ -107,7 +131,7 @@ def mma_sp_m16n8k16_fp16(a: T.handle, b: T.handle, c: T.handle, _metadata: T.han "col", "fp16", "fp16", - "fp16", + "fp32", multi_a, 0, multi_b, @@ -116,17 +140,18 @@ def mma_sp_m16n8k16_fp16(a: T.handle, b: T.handle, c: T.handle, _metadata: T.han 0, meta_local, 0, + 0, False, - dtype="float16", + dtype="float32", ) ) for i in range(4): - C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = T.load("float16", accum, i) + C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = T.load("float32", accum, i) @T.prim_func -def mma_sp_m16n8k32_fp16(a: T.handle, b: T.handle, c: T.handle, _metadata: T.handle): +def mma_sp_m16n8k32_f16f16f16(a: T.handle, b: T.handle, c: T.handle, _metadata: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) A = T.match_buffer(a, [16, 16], dtype="float16") B = T.match_buffer(b, [32, 8], dtype="float16") @@ -169,6 +194,7 @@ def mma_sp_m16n8k32_fp16(a: T.handle, b: T.handle, c: T.handle, _metadata: T.han 0, meta_local, 0, + 0, False, dtype="float16", ) @@ -178,60 +204,146 @@ def mma_sp_m16n8k32_fp16(a: T.handle, b: T.handle, c: T.handle, _metadata: T.han C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = T.load("float16", accum, i) +@T.prim_func +def mma_sp_m16n8k32_f16f16f32(a: T.handle, b: T.handle, c: T.handle, _metadata: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [16, 16], dtype="float16") + B = T.match_buffer(b, [32, 8], dtype="float16") + C = T.match_buffer(c, [16, 8], dtype="float32") + metadata = T.match_buffer(_metadata, [16], dtype="uint32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + multi_a = T.allocate([8], "float16", scope="local") + multi_b = T.allocate([8], "float16", scope="local") + accum = T.allocate([4], "float32", scope="local") + meta_local = T.allocate([1], "uint32", scope="local") + for i in range(4): + accum[i] = T.float16(0) + + for i in range(8): + multi_a[i] = A[(i % 4) // 2 * 8 + tx // 4, i // 4 * 8 + tx % 4 * 2 + i % 2] + + for i in range(8): + multi_b[i] = B[i // 2 * 8 + tx % 4 * 2 + i % 2, tx // 4] + + meta_local[0] = metadata[tx // 4 * 2 + tx % 2] + + T.evaluate( + T.ptx_mma_sp( + "m16n8k32", + "row", + "col", + "fp16", + "fp16", + "fp32", + multi_a, + 0, + multi_b, + 0, + accum, + 0, + meta_local, + 0, + 0, + False, + dtype="float32", + ) + ) + + for i in range(4): + C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = T.load("float32", accum, i) + + @tvm.testing.requires_cuda -def test_mma_sp_m16n8k16_fp16(): - sch = tvm.tir.Schedule(mma_sp_m16n8k16_fp16) - arch = tvm.contrib.nvcc.get_target_compute_version() - major, _ = tvm.contrib.nvcc.parse_compute_version(arch) - if major < 8: - # Requires SM80+ - return - cuda_mod = tvm.build(sch.mod, target="cuda") - - A_np = np.random.uniform(-1, 1, [16, 8]).astype("float16") - B_np = np.random.uniform(-1, 1, [16, 8]).astype("float16") - mask = gen_2in4_mask(16, 16) - A_dense_np = get_dense_mat_by_mask(A_np, mask) - C_np = np.matmul(A_dense_np, B_np) - meta = get_meta_m16n8k16_half(mask) - - ctx = tvm.cuda() - A_tvm = tvm.nd.array(A_np, ctx) - B_tvm = tvm.nd.array(B_np, ctx) - C_tvm = tvm.nd.array(np.zeros_like(C_np), ctx) - meta_tvm = tvm.nd.array(meta, ctx) - cuda_mod(A_tvm, B_tvm, C_tvm, meta_tvm) +def test_mma_sp_m16n8k16_f16(): + def get_meta_m16n8k16_half(mask): + assert mask.shape == (16, 4, 2) + mask = mask.reshape(16, 8) + ret = np.zeros((8,)).astype("uint32") - tvm.testing.assert_allclose(C_tvm.numpy(), C_np, atol=1e-3, rtol=1e-3) + for i in range(8): + base = 1 + for blk in range(2): + for j in range(8): + ret[i] |= int(mask[blk * 8 + i, j]) * base + base = base << 2 + return ret + + for out_dtype in ["float16", "float32"]: + func = mma_sp_m16n8k16_f16f16f16 if out_dtype == "float16" else mma_sp_m16n8k16_f16f16f32 + sch = tvm.tir.Schedule(func) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, _ = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # Requires SM80+ + return + cuda_mod = tvm.build(sch.mod, target="cuda") + + A_np = np.random.uniform(-1, 1, [16, 8]).astype("float16") + B_np = np.random.uniform(-1, 1, [16, 8]).astype("float16") + mask = gen_2in4_mask(16, 16) + A_dense_np = get_dense_mat_by_mask(A_np, mask) + C_np = np.matmul(A_dense_np, B_np).astype(out_dtype) + meta = get_meta_m16n8k16_half(mask) + + ctx = tvm.cuda() + A_tvm = tvm.nd.array(A_np, ctx) + B_tvm = tvm.nd.array(B_np, ctx) + C_tvm = tvm.nd.array(np.zeros_like(C_np), ctx) + meta_tvm = tvm.nd.array(meta, ctx) + cuda_mod(A_tvm, B_tvm, C_tvm, meta_tvm) + + tvm.testing.assert_allclose(C_tvm.numpy(), C_np, atol=1e-3, rtol=1e-3) @tvm.testing.requires_cuda -def test_mma_sp_m16n8k32_fp16(): - sch = tvm.tir.Schedule(mma_sp_m16n8k32_fp16) - arch = tvm.contrib.nvcc.get_target_compute_version() - major, _ = tvm.contrib.nvcc.parse_compute_version(arch) - if major < 8: - # Requires SM80+ - return - cuda_mod = tvm.build(sch.mod, target="cuda") - - A_np = np.random.uniform(-1, 1, [16, 16]).astype("float16") - B_np = np.random.uniform(-1, 1, [32, 8]).astype("float16") - mask = gen_2in4_mask(16, 32) - A_dense_np = get_dense_mat_by_mask(A_np, mask) - C_np = np.matmul(A_dense_np, B_np) - meta = get_meta_m16n8k32_half(mask) - - ctx = tvm.cuda() - A_tvm = tvm.nd.array(A_np, ctx) - B_tvm = tvm.nd.array(B_np, ctx) - C_tvm = tvm.nd.array(np.zeros_like(C_np), ctx) - meta_tvm = tvm.nd.array(meta, ctx) - cuda_mod(A_tvm, B_tvm, C_tvm, meta_tvm) +def test_mma_sp_m16n8k32_f16(): + def get_meta_m16n8k32_half(mask): + assert mask.shape == (16, 8, 2) + mask = mask.reshape(16, 2, 8) + ret = np.zeros((8, 2)).astype("uint32") + + for i in range(8): + for k in range(2): + base = 1 + for blk in range(2): + for j in range(8): + ret[i, k] |= int(mask[blk * 8 + i, k, j]) * base + base = base << 2 + + return ret.reshape(16) + + for out_dtype in ["float16", "float32"]: + func = mma_sp_m16n8k32_f16f16f16 if out_dtype == "float16" else mma_sp_m16n8k32_f16f16f32 + sch = tvm.tir.Schedule(func) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, _ = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # Requires SM80+ + return + cuda_mod = tvm.build(sch.mod, target="cuda") + + A_np = np.random.uniform(-1, 1, [16, 16]).astype("float16") + B_np = np.random.uniform(-1, 1, [32, 8]).astype("float16") + mask = gen_2in4_mask(16, 32) + A_dense_np = get_dense_mat_by_mask(A_np, mask) + C_np = np.matmul(A_dense_np, B_np).astype(out_dtype) + meta = get_meta_m16n8k32_half(mask) + + ctx = tvm.cuda() + A_tvm = tvm.nd.array(A_np, ctx) + B_tvm = tvm.nd.array(B_np, ctx) + C_tvm = tvm.nd.array(np.zeros_like(C_np), ctx) + meta_tvm = tvm.nd.array(meta, ctx) + cuda_mod(A_tvm, B_tvm, C_tvm, meta_tvm) tvm.testing.assert_allclose(C_tvm.numpy(), C_np, atol=1e-3, rtol=1e-3) if __name__ == "__main__": - test_mma_sp_m16n8k16_fp16() - test_mma_sp_m16n8k32_fp16() + test_mma_sp_m16n8k16_f16() + test_mma_sp_m16n8k32_f16() From 24e95e836aceb5d42b09e591c44540bf91db6525 Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 24 Feb 2022 19:09:43 -0800 Subject: [PATCH 11/15] upd --- src/target/source/codegen_cuda.cc | 16 +- src/target/source/ptx_mma.cc | 1766 +++++------------- src/target/source/ptx_mma.h | 9 +- src/target/source/ptx_mma_sp.cc | 307 --- src/target/source/ptx_mma_sp.h | 47 - tests/python/unittest/test_tir_ptx_mma.py | 6 +- tests/python/unittest/test_tir_ptx_mma_sp.py | 3 - 7 files changed, 475 insertions(+), 1679 deletions(-) delete mode 100644 src/target/source/ptx_mma_sp.cc delete mode 100644 src/target/source/ptx_mma_sp.h diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 4562a8702e9b..f74d5cf484b9 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -34,7 +34,6 @@ #include "literal/cuda_half_t.h" #include "ptx_mma.h" -#include "ptx_mma_sp.h" namespace tvm { namespace codegen { @@ -745,7 +744,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { // arg 10: C accumulator // arg 11: C accumulator index // arg 12: saturate - ICHECK_EQ(op->args.size(), 13U); + // arg 13: (optional) 1-bit operator (xor or and) + ICHECK(op->args.size() == 13U || op->args.size() == 14U); std::string shape = Downcast(op->args[0])->value; std::string A_layout = Downcast(op->args[1])->value; std::string B_layout = Downcast(op->args[2])->value; @@ -758,9 +758,11 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string b_bias = this->PrintExpr(op->args[9]); std::string c_ref = this->PrintExpr(op->args[10]); std::string c_bias = this->PrintExpr(op->args[11]); - bool saturate = (Downcast(op->args[12])->value != 0); - std::string asm_code = PrintMMAAssembly(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, - a_ref, a_bias, b_ref, b_bias, c_ref, c_bias, saturate); + bool saturate = Downcast(op->args[12])->value; + std::string bit_op = op->args.size() > 13 ? Downcast(op->args[13])->value : ""; + std::string asm_code = + PrintMMAAssembly(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, b_ref, + b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate); this->stream << asm_code; } else if (op->op.same_as(builtin::ptx_mma_sp())) { @@ -797,9 +799,9 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string metadata_offset = this->PrintExpr(op->args[13]); std::string sparse_selector = this->PrintExpr(op->args[14]); bool saturate = Downcast(op->args[15])->value; - std::string asm_code = PrintMMASparseAssembly( + std::string asm_code = PrintMMAAssembly( shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset, b_ref, b_offset, - c_ref, c_offset, metadata, metadata_offset, sparse_selector, saturate); + c_ref, c_offset, metadata, metadata_offset, sparse_selector, "", true, saturate); this->stream << asm_code; } else { CodeGenC::VisitExpr_(op, os); diff --git a/src/target/source/ptx_mma.cc b/src/target/source/ptx_mma.cc index b6182720416c..49d64b1dc072 100644 --- a/src/target/source/ptx_mma.cc +++ b/src/target/source/ptx_mma.cc @@ -23,1351 +23,503 @@ #include "ptx_mma.h" +#include +#include +#include +#include +#include + namespace tvm { namespace codegen { -std::string ReplaceMMAArgument(std::string asm_code, const std::string& original, - const std::string& new_arg) { - size_t len = original.size(); - size_t new_len = new_arg.size(); - size_t pos = asm_code.find(original); - while (pos != std::string::npos) { - asm_code = asm_code.replace(pos, len, new_arg); - pos = asm_code.find(original, pos + new_len); - } - return asm_code; -} +// PTX related data structures and functions. +namespace ptx { -std::string PrintMMAm8n8k4Assembly(const std::string& A_layout, const std::string& B_layout, - const std::string& A_dtype, const std::string& B_dtype, - const std::string& C_dtype, const std::string& a_ref, - const std::string& a_bias, const std::string& b_ref, - const std::string& b_bias, const std::string& c_ref, - const std::string& c_bias, bool saturate) { - std::string asm_code = ""; - std::string new_a_ref = ""; - std::string new_b_ref = ""; - std::string new_c_ref = ""; - ICHECK(((A_dtype == "fp16") && (B_dtype == "fp16")) || - ((A_dtype == "fp64") && (B_dtype == "fp64"))); - ICHECK(saturate == false) << "Saturate is not allowed for m8n8k4 mma."; - if ((A_dtype == "fp16") && (B_dtype == "fp16")) { - // A/B multiplicand is fp16, SM 70 Tensor Core instructions - ICHECK((C_dtype == "fp16") || (C_dtype == "fp32")); - if (C_dtype == "fp16") { - // C accumulator is fp16 - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((unsigned *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k4.left_layout.right_layout.f16.f16.f16.f16 " - "{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, " - "{%8,%9,%10,%11};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // C accumulator is fp32 - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k4.left_layout.right_layout.f32.f16.f16.f32 " - "{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " - "{%12,%13,%14,%15,%16,%17,%18,%19};\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]), - "=f"(D[4]), "=f"(D[5]), "=f"(D[6]), "=f"(D[7]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), - "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); - } - )"; - } +/*! + * \brief PTX data type. + */ +enum class DataType : int { + kInt4 = 0, + kUInt4 = 1, + kInt8 = 2, + kUInt8 = 3, + kInt16 = 4, + kUInt16 = 5, + kInt32 = 6, + kUInt32 = 7, + kInt64 = 8, + kUInt64 = 9, + kFloat16 = 10, + kBFloat16 = 11, + kFloat16x2 = 12, + kFloat32 = 13, + kTensorFloat32 = 14, + kFloat64 = 15, + kBit1 = 16 +}; + +static const char* dtype_str[] = {".s4", ".u4", ".s8", ".u8", ".s16", ".u16", + ".s32", ".u32", ".s64", ".u64", ".f16", ".bf16", + ".f16x2", ".f32", ".tf32", ".f64", ".b1"}; +static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 16, 16, 32, 32, 32, 64, 1}; + +/*! + * \brief Create PTX data type from string. + */ +inline DataType DTypeFromString(const std::string str) { + if (str == "int4" || str == ".s4") { + return DataType::kInt4; + } else if (str == "uint4" || str == ".u4") { + return DataType::kUInt4; + } else if (str == "int8" || str == ".s8") { + return DataType::kInt8; + } else if (str == "uint8" || str == ".u8") { + return DataType::kUInt8; + } else if (str == "int16" || str == ".s16") { + return DataType::kInt16; + } else if (str == "uint16" || str == ".u16") { + return DataType::kUInt16; + } else if (str == "int32" || str == ".s32") { + return DataType::kInt32; + } else if (str == "uint32" || str == ".u32") { + return DataType::kUInt32; + } else if (str == "int64" || str == ".s64") { + return DataType::kInt64; + } else if (str == "uint64" || str == ".u64") { + return DataType::kUInt64; + } else if (str == "float16" || str == "fp16" || str == ".f16") { + return DataType::kFloat16; + } else if (str == "bfloat16" || str == "bf16") { + return DataType::kBFloat16; + } else if (str == ".f16x2") { + return DataType::kFloat16x2; + } else if (str == "float32" || str == "fp32" || str == ".f32") { + return DataType::kFloat32; + } else if (str == "tf32") { + return DataType::kTensorFloat32; + } else if (str == "float64" || str == "fp64" || str == ".f64") { + return DataType::kFloat64; + } else if (str == "int1" || str == ".b1") { + return DataType::kBit1; } else { - // A/B multiplicand is fp64, SM 80 Tensor Core instructions - ICHECK(C_dtype == "fp64"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Fp64 Tensor Core instructions " - << "with shape m8n8k4 expect A layout is row major and B layout is col major."; - // C accumulator is fp64 - new_a_ref = "((double *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((double *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((double *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=d"(D[0]), "=d"(D[1]) - : "d"(A[0]), "d"(B[0]), - "d"(C[0]), "d"(C[1])); - } - )"; + LOG(FATAL) << "Unrecognized PTX data type " << str; + return DataType(0); } - asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); - asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); - asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); - asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); - asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); - asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); - return asm_code; } -std::string PrintMMAm16n8k8Assembly(const std::string& A_layout, const std::string& B_layout, - const std::string& A_dtype, const std::string& B_dtype, - const std::string& C_dtype, const std::string& a_ref, - const std::string& a_bias, const std::string& b_ref, - const std::string& b_bias, const std::string& c_ref, - const std::string& c_bias, bool saturate) { - std::string asm_code = ""; - std::string new_a_ref = ""; - std::string new_b_ref = ""; - std::string new_c_ref = ""; - ICHECK(((A_dtype == "fp16") && (B_dtype == "fp16")) || - ((A_dtype == "bf16") && (B_dtype == "bf16"))); - ICHECK(saturate == false) << "Saturate is not allowed for m16n8k8 mma."; - if ((A_dtype == "fp16") && (B_dtype == "fp16")) { - // A/B multiplicand is fp16, SM 75 Tensor Core instructions - ICHECK((C_dtype == "fp16") || (C_dtype == "fp32")); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM75 Tensor Core instructions " - << "with shape m16n8k8 expect A layout is row major and B layout is col major."; - if (C_dtype == "fp16") { - // C accumulator is fp16 - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((unsigned *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " - "{%0,%1}, {%2,%3}, {%5}, " - "{%5,%6};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } else { - // C accumulator is fp32 - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); - } - )"; - } +/*! + * \brief Get the string representation of given PTX data type. + */ +inline std::string DTypeToString(DataType dtype) { return dtype_str[static_cast(dtype)]; } + +/*! + * \brief Get the number of bits of given PTX data type. + */ +inline uint32_t DTypeBits(DataType dtype) { return num_bits[static_cast(dtype)]; } + +/*! + * \brief Extract the value m, n, k from string m*n*k* + */ +inline std::tuple ParseMMAShape(const std::string& str) { + size_t pos_m = str.find("m"), pos_n = str.find("n"), pos_k = str.find("k"); + CHECK(pos_m != str.npos && pos_n != str.npos && pos_k != str.npos) + << "Cannot parse MMA shape " << str; + int m = std::stoi(str.substr(pos_m + 1, pos_n - pos_m - 1)), + n = std::stoi(str.substr(pos_n + 1, pos_k - pos_n - 1)), k = std::stoi(str.substr(pos_k + 1)); + return std::make_tuple(m, n, k); +} + +/*! + * \brief Layout Type + */ +enum class LayoutType : int { kRowMajor = 0, kColumnMajor = 1 }; + +/*! + * \brief Parse layout type + */ +LayoutType LayoutTypeFromString(const std::string& str) { + if (str == "row") { + return LayoutType::kRowMajor; + } else if (str == "col") { + return LayoutType::kColumnMajor; } else { - // A/B multiplicand is bf16, SM 80 Tensor Core instructions - ICHECK(C_dtype == "fp32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k8 expect A layout is row major and B layout is col major."; - // C accumulator is fp32 - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); - } - )"; + LOG(FATAL) << "Unrecognized layout type " << str; + return LayoutType::kRowMajor; } - asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); - asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); - asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); - asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); - asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); - asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); - return asm_code; } -std::string PrintMMAm8n8k16Assembly(const std::string& A_layout, const std::string& B_layout, - const std::string& A_dtype, const std::string& B_dtype, - const std::string& C_dtype, const std::string& a_ref, - const std::string& a_bias, const std::string& b_ref, - const std::string& b_bias, const std::string& c_ref, - const std::string& c_bias, bool saturate) { - std::string asm_code = ""; - std::string new_a_ref = ""; - std::string new_b_ref = ""; - std::string new_c_ref = ""; - ICHECK(((A_dtype == "int8") && (B_dtype == "int8")) || - ((A_dtype == "uint8") && (B_dtype == "int8")) || - ((A_dtype == "int8") && (B_dtype == "uint8")) || - ((A_dtype == "uint8") && (B_dtype == "uint8"))); - if ((A_dtype == "int8") && (B_dtype == "int8")) { - // A/B multiplicand is int8, SM 75 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM75 Tensor Core instructions " - << "with shape m8n8k16 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } - } else if ((A_dtype == "uint8") && (B_dtype == "int8")) { - // A multiplicand is uint8, B multiplicand is int8 - // SM 75 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM75 Tensor Core instructions " - << "with shape m8n8k16 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.s8.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } - } else if ((A_dtype == "int8") && (B_dtype == "uint8")) { - // A multiplicand is int8, B multiplicand is uint8 - // SM 75 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM75 Tensor Core instructions " - << "with shape m8n8k16 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k16.row.col.s32.s8.u8.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.u8.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } - } else { - // A/B multiplicand is uint8, SM 75 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM75 Tensor Core instructions " - << "with shape m8n8k16 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.u8.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } +static const char* layout_type_str[] = {"row", "col"}; + +/*! + * \brief Convert layout type to string. + */ +inline std::string LayoutTypeToString(LayoutType ltype) { return layout_type_str[int(ltype)]; } + +/*! + * \brief MMA Configurations, used to determine validity. + */ +struct MMAConfig { + explicit MMAConfig(int m, int n, int k, DataType dtype_mul, bool use_bit_op, bool sparse) + : m(m), n(n), k(k), dtype_mul(dtype_mul), use_bit_op(use_bit_op), sparse(sparse) {} + int m, n, k; + DataType dtype_mul; + bool use_bit_op; + bool sparse; + inline bool operator==(const MMAConfig& other) { + return m == other.m && n == other.n && k == other.k && dtype_mul == other.dtype_mul && + use_bit_op == other.use_bit_op && sparse == other.sparse; + } +}; + +const MMAConfig valid_mma_configs[] = { + MMAConfig(8, 8, 4, DataType::kFloat64, false, false), + MMAConfig(8, 8, 4, DataType::kFloat16, false, false), + MMAConfig(16, 8, 8, DataType::kFloat16, false, false), + MMAConfig(16, 8, 16, DataType::kFloat16, false, false), + MMAConfig(16, 8, 8, DataType::kBFloat16, false, false), + MMAConfig(16, 8, 16, DataType::kBFloat16, false, false), + MMAConfig(16, 8, 4, DataType::kTensorFloat32, false, false), + MMAConfig(16, 8, 8, DataType::kTensorFloat32, false, false), + MMAConfig(8, 8, 16, DataType::kInt8, false, false), + MMAConfig(16, 8, 16, DataType::kInt8, false, false), + MMAConfig(16, 8, 32, DataType::kInt8, false, false), + MMAConfig(8, 8, 16, DataType::kUInt8, false, false), + MMAConfig(16, 8, 16, DataType::kUInt8, false, false), + MMAConfig(16, 8, 32, DataType::kUInt8, false, false), + MMAConfig(8, 8, 32, DataType::kInt4, false, false), + MMAConfig(16, 8, 32, DataType::kInt4, false, false), + MMAConfig(16, 8, 64, DataType::kInt4, false, false), + MMAConfig(8, 8, 32, DataType::kUInt4, false, false), + MMAConfig(16, 8, 32, DataType::kUInt4, false, false), + MMAConfig(16, 8, 64, DataType::kUInt4, false, false), + MMAConfig(8, 8, 128, DataType::kBit1, true, false), + MMAConfig(16, 8, 128, DataType::kBit1, true, false), + MMAConfig(16, 8, 256, DataType::kBit1, true, false), + MMAConfig(16, 8, 16, DataType::kFloat16, false, true), + MMAConfig(16, 8, 32, DataType::kFloat16, false, true), + MMAConfig(16, 8, 16, DataType::kBFloat16, false, true), + MMAConfig(16, 8, 32, DataType::kBFloat16, false, true), + MMAConfig(16, 8, 8, DataType::kTensorFloat32, false, true), + MMAConfig(16, 8, 16, DataType::kTensorFloat32, false, true), + MMAConfig(16, 8, 32, DataType::kInt8, false, true), + MMAConfig(16, 8, 64, DataType::kInt8, false, true), + MMAConfig(16, 8, 32, DataType::kUInt8, false, true), + MMAConfig(16, 8, 64, DataType::kUInt8, false, true), + MMAConfig(16, 8, 64, DataType::kInt4, false, true), + MMAConfig(16, 8, 128, DataType::kInt4, false, true), + MMAConfig(16, 8, 64, DataType::kUInt4, false, true), + MMAConfig(16, 8, 128, DataType::kUInt4, false, true), +}; + +/*! + * \brief Check whether the multiplicand data type and accumulator data type is valid for MMA + * computation. + * \param mul The multiplicand data type. + * \param acc The accumulator data type. + */ +void CheckMMADTypeCompatible(DataType mul, DataType acc) { + switch (mul) { + case DataType::kBit1: + case DataType::kInt4: + case DataType::kUInt4: + case DataType::kInt8: + case DataType::kUInt8: + CHECK(acc == DataType::kInt32) << "For multiplicand data type " << DTypeToString(mul) + << ", accumulator data type should be s32."; + break; + case DataType::kFloat16: + CHECK(acc == DataType::kFloat16 || acc == DataType::kFloat32) + << "For multiplicand data type f16, accumulator data type should be f16/f32."; + break; + case DataType::kBFloat16: + case DataType::kTensorFloat32: + CHECK(acc == DataType::kFloat32) + << "For multiplicand data type bf16/tf32, accumulator data type can only be f32"; + break; + case DataType::kFloat64: + CHECK(acc == DataType::kFloat64) + << "For multiplicand data type f64, accumulator data type can only be f64"; + break; + default: + CHECK(false) << "Invalid multiplicand/accumulator data type pair: " << DTypeToString(mul) + << ", " << DTypeToString(acc) << "."; } - asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); - asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); - asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); - asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); - asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); - asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); - return asm_code; } -std::string PrintMMAm8n8k32Assembly(const std::string& A_layout, const std::string& B_layout, - const std::string& A_dtype, const std::string& B_dtype, - const std::string& C_dtype, const std::string& a_ref, - const std::string& a_bias, const std::string& b_ref, - const std::string& b_bias, const std::string& c_ref, - const std::string& c_bias, bool saturate) { - std::string asm_code = ""; - std::string new_a_ref = ""; - std::string new_b_ref = ""; - std::string new_c_ref = ""; - ICHECK(((A_dtype == "int4") && (B_dtype == "int4")) || - ((A_dtype == "uint4") && (B_dtype == "int4")) || - ((A_dtype == "int4") && (B_dtype == "uint4")) || - ((A_dtype == "uint4") && (B_dtype == "uint4"))); - if ((A_dtype == "int4") && (B_dtype == "int4")) { - // A/B multiplicand is int4, SM 75 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM75 Tensor Core instructions " - << "with shape m8n8k32 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.s4.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } - } else if ((A_dtype == "uint4") && (B_dtype == "int4")) { - // A multiplicand is uint4, B multiplicand is int4 - // SM 75 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM75 Tensor Core instructions " - << "with shape m8n8k32 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } - } else if ((A_dtype == "int4") && (B_dtype == "uint4")) { - // A multiplicand is int4, B multiplicand is uint4 - // SM 75 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM75 Tensor Core instructions " - << "with shape m8n8k32 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.u4.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } - } else { - // A/B multiplicand is uint4, SM 75 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM75 Tensor Core instructions " - << "with shape m8n8k32 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.u4.s32 " - "{%0,%1}, {%2}, {%3}, " - "{%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(B[0]), - "r"(C[0]), "r"(C[1])); - } - )"; +/*! + * \brief Check whether the given configuration is valid for MMA computation. + * \param m The M in mMnNkK of MMA instructions. + * \param n The N in mMnNkK of MMA instructions. + * \param k The K in mMnNkK of MMA instructions. + * \param layout_a The layout of multiplicand A (row/col). + * \param layout_b The layout of multiplicand B (row/col). + * \param dtype_a The data type of multiplicand A. + * \param dtype_b The data type of multiplicand B. + * \param dtype_c The data type of accumulator C. + * \param bit_op The bit operator for 1-bit MMA computation, can be "xor"/"and" or ""(if it's not + * 1-bit MMA). + * \param sparse Whether it's Sparse MMA or not. + * \param saturate Whether saturate output or not. + */ +void CheckMMAConfigValidity(int m, int n, int k, LayoutType layout_a, LayoutType layout_b, + DataType dtype_a, DataType dtype_b, DataType dtype_c, + const std::string& bit_op, bool sparse, bool saturate) { + CHECK(bit_op == "xor" || bit_op == "and" || bit_op == "") + << "Unrecognized 1-bit operation " << bit_op << " , can only be xor/and."; + bool use_bit_op = !bit_op.empty(); + if (use_bit_op) { + CHECK(dtype_a == DataType::kBit1) << "Bit operator is only compatible with 1bit multiplicand."; + } + CHECK(dtype_a == dtype_b) << "The multiplicand data type must be equal, found " + << DTypeToString(dtype_a) << " and " << ptx::DTypeToString(dtype_b) + << "."; + CheckMMADTypeCompatible(dtype_a, dtype_c); + if (saturate) { + CHECK(dtype_a == DataType::kInt4 || dtype_a == DataType::kUInt4 || dtype_a == DataType::kInt8 || + dtype_a == DataType::kUInt8) + << "Output saturation only applicable to multiplicand type s4/u4/s8/u8."; + } + + if (!(m == 8 && n == 8 && k == 4 && dtype_a == ptx::DataType::kFloat16)) { + // Only MMA on m8n8k4 for fp16 supports customized layouts. + CHECK(layout_a == LayoutType::kRowMajor && layout_b == LayoutType::kColumnMajor) + << "Invalid layout combination " << LayoutTypeToString(layout_a) << "," + << LayoutTypeToString(layout_b) << "."; + } + + MMAConfig config(m, n, k, dtype_a, use_bit_op, sparse); + bool match = false; + for (const MMAConfig& valid_config : valid_mma_configs) { + if (config == valid_config) { + match = true; + break; } } - asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); - asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); - asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); - asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); - asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); - asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); - return asm_code; + CHECK(match) << "Cannot find matched MMA configurations."; } -std::string PrintMMAm16n8k4Assembly(const std::string& A_layout, const std::string& B_layout, - const std::string& A_dtype, const std::string& B_dtype, - const std::string& C_dtype, const std::string& a_ref, - const std::string& a_bias, const std::string& b_ref, - const std::string& b_bias, const std::string& c_ref, - const std::string& c_bias, bool saturate) { - std::string asm_code = ""; - std::string new_a_ref = ""; - std::string new_b_ref = ""; - std::string new_c_ref = ""; - ICHECK((A_dtype == "tf32") && (B_dtype == "tf32")); - ICHECK(saturate == false) << "Saturate is not allowed for m16n8k4 mma."; - // A/B multiplicand is tf32, SM 80 Tensor Core instructions - ICHECK(C_dtype == "fp32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k4 expect A layout is row major and B layout is col major."; - // C accumulator is fp32 - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "f"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); - } - )"; - asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); - asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); - asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); - asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); - asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); - asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); - return asm_code; -} +/*! + * \brief Fragment attributes + */ +class FragAttrs { + public: + explicit FragAttrs(char reg_type, uint32_t size, std::string ptr_sig) + : reg_type(reg_type), size(size), ptr_sig(ptr_sig) {} + /*! \brief PTX register type */ + char reg_type; + /*! \brief Fragment size */ + uint32_t size; + /*! \brief Fragment pointer signature */ + std::string ptr_sig; +}; -std::string PrintMMAm16n8k16Assembly(const std::string& A_layout, const std::string& B_layout, - const std::string& A_dtype, const std::string& B_dtype, - const std::string& C_dtype, const std::string& a_ref, - const std::string& a_bias, const std::string& b_ref, - const std::string& b_bias, const std::string& c_ref, - const std::string& c_bias, bool saturate) { - std::string asm_code = ""; - std::string new_a_ref = ""; - std::string new_b_ref = ""; - std::string new_c_ref = ""; - ICHECK(((A_dtype == "fp16") && (B_dtype == "fp16")) || - ((A_dtype == "bf16") && (B_dtype == "bf16")) || - ((A_dtype == "int8") && (B_dtype == "int8")) || - ((A_dtype == "uint8") && (B_dtype == "int8")) || - ((A_dtype == "int8") && (B_dtype == "uint8")) || - ((A_dtype == "uint8") && (B_dtype == "uint8"))); - if ((A_dtype == "fp16") && (B_dtype == "fp16")) { - ICHECK(saturate == false) << "Saturate is not allowed for m16n8k8 fp16 mma."; - // A/B multiplicand is fp16, SM 80 Tensor Core instructions - ICHECK((C_dtype == "fp16") || (C_dtype == "fp32")); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k16 expect A layout is row major and B layout is col major."; - if (C_dtype == "fp16") { - // C accumulator is fp16 - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((unsigned *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " - "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, " - "{%8,%9};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1])); - } - )"; - } else { - // C accumulator is fp32 - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); - } - )"; - } - } else if ((A_dtype == "bf16") && (B_dtype == "bf16")) { - // A/B multiplicand is bf16, SM 80 Tensor Core instructions - ICHECK(saturate == false) << "Saturate is not allowed for m16n8k8 bf16 mma."; - ICHECK(C_dtype == "fp32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k16 expect A layout is row major and B layout is col major."; - // C accumulator is fp32 - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); - } - )"; - } else if ((A_dtype == "int8") && (B_dtype == "int8")) { - // A/B multiplicand is int8, SM 80 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k16 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } - } else if ((A_dtype == "uint8") && (B_dtype == "int8")) { - // A multiplicand is uint8, B multiplicand is int8 - // SM 80 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k16 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32 " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } - } else if ((A_dtype == "int8") && (B_dtype == "uint8")) { - // A multiplicand is int8, B multiplicand is uint8 - // SM 80 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k16 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32 " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } - } else { - // A/B multiplicand is uint8, SM 80 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k16 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32 " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } +/*! + * \brief Fragment attributes of given data type. + */ +inline FragAttrs GetFragAttrs(DataType dtype) { + switch (dtype) { + case DataType::kBit1: + case DataType::kInt4: + case DataType::kUInt4: + case DataType::kInt8: + case DataType::kUInt8: + case DataType::kFloat16: // .f16x2 register + case DataType::kBFloat16: + case DataType::kTensorFloat32: + return FragAttrs('r', 32, "(unsigned *)"); + case DataType::kInt32: + return FragAttrs('r', 32, "(int *)"); + case DataType::kFloat32: + return FragAttrs('f', 32, "(float *)"); + case DataType::kFloat64: + return FragAttrs('d', 64, "(double *)"); + default: + ICHECK(false) << DTypeToString(dtype) << " is not matrix data type in MMA."; + return FragAttrs('\0', 0, ""); } - asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); - asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); - asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); - asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); - asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); - asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); - return asm_code; } -std::string PrintMMAm16n8k32Assembly(const std::string& A_layout, const std::string& B_layout, - const std::string& A_dtype, const std::string& B_dtype, - const std::string& C_dtype, const std::string& a_ref, - const std::string& a_bias, const std::string& b_ref, - const std::string& b_bias, const std::string& c_ref, - const std::string& c_bias, bool saturate) { - std::string asm_code = ""; - std::string new_a_ref = ""; - std::string new_b_ref = ""; - std::string new_c_ref = ""; - ICHECK(((A_dtype == "int8") && (B_dtype == "int8")) || - ((A_dtype == "uint8") && (B_dtype == "int8")) || - ((A_dtype == "int8") && (B_dtype == "uint8")) || - ((A_dtype == "uint8") && (B_dtype == "uint8"))); - if ((A_dtype == "int8") && (B_dtype == "int8")) { - // A/B multiplicand is int8, SM 80 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k32 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } - } else if ((A_dtype == "uint8") && (B_dtype == "int8")) { - // A multiplicand is uint8, B multiplicand is int8 - // SM 80 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k32 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } - } else if ((A_dtype == "int8") && (B_dtype == "uint8")) { - // A multiplicand is int8, B multiplicand is uint8 - // SM 80 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k32 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; +}; // namespace ptx + +/*! + * \brief Replace patterns with replacement strings. + * \note should use std::format instead when codebase is ported to C++20. + */ +class Replacer { + public: + void register_rule(const std::string& pattern, const std::string& replacement) { + _rules.emplace_back(pattern, replacement); + } + std::string rewrite(std::string str) { + for (auto&& rule : _rules) { + std::string pattern, replacement; + std::tie(pattern, replacement) = rule; + size_t len = pattern.size(); + size_t new_len = replacement.size(); + size_t pos = str.find(pattern); + while (pos != std::string::npos) { + str = str.replace(pos, len, replacement); + pos = str.find(pattern, pos + new_len); + } } + return str; + } + void empty_rules() { _rules.clear(); } + + private: + std::vector> _rules; +}; + +/*! + * \brief Get the number of MMA computations for given shape and datatype. + */ +inline uint32_t GetNumMMAComputations(int m, int n, int k, ptx::DataType dtype) { + if (m == 8 && n == 8 && k == 4 && dtype == ptx::DataType::kFloat16) { + // MMA for m8n8k4 on fp16 would launch 4 MMA computations instead of one. + return 4; } else { - // A/B multiplicand is uint8, SM 80 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k32 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } + return 1; } - asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); - asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); - asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); - asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); - asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); - asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); - return asm_code; } -std::string PrintMMAm16n8k64Assembly(const std::string& A_layout, const std::string& B_layout, - const std::string& A_dtype, const std::string& B_dtype, - const std::string& C_dtype, const std::string& a_ref, - const std::string& a_bias, const std::string& b_ref, - const std::string& b_bias, const std::string& c_ref, - const std::string& c_bias, bool saturate) { - std::string asm_code = ""; - std::string new_a_ref = ""; - std::string new_b_ref = ""; - std::string new_c_ref = ""; - ICHECK(((A_dtype == "int4") && (B_dtype == "int4")) || - ((A_dtype == "uint4") && (B_dtype == "int4")) || - ((A_dtype == "int4") && (B_dtype == "uint4")) || - ((A_dtype == "uint4") && (B_dtype == "uint4"))); - if ((A_dtype == "int4") && (B_dtype == "int4")) { - // A/B multiplicand is int4, SM 80 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k64 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } - } else if ((A_dtype == "uint4") && (B_dtype == "int4")) { - // A multiplicand is uint4, B multiplicand is int4 - // SM 80 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k64 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } - } else if ((A_dtype == "int4") && (B_dtype == "uint4")) { - // A multiplicand is int4, B multiplicand is uint4 - // SM 80 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k64 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } - } else { - // A/B multiplicand is uint4, SM 75 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k64 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - if (!saturate) { - // no saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // saturate - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32.satfinite " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; +/*! + * \brief Return template string, input operands string and output operands string. + * \param m The M in mMnNkK of MMA instructions. + * \param n The N in mMnNkK of MMA instructions. + * \param k The K in mMnNkK of MMA instructions. + * \param dtype_mul The data type of multiplicand. + * \param dtype_acc The data type of accumulator. + * \param sparse Whether it's Sparse MMA or not. + */ +inline std::tuple GetMMAOperands(int m, int n, int k, + ptx::DataType dtype_mul, + ptx::DataType dtype_acc, + bool sparse) { + std::stringstream templates, inputs, outputs; + const ptx::FragAttrs frag_attr_mul = ptx::GetFragAttrs(dtype_mul), + frag_attr_acc = ptx::GetFragAttrs(dtype_acc); + constexpr uint32_t warp_size = 32; + const uint32_t threads = warp_size / GetNumMMAComputations(m, n, k, dtype_mul); + const int num_operands_a = (m * k) * ptx::DTypeBits(dtype_mul) / frag_attr_acc.size / threads / + (sparse ? 2 : 1), + num_operands_b = (k * n) * ptx::DTypeBits(dtype_mul) / frag_attr_mul.size / threads, + num_operands_c = (m * n) * ptx::DTypeBits(dtype_acc) / frag_attr_acc.size / threads; + + // generate templates; + int arg_counter = 0; + templates << "{" + << "%" << arg_counter++; + for (int i = 1; i < num_operands_c; ++i) { + templates << ", %" << arg_counter++; + } + templates << "}, {" + << "%" << arg_counter++; + for (int i = 1; i < num_operands_a; ++i) { + templates << ", %" << arg_counter++; + } + templates << "}, {" + << "%" << arg_counter++; + for (int i = 1; i < num_operands_b; ++i) { + templates << ", %" << arg_counter++; + } + templates << "}, {" + << "%" << arg_counter++; + for (int i = 1; i < num_operands_c; ++i) { + templates << ", %" << arg_counter++; + } + templates << "}"; + // templates of metadata and sparse selector for sparse mma. + if (sparse) { + templates << ", %" << (arg_counter++) << ", F"; + } + + // generate inputs + for (int i = 0; i < num_operands_a; ++i) { + if (i != 0) { + inputs << ", "; } + inputs << "\"" << frag_attr_mul.reg_type << "\"((" << frag_attr_mul.ptr_sig << "(A))[" << i + << "])"; + } + for (int i = 0; i < num_operands_b; ++i) { + inputs << ", \"" << frag_attr_mul.reg_type << "\"((" << frag_attr_mul.ptr_sig << "(B))[" << i + << "])"; + } + for (int i = 0; i < num_operands_c; ++i) { + inputs << ", \"" << frag_attr_acc.reg_type << "\"((" << frag_attr_acc.ptr_sig << "(C))[" << i + << "])"; + } + // input of metadata for sparse mma. + if (sparse) { + inputs << ", \"r\"(((unsigned *)(E))[0])"; } - asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); - asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); - asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); - asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); - asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); - asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); - return asm_code; -} -std::string PrintMMAm16n8k256Assembly(const std::string& A_layout, const std::string& B_layout, - const std::string& A_dtype, const std::string& B_dtype, - const std::string& C_dtype, const std::string& a_ref, - const std::string& a_bias, const std::string& b_ref, - const std::string& b_bias, const std::string& c_ref, - const std::string& c_bias, bool saturate) { - std::string asm_code = ""; - std::string new_a_ref = ""; - std::string new_b_ref = ""; - std::string new_c_ref = ""; - ICHECK(((A_dtype == "uint1") && (B_dtype == "uint1")) || - ((A_dtype == "int1") && (B_dtype == "int1"))); - if ((A_dtype == "uint1") && (B_dtype == "uint1")) { - // A/B multiplicand is uint1, SM 80 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k256 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; - } else { - // A/B multiplicand is int1, SM 80 Tensor Core instructions - ICHECK(C_dtype == "int32"); - ICHECK((A_layout == "row") && (B_layout == "col")) - << "SM80 Tensor Core instructions " - << "with shape m16n8k256 expect A layout is row major and B layout is col major."; - // C accumulator is int32 - new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; - new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; - new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; - asm_code = R"( - { - __asm__ __volatile__( - "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " - "{%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), - "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - } - )"; + // generate outputs + for (int i = 0; i < num_operands_c; ++i) { + if (i != 0) { + outputs << ","; + } + outputs << " \"=" << frag_attr_acc.reg_type << "\"((" << frag_attr_acc.ptr_sig << "(D))[" << i + << "])"; } - asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); - asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); - asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); - asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); - asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); - asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); - return asm_code; + return std::make_tuple(templates.str(), inputs.str(), outputs.str()); } std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout, const std::string& B_layout, const std::string& A_dtype, const std::string& B_dtype, const std::string& C_dtype, - const std::string& a_ref, const std::string& a_bias, - const std::string& b_ref, const std::string& b_bias, - const std::string& c_ref, const std::string& c_bias, bool saturate) { - ICHECK((shape == "m8n8k4") || (shape == "m16n8k8") || (shape == "m8n8k16") || - (shape == "m8n8k32") || (shape == "m16n8k4") || (shape == "m16n8k16") || - (shape == "m16n8k32") || (shape == "m16n8k64") || (shape == "m16n8k256")); - ICHECK((A_layout == "row") || (A_layout == "col")) << "Unknown A layout: " << A_layout; - ICHECK((B_layout == "row") || (B_layout == "col")) << "Unknown B layout: " << B_layout; - - if (shape == "m8n8k4") { - return PrintMMAm8n8k4Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, - b_ref, b_bias, c_ref, c_bias, saturate); - } else if (shape == "m16n8k8") { - return PrintMMAm16n8k8Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, - b_ref, b_bias, c_ref, c_bias, saturate); - } else if (shape == "m8n8k16") { - return PrintMMAm8n8k16Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, - b_ref, b_bias, c_ref, c_bias, saturate); - } else if (shape == "m8n8k32") { - return PrintMMAm8n8k32Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, - b_ref, b_bias, c_ref, c_bias, saturate); - } else if (shape == "m16n8k4") { - return PrintMMAm16n8k4Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, - b_ref, b_bias, c_ref, c_bias, saturate); - } else if (shape == "m16n8k16") { - return PrintMMAm16n8k16Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, - b_ref, b_bias, c_ref, c_bias, saturate); - } else if (shape == "m16n8k32") { - return PrintMMAm16n8k32Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, - b_ref, b_bias, c_ref, c_bias, saturate); - } else if (shape == "m16n8k64") { - return PrintMMAm16n8k64Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, - b_ref, b_bias, c_ref, c_bias, saturate); - } else if (shape == "m16n8k256") { - return PrintMMAm16n8k256Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, - b_ref, b_bias, c_ref, c_bias, saturate); + const std::string& a_ref, const std::string& a_offset, + const std::string& b_ref, const std::string& b_offset, + const std::string& c_ref, const std::string& c_offset, + const std::string& metadata, const std::string& metadata_offset, + const std::string& sparsity_selector, const std::string& bit_op, + bool sparse, bool saturate) { + ptx::DataType dtype_a = ptx::DTypeFromString(A_dtype), dtype_b = ptx::DTypeFromString(B_dtype), + dtype_c = ptx::DTypeFromString(C_dtype); + ptx::LayoutType layout_a = ptx::LayoutTypeFromString(A_layout), + layout_b = ptx::LayoutTypeFromString(B_layout); + int m, n, k; + std::tie(m, n, k) = ptx::ParseMMAShape(shape); + CheckMMAConfigValidity(m, n, k, layout_a, layout_b, dtype_a, dtype_b, dtype_c, bit_op, sparse, + saturate); + std::string asm_code = R"( + { + __asm__ __volatile__( + "mma{sparse}.sync.aligned.{shape}.{alayout}.{blayout}{satinite}{dtype}{atype}{btype}{ctype}{1bit}" + "{templates};\n" + : {outputs} + : {inputs}); } - /* - * TODO: add mma.m16n8k128 - */ - throw Error("Unknown PTX mma instructions."); +)"; + std::string templates_str, inputs_str, outputs_str; + std::tie(templates_str, inputs_str, outputs_str) = + GetMMAOperands(m, n, k, dtype_a, dtype_c, sparse); + + // replace patterns + Replacer replacer; + replacer.register_rule("{sparse}", sparse ? ".sp" : ""); + replacer.register_rule("{shape}", shape); + replacer.register_rule("{satinite}", saturate ? ".satfinite" : ""); + replacer.register_rule("{alayout}", A_layout); + replacer.register_rule("{blayout}", B_layout); + replacer.register_rule("{atype}", ptx::DTypeToString(dtype_a)); + replacer.register_rule("{btype}", ptx::DTypeToString(dtype_b)); + replacer.register_rule("{ctype}", ptx::DTypeToString(dtype_c)); + replacer.register_rule("{dtype}", ptx::DTypeToString(dtype_c)); + replacer.register_rule("{1bit}", bit_op.empty() ? "" : "." + bit_op + ".popc"); + replacer.register_rule("{templates}", templates_str); + replacer.register_rule("{outputs}", outputs_str); + replacer.register_rule("{inputs}", inputs_str); + asm_code = replacer.rewrite(asm_code); + replacer.empty_rules(); + replacer.register_rule("A", a_ref + " + " + a_offset); + replacer.register_rule("B", b_ref + " + " + b_offset); + replacer.register_rule("C", c_ref + " + " + c_offset); + replacer.register_rule("D", c_ref + " + " + c_offset); + replacer.register_rule("E", metadata + " + " + metadata_offset); + replacer.register_rule("F", sparsity_selector); + asm_code = replacer.rewrite(asm_code); + return asm_code; } } // namespace codegen diff --git a/src/target/source/ptx_mma.h b/src/target/source/ptx_mma.h index d2a7a6705d6d..34862d1d0ed4 100644 --- a/src/target/source/ptx_mma.h +++ b/src/target/source/ptx_mma.h @@ -35,9 +35,12 @@ namespace codegen { std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout, const std::string& B_layout, const std::string& A_dtype, const std::string& B_dtype, const std::string& C_dtype, - const std::string& a_ref, const std::string& a_bias, - const std::string& b_ref, const std::string& b_bias, - const std::string& c_ref, const std::string& c_bias, bool saturate); + const std::string& a_ref, const std::string& a_offset, + const std::string& b_ref, const std::string& b_offset, + const std::string& c_ref, const std::string& c_offset, + const std::string& metadata, const std::string& metadata_offset, + const std::string& sparsity_selector, const std::string& single_bit_op, + bool sparse, bool saturate); } // namespace codegen } // namespace tvm diff --git a/src/target/source/ptx_mma_sp.cc b/src/target/source/ptx_mma_sp.cc deleted file mode 100644 index cc1c8b7bd012..000000000000 --- a/src/target/source/ptx_mma_sp.cc +++ /dev/null @@ -1,307 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file ptx_mma_sp.cc - */ - -#include "ptx_mma_sp.h" - -#include -#include -#include -#include -#include - -namespace tvm { -namespace codegen { - -namespace ptx { - -/*! - * \brief PTX data type. - */ -enum class DataType : int { - kInt4 = 0, - kUInt4 = 1, - kInt8 = 2, - kUInt8 = 3, - kInt16 = 4, - kUInt16 = 5, - kInt32 = 6, - kUInt32 = 7, - kInt64 = 8, - kUInt64 = 9, - kFloat16 = 10, - kBFloat16 = 11, - kFloat16x2 = 12, - kFloat32 = 13, - kTensorFloat32 = 14, - kFloat64 = 15, - kBit1 = 16 -}; - -static const char* dtype_str[] = {".s4", ".u4", ".s8", ".u8", ".s16", ".u16", - ".s32", ".u32", ".s64", ".u64", ".f16", ".bf16", - ".f16x2", ".f32", ".tf32", ".f64", ".b1"}; -static uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 16, 16, 32, 32, 32, 64, 1}; - -/*! - * \brief Create PTX data type from string. - */ -inline DataType DTypeFromString(const std::string str) { - if (str == "int4" || str == ".s4") { - return DataType::kInt4; - } else if (str == "uint4" || str == ".u4") { - return DataType::kUInt4; - } else if (str == "int8" || str == ".s8") { - return DataType::kUInt8; - } else if (str == "uint8" || str == ".u8") { - return DataType::kUInt8; - } else if (str == "int16" || str == ".s16") { - return DataType::kInt16; - } else if (str == "uint16" || str == ".u16") { - return DataType::kUInt16; - } else if (str == "int32" || str == ".s32") { - return DataType::kInt32; - } else if (str == "uint32" || str == ".u32") { - return DataType::kUInt32; - } else if (str == "int64" || str == ".s64") { - return DataType::kInt64; - } else if (str == "uint64" || str == ".u64") { - return DataType::kUInt64; - } else if (str == "float16" || str == "fp16" || str == ".f16") { - return DataType::kFloat16; - } else if (str == "bfloat16" || str == "bf16") { - return DataType::kBFloat16; - } else if (str == ".f16x2") { - return DataType::kFloat16x2; - } else if (str == "float32" || str == "fp32" || str == ".f32") { - return DataType::kFloat32; - } else if (str == "tf32") { - return DataType::kTensorFloat32; - } else if (str == "float64" || str == "fp64" || str == ".f64") { - return DataType::kFloat64; - } else if (str == ".b1") { - return DataType::kBit1; - } else { - LOG(FATAL) << "Unrecognized data type " << str << " for PTX."; - return DataType(0); - } -} - -/*! - * \brief Get the string representation of given PTX data type. - */ -inline std::string DTypeToString(DataType dtype) { return dtype_str[static_cast(dtype)]; } - -/*! - * \brief Get the number of bits of given PTX data type. - */ -inline uint32_t DTypeBits(DataType dtype) { return num_bits[static_cast(dtype)]; } - -/*! - * \brief Extract the value m, n, k from string m*n*k* - */ -std::tuple ParseMMAShape(const std::string& str) { - size_t pos_m = str.find("m"), pos_n = str.find("n"), pos_k = str.find("k"); - CHECK(pos_m != str.npos && pos_n != str.npos && pos_k != str.npos) - << "Cannot parse MMA shape " << str; - int m = std::stoi(str.substr(pos_m + 1, pos_n - pos_m - 1)), - n = std::stoi(str.substr(pos_n + 1, pos_k - pos_n - 1)), k = std::stoi(str.substr(pos_k + 1)); - return std::make_tuple(m, n, k); -} - -/*! - * \brief Fragment attributes of given data type. - * \return the register type in ptx, fragment size, fragment pointer string. - */ -inline std::tuple FragmentAttrs(DataType dtype) { - switch (dtype) { - case DataType::kBit1: - case DataType::kInt4: - case DataType::kUInt4: - case DataType::kInt8: - case DataType::kUInt8: - case DataType::kFloat16: // .f16x2 register - case DataType::kBFloat16: - case DataType::kTensorFloat32: - return std::make_tuple('r', 32, "(unsigned *)"); - case DataType::kInt32: - return std::make_tuple('r', 32, "(int *)"); - case DataType::kFloat32: - return std::make_tuple('f', 32, "(float *)"); - case DataType::kFloat64: - return std::make_tuple('d', 64, "(double *)"); - default: - LOG(FATAL) << DTypeToString(dtype) << " is not matrix data type in MMA."; - return std::make_tuple('\0', 0, ""); - } -} - -}; // namespace ptx - -/*! - * \brief Replace patterns with replacement strings. - * \note should use std::format instead when codebase is ported to C++20. - */ -class Replacer { - public: - void register_rule(const std::string& pattern, const std::string& replacement) { - _rules.emplace_back(pattern, replacement); - } - std::string rewrite(std::string str) { - for (auto&& rule : _rules) { - std::string pattern, replacement; - std::tie(pattern, replacement) = rule; - size_t len = pattern.size(); - size_t new_len = replacement.size(); - size_t pos = str.find(pattern); - while (pos != std::string::npos) { - str = str.replace(pos, len, replacement); - pos = str.find(pattern, pos + new_len); - } - } - return str; - } - void empty_rules() { _rules.clear(); } - - private: - std::vector> _rules; -}; - -/*! - * \brief Return template string, input operands string and output operands string. - */ -inline std::tuple get_mma_sp_operands( - int m, int n, int k, ptx::DataType dtype_a, ptx::DataType dtype_b, ptx::DataType dtype_c) { - std::stringstream templates, inputs, outputs; - auto frag_attr_a = ptx::FragmentAttrs(dtype_a), frag_attr_b = ptx::FragmentAttrs(dtype_b), - frag_attr_c = ptx::FragmentAttrs(dtype_c); - constexpr int warp_size = 32; - int num_operands_a, num_operands_b, num_operands_c; - num_operands_a = (m * k / 2) * ptx::DTypeBits(dtype_a) / std::get<1>(frag_attr_a) / warp_size; - num_operands_b = (k * n) * ptx::DTypeBits(dtype_b) / std::get<1>(frag_attr_b) / warp_size; - num_operands_c = (m * n) * ptx::DTypeBits(dtype_c) / std::get<1>(frag_attr_c) / warp_size; - - // generate templates; - int arg_counter = 0; - templates << "{" - << "%" << arg_counter++; - for (int i = 1; i < num_operands_c; ++i) { - templates << ", %" << arg_counter++; - } - templates << "}, {" - << "%" << arg_counter++; - for (int i = 1; i < num_operands_a; ++i) { - templates << ", %" << arg_counter++; - } - templates << "}, {" - << "%" << arg_counter++; - for (int i = 1; i < num_operands_b; ++i) { - templates << ", %" << arg_counter++; - } - templates << "}, {" - << "%" << arg_counter++; - for (int i = 1; i < num_operands_c; ++i) { - templates << ", %" << arg_counter++; - } - templates << "}, %" << (arg_counter++) << ", F"; - - // generate inputs - for (int i = 0; i < num_operands_a; ++i) { - if (i != 0) { - inputs << ", "; - } - inputs << "\"" << std::get<0>(frag_attr_a) << "\"((" << std::get<2>(frag_attr_a) << "(A))[" << i - << "])"; - } - for (int i = 0; i < num_operands_b; ++i) { - inputs << ", \"" << std::get<0>(frag_attr_b) << "\"((" << std::get<2>(frag_attr_b) << "(B))[" - << i << "])"; - } - for (int i = 0; i < num_operands_c; ++i) { - inputs << ", \"" << std::get<0>(frag_attr_c) << "\"((" << std::get<2>(frag_attr_c) << "(C))[" - << i << "])"; - } - inputs << ", \"r\"(((unsigned *)(E))[0])"; - - // generate outputs - for (int i = 0; i < num_operands_c; ++i) { - if (i != 0) { - outputs << ","; - } - outputs << " \"=" << std::get<0>(frag_attr_c) << "\"((" << std::get<2>(frag_attr_c) << "(D))[" - << i << "])"; - } - return std::make_tuple(templates.str(), inputs.str(), outputs.str()); -} - -std::string PrintMMASparseAssembly(const std::string& shape, const std::string& A_layout, - const std::string& B_layout, const std::string& A_dtype, - const std::string& B_dtype, const std::string& C_dtype, - const std::string& a_ref, const std::string& a_offset, - const std::string& b_ref, const std::string& b_offset, - const std::string& c_ref, const std::string& c_offset, - const std::string& metadata, const std::string& metadata_offset, - const std::string& sparsity_selector, bool saturate) { - ptx::DataType dtype_a = ptx::DTypeFromString(A_dtype), dtype_b = ptx::DTypeFromString(B_dtype), - dtype_c = ptx::DTypeFromString(C_dtype); - int m, n, k; - std::tie(m, n, k) = ptx::ParseMMAShape(shape); - std::string asm_code = R"( - { - __asm__ __volatile__( - "mma.sp.sync.aligned.{shape}.{alayout}.{blayout}{satinite}{dtype}{atype}{btype}{ctype}" - "{templates};\n" - : {outputs} - : {inputs}); - } -)"; - std::string templates_str, inputs_str, outputs_str; - std::tie(templates_str, inputs_str, outputs_str) = - get_mma_sp_operands(m, n, k, dtype_a, dtype_b, dtype_c); - - // replace patterns - Replacer replacer; - replacer.register_rule("{shape}", shape); - replacer.register_rule("{satinite}", saturate ? ".satinite" : ""); - replacer.register_rule("{alayout}", A_layout); - replacer.register_rule("{blayout}", B_layout); - replacer.register_rule("{atype}", ptx::DTypeToString(dtype_a)); - replacer.register_rule("{btype}", ptx::DTypeToString(dtype_b)); - replacer.register_rule("{ctype}", ptx::DTypeToString(dtype_c)); - replacer.register_rule("{dtype}", ptx::DTypeToString(dtype_c)); - replacer.register_rule("{templates}", templates_str); - replacer.register_rule("{outputs}", outputs_str); - replacer.register_rule("{inputs}", inputs_str); - asm_code = replacer.rewrite(asm_code); - replacer.empty_rules(); - replacer.register_rule("A", a_ref + " + " + a_offset); - replacer.register_rule("B", b_ref + " + " + b_offset); - replacer.register_rule("C", c_ref + " + " + c_offset); - replacer.register_rule("D", c_ref + " + " + c_offset); - replacer.register_rule("E", metadata + " + " + metadata_offset); - replacer.register_rule("F", sparsity_selector); - asm_code = replacer.rewrite(asm_code); - return asm_code; -} - -} // namespace codegen -} // namespace tvm diff --git a/src/target/source/ptx_mma_sp.h b/src/target/source/ptx_mma_sp.h deleted file mode 100644 index 8ee7099db2c4..000000000000 --- a/src/target/source/ptx_mma_sp.h +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file ptx_mma_sp.h - * \brief Sparse MMA code generation with inlined PTX code. - */ -#ifndef TVM_TARGET_SOURCE_PTX_MMA_SP_H_ -#define TVM_TARGET_SOURCE_PTX_MMA_SP_H_ - -#include - -#include -#include - -namespace tvm { -namespace codegen { - -std::string PrintMMASparseAssembly(const std::string& shape, const std::string& A_layout, - const std::string& B_layout, const std::string& A_dtype, - const std::string& B_dtype, const std::string& C_dtype, - const std::string& a_ref, const std::string& a_offset, - const std::string& b_ref, const std::string& b_offset, - const std::string& c_ref, const std::string& c_offset, - const std::string& metadata, const std::string& metadata_offset, - const std::string& sparsity_selector, bool saturate); - -} // namespace codegen -} // namespace tvm - -#endif // TVM_TARGET_SOURCE_PTX_MMA_SP_H_ diff --git a/tests/python/unittest/test_tir_ptx_mma.py b/tests/python/unittest/test_tir_ptx_mma.py index 8f653c614d42..1bc0b012b716 100644 --- a/tests/python/unittest/test_tir_ptx_mma.py +++ b/tests/python/unittest/test_tir_ptx_mma.py @@ -1311,6 +1311,7 @@ def gemm_mma_m16n8k256_row_col_b1b1s32(a: T.handle, b: T.handle, c: T.handle): Accum.data, 0, False, + "xor", dtype="int32", ) ) @@ -1347,16 +1348,11 @@ def test_gemm_mma_m16n8k256_row_col_b1b1s32(): test_gemm_mma_m8n8k4_row_row_fp16fp16fp16() test_gemm_mma_m8n8k4_row_row_fp16fp16fp32() test_gemm_mma_m8n8k16_row_col_s8s8s32() - test_gemm_mma_m8n8k16_row_col_s8u8s32() test_gemm_mma_m8n8k32_row_col_s4s4s32() - test_gemm_mma_m8n8k32_row_col_s4u4s32() test_gemm_mma_m16n8k8_row_col_fp16fp16fp32() test_gemm_mma_m16n8k16_row_col_fp16fp16fp16() test_gemm_mma_m16n8k16_row_col_fp16fp16fp32() test_gemm_mma_m16n8k16_row_col_s8s8s32() - test_gemm_mma_m16n8k16_row_col_s8u8s32() test_gemm_mma_m16n8k32_row_col_s8s8s32() - test_gemm_mma_m16n8k32_row_col_s8u8s32() test_gemm_mma_m16n8k64_row_col_s4s4s32() - test_gemm_mma_m16n8k64_row_col_s4u4s32() test_gemm_mma_m16n8k256_row_col_b1b1s32() diff --git a/tests/python/unittest/test_tir_ptx_mma_sp.py b/tests/python/unittest/test_tir_ptx_mma_sp.py index 93c7915c2ab5..9c3e8bc4121f 100644 --- a/tests/python/unittest/test_tir_ptx_mma_sp.py +++ b/tests/python/unittest/test_tir_ptx_mma_sp.py @@ -15,9 +15,6 @@ # specific language governing permissions and limitations # under the License. -import sys -import pytest - import tvm from tvm.script import tir as T import numpy as np From 2f512bee2aa359d31a69640e0299fd65e07e2a17 Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 24 Feb 2022 19:19:36 -0800 Subject: [PATCH 12/15] docstring and sanity --- src/target/source/ptx_mma.cc | 4 +++- src/target/source/ptx_mma.h | 23 ++++++++++++++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/target/source/ptx_mma.cc b/src/target/source/ptx_mma.cc index 49d64b1dc072..cebb6ebad95f 100644 --- a/src/target/source/ptx_mma.cc +++ b/src/target/source/ptx_mma.cc @@ -153,7 +153,9 @@ static const char* layout_type_str[] = {"row", "col"}; /*! * \brief Convert layout type to string. */ -inline std::string LayoutTypeToString(LayoutType ltype) { return layout_type_str[int(ltype)]; } +inline std::string LayoutTypeToString(LayoutType layout) { + return layout_type_str[static_cast(layout)]; +} /*! * \brief MMA Configurations, used to determine validity. diff --git a/src/target/source/ptx_mma.h b/src/target/source/ptx_mma.h index 34862d1d0ed4..728478cdf5fb 100644 --- a/src/target/source/ptx_mma.h +++ b/src/target/source/ptx_mma.h @@ -32,6 +32,27 @@ namespace tvm { namespace codegen { +/*! + * \brief Print MMA assembly string given parameters. + * \param shape The shape string mMnNkK + * \param A_layout The layout of multiplicand A, can be either "row" or "col". + * \param B_layout The layout of multiplicand B, can be either "row" or "col". + * \param A_dtype The data type of multiplicand A. + * \param B_dtype The data type of multiplicand B. + * \param C_dtype The data type of multiplicand C. + * \param a_ref Pointer to buffer A. + * \param a_offset The offset of element in A. + * \param b_ref Pointer to buffer B. + * \param b_offset The offset of element in B. + * \param c_ref Pointer to buffer C. + * \param c_offset The offset of element in C. + * \param metadata Pointer to metadata buffer (only used for sparse mma). + * \param metadata_offset The offset of element in metadata. + * \param sparsity_selector The sparsity selector in sparse mma. + * \param bit_op The bit operator used in 1-bit mma, can be either "xor" or "and". + * \param sparse Whether it's sparse mma or not. + * \param saturate Whether saturate output or not. + */ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout, const std::string& B_layout, const std::string& A_dtype, const std::string& B_dtype, const std::string& C_dtype, @@ -39,7 +60,7 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo const std::string& b_ref, const std::string& b_offset, const std::string& c_ref, const std::string& c_offset, const std::string& metadata, const std::string& metadata_offset, - const std::string& sparsity_selector, const std::string& single_bit_op, + const std::string& sparsity_selector, const std::string& bit_op, bool sparse, bool saturate); } // namespace codegen From 64014cd524a126d206c68f2889212449d98a5baa Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 24 Feb 2022 21:55:44 -0800 Subject: [PATCH 13/15] add u8s8s32 back --- src/target/source/ptx_mma.cc | 93 ++++++++++++++--------- tests/python/unittest/test_tir_ptx_mma.py | 5 ++ 2 files changed, 64 insertions(+), 34 deletions(-) diff --git a/src/target/source/ptx_mma.cc b/src/target/source/ptx_mma.cc index cebb6ebad95f..45eb8303183a 100644 --- a/src/target/source/ptx_mma.cc +++ b/src/target/source/ptx_mma.cc @@ -216,35 +216,61 @@ const MMAConfig valid_mma_configs[] = { /*! * \brief Check whether the multiplicand data type and accumulator data type is valid for MMA * computation. - * \param mul The multiplicand data type. - * \param acc The accumulator data type. + * \param dtype_a The data type of multiplicand a. + * \param dtype_b The data type of multiplicand b. + * \param dtype_c The data type of accumulator c. */ -void CheckMMADTypeCompatible(DataType mul, DataType acc) { - switch (mul) { +void CheckMMADTypeCompatible(DataType dtype_a, DataType dtype_b, DataType dtype_c) { + std::string ab_not_match_err_str = "The multiplicands' data type " + DTypeToString(dtype_a) + + DTypeToString(dtype_b) + " do not match."; + // check a and b + switch (dtype_a) { case DataType::kBit1: + case DataType::kFloat16: + case DataType::kBFloat16: + case DataType::kTensorFloat32: + case DataType::kFloat64: + CHECK(dtype_a == dtype_b) << ab_not_match_err_str; + break; case DataType::kInt4: case DataType::kUInt4: + CHECK(dtype_b == DataType::kInt4 || dtype_b == DataType::kUInt4) << ab_not_match_err_str; + break; case DataType::kInt8: case DataType::kUInt8: - CHECK(acc == DataType::kInt32) << "For multiplicand data type " << DTypeToString(mul) - << ", accumulator data type should be s32."; + CHECK(dtype_b == DataType::kInt8 || dtype_b == DataType::kUInt8) << ab_not_match_err_str; + break; + default: + CHECK(false) << "Invalid multiplicand data types: " << DTypeToString(dtype_a) + << DTypeToString(dtype_b); + } + // check a,b and c + switch (dtype_a) { + case DataType::kBit1: + case DataType::kInt4: + case DataType::kUInt4: + case DataType::kInt8: + case DataType::kUInt8: + CHECK(dtype_c == DataType::kInt32) + << "For multiplicand data type " << DTypeToString(dtype_a) << DTypeToString(dtype_b) + << ", accumulator data type should be s32."; break; case DataType::kFloat16: - CHECK(acc == DataType::kFloat16 || acc == DataType::kFloat32) + CHECK(dtype_c == DataType::kFloat16 || dtype_c == DataType::kFloat32) << "For multiplicand data type f16, accumulator data type should be f16/f32."; break; case DataType::kBFloat16: case DataType::kTensorFloat32: - CHECK(acc == DataType::kFloat32) - << "For multiplicand data type bf16/tf32, accumulator data type can only be f32"; + CHECK(dtype_c == DataType::kFloat32) + << "For multiplicand data type bf16/tf32, accumulator data type can only be f32."; break; case DataType::kFloat64: - CHECK(acc == DataType::kFloat64) - << "For multiplicand data type f64, accumulator data type can only be f64"; + CHECK(dtype_c == DataType::kFloat64) + << "For multiplicand data type f64, accumulator data type can only be f64."; break; default: - CHECK(false) << "Invalid multiplicand/accumulator data type pair: " << DTypeToString(mul) - << ", " << DTypeToString(acc) << "."; + CHECK(false) << "Invalid multiplicand/accumulator data types: " << DTypeToString(dtype_a) + << DTypeToString(dtype_b) << DTypeToString(dtype_c) << "."; } } @@ -272,10 +298,7 @@ void CheckMMAConfigValidity(int m, int n, int k, LayoutType layout_a, LayoutType if (use_bit_op) { CHECK(dtype_a == DataType::kBit1) << "Bit operator is only compatible with 1bit multiplicand."; } - CHECK(dtype_a == dtype_b) << "The multiplicand data type must be equal, found " - << DTypeToString(dtype_a) << " and " << ptx::DTypeToString(dtype_b) - << "."; - CheckMMADTypeCompatible(dtype_a, dtype_c); + CheckMMADTypeCompatible(dtype_a, dtype_b, dtype_c); if (saturate) { CHECK(dtype_a == DataType::kInt4 || dtype_a == DataType::kUInt4 || dtype_a == DataType::kInt8 || dtype_a == DataType::kUInt8) @@ -389,23 +412,26 @@ inline uint32_t GetNumMMAComputations(int m, int n, int k, ptx::DataType dtype) * \param m The M in mMnNkK of MMA instructions. * \param n The N in mMnNkK of MMA instructions. * \param k The K in mMnNkK of MMA instructions. - * \param dtype_mul The data type of multiplicand. - * \param dtype_acc The data type of accumulator. + * \param dtype_a The data type of multiplicand a. + * \param dtype_b The data type of multiplicand b. + * \param dtype_c The data type of accumulator c. * \param sparse Whether it's Sparse MMA or not. */ inline std::tuple GetMMAOperands(int m, int n, int k, - ptx::DataType dtype_mul, - ptx::DataType dtype_acc, + ptx::DataType dtype_a, + ptx::DataType dtype_b, + ptx::DataType dtype_c, bool sparse) { std::stringstream templates, inputs, outputs; - const ptx::FragAttrs frag_attr_mul = ptx::GetFragAttrs(dtype_mul), - frag_attr_acc = ptx::GetFragAttrs(dtype_acc); + const ptx::FragAttrs frag_attr_a = ptx::GetFragAttrs(dtype_a), + frag_attr_b = ptx::GetFragAttrs(dtype_b), + frag_attr_c = ptx::GetFragAttrs(dtype_c); constexpr uint32_t warp_size = 32; - const uint32_t threads = warp_size / GetNumMMAComputations(m, n, k, dtype_mul); - const int num_operands_a = (m * k) * ptx::DTypeBits(dtype_mul) / frag_attr_acc.size / threads / - (sparse ? 2 : 1), - num_operands_b = (k * n) * ptx::DTypeBits(dtype_mul) / frag_attr_mul.size / threads, - num_operands_c = (m * n) * ptx::DTypeBits(dtype_acc) / frag_attr_acc.size / threads; + const uint32_t threads = warp_size / GetNumMMAComputations(m, n, k, dtype_a); + const int num_operands_a = + (m * k) * ptx::DTypeBits(dtype_a) / frag_attr_a.size / threads / (sparse ? 2 : 1), + num_operands_b = (k * n) * ptx::DTypeBits(dtype_b) / frag_attr_b.size / threads, + num_operands_c = (m * n) * ptx::DTypeBits(dtype_c) / frag_attr_c.size / threads; // generate templates; int arg_counter = 0; @@ -440,15 +466,14 @@ inline std::tuple GetMMAOperands(int m, i if (i != 0) { inputs << ", "; } - inputs << "\"" << frag_attr_mul.reg_type << "\"((" << frag_attr_mul.ptr_sig << "(A))[" << i - << "])"; + inputs << "\"" << frag_attr_a.reg_type << "\"((" << frag_attr_a.ptr_sig << "(A))[" << i << "])"; } for (int i = 0; i < num_operands_b; ++i) { - inputs << ", \"" << frag_attr_mul.reg_type << "\"((" << frag_attr_mul.ptr_sig << "(B))[" << i + inputs << ", \"" << frag_attr_b.reg_type << "\"((" << frag_attr_b.ptr_sig << "(B))[" << i << "])"; } for (int i = 0; i < num_operands_c; ++i) { - inputs << ", \"" << frag_attr_acc.reg_type << "\"((" << frag_attr_acc.ptr_sig << "(C))[" << i + inputs << ", \"" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_sig << "(C))[" << i << "])"; } // input of metadata for sparse mma. @@ -461,7 +486,7 @@ inline std::tuple GetMMAOperands(int m, i if (i != 0) { outputs << ","; } - outputs << " \"=" << frag_attr_acc.reg_type << "\"((" << frag_attr_acc.ptr_sig << "(D))[" << i + outputs << " \"=" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_sig << "(D))[" << i << "])"; } return std::make_tuple(templates.str(), inputs.str(), outputs.str()); @@ -495,7 +520,7 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo )"; std::string templates_str, inputs_str, outputs_str; std::tie(templates_str, inputs_str, outputs_str) = - GetMMAOperands(m, n, k, dtype_a, dtype_c, sparse); + GetMMAOperands(m, n, k, dtype_a, dtype_b, dtype_c, sparse); // replace patterns Replacer replacer; diff --git a/tests/python/unittest/test_tir_ptx_mma.py b/tests/python/unittest/test_tir_ptx_mma.py index 1bc0b012b716..23405fdee98a 100644 --- a/tests/python/unittest/test_tir_ptx_mma.py +++ b/tests/python/unittest/test_tir_ptx_mma.py @@ -1348,11 +1348,16 @@ def test_gemm_mma_m16n8k256_row_col_b1b1s32(): test_gemm_mma_m8n8k4_row_row_fp16fp16fp16() test_gemm_mma_m8n8k4_row_row_fp16fp16fp32() test_gemm_mma_m8n8k16_row_col_s8s8s32() + test_gemm_mma_m8n8k16_row_col_s8u8s32() test_gemm_mma_m8n8k32_row_col_s4s4s32() + test_gemm_mma_m8n8k32_row_col_s4u4s32() test_gemm_mma_m16n8k8_row_col_fp16fp16fp32() test_gemm_mma_m16n8k16_row_col_fp16fp16fp16() test_gemm_mma_m16n8k16_row_col_fp16fp16fp32() test_gemm_mma_m16n8k16_row_col_s8s8s32() + test_gemm_mma_m16n8k16_row_col_s8u8s32() test_gemm_mma_m16n8k32_row_col_s8s8s32() + test_gemm_mma_m16n8k32_row_col_s8u8s32() test_gemm_mma_m16n8k64_row_col_s4s4s32() + test_gemm_mma_m16n8k64_row_col_s4u4s32() test_gemm_mma_m16n8k256_row_col_b1b1s32() From 7567b48d98f01aa289fd81302eb978fb8be621b7 Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 24 Feb 2022 22:15:10 -0800 Subject: [PATCH 14/15] improvement --- src/target/source/ptx_mma.cc | 37 ++++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/src/target/source/ptx_mma.cc b/src/target/source/ptx_mma.cc index 45eb8303183a..d04c01896ed7 100644 --- a/src/target/source/ptx_mma.cc +++ b/src/target/source/ptx_mma.cc @@ -37,6 +37,11 @@ namespace ptx { /*! * \brief PTX data type. + * \note + * PTX fundamental data types: + * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types + * PTX matrix data types: + * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types */ enum class DataType : int { kInt4 = 0, @@ -173,6 +178,11 @@ struct MMAConfig { } }; +/*! + * \brief Valid MMA configurations + * \note Reference: + * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-shape + */ const MMAConfig valid_mma_configs[] = { MMAConfig(8, 8, 4, DataType::kFloat64, false, false), MMAConfig(8, 8, 4, DataType::kFloat16, false, false), @@ -219,6 +229,8 @@ const MMAConfig valid_mma_configs[] = { * \param dtype_a The data type of multiplicand a. * \param dtype_b The data type of multiplicand b. * \param dtype_c The data type of accumulator c. + * \note Reference: + * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types */ void CheckMMADTypeCompatible(DataType dtype_a, DataType dtype_b, DataType dtype_c) { std::string ab_not_match_err_str = "The multiplicands' data type " + DTypeToString(dtype_a) + @@ -296,7 +308,7 @@ void CheckMMAConfigValidity(int m, int n, int k, LayoutType layout_a, LayoutType << "Unrecognized 1-bit operation " << bit_op << " , can only be xor/and."; bool use_bit_op = !bit_op.empty(); if (use_bit_op) { - CHECK(dtype_a == DataType::kBit1) << "Bit operator is only compatible with 1bit multiplicand."; + CHECK(dtype_a == DataType::kBit1) << "Bit operator is only compatible with 1-bit multiplicand."; } CheckMMADTypeCompatible(dtype_a, dtype_b, dtype_c); if (saturate) { @@ -328,14 +340,14 @@ void CheckMMAConfigValidity(int m, int n, int k, LayoutType layout_a, LayoutType */ class FragAttrs { public: - explicit FragAttrs(char reg_type, uint32_t size, std::string ptr_sig) - : reg_type(reg_type), size(size), ptr_sig(ptr_sig) {} + explicit FragAttrs(char reg_type, uint32_t size, std::string ptr_type) + : reg_type(reg_type), size(size), ptr_type(ptr_type) {} /*! \brief PTX register type */ char reg_type; /*! \brief Fragment size */ uint32_t size; - /*! \brief Fragment pointer signature */ - std::string ptr_sig; + /*! \brief Fragment pointer type */ + std::string ptr_type; }; /*! @@ -466,14 +478,15 @@ inline std::tuple GetMMAOperands(int m, i if (i != 0) { inputs << ", "; } - inputs << "\"" << frag_attr_a.reg_type << "\"((" << frag_attr_a.ptr_sig << "(A))[" << i << "])"; + inputs << "\"" << frag_attr_a.reg_type << "\"((" << frag_attr_a.ptr_type << "(A))[" << i + << "])"; } for (int i = 0; i < num_operands_b; ++i) { - inputs << ", \"" << frag_attr_b.reg_type << "\"((" << frag_attr_b.ptr_sig << "(B))[" << i + inputs << ", \"" << frag_attr_b.reg_type << "\"((" << frag_attr_b.ptr_type << "(B))[" << i << "])"; } for (int i = 0; i < num_operands_c; ++i) { - inputs << ", \"" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_sig << "(C))[" << i + inputs << ", \"" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_type << "(C))[" << i << "])"; } // input of metadata for sparse mma. @@ -486,7 +499,7 @@ inline std::tuple GetMMAOperands(int m, i if (i != 0) { outputs << ","; } - outputs << " \"=" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_sig << "(D))[" << i + outputs << " \"=" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_type << "(D))[" << i << "])"; } return std::make_tuple(templates.str(), inputs.str(), outputs.str()); @@ -512,7 +525,7 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo std::string asm_code = R"( { __asm__ __volatile__( - "mma{sparse}.sync.aligned.{shape}.{alayout}.{blayout}{satinite}{dtype}{atype}{btype}{ctype}{1bit}" + "mma{sparse}.sync.aligned.{shape}.{alayout}.{blayout}{saturate}{dtype}{atype}{btype}{ctype}{bitop}" "{templates};\n" : {outputs} : {inputs}); @@ -526,14 +539,14 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo Replacer replacer; replacer.register_rule("{sparse}", sparse ? ".sp" : ""); replacer.register_rule("{shape}", shape); - replacer.register_rule("{satinite}", saturate ? ".satfinite" : ""); + replacer.register_rule("{saturate}", saturate ? ".satfinite" : ""); replacer.register_rule("{alayout}", A_layout); replacer.register_rule("{blayout}", B_layout); replacer.register_rule("{atype}", ptx::DTypeToString(dtype_a)); replacer.register_rule("{btype}", ptx::DTypeToString(dtype_b)); replacer.register_rule("{ctype}", ptx::DTypeToString(dtype_c)); replacer.register_rule("{dtype}", ptx::DTypeToString(dtype_c)); - replacer.register_rule("{1bit}", bit_op.empty() ? "" : "." + bit_op + ".popc"); + replacer.register_rule("{bitop}", bit_op.empty() ? "" : "." + bit_op + ".popc"); replacer.register_rule("{templates}", templates_str); replacer.register_rule("{outputs}", outputs_str); replacer.register_rule("{inputs}", inputs_str); From b3f5751055bf59f2c5565d567c01b8e3c966e0b7 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 6 Mar 2022 19:30:07 -0800 Subject: [PATCH 15/15] compatible #9727 --- tests/python/unittest/test_tir_ptx_mma_sp.py | 40 ++++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/tests/python/unittest/test_tir_ptx_mma_sp.py b/tests/python/unittest/test_tir_ptx_mma_sp.py index 9c3e8bc4121f..321cd28ff6f7 100644 --- a/tests/python/unittest/test_tir_ptx_mma_sp.py +++ b/tests/python/unittest/test_tir_ptx_mma_sp.py @@ -75,13 +75,13 @@ def mma_sp_m16n8k16_f16f16f16(a: T.handle, b: T.handle, c: T.handle, _metadata: "fp16", "fp16", "fp16", - multi_a, + multi_a.data, 0, - multi_b, + multi_b.data, 0, - accum, + accum.data, 0, - meta_local, + meta_local.data, 0, 0, False, @@ -90,7 +90,7 @@ def mma_sp_m16n8k16_f16f16f16(a: T.handle, b: T.handle, c: T.handle, _metadata: ) for i in range(4): - C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = T.load("float16", accum, i) + C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = accum[i] @T.prim_func @@ -129,13 +129,13 @@ def mma_sp_m16n8k16_f16f16f32(a: T.handle, b: T.handle, c: T.handle, _metadata: "fp16", "fp16", "fp32", - multi_a, + multi_a.data, 0, - multi_b, + multi_b.data, 0, - accum, + accum.data, 0, - meta_local, + meta_local.data, 0, 0, False, @@ -144,7 +144,7 @@ def mma_sp_m16n8k16_f16f16f32(a: T.handle, b: T.handle, c: T.handle, _metadata: ) for i in range(4): - C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = T.load("float32", accum, i) + C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = accum[i] @T.prim_func @@ -183,13 +183,13 @@ def mma_sp_m16n8k32_f16f16f16(a: T.handle, b: T.handle, c: T.handle, _metadata: "fp16", "fp16", "fp16", - multi_a, + multi_a.data, 0, - multi_b, + multi_b.data, 0, - accum, + accum.data, 0, - meta_local, + meta_local.data, 0, 0, False, @@ -198,7 +198,7 @@ def mma_sp_m16n8k32_f16f16f16(a: T.handle, b: T.handle, c: T.handle, _metadata: ) for i in range(4): - C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = T.load("float16", accum, i) + C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = accum[i] @T.prim_func @@ -237,13 +237,13 @@ def mma_sp_m16n8k32_f16f16f32(a: T.handle, b: T.handle, c: T.handle, _metadata: "fp16", "fp16", "fp32", - multi_a, + multi_a.data, 0, - multi_b, + multi_b.data, 0, - accum, + accum.data, 0, - meta_local, + meta_local.data, 0, 0, False, @@ -252,7 +252,7 @@ def mma_sp_m16n8k32_f16f16f32(a: T.handle, b: T.handle, c: T.handle, _metadata: ) for i in range(4): - C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = T.load("float32", accum, i) + C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = accum[i] @tvm.testing.requires_cuda