Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: slight optimization on fragment layout swizzle #458

Merged
merged 1 commit into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions include/flashinfer/frag_layout_swizzle.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ __device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b(uint32_t x) {
}

__device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b_trans(uint32_t x) {
x = __byte_perm(x, x, 0x3120);
uint32_t tmp = __shfl_xor_sync(0xffffffff, x, 0x4);
x = __byte_perm(x, tmp, ((threadIdx.x & 0x4) == 0) ? 0x5410 : 0x3276);
x = __byte_perm(x, tmp, ((threadIdx.x & 0x4) == 0) ? 0x6420 : 0x3175);
tmp = __shfl_xor_sync(0xffffffff, x, 0x8);
x = __byte_perm(x, tmp, ((threadIdx.x & 0x8) == 0) ? 0x5410 : 0x3276);
tmp = __shfl_xor_sync(0xffffffff, x, 0x10);
Expand Down
1 change: 0 additions & 1 deletion include/flashinfer/vec_dtypes.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ __device__ void fast_dequant_f8f16x4(uint32_t* input, uint2* output) {
constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
constexpr int MASK3 = MASK2 & 0x7fffffff;
constexpr int MASK = MASK3 | (MASK3 >> 16);
// Final MASK value: 0x7F007F00
q = __byte_perm(q, q, 0x1302);

// Extract and shift FP8 values to FP16 format
Expand Down