Skip to content

Commit

Permalink
[PTX] Support mma.sp to use Sparse Tensor Cores and refactor mma code…
Browse files Browse the repository at this point in the history
…gen (apache#10339)

* init

* upd

* upd

* lint

* lint again

* upd

* add m16n8k32 testcase

* format

* use make_tuple instead of initializer list

* add metadata offset

* upd

* docstring and sanity

* add u8s8s32 back

* improvement

* compatible apache#9727
  • Loading branch information
yzh119 authored and pfk-beta committed Apr 11, 2022
1 parent 060bdee commit 86a3c68
Show file tree
Hide file tree
Showing 7 changed files with 934 additions and 1,314 deletions.
13 changes: 13 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,19 @@ TVM_DLL const Op& tvm_store_matrix_sync();
*/
TVM_DLL const Op& ptx_mma();

/*!
* \brief tvm intrinsic for sparse tensor core ptx instructions.
*
* void ptx_mma_sp(StringImm shape, StringImm A_layout, StringImm B_layout,
* StringImm A_dtype, StringImm B_dtype, StringImm C_dtype,
* Var multiplicand_a, Expr a_index,
* Var multiplicand_b, Expr b_index,
* Var accumulator, Expr c_index,
* Var metadata, Expr meta_index,
* Var sparse_selector, bool saturate);
*/
TVM_DLL const Op& ptx_mma_sp();

// TODO(tvm-team) replace the usage of the vector operations by Shuffle.
/*!
* \brief Get the high level half of the vector
Expand Down
49 changes: 45 additions & 4 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
// arg 10: C accumulator
// arg 11: C accumulator index
// arg 12: saturate
ICHECK_EQ(op->args.size(), 13U);
// arg 13: (optional) 1-bit operator (xor or and)
ICHECK(op->args.size() == 13U || op->args.size() == 14U);
std::string shape = Downcast<StringImm>(op->args[0])->value;
std::string A_layout = Downcast<StringImm>(op->args[1])->value;
std::string B_layout = Downcast<StringImm>(op->args[2])->value;
Expand All @@ -757,11 +758,51 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string b_bias = this->PrintExpr(op->args[9]);
std::string c_ref = this->PrintExpr(op->args[10]);
std::string c_bias = this->PrintExpr(op->args[11]);
bool saturate = (Downcast<IntImm>(op->args[12])->value != 0);
std::string asm_code = PrintMMAAssembly(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype,
a_ref, a_bias, b_ref, b_bias, c_ref, c_bias, saturate);
bool saturate = Downcast<Bool>(op->args[12])->value;
std::string bit_op = op->args.size() > 13 ? Downcast<StringImm>(op->args[13])->value : "";
std::string asm_code =
PrintMMAAssembly(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, b_ref,
b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate);

this->stream << asm_code;
} else if (op->op.same_as(builtin::ptx_mma_sp())) {
// arg 0: shape: mXnXkX
// arg 1: A layout: row/col
// arg 2: B layout: row/col
// arg 3: A precision: fp16, fp32, ...
// arg 4: B precision: fp16, fp32, ...
// arg 5: C precision: fp16, fp32, ...
// arg 6: A multiplicand
// arg 7: A multiplicand index
// arg 8: B multiplicand
// arg 9: B multiplicand index
// arg 10: C accumulator
// arg 11: C accumulator index
// arg 12: metadata
// arg 13: metadata index
// arg 14: sparse_selector
// arg 15: saturate
ICHECK_EQ(op->args.size(), 16U);
std::string shape = Downcast<StringImm>(op->args[0])->value;
std::string A_layout = Downcast<StringImm>(op->args[1])->value;
std::string B_layout = Downcast<StringImm>(op->args[2])->value;
std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
std::string a_ref = this->PrintExpr(op->args[6]);
std::string a_offset = this->PrintExpr(op->args[7]);
std::string b_ref = this->PrintExpr(op->args[8]);
std::string b_offset = this->PrintExpr(op->args[9]);
std::string c_ref = this->PrintExpr(op->args[10]);
std::string c_offset = this->PrintExpr(op->args[11]);
std::string metadata = this->PrintExpr(op->args[12]);
std::string metadata_offset = this->PrintExpr(op->args[13]);
std::string sparse_selector = this->PrintExpr(op->args[14]);
bool saturate = Downcast<Bool>(op->args[15])->value;
std::string asm_code = PrintMMAAssembly(
shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset, b_ref, b_offset,
c_ref, c_offset, metadata, metadata_offset, sparse_selector, "", true, saturate);
this->stream << asm_code;
} else {
CodeGenC::VisitExpr_(op, os);
}
Expand Down
Loading

0 comments on commit 86a3c68

Please sign in to comment.