Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FlashAttention Forwad Pass] QSkipLDS with Persistent register cache policy #22

Open
wants to merge 11 commits into
base: main_old_1
Choose a base branch
from
1 change: 1 addition & 0 deletions example/91_tile_program/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ add_example_executable(example_softmax softmax.cpp)
add_example_executable(example_gemm_softmax_gemm gemm_softmax_gemm.cpp)
add_example_executable(example_batched_gemm_softmax_gemm batched_gemm_softmax_gemm.cpp)
add_example_executable(example_fmha_fwd fmha_fwd.cpp)
add_example_executable(example_flash_attention_fwd flash_attention_fwd.cpp)
225 changes: 225 additions & 0 deletions example/91_tile_program/flash_attention_fwd.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
#include <cstring>

#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor/tensor_view.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"

#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"

#include "reference_batched_gemm.hpp"
#include "reference_batched_softmax.hpp"
#include "flash_attention_fwd.hpp"

int main(int argc, char* argv[])
{
using QDataType = ck::half_t;
using KDataType = ck::half_t;
using VDataType = ck::half_t;
using SaccDataType = float;
using SMPLComputeDataType = float;
using PDataType = ck::half_t;
using OaccDataType = float;
using ODataType = ck::half_t;

ck::index_t Batch = 64;
ck::index_t M0 = 4096;
ck::index_t N0 = 4096;
ck::index_t K0 = 128;
ck::index_t N1 = 128;
ck::index_t init_method = 1;
ck::index_t time_kernel = 0;

if(argc == 3)
{
init_method = std::stoi(argv[1]);
time_kernel = std::stoi(argv[2]);
}

if(argc == 8)
{
init_method = std::stoi(argv[1]);
time_kernel = std::stoi(argv[2]);
Batch = std::stoi(argv[3]);
M0 = std::stoi(argv[4]);
N0 = std::stoi(argv[5]);
K0 = std::stoi(argv[6]);
N1 = std::stoi(argv[7]);
}

std::array<ck::index_t, 3> q_lengths{Batch, M0, K0};
std::array<ck::index_t, 3> q_strides{M0 * K0, K0, 1};

std::array<ck::index_t, 3> k_lengths{Batch, N0, K0};
std::array<ck::index_t, 3> k_strides{N0 * K0, K0, 1};

std::array<ck::index_t, 3> v_lengths{Batch, N1, N0};
std::array<ck::index_t, 3> v_strides{N1 * N0, N0, 1};

std::array<ck::index_t, 3> s_lengths{Batch, M0, N0};
std::array<ck::index_t, 3> s_strides{M0 * N0, N0, 1};

std::array<ck::index_t, 3> p_lengths{Batch, M0, N0};
std::array<ck::index_t, 3> p_strides{M0 * N0, N0, 1};

std::array<ck::index_t, 3> o_lengths{Batch, M0, N1};
std::array<ck::index_t, 3> o_strides{M0 * N1, N1, 1};

// host verify
Tensor<QDataType> q_host(q_lengths, q_strides);
Tensor<KDataType> k_host(k_lengths, k_strides);
Tensor<VDataType> v_host(v_lengths, v_strides);
Tensor<SMPLComputeDataType> s_host_ref(s_lengths, s_strides);
Tensor<PDataType> p_host_ref(p_lengths, p_strides);
Tensor<ODataType> o_host_ref(o_lengths, o_strides);
Tensor<ODataType> o_host_dev(o_lengths, o_strides);

switch(init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f}(q_host);
ck::utils::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f}(k_host);
ck::utils::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f}(v_host);
break;
case 2:
ck::utils::FillUniformDistribution<QDataType>{-3.f, 3.f}(q_host);
ck::utils::FillUniformDistribution<KDataType>{-3.f, 3.f}(k_host);
ck::utils::FillUniformDistribution<VDataType>{-3.f, 3.f}(v_host);
break;
case 3:
ck::utils::FillConstant<QDataType>{1.f}(q_host);
ck::utils::FillConstant<KDataType>{1.f}(k_host);
ck::utils::FillConstant<VDataType>{1.f}(v_host);
break;
case 4:
ck::utils::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f}(q_host);
ck::utils::FillConstant<KDataType>{1.f}(k_host);
ck::utils::FillConstant<VDataType>{1.f}(v_host);
break;
case 5:
ck::utils::FillConstant<QDataType>{1.f}(q_host);
ck::utils::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f}(k_host);
ck::utils::FillConstant<VDataType>{1.f}(v_host);
break;
case 6:
ck::utils::FillConstant<QDataType>{1.f}(q_host);
ck::utils::FillConstant<KDataType>{1.f}(k_host);
ck::utils::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f}(v_host);
break;
case 7:
ck::utils::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f}(q_host);
ck::utils::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f}(k_host);
ck::utils::FillConstant<VDataType>{1.f}(v_host);
break;
case 8:
ck::utils::FillConstant<QDataType>{1.f}(q_host);
ck::utils::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f}(k_host);
ck::utils::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f}(v_host);
break;
case 9:
ck::utils::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f}(q_host);
ck::utils::FillConstant<KDataType>{1.f}(k_host);
ck::utils::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f}(v_host);
break;
default:
ck::utils::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f}(q_host);
ck::utils::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f}(k_host);
ck::utils::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f}(v_host);
}

// reference
reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
q_host, k_host, s_host_ref);
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(s_host_ref,
p_host_ref);
reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
p_host_ref, v_host, o_host_ref);

DeviceMem q_buf(sizeof(QDataType) * q_host.GetElementSpaceSize());
DeviceMem k_buf(sizeof(KDataType) * k_host.GetElementSpaceSize());
DeviceMem v_buf(sizeof(VDataType) * v_host.GetElementSpaceSize());
DeviceMem o_buf(sizeof(ODataType) * o_host_ref.GetElementSpaceSize());

