Skip to content

Commit

Permalink
feat: add mma instructions for fp8 (#179)
Browse files Browse the repository at this point in the history
MMA instructions for fp8 are available in PTX 8.4 and CUDA 12.4.
  • Loading branch information
yzh119 authored Mar 13, 2024
1 parent 238563f commit d305798
Showing 1 changed file with 108 additions and 0 deletions.
108 changes: 108 additions & 0 deletions include/flashinfer/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ namespace flashinfer {

namespace mma {

#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120400)
#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 890))
#define FLASHINFER_MMA_F8F8F32_M16N8K32_ENABLED
#endif
#endif

#if (__CUDACC_VER_MAJOR__ >= 11)
#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 900))
#define FLASHINFER_STMATRIX_M8N8X4_ENABLED
Expand Down Expand Up @@ -115,6 +121,105 @@ __device__ __forceinline__ void stmatrix_m8n8x4(uint32_t* R, T* smem_ptr) {
#endif
}

/*!
* \brief Wrapper of two mma m16n8k32 instructions for row major and column major f8 matrix
* multiplication, accumulated in f32.
* \tparam T data type of the fragment
* \tparam mma_mode whether we are initializing the accumulator or updating it
* \param C pointer to the accumulator
* \param A pointer to the fragment of matrix A
* \param B pointer to the fragment of matrix B
*/
template <typename T, MMAMode mma_mode = MMAMode::kInplaceUpdate>
__device__ __forceinline__ void mma_sync_m16n16k32_row_col_f8f8f32(float* C, uint32_t* A,
uint32_t* B) {
#if defined(FLASHINFER_MMA_F8F8F32_M16N8K32_ENABLED)
if constexpr (mma_mode == MMAMode::kInit) {
if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3])
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f),
"f"(0.f), "f"(0.f));
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7])
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(0.f), "f"(0.f),
"f"(0.f), "f"(0.f));
} else { // e5m2
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3])
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f),
"f"(0.f), "f"(0.f));
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7])
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(0.f), "f"(0.f),
"f"(0.f), "f"(0.f));
}
} else {
if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3])
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]),
"f"(C[2]), "f"(C[3]));
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7])
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(C[4]), "f"(C[5]),
"f"(C[6]), "f"(C[7]));
} else { // e5m2
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3])
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]),
"f"(C[2]), "f"(C[3]));
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7])
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(C[4]), "f"(C[5]),
"f"(C[6]), "f"(C[7]));
}
}
#else
static_assert(false, "fp8 mma instruction is only available for sm89, PTX 8.4+ and CUDA 12.4+");
#endif
}

/*!
* \brief Wrapper of two mma m16n8k16 instructions for row major and column major f16 matrix
* multiplication, accumulated in f32.
Expand Down Expand Up @@ -282,6 +387,9 @@ __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, u
#endif
}

// template <typename DType>
// __device__ __forceinline__ void

/*!
* \brief Use mma instructions to compute rowsum.
*/
Expand Down

0 comments on commit d305798

Please sign in to comment.