diff --git a/CMakeLists.txt b/CMakeLists.txt index 5bde8eae3..519a4c994 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,6 +32,8 @@ flashinfer_option(FLASHINFER_CASCADE "Whether to compile cascade kernel tests/be flashinfer_option(FLASHINFER_SAMPLING "Whether to compile sampling kernel tests/benchmarks or not." OFF) flashinfer_option(FLASHINFER_NORM "Whether to compile normalization kernel tests/benchmarks or not." OFF) flashinfer_option(FLASHINFER_DISTRIBUTED "Whether to compile distributed kernel tests/benchmarks or not." OFF) +flashinfer_option(FLASHINFER_FASTDIV_TEST "Whether to compile fastdiv kernel tests or not." OFF) +flashinfer_option(FLASHINFER_FASTDEQAUNT_TEST "Whether to compile fast dequant kernel tests or not." OFF) flashinfer_option(FLASHINFER_TVM_BINDING "Whether to compile tvm binding or not." OFF) flashinfer_option(FLASHINFER_TVM_SOURCE_DIR "The path to tvm for building tvm binding." "") @@ -477,6 +479,17 @@ if(FLASHINFER_FASTDIV_TEST) target_link_libraries(test_fastdiv PRIVATE gtest gtest_main) endif(FLASHINFER_FASTDIV_TEST) +if(FLASHINFER_FASTDEQUANT_TEST) + message(STATUS "Compile fast dequant test.") + file(GLOB_RECURSE TEST_FAST_DEQUANT_SRCS ${PROJECT_SOURCE_DIR}/src/test_fast_dequant.cu) + add_executable(test_fast_dequant ${TEST_FAST_DEQUANT_SRCS}) + target_include_directories(test_fast_dequant PRIVATE ${FLASHINFER_INCLUDE_DIR}) + target_include_directories(test_fast_dequant PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) + target_link_libraries(test_fast_dequant PRIVATE gtest gtest_main) +endif(FLASHINFER_FASTDIV_TEST) + + + if (FLASHINFER_DISTRIBUTED) find_package(MPI REQUIRED) diff --git a/cmake/config.cmake b/cmake/config.cmake index 75ea4fc09..0d51e4916 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -18,6 +18,8 @@ set(FLASHINFER_SAMPLING ON) set(FLASHINFER_NORMALIZATION ON) # Whether to compile fastdiv tests set(FLASHINFER_FASTDIV_TEST ON) +# Whether to compile fastdequant tests +set(FLASHINFER_FASTDEQUANT_TEST ON) # Whether to compile distributed tests set(FLASHINFER_DISTRIBUTED ON) # The following configurations can impact the binary diff --git a/include/flashinfer/vec_dtypes.cuh b/include/flashinfer/vec_dtypes.cuh index a55e5990f..38719b43b 100644 --- a/include/flashinfer/vec_dtypes.cuh +++ b/include/flashinfer/vec_dtypes.cuh @@ -16,19 +16,19 @@ #ifndef VEC_DTYPES_CUH_ #define VEC_DTYPES_CUH_ -#ifdef FLASHINFER_ENABLE_BF16 #include -#endif #include -#ifdef FLASHINFER_ENABLE_FP8 #include -#endif #include #include namespace flashinfer { +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 900)) +#define FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED +#endif + #define FLASHINFER_INLINE inline __attribute__((always_inline)) __device__ /******************* vec_t type cast *******************/ @@ -74,11 +74,130 @@ struct vec_cast { } }; -#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 900)) +template +constexpr FLASHINFER_INLINE int get_exponent_bits() { + if constexpr (std::is_same::value) { + return 4; + } else if constexpr (std::is_same::value) { + return 5; + } else if constexpr (std::is_same::value) { + return 5; + } else if constexpr (std::is_same::value) { + return 8; + } +} + +template +constexpr FLASHINFER_INLINE int get_mantissa_bits() { + if constexpr (std::is_same::value) { + return 3; + } else if constexpr (std::is_same::value) { + return 2; + } else if constexpr (std::is_same::value) { + return 11; + } else if constexpr (std::is_same::value) { + return 7; + } +} + +/*! + * \brief Fallback to software fast dequant implementation if hardware dequantization is not + * available. + * \note Inspired by Marlin's fast dequantization, but here we don't have to permute + * weights order. + * \ref + * https://github.com/vllm-project/vllm/blob/6dffa4b0a6120159ef2fe44d695a46817aff65bc/csrc/quantization/fp8/fp8_marlin.cu#L120 + */ +template +__device__ void fast_dequant_f8f16x4(uint32_t* input, uint2* output) { + uint32_t q = *input; + if constexpr (std::is_same::value && + std::is_same::value) { + output->x = __byte_perm(0U, q, 0x5140); + output->y = __byte_perm(0U, q, 0x7362); + } else { + constexpr int FP8_EXPONENT = get_exponent_bits(); + constexpr int FP8_MANTISSA = get_mantissa_bits(); + constexpr int FP16_EXPONENT = get_exponent_bits(); + + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; + // Calculate MASK for extracting mantissa and exponent + constexpr int MASK1 = 0x80000000; + constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); + constexpr int MASK3 = MASK2 & 0x7fffffff; + constexpr int MASK = MASK3 | (MASK3 >> 16); + // Final MASK value: 0x7F007F00 + q = __byte_perm(q, q, 0x1302); + + // Extract and shift FP8 values to FP16 format + uint32_t Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + uint32_t Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); + + constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + // Construct and apply exponent bias + if (std::is_same::value) { + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + *(half2*)&(output->x) = __hmul2(*reinterpret_cast(&Out1), bias_reg); + *(half2*)&(output->y) = __hmul2(*reinterpret_cast(&Out2), bias_reg); + } else { + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + // Convert to bfloat162 and apply bias + *(nv_bfloat162*)&(output->x) = + __hmul2(*reinterpret_cast(&Out1), bias_reg); + *(nv_bfloat162*)&(output->y) = + __hmul2(*reinterpret_cast(&Out2), bias_reg); + } + } +} + +template <> +struct vec_cast { + template + FLASHINFER_INLINE static void cast(nv_bfloat16* dst, const __nv_fp8_e4m3* src) { + if constexpr (vec_size == 1) { + dst[0] = nv_bfloat16(src[0]); + } else if constexpr (vec_size == 2) { + dst[0] = nv_bfloat16(src[0]); + dst[1] = nv_bfloat16(src[1]); + } else { + static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); +#pragma unroll + for (uint32_t i = 0; i < vec_size / 4; ++i) { + fast_dequant_f8f16x4<__nv_fp8_e4m3, nv_bfloat16>((uint32_t*)&src[i * 4], + (uint2*)&dst[i * 4]); + } + } + } +}; + +template <> +struct vec_cast { + template + FLASHINFER_INLINE static void cast(nv_bfloat16* dst, const __nv_fp8_e5m2* src) { + if constexpr (vec_size == 1) { + dst[0] = nv_bfloat16(src[0]); + } else if constexpr (vec_size == 2) { + dst[0] = nv_bfloat16(src[0]); + dst[1] = nv_bfloat16(src[1]); + } else { + static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); +#pragma unroll + for (uint32_t i = 0; i < vec_size / 4; ++i) { + fast_dequant_f8f16x4<__nv_fp8_e5m2, nv_bfloat16>((uint32_t*)&src[i * 4], + (uint2*)&dst[i * 4]); + } + } + } +}; + template <> struct vec_cast<__nv_fp8_e4m3, half> { template FLASHINFER_INLINE static void cast(__nv_fp8_e4m3* dst, const half* src) { +#ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED if constexpr (vec_size == 1) { dst[0] = __nv_fp8_e4m3(src[0]); } else { @@ -90,6 +209,12 @@ struct vec_cast<__nv_fp8_e4m3, half> { *(uint16_t*)&dst[i * 2] = y; } } +#else +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = __nv_fp8_e4m3(src[i]); + } +#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED } }; @@ -97,6 +222,7 @@ template <> struct vec_cast<__nv_fp8_e5m2, half> { template FLASHINFER_INLINE static void cast(__nv_fp8_e5m2* dst, const half* src) { +#ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED if constexpr (vec_size == 1) { dst[0] = __nv_fp8_e5m2(src[0]); } else { @@ -108,6 +234,12 @@ struct vec_cast<__nv_fp8_e5m2, half> { *(uint16_t*)&dst[i * 2] = y; } } +#else +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = __nv_fp8_e5m2(src[i]); + } +#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED } }; @@ -115,6 +247,7 @@ template <> struct vec_cast { template FLASHINFER_INLINE static void cast(half* dst, const __nv_fp8_e4m3* src) { +#ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED if constexpr (vec_size == 1) { dst[0] = half(src[0]); } else { @@ -126,6 +259,20 @@ struct vec_cast { *(uint32_t*)&dst[i * 2] = y; } } +#else + if constexpr (vec_size == 1) { + dst[0] = half(src[0]); + } else if constexpr (vec_size == 2) { + dst[0] = half(src[0]); + dst[1] = half(src[1]); + } else { + static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); +#pragma unroll + for (uint32_t i = 0; i < vec_size / 4; ++i) { + fast_dequant_f8f16x4<__nv_fp8_e4m3, half>((uint32_t*)&src[i * 4], (uint2*)&dst[i * 4]); + } + } +#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED } }; @@ -133,6 +280,7 @@ template <> struct vec_cast { template FLASHINFER_INLINE static void cast(half* dst, const __nv_fp8_e5m2* src) { +#ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED if constexpr (vec_size == 1) { dst[0] = half(src[0]); } else { @@ -144,13 +292,23 @@ struct vec_cast { *(uint32_t*)&dst[i * 2] = y; } } +#else + if constexpr (vec_size == 1) { + dst[0] = half(src[0]); + } else if constexpr (vec_size == 2) { + dst[0] = half(src[0]); + dst[1] = half(src[1]); + } else { + static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); +#pragma unroll + for (uint32_t i = 0; i < vec_size / 4; ++i) { + fast_dequant_f8f16x4<__nv_fp8_e5m2, half>((uint32_t*)&src[i * 4], (uint2*)&dst[i * 4]); + } + } +#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED } }; -#endif // !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 900) - -#ifdef FLASHINFER_ENABLE_BF16 - template <> struct vec_cast { template @@ -180,7 +338,6 @@ struct vec_cast { } } }; -#endif // FLASHINFER_ENABLE_BF16 template struct vec_t { @@ -230,7 +387,6 @@ FLASHINFER_INLINE void cast_store_impl(tgt_float_t* dst_ptr, } } -#ifdef FLASHINFER_ENABLE_FP8 /******************* vec_t<__nv_fp8_e4m3> *******************/ // __nv_fp8_e4m3 x 1 @@ -724,7 +880,6 @@ struct vec_t<__nv_fp8_e5m2, vec_size> { } } }; -#endif /******************* vec_t *******************/ @@ -889,7 +1044,6 @@ struct vec_t { } }; -#ifdef FLASHINFER_ENABLE_BF16 /******************* vec_t *******************/ // nv_bfloat16 x 1 @@ -1071,8 +1225,6 @@ struct vec_t { } }; -#endif - /******************* vec_t *******************/ // float x 1 diff --git a/src/test_fast_dequant.cu b/src/test_fast_dequant.cu new file mode 100644 index 000000000..2ffbdc1c1 --- /dev/null +++ b/src/test_fast_dequant.cu @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include +#include + +#include "utils.h" + +using namespace flashinfer; + +template +__global__ void test_fast_f8_f16_dequant(dtype_f8* f8, dtype_f16* f16) { + size_t global_tidx = blockIdx.x * blockDim.x + threadIdx.x; + vec_cast::cast<8>(f16 + global_tidx * 8, f8 + global_tidx * 8); +} + +template +void TestFastDequant() { + std::vector f8_h(1024); + utils::vec_normal_(f8_h); + std::vector f16_h_ref(1024); + for (uint32_t i = 0; i < 1024; ++i) { + f16_h_ref[i] = static_cast(f8_h[i]); + } + + thrust::device_vector f8_d(f8_h); + thrust::device_vector f16_d(1024); + + test_fast_f8_f16_dequant + <<<1, 128>>>(thrust::raw_pointer_cast(f8_d.data()), thrust::raw_pointer_cast(f16_d.data())); + + cudaError_t err = cudaGetLastError(); + EXPECT_EQ(err, cudaSuccess); + + thrust::host_vector f16_h(f16_d); + for (uint32_t i = 0; i < 1024; ++i) { + if (f16_h[i] != f16_h_ref[i]) { + printf("mismatch at i=%d: out=%x ref=%x\n", i, *(uint16_t*)(f16_h.data() + i), + *(uint16_t*)(f16_h_ref.data() + i)); + } + EXPECT_EQ(f16_h[i], f16_h_ref[i]); + } +} + +TEST(FlashInferCorrectnessTest, TestFastDequantCorrectnessE4M3ToFloat16) { + TestFastDequant<__nv_fp8_e4m3, half>(); +} +TEST(FlashInferCorrectnessTest, TestFastDequantCorrectnessE5M2ToFloat16) { + TestFastDequant<__nv_fp8_e5m2, half>(); +} +TEST(FlashInferCorrectnessTest, TestFastDequantCorrectnessE4M3ToBFloat16) { + TestFastDequant<__nv_fp8_e4m3, __nv_bfloat16>(); +} +TEST(FlashInferCorrectnessTest, TestFastDequantCorrectnessE5M2ToBFloat16) { + TestFastDequant<__nv_fp8_e5m2, __nv_bfloat16>(); +}