q_buf.ToDevice(q_host.mData.data());
k_buf.ToDevice(k_host.mData.data());
v_buf.ToDevice(v_host.mData.data());

constexpr ck::index_t kM0PerBlock = 128;
constexpr ck::index_t kN0PerBlock = 128;
constexpr ck::index_t kK0PerBlock = 32;
constexpr ck::index_t kN1PerBlock = 128;
constexpr ck::index_t kK1PerBlock = 32;

constexpr ck::index_t kBlockSize = 256;
constexpr ck::index_t kHeadDim = 128;

ck::index_t kGridSize = Batch * (M0 / kM0PerBlock) * (N1 / kN1PerBlock);

std::cout << "grid size " << kGridSize << std::endl;

constexpr ck::index_t kWarpPerCu = 8; // 2 warps per SIMD
constexpr ck::index_t kWarpPerBlock = kBlockSize / warpSize;
constexpr ck::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;

float ave_time = launch_kernel<kBlockSize, kBlockPerCu>(
StreamConfig{nullptr, static_cast<bool>(time_kernel)},
FlashAttentionFwd<QDataType,
KDataType,
VDataType,
SaccDataType,
SMPLComputeDataType,
PDataType,
OaccDataType,
ODataType,
kBlockSize,
kHeadDim,
kM0PerBlock,
kN0PerBlock,
kK0PerBlock,
kN1PerBlock,
kK1PerBlock>{},
kGridSize,
kBlockSize,
0,
static_cast<QDataType*>(q_buf.GetDeviceBuffer()),
static_cast<KDataType*>(k_buf.GetDeviceBuffer()),
static_cast<VDataType*>(v_buf.GetDeviceBuffer()),
static_cast<ODataType*>(o_buf.GetDeviceBuffer()),
M0,
N0,
K0,
N1,
Batch,
K0, // StrideQ
K0, // StrideK
N0, // StrideV
N1, // StrideO
M0 * K0, // BatchStrideQ
N0 * K0, // BatchStrideK
N1 * N0, // BatchStrideV
M0 * N1); // BatchStrideO

o_buf.FromDevice(o_host_dev.mData.data());

std::size_t flop =
std::size_t(2) * Batch * M0 * N0 * K0 + std::size_t(2) * Batch * M0 * N1 * N0;
std::size_t num_btype =
sizeof(QDataType) * Batch * M0 * K0 + sizeof(KDataType) * Batch * N0 * K0 +
sizeof(VDataType) * Batch * N1 * N0 + sizeof(ODataType) * Batch * M0 * N1;

float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

float gb_per_sec = num_btype / 1.E6 / ave_time;

std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;

return !ck::utils::check_err(o_host_dev, o_host_ref);
}
113 changes: 113 additions & 0 deletions example/91_tile_program/flash_attention_fwd.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"

#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck/tile_program/block_tile/block_reduce.hpp"

#include "flash_attention_fwd_impl.hpp"

// S[M0, N0] = Q[M0, K0] * K[N0, K0]
// P[M0, N0] = Softmax(S[M0, N0])
// O[M0, N1] = P[M0, N0] * V[N1, N0]
template <typename QDataType,
typename KDataType,
typename VDataType,
typename SaccDataType,
typename SMPLComputeDataType,
typename PDataType,
typename OaccDataType,
typename ODataType,
ck::index_t kBlockSize,
ck::index_t kHeadDim,
ck::index_t kM0PerBlock,
ck::index_t kN0PerBlock,
ck::index_t kK0PerBlock,
ck::index_t kN1PerBlock,
ck::index_t kK1PerBlock>
struct FlashAttentionFwd
{
__device__ void operator()(const QDataType* q_ptr,
const KDataType* k_ptr,
const VDataType* v_ptr,
ODataType* o_ptr,
const ck::index_t M0,
const ck::index_t N0,
const ck::index_t K0,
const ck::index_t N1,
const ck::index_t /* Batch */,
const ck::index_t StrideQ,
const ck::index_t StrideK,
const ck::index_t StrideV,
const ck::index_t StrideO,
const ck::index_t BatchStrideQ,
const ck::index_t BatchStrideK,
const ck::index_t BatchStrideV,
const ck::index_t BatchStrideO) const
{
using namespace ck;

// divide problem
const index_t num_tile_m0 = M0 / kM0PerBlock;
const index_t num_tile_n1 = N1 / kN1PerBlock;

const index_t id_block = get_block_id();

const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;

return ck::make_tuple(quotient, modulus);
};

const auto [itmp, id_tile_n] = f(id_block, num_tile_n1);
const auto [id_tile_batch, id_tile_m] = f(itmp, num_tile_m0);

const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch);
const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile_m * kM0PerBlock);
const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile_n * kN1PerBlock);

const auto kernel_impl = FlashAttentionFwdImpl<QDataType,
KDataType,
VDataType,
SaccDataType,
SMPLComputeDataType,
PDataType,
OaccDataType,
ODataType,
kBlockSize,
kHeadDim,
kM0PerBlock,
kN0PerBlock,
kK0PerBlock,
kN1PerBlock,
kK1PerBlock>{};

kernel_impl(q_ptr + iBatch * BatchStrideQ,
k_ptr + iBatch * BatchStrideK,
v_ptr + iBatch * BatchStrideV,
o_ptr + iBatch * BatchStrideO,
M0,
N0,
K0,
N1,
StrideQ,
StrideK,
StrideV,
StrideO,
iM0,
iN1);
}
};
Loading