Skip to content

Commit

Permalink
upd
Browse files Browse the repository at this point in the history
ok

how come

upd
  • Loading branch information
yzh119 committed Jul 24, 2024
1 parent 2ab2bca commit a8bc999
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 17 deletions.
41 changes: 40 additions & 1 deletion include/flashinfer/attention/cascade.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,29 @@ __global__ void MergeStatesKernel(DTypeIn* __restrict__ V, float* __restrict__ S
uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t pos = blockIdx.x;
uint32_t head_idx = ty;
state_t<vec_size> st;

if (num_index_sets == 0) {
vec_t<DTypeOut, vec_size> v;
v.fill(DTypeOut(0));
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = -5e4;
}
return;
}

if (num_index_sets == 1) {
vec_t<DTypeOut, vec_size> v;
v.cast_load(V + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = S[pos * num_heads + head_idx];
}
return;
}

vec_t<float, vec_size> v_merged_vec;
state_t<vec_size> st;
v_merged_vec.fill(0.f);
#pragma unroll 2
for (uint32_t iter = 0; iter < num_index_sets; ++iter) {
Expand Down Expand Up @@ -296,6 +316,25 @@ __global__ void VariableLengthMergeStatesKernel(DTypeIn* __restrict__ V, float*
float* s_smem = (float*)(smem + num_smem_stages * bdy * head_dim * sizeof(DTypeIn));
const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos];

if (num_index_sets == 0) {
vec_t<DTypeOut, vec_size> v;
v.fill(DTypeOut(0));
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = -5e4;
}
return;
}

if (num_index_sets == 1) {
vec_t<DTypeOut, vec_size> v;
v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size);
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = S[indptr[pos] * num_heads + head_idx];
}
}

#pragma unroll
for (uint32_t iter = 0; iter < num_smem_stages; ++iter) {
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
Expand Down
36 changes: 20 additions & 16 deletions src/test_cascade.cu
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ void _TestVariableLengthMergeKernelCorrectness(size_t seq_len, size_t num_heads,
template <typename T>
void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t num_heads,
size_t head_dim, bool sparse_s) {
EXPECT_GT(num_index_sets, 1) << "num_index_sets must be greater than 1";
std::vector<T> V_host(seq_len * num_index_sets * num_heads * head_dim);
std::vector<float> V_host_trans_f32(num_index_sets * seq_len * num_heads * head_dim);
std::vector<float> S_host(seq_len * num_index_sets * num_heads);
Expand Down Expand Up @@ -178,20 +177,25 @@ void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t n
thrust::device_vector<T> V_merged_1_device(seq_len * num_heads * head_dim);
thrust::device_vector<float> S_merged_1_device(seq_len * num_heads);

// Method 0: use MergeState
MergeState(thrust::raw_pointer_cast(V_device_trans_f32.data()),
thrust::raw_pointer_cast(S_device_trans.data()),
thrust::raw_pointer_cast(V_device_trans_f32.data() + seq_len * num_heads * head_dim),
thrust::raw_pointer_cast(S_device_trans.data() + seq_len * num_heads),
thrust::raw_pointer_cast(V_merged_0_device.data()),
thrust::raw_pointer_cast(S_merged_0_device.data()), seq_len, num_heads, head_dim);
for (uint i = 2; i < num_index_sets; ++i) {
MergeStateInPlace(
thrust::raw_pointer_cast(V_merged_0_device.data()),
thrust::raw_pointer_cast(S_merged_0_device.data()),
thrust::raw_pointer_cast(V_device_trans_f32.data() + i * seq_len * num_heads * head_dim),
thrust::raw_pointer_cast(S_device_trans.data() + i * seq_len * num_heads), seq_len,
num_heads, head_dim);
if (num_index_sets > 1) {
// Method 0: use MergeState
MergeState(thrust::raw_pointer_cast(V_device_trans_f32.data()),
thrust::raw_pointer_cast(S_device_trans.data()),
thrust::raw_pointer_cast(V_device_trans_f32.data() + seq_len * num_heads * head_dim),
thrust::raw_pointer_cast(S_device_trans.data() + seq_len * num_heads),
thrust::raw_pointer_cast(V_merged_0_device.data()),
thrust::raw_pointer_cast(S_merged_0_device.data()), seq_len, num_heads, head_dim);
for (uint i = 2; i < num_index_sets; ++i) {
MergeStateInPlace(
thrust::raw_pointer_cast(V_merged_0_device.data()),
thrust::raw_pointer_cast(S_merged_0_device.data()),
thrust::raw_pointer_cast(V_device_trans_f32.data() + i * seq_len * num_heads * head_dim),
thrust::raw_pointer_cast(S_device_trans.data() + i * seq_len * num_heads), seq_len,
num_heads, head_dim);
}
} else {
V_merged_0_device = V_device;
S_merged_0_device = S_device;
}

// Method 1: use MergeStates
Expand Down Expand Up @@ -479,7 +483,7 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size,

template <typename T>
void TestMergeKernelCorrectness() {
for (size_t num_index_sets : {2, 9, 81, 513}) {
for (size_t num_index_sets : {1, 2, 9, 81, 513}) {
for (size_t seq_len : {4, 16, 77}) {
for (size_t num_heads : {1, 21, 32}) {
for (size_t head_dim : {64, 128, 256}) {
Expand Down

0 comments on commit a8bc999

Please sign in to comment.