From 6d564e060d78185f174dd83fef93f02e771b3208 Mon Sep 17 00:00:00 2001 From: yzh119 Date: Fri, 19 Jul 2024 02:06:32 +0000 Subject: [PATCH 1/2] upd --- include/flashinfer/attention/prefill.cuh | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index e1c676bb..5b02b34e 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -894,13 +894,8 @@ __device__ __forceinline__ void write_o_reg_gmem( uint32_t o_frag_f16[4]; vec_cast((DTypeOut*)o_frag_f16, o_frag[fx][fy]); uint32_t o_smem_offset_w = smem_t::get_permuted_offset( - (warp_idx_x * num_frags_x + fx) * 16 + lane_idx / 4, fy * 2); - ((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0]; - ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * channel_size_128b_out))[lane_idx % 4] = - o_frag_f16[1]; - ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[lane_idx % 4] = o_frag_f16[2]; - ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) + - 8 * channel_size_128b_out))[lane_idx % 4] = o_frag_f16[3]; + (warp_idx_x * num_frags_x + fx) * 16 + lane_idx % 16, fy * 2 + lane_idx / 16); + o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); } } From fcacb4d7e1a8ab6fa211db9b511bdc70643b5648 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Fri, 19 Jul 2024 02:38:25 +0000 Subject: [PATCH 2/2] upd --- include/flashinfer/attention/prefill.cuh | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 5b02b34e..13d7a54f 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -893,9 +893,20 @@ __device__ __forceinline__ void write_o_reg_gmem( for (uint32_t fy = 0; fy < num_frags_y; ++fy) { uint32_t o_frag_f16[4]; vec_cast((DTypeOut*)o_frag_f16, o_frag[fx][fy]); +#ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED uint32_t o_smem_offset_w = smem_t::get_permuted_offset( (warp_idx_x * num_frags_x + fx) * 16 + lane_idx % 16, fy * 2 + lane_idx / 16); o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); +#else + uint32_t o_smem_offset_w = smem_t::get_permuted_offset( + (warp_idx_x * num_frags_x + fx) * 16 + lane_idx / 4, fy * 2); + ((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0]; + ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * channel_size_128b_out))[lane_idx % 4] = + o_frag_f16[1]; + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[lane_idx % 4] = o_frag_f16[2]; + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) + + 8 * channel_size_128b_out))[lane_idx % 4] = o_frag_f16[3]; +#endif } }