Skip to content

Commit

Permalink
[CUTLASS] Add FP8 gemm kernels
Browse files Browse the repository at this point in the history
This PR introduces the sm90a FP8 kernels from CUTLASS. These kernels
are helpful in the cases of small `M`, where cuBLAS has unoptimized
performance.
  • Loading branch information
MasterJH5574 committed Sep 24, 2024
1 parent 2a87c4c commit 6972e95
Show file tree
Hide file tree
Showing 5 changed files with 349 additions and 15 deletions.
1 change: 1 addition & 0 deletions cmake/modules/contrib/CUTLASS.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ if(USE_CUDA AND USE_CUTLASS)
if (CMAKE_CUDA_ARCHITECTURES MATCHES "90a")
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm.cu)
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_group_gemm.cu)
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_gemm.cu)
endif()
if(TVM_CUTLASS_RUNTIME_SRCS)
add_library(tvm_cutlass_objs OBJECT ${TVM_CUTLASS_RUNTIME_SRCS})
Expand Down
6 changes: 4 additions & 2 deletions src/runtime/contrib/cublas/cublas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,13 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
&bias->data, sizeof(float*)));
}

if (scaleA != nullptr && scaleB != nullptr) {
if (scaleA != nullptr) {
auto scaleA_data = static_cast<char*>(scaleA->data) + scaleA->byte_offset;
auto scaleB_data = static_cast<char*>(scaleB->data) + scaleB->byte_offset;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&scaleA_data, sizeof(float*)));
}
if (scaleB != nullptr) {
auto scaleB_data = static_cast<char*>(scaleB->data) + scaleB->byte_offset;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&scaleB_data, sizeof(float*)));
}
Expand Down
95 changes: 95 additions & 0 deletions src/runtime/contrib/cutlass/fp8_gemm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <cuda_fp16.h>
#include <float.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

#include "../cublas/cublas_utils.h"
#include "gemm_runner.cuh"

#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)

struct KernelTraitsM64 {
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using TileShape = Shape<_64, _64, _128>;
using ClusterShape = Shape<_1, _8, _1>;
};

namespace tvm {
namespace runtime {

template <typename ElementA, typename ElementB, typename ElementC>
void tvm_cutlass_fp8_gemm(NDArray x, NDArray weight, NDArray workspace, NDArray alpha,
NDArray out) {
// Workspace is used for storing device-side gemm arguments and cutlass internal workspace.
// Recommened size is 4MB.
auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
ICHECK(func != nullptr);
CHECK_GE(x->ndim, 2);
CHECK_EQ(weight->ndim, 2);
CHECK_EQ(workspace->ndim, 1);
CHECK_GE(out->ndim, 2);
CHECK_EQ(alpha->dtype.code, kDLFloat);
CHECK_EQ(alpha->dtype.bits, 32);
CHECK_EQ(alpha->ndim, 1);
CHECK_EQ(alpha->shape[0], 1);
int64_t m = 1;
for (int i = 0; i < x->ndim - 1; ++i) {
m *= x->shape[i];
}
int64_t n = weight->shape[0];
CHECK_EQ(x->shape[x->ndim - 1], weight->shape[1]) << "Only col-major weight is supported now.";
int64_t k = x->shape[x->ndim - 1];
const float* beta = nullptr;
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
if (m <= 64) {
cutlass_gemm<KernelTraitsM64>(
static_cast<ElementA*>(x->data), static_cast<ElementB*>(weight->data),
static_cast<uint8_t*>(workspace->data), workspace->shape[0], m, n, k,
static_cast<float*>(alpha->data), beta, static_cast<ElementC*>(out->data), stream);
} else {
tvm::contrib::CuBlasLtThreadEntry* cublas_entry =
tvm::contrib::CuBlasLtThreadEntry::ThreadLocal();
tvm::contrib::CallCublasLt(cublas_entry->handle, stream, cublas_entry->matmul_pref_desc,
x.operator->(), weight.operator->(), nullptr, alpha.operator->(),
nullptr, out.operator->(), /*transa=*/false, /*transb=*/true,
cublas_entry->workspace_ptr, cublas_entry->workspace_size,
CUBLASLT_EPILOGUE_DEFAULT, std::nullopt);
}
}

TVM_REGISTER_GLOBAL("cutlass.gemm_e5m2_e5m2_fp16")
.set_body_typed(
tvm_cutlass_fp8_gemm<cutlass::float_e5m2_t, cutlass::float_e5m2_t, cutlass::half_t>);

TVM_REGISTER_GLOBAL("cutlass.gemm_e5m2_e4m3_fp16")
.set_body_typed(
tvm_cutlass_fp8_gemm<cutlass::float_e5m2_t, cutlass::float_e4m3_t, cutlass::half_t>);

TVM_REGISTER_GLOBAL("cutlass.gemm_e4m3_e4m3_fp16")
.set_body_typed(
tvm_cutlass_fp8_gemm<cutlass::float_e4m3_t, cutlass::float_e4m3_t, cutlass::half_t>);

} // namespace runtime
} // namespace tvm

