Skip to content

Commit

Permalink
improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Feb 25, 2022
1 parent ca92032 commit 89536ff
Showing 1 changed file with 25 additions and 12 deletions.
37 changes: 25 additions & 12 deletions src/target/source/ptx_mma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ namespace ptx {

/*!
* \brief PTX data type.
* \note
* PTX fundamental data types:
* https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types
* PTX matrix data types:
* https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types
*/
enum class DataType : int {
kInt4 = 0,
Expand Down Expand Up @@ -173,6 +178,11 @@ struct MMAConfig {
}
};

/*!
* \brief Valid MMA configurations
* \note Reference:
* https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-shape
*/
const MMAConfig valid_mma_configs[] = {
MMAConfig(8, 8, 4, DataType::kFloat64, false, false),
MMAConfig(8, 8, 4, DataType::kFloat16, false, false),
Expand Down Expand Up @@ -219,6 +229,8 @@ const MMAConfig valid_mma_configs[] = {
* \param dtype_a The data type of multiplicand a.
* \param dtype_b The data type of multiplicand b.
* \param dtype_c The data type of accumulator c.
* \note Reference:
* https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types
*/
void CheckMMADTypeCompatible(DataType dtype_a, DataType dtype_b, DataType dtype_c) {
std::string ab_not_match_err_str = "The multiplicands' data type " + DTypeToString(dtype_a) +
Expand Down Expand Up @@ -296,7 +308,7 @@ void CheckMMAConfigValidity(int m, int n, int k, LayoutType layout_a, LayoutType
<< "Unrecognized 1-bit operation " << bit_op << " , can only be xor/and.";
bool use_bit_op = !bit_op.empty();
if (use_bit_op) {
CHECK(dtype_a == DataType::kBit1) << "Bit operator is only compatible with 1bit multiplicand.";
CHECK(dtype_a == DataType::kBit1) << "Bit operator is only compatible with 1-bit multiplicand.";
}
CheckMMADTypeCompatible(dtype_a, dtype_b, dtype_c);
if (saturate) {
Expand Down Expand Up @@ -328,14 +340,14 @@ void CheckMMAConfigValidity(int m, int n, int k, LayoutType layout_a, LayoutType
*/
class FragAttrs {
public:
explicit FragAttrs(char reg_type, uint32_t size, std::string ptr_sig)
: reg_type(reg_type), size(size), ptr_sig(ptr_sig) {}
explicit FragAttrs(char reg_type, uint32_t size, std::string ptr_type)
: reg_type(reg_type), size(size), ptr_type(ptr_type) {}
/*! \brief PTX register type */
char reg_type;
/*! \brief Fragment size */
uint32_t size;
/*! \brief Fragment pointer signature */
std::string ptr_sig;
/*! \brief Fragment pointer type */
std::string ptr_type;
};

/*!
Expand Down Expand Up @@ -466,14 +478,15 @@ inline std::tuple<std::string, std::string, std::string> GetMMAOperands(int m, i
if (i != 0) {
inputs << ", ";
}
inputs << "\"" << frag_attr_a.reg_type << "\"((" << frag_attr_a.ptr_sig << "(A))[" << i << "])";
inputs << "\"" << frag_attr_a.reg_type << "\"((" << frag_attr_a.ptr_type << "(A))[" << i
<< "])";
}
for (int i = 0; i < num_operands_b; ++i) {
inputs << ", \"" << frag_attr_b.reg_type << "\"((" << frag_attr_b.ptr_sig << "(B))[" << i
inputs << ", \"" << frag_attr_b.reg_type << "\"((" << frag_attr_b.ptr_type << "(B))[" << i
<< "])";
}
for (int i = 0; i < num_operands_c; ++i) {
inputs << ", \"" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_sig << "(C))[" << i
inputs << ", \"" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_type << "(C))[" << i
<< "])";
}
// input of metadata for sparse mma.
Expand All @@ -486,7 +499,7 @@ inline std::tuple<std::string, std::string, std::string> GetMMAOperands(int m, i
if (i != 0) {
outputs << ",";
}
outputs << " \"=" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_sig << "(D))[" << i
outputs << " \"=" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_type << "(D))[" << i
<< "])";
}
return std::make_tuple(templates.str(), inputs.str(), outputs.str());
Expand All @@ -512,7 +525,7 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo
std::string asm_code = R"(
{
__asm__ __volatile__(
"mma{sparse}.sync.aligned.{shape}.{alayout}.{blayout}{satinite}{dtype}{atype}{btype}{ctype}{1bit}"
"mma{sparse}.sync.aligned.{shape}.{alayout}.{blayout}{saturate}{dtype}{atype}{btype}{ctype}{bitop}"
"{templates};\n"
: {outputs}
: {inputs});
Expand All @@ -526,14 +539,14 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo
Replacer replacer;
replacer.register_rule("{sparse}", sparse ? ".sp" : "");
replacer.register_rule("{shape}", shape);
replacer.register_rule("{satinite}", saturate ? ".satfinite" : "");
replacer.register_rule("{saturate}", saturate ? ".satfinite" : "");
replacer.register_rule("{alayout}", A_layout);
replacer.register_rule("{blayout}", B_layout);
replacer.register_rule("{atype}", ptx::DTypeToString(dtype_a));
replacer.register_rule("{btype}", ptx::DTypeToString(dtype_b));
replacer.register_rule("{ctype}", ptx::DTypeToString(dtype_c));
replacer.register_rule("{dtype}", ptx::DTypeToString(dtype_c));
replacer.register_rule("{1bit}", bit_op.empty() ? "" : "." + bit_op + ".popc");
replacer.register_rule("{bitop}", bit_op.empty() ? "" : "." + bit_op + ".popc");
replacer.register_rule("{templates}", templates_str);
replacer.register_rule("{outputs}", outputs_str);
replacer.register_rule("{inputs}", inputs_str);
Expand Down

0 comments on commit 89536ff

Please sign in to comment.