Skip to content

Commit

Permalink
[PTX-MMA] Add full PTX MMA code generation support
Browse files Browse the repository at this point in the history
  • Loading branch information
KnowingNothing committed Jan 12, 2022
1 parent d1ee201 commit c97319a
Show file tree
Hide file tree
Showing 4 changed files with 2,781 additions and 0 deletions.
36 changes: 36 additions & 0 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <vector>

#include "literal/cuda_half_t.h"
#include "ptx_mma.h"

namespace tvm {
namespace codegen {
Expand Down Expand Up @@ -723,6 +724,41 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
this->PrintExpr(op->args[i * 2 + 1], os);
os << "]" << ((i < 3) ? ", " : ")");
}
} else if ((op->op.same_as(builtin::call_extern()) ||
op->op.same_as(builtin::call_pure_extern())) &&
Downcast<StringImm>(op->args[0])->value == "ptx_mma") {
// arg 0: name: ptx_mma
// arg 1: shape: mXnXkX
// arg 2: A layout: row/col
// arg 3: B layout: row/col
// arg 4: A precision: fp16, fp64, ...
// arg 5: B precision: fp16, fp64, ...
// arg 6: C precision: fp32, fp64, ...
// arg 7: A multiplicand
// arg 8: A multiplicand index
// arg 9: B multiplicand
// arg 10: B multiplicand index
// arg 11: C accumulator
// arg 12: C accumulator index
// arg 13: saturate
ICHECK_EQ(op->args.size(), 14U);
std::string shape = Downcast<StringImm>(op->args[1])->value;
std::string A_layout = Downcast<StringImm>(op->args[2])->value;
std::string B_layout = Downcast<StringImm>(op->args[3])->value;
std::string A_dtype = Downcast<StringImm>(op->args[4])->value;
std::string B_dtype = Downcast<StringImm>(op->args[5])->value;
std::string C_dtype = Downcast<StringImm>(op->args[6])->value;
std::string a_ref = this->PrintExpr(op->args[7]);
std::string a_bias = this->PrintExpr(op->args[8]);
std::string b_ref = this->PrintExpr(op->args[9]);
std::string b_bias = this->PrintExpr(op->args[10]);
std::string c_ref = this->PrintExpr(op->args[11]);
std::string c_bias = this->PrintExpr(op->args[12]);
bool saturate = (Downcast<IntImm>(op->args[13])->value != 0);
std::string asm_code = PrintPTXAssembly(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype,
a_ref, a_bias, b_ref, b_bias, c_ref, c_bias, saturate);

this->stream << asm_code;
} else {
CodeGenC::VisitExpr_(op, os);
}
Expand Down
Loading

0 comments on commit c97319a

Please sign in to comment.