#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED
155 changes: 155 additions & 0 deletions src/runtime/contrib/cutlass/gemm_runner.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <fstream>
#include <iostream>
#include <sstream>
#include <variant>
#include <vector>

#include "../../cuda/cuda_common.h"

// clang-format off
#include "cutlass/cutlass.h"

#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
// clang-format on

#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
CHECK(error == cutlass::Status::kSuccess) \
<< "Got cutlass error: " << cutlassGetStatusString(error); \
}

using namespace cute;
using ProblemShape = Shape<int, int, int>; // <M, N, K>

template <typename KernelTraits, typename ElementA, typename ElementB, typename ElementC,
typename LayoutA = cutlass::layout::RowMajor,
typename LayoutB = cutlass::layout::ColumnMajor,
typename LayoutC = cutlass::layout::RowMajor>
struct CutlassGemmRunner {
static constexpr int AlignmentA =
128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements
// (up to 16 bytes)

static constexpr int AlignmentB =
128 / cutlass::sizeof_bits<ElementB>::value; // Alignment of B matrix in units of elements
// (up to 16 bytes)

static constexpr int AlignmentC =
128 / cutlass::sizeof_bits<ElementC>::value; // Alignment of C matrix in units of elements
// (up to 16 bytes)

// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ScaleType = std::variant<ElementAccumulator, const ElementAccumulator*>;
using ArchTag =
cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = typename KernelTraits::TileShape;
using ClusterShape = typename KernelTraits::ClusterShape;
using StageCountType =
cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = typename KernelTraits::KernelSchedule; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; // Epilogue to launch

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC, ElementC, LayoutC, AlignmentC, EpilogueSchedule>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB,
ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;

using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;

void run_gemm(const ElementA* ptr_A, const ElementB* ptr_B, const ElementC* ptr_C,
ElementC* ptr_D, ProblemShape* problem_size, StrideA* stride_A, StrideB* stride_B,
StrideC* stride_C, StrideD* stride_D, uint8_t* workspace, int64_t workspace_size,
ScaleType alpha, ScaleType beta, cudaStream_t stream) {
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGemm,
*problem_size,
{ptr_A, *stride_A, ptr_B, *stride_B},
{{}, ptr_C, *stride_C, ptr_D, *stride_D},
// {epilogue_params, ptr_C, *stride_C, ptr_D, *stride_D},
hw_info};

ICHECK(alpha.index() == beta.index()) << "alpha and beta must have the same type";
if (std::holds_alternative<ElementAccumulator>(alpha)) {
arguments.epilogue.thread.alpha = std::get<ElementAccumulator>(alpha);
arguments.epilogue.thread.beta = std::get<ElementAccumulator>(beta);
} else if (std::holds_alternative<const ElementAccumulator*>(alpha)) {
arguments.epilogue.thread.alpha_ptr = std::get<const ElementAccumulator*>(alpha);
arguments.epilogue.thread.beta_ptr = std::get<const ElementAccumulator*>(beta);
} else {
LOG(FATAL) << "Unsupported alpha and beta type";
throw;
}

Gemm gemm_op;
CUTLASS_CHECK(gemm_op.can_implement(arguments));
CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
CUTLASS_CHECK(gemm_op.run(stream));
}
};

template <typename KernelTraits, typename ElementA, typename ElementB, typename ElementC>
void cutlass_gemm(ElementA* x, ElementB* weight, uint8_t* workspace, int64_t workspace_size,
int64_t m, int64_t n, int64_t k, std::variant<float, const float*> alpha,
std::variant<float, const float*> beta, ElementC* out, cudaStream_t stream) {
using Runner = CutlassGemmRunner<KernelTraits, ElementA, ElementB, ElementC>;
using StrideA = typename Runner::StrideA;
using StrideB = typename Runner::StrideB;
using StrideC = typename Runner::StrideC;

Runner runner;
StrideA stride_A = cute::make_stride(k, Int<1>{}, int64_t{0});
StrideB stride_B = cute::make_stride(k, Int<1>{}, int64_t{0});
StrideC stride_D = cute::make_stride(n, Int<1>{}, int64_t{0});
ProblemShape problem_size{static_cast<int>(m), static_cast<int>(n), static_cast<int>(k)};
runner.run_gemm(x, weight, out, out, &problem_size, &stride_A, &stride_B, &stride_D, &stride_D,
workspace, workspace_size, alpha, beta, stream);
}
Loading

0 comments on commit 6972e95

Please sign in to comment.