diff --git a/.github/workflows/nv-a6000.yml b/.github/workflows/nv-a6000.yml index 38d4ffd1828b..e1f8ed41cebe 100644 --- a/.github/workflows/nv-a6000.yml +++ b/.github/workflows/nv-a6000.yml @@ -2,9 +2,9 @@ name: nv-a6000 on: pull_request: - paths-ignore: - - 'docs/**' - - 'blogs/**' + paths: + - "deepspeed/inference/v2/**" + - "tests/unit/inference/v2/**" workflow_dispatch: concurrency: diff --git a/.github/workflows/nv-accelerate-v100.yml b/.github/workflows/nv-accelerate-v100.yml index 5bd0c22c8b98..4525a8124dc2 100644 --- a/.github/workflows/nv-accelerate-v100.yml +++ b/.github/workflows/nv-accelerate-v100.yml @@ -5,6 +5,7 @@ on: paths-ignore: - 'docs/**' - 'blogs/**' + - 'deepspeed/inference/v2/**' merge_group: branches: [ master ] schedule: diff --git a/.github/workflows/nv-inference.yml b/.github/workflows/nv-inference.yml index b5cf46f79011..e9c63051cbdf 100644 --- a/.github/workflows/nv-inference.yml +++ b/.github/workflows/nv-inference.yml @@ -5,6 +5,7 @@ on: paths-ignore: - 'docs/**' - 'blogs/**' + - 'deepspeed/inference/v2/**' merge_group: branches: [ master ] schedule: diff --git a/.github/workflows/nv-lightning-v100.yml b/.github/workflows/nv-lightning-v100.yml index 75d2dc732d4d..b2b900e186f8 100644 --- a/.github/workflows/nv-lightning-v100.yml +++ b/.github/workflows/nv-lightning-v100.yml @@ -5,6 +5,7 @@ on: paths-ignore: - 'docs/**' - 'blogs/**' + - 'deepspeed/inference/v2/**' merge_group: branches: [ master ] schedule: diff --git a/.github/workflows/nv-megatron.yml b/.github/workflows/nv-megatron.yml index 2fb9e37e5e9c..7bd29bb14e07 100644 --- a/.github/workflows/nv-megatron.yml +++ b/.github/workflows/nv-megatron.yml @@ -5,6 +5,7 @@ on: paths-ignore: - 'docs/**' - 'blogs/**' + - 'deepspeed/inference/v2/**' merge_group: branches: [ master ] schedule: diff --git a/.github/workflows/nv-pre-compile-ops.yml b/.github/workflows/nv-pre-compile-ops.yml index f253340c6966..79bc36dac7ab 100644 --- a/.github/workflows/nv-pre-compile-ops.yml +++ b/.github/workflows/nv-pre-compile-ops.yml @@ -7,6 +7,7 @@ on: paths-ignore: - 'docs/**' - 'blogs/**' + - 'deepspeed/inference/v2/**' merge_group: branches: [ master ] schedule: diff --git a/.github/workflows/nv-torch-latest-cpu.yml b/.github/workflows/nv-torch-latest-cpu.yml index 3c2b7301acf0..b62d30e3621b 100644 --- a/.github/workflows/nv-torch-latest-cpu.yml +++ b/.github/workflows/nv-torch-latest-cpu.yml @@ -5,6 +5,7 @@ on: paths-ignore: - 'docs/**' - 'blogs/**' + - 'deepspeed/inference/v2/**' merge_group: branches: [ master ] schedule: diff --git a/.github/workflows/nv-torch-latest-v100.yml b/.github/workflows/nv-torch-latest-v100.yml index 2b91df3ae44c..2d396b79b14a 100644 --- a/.github/workflows/nv-torch-latest-v100.yml +++ b/.github/workflows/nv-torch-latest-v100.yml @@ -5,6 +5,7 @@ on: paths-ignore: - 'docs/**' - 'blogs/**' + - 'deepspeed/inference/v2/**' merge_group: branches: [ master ] schedule: diff --git a/.github/workflows/nv-transformers-v100.yml b/.github/workflows/nv-transformers-v100.yml index 1cc0c6588610..4ac66edb1e93 100644 --- a/.github/workflows/nv-transformers-v100.yml +++ b/.github/workflows/nv-transformers-v100.yml @@ -5,6 +5,7 @@ on: paths-ignore: - 'docs/**' - 'blogs/**' + - 'deepspeed/inference/v2/**' merge_group: branches: [ master ] schedule: diff --git a/deepspeed/inference/v2/kernels/includes/activation_type.h b/deepspeed/inference/v2/kernels/includes/activation_type.h new file mode 100644 index 000000000000..a44921d5d650 --- /dev/null +++ b/deepspeed/inference/v2/kernels/includes/activation_type.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +enum ActivationType { + GELU = 0, + RELU = 1, + SILU = 2, + GEGLU = 3, + ReGLU = 4, + SiGLU = 5, + IDENTITY = 6, + InvalidType = -1 +}; diff --git a/deepspeed/inference/v2/kernels/includes/conversion_utils.h b/deepspeed/inference/v2/kernels/includes/conversion_utils.h new file mode 100644 index 000000000000..3a90a3e91ddf --- /dev/null +++ b/deepspeed/inference/v2/kernels/includes/conversion_utils.h @@ -0,0 +1,640 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "ds_kernel_utils.h" + +#include + +#ifdef BF16_AVAILABLE +#include +#endif + +namespace conversion { + +// Basic primitive for constructing conversions +template +DS_D_INLINE TO to(FROM val) +{ + return to(val); +} + +// Specializations + +/********************* Identity Conversions *********************/ +/* +Identity conversions are useful in templated functions where we might have +a fixed destination type. For example, I might have a kernel that accepts +__half, __nv_bfloat16, and float but always want to do the core computation +at floating point: + +T mem_value = input[idx]; +float compute_value = conversion::to(mem_value); + +In practice, we should be able to elide the second template parameter: +float compute_val = conversion::to(mem_value); + +In this case, we need an implementation to handle the T = float case + +NOTE: The type inferencing system appears to be unable to handle inferring the first +template parameter, even in the trivial case. +*/ + +// Floating point types +template <> +DS_D_INLINE double to(double val) +{ + return val; +} +template <> +DS_D_INLINE float to(float val) +{ + return val; +} +template <> +DS_D_INLINE __half to(__half val) +{ + return val; +} +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat16 to(__nv_bfloat16 val) +{ + return val; +} +#endif + +// Integer types +template <> +DS_D_INLINE int8_t to(int8_t val) +{ + return val; +} +template <> +DS_D_INLINE uint8_t to(uint8_t val) +{ + return val; +} +template <> +DS_D_INLINE int16_t to(int16_t val) +{ + return val; +} +template <> +DS_D_INLINE uint16_t to(uint16_t val) +{ + return val; +} +template <> +DS_D_INLINE int32_t to(int32_t val) +{ + return val; +} +template <> +DS_D_INLINE uint32_t to(uint32_t val) +{ + return val; +} +template <> +DS_D_INLINE int64_t to(int64_t val) +{ + return val; +} +template <> +DS_D_INLINE uint64_t to(uint64_t val) +{ + return val; +} + +// TODO: evaluate if we want bools + +/********************* To Double Conversions *********************/ + +// * to double variants + +// Would normally like to not use C cast, but this is an important enough conversion +// to keep +template <> +DS_D_INLINE double to(float val) +{ +#ifdef PTX_AVAILABLE + double ret_val; + asm("ctv.rn.f64.f32 %0, %1;\n" : "=d"(ret_val) : "f"(val)); + return ret_val; +#else + return double(val); +#endif +} +// Note: there is a CVT instruction for __half -> double, but there's no inline interface +// for passing a single half value +template <> +DS_D_INLINE double to(__half val) +{ + return to(__half2float(val)); +} +template <> +DS_D_INLINE double to(int64_t val) +{ + return __ll2double_rn(val); +} +template <> +DS_D_INLINE double to(int32_t val) +{ + return __int2double_rn(val); +} +template <> +DS_D_INLINE double to(int16_t val) +{ + return __int2double_rn(val); +} +template <> +DS_D_INLINE double to(int8_t val) +{ + return __int2double_rn(val); +} +template <> +DS_D_INLINE double to(uint64_t val) +{ + return __ull2double_rn(val); +} +template <> +DS_D_INLINE double to(uint32_t val) +{ + return __uint2double_rn(val); +} +template <> +DS_D_INLINE double to(uint16_t val) +{ + return __uint2double_rn(val); +} +template <> +DS_D_INLINE double to(uint8_t val) +{ + return __uint2double_rn(val); +} + +// Same applies here +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE double to(__nv_bfloat16 val) +{ + return to(__bfloat162float(val)); +} +#endif + +/********************* To Float Conversions *********************/ + +template <> +DS_D_INLINE float to(double val) +{ + return __double2float_rn(val); +} +template <> +DS_D_INLINE float to(__half val) +{ + return __half2float(val); +} +template <> +DS_D_INLINE float to(int64_t val) +{ + return __ll2float_rn(val); +} +template <> +DS_D_INLINE float to(int32_t val) +{ + return __int2float_rn(val); +} +template <> +DS_D_INLINE float to(int16_t val) +{ + return __int2float_rn(val); +} +template <> +DS_D_INLINE float to(int8_t val) +{ + return __int2float_rn(val); +} +template <> +DS_D_INLINE float to(uint64_t val) +{ + return __ull2float_rn(val); +} +template <> +DS_D_INLINE float to(uint32_t val) +{ + return __uint2float_rn(val); +} +template <> +DS_D_INLINE float to(uint16_t val) +{ + return __uint2float_rn(val); +} +template <> +DS_D_INLINE float to(uint8_t val) +{ + return __uint2float_rn(val); +} + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE float to(__nv_bfloat16 val) +{ + return __bfloat162float(val); +} +#endif + +/********************* To Float2 Conversions *********************/ +template <> +DS_D_INLINE float2 to(__half2 val) +{ + return __half22float2(val); +} + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE float2 to(__nv_bfloat162 val) +{ + return __bfloat1622float2(val); +} +#endif + +/********************* To Half Conversions *********************/ +template <> +DS_D_INLINE __half to(double val) +{ +#ifdef __HIP_PLATFORM_AMD__ + float val_f = __double2float_rn(val); + return __float2half(val_f); +#else + return __double2half(val); +#endif +} +template <> +DS_D_INLINE __half to(float val) +{ + return __float2half(val); +} +template <> +DS_D_INLINE __half to(int64_t val) +{ + return __ll2half_rn(val); +} +template <> +DS_D_INLINE __half to(int32_t val) +{ + return __int2half_rn(val); +} +template <> +DS_D_INLINE __half to(int16_t val) +{ + return __short2half_rn(val); +} +template <> +DS_D_INLINE __half to(int8_t val) +{ + return __int2half_rn(val); +} +template <> +DS_D_INLINE __half to(uint64_t val) +{ + return __ull2half_rn(val); +} +template <> +DS_D_INLINE __half to(uint32_t val) +{ + return __uint2half_rn(val); +} +template <> +DS_D_INLINE __half to(uint16_t val) +{ + return __ushort2half_rn(val); +} +template <> +DS_D_INLINE __half to(uint8_t val) +{ + return __uint2half_rn(val); +} + +#ifdef BF16_AVAILABLE +// No direct conversion +template <> +DS_D_INLINE __half to(__nv_bfloat16 val) +{ + return to<__half>(to(val)); +} +#endif + +/********************* To Half2 Conversions *********************/ +template <> +DS_D_INLINE __half2 to(float2 val) +{ + return __float22half2_rn(val); +} +template <> +DS_D_INLINE __half2 to(float val) +{ + return __float2half2_rn(val); +} + +#ifdef BF16_AVAILABLE +// No direct conversion +template <> +DS_D_INLINE __half2 to(__nv_bfloat162 val) +{ + return to<__half2>(to(val)); +} +#endif + +/********************* To BF16 Conversions *********************/ +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat16 to(double val) +{ + return __double2bfloat16(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(float val) +{ + return __float2bfloat16(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(int64_t val) +{ + return __ll2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(int32_t val) +{ + return __int2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(int16_t val) +{ + return __short2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(int8_t val) +{ + return __int2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(uint64_t val) +{ + return __ull2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(uint32_t val) +{ + return __uint2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(uint16_t val) +{ + return __ushort2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(uint8_t val) +{ + return __uint2bfloat16_rn(val); +} +#endif + +/********************* To BF162 Conversions *********************/ +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat162 to(float2 val) +{ + return __float22bfloat162_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat162 to(float val) +{ + return __float2bfloat162_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat162 to(__half2 val) +{ + return to<__nv_bfloat162>(to(val)); +} +#endif + +/********************* To INT64_T Conversions *********************/ +template <> +DS_D_INLINE int64_t to(double val) +{ + return __double2ll_rn(val); +} +template <> +DS_D_INLINE int64_t to(float val) +{ + return __float2ll_rn(val); +} +template <> +DS_D_INLINE int64_t to(__half val) +{ + return __half2ll_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE int64_t to(__nv_bfloat16 val) +{ + return __bfloat162ll_rn(val); +} +#endif + +/********************* To INT32_T Conversions *********************/ +template <> +DS_D_INLINE int32_t to(double val) +{ + return __double2int_rn(val); +} +template <> +DS_D_INLINE int32_t to(float val) +{ + return __float2int_rn(val); +} +template <> +DS_D_INLINE int32_t to(__half val) +{ + return __half2int_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE int32_t to(__nv_bfloat16 val) +{ + return __bfloat162int_rn(val); +} +#endif + +/********************* To INT16_T Conversions *********************/ +template <> +DS_D_INLINE int16_t to(double val) +{ + return __double2int_rn(val); +} +template <> +DS_D_INLINE int16_t to(float val) +{ + return __float2int_rn(val); +} +template <> +DS_D_INLINE int16_t to(__half val) +{ + return __half2int_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE int16_t to(__nv_bfloat16 val) +{ + return __bfloat162int_rn(val); +} +#endif + +/********************* To INT8_T Conversions *********************/ +template <> +DS_D_INLINE int8_t to(double val) +{ + return __double2int_rn(val); +} +template <> +DS_D_INLINE int8_t to(float val) +{ + return __float2int_rn(val); +} +template <> +DS_D_INLINE int8_t to(__half val) +{ + return __half2int_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE int8_t to(__nv_bfloat16 val) +{ + return __bfloat162int_rn(val); +} +#endif + +/********************* To UINT64_T Conversions *********************/ +template <> +DS_D_INLINE uint64_t to(double val) +{ + return __double2ull_rn(val); +} +template <> +DS_D_INLINE uint64_t to(float val) +{ + return __float2ull_rn(val); +} +template <> +DS_D_INLINE uint64_t to(__half val) +{ + return __half2ull_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE uint64_t to(__nv_bfloat16 val) +{ + return __bfloat162ull_rn(val); +} +#endif + +/********************* To UINT32_T Conversions *********************/ +template <> +DS_D_INLINE uint32_t to(double val) +{ + return __double2uint_rn(val); +} +template <> +DS_D_INLINE uint32_t to(float val) +{ + return __float2uint_rn(val); +} +template <> +DS_D_INLINE uint32_t to(__half val) +{ + return __half2uint_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE uint32_t to(__nv_bfloat16 val) +{ + return __bfloat162uint_rn(val); +} +#endif + +/********************* To UINT16_T Conversions *********************/ +template <> +DS_D_INLINE uint16_t to(double val) +{ + return __double2uint_rn(val); +} +template <> +DS_D_INLINE uint16_t to(float val) +{ + return __float2uint_rn(val); +} +template <> +DS_D_INLINE uint16_t to(__half val) +{ + return __half2uint_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE uint16_t to(__nv_bfloat16 val) +{ + return __bfloat162uint_rn(val); +} +#endif + +/********************* To UINT8_T Conversions *********************/ +template <> +DS_D_INLINE uint8_t to(double val) +{ + return __double2uint_rn(val); +} +template <> +DS_D_INLINE uint8_t to(float val) +{ + return __float2uint_rn(val); +} +template <> +DS_D_INLINE uint8_t to(__half val) +{ + return __half2uint_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE uint8_t to(__nv_bfloat16 val) +{ + return __bfloat162uint_rn(val); +} +#endif + +} // namespace conversion diff --git a/deepspeed/inference/v2/kernels/includes/ds_kernel_utils.h b/deepspeed/inference/v2/kernels/includes/ds_kernel_utils.h new file mode 100644 index 000000000000..8e4888109fcd --- /dev/null +++ b/deepspeed/inference/v2/kernels/includes/ds_kernel_utils.h @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Centralized header file for preprocessor macros and constants +used throughout the codebase. +*/ + +#pragma once + +#include +#include + +#ifdef BF16_AVAILABLE +#include +#endif + +#define DS_HD_INLINE __host__ __device__ __forceinline__ +#define DS_D_INLINE __device__ __forceinline__ + +#ifdef __HIP_PLATFORM_AMD__ + +// constexpr variant of warpSize for templating +constexpr int hw_warp_size = 64; +#define HALF_PRECISION_AVAILABLE = 1 +#include +#include + +#else // !__HIP_PLATFORM_AMD__ + +// constexpr variant of warpSize for templating +constexpr int hw_warp_size = 32; + +#if __CUDA_ARCH__ >= 530 +#define HALF_PRECISION_AVAILABLE = 1 +#define PTX_AVAILABLE +#endif // __CUDA_ARCH__ >= 530 + +#if __CUDA_ARCH__ >= 800 +#define ASYNC_COPY_AVAILABLE +#endif // __CUDA_ARCH__ >= 800 + +#include +#include + +#endif //__HIP_PLATFORM_AMD__ + +inline int next_pow2(const int val) +{ + int rounded_val = val - 1; + rounded_val |= rounded_val >> 1; + rounded_val |= rounded_val >> 2; + rounded_val |= rounded_val >> 4; + rounded_val |= rounded_val >> 8; + return rounded_val + 1; +} diff --git a/deepspeed/inference/v2/kernels/includes/memory_access_utils.h b/deepspeed/inference/v2/kernels/includes/memory_access_utils.h new file mode 100644 index 000000000000..6789714d27c7 --- /dev/null +++ b/deepspeed/inference/v2/kernels/includes/memory_access_utils.h @@ -0,0 +1,1115 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include "ds_kernel_utils.h" + +/////////////////////////////// Memory Access Utils /////////////////////////////// +namespace mem_access { + +enum class LoadPolicy { + CacheAll, // Cache at all levels + CacheGlobal, // Cache at L2 only + CacheStreaming // Cache with evict first policy +}; + +enum class StorePolicy { + Writeback, // Cache in L1, write-back on eviction + CacheGlobal, // Bypass L1, write-back on eviction + CacheStreaming // Allocate cache line with evict first policy +}; + +template +__device__ __forceinline__ void load_global(void* dst, const void* src); + +template +__device__ __forceinline__ void load_global(void* dst, const void* src, bool do_access); + +// Shared accesses have no cache policy +template +__device__ __forceinline__ void load_shared(void* dst, const void* src); + +template +__device__ __forceinline__ void load_shared(void* dst, const void* src, bool do_access); + +template +__device__ __forceinline__ void store_global(void* dst, const void* src); + +// Shared accesses have no cache policy +template +__device__ __forceinline__ void store_shared(void* dst, const void* src); + +#ifdef ASYNC_COPY_AVAILABLE +template +__device__ __forceinline__ void memcpy_async(void* shr, const void* gbl); + +template +__device__ __forceinline__ void memcpy_async_nop(void* shr, const void* gbl, bool predicate); + +template +__device__ __forceinline__ void memcpy_async_zero(void* shr, const void* gbl, bool predicate); + +__device__ __forceinline__ void memcpy_async_fence(); + +template +__device__ __forceinline__ void memcpy_async_wait(); + +template +__device__ __forceinline__ void tail_complete_wait(int remaining_stages); +#endif + +// Util for tracking pipeline buffers +// TODO: Evaluate whether this should also be guarded by ASYNC_COPY_AVAILABLE +template +class BufferTracker { +public: + int current_state; + + __device__ __forceinline__ BufferTracker() : current_state(0) {} + + __device__ __forceinline__ int get() + { + int return_val = current_state++; + current_state = (current_state == max ? 0 : current_state); + return return_val; + } +}; + +__device__ __forceinline__ uint32_t lane_id() +{ +#ifdef PTX_AVAILABLE + unsigned int lane_id; + asm volatile("mov.u32 %0, %%laneid;" : "=r"(lane_id)); + return lane_id; +#else + return threadIdx.x & (warpSize - 1); // Portable +#endif +} + +/////////// Load Global /////////// +template <> +__device__ __forceinline__ void load_global<16>(void* dst, const void* src) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.ca.v4.u32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "l"(src)); +#else + const uint4* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<16>(void* dst, const void* src, bool do_access) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %5, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\tmov.b32 %2, 0;\n" + "\tmov.b32 %3, 0;\n" + "\t@p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "l"(src), "r"((int)do_access)); +#else + const uint4* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + data[0].z = 0; + data[0].w = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<16, LoadPolicy::CacheGlobal>(void* dst, const void* src) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "l"(src)); +#else + const uint4* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<16, LoadPolicy::CacheGlobal>(void* dst, + const void* src, + bool do_access) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %5, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\tmov.b32 %2, 0;\n" + "\tmov.b32 %3, 0;\n" + "\t@p ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "l"(src), "r"((int)do_access)); +#else + const uint4* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + data[0].z = 0; + data[0].w = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<16, LoadPolicy::CacheStreaming>(void* dst, + const void* src) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cs.v4.u32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "l"(src)); +#else + const uint4* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<16, LoadPolicy::CacheStreaming>(void* dst, + const void* src, + bool do_access) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %5, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\tmov.b32 %2, 0;\n" + "\tmov.b32 %3, 0;\n" + "\t@p ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "l"(src), "r"((int)do_access)); +#else + const uint4* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + data[0].z = 0; + data[0].w = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<8>(void* dst, const void* src) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.ca.v2.u32 {%0, %1}, [%2];\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "l"(src)); +#else + const uint2* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<8>(void* dst, const void* src, bool do_access) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %3, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\t@p ld.global.v2.u32 {%0, %1}, [%2];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "l"(src), "r"((int)do_access)); +#else + const uint2* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<8, LoadPolicy::CacheGlobal>(void* dst, const void* src) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cg.v2.u32 {%0, %1}, [%2];\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "l"(src)); +#else + const uint2* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<8, LoadPolicy::CacheGlobal>(void* dst, + const void* src, + bool do_access) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %3, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\t@p ld.global.cg.v2.u32 {%0, %1}, [%2];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "l"(src), "r"((int)do_access)); +#else + const uint2* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<8, LoadPolicy::CacheStreaming>(void* dst, + const void* src) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cs.v2.u32 {%0, %1}, [%2];\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "l"(src)); +#else + const uint2* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<8, LoadPolicy::CacheStreaming>(void* dst, + const void* src, + bool do_access) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %3, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\t@p ld.global.cs.v2.u32 {%0, %1}, [%2];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "l"(src), "r"((int)do_access)); +#else + const uint2* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<4>(void* dst, const void* src) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.ca.u32 {%0}, [%1];\n" : "=r"(*data) : "l"(src)); +#else + const int32_t* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<4>(void* dst, const void* src, bool do_access) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %2, 0;\n" + "\tmov.b32 %0, 0;\n" + "\t@p ld.global.u32 {%0}, [%1];\n" + "}\n" + : "=r"(data[0]) + : "l"(src), "r"((int)do_access)); +#else + const int32_t* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0] = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<4, LoadPolicy::CacheGlobal>(void* dst, const void* src) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cg.u32 {%0}, [%1];\n" : "=r"(*data) : "l"(src)); +#else + const int32_t* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<4, LoadPolicy::CacheGlobal>(void* dst, + const void* src, + bool do_access) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %2, 0;\n" + "\tmov.b32 %0, 0;\n" + "\t@p ld.global.cg.u32 {%0}, [%1];\n" + "}\n" + : "=r"(data[0]) + : "l"(src), "r"((int)do_access)); +#else + const int32_t* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0] = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<4, LoadPolicy::CacheStreaming>(void* dst, + const void* src) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cs.u32 {%0}, [%1];\n" : "=r"(*data) : "l"(src)); +#else + const int32_t* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<4, LoadPolicy::CacheStreaming>(void* dst, + const void* src, + bool do_access) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %2, 0;\n" + "\tmov.b32 %0, 0;\n" + "\t@p ld.global.cs.u32 {%0}, [%1];\n" + "}\n" + : "=r"(data[0]) + : "l"(src), "r"((int)do_access)); +#else + const int32_t* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0] = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<2>(void* dst, const void* src) +{ + int16_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.ca.u16 {%0}, [%1];\n" : "=h"(*data) : "l"(src)); +#else + const int16_t* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<2>(void* dst, const void* src, bool do_access) +{ + int16_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %2, 0;\n" + "\tmov.u16 %0, 0;\n" + "\t@p ld.global.u16 {%0}, [%1];\n" + "}\n" + : "=h"(*data) + : "l"(src), "r"((int)do_access)); +#else + const int16_t* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0] = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<2, LoadPolicy::CacheGlobal>(void* dst, const void* src) +{ + int16_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cg.u16 {%0}, [%1];\n" : "=h"(*data) : "l"(src)); +#else + const int16_t* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<2, LoadPolicy::CacheGlobal>(void* dst, + const void* src, + bool do_access) +{ + int16_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %2, 0;\n" + "\tmov.u16 %0, 0;\n" + "\t@p ld.global.cg.u16 {%0}, [%1];\n" + "}\n" + : "=h"(*data) + : "l"(src), "r"((int)do_access)); +#else + const int16_t* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0] = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_global<2, LoadPolicy::CacheStreaming>(void* dst, + const void* src) +{ + int16_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile("ld.global.cs.u16 {%0}, [%1];\n" : "=h"(*data) : "l"(src)); +#else + const int16_t* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_global<2, LoadPolicy::CacheStreaming>(void* dst, + const void* src, + bool do_access) +{ + int16_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %2, 0;\n" + "\tmov.u16 %0, 0;\n" + "\t@p ld.global.cs.u16 {%0}, [%1];\n" + "}\n" + : "=h"(*data) + : "l"(src), "r"((int)do_access)); +#else + const int16_t* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0] = 0; + } +#endif +} + +/////////// Load Shared /////////// +namespace internal { + +#ifdef PTX_AVAILABLE +__device__ __forceinline__ unsigned convert_to_shared(const void* ptr) +{ +#if __CUDACC_VER_MAJOR__ >= 11 + // In CUDA 11 we have a builtin intrinsic + return __cvta_generic_to_shared(ptr); +#else + unsigned ret_val; + asm volatile( + "{\n" + "\t.reg .u64 p1;\n" + "\tcvta.to.shared.u64 p1, %1\n" + "\tcvt.u32.u64 %0, p1;\n" + "}\n" + : "=r"(ret_val) + : "l"(ptr)); + return ret_val; +#endif +} +#endif + +} // namespace internal + +template <> +__device__ __forceinline__ void load_shared<16>(void* dst, const void* src) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + unsigned src_shr = internal::convert_to_shared(src); + + asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "r"(src_shr)); +#else + const uint4* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_shared<16>(void* dst, const void* src, bool do_access) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + unsigned src_shr = internal::convert_to_shared(src); + + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %5, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\tmov.b32 %2, 0;\n" + "\tmov.b32 %3, 0;\n" + "\t@p ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "r"(src_shr), "r"((int)do_access)); +#else + const uint4* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + data[0].z = 0; + data[0].w = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_shared<8>(void* dst, const void* src) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + unsigned src_shr = internal::convert_to_shared(src); + + asm volatile("ld.shared.v2.u32 {%0, %1}, [%2];\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "r"(src_shr)); +#else + const uint2* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_shared<8>(void* dst, const void* src, bool do_access) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + unsigned src_shr = internal::convert_to_shared(src); + + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %3, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\t@p ld.shared.v2.u32 {%0, %1}, [%2];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "r"(src_shr), "r"((int)do_access)); +#else + const uint2* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + } +#endif +} + +template <> +__device__ __forceinline__ void load_shared<4>(void* dst, const void* src) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + unsigned src_shr = internal::convert_to_shared(src); + + asm volatile("ld.shared.u32 {%0}, [%1];\n" : "=r"(*data) : "r"(src_shr)); +#else + const int32_t* src_cast = reinterpret_cast(src); + data[0] = src_cast[0]; +#endif +} + +template <> +__device__ __forceinline__ void load_shared<4>(void* dst, const void* src, bool do_access) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + unsigned src_shr = internal::convert_to_shared(src); + + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %2, 0;\n" + "\tmov.b32 %0, 0;\n" + "\t@p ld.shared.u32 %0, [%1];\n" + "}\n" + : "=r"(data[0]) + : "r"(src_shr), "r"((int)do_access)); +#else + const int32_t* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0] = 0; + } +#endif +} + +/////////// Store Global /////////// + +template <> +__device__ __forceinline__ void store_global<16>(void* dst, const void* src) +{ + const uint4* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.wb.v4.u32 [%0], {%1, %2, %3, %4};\n" + : + : "l"(dst), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), "r"(data[0].w) + : "memory"); +#else + uint4* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<16, StorePolicy::CacheGlobal>(void* dst, + const void* src) +{ + const uint4* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.cg.v4.u32 [%0], {%1, %2, %3, %4};\n" + : + : "l"(dst), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), "r"(data[0].w) + : "memory"); +#else + uint4* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<16, StorePolicy::CacheStreaming>(void* dst, + const void* src) +{ + const uint4* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.cs.v4.u32 [%0], {%1, %2, %3, %4};\n" + : + : "l"(dst), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), "r"(data[0].w) + : "memory"); +#else + uint4* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<8>(void* dst, const void* src) +{ + const uint2* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.wb.v2.u32 [%0], {%1, %2};\n" + : + : "l"(dst), "r"(data[0].x), "r"(data[0].y)); +#else + uint2* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<8, StorePolicy::CacheGlobal>(void* dst, + const void* src) +{ + const uint2* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.cg.v2.u32 [%0], {%1, %2};\n" + : + : "l"(dst), "r"(data[0].x), "r"(data[0].y)); +#else + uint2* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<8, StorePolicy::CacheStreaming>(void* dst, + const void* src) +{ + const uint2* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.cs.v2.u32 [%0], {%1, %2};\n" + : + : "l"(dst), "r"(data[0].x), "r"(data[0].y)); +#else + uint2* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<4>(void* dst, const void* src) +{ + const int32_t* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.wb.u32 [%0], %1;\n" : : "l"(dst), "r"(*data)); +#else + int32_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<4, StorePolicy::CacheGlobal>(void* dst, + const void* src) +{ + const int32_t* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.cg.u32 [%0], %1;\n" : : "l"(dst), "r"(*data)); +#else + int32_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_global<4, StorePolicy::CacheStreaming>(void* dst, + const void* src) +{ + const int32_t* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + asm volatile("st.global.cs.u32 [%0], %1;\n" : : "l"(dst), "r"(*data)); +#else + int32_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +/////////// Store Shared /////////// + +template <> +__device__ __forceinline__ void store_shared<16>(void* dst, const void* src) +{ + const uint4* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + unsigned dst_int = internal::convert_to_shared(dst); + + asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n" + : + : "r"(dst_int), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), "r"(data[0].w)); +#else + uint4* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_shared<8>(void* dst, const void* src) +{ + const uint2* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + unsigned dst_int = internal::convert_to_shared(dst); + + asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n" + : + : "r"(dst_int), "r"(data[0].x), "r"(data[0].y)); +#else + uint2* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +template <> +__device__ __forceinline__ void store_shared<4>(void* dst, const void* src) +{ + const int32_t* data = reinterpret_cast(src); +#ifdef PTX_AVAILABLE + unsigned dst_int = internal::convert_to_shared(dst); + + asm volatile("st.shared.u32 [%0], %1;\n" : : "r"(dst_int), "r"(*data)); +#else + int32_t* dst_cast = reinterpret_cast(dst); + dst_cast[0] = data[0]; +#endif +} + +/////////// Asynchronous Memory Copy /////////// + +#ifdef ASYNC_COPY_AVAILABLE +template +__device__ __forceinline__ void memcpy_async(void* shr, const void* gbl) +{ + static_assert((AccessSize == 4 || AccessSize == 8 || AccessSize == 16)); + unsigned shr_int = internal::convert_to_shared(shr); + + asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n" + : + : "r"(shr_int), "l"(gbl), "n"(AccessSize)); +} + +template +__device__ __forceinline__ void memcpy_async_nop(void* shr, const void* gbl, bool predicate) +{ + static_assert((AccessSize == 4 || AccessSize == 8 || AccessSize == 16)); + unsigned shr_int = internal::convert_to_shared(shr); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" + : + : "r"((int)predicate), "r"(shr_int), "l"(gbl), "n"(AccessSize)); +} + +template +__device__ __forceinline__ void memcpy_async_zero(void* shr, const void* gbl, bool predicate) +{ + static_assert((AccessSize == 4 || AccessSize == 8 || AccessSize == 16)); + unsigned shr_int = internal::convert_to_shared(shr); + int bytes_to_copy = (predicate ? AccessSize : 0); + + asm volatile("cp.async.ca.shared.global [%0], [%1], %2, %3;\n" + : + : "r"(shr_int), "l"(gbl), "n"(AccessSize), "r"(bytes_to_copy)); +} + +template +__device__ __forceinline__ void memcpy_async_zero_nop(void* shr, + const void* gbl, + bool zero_predicate, + bool nop_predicate) +{ + static_assert((AccessSize == 4 || AccessSize == 8 || AccessSize == 16)); + unsigned shr_int = internal::convert_to_shared(shr); + int bytes_to_copy = (zero_predicate ? AccessSize : 0); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3, %4;\n" + "}\n" + : + : "r"((int)nop_predicate), "r"(shr_int), "l"(gbl), "n"(AccessSize), "r"(bytes_to_copy)); +} + +// Cache global variants. Separate interface to require deliberate use of them. +__device__ __forceinline__ void memcpy_async_cg(void* shr, const void* gbl) +{ + unsigned shr_int = internal::convert_to_shared(shr); + + asm volatile("cp.async.cg.shared.global [%0], [%1], 16;\n" : : "r"(shr_int), "l"(gbl)); +} + +__device__ __forceinline__ void memcpy_async_nop_cg(void* shr, const void* gbl, bool predicate) +{ + unsigned shr_int = internal::convert_to_shared(shr); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], 16;\n" + "}\n" + : + : "r"((int)predicate), "r"(shr_int), "l"(gbl)); +} + +__device__ __forceinline__ void memcpy_async_zero_cg(void* shr, const void* gbl, bool predicate) +{ + unsigned shr_int = internal::convert_to_shared(shr); + int bytes_to_copy = (predicate ? 16 : 0); + + asm volatile("cp.async.cg.shared.global [%0], [%1], 16, %2;\n" + : + : "r"(shr_int), "l"(gbl), "r"(bytes_to_copy)); +} + +__device__ __forceinline__ void memcpy_async_zero_nop_cg(void* shr, + const void* gbl, + bool zero_predicate, + bool nop_predicate) +{ + unsigned shr_int = internal::convert_to_shared(shr); + int bytes_to_copy = (zero_predicate ? 16 : 0); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], 16, %3;\n" + "}\n" + : + : "r"((int)nop_predicate), "r"(shr_int), "l"(gbl), "r"(bytes_to_copy)); +} + +__device__ __forceinline__ void memcpy_async_fence() { asm volatile("cp.async.commit_group;\n"); } + +template +__device__ __forceinline__ void memcpy_async_wait() +{ + static_assert(stages <= 8); + + asm volatile("cp.async.wait_group %0;\n" : : "n"(stages)); +} + +// TODO: The tail complete should be a known compile time artifact, should try and induce this +// without all of the branches from the call-site. This is a hacky solution. +template <> +__device__ __forceinline__ void tail_complete_wait<1>(int remaining_stages) +{ + if (remaining_stages == 0) memcpy_async_wait<0>(); +} + +template <> +__device__ __forceinline__ void tail_complete_wait<2>(int remaining_stages) +{ + if (remaining_stages == 1) + memcpy_async_wait<1>(); + else if (remaining_stages == 0) + memcpy_async_wait<0>(); +} + +template <> +__device__ __forceinline__ void tail_complete_wait<3>(int remaining_stages) +{ + if (remaining_stages == 2) + memcpy_async_wait<2>(); + else if (remaining_stages == 1) + memcpy_async_wait<1>(); + else if (remaining_stages == 0) + memcpy_async_wait<0>(); +} + +template <> +__device__ __forceinline__ void tail_complete_wait<4>(int remaining_stages) +{ + if (remaining_stages == 3) + memcpy_async_wait<3>(); + else if (remaining_stages == 2) + memcpy_async_wait<2>(); + else if (remaining_stages == 1) + memcpy_async_wait<1>(); + else if (remaining_stages == 0) + memcpy_async_wait<0>(); +} + +template <> +__device__ __forceinline__ void tail_complete_wait<5>(int remaining_stages) +{ + if (remaining_stages == 4) + memcpy_async_wait<4>(); + else if (remaining_stages == 3) + memcpy_async_wait<3>(); + else if (remaining_stages == 2) + memcpy_async_wait<2>(); + else if (remaining_stages == 1) + memcpy_async_wait<1>(); + else if (remaining_stages == 0) + memcpy_async_wait<0>(); +} + +template <> +__device__ __forceinline__ void tail_complete_wait<6>(int remaining_stages) +{ + if (remaining_stages == 5) + memcpy_async_wait<5>(); + else if (remaining_stages == 4) + memcpy_async_wait<4>(); + else if (remaining_stages == 3) + memcpy_async_wait<3>(); + else if (remaining_stages == 2) + memcpy_async_wait<2>(); + else if (remaining_stages == 1) + memcpy_async_wait<1>(); + else if (remaining_stages == 0) + memcpy_async_wait<0>(); +} +#endif + +} // namespace mem_access diff --git a/deepspeed/inference/v2/kernels/includes/reduction_utils.h b/deepspeed/inference/v2/kernels/includes/reduction_utils.h new file mode 100644 index 000000000000..eb8efab77ac1 --- /dev/null +++ b/deepspeed/inference/v2/kernels/includes/reduction_utils.h @@ -0,0 +1,778 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "conversion_utils.h" +#include "ds_kernel_utils.h" +#include "memory_access_utils.h" + +namespace cg = cooperative_groups; + +namespace reduce { + +enum class ROpType { + // Addition + Add, + + // Maximum reduction + Max, + + // Minimum reduction + Min, +}; + +constexpr int max_threads = 1024; +constexpr int max_warps = max_threads / hw_warp_size; + +/* +High level API. The API takes in a set of operations and variables +and performs that reduction operation on that variable. The reductions +of each of the arguments are completely independent of each other ( +i.e., the val1-op1 combination has no impact on val2-op2). + +Example usage: +``` cpp +float max_val; +float min_val; +reduce::block(tb, warp, max_val, min_val); +``` + +TODO(cmikeh2): In theory, we might be able to do this sequentially with +device functions and rely on the assembler correctly behaving. My initial +instinct is this won't work, but if it does it would reduce implementation +cost significantly. + +TODO(cmikeh2): We need to support sub-block reductions. The warp intrinsic +currently supports this (more incidentally than anything else). It is not +uncommon in something like softmax or a fused attention kernel to map multiple +reductions to a thread block, but each reduction itself is only scoped +to part of the threads (i.e block size = 512, 128 threads per reduction). +*/ +template +DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val); + +template +DS_D_INLINE void block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2); + +template +DS_D_INLINE void block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3); + +template +DS_D_INLINE void block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3, + float& val4); + +/* +The partitioned block is a special case of the above where in the warps of a threadblock are +partitioned into separate independent reductions. For example, I might have an 8 warp thread block +in which each pair of warps is processing an independent piece of data. I would then reduce that +data with the something like the following: +``` cpp +float max_val; +reduce::partitioned_block(tb, warp, max_val); +``` +After which, each pair of warps would have coherent data with each other. Note, this API will not +provide correct results if the number of warps per partition is not a power of 2. +*/ +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val); + +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2); + +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3); + +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3, + float& val4); + +/* +Single element reduction primitives. Used inside serial collection +loops. + +Example usage: +using rop = reduce::OpType; +float min = init(); +for (int i = 0; i < 4; i++) { + min = reduce::element(min, data[i]); +} +*/ + +template +DS_D_INLINE T element(const T lhs, const T rhs); + +template +DS_D_INLINE T init(); + +/********************** Internal reduction APIs **********************/ + +/* +Single element "reductions". TODO(cmikeh2): this sort of "op" concept +should be refactored into its own implementation at some point. This interface +may be easily expanded for new types/operations, but the typical reductions +we need are covered with min/max/add on float. + +NOTE: there is no mean reduction because that relies on knowledge of how +many values were already reduced into each scalar. Implementing this on top +of reduce should be straightforward (can just wrap the sum reduction) and +would be a good extension of the header. +*/ + +DS_D_INLINE int _warp_rank() +{ + const int thread_rank = + threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y; + return thread_rank / hw_warp_size; +} + +/* Float element reduce implementations */ +template <> +DS_D_INLINE float element(const float lhs, const float rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE float element(const float lhs, const float rhs) +{ + return fmaxf(lhs, rhs); +} + +template <> +DS_D_INLINE float element(const float lhs, const float rhs) +{ + return fminf(lhs, rhs); +} + +/* __half element reduce implementation */ +template <> +DS_D_INLINE __half element(const __half lhs, const __half rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE __half element(const __half lhs, const __half rhs) +{ +#if __CUDA_ARCH__ >= 800 + // Intrinsic limited to Ampere + newer + return __hmax(lhs, rhs); +#else + return (lhs > rhs) ? lhs : rhs; +#endif +} + +template <> +DS_D_INLINE __half element(const __half lhs, const __half rhs) +{ +#if __CUDA_ARCH__ >= 800 + // Intrinsic limited to Ampere + newer + return __hmin(lhs, rhs); +#else + return (lhs < rhs) ? lhs : rhs; +#endif +} + +/* __half2 element reduce implementation */ +template <> +DS_D_INLINE __half2 element(const __half2 lhs, const __half2 rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE __half2 element(const __half2 lhs, const __half2 rhs) +{ +#if __CUDA_ARCH__ >= 800 + return __hmax2(lhs, rhs); +#else + __half2 ret_val; + ret_val.x = (lhs.x > rhs.x) ? lhs.x : rhs.x; + ret_val.y = (lhs.y > rhs.y) ? lhs.y : rhs.y; + return ret_val; +#endif +} + +template <> +DS_D_INLINE __half2 element(const __half2 lhs, const __half2 rhs) +{ +#if __CUDA_ARCH__ >= 800 + return __hmin2(lhs, rhs); +#else + __half2 ret_val; + ret_val.x = (lhs.x < rhs.x) ? lhs.x : rhs.x; + ret_val.y = (lhs.y < rhs.y) ? lhs.y : rhs.y; + return ret_val; +#endif +} + +template <> +DS_D_INLINE int32_t element(const int32_t lhs, const int32_t rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE int32_t element(const int32_t lhs, const int32_t rhs) +{ + return (lhs > rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE int32_t element(const int32_t lhs, const int32_t rhs) +{ + return (lhs < rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE uint32_t element(const uint32_t lhs, const uint32_t rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE uint32_t element(const uint32_t lhs, const uint32_t rhs) +{ + return (lhs > rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE uint32_t element(const uint32_t lhs, const uint32_t rhs) +{ + return (lhs < rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE int64_t element(const int64_t lhs, const int64_t rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE int64_t element(const int64_t lhs, const int64_t rhs) +{ + return (lhs > rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE int64_t element(const int64_t lhs, const int64_t rhs) +{ + return (lhs < rhs) ? lhs : rhs; +} + +/* +Reduction initialization primitives +*/ +template <> +DS_D_INLINE float init() +{ + return 0.0f; +} + +template <> +DS_D_INLINE float init() +{ + // Positive infinity + return INFINITY; +} + +template <> +DS_D_INLINE float init() +{ + // Negative infinity + return -INFINITY; +} + +template <> +DS_D_INLINE __half init() +{ + constexpr __half_raw zero = {0x0000}; + return __half(zero); +} + +template <> +DS_D_INLINE __half init() +{ + constexpr __half_raw inf = {0x7C00}; + return __half(inf); +} + +template <> +DS_D_INLINE __half init() +{ + constexpr __half_raw neg_inf = {0xFC00}; + return __half(neg_inf); +} + +template <> +DS_D_INLINE __half2 init() +{ +#ifdef __HIP_PLATFORM_AMD__ + return __half2{_Float16_2{0x0000, 0x0000}}; +#else + constexpr __half2_raw zero = {0x0000, 0x0000}; + return __half2(zero); +#endif +} + +template <> +DS_D_INLINE __half2 init() +{ +#ifdef __HIP_PLATFORM_AMD__ + return __half2{_Float16_2{0x7C00, 0x7C00}}; +#else + constexpr __half2_raw inf = {0x7C00, 0x7C00}; + return __half2(inf); +#endif +} + +template <> +DS_D_INLINE __half2 init() +{ +#ifdef __HIP_PLATFORM_AMD__ + return __half2{_Float16_2{0xFC00, 0xFC00}}; +#else + constexpr __half2_raw neg_inf = {0xFC00, 0xFC00}; + return __half2(neg_inf); +#endif +} + +template <> +DS_D_INLINE int32_t init() +{ + return 0; +} + +template <> +DS_D_INLINE int32_t init() +{ + return 0x7FFFFFFF; +} + +template <> +DS_D_INLINE int32_t init() +{ + return 0x80000000; +} + +template <> +DS_D_INLINE uint32_t init() +{ + return 0; +} + +template <> +DS_D_INLINE uint32_t init() +{ + return 0xFFFFFFFF; +} + +template <> +DS_D_INLINE uint32_t init() +{ + return 0; +} + +template <> +DS_D_INLINE int64_t init() +{ + return 0; +} + +template <> +DS_D_INLINE int64_t init() +{ + return 0x7FFFFFFFFFFFFFFF; +} + +template <> +DS_D_INLINE int64_t init() +{ + return 0x8000000000000000; +} + +template <> +DS_D_INLINE uint64_t init() +{ + return 0; +} + +template <> +DS_D_INLINE uint64_t init() +{ + return 0xFFFFFFFFFFFFFFFF; +} + +template <> +DS_D_INLINE uint64_t init() +{ + return 0; +} + +template +DS_D_INLINE void init(T* data) +{ + data[0] = init(); +} + +template +DS_D_INLINE void init(T* data) +{ + data[0] = init(); + data[1] = init(); +} + +template +DS_D_INLINE void init(T* data) +{ + data[0] = init(); + data[1] = init(); + data[2] = init(); +} + +template +DS_D_INLINE void init(T* data) +{ + data[0] = init(); + data[1] = init(); + data[2] = init(); + data[3] = init(); +} + +/* +Warp reduction primitives + +`reduction_width` is an unsafe template parameter, that is that +when using `reduction_width` < hw_warp_size the warp is partitioned +into `hw_warp_size` / `reduction_width` groups of partial sums. + +If someone can figure out how to use variadic templates in a reasonable way +here (fold is C++17 only and I don't think helps and recursion feels like +huge overkill that harms readability) that would be wonderful. +*/ + +template +DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) +{ +#pragma unroll + for (int i = 1; i < reduce_width; i *= 2) { + data[0] = element(data[0], warp.shfl_xor(data[0], i)); + } +} + +template +DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) +{ +#pragma unroll + for (int i = 1; i < reduce_width; i *= 2) { + data[0] = element(data[0], warp.shfl_xor(data[0], i)); + data[1] = element(data[1], warp.shfl_xor(data[1], i)); + } +} + +template +DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) +{ +#pragma unroll + for (int i = 1; i < reduce_width; i *= 2) { + data[0] = element(data[0], warp.shfl_xor(data[0], i)); + data[1] = element(data[1], warp.shfl_xor(data[1], i)); + data[2] = element(data[2], warp.shfl_xor(data[2], i)); + } +} + +template +DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) +{ +#pragma unroll + for (int i = 1; i < reduce_width; i *= 2) { + data[0] = element(data[0], warp.shfl_xor(data[0], i)); + data[1] = element(data[1], warp.shfl_xor(data[1], i)); + data[2] = element(data[2], warp.shfl_xor(data[2], i)); + data[3] = element(data[3], warp.shfl_xor(data[3], i)); + } +} + +/* +Implementation for primary block reduction that serves both `block` and +`partitioned_block`. + +Total warps refers to the reduction width of the reduction, not +the number of warps in the block (which may exceed that +if the block is partitioned or if we do a conservative bound at +compile time). +*/ +template +DS_D_INLINE void _block(cg::thread_block& tb, + cg::thread_block_tile& warp_arg, + T* data) +{ + constexpr int elems = sizeof...(Ops); + constexpr int bytes = sizeof(T); + // Unused when `partition_size == 1` or total_warps == 1 + __shared__ T reduce_buffer[max_warps * elems]; + +#ifdef __HIP_PLATFORM_AMD__ + const int total_threads = blockDim.x * blockDim.y * blockDim.z; + const int running_warps = total_threads / hw_warp_size; +#else + const int running_warps = warp_arg.meta_group_size(); +#endif + + // Always perform warp-scope reduction + _warp(warp_arg, data); + + // If max_warps == 1 let's skip the runtime check + if (total_warps != 1) { + if (warp_arg.thread_rank() == 0) { +#pragma unroll + for (int i = 0; i < elems; i++) { + mem_access::store_shared(reduce_buffer + elems * _warp_rank() + i, data + i); + } + } + + // Synchronization inside block-uniform conditional is safe + tb.sync(); + + if (_warp_rank() == 0) { + if (warp_arg.thread_rank() < running_warps) { +#pragma unroll + for (int i = 0; i < elems; i++) { + mem_access::load_shared( + data + i, reduce_buffer + elems * warp_arg.thread_rank() + i); + } + } else { + init(data); + } + + _warp(warp_arg, data); + +#pragma unroll + for (int i = 0; i < elems; i++) { + mem_access::store_shared(reduce_buffer + elems * warp_arg.thread_rank() + i, + data + i); + } + } + + // Synchronization inside block-uniform conditional is safe + tb.sync(); + +#pragma unroll + for (int i = 0; i < elems; i++) { + mem_access::load_shared(data + i, reduce_buffer + _warp_rank() * elems + i); + } + } +} + +/* +Main API implementations. For the most part, they just convert the individual +variables into arrays, which makes working with them easier with a single +implementation. In theory, we could use the `_block` implementation as another +option, but the nature of using a pointer is a little less safe and this allows +us to obfuscate the details of the partitioned implementation. +*/ +template +DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val) +{ + _block(tb, warp, &val); +} + +template +DS_D_INLINE void block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2) +{ + float data[2] = {val1, val2}; + _block(tb, warp, data); + val1 = data[0]; + val2 = data[1]; +} + +template +DS_D_INLINE void block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3) +{ + float data[3] = {val1, val2, val3}; + _block(tb, warp, data); + val1 = data[0]; + val2 = data[1]; + val3 = data[2]; +} + +template +DS_D_INLINE void block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3, + float& val4) +{ + float data[4] = {val1, val2, val3, val4}; + _block(tb, warp, data); + val1 = data[0]; + val2 = data[1]; + val3 = data[2]; + val4 = data[3]; +} + +/* +Note: for the partitioned blocks, the implementation does not support non-power of 2 blocks in order +to shorten block scale reduction length. +*/ +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val) +{ + if (num_threads <= hw_warp_size) { + _warp(warp, &val); + } else { + constexpr int num_warps = num_threads / hw_warp_size; + _block(tb, warp, &val); + } +} + +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2) +{ + float data[2] = {val1, val2}; + + if (num_threads <= hw_warp_size) { + _warp(warp, data); + } else { + constexpr int num_warps = num_threads / hw_warp_size; + _block(tb, warp, data); + } + + val1 = data[0]; + val2 = data[1]; +} + +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3) +{ + float data[3] = {val1, val2, val3}; + + if (num_threads <= hw_warp_size) { + _warp(warp, data); + } else { + constexpr int num_warps = num_threads / hw_warp_size; + _block(tb, warp, data); + } + + val1 = data[0]; + val2 = data[1]; + val3 = data[2]; +} + +template +DS_D_INLINE void partitioned_block(cg::thread_block& tb, + cg::thread_block_tile& warp, + float& val1, + float& val2, + float& val3, + float& val4) +{ + float data[4] = {val1, val2, val3, val4}; + + if (num_threads <= hw_warp_size) { + _warp(warp, data); + } else { + constexpr int num_warps = num_threads / hw_warp_size; + _block(tb, warp, data); + } + + val1 = data[0]; + val2 = data[1]; + val3 = data[2]; + val4 = data[3]; +} + +/* +Arg-reduce is a specialization of the above. We only support this with a single reduction +parameter. This only works for max/min reductions. +*/ + +__align__(8) struct IdxReduceResult { + /* + NOTE: ORDERING MATTERS HERE! The idx is the least significant set of bits + and the val is the most significant. Changing the order of this declaration + will break the code. + */ + int idx; + float val; +}; + +template +DS_D_INLINE IdxReduceResult +idx_reduce(cg::thread_block& tb, cg::thread_block_tile& warp, float val, int idx) +{ + IdxReduceResult res = {idx, val}; + + // Clear out the nan. This shouldn't be an issue for our initial applications + if (isnan(val)) res.val = init(); + + // Can do float compares as integers. By packing the index into the lower bits + // we can just do a single int64 rather than a branch, compare, and select. + // One side benefit of this is that it is by nature a stable algorithm and + // will always bias ties to the higher index. + int64_t* res_as_int = reinterpret_cast(&res); + + // The way floating point compare works is normally to perform a sign comparison + // and if they match, then do a comparison of the rest of the bits as unsigned + // integers. Since we are bundling these, that means for negative values we need + // to reverse the sort order, which we can do with an XOR. + if (val < 0) { *res_as_int ^= 0x7fffffff00000000; } + + _block(tb, warp, res_as_int); + + // Sign bit is preserved, so we can check if we need to invert the mantissa back + if (res.val < 0) { *res_as_int ^= 0x7fffffff00000000; } + + return res; +} + +} // namespace reduce diff --git a/op_builder/inference_core_ops.py b/op_builder/inference_core_ops.py index b8ea54cd0b3f..229b500bebda 100755 --- a/op_builder/inference_core_ops.py +++ b/op_builder/inference_core_ops.py @@ -83,10 +83,10 @@ def include_paths(self): 'inference/v2/kernels/core_ops/cuda_layer_norm', 'inference/v2/kernels/core_ops/cuda_rms_norm', 'inference/v2/kernels/core_ops/gated_activations', + 'inference/v2/kernels/includes', ] prefix = self.get_prefix() sources = [os.path.join(prefix, src) for src in sources] - sources.append('csrc/includes') return sources diff --git a/op_builder/inference_cutlass_builder.py b/op_builder/inference_cutlass_builder.py index 73400b2ce10e..51f7931d9435 100644 --- a/op_builder/inference_cutlass_builder.py +++ b/op_builder/inference_cutlass_builder.py @@ -81,6 +81,7 @@ def extra_ldflags(self): def include_paths(self): sources = [ + 'inference/v2/kernels/includes', 'inference/v2/kernels/cutlass_ops/mixed_gemm', 'inference/v2/kernels/cutlass_ops/moe_gemm', 'inference/v2/kernels/cutlass_ops/shared_resources/', @@ -88,5 +89,4 @@ def include_paths(self): prefix = self.get_prefix() sources = [os.path.join(prefix, src) for src in sources] - sources.append('csrc/includes') return sources diff --git a/op_builder/ragged_ops.py b/op_builder/ragged_ops.py index 505aaa9708cf..13d71b476b5a 100644 --- a/op_builder/ragged_ops.py +++ b/op_builder/ragged_ops.py @@ -96,6 +96,7 @@ def extra_ldflags(self): def include_paths(self): sources = [ + 'inference/v2/kernels/includes', 'inference/v2/kernels/ragged_ops', 'inference/v2/kernels/ragged_ops/atom_builder', 'inference/v2/kernels/ragged_ops/blocked_flash', @@ -110,7 +111,4 @@ def include_paths(self): prefix = self.get_prefix() sources = [os.path.join(prefix, src) for src in sources] - - sources.append('csrc/includes') - return sources diff --git a/op_builder/ragged_utils.py b/op_builder/ragged_utils.py index c1d38eef1981..89450e1fd30d 100755 --- a/op_builder/ragged_utils.py +++ b/op_builder/ragged_utils.py @@ -70,8 +70,8 @@ def extra_ldflags(self): return [] def include_paths(self): - include_dir = "inference/v2/ragged/includes" + include_dirs = ['inference/v2/ragged/includes', 'inference/v2/kernels/includes'] prefix = self.get_prefix() - include_dir = os.path.join(prefix, include_dir) + includes = [os.path.join(prefix, include_dir) for include_dir in include_dirs] - return ['csrc/includes', include_dir] + return includes diff --git a/tests/unit/inference/v2/__init__.py b/tests/unit/inference/v2/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/v2/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/v2/inference_test_utils.py b/tests/unit/inference/v2/inference_test_utils.py new file mode 100644 index 000000000000..d63c51267e51 --- /dev/null +++ b/tests/unit/inference/v2/inference_test_utils.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Tuple + +import torch +from deepspeed.accelerator import get_accelerator + +TOLERANCES = None + + +def get_tolerances(): + global TOLERANCES + if TOLERANCES is None: + TOLERANCES = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)} + if get_accelerator().is_bf16_supported(): + # Note: BF16 tolerance is higher than FP16 because of the lower precision (7 (+1) bits vs + # 10 (+1) bits) + TOLERANCES[torch.bfloat16] = (4.8e-1, 3.2e-2) + return TOLERANCES + + +DTYPES = None + + +def get_dtypes(include_float=True): + global DTYPES + if DTYPES is None: + DTYPES = [torch.float16, torch.float32] if include_float else [torch.float16] + try: + if get_accelerator().is_bf16_supported(): + DTYPES.append(torch.bfloat16) + except (AssertionError, AttributeError): + pass + return DTYPES + + +def allclose(x, y, tolerances: Tuple[int, int] = None): + assert x.dtype == y.dtype + if tolerances is None: + rtol, atol = get_tolerances()[x.dtype] + else: + rtol, atol = tolerances + return torch.allclose(x, y, rtol=rtol, atol=atol) diff --git a/tests/unit/inference/v2/kernels/__init__.py b/tests/unit/inference/v2/kernels/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/v2/kernels/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/v2/kernels/core_ops/__init__.py b/tests/unit/inference/v2/kernels/core_ops/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/v2/kernels/core_ops/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/v2/kernels/core_ops/test_bias_activation.py b/tests/unit/inference/v2/kernels/core_ops/test_bias_activation.py new file mode 100644 index 000000000000..376188b92565 --- /dev/null +++ b/tests/unit/inference/v2/kernels/core_ops/test_bias_activation.py @@ -0,0 +1,101 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import ActivationType, DtypeEnum +from deepspeed.inference.v2.kernels.core_ops import CUDABiasActivation +from ....v2.inference_test_utils import get_dtypes, allclose + + +def reference_bias_act_implementation(input: torch.Tensor, bias: Optional[torch.Tensor], + act_type: ActivationType) -> torch.Tensor: + bias_func_map = { + ActivationType.RELU: torch.nn.functional.relu, + ActivationType.GELU: torch.nn.functional.gelu, + ActivationType.SILU: torch.nn.functional.silu, + ActivationType.IDENTITY: lambda x: x, + } + + dtype = input.dtype + input_f = input.to(torch.float32) + if bias is not None: + bias_f = bias.to(torch.float32) + output_f = input_f + bias_f + else: + output_f = input_f + output_f = bias_func_map[act_type](output_f) + + return output_f.to(dtype) + + +def _bias_activation_test_helper(tokens: int, + channels: int, + act_fn: ActivationType, + dtype: DtypeEnum, + use_bias: bool = True) -> None: + """ + Fully parameterized testing entry point. + """ + # Input vals + input_tensor = torch.randn((tokens, channels), dtype=dtype.value, device=get_accelerator().current_device_name()) + if use_bias: + bias = torch.randn((channels), dtype=dtype.value, device=get_accelerator().current_device_name()) + else: + bias = None + + # Reference output + ref_output = reference_bias_act_implementation(input_tensor, bias, act_fn) + + bias_act = CUDABiasActivation(channels, dtype, act_fn) + + # New output + ds_tensor = input_tensor.clone() + bias_act(ds_tensor, bias) + + # Check + assert allclose(ds_tensor, ref_output) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens, channels", [(1, 4096), (37, 2048), (112, 14432), (1024, 6144)]) +@pytest.mark.parametrize("dtype", get_dtypes(include_float=False)) +def test_token_channels_permutations(tokens: int, channels: int, dtype: torch.dtype) -> None: + """ + Validate bias activation kernel with different token and channel permutations when using the RELU + activation function. + """ + act_fn = ActivationType.RELU + dtype = DtypeEnum(dtype) + _bias_activation_test_helper(tokens, channels, act_fn, dtype) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("act_fn", + [ActivationType.RELU, ActivationType.GELU, ActivationType.SILU, ActivationType.IDENTITY]) +def test_act_fns(act_fn: ActivationType) -> None: + """ + Validate bias activation kernel with different activation functions. + """ + tokens = 223 + channels = 4096 + dtype = DtypeEnum.fp16 + _bias_activation_test_helper(tokens, channels, act_fn, dtype) + + +@pytest.mark.inference_v2_ops +def test_no_bias() -> None: + """ + Validate bias activation kernel with no bias. + """ + tokens = 223 + channels = 4096 + dtype = DtypeEnum.fp16 + act_fn = ActivationType.IDENTITY + _bias_activation_test_helper(tokens, channels, act_fn, dtype, use_bias=False) diff --git a/tests/unit/inference/v2/kernels/core_ops/test_blas_linear.py b/tests/unit/inference/v2/kernels/core_ops/test_blas_linear.py new file mode 100644 index 000000000000..864db6204a16 --- /dev/null +++ b/tests/unit/inference/v2/kernels/core_ops/test_blas_linear.py @@ -0,0 +1,73 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Tuple + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.kernels.core_ops import BlasLibLinear +from ....v2.inference_test_utils import allclose + +# Note: only testing with FP16 and BF16 because we use TF32 on Ampere and we don't have a good +# set of tolerances. Since this is just on top of BLAS though, the test is more about +# making sure the stride/contiguity is correct and that's data type agnostic. + + +def reference_implementation(hidden_states, weights): + return hidden_states @ weights.t() + + +problem_shapes = [ + (1, 1, 1024, 1024), + (1, 1024, 1024, 1024), + (2, 1024, 1024, 1024), + (1, 128, 768, 3072), + (1, 128, 3072, 768), + (1, 1024, 8192, 8192), + (1, 733, 8192, 32768), + (1, 13, 32768, 8192), +] + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("fp_dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("problem_shape", problem_shapes) +def test_blas_linear(fp_dtype: torch.dtype, problem_shape: Tuple[int, int, int, int]): + batch, seq_len, in_features, out_features = problem_shape + hidden_states = torch.randn(batch, seq_len, in_features, dtype=fp_dtype, + device=get_accelerator().current_device()) * 0.1 + weights = torch.randn(out_features, in_features, dtype=fp_dtype, device=get_accelerator().current_device()) * 0.01 + ds_output = torch.empty(batch, seq_len, out_features, dtype=fp_dtype, device=get_accelerator().current_device()) + + ds_kernel = BlasLibLinear(fp_dtype) + + ds_output = ds_kernel(ds_output, hidden_states, weights) + ref_output = reference_implementation(hidden_states, weights) + + assert allclose(ds_output, ref_output) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("fp_dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("problem_shape", problem_shapes) +def test_blas_linear_t(fp_dtype: torch.dtype, problem_shape: Tuple[int, int, int, int]): + batch, seq_len, in_features, out_features = problem_shape + hidden_states = torch.randn(batch, seq_len, in_features, dtype=fp_dtype, + device=get_accelerator().current_device()) * 0.1 + weights = torch.randn(out_features, in_features, dtype=fp_dtype, device=get_accelerator().current_device()) * 0.01 + ds_output = torch.empty(batch, seq_len, out_features, dtype=fp_dtype, device=get_accelerator().current_device()) + + ds_kernel = BlasLibLinear(fp_dtype) + + # Transpose the weights then revert to the format we expect. + weights = weights.t().contiguous() + weights = weights.t() + ds_output = ds_kernel(ds_output, hidden_states, weights) + + ref_output = reference_implementation(hidden_states, weights) + + assert allclose(ds_output, ref_output) diff --git a/tests/unit/inference/v2/kernels/core_ops/test_gated_activation.py b/tests/unit/inference/v2/kernels/core_ops/test_gated_activation.py new file mode 100644 index 000000000000..8cb95a6cdcba --- /dev/null +++ b/tests/unit/inference/v2/kernels/core_ops/test_gated_activation.py @@ -0,0 +1,133 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Optional + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.kernels.core_ops import CUDAGatedActivation +from deepspeed.inference.v2.inference_utils import ActivationType +from ....v2.inference_test_utils import get_dtypes, allclose + + +def reference_geglu_implementation(input: torch.Tensor, + bias: Optional[torch.Tensor] = None, + act_fn: Optional[ActivationType] = ActivationType.GEGLU) -> torch.Tensor: + act_func_map = { + ActivationType.ReGLU: torch.nn.functional.relu, + ActivationType.GEGLU: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + ActivationType.SiGLU: torch.nn.functional.silu, + } + + dtype = input.dtype + input = input.to(torch.float32) + + if bias is not None: + bias = bias.to(torch.float32) + input = input + bias + + act_act = input[..., ::2] + act_linear = input[..., 1::2] + + act_act = act_func_map[act_fn](act_act) + + return (act_act * act_linear).to(dtype) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("shape", [(1372, 16384), (2, 743, 22016)]) +@pytest.mark.parametrize("dtype", get_dtypes()) +def test_dtypes(shape: Iterable[int], dtype: torch.dtype) -> None: + input_tensor = torch.randn(shape, dtype=dtype, device=get_accelerator().current_device_name()) + + # Reference output + ref_output = reference_geglu_implementation(input_tensor, act_fn=ActivationType.GEGLU) + + # Build kernel + geglu = CUDAGatedActivation(input_tensor.size(-1), input_tensor.dtype, ActivationType.GEGLU) + + # New output + output_shape = list(input_tensor.shape) + output_shape[-1] //= 2 + output_tensor = torch.empty(output_shape, dtype=input_tensor.dtype, device=get_accelerator().current_device_name()) + geglu(output_tensor, input_tensor) + + # Check + assert allclose(output_tensor, ref_output) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("act_fn", [ActivationType.GEGLU, ActivationType.ReGLU, ActivationType.SiGLU]) +def test_act_fn(act_fn: ActivationType) -> None: + input_tensor = torch.randn(832, 4096, dtype=torch.float16, device=get_accelerator().current_device()) + + # Reference output + ref_output = reference_geglu_implementation(input_tensor, act_fn=act_fn) + + cuda_act = CUDAGatedActivation(4096, torch.float16, act_fn) + + # New output + output_tensor = torch.empty(832, 2048, dtype=torch.float16, device=get_accelerator().current_device()) + cuda_act(output_tensor, input_tensor) + + assert allclose(output_tensor, ref_output) + + +@pytest.mark.inference_v2_ops +def test_act_with_bias(): + input_tensor = torch.randn(832, 4096, dtype=torch.float16, device=get_accelerator().current_device()) + bias = torch.randn(4096, dtype=torch.float16, device=get_accelerator().current_device()) + + # Reference output + ref_output = reference_geglu_implementation(input_tensor, bias=bias, act_fn=ActivationType.GEGLU) + + cuda_act = CUDAGatedActivation(4096, torch.float16, ActivationType.GEGLU) + + # New output + output_tensor = torch.empty(832, 2048, dtype=torch.float16, device=get_accelerator().current_device()) + + cuda_act(output_tensor, input_tensor, bias) + + assert allclose(output_tensor, ref_output) + + +@pytest.mark.inference_v2_ops +def test_max_channels(): + input_tensor = torch.randn(832, 48152, dtype=torch.float16, device=get_accelerator().current_device()) + + ref_output = reference_geglu_implementation(input_tensor, act_fn=ActivationType.GEGLU) + + cuda_act = CUDAGatedActivation(48152, torch.float16, ActivationType.GEGLU) + + output_tensor = torch.empty(832, 24076, dtype=torch.float16, device=get_accelerator().current_device()) + cuda_act(output_tensor, input_tensor) + + assert allclose(output_tensor, ref_output) + + +@pytest.mark.inference_v2_ops +def test_bad_dtype() -> None: + with pytest.raises(ValueError): + CUDAGatedActivation(128, torch.int8, ActivationType.GEGLU) + + +@pytest.mark.inference_v2_ops +def test_bad_act_fn() -> None: + with pytest.raises(ValueError): + CUDAGatedActivation(128, torch.float16, ActivationType.RELU) + + +@pytest.mark.inference_v2_ops +def test_bad_alignment() -> None: + with pytest.raises(ValueError): + CUDAGatedActivation(127, torch.float16, ActivationType.GEGLU) + + +@pytest.mark.inference_v2_ops +def test_too_many_channels() -> None: + with pytest.raises(ValueError): + CUDAGatedActivation(49160, torch.float16, ActivationType.GEGLU) diff --git a/tests/unit/inference/v2/kernels/core_ops/test_post_ln.py b/tests/unit/inference/v2/kernels/core_ops/test_post_ln.py new file mode 100644 index 000000000000..0b489894bb9b --- /dev/null +++ b/tests/unit/inference/v2/kernels/core_ops/test_post_ln.py @@ -0,0 +1,47 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.kernels.core_ops import CUDAFPPostLN +from ....v2.inference_test_utils import get_dtypes, allclose + + +def reference_implementation(residual: torch.Tensor, hidden_states: torch.Tensor, gamma: torch.Tensor, + beta: torch.Tensor, epsilon: float) -> torch.Tensor: + residual_f = residual.to(torch.float32) + hidden_states_f = hidden_states.to(torch.float32) + gamma_f = gamma.to(torch.float32) + beta_f = beta.to(torch.float32) + return torch.nn.functional.layer_norm(residual_f + hidden_states_f, (hidden_states_f.size(-1), ), + weight=gamma_f, + bias=beta_f, + eps=epsilon).to(hidden_states.dtype) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens, channels", [(1, 4096), (37, 2048), (112, 14432), (1024, 6144)]) +@pytest.mark.parametrize("dtype", get_dtypes()) +def test_cuda_post_ln(tokens: int, channels: int, dtype: torch.dtype) -> None: + + # Input vals + hidden_states = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name()) + residual = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name()) + gamma = torch.randn((channels), dtype=dtype, device=get_accelerator().current_device_name()) + beta = torch.rand((channels), dtype=dtype, device=get_accelerator().current_device_name()) + epsilon = 1e-5 + + # Reference output + ref_output = reference_implementation(residual, hidden_states, gamma, beta, epsilon) + + # New output + post_ln_kernel = CUDAFPPostLN(hidden_states.size(-1), residual.dtype) + ds_output = torch.empty_like(residual) + post_ln_kernel(ds_output, residual, hidden_states, gamma, beta) + + # Check + assert allclose(ds_output, ref_output) diff --git a/tests/unit/inference/v2/kernels/core_ops/test_pre_ln.py b/tests/unit/inference/v2/kernels/core_ops/test_pre_ln.py new file mode 100644 index 000000000000..ffb748e57af2 --- /dev/null +++ b/tests/unit/inference/v2/kernels/core_ops/test_pre_ln.py @@ -0,0 +1,51 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.kernels.core_ops import CUDAFPPreLN +from ....v2.inference_test_utils import get_dtypes, allclose + + +def reference_implementation(residual: torch.Tensor, hidden_states: torch.Tensor, gamma: torch.Tensor, + beta: torch.Tensor, epsilon: float) -> torch.Tensor: + residual_f = residual.to(torch.float32) + hidden_states_f = hidden_states.to(torch.float32) + gamma_f = gamma.to(torch.float32) + beta_f = beta.to(torch.float32) + residual_out = residual_f + hidden_states_f + hidden_out = torch.nn.functional.layer_norm(residual_out, (hidden_states_f.size(-1), ), + weight=gamma_f, + bias=beta_f, + eps=epsilon) + return residual_out.to(hidden_states.dtype), hidden_out.to(hidden_states.dtype) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens, channels", [(1, 4096), (37, 2048), (112, 14432), (1024, 6144)]) +@pytest.mark.parametrize("dtype", get_dtypes()) +def test_cuda_pre_ln(tokens: int, channels: int, dtype: torch.dtype) -> None: + + # Input vals + hidden_states = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name()) + residual = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name()) + gamma = torch.randn((channels), dtype=dtype, device=get_accelerator().current_device_name()) + beta = torch.rand((channels), dtype=dtype, device=get_accelerator().current_device_name()) + epsilon = 1e-5 + + # Reference output + ref_output_res, ref_output_hid = reference_implementation(residual, hidden_states, gamma, beta, epsilon) + + # New output + pre_ln_kernel = CUDAFPPreLN(hidden_states.size(-1), residual.dtype) + ds_output_res = torch.empty_like(residual) + ds_output_hid = torch.empty_like(hidden_states) + pre_ln_kernel(ds_output_res, ds_output_hid, residual, hidden_states, gamma, beta) + + # Check + assert allclose(ds_output_res, ref_output_res) + assert allclose(ds_output_hid, ref_output_hid) diff --git a/tests/unit/inference/v2/kernels/core_ops/test_rms_norm.py b/tests/unit/inference/v2/kernels/core_ops/test_rms_norm.py new file mode 100644 index 000000000000..63b16da171c9 --- /dev/null +++ b/tests/unit/inference/v2/kernels/core_ops/test_rms_norm.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import DtypeEnum +from deepspeed.inference.v2.kernels.core_ops import CUDARMSNorm, CUDARMSPreNorm +from ....v2.inference_test_utils import get_dtypes, allclose + + +def reference_rms_norm(vals: torch.Tensor, gamma: torch.Tensor, epsilon: float = 1e-5) -> torch.Tensor: + variance = vals.to(torch.float32).pow(2).mean(-1, keepdim=True) + vals = vals * torch.rsqrt(variance + epsilon) + + if gamma.dtype in [torch.float16, torch.bfloat16]: + vals = vals.to(gamma.dtype) + + return gamma * vals + + +def reference_rms_pre_norm(vals: torch.Tensor, + residual: torch.Tensor, + gamma: torch.Tensor, + epsilon: float = 1e-5) -> torch.Tensor: + residual = residual + vals + return residual, reference_rms_norm(residual, gamma, epsilon) + + +def _rms_norm_testing_helper(rows: int, channels: int, do_residual: bool, dtype: DtypeEnum) -> None: + device = get_accelerator().current_device_name() + t_dtype = dtype.value + + vals = torch.randn((rows, channels), dtype=t_dtype, device=device) + gamma = torch.randn((channels), dtype=t_dtype, device=device) + epsilon = 1e-5 + + if do_residual: + residual_in = torch.randn((rows, channels), dtype=t_dtype, device=device) + ds_residual = residual_in.clone() + + ref_residual, ref_output = reference_rms_pre_norm(vals, residual_in, gamma, epsilon) + + kernel = CUDARMSPreNorm(channels, t_dtype, epsilon=epsilon) + ds_out = torch.empty_like(ds_residual) + + kernel(ds_residual, ds_out, residual_in, vals, gamma) + + assert allclose(ds_out, ref_output) + assert allclose(ds_residual, ref_residual) + else: + + ref_output = reference_rms_norm(vals, gamma, epsilon) + + kernel = CUDARMSNorm(channels, t_dtype, epsilon=epsilon) + ds_out = torch.empty_like(vals) + + kernel(ds_out, vals, gamma) + + assert allclose(ds_out, ref_output) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("dtype", get_dtypes()) +@pytest.mark.parametrize("do_residual", [True, False]) +def test_rms_dtypes(dtype: DtypeEnum, do_residual: bool) -> None: + _rms_norm_testing_helper(883, 1024, do_residual, DtypeEnum(dtype)) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("rows, cols", [(1, 4096), (37, 2048), (112, 14432), (1024, 6144)]) +@pytest.mark.parametrize("do_residual", [True, False]) +def test_rms_shapes(rows: int, cols: int, do_residual: bool) -> None: + _rms_norm_testing_helper(rows, cols, do_residual, DtypeEnum.fp16) diff --git a/tests/unit/inference/v2/kernels/cutlass_ops/__init__.py b/tests/unit/inference/v2/kernels/cutlass_ops/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/v2/kernels/cutlass_ops/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/v2/kernels/cutlass_ops/test_moe_gemm.py b/tests/unit/inference/v2/kernels/cutlass_ops/test_moe_gemm.py new file mode 100644 index 000000000000..ed76dabe1f4c --- /dev/null +++ b/tests/unit/inference/v2/kernels/cutlass_ops/test_moe_gemm.py @@ -0,0 +1,113 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import ActivationType, DtypeEnum +from deepspeed.inference.v2.kernels.cutlass_ops import MoEGEMM +from ....v2.inference_test_utils import allclose + +SINGLE_EXPERT_CASES = [(13, 2048, 2048), (256, 1024, 4096), (278, 5120, 2048), (893, 5120, 2560)] + +PYTORCH_ACT_FN_MAP = { + ActivationType.GELU: torch.nn.functional.gelu, + ActivationType.SILU: torch.nn.functional.silu, + ActivationType.RELU: torch.nn.functional.relu +} + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens, in_neurons, out_neurons", SINGLE_EXPERT_CASES) +def test_single_expert(n_tokens: int, in_neurons: int, out_neurons: int) -> None: + """ + Validate that the GEMM kernel produces identical results for a single GEMM instance. + """ + device = get_accelerator().current_device() + + activations = torch.rand((n_tokens, in_neurons), device=device, dtype=torch.float16) - 0.5 + weights = torch.rand((1, in_neurons, out_neurons), device=device, dtype=torch.float16) - 0.5 + biases = torch.randn((1, out_neurons), device=device, dtype=torch.float16) + + weights_ref = weights.reshape(in_neurons, out_neurons) + biases_ref = biases.reshape(out_neurons) + ref_output = torch.matmul(activations, weights_ref) + biases_ref + + moe_gemm = MoEGEMM(DtypeEnum.fp16, ActivationType.IDENTITY) + output = torch.empty((n_tokens, out_neurons), device=device, dtype=torch.float16) + cumsum_rows = torch.tensor([n_tokens], dtype=torch.int64, device=device) + + moe_gemm(output, activations, weights, cumsum_rows, biases) + assert allclose(output, ref_output, tolerances=(1e-2, 1e-2)) + get_accelerator().synchronize() + + +def moe_test_helper(in_neurons: int, out_neurons: int, n_experts: int, max_tokens_per_expert: int, + act_fn: ActivationType, dtype: DtypeEnum) -> None: + """ + Helper function for validating the GEMM kernel for a single expert. + """ + device = get_accelerator().current_device() + + expert_allocations = torch.randint(0, max_tokens_per_expert, (n_experts, ), device=device, dtype=torch.int32) + cumsum_rows = expert_allocations.cumsum(dim=0) + print(cumsum_rows.dtype) + + activations = torch.rand((cumsum_rows[-1], in_neurons), device=device, dtype=dtype.value) - 0.5 + weights = torch.rand((n_experts, in_neurons, out_neurons), device=device, dtype=dtype.value) - 0.5 + biases = torch.randn((n_experts, out_neurons), device=device, dtype=dtype.value) + + out_ref = torch.empty((cumsum_rows[-1], out_neurons), device=device, dtype=dtype.value) + + for expert_idx in range(n_experts): + start = cumsum_rows[expert_idx - 1] if expert_idx > 0 else 0 + end = cumsum_rows[expert_idx] + activations_slice = activations[start:end] + weights_slice = weights[expert_idx] + biases_slice = biases[expert_idx] + out_ref[start:end] = torch.matmul(activations_slice, weights_slice) + biases_slice + + if act_fn != ActivationType.IDENTITY: + act_fn_fn = PYTORCH_ACT_FN_MAP[act_fn] + out_ref = act_fn_fn(out_ref) + + moe_gemm = MoEGEMM(DtypeEnum.fp16, act_fn) + output = torch.empty((cumsum_rows[-1], out_neurons), device=device, dtype=dtype.value) + + moe_gemm(output, activations, weights, cumsum_rows, biases) + + if dtype == DtypeEnum.bf16: + assert allclose(output, out_ref, tolerances=(1e-1, 1e-1)) + else: + assert allclose(output, out_ref, tolerances=(1e-2, 1e-2)) + get_accelerator().synchronize() + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("max_tokens_per_expert", [1, 4, 16, 64, 128]) +def test_multi_expert(max_tokens_per_expert: int) -> None: + """ + Validate for multi-expert GEMM instances that the output is identical to the reference. + """ + moe_test_helper(5120, 2048, 64, max_tokens_per_expert, ActivationType.IDENTITY, DtypeEnum.fp16) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("act_fn", [ActivationType.GELU, ActivationType.SILU, ActivationType.RELU]) +def test_act_fns(act_fn: ActivationType) -> None: + """ + Validate activation function behavior. + """ + moe_test_helper(5120, 2048, 64, 32, act_fn, DtypeEnum.fp16) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("dtype", [DtypeEnum.fp16, DtypeEnum.bf16]) +def test_dtypes(dtype: DtypeEnum) -> None: + """ + Validate data type behavior. + """ + moe_test_helper(5120, 2048, 64, 32, ActivationType.IDENTITY, dtype) diff --git a/tests/unit/inference/v2/kernels/ragged_ops/__init__.py b/tests/unit/inference/v2/kernels/ragged_ops/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/v2/kernels/ragged_ops/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/v2/kernels/ragged_ops/ragged_testing_utils.py b/tests/unit/inference/v2/kernels/ragged_ops/ragged_testing_utils.py new file mode 100644 index 000000000000..b262951000bb --- /dev/null +++ b/tests/unit/inference/v2/kernels/ragged_ops/ragged_testing_utils.py @@ -0,0 +1,300 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import random +from typing import List, Optional, Tuple + +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.ragged import ( + AllocationMode, + DSSequenceDescriptor, + DSStateManager, + DSStateManagerConfig, + KVCacheConfig, + MemoryConfig, + PlaceholderSequenceDescriptor, + RaggedBatchWrapper, +) +from ....v2.inference_test_utils import allclose + + +def build_simple_batch(seq_lens: List[int], + vocab_range: Optional[int] = 100, + padding: Optional[bool] = False) -> RaggedBatchWrapper: + """ + Construct a simple batch with the given sequence lengths. This method should not + be used for for testing scenarios that require information about KV or sequence + history. + """ + total_tokens = max(sum(seq_lens), 1024) + n_seqs = max(len(seq_lens), 128) + + config = DSStateManagerConfig(max_tracked_sequences=n_seqs, + max_ragged_sequence_count=n_seqs, + max_ragged_batch_size=total_tokens) + batch = RaggedBatchWrapper(config) + + batch.clear() + + for seq_len in seq_lens: + seq_desc = PlaceholderSequenceDescriptor() + tokens = torch.randint(0, vocab_range, (seq_len, )) + batch.insert_sequence(seq_desc, tokens) + + batch.finalize(padding=padding) + + return batch + + +def build_complex_batch(seq_params: List[Tuple[int, int, int]], + kv_block_size: int, + vocab_range: Optional[int] = 100, + padding: Optional[bool] = False) -> Tuple[RaggedBatchWrapper, int]: + """ + Construct a fully paramtrized batch with the given sequence lengths. This method + can be used to construct more realistic inputs for testing scenarios that will interact + with all the members of the RaggedBatchWrapper. + """ + seq_lens = [seq_param[0] for seq_param in seq_params] + total_tokens = max(sum(seq_lens), 1024) + n_seqs = max(len(seq_lens), 128) + + config = DSStateManagerConfig(max_tracked_sequences=n_seqs, + max_ragged_sequence_count=n_seqs, + max_ragged_batch_size=total_tokens) + batch = RaggedBatchWrapper(config) + + batch.clear() + + total_kv_blocks = 0 + + for seq_len, n_seen_tokens, kv_ptr in seq_params: + n_kv_blocks = (seq_len + n_seen_tokens + kv_block_size - 1) // kv_block_size + seq_desc = PlaceholderSequenceDescriptor(seen_tokens=n_seen_tokens, + cur_allocated_blocks=n_kv_blocks, + kv_blocks_ptr=kv_ptr) + tokens = torch.randint(0, vocab_range, (seq_len, )) + batch.insert_sequence(seq_desc, tokens) + total_kv_blocks += n_kv_blocks + + batch.finalize(padding=padding) + + return batch, total_kv_blocks + + +def build_batch_and_manager( + seq_params: List[Tuple[int, int]], + head_size: int, + n_heads_kv: int, + kv_block_size: int, + vocab_range: Optional[int] = 100, + padding: Optional[bool] = False, + kv_fill: Optional[List[torch.Tensor]] = None +) -> Tuple[RaggedBatchWrapper, DSStateManager, List[DSSequenceDescriptor]]: + """ + Will construct and populate a batch and KVCache with the given sequence parameters. + + Arguments: + seq_params (List[Tuple[int, int]]): A list of tuples containing the sequence length and + the number of tokens that have already been seen for that sequence. + head_size (int): The size of each attention head. + n_heads_kv (int): The number of attention heads for the KV-cache. + kv_block_size (int): The size of each block in the KV-cache. + vocab_range (Optional[int]): The range of the vocabulary. Defaults to 100. + padding (Optional[bool]): Whether to pad the batch. Defaults to False. + kv_fill (Optional[List[torch.Tensor]]): A list of tensors to use to populate the KV-cache. + If this is not provided, the KV-cache will be treated as empty and the contents should + not be relied upon. NOTE(cmikeh2): This functionality relies on the functionality + of LinearBlockedKVCopy. If tests relying on this feature are failing, make sure that + LinearBlockedKVCopy is working correctly. + """ + seq_lens = [seq_param[0] for seq_param in seq_params] + fill_lens = [seq_param[1] for seq_param in seq_params] + max_created_batch_len = max(sum(seq_lens), sum(fill_lens)) + total_tokens = max(max_created_batch_len, 1024) + n_seqs = max(len(seq_lens), 128) + + req_kv_blocks = [None] * n_seqs + total_kv_blocks = 0 + for i, (seq_len, n_seen_tokens) in enumerate(seq_params): + req_kv_blocks[i] = (seq_len + n_seen_tokens + kv_block_size - 1) // kv_block_size + total_kv_blocks += req_kv_blocks[i] + + kv_config = KVCacheConfig(block_size=kv_block_size, + num_allocation_groups=1, + cache_shape=(1, n_heads_kv, head_size)) + memory_config = MemoryConfig(mode=AllocationMode.ALLOCATE, size=total_kv_blocks) + + config = DSStateManagerConfig(max_tracked_sequences=n_seqs, + max_ragged_sequence_count=n_seqs, + max_ragged_batch_size=total_tokens, + memory_config=memory_config) + + batch = RaggedBatchWrapper(config) + state_manager = DSStateManager(config, kv_config) + + # At the beginning of operation, the design of the allocator is such that it will return + # linear blocks of memory. The following will "warm up" the allocator so that we can be + # more certain that code is not dependent on this behavior. + all_allocs = [] + for _ in range(20): + decision = random.randint(0, 1) + + if decision == 0: + blocks_to_allocate = random.randint(0, total_kv_blocks) + if blocks_to_allocate <= state_manager.free_blocks and blocks_to_allocate > 0: + all_allocs.append(state_manager.allocate_blocks(blocks_to_allocate)) + else: + if len(all_allocs) > 0: + idx = random.randint(0, len(all_allocs) - 1) + state_manager._kv_cache.free(all_allocs[idx]) + + del all_allocs[idx] + + for alloc in all_allocs: + state_manager._kv_cache.free(alloc) + + assert state_manager.free_blocks == total_kv_blocks + + batch.clear() + seq_descs = [] + + if kv_fill is None or sum(fill_lens) == 0: + for i, (seq_len, n_seen_tokens) in enumerate(seq_params): + # Create empty descriptor + seq_desc = state_manager.get_or_create_sequence(i) + + # Update `seen_tokens` in the descriptor + seq_desc.pre_forward(n_seen_tokens) + seq_desc.post_forward() + + # Ensure there's enough KV-cache for the sequence + kv_block_ids = state_manager.allocate_blocks(req_kv_blocks[i]) + print(f"Allocated {req_kv_blocks[i]} blocks for sequence {i}: {kv_block_ids}") + seq_desc.extend_kv_cache(kv_block_ids) + + # Insert sequence into batch + tokens = torch.randint(0, vocab_range, (seq_len, )) + batch.insert_sequence(seq_desc, tokens) + seq_desc.pre_forward(seq_len) + seq_descs.append(seq_desc) + else: + qkv = torch.empty((total_tokens, (n_heads_kv * 3) * head_size), + dtype=torch.float16, + device=get_accelerator().current_device()) + fills_as_tensor = torch.tensor(fill_lens, dtype=torch.int32) + fill_cumsum = torch.cat((torch.tensor([0], dtype=torch.int32), torch.cumsum(fills_as_tensor, dim=0))) + + for i, (_, n_seen_tokens) in enumerate(seq_params): + # Create empty descriptor + seq_desc = state_manager.get_or_create_sequence(i) + + # Update `seen_tokens` in the descriptor + if n_seen_tokens > 0: + dummy_fill_toks = torch.randint(0, vocab_range, (n_seen_tokens, )) + batch.insert_sequence(seq_desc, dummy_fill_toks) + seq_desc.pre_forward(n_seen_tokens) + + # Ensure there's enough KV-cache for the sequence + kv_block_ids = state_manager.allocate_blocks(req_kv_blocks[i]) + print(f"Allocated {req_kv_blocks[i]} blocks for sequence {i}: {kv_block_ids}") + seq_desc.extend_kv_cache(kv_block_ids) + seq_descs.append(seq_desc) + + if n_seen_tokens == 0: + continue + + assert kv_fill[i].shape[0] == n_seen_tokens + assert kv_fill[i].shape[1] == n_heads_kv * head_size * 2 + + local_q = torch.randn((n_seen_tokens, n_heads_kv * head_size), dtype=torch.float16, device=qkv.device) + local_qkv = torch.cat((local_q, kv_fill[i]), dim=1) + qkv[fill_cumsum[i]:fill_cumsum[i + 1]] = local_qkv + + batch.finalize(padding=padding) + + from deepspeed.inference.v2.kernels.ragged_ops import LinearBlockedKVCopy + kv_copy = LinearBlockedKVCopy(head_size, n_heads_kv, n_heads_kv, torch.float16) + kv_cache = state_manager.get_cache(0) + kv_copy(kv_cache, qkv, batch) + + for seq_desc in seq_descs: + if seq_desc.in_flight_tokens > 0: + seq_desc.post_forward() + + batch.clear() + + for i, (seq_len, _) in enumerate(seq_params): + seq_desc = state_manager.get_or_create_sequence(i) + tokens = torch.randint(0, vocab_range, (seq_len, )) + batch.insert_sequence(seq_desc, tokens) + seq_desc.pre_forward(seq_len) + + # We will skip KV cache allocation here because we did a lump allocation above + # for both the fill and the sequence itself. + + batch.finalize(padding=padding) + + return batch, state_manager, seq_descs + + +def validate_kv_cache(kv_cache: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_descs: List[DSSequenceDescriptor], + batch: RaggedBatchWrapper, + exact: bool = True) -> None: + """ + Given a QKV tensor and a KV cache, validate that the cache contains the correct values. + """ + block_size = kv_cache.shape[1] + n_kv_heads = kv_cache.shape[3] + head_size = kv_cache.shape[4] + + inflight_descs = batch.inflight_seq_descriptors(on_device=False)[:batch.current_sequences] + + if inflight_descs.shape[0] != len(seq_descs): + raise ValueError("The number of sequence descriptors does not match the number of sequences in the batch.") + + for seq_desc, inflight_seq in zip(seq_descs, inflight_descs): + start_idx = inflight_seq[0] + assigned_kv_blocks = seq_desc.kv_cache_ids(on_device=False) + + real_k_values = k[start_idx:start_idx + seq_desc.in_flight_tokens] + real_v_values = v[start_idx:start_idx + seq_desc.in_flight_tokens] + + start_block_idx = seq_desc.seen_tokens // block_size + local_start_idx = 0 + cur_start_idx = seq_desc.seen_tokens + + for block_idx in range(start_block_idx, seq_desc.cur_allocated_blocks): + block = kv_cache[assigned_kv_blocks[0, block_idx].item()] + block_start_idx = cur_start_idx % block_size + n_tokens_to_check = min(block_size - block_start_idx, seq_desc.in_flight_tokens - local_start_idx) + block_end_idx = block_start_idx + n_tokens_to_check + + if exact: + assert torch.equal( + block[block_start_idx:block_end_idx, 0, :, :], + real_k_values[local_start_idx:local_start_idx + n_tokens_to_check].reshape( + n_tokens_to_check, n_kv_heads, head_size)) + assert torch.equal( + block[block_start_idx:block_end_idx, 1, :, :], + real_v_values[local_start_idx:local_start_idx + n_tokens_to_check].reshape( + n_tokens_to_check, n_kv_heads, head_size)) + else: + assert allclose( + block[block_start_idx:block_end_idx, 0, :, :], + real_k_values[local_start_idx:local_start_idx + n_tokens_to_check].reshape( + n_tokens_to_check, n_kv_heads, head_size)) + assert allclose( + block[block_start_idx:block_end_idx, 1, :, :], + real_v_values[local_start_idx:local_start_idx + n_tokens_to_check].reshape( + n_tokens_to_check, n_kv_heads, head_size)) + + local_start_idx += n_tokens_to_check + cur_start_idx += n_tokens_to_check diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_atom_builder.py b/tests/unit/inference/v2/kernels/ragged_ops/test_atom_builder.py new file mode 100644 index 000000000000..a33c938a0608 --- /dev/null +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_atom_builder.py @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.inference.v2.kernels.ragged_ops import AtomBuilder +from .ragged_testing_utils import build_complex_batch + +Q_BLOCK_SIZE = 128 +KV_BLOCK_SIZE = 128 + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('seq_params', [(1, 0, 0), (1, 228, 0), (383, 0, 0), (1, 494, 0)]) +def test_single_sequence(seq_params) -> None: + seq_len, n_seen_tokens, _ = seq_params + + batch, _ = build_complex_batch([seq_params], kv_block_size=KV_BLOCK_SIZE, padding=False) + atom_builder = AtomBuilder() + + atoms = torch.empty((8, 8), dtype=torch.int32, device=torch.device("cpu")) + atoms, n_atoms = atom_builder(atoms, batch, Q_BLOCK_SIZE, KV_BLOCK_SIZE) + + calc_n_atoms = (seq_len + 127) // 128 + + assert n_atoms == calc_n_atoms + + for i, atom in enumerate(atoms[:n_atoms]): + # Since the ptr was 0, first 2 elements should be 0 + assert atom[0] == 0 + assert atom[1] == 0 + + # Since we have a single sequence, the q_start_idx should always be + # whichever atom we're on multiplied by the block size + assert atom[2] == i * Q_BLOCK_SIZE + assert atom[3] == min(Q_BLOCK_SIZE, seq_len - i * Q_BLOCK_SIZE) + total_toks = i * Q_BLOCK_SIZE + min(Q_BLOCK_SIZE, seq_len - i * Q_BLOCK_SIZE) + + assert atom[4] == (total_toks + n_seen_tokens + KV_BLOCK_SIZE - 1) // KV_BLOCK_SIZE + assert atom[5] == (total_toks + n_seen_tokens) + + assert atom[6] == n_seen_tokens + i * Q_BLOCK_SIZE diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_flash.py b/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_flash.py new file mode 100644 index 000000000000..ce5a178c9548 --- /dev/null +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_flash.py @@ -0,0 +1,197 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import itertools + +from typing import List, Tuple + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import DtypeEnum +from deepspeed.inference.v2.kernels.ragged_ops import ( + AtomBuilder, + BlockedFlashAttn, + get_q_block_size, + get_kv_block_size, + LinearBlockedKVCopy, +) +from deepspeed.inference.v2.ragged import split_kv +from deepspeed.ops.op_builder import RaggedUtilsBuilder + +from .ragged_testing_utils import build_batch_and_manager +from ....v2.inference_test_utils import allclose + +try: + from flash_attn.flash_attn_interface import flash_attn_varlen_func + validate_accuracy = True +except ImportError: + validate_accuracy = False +""" +NOTE(cmikeh2): These tests depend on atom construction and KV-cache copying to behave correctly. +If one or the other of those is not working, then these tests will fail. Before debugging here, +make sure that the atom construction and KV-cache copying tests are passing. +""" + + +def _blocked_flash_testing_helper(head_size: int, n_heads_q: int, n_heads_kv: int, + seq_params: List[Tuple[int, int]]) -> None: + """ + Helper function for testing blocked flash attention. Used to enable parametrize to only set up + a subset of parameters before being passed to the unified test function. + """ + q_block_size = get_q_block_size(head_size) + kv_block_size = get_kv_block_size(head_size) + + kvs = [] + for _, history_len in seq_params: + if history_len > 0: + kvs.append( + torch.randn((history_len, 2 * n_heads_kv * head_size), + device=get_accelerator().current_device(), + dtype=torch.float16)) + else: + kvs.append(None) + + batch, state_manager, _ = build_batch_and_manager(seq_params, head_size, n_heads_kv, kv_block_size, kv_fill=kvs) + + atom_builder = AtomBuilder() + kv_copy = LinearBlockedKVCopy(head_size, n_heads_q, n_heads_kv, DtypeEnum.fp16) + atom_flash = BlockedFlashAttn(head_size, DtypeEnum.fp16) + + total_atoms = sum((seq[0] + q_block_size - 1) // q_block_size for seq in seq_params) + atoms = torch.empty((total_atoms, 8), dtype=torch.int32, device=get_accelerator().current_device()) + alloc_func = RaggedUtilsBuilder().load().allocate_fast_host_buffer + atoms_host = alloc_func(atoms) + + qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size), + device=get_accelerator().current_device(), + dtype=torch.float16) + + atoms_host, n_atoms = atom_builder(atoms_host, batch, q_block_size, kv_block_size) + atoms.copy_(atoms_host[:n_atoms]) + + kv_cache = state_manager.get_cache(0) + kv_copy(kv_cache, qkv, batch) + + out = torch.empty((batch.current_tokens, head_size * n_heads_q), + device=get_accelerator().current_device(), + dtype=torch.float16) + k_cache, v_cache = split_kv(kv_cache) + q = qkv[:, :head_size * n_heads_q] + + atom_flash(out, q, k_cache, v_cache, atoms, 1.0) + + if validate_accuracy: + cu_seqlens_q = torch.tensor([0] + list(itertools.accumulate([seq[0] for seq in seq_params])), + dtype=torch.int32, + device=get_accelerator().current_device()) + cu_seqlens_kv = torch.tensor([0] + list(itertools.accumulate([seq[1] + seq[0] for seq in seq_params])), + dtype=torch.int32, + device=get_accelerator().current_device()) + + inflight_kv = qkv[:, head_size * n_heads_q:] + full_kvs = [] + for i, kv in enumerate(kvs): + if kv is not None: + full_kvs.append(torch.cat([kv, inflight_kv[cu_seqlens_q[i]:cu_seqlens_q[i + 1]]], dim=0)) + else: + full_kvs.append(inflight_kv[cu_seqlens_q[i]:cu_seqlens_q[i + 1]]) + run_kvs = torch.cat(full_kvs, dim=0) + k = run_kvs[:, :head_size * n_heads_kv] + v = run_kvs[:, head_size * n_heads_kv:] + + q_ref = q.reshape((batch.current_tokens, n_heads_q, head_size)) + k_ref = k.reshape((k.shape[0], n_heads_kv, head_size)) + v_ref = v.reshape((v.shape[0], n_heads_kv, head_size)) + + max_seqlen_q = max([seq[0] for seq in seq_params]) + max_seqlen_kv = max([seq[1] + seq[0] for seq in seq_params]) + + ref_o = flash_attn_varlen_func(q_ref, + k_ref, + v_ref, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + softmax_scale=1.0, + causal=True) + + ref_o = ref_o.reshape(batch.current_tokens, head_size * n_heads_q) + + assert allclose(out, ref_o) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens", [2, 33, 65, 128, 256, 2037]) +def test_single_prompt(n_tokens: int) -> None: + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + + seq_params = [(n_tokens, 0)] + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("prompt_lengths", [(128, 128), (192, 38), (514, 713), (83, 312, 610)]) +def test_multiple_prompts(prompt_lengths: Tuple[int, int]) -> None: + """ + Test multiple prompts in a single batch. + """ + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + + seq_params = [(prompt_lengths[i], 0) for i in range(len(prompt_lengths))] + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("seq_params", [(1, 34), (43, 40), (1, 144), (64, 128), (332, 628)]) +def test_continuation(seq_params: Tuple[int, int]) -> None: + """ + Test continued generation/prompt processing. + """ + head_size = 64 + n_heads_q = 32 + n_heads_kv = 32 + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, [seq_params]) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("head_size", [64, 128]) +def test_head_size(head_size: int) -> None: + n_heads_q = 16 + n_heads_kv = 16 + seq_params = [(128, 128), (192, 38), (1, 814)] + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("head_config", [(32, 8), (64, 16), (40, 8)]) +def test_gqa(head_config: Tuple[int, int]) -> None: + head_size = 128 + n_heads_q = head_config[0] + n_heads_kv = head_config[1] + + seq_params = [(128, 128), (192, 38), (1, 814)] + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +def test_fully_composed() -> None: + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + + seq_params = [(332, 628), (1, 718), (1, 323), (180, 5), (224, 0)] + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_kv_copy.py b/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_kv_copy.py new file mode 100644 index 000000000000..90fe26eb4490 --- /dev/null +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_kv_copy.py @@ -0,0 +1,112 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.kernels.ragged_ops import LinearBlockedKVCopy +from .ragged_testing_utils import build_batch_and_manager, validate_kv_cache + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens, history_size", [(1, 0), (17, 0), (33, 8), (63, 1)]) +def test_single_sequence_single_block(n_tokens: int, history_size: int): + """ + Validate that the copy works correctly + """ + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + kv_block_size = 64 + device = get_accelerator().current_device() + + batch, state_manager, seq_descs = build_batch_and_manager([(n_tokens, history_size)], head_size, n_heads_kv, + kv_block_size) + + assert batch.current_sequences == 1 + assert batch.current_tokens == n_tokens + + qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size), + device=device, + dtype=torch.float16) + + kv_cache = state_manager.get_cache(0) + + copy_impl = LinearBlockedKVCopy(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl(kv_cache, qkv, batch) + + k = qkv[:, head_size * n_heads_q:head_size * (n_heads_q + n_heads_kv)] + v = qkv[:, head_size * (n_heads_q + n_heads_kv):] + + validate_kv_cache(kv_cache, k, v, seq_descs, batch) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens, history_size", [(128, 0), (177, 0), (169, 8), (117, 88)]) +def test_single_sequence_multiple_blocks(n_tokens: int, history_size: int): + """ + Validate that the copy works correctly + """ + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + kv_block_size = 64 + device = get_accelerator().current_device() + + batch, state_manager, seq_descs = build_batch_and_manager([(n_tokens, history_size)], head_size, n_heads_kv, + kv_block_size) + + assert batch.current_sequences == 1 + assert batch.current_tokens == n_tokens + + qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size), + device=device, + dtype=torch.float16) + + kv_cache = state_manager.get_cache(0) + + copy_impl = LinearBlockedKVCopy(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl(kv_cache, qkv, batch) + + k = qkv[:, head_size * n_heads_q:head_size * (n_heads_q + n_heads_kv)] + v = qkv[:, head_size * (n_heads_q + n_heads_kv):] + + validate_kv_cache(kv_cache, k, v, seq_descs, batch) + + +@pytest.mark.inference_v2_ops +def test_multi_sequence() -> None: + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + kv_block_size = 64 + device = get_accelerator().current_device() + + batch_config = [ + (128, 0), + (177, 0), + (169, 8), + (117, 88), + (1, 293), + (1, 733), + (1, 33), + ] + + batch, state_manager, seq_descs = build_batch_and_manager(batch_config, head_size, n_heads_kv, kv_block_size) + + qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size), + device=device, + dtype=torch.float16) + + kv_cache = state_manager.get_cache(0) + + copy_impl = LinearBlockedKVCopy(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl(kv_cache, qkv, batch) + + k = qkv[:, head_size * n_heads_q:head_size * (n_heads_q + n_heads_kv)] + v = qkv[:, head_size * (n_heads_q + n_heads_kv):] + + validate_kv_cache(kv_cache, k, v, seq_descs, batch) diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_rotary_emb.py b/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_rotary_emb.py new file mode 100644 index 000000000000..618c2d3b87ec --- /dev/null +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_rotary_emb.py @@ -0,0 +1,203 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List, Tuple + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.kernels.ragged_ops import BlockedRotaryEmbeddings, BlockedTrainedRotaryEmbeddings +from deepspeed.inference.v2.ragged import RaggedBatchWrapper, DSSequenceDescriptor +from .ragged_testing_utils import build_batch_and_manager, validate_kv_cache +from ....v2.inference_test_utils import allclose +""" +NOTE(cmikeh2): It is very possible to see unit test failures (even on FP16) depending on when +certain values are casted up to or down from float32. If we are seeing accuracy issues, we should +make sure we are aligning on the training implementation's cast pattern here, given these tolerances +tend to be sufficient elsewhere. +""" + + +def rotary_pos_embs(q: torch.Tensor, k: torch.Tensor, seq_descs: List[DSSequenceDescriptor], batch: RaggedBatchWrapper, + head_size: int): + + def make_cos_sin_emb(seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]: + t = torch.arange(seq_len, dtype=torch.float32, device=get_accelerator().current_device()) + inv_freq = (1.0 / (10000.0**(torch.arange( + 0, head_size, 2, dtype=torch.float32, device=get_accelerator().current_device()) / head_size))).half() + + freqs = torch.einsum("i,j->ij", t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + + return torch.cos(emb)[:, None, :], torch.sin(emb)[:, None, :], inv_freq + + def rotate_half(x: torch.Tensor) -> torch.Tensor: + return torch.cat((-x[..., x.shape[-1] // 2:], x[..., :x.shape[-1] // 2]), dim=-1) + + cos, sin, freqs = make_cos_sin_emb(1024) + + q_out = torch.empty_like(q) + k_out = torch.empty_like(k) + n_heads_q = q.shape[1] // head_size + n_heads_kv = k.shape[1] // head_size + + inflight_descs = batch.inflight_seq_descriptors(on_device=False)[:batch.current_sequences] + + if inflight_descs.shape[0] != len(seq_descs): + raise ValueError("The number of sequence descriptors does not match the number of sequences in the batch.") + + for seq_desc, inflight_seq in zip(seq_descs, inflight_descs): + start_idx = inflight_seq[0] + n_tokens = seq_desc.in_flight_tokens + + q_src = q[start_idx:start_idx + n_tokens].reshape(n_tokens, n_heads_q, head_size).float() + k_src = k[start_idx:start_idx + n_tokens].reshape(n_tokens, n_heads_kv, head_size).float() + freq_start_offset = seq_desc.seen_tokens + + cos_chunk = cos[range(freq_start_offset, freq_start_offset + n_tokens)] + sin_chunk = sin[range(freq_start_offset, freq_start_offset + n_tokens)] + + q_emb = q_src * cos_chunk + rotate_half(q_src) * sin_chunk + k_emb = k_src * cos_chunk + rotate_half(k_src) * sin_chunk + + q_out[start_idx:start_idx + n_tokens] = q_emb.reshape(n_tokens, n_heads_q * head_size).to(q_out.dtype) + k_out[start_idx:start_idx + n_tokens] = k_emb.reshape(n_tokens, n_heads_kv * head_size).to(k_out.dtype) + + return q_out, k_out, freqs + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens, history_size", [(1, 0), (17, 0), (33, 15), (1, 63)]) +@pytest.mark.parametrize("trained_emb", [False, True]) +def test_single_sequence_single_block(n_tokens: int, history_size: int, trained_emb: bool): + """ + Validate that the copy works correctly + """ + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + kv_block_size = 64 + device = get_accelerator().current_device() + + batch, state_manager, seq_descs = build_batch_and_manager([(n_tokens, history_size)], head_size, n_heads_kv, + kv_block_size) + + assert batch.current_sequences == 1 + assert batch.current_tokens == n_tokens + + qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size), + device=device, + dtype=torch.float16) + qkv_ref = qkv.clone() + + q = qkv_ref[:, :head_size * n_heads_q] + k = qkv_ref[:, head_size * n_heads_q:head_size * (n_heads_q + n_heads_kv)] + v = qkv_ref[:, head_size * (n_heads_q + n_heads_kv):] + + q_ref, k, freqs = rotary_pos_embs(q, k, seq_descs, batch, head_size) + freqs = freqs.half() + + kv_cache = state_manager.get_cache(0) + + if trained_emb: + copy_impl = BlockedTrainedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl(kv_cache, qkv, batch, freqs) + else: + copy_impl = BlockedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl(kv_cache, qkv, batch) + + assert allclose(qkv[:, :head_size * n_heads_q], q_ref) + validate_kv_cache(kv_cache, k, v, seq_descs, batch, exact=False) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens, history_size", [(128, 0), (177, 0), (169, 8), (117, 88)]) +@pytest.mark.parametrize("trained_emb", [False, True]) +def test_single_sequence_multiple_blocks(n_tokens: int, history_size: int, trained_emb: bool): + """ + Validate that the copy works correctly + """ + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + kv_block_size = 64 + device = get_accelerator().current_device() + + batch, state_manager, seq_descs = build_batch_and_manager([(n_tokens, history_size)], head_size, n_heads_kv, + kv_block_size) + + assert batch.current_sequences == 1 + assert batch.current_tokens == n_tokens + + qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size), + device=device, + dtype=torch.float16) + qkv_ref = qkv.clone() + + q = qkv_ref[:, :head_size * n_heads_q] + k = qkv_ref[:, head_size * n_heads_q:head_size * (n_heads_q + n_heads_kv)] + v = qkv_ref[:, head_size * (n_heads_q + n_heads_kv):] + + q_ref, k, freqs = rotary_pos_embs(q, k, seq_descs, batch, head_size) + freqs = freqs.half() + + kv_cache = state_manager.get_cache(0) + + if trained_emb: + copy_impl = BlockedTrainedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl(kv_cache, qkv, batch, freqs) + else: + copy_impl = BlockedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl(kv_cache, qkv, batch) + + assert allclose(qkv[:, :head_size * n_heads_q], q_ref) + validate_kv_cache(kv_cache, k, v, seq_descs, batch, exact=False) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("trained_emb", [False, True]) +def test_multi_sequences(trained_emb: bool) -> None: + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + kv_block_size = 64 + device = get_accelerator().current_device() + + batch_config = [ + (128, 0), + (177, 0), + (169, 8), + (117, 88), + (1, 293), + (1, 733), + (1, 33), + ] + + batch, state_manager, seq_descs = build_batch_and_manager(batch_config, head_size, n_heads_kv, kv_block_size) + + qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size), + device=device, + dtype=torch.float16) + qkv_ref = qkv.clone() + + q = qkv_ref[:, :head_size * n_heads_q] + k = qkv_ref[:, head_size * n_heads_q:head_size * (n_heads_q + n_heads_kv)] + v = qkv_ref[:, head_size * (n_heads_q + n_heads_kv):] + + q_ref, k, freqs = rotary_pos_embs(q, k, seq_descs, batch, head_size) + freqs = freqs.half() + + kv_cache = state_manager.get_cache(0) + + if trained_emb: + copy_impl = BlockedTrainedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl(kv_cache, qkv, batch, freqs) + else: + copy_impl = BlockedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl(kv_cache, qkv, batch) + + assert allclose(qkv[:, :head_size * n_heads_q], q_ref) + validate_kv_cache(kv_cache, k, v, seq_descs, batch, exact=False) diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_logits_gather.py b/tests/unit/inference/v2/kernels/ragged_ops/test_logits_gather.py new file mode 100644 index 000000000000..1feefa9ee588 --- /dev/null +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_logits_gather.py @@ -0,0 +1,96 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.kernels.ragged_ops import RaggedLogitsGather +from ....v2.inference_test_utils import allclose, get_dtypes +from .ragged_testing_utils import build_simple_batch + + +def baseline_implementation(hidden_states: torch.Tensor, seq_lens: List[int]) -> torch.Tensor: + output = torch.empty((len(seq_lens), hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device) + + offset = 0 + for i, seq_len in enumerate(seq_lens): + output[i] = hidden_states[offset + seq_len - 1] + offset += seq_len + + return output + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('dtype', get_dtypes()) +def test_supported_dtypes(dtype: torch.dtype) -> None: + """ + Validate support on nominally supported data types. + """ + model_dim = 4096 + + batch = build_simple_batch([256], padding=False) + hidden_states = torch.randn((batch.current_tokens, model_dim), + dtype=dtype, + device=get_accelerator().current_device()) + + reference_result = baseline_implementation(hidden_states, [256]) + + kernel = RaggedLogitsGather(model_dim, dtype) + output = torch.empty_like(reference_result) + kernel(output, hidden_states, batch) + + assert allclose(output, reference_result) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('seq_lens', [[128, 64, 192, 32], [57, 112, 63, 89, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], + [63, 27, 74, 83, 32, 17, 1, 1, 1, 1, 1]]) +def test_multiple_sequences(seq_lens: List[int]) -> None: + """ + Validate support on more multi-sequence inputs. + """ + model_dim = 4096 + dtype = torch.float16 + + batch = build_simple_batch(seq_lens, padding=False) + hidden_states = torch.randn((batch.current_tokens, model_dim), + dtype=dtype, + device=get_accelerator().current_device()) + + reference_result = baseline_implementation(hidden_states, seq_lens) + + kernel = RaggedLogitsGather(model_dim, dtype) + output = torch.empty_like(reference_result) + kernel(output, hidden_states, batch) + + assert allclose(output, reference_result) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("model_dim", [1024, 6144, 6784]) +def test_problem_size_permutations(model_dim: int) -> None: + """ + Validate for different embedding sizes. + """ + dtype = torch.float16 + seq_lens = [128, 64, 192, 32] + + batch = build_simple_batch(seq_lens, padding=False) + hidden_states = torch.randn((batch.current_tokens, model_dim), + dtype=dtype, + device=get_accelerator().current_device()) + + reference_result = baseline_implementation(hidden_states, seq_lens) + + kernel = RaggedLogitsGather(model_dim, dtype) + output = torch.empty_like(reference_result) + kernel(output, hidden_states, batch) + + assert allclose(output, reference_result) diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_moe_gather.py b/tests/unit/inference/v2/kernels/ragged_ops/test_moe_gather.py new file mode 100644 index 000000000000..5fa375b49c19 --- /dev/null +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_moe_gather.py @@ -0,0 +1,83 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import DtypeEnum +from deepspeed.inference.v2.kernels.ragged_ops import ( + MoEGather, + MoEScatter, + RaggedTop1Gating, +) +from .ragged_testing_utils import build_simple_batch +""" +For simplicity's sake, these tests do rely on ``RaggedTop1Gating`` and +``MoEScatter`` to produce correct inputs. If either of these kernels is broken +these tests will fail, so double check the unit test results there before +debugging here. +""" + + +def build_inputs(n_tokens, n_experts, do_padding): + + assert n_tokens <= 2048, "This test will break if n_tokens > 2048" + + # Sequence composition shouldn't matter here + batch = build_simple_batch([n_tokens], padding=do_padding) + + logits = torch.randn((batch.tensor_toks, n_experts), + dtype=torch.float16, + device=get_accelerator().current_device()) + + # This will make each token's value equal to its index. NOTE: This will break for + # tokens with index > 2048. + hidden_states = torch.arange(batch.tensor_toks, dtype=torch.float16, + device=get_accelerator().current_device()).repeat_interleave(4096, dim=0).reshape( + batch.tensor_toks, 4096).contiguous() + + gate = RaggedTop1Gating(DtypeEnum.fp16) + + # Gating outputs + expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((batch.tensor_toks, ), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((batch.tensor_toks, ), + dtype=torch.int32, + device=get_accelerator().current_device()) + expert_offset = torch.empty((batch.tensor_toks, ), dtype=torch.int32, device=get_accelerator().current_device()) + + gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) + + # Scatter outputs + moe_input = torch.empty((batch.tensor_toks, 4096), dtype=torch.float16, device=get_accelerator().current_device()) + expert_cumsum = torch.empty((n_experts, ), dtype=torch.int64, device=get_accelerator().current_device()) + mapped_slots = torch.empty((batch.tensor_toks, ), dtype=torch.int32, device=get_accelerator().current_device()) + + scatter = MoEScatter(DtypeEnum.fp16, 4096) + scatter(moe_input, expert_cumsum, mapped_slots, hidden_states, expert_counts, expert_assignment, expert_offset) + + return batch, moe_input, scores, mapped_slots, expert_counts + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens, n_experts", [(13, 64), (278, 64), (1977, 64)]) +@pytest.mark.parametrize("do_padding", [True, False]) +def test_moe_gather(n_tokens, n_experts, do_padding): + + batch, moe_input, scores, mapped_slots, expert_counts = build_inputs(n_tokens, n_experts, do_padding) + + output = torch.randn((batch.tensor_toks, 4096), dtype=torch.float16, device=get_accelerator().current_device()) + + gather = MoEGather(DtypeEnum.fp16, 4096) + gather(output, moe_input, scores, mapped_slots, expert_counts) + + for token_idx in range(n_tokens): + assert torch.equal( + output[token_idx], + torch.full((4096, ), + token_idx * scores[token_idx], + dtype=torch.float16, + device=get_accelerator().current_device())) diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_moe_scatter.py b/tests/unit/inference/v2/kernels/ragged_ops/test_moe_scatter.py new file mode 100644 index 000000000000..4ca051410c1c --- /dev/null +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_moe_scatter.py @@ -0,0 +1,74 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import DtypeEnum +from deepspeed.inference.v2.kernels.ragged_ops import MoEScatter, RaggedTop1Gating +from .ragged_testing_utils import build_simple_batch +""" +For simplicity's sake, these tests do rely on ``RaggedTop1Gating`` to produce correct +inputs. If ``RaggedTop1Gating`` is broken, these tests will fail, so double check +the unit test results there before debugging here. +""" + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens, n_experts", [(13, 64), (278, 64), (1977, 64)]) +@pytest.mark.parametrize("do_padding", [True, False]) +def test_moe_scatter(n_tokens, n_experts, do_padding): + + # Sequence composition shouldn't matter here + batch = build_simple_batch([n_tokens], padding=do_padding) + + logits = torch.randn((batch.tensor_toks, n_experts), + dtype=torch.float16, + device=get_accelerator().current_device()) + + # This will make each token's value equal to its index. NOTE: This will break for + # tokens with index > 2048. + hidden_states = torch.arange(batch.tensor_toks, dtype=torch.float16, + device=get_accelerator().current_device()).repeat_interleave(4096, dim=0).reshape( + batch.tensor_toks, 4096).contiguous() + + gate = RaggedTop1Gating(DtypeEnum.fp16) + + # Gating outputs + expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((batch.tensor_toks, ), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((batch.tensor_toks, ), + dtype=torch.int32, + device=get_accelerator().current_device()) + expert_offset = torch.empty((batch.tensor_toks, ), dtype=torch.int32, device=get_accelerator().current_device()) + + gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) + + # Scatter outputs + moe_input = torch.empty((batch.tensor_toks, 4096), dtype=torch.float16, device=get_accelerator().current_device()) + expert_cumsum = torch.empty((n_experts, ), dtype=torch.int64, device=get_accelerator().current_device()) + mapped_slots = torch.empty((batch.tensor_toks, ), dtype=torch.int32, device=get_accelerator().current_device()) + + scatter = MoEScatter(DtypeEnum.fp16, 4096) + scatter(moe_input, expert_cumsum, mapped_slots, hidden_states, expert_counts, expert_assignment, expert_offset) + assert torch.equal(expert_cumsum, torch.cumsum(expert_counts, dim=0).to(torch.int64)) + + for token_idx in range(batch.tensor_toks): + if token_idx < n_tokens: + expert_idx = expert_assignment[token_idx].item() + if expert_idx == 0: + expert_cumsum_val = 0 + else: + expert_cumsum_val = expert_cumsum[expert_idx - 1] + offset = expert_offset[token_idx] + total_offset = offset + expert_cumsum_val + + assert total_offset == mapped_slots[token_idx].item() + assert torch.equal(moe_input[total_offset], hidden_states[token_idx]) + else: + assert mapped_slots[token_idx].item() == -1 + + assert expert_cumsum[-1] == n_tokens diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_ragged_embed.py b/tests/unit/inference/v2/kernels/ragged_ops/test_ragged_embed.py new file mode 100644 index 000000000000..f179f62a9b12 --- /dev/null +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_ragged_embed.py @@ -0,0 +1,177 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List, Optional, Tuple + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.kernels.ragged_ops import RaggedEmbeddingKernel +from ....v2.inference_test_utils import allclose, get_dtypes +from .ragged_testing_utils import build_batch_and_manager + + +def baseline_implementation(token_ids: torch.Tensor, + embedding_table: torch.Tensor, + unpadded_size: int, + positional_embedding_table: Optional[torch.Tensor] = None, + positional_ids: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Baseline implementation for our ragged embedding kernel. + """ + if unpadded_size == token_ids.shape[0]: + token_embed = torch.nn.functional.embedding(token_ids, embedding_table) + + if positional_embedding_table is not None: + pos_embed = torch.nn.functional.embedding(positional_ids, positional_embedding_table) + token_embed += pos_embed + return token_embed + else: + real_token_ids = token_ids[:unpadded_size] + output = torch.empty((token_ids.shape[0], embedding_table.shape[1]), + dtype=embedding_table.dtype, + device=get_accelerator().current_device()) + unpadded_output = torch.nn.functional.embedding(real_token_ids, embedding_table) + + # Positional embeddings aren't padded because it's simulated + if positional_embedding_table is not None: + pos_embed = torch.nn.functional.embedding(positional_ids, positional_embedding_table) + unpadded_output += pos_embed + + output[:unpadded_size] = unpadded_output + return output + + +def _ragged_embed_test_helper(sequence_config: List[Tuple[int, int]], + embed_dtype: torch.dtype, + token_dtype: torch.dtype, + embed_dim: int, + vocab_size: int, + do_padding: bool = False, + pos_embed_size: int = -1, + pos_embed_offset: int = 0) -> None: + """ + Helper for embedding test to limit the number of tests to run. + + Params: + embed_dim (int): Model dimension + vocab_size (int): Leading dimension on embedding weight + pos_embed_size (int): Size of positional embedding. If negative, no positional embedding + is used. + pos_embed_offset (int): Offset for positional embedding. Effectively, the raw offsets + of a token into a sequence are offset by this amount into the embedding matrix. ( + i.e. the shape of the positional embeddings is (pos_embed_size + pos_embed_offset + embed_dim) + """ + device = get_accelerator().current_device() + + # Heads/Block size are irrelevant here but need something. + batch, _, _, = build_batch_and_manager(sequence_config, 64, 16, 64, vocab_range=vocab_size, padding=do_padding) + + embedding_table = torch.randn((vocab_size, embed_dim), dtype=embed_dtype, device=device) + + if pos_embed_size > 0: + pos_embedding_table = torch.randn((pos_embed_size + pos_embed_offset, embed_dim), + dtype=embed_dtype, + device=device) + positional_ids = torch.cat([ + torch.arange(start_idx, start_idx + seq_len, dtype=token_dtype, device=device) + for seq_len, start_idx in sequence_config + ]) + pos_embed_offset + else: + pos_embedding_table = None + positional_ids = None + + baseline_output = baseline_implementation(batch.input_ids().to(token_dtype), embedding_table, batch.current_tokens, + pos_embedding_table, positional_ids) + + kernel = RaggedEmbeddingKernel(embed_dtype, token_dtype, embed_dim) + output = torch.empty_like(baseline_output) + + kernel(output, + batch, + embedding_table, + position_embed_weight=pos_embedding_table, + position_embed_offset=pos_embed_offset) + + if do_padding: + assert output.shape[0] != batch.current_tokens + + assert allclose(output[:batch.current_tokens], baseline_output[:batch.current_tokens]) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('token_dtype', [torch.int32, torch.int64]) +@pytest.mark.parametrize('embed_dtype', get_dtypes()) +def test_dtype_permutations(token_dtype: torch.dtype, embed_dtype: torch.dtype) -> None: + """ + Validate (on a single problem size) that the kernel support for different data types + is correct. + """ + embed_dim = 4096 + vocab_size = 50304 + + _ragged_embed_test_helper([(256, 0)], embed_dtype, token_dtype, embed_dim, vocab_size) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('vocab_size, embed_dim', [(1024, 1024), (32000, 5120), (50304, 6144)]) +def test_problem_size_permutations(vocab_size: int, embed_dim: int) -> None: + """ + Validate on wider range of problem sizes. + """ + + _ragged_embed_test_helper([(256, 0)], torch.float16, torch.int32, embed_dim, vocab_size) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('seq_lens', [[128, 64, 192, 32], [57, 112, 63, 89, 1, 1, 1, 1]]) +@pytest.mark.parametrize('do_padding', [True, False]) +def test_complex_sequences(seq_lens: List[int], do_padding: bool) -> None: + """ + Validate on different ragged batch construction scenarios. + """ + embed_dim = 4096 + vocab_size = 50304 + + _ragged_embed_test_helper([(seq_len, 0) for seq_len in seq_lens], + torch.float16, + torch.int32, + embed_dim, + vocab_size, + do_padding=do_padding) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("seq_lens", [[(256, 0)], [(256, 0), + (128, 0)], [(256, 0), (128, 0), + (64, 0)], [(1, 877), (619, 0), (213, 372), (1, 45)]]) +def test_positional_embedding(seq_lens: List[Tuple[int, int]]) -> None: + """ + Validate that positional embedding works correctly. + """ + embed_dim = 4096 + vocab_size = 50304 + + _ragged_embed_test_helper(seq_lens, torch.float16, torch.int32, embed_dim, vocab_size, pos_embed_size=2048) + + +@pytest.mark.inference_v2_ops +def test_positional_embedding_offset() -> None: + """ + Validate that positional embedding works correctly with an offset. + """ + embed_dim = 4096 + vocab_size = 50304 + seq_config = [(1, 877), (619, 0), (213, 372), (1, 45)] + + _ragged_embed_test_helper(seq_config, + torch.float16, + torch.int32, + embed_dim, + vocab_size, + pos_embed_size=2048, + pos_embed_offset=2) diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_top_1_gating.py b/tests/unit/inference/v2/kernels/ragged_ops/test_top_1_gating.py new file mode 100644 index 000000000000..6ff2508bf320 --- /dev/null +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_top_1_gating.py @@ -0,0 +1,120 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import torch.nn.functional as F + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import DtypeEnum +from deepspeed.inference.v2.kernels.ragged_ops import RaggedTop1Gating +from .ragged_testing_utils import build_simple_batch +from ....v2.inference_test_utils import allclose + + +def _test_single_mapping_helper(n_tokens: int, + n_experts: int, + assigned_expert: int, + logit_fill: float = 0.0, + match_fill: float = 1.0) -> None: + logits = torch.full((n_tokens, n_experts), + logit_fill, + dtype=torch.float16, + device=get_accelerator().current_device()) + + logits[:, assigned_expert] = match_fill + + gate = RaggedTop1Gating(DtypeEnum.fp16) + + expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((n_tokens, ), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) + expert_offset = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) + batch = build_simple_batch([n_tokens], padding=False) + + gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) + + assert expert_counts[assigned_expert] == n_tokens + assert torch.all(expert_assignment == assigned_expert) + assert torch.unique(expert_offset).shape[0] == n_tokens + assert allclose(scores, F.softmax(logits.float(), dim=1)[:, assigned_expert]) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('n_tokens, n_experts', [(1, 16), (17, 16), (32, 128), (89, 128), (433, 128)]) +def test_single_mapping_gating(n_tokens: int, n_experts: int) -> None: + """ + Evaluate our expert stacking behavior in complete isolation. This ensures all tokens + mapped to the same expert are getting unique offsets and identical scores. + """ + assigned_expert = 13 + _test_single_mapping_helper(n_tokens, n_experts, assigned_expert) + + +@pytest.mark.inference_v2_ops +def test_negative_logits(): + """ + Ensure that scores/values are propagated correctly when all the logits are negative. An + earlier implementation of the scoring would return NaN for this case. + """ + _test_single_mapping_helper(128, 32, 13, logit_fill=-2.0, match_fill=-1.0) + + +@pytest.mark.inference_v2_ops +def test_determinism(): + """ + Ensure that ties between two logits are broken deterministically. This is essential when + the gating is distributed across multiple devices that need to map the same token to + the same expert. + """ + + n_tokens = 512 + n_experts = 64 + + logits = torch.zeros((n_tokens, n_experts), dtype=torch.float16, device=get_accelerator().current_device()) + batch = build_simple_batch([n_tokens], padding=False) + + logits[:, 19] = 1.0 + logits[:, 26] = 1.0 + + gate = RaggedTop1Gating(DtypeEnum.fp16) + + for _ in range(1024): + expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((n_tokens, ), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) + expert_offset = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) + batch = build_simple_batch([n_tokens], padding=False) + + gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) + + assert expert_counts[19] == n_tokens + assert expert_counts[26] == 0 + assert torch.all(expert_assignment == 19) + assert torch.unique(expert_offset).shape[0] == n_tokens + assert allclose(scores, F.softmax(logits.float(), dim=1)[:, 19]) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('n_tokens, n_experts', [(1, 16), (17, 16), (32, 128), (89, 128), (433, 2)]) +def test_score_accuracy(n_tokens: int, n_experts: int) -> None: + """ + Validate expert scores are correct. + """ + logits = torch.randn((n_tokens, n_experts), dtype=torch.float16, device=get_accelerator().current_device()) + batch = build_simple_batch([n_tokens], padding=False) + + gate = RaggedTop1Gating(DtypeEnum.fp16) + + expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((n_tokens, ), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) + expert_offset = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) + + ref_scores = F.softmax(logits.float(), dim=1).max(dim=1).values + + gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) + assert allclose(scores, ref_scores) + assert expert_counts.sum() == n_tokens diff --git a/tests/unit/inference/v2/model_implementations/__init__.py b/tests/unit/inference/v2/model_implementations/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/v2/model_implementations/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/v2/model_implementations/parameters/__init__.py b/tests/unit/inference/v2/model_implementations/parameters/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/v2/model_implementations/parameters/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/v2/model_implementations/parameters/test_layer_inheritance.py b/tests/unit/inference/v2/model_implementations/parameters/test_layer_inheritance.py new file mode 100644 index 000000000000..20803e53a320 --- /dev/null +++ b/tests/unit/inference/v2/model_implementations/parameters/test_layer_inheritance.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.inference.v2.model_implementations.layer_container_base import LayerContainer + +from .utils import validate_device, SimpleParam, DummyInferenceModel + + +class ParentLayer(LayerContainer): + """ + A layer that has a dependency on a simple parameter. + """ + + param_1: SimpleParam + + +class ChildLayer(ParentLayer): + """ + A layer that inherits from another layer. + """ + + param_2: SimpleParam + + +@pytest.mark.inference_v2 +def test_layer_inheritance(): + inference_model = DummyInferenceModel() + + multi_param_layer = ChildLayer(inference_model) + + assert multi_param_layer.n_params == 2 + assert multi_param_layer.is_initialized is False + + multi_param_layer.param_1.param = torch.ones(16, 16) + + assert multi_param_layer.is_initialized is False + + multi_param_layer.param_2.param = torch.full((16, 16), 2.0) + + assert multi_param_layer.is_initialized is True + assert isinstance(multi_param_layer.param_1, torch.Tensor) + assert isinstance(multi_param_layer.param_2, torch.Tensor) + + validate_device(multi_param_layer.param_1) + validate_device(multi_param_layer.param_2) diff --git a/tests/unit/inference/v2/model_implementations/parameters/test_mapping.py b/tests/unit/inference/v2/model_implementations/parameters/test_mapping.py new file mode 100644 index 000000000000..3c74d7a0479a --- /dev/null +++ b/tests/unit/inference/v2/model_implementations/parameters/test_mapping.py @@ -0,0 +1,165 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.inference.v2.allocator import on_device +from deepspeed.inference.v2.model_implementations.parameter_base import ParameterBase, ParamList +from deepspeed.inference.v2.model_implementations.layer_container_base import LayerContainer + + +class MultiDependencyContainer(ParameterBase): + + dependency_1: torch.Tensor + + dependency_2: torch.Tensor + + @on_device + def finalize(self) -> torch.Tensor: + return torch.cat([self.dependency_1, self.dependency_2]) + + +class ListDependencyContainer(ParameterBase): + + dependencies: ParamList("list_items") # noqa: F821 + + @on_device + def finalize(self) -> torch.Tensor: + return torch.cat(tuple(self.dependencies)) + + +class MappingLayer(LayerContainer): + PARAM_MAPPING = { + "model.val.item.d_1": "multi_depend.dependency_1", + "model.val.item.d_2": "multi_depend.dependency_2", + "model.list_vals.*.d": "list_depend.dependencies" + } + + multi_depend: MultiDependencyContainer + + list_depend: ListDependencyContainer + + +class SubMappingLayer(MappingLayer): + PARAM_MAPPING = { + "model.val.item2.d_1": "multi_depend2.dependency_1", + "model.val.item2.d_2": "multi_depend2.dependency_2", + } + + multi_depend2: MultiDependencyContainer + + +class DoubleMappingLayer(LayerContainer): + PARAM_MAPPING = { + "model.val.item.d_1": ["multi_depend.dependency_1", "multi_depend.dependency_2"], + } + + multi_depend: MultiDependencyContainer + + +class InferenceModel: + + @property + def list_items(self) -> int: + return 16 + + +@pytest.mark.inference_v2 +def test_mapping_syntax(): + model = InferenceModel() + + mapping_layer = MappingLayer(model) + + mapping_layer.set_dependency("model.val.item.d_1", torch.ones(1)) + mapping_layer.set_dependency("model.val.item.d_2", torch.ones(1) * 2) + + assert isinstance(mapping_layer.multi_depend, torch.Tensor) + + for i in range(16): + mapping_layer.set_dependency(f"model.list_vals.{i}.d", torch.ones(1) * i) + if i != 16 - 1: + assert mapping_layer.is_initialized == False + + assert isinstance(mapping_layer.list_depend, torch.Tensor) + assert mapping_layer.is_initialized == True + + +@pytest.mark.inference_v2 +def test_sub_mapping_syntax(): + model = InferenceModel() + + mapping_layer = SubMappingLayer(model) + + mapping_layer.set_dependency("model.val.item.d_1", torch.ones(1)) + mapping_layer.set_dependency("model.val.item.d_2", torch.ones(1) * 2) + + assert isinstance(mapping_layer.multi_depend, torch.Tensor) + + mapping_layer.set_dependency("model.val.item2.d_1", torch.ones(1)) + mapping_layer.set_dependency("model.val.item2.d_2", torch.ones(1) * 2) + + assert isinstance(mapping_layer.multi_depend2, torch.Tensor) + + # We want to check into double digits to make sure that this isn't specific + # to single difit indexing. + for i in range(16): + mapping_layer.set_dependency(f"model.list_vals.{i}.d", torch.ones(1) * i) + if i != 16 - 1: + assert mapping_layer.is_initialized == False + + assert isinstance(mapping_layer.list_depend, torch.Tensor) + assert mapping_layer.is_initialized == True + + +@pytest.mark.inference_v2 +def test_double_mapping_syntax(): + model = InferenceModel() + + mapping_layer = DoubleMappingLayer(model) + mapping_layer.set_dependency("model.val.item.d_1", torch.ones(1)) + + # The single parameter setting should immediately make the parameter finalized + # and the whole layer initialized. + assert isinstance(mapping_layer.multi_depend, torch.Tensor) + assert mapping_layer.is_initialized == True + + +@pytest.mark.inference_v2 +def test_insufficient_mapping_syntax(): + """ + In the above example, we don't have a mapping for `multi_depend2.dependency_2`. + """ + + with pytest.raises(ValueError): + + class InsuffienctMappingLayer(LayerContainer): + PARAM_MAPPING = { + "model.val.item.d_1": "multi_depend1.dependency_1", + "model.val.item.d_2": "multi_depend1.dependency_2", + "model.val.item2.d_1": "multi_depend2.dependency_1", + } + + multi_depend1: MultiDependencyContainer + + multi_depend2: MultiDependencyContainer + + +@pytest.mark.inference_v2 +def test_unknown_target_mapping_syntax(): + """ + In the above example, `multi_depend_unknown` does not exist + """ + + with pytest.raises(ValueError): + + class UnknownTargetMappingLayer(LayerContainer): + PARAM_MAPPING = { + "model.val.item.d_1": "multi_depend1.dependency_1", + "model.val.item.d_2": "multi_depend1.dependency_2", + "model.val.item2.d_1": "multi_depend_unknown.dependency_1", + } + + multi_depend: MultiDependencyContainer diff --git a/tests/unit/inference/v2/model_implementations/parameters/test_multi_parameter_layer.py b/tests/unit/inference/v2/model_implementations/parameters/test_multi_parameter_layer.py new file mode 100644 index 000000000000..6bfc04e97c30 --- /dev/null +++ b/tests/unit/inference/v2/model_implementations/parameters/test_multi_parameter_layer.py @@ -0,0 +1,111 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.inference.v2.model_implementations.layer_container_base import LayerContainer + +from .utils import validate_device, SimpleParam, ListParam, DummyInferenceModel + + +class MultiParameterLayer(LayerContainer): + """ + Two dependencies, both of which are simple parameters. + """ + + param_1: SimpleParam + + param_2: SimpleParam + + +class MixedMultiParameterLayer(LayerContainer): + """ + Two dependencies, one of which is a simple parameter, the other is a list parameter. + """ + + param_1: SimpleParam + + param_2: ListParam + + +@pytest.mark.inference_v2 +def test_multi_parameter_layer(): + inference_model = DummyInferenceModel() + + multi_param_layer = MultiParameterLayer(inference_model) + + assert multi_param_layer.n_params == 2 + assert multi_param_layer.is_initialized is False + + multi_param_layer.param_1.param = torch.ones(16, 16) + + assert multi_param_layer.is_initialized is False + + multi_param_layer.param_2.param = torch.full((16, 16), 2.0) + + assert multi_param_layer.is_initialized is True + assert isinstance(multi_param_layer.param_1, torch.Tensor) + assert isinstance(multi_param_layer.param_2, torch.Tensor) + + validate_device(multi_param_layer.param_1) + validate_device(multi_param_layer.param_2) + + +@pytest.mark.inference_v2 +def test_mixed_multi_parameter_layer(): + inference_model = DummyInferenceModel() + + mixed_multi_param_layer = MixedMultiParameterLayer(inference_model) + + assert mixed_multi_param_layer.n_params == 2 + assert mixed_multi_param_layer.is_initialized is False + + mixed_multi_param_layer.param_2.params[1] = torch.full((16, 16), 2.0) + assert mixed_multi_param_layer.is_initialized is False + assert not isinstance(mixed_multi_param_layer.param_2, torch.Tensor) + + mixed_multi_param_layer.param_1.param = torch.ones(16, 16) + assert mixed_multi_param_layer.is_initialized is False + assert isinstance(mixed_multi_param_layer.param_1, torch.Tensor) + + validate_device(mixed_multi_param_layer.param_1) + + mixed_multi_param_layer.param_2.params[0] = torch.full((16, 16), 2.0) + + assert mixed_multi_param_layer.is_initialized is True + assert isinstance(mixed_multi_param_layer.param_2, torch.Tensor) + + validate_device(mixed_multi_param_layer.param_2) + + +class NoCopyInferenceModel: + + @property + def num_dependencies(self) -> int: + return 2 + + def transform(self, param: torch.Tensor) -> torch.Tensor: + return param + + +@pytest.mark.inference_v2 +def test_device_validation(): + inference_model = NoCopyInferenceModel() + + multi_param_layer = MultiParameterLayer(inference_model) + + assert multi_param_layer.n_params == 2 + assert multi_param_layer.is_initialized is False + + multi_param_layer.param_1.param = torch.ones(16, 16) + + assert multi_param_layer.is_initialized is False + + multi_param_layer.param_2.param = torch.full((16, 16), 2.0) + + with pytest.raises(RuntimeError): + # NoCopyInference model did not copy the parameters, so the device validation should fail. + assert multi_param_layer.is_initialized is True diff --git a/tests/unit/inference/v2/model_implementations/parameters/test_parameter_list.py b/tests/unit/inference/v2/model_implementations/parameters/test_parameter_list.py new file mode 100644 index 000000000000..42edd90595fa --- /dev/null +++ b/tests/unit/inference/v2/model_implementations/parameters/test_parameter_list.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.inference.v2.model_implementations.parameter_base import ParameterBase, ParamList +from deepspeed.inference.v2.model_implementations.layer_container_base import LayerContainer +from deepspeed.inference.v2.model_implementations.common_parameters import * +from deepspeed.inference.v2.allocator import on_device + +from .utils import validate_device + + +class SimpleMoELayer(LayerContainer): + + moe_mlp_1: UnfusedMoEMLP1Parameter + + +class DummyInferenceModel: + + def __init__(self, experts_per_rank: int) -> None: + self._num_experts = experts_per_rank + + @property + def num_experts(self) -> int: + return self._num_experts + + @on_device + def transform_moe_mlp_1_param(self, param: torch.Tensor) -> torch.Tensor: + return param + + +@pytest.mark.inference_v2 +def test_simple_moe_layer(): + + inference_model = DummyInferenceModel(experts_per_rank=2) + + simple_moe_layer = SimpleMoELayer(inference_model) + + assert simple_moe_layer.moe_mlp_1.experts[0] is None + assert simple_moe_layer.moe_mlp_1.experts[1] is None + + # Set the first expert + simple_moe_layer.moe_mlp_1.experts[0] = torch.zeros(16, 16) + + assert simple_moe_layer.moe_mlp_1.experts[0] is not None + assert simple_moe_layer.moe_mlp_1.experts[1] is None + + assert not simple_moe_layer.is_initialized + + # Set the second expert + simple_moe_layer.moe_mlp_1.experts[1] = torch.ones(16, 16) + + # We have all the experts, so the layer should be initialized + assert simple_moe_layer.is_initialized + assert isinstance(simple_moe_layer.moe_mlp_1, torch.Tensor) + + validate_device(simple_moe_layer.moe_mlp_1) + + +""" +Check that we can mix the number of elements in lists in the same context and have that +be tracked correctly. +""" + + +class CustomListParam1(ParameterBase): + + deps: ParamList("attr_1") + + +class CustomListParam2(ParameterBase): + + deps: ParamList("attr_2") + + +class MixedLayer(LayerContainer): + + list_1: CustomListParam1 + list_2: CustomListParam2 + + +class MixedInferenceModel: + + @property + def attr_1(self) -> int: + return 1 + + @property + def attr_2(self) -> int: + return 2 + + +@pytest.mark.inference_v2 +def test_mixed_param_lists(): + model = MixedInferenceModel() + + layer = MixedLayer(model) + + assert layer.list_1.deps.n_params == 1 + assert layer.list_2.deps.n_params == 2 diff --git a/tests/unit/inference/v2/model_implementations/parameters/utils.py b/tests/unit/inference/v2/model_implementations/parameters/utils.py new file mode 100644 index 000000000000..0d2cbb27d40e --- /dev/null +++ b/tests/unit/inference/v2/model_implementations/parameters/utils.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.allocator import on_device +from deepspeed.inference.v2.model_implementations.parameter_base import ParameterBase, ParametrizedList + + +class SimpleParam(ParameterBase): + """ + Parameter with single dependency. + """ + + param: torch.Tensor + + def finalize(self) -> torch.Tensor: + return self.inference_model.transform(self.param) + + +class SimpleParametrizedList(ParametrizedList): + """ + Parameter list based on `num_dependencies` attribute. + """ + + count_attr: str = "num_dependencies" + + +class ListParam(ParameterBase): + """ + Parameter with list dependency. + + NOTE: This uses the tuple workaround for the `ParametrizedList` class + as described in the docstring of `ParametrizedList`. + """ + + params: SimpleParametrizedList + + def finalize(self) -> torch.Tensor: + return self.inference_model.transform(torch.cat(tuple(self.params))) + + +class DummyInferenceModel: + + @property + def num_dependencies(self) -> int: + return 2 + + @on_device + def transform(self, param: torch.Tensor) -> torch.Tensor: + return param + + +def validate_device(tensor: torch.Tensor): + assert tensor.device == torch.device(get_accelerator().current_device()) diff --git a/tests/unit/inference/v2/model_implementations/sharding/__init__.py b/tests/unit/inference/v2/model_implementations/sharding/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/v2/model_implementations/sharding/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/v2/model_implementations/sharding/test_attn_out_sharding.py b/tests/unit/inference/v2/model_implementations/sharding/test_attn_out_sharding.py new file mode 100644 index 000000000000..850c4c24fde6 --- /dev/null +++ b/tests/unit/inference/v2/model_implementations/sharding/test_attn_out_sharding.py @@ -0,0 +1,129 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.model_implementations.sharding import * + +# None of the logic should be dependent on head size. +HEAD_SIZE = 64 + + +def fill_with_head_ids(head_size: int, n_heads: int) -> torch.Tensor: + """ + Fills a tensor with the associated head ids. All columns should have the same value. + """ + head_ids = torch.arange(n_heads, dtype=torch.half, device=get_accelerator().current_device()) + + head_ids = head_ids.repeat_interleave(head_size).repeat(head_size * n_heads).reshape(n_heads * head_size, -1) + return head_ids + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("n_heads, n_shards", [(1, 1), (8, 4), (32, 8)]) +def test_mha_even_sharding(n_heads: int, n_shards: int): + """ + Even head sharding for MHA. + + Args: + n_heads (int): The number QKV heads. + n_shards (int): The number of shards to test for. + """ + param = fill_with_head_ids(HEAD_SIZE, n_heads) + + n_local_heads = n_heads // n_shards + sharded_shape = (HEAD_SIZE * n_heads, HEAD_SIZE * n_local_heads) + + for shard_rank in range(n_shards): + sharded_param = shard_attn_out_param(param, shard_rank, n_shards, HEAD_SIZE) + n_heads_local_q, _ = get_local_heads(shard_rank, n_shards, n_heads) + + assert sharded_param.shape[-1] == HEAD_SIZE * n_heads_local_q + assert sharded_param.shape == sharded_shape + + heads = torch.chunk(sharded_param, n_local_heads, dim=1) + + for i, head in enumerate(heads): + assert torch.all(head == i + shard_rank * n_local_heads) + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("n_heads, n_shards", [(3, 2), (20, 8)]) +def test_mha_unbalanced_sharding(n_heads: int, n_shards: int): + """ + Unbalanced head sharding for MHA. + + Args: + n_heads (int): The number QKV heads. + n_shards (int): The number of shards to test for. + """ + param = fill_with_head_ids(HEAD_SIZE, n_heads) + + max_heads = 0 + min_heads = n_heads + seen_heads = set() + total_heads = 0 + + for shard_rank in range(n_shards): + sharded_param = shard_attn_out_param(param, shard_rank, n_shards, HEAD_SIZE) + n_heads_local_q, _ = get_local_heads(shard_rank, n_shards, n_heads) + + assert sharded_param.shape[-1] == HEAD_SIZE * n_heads_local_q + + n_local_heads = sharded_param.shape[1] // HEAD_SIZE + total_heads += n_local_heads + max_heads = max(max_heads, n_local_heads) + min_heads = min(min_heads, n_local_heads) + + for i in range(n_local_heads): + head_ids = torch.unique_consecutive(sharded_param[:, i * HEAD_SIZE:(i + 1) * HEAD_SIZE]) + assert len(head_ids) == 1 + seen_heads.add(head_ids.item()) + + assert max_heads == min_heads + 1 + assert total_heads == n_heads + assert len(seen_heads) == n_heads + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("n_heads_q, n_heads_kv, n_shards", [(20, 4, 8)]) +def test_gqa_uneven_sharding(n_heads_q: int, n_heads_kv: int, n_shards: int): + """ + We only test the uneven GQA test case because even GQA shards the attention output + in the exact same manner as MHA. + + Args: + n_heads_q (int): The number of query heads. + n_heads_kv (int): The number of key/value heads. + n_shards (int): The number of shards to test for. + """ + param = fill_with_head_ids(HEAD_SIZE, n_heads_q) + + min_heads = n_heads_q + max_heads = 0 + seen_heads = set() + total_heads = 0 + + for shard_rank in range(n_shards): + sharded_param = shard_attn_out_param(param, shard_rank, n_shards, HEAD_SIZE, n_heads_q, n_heads_kv) + n_heads_local_q, _ = get_local_heads(shard_rank, n_shards, n_heads_q, n_heads_kv) + + assert sharded_param.shape[-1] == HEAD_SIZE * n_heads_local_q + + n_local_heads = sharded_param.shape[1] // HEAD_SIZE + total_heads += n_local_heads + max_heads = max(max_heads, n_local_heads) + min_heads = min(min_heads, n_local_heads) + + for i in range(n_local_heads): + head_id = torch.unique_consecutive(sharded_param[:, i * HEAD_SIZE:(i + 1) * HEAD_SIZE]) + assert len(head_id) == 1 + seen_heads.add(head_id.item()) + + assert max_heads == min_heads + 1 + assert total_heads == n_heads_q + assert len(seen_heads) == n_heads_q diff --git a/tests/unit/inference/v2/model_implementations/sharding/test_mlp_sharding.py b/tests/unit/inference/v2/model_implementations/sharding/test_mlp_sharding.py new file mode 100644 index 000000000000..aac7e5391d8f --- /dev/null +++ b/tests/unit/inference/v2/model_implementations/sharding/test_mlp_sharding.py @@ -0,0 +1,116 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.model_implementations.sharding import * + + +def round_up_to_256(x: int) -> int: + """ + Round up to the nearest multiple of 256. + """ + return x + (256 - x % 256) + + +def make_params(model_dim: int, ffn_multiplier: int, n_experts: int, gated: bool = False) -> torch.Tensor: + """ + + """ + if gated: + mlp_1_intermediate = round_up_to_256(int(model_dim * ffn_multiplier * 4 / 3)) + mlp_2_intermediate = mlp_1_intermediate // 2 + else: + mlp_1_intermediate = ffn_multiplier * model_dim + mlp_2_intermediate = ffn_multiplier * model_dim + + mlp_1_shared_dim = torch.arange(mlp_1_intermediate, dtype=torch.float32, device=get_accelerator().current_device()) + + mlp_1_w = mlp_1_shared_dim.repeat_interleave(model_dim).reshape(mlp_1_intermediate, model_dim) + mlp_1_b = mlp_1_shared_dim + + mlp_2_shared_dim = torch.arange(mlp_2_intermediate, dtype=torch.float32, device=get_accelerator().current_device()) + mlp_2_w = mlp_2_shared_dim.repeat(model_dim).reshape(model_dim, mlp_2_intermediate) + mlp_2_b = torch.ones(model_dim, dtype=torch.float32, device=get_accelerator().current_device()) + + if n_experts > 1: + mlp_1_w = mlp_1_w.expand(n_experts, -1, -1) + mlp_1_b = mlp_1_b.expand(n_experts, -1) + mlp_2_w = mlp_2_w.expand(n_experts, -1, -1) + mlp_2_b = mlp_2_b.expand(n_experts, -1) + + return (mlp_1_w, mlp_1_b, mlp_2_w, mlp_2_b) + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("model_dim, ffn_multiplier, n_shards", [(1024, 4, 1), (1024, 4, 8), (1024, 4, 6)]) +@pytest.mark.parametrize("n_experts", [1, 16]) +def test_even_ffn_sharding(model_dim: int, ffn_multiplier: int, n_shards: int, n_experts: int): + """ + FFN sharding tends to be much simpler than attention sharding since it works on larger granularities. + While the test case of (1024, 4, 6) is not a use case we're likely to see, this does ensure that + the sharding logic will round correctly for the alignments we care about. + """ + mlp_1_w, mlp_1_b, mlp_2_w, mlp_2_b = make_params(model_dim, ffn_multiplier, n_experts) + + total_ffn_dim = model_dim * ffn_multiplier + mapped_neurons = 0 + + is_moe = n_experts > 1 + + for shard_rank in range(n_shards): + shard_1_w = shard_mlp_1_param(mlp_1_w, shard_rank, n_shards, is_moe=is_moe) + shard_1_b = shard_mlp_1_param(mlp_1_b, shard_rank, n_shards, is_moe=is_moe) + shard_2_w = shard_mlp_2_param(mlp_2_w, shard_rank, n_shards, is_moe=is_moe) + shard_2_b = shard_mlp_2_param(mlp_2_b, shard_rank, n_shards, is_moe=is_moe) + + assert shard_1_w.shape[-2] == shard_2_w.shape[-1] + assert shard_1_w.shape[-2] % DEFAULT_SHARD_GRANULARITY == 0 + assert shard_1_w.shape[-2] == shard_1_b.shape[-1] + + mapped_neurons += shard_1_w.shape[-2] + + if shard_rank != 0: + assert shard_2_b is None + else: + assert shard_2_b.shape[-1] == model_dim + + assert mapped_neurons == total_ffn_dim + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("model_dim, ffn_multiplier, n_shards", [(1024, 4, 1), (1024, 4, 8), (1024, 4, 6)]) +@pytest.mark.parametrize("n_experts", [1, 16]) +def test_gated_ffn_sharding(model_dim: int, ffn_multiplier: int, n_shards: int, n_experts: int): + """ + Test the same cases assuming a gated regime. + """ + mlp_1_w, mlp_1_b, mlp_2_w, mlp_2_b = make_params(model_dim, ffn_multiplier, n_experts, gated=True) + + total_ffn_dim = round_up_to_256(int(model_dim * ffn_multiplier * 4 / 3)) + mapped_neurons = 0 + + is_moe = n_experts > 1 + + for shard_rank in range(n_shards): + shard_1_w = shard_mlp_1_param(mlp_1_w, shard_rank, n_shards, gated=True, is_moe=is_moe) + shard_1_b = shard_mlp_1_param(mlp_1_b, shard_rank, n_shards, gated=True, is_moe=is_moe) + shard_2_w = shard_mlp_2_param(mlp_2_w, shard_rank, n_shards, is_moe=is_moe) + shard_2_b = shard_mlp_2_param(mlp_2_b, shard_rank, n_shards, is_moe=is_moe) + + assert shard_1_w.shape[-2] == shard_2_w.shape[-1] * 2 + assert shard_1_w.shape[-2] % DEFAULT_SHARD_GRANULARITY == 0 + assert shard_1_w.shape[-2] == shard_1_b.shape[-1] + + mapped_neurons += shard_1_w.shape[-2] + + if shard_rank != 0: + assert shard_2_b is None + else: + assert shard_2_b.shape[-1] == model_dim + + assert mapped_neurons == total_ffn_dim diff --git a/tests/unit/inference/v2/model_implementations/sharding/test_qkv_sharding.py b/tests/unit/inference/v2/model_implementations/sharding/test_qkv_sharding.py new file mode 100644 index 000000000000..9a1cb9c09c64 --- /dev/null +++ b/tests/unit/inference/v2/model_implementations/sharding/test_qkv_sharding.py @@ -0,0 +1,251 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.model_implementations.sharding import * + + +def fill_with_head_ids(head_size: int, n_heads_q: int, n_heads_kv: Optional[int] = None) -> torch.Tensor: + """ + + """ + head_ids_q = torch.arange(n_heads_q, dtype=torch.half, device=get_accelerator().current_device()) + head_vals_q = head_ids_q.repeat_interleave(head_size * head_size * n_heads_q).reshape(n_heads_q * head_size, -1) + + if n_heads_kv is None: + return torch.cat([head_vals_q, head_vals_q, head_vals_q], dim=0) + + head_ids_k = torch.arange(n_heads_kv, dtype=torch.half, device=get_accelerator().current_device()) + head_vals_k = head_ids_k.repeat_interleave(head_size * head_size * n_heads_q).reshape(n_heads_kv * head_size, -1) + + return torch.cat([head_vals_q, head_vals_k, head_vals_k], dim=0) + + +def validate_inferred_shape(shard: torch.Tensor, head_size: int, n_local_q_heads: int, n_local_kv_heads: int): + """ + Validate that the leading dim of the shard is of the expected size and aligns with the sharding + logic for the attention computation itself. + """ + inferred_leading_dim = head_size * (n_local_q_heads + 2 * n_local_kv_heads) + assert shard.shape[0] == inferred_leading_dim + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("n_heads,n_shards", [(1, 1), (32, 1), (32, 8)]) +def test_even_mha_sharding(head_size: int, n_heads: int, n_shards: int): + """ + Test for MHA sharding. In these scenarios, we expect that each of the shards + should be the same size. + """ + param = fill_with_head_ids(head_size, n_heads) + + heads_per_shard = n_heads // n_shards + + for shard_rank in range(n_shards): + + shard = shard_qkv_param(param, shard_rank, n_shards, head_size, n_heads, n_heads) + n_local_q_heads, n_local_kv_heads = get_local_heads(shard_rank, n_shards, n_heads, n_heads) + validate_inferred_shape(shard, head_size, n_local_q_heads, n_local_kv_heads) + + assert shard.shape == (3 * head_size * heads_per_shard, head_size * n_heads) + + heads = shard.chunk(heads_per_shard * 3, dim=0) + for i in range(heads_per_shard): + assert torch.all(heads[i] == i + shard_rank * heads_per_shard) + assert torch.all(heads[i + heads_per_shard] == i + shard_rank * heads_per_shard) + assert torch.all(heads[i + heads_per_shard * 2] == i + shard_rank * heads_per_shard) + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("n_heads, n_shards", [(3, 2), (20, 8)]) +def test_unbalanced_mha_sharding(head_size: int, n_heads: int, n_shards: int): + """ + Test MHA sharding when the distribution of heads will not be equal across all ranks. + """ + param = fill_with_head_ids(head_size, n_heads) + + max_heads = 0 + min_heads = n_heads + total_heads = 0 + seen_heads = set() + + for shard_rank in range(n_shards): + shard = shard_qkv_param(param, shard_rank, n_shards, head_size, n_heads, n_heads) + n_local_q_heads, n_local_kv_heads = get_local_heads(shard_rank, n_shards, n_heads, n_heads) + validate_inferred_shape(shard, head_size, n_local_q_heads, n_local_kv_heads) + + n_heads_in_shard = shard.shape[0] // head_size // 3 + + max_heads = max(max_heads, n_heads_in_shard) + min_heads = min(min_heads, n_heads_in_shard) + total_heads += n_heads_in_shard + + heads = shard.chunk(n_heads_in_shard * 3, dim=0) + + for local_head_id in range(n_heads_in_shard): + head_qkv = torch.cat([ + heads[local_head_id], heads[local_head_id + n_heads_in_shard], + heads[local_head_id + 2 * n_heads_in_shard] + ], + dim=0) + assert head_qkv.shape == (3 * head_size, head_size * n_heads) + + global_head_id = torch.unique_consecutive(head_qkv) + assert len(global_head_id) == 1 + + seen_heads.add(global_head_id.item()) + + assert max_heads - min_heads <= 1 + assert total_heads == n_heads + assert len(seen_heads) == n_heads + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("n_heads_q, n_heads_kv, n_shards", [(4, 2, 1), (8, 2, 1), (64, 16, 8)]) +def test_gqa_even_sharding(head_size: int, n_heads_q: int, n_heads_kv: int, n_shards: int): + """ + Test GQA sharding when the KV heads are evenly divisible by the number of shards. + """ + param = fill_with_head_ids(head_size, n_heads_q, n_heads_kv) + + n_kv_heads_in_shard = n_heads_kv // n_shards + n_q_heads_in_shard = n_heads_q // n_shards + + for shard_rank in range(n_shards): + shard = shard_qkv_param(param, shard_rank, n_shards, head_size, n_heads_q, n_heads_kv) + n_local_q_heads, n_local_kv_heads = get_local_heads(shard_rank, n_shards, n_heads_q, n_heads_kv) + validate_inferred_shape(shard, head_size, n_local_q_heads, n_local_kv_heads) + + assert shard.shape[0] == (n_q_heads_in_shard + n_kv_heads_in_shard * 2) * head_size + + q = shard[:n_q_heads_in_shard * head_size] + k = shard[n_q_heads_in_shard * head_size:(n_q_heads_in_shard + n_kv_heads_in_shard) * head_size] + v = shard[(n_q_heads_in_shard + n_kv_heads_in_shard) * head_size:] + + for local_head_id in range(n_q_heads_in_shard): + assert torch.all(q[local_head_id * head_size:(local_head_id + 1) * head_size] == local_head_id + + shard_rank * n_q_heads_in_shard) + + for local_head_id in range(n_kv_heads_in_shard): + assert torch.all(k[local_head_id * head_size:(local_head_id + 1) * head_size] == local_head_id + + shard_rank * n_kv_heads_in_shard) + assert torch.all(v[local_head_id * head_size:(local_head_id + 1) * head_size] == local_head_id + + shard_rank * n_kv_heads_in_shard) + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("n_heads_q, n_heads_kv, n_shards", [(4, 2, 4), (20, 4, 8)]) +def test_gqa_uneven_sharding(head_size: int, n_heads_q: int, n_heads_kv: int, n_shards: int): + """ + Test GQA sharding when there are more shards than KV heads. + """ + param = fill_with_head_ids(head_size, n_heads_q, n_heads_kv) + + n_kv_heads_in_shard = 1 + n_shards_per_kv_head = n_shards // n_heads_kv + + max_heads = 0 + min_heads = n_heads_q + total_heads = 0 + seen_heads = set() + + for shard_rank in range(n_shards): + shard = shard_qkv_param(param, shard_rank, n_shards, head_size, n_heads_q, n_heads_kv) + n_local_q_heads, n_local_kv_heads = get_local_heads(shard_rank, n_shards, n_heads_q, n_heads_kv) + validate_inferred_shape(shard, head_size, n_local_q_heads, n_local_kv_heads) + + local_n_heads_q = (shard.shape[0] - 2 * n_kv_heads_in_shard * head_size) // head_size + + max_heads = max(max_heads, local_n_heads_q) + min_heads = min(min_heads, local_n_heads_q) + total_heads += local_n_heads_q + + q = shard[:local_n_heads_q * head_size] + kv = shard[local_n_heads_q * head_size:] + + for local_head_id in range(local_n_heads_q): + q_head_id = torch.unique_consecutive(q[local_head_id * head_size:(local_head_id + 1) * head_size]) + assert len(q_head_id) == 1 + + seen_heads.add(q_head_id.item()) + + kv_id_calc = shard_rank // n_shards_per_kv_head + kv_id = torch.unique_consecutive(kv) + assert len(kv_id) == 1 + assert kv_id.item() == kv_id_calc + + assert max_heads - min_heads <= 1 + assert total_heads == n_heads_q + assert len(seen_heads) == n_heads_q + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("n_heads, n_shards", [(6, 8)]) +def test_unsupported_mha_configs(head_size: int, n_heads: int, n_shards: int): + """ + Sharding should fail if there are fewer heads than shards. + + TODO(cmikeh2): Look to support this configuration. + """ + param = fill_with_head_ids(head_size, n_heads) + + for shard_rank in range(n_shards): + with pytest.raises(ValueError): + shard_qkv_param(param, shard_rank, n_shards, head_size, n_heads, n_heads) + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("n_heads_q, n_heads_kv, n_shards", [(5, 2, 1), (40, 10, 8), (30, 5, 8)]) +def test_unsupported_gqa_configs(head_size: int, n_heads_q: int, n_heads_kv: int, n_shards: int): + """ + GQA has stricter requirements. We must be able to evenly shard or distribute the KV heads. + + Test cases are to test the following preconditions specifically: + 1. n_heads_q % n_heads_kv == 0 + 2. We must be able to evenly distribute KV heads + 3. We must be able to evely split KV heads + """ + param = fill_with_head_ids(head_size, n_heads_q, n_heads_kv) + + for shard_rank in range(n_shards): + with pytest.raises(ValueError): + shard_qkv_param(param, shard_rank, n_shards, head_size, n_heads_q, n_heads_kv) + + +@pytest.mark.inference_v2 +def test_mha_input_shape_error(): + + param = torch.empty(256, 128) + + n_heads = 2 + head_size = 64 + + with pytest.raises(ValueError): + shard_qkv_param(param, 0, 1, 64) + + +@pytest.mark.inference_v2 +def test_gqa_input_shape_error(): + + head_size = 64 + n_heads_q = 16 + n_heads_kv = 4 + + # Correct shape is 1536 (=16 * 64 + 2 * 4 * 64), 1024 + param = torch.empty(2048, 1024) + + with pytest.raises(ValueError): + shard_qkv_param(param, 0, 1, head_size, n_heads_q, n_heads_kv) diff --git a/tests/unit/inference/v2/modules/__init__.py b/tests/unit/inference/v2/modules/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/v2/modules/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/v2/modules/test_blas_linear_module.py b/tests/unit/inference/v2/modules/test_blas_linear_module.py new file mode 100644 index 000000000000..f4d0b1991238 --- /dev/null +++ b/tests/unit/inference/v2/modules/test_blas_linear_module.py @@ -0,0 +1,111 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import ActivationType, DtypeEnum, is_gated +from deepspeed.inference.v2.modules import ConfigBundle +from deepspeed.inference.v2.modules.configs import DSLinearConfig +from deepspeed.inference.v2.modules.interfaces import DSLinearRegistry +from ...v2.inference_test_utils import allclose + + +def reference_implementation(hidden_states: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], + act_type: ActivationType) -> torch.Tensor: + dtype = hidden_states.dtype + out_states = torch.nn.functional.linear(hidden_states, weight, bias) + out_states.float() + + if is_gated(act_type): + act_func_map = { + ActivationType.ReGLU: torch.nn.functional.relu, + ActivationType.GEGLU: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + ActivationType.SiGLU: torch.nn.functional.silu, + } + + act_act = out_states[..., ::2] + act_linear = out_states[..., 1::2] + + act_act = act_func_map[act_type](act_act) + out_states = act_act * act_linear + else: + act_func_map = { + ActivationType.RELU: torch.nn.functional.relu, + ActivationType.GELU: torch.nn.functional.gelu, + ActivationType.SILU: torch.nn.functional.silu, + ActivationType.IDENTITY: lambda x: x, + } + + out_states = act_func_map[act_type](out_states) + return out_states.to(dtype) + + +def _blas_linear_helper(tokens: int, + in_channels: int, + out_channels: int, + dtype: DtypeEnum, + act_fn: ActivationType, + use_bias: bool = True) -> None: + linear_config = DSLinearConfig(max_tokens=2048, + in_channels=in_channels, + out_channels=out_channels, + activation=act_fn, + input_dtype=dtype, + output_dtype=dtype) + + bundle = ConfigBundle(name='blas_fp_linear', config=linear_config) + + module = DSLinearRegistry.instantiate_config(bundle) + + # Input vals + hidden_states = torch.randn( + (tokens, in_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 + + weight_out_channels = 2 * out_channels if is_gated(act_fn) else out_channels + weight = torch.randn( + (weight_out_channels, in_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 + if use_bias: + bias = torch.randn( + (weight_out_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 + else: + bias = None + + # Reference output + ref_output = reference_implementation(hidden_states, weight, bias, act_fn) + + # New output + ds_output = module(hidden_states, weight, bias) + + # Check + assert allclose(ds_output, ref_output) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens, in_channels, out_channels", [(1, 4608, 1728), (37, 8192, 4096), (1280, 3072, 6144)]) +def test_blas_linear_shapes(tokens: int, in_channels: int, out_channels: int) -> None: + + _blas_linear_helper(tokens, in_channels, out_channels, DtypeEnum.fp16, ActivationType.IDENTITY) + + +all_acts = [ + ActivationType.RELU, + ActivationType.GELU, + ActivationType.SILU, + ActivationType.GEGLU, + ActivationType.ReGLU, + ActivationType.SiGLU, +] + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("act_fn", all_acts) +@pytest.mark.parametrize("use_bias", [True, False]) +def test_blas_linear_act_fn(act_fn: ActivationType, use_bias: bool) -> None: + + _blas_linear_helper(283, 512, 4096, DtypeEnum.fp16, act_fn, use_bias=use_bias) diff --git a/tests/unit/inference/v2/modules/test_blocked_attn.py b/tests/unit/inference/v2/modules/test_blocked_attn.py new file mode 100644 index 000000000000..215ad64636b1 --- /dev/null +++ b/tests/unit/inference/v2/modules/test_blocked_attn.py @@ -0,0 +1,210 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import itertools + +from typing import List, Tuple + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.modules import ConfigBundle +from deepspeed.inference.v2.modules.configs import DSSelfAttentionConfig, PositionalEmbeddingType +from deepspeed.inference.v2.modules.interfaces import DSSelfAttentionRegistry, DSSelfAttentionBase + +from ..kernels.ragged_ops.ragged_testing_utils import build_batch_and_manager +from ...v2.inference_test_utils import allclose + +try: + from flash_attn.flash_attn_interface import flash_attn_varlen_func + validate_accuracy = True +except ImportError: + validate_accuracy = False + + +def _blocked_flash_testing_helper(head_size: int, + n_heads_q: int, + n_heads_kv: int, + seq_params: List[Tuple[int, int]], + trained_freqs: bool = None) -> None: + """ + Helper function for testing blocked flash attention. This implementation is based on + the implemnentation in ``unit.inference.kernels.ragged_ops.test_blocked_flash`` but + integrates functionality to validate the composability. + """ + if trained_freqs is None: + embed_type = PositionalEmbeddingType.none + embed_args = {} + else: + embed_type = PositionalEmbeddingType.rotate_half + if trained_freqs: + embed_args = {'trained_freqs': True} + else: + embed_args = {'trained_freqs': False} + + attn_config = DSSelfAttentionConfig(max_tokens=2048, + n_heads_q=n_heads_q, + n_heads_kv=n_heads_kv, + head_size=head_size, + max_sequences=32, + positional_embedding_type=embed_type, + positional_embedding_args=embed_args) + + config = ConfigBundle(name='dense_blocked_attention', config=attn_config) + attn_module: DSSelfAttentionBase = DSSelfAttentionRegistry.instantiate_config(config) + + kv_block_size = attn_module.kv_block_size + + kvs = [] + for _, history_len in seq_params: + if history_len > 0: + kvs.append( + torch.randn((history_len, 2 * n_heads_kv * head_size), + device=get_accelerator().current_device(), + dtype=torch.float16)) + else: + kvs.append(None) + + batch, state_manager, _ = build_batch_and_manager(seq_params, head_size, n_heads_kv, kv_block_size, kv_fill=kvs) + + qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size), + device=get_accelerator().current_device(), + dtype=torch.float16) + + kv_cache = state_manager.get_cache(0) + + attn_module.build_atoms(batch) + if not trained_freqs: + out = attn_module(qkv, kv_cache, batch) + else: + inv_freqs = torch.randn((head_size // 2, ), device=get_accelerator().current_device(), dtype=torch.float16) + out = attn_module(qkv, kv_cache, batch, inv_freqs) + + if validate_accuracy and trained_freqs is None: + cu_seqlens_q = torch.tensor([0] + list(itertools.accumulate([seq[0] for seq in seq_params])), + dtype=torch.int32, + device=get_accelerator().current_device()) + cu_seqlens_kv = torch.tensor([0] + list(itertools.accumulate([seq[1] + seq[0] for seq in seq_params])), + dtype=torch.int32, + device=get_accelerator().current_device()) + + inflight_kv = qkv[:, head_size * n_heads_q:] + full_kvs = [] + for i, kv in enumerate(kvs): + if kv is not None: + full_kvs.append(torch.cat([kv, inflight_kv[cu_seqlens_q[i]:cu_seqlens_q[i + 1]]], dim=0)) + else: + full_kvs.append(inflight_kv[cu_seqlens_q[i]:cu_seqlens_q[i + 1]]) + run_kvs = torch.cat(full_kvs, dim=0) + k = run_kvs[:, :head_size * n_heads_kv] + v = run_kvs[:, head_size * n_heads_kv:] + + q = qkv[:, :head_size * n_heads_q] + q_ref = q.reshape((batch.current_tokens, n_heads_q, head_size)) + k_ref = k.reshape((k.shape[0], n_heads_kv, head_size)) + v_ref = v.reshape((v.shape[0], n_heads_kv, head_size)) + + max_seqlen_q = max([seq[0] for seq in seq_params]) + max_seqlen_kv = max([seq[1] + seq[0] for seq in seq_params]) + + ref_o = flash_attn_varlen_func(q_ref, + k_ref, + v_ref, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + softmax_scale=1.0, + causal=True) + + ref_o = ref_o.reshape(batch.current_tokens, head_size * n_heads_q) + + assert allclose(out, ref_o) + + get_accelerator().synchronize() + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens", [2, 33, 65, 128, 256, 2037]) +def test_single_prompt(n_tokens: int) -> None: + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + + seq_params = [(n_tokens, 0)] + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("prompt_lengths", [(128, 128), (192, 38), (514, 713), (83, 312, 610)]) +def test_multiple_prompts(prompt_lengths: Tuple[int, int]) -> None: + """ + Test multiple prompts in a single batch. + """ + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + + seq_params = [(prompt_lengths[i], 0) for i in range(len(prompt_lengths))] + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("seq_params", [(1, 34), (43, 40), (1, 144), (64, 128), (332, 628)]) +def test_continuation(seq_params: Tuple[int, int]) -> None: + """ + Test continued generation/prompt processing. + """ + head_size = 64 + n_heads_q = 32 + n_heads_kv = 32 + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, [seq_params]) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("head_size", [64, 128]) +def test_head_size(head_size: int) -> None: + n_heads_q = 16 + n_heads_kv = 16 + seq_params = [(128, 128), (192, 38), (1, 814)] + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("head_config", [(32, 8), (64, 16), (40, 8)]) +def test_gqa(head_config: Tuple[int, int]) -> None: + head_size = 128 + n_heads_q = head_config[0] + n_heads_kv = head_config[1] + + seq_params = [(128, 128), (192, 38), (1, 814)] + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +def test_fully_composed() -> None: + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + + seq_params = [(332, 628), (1, 718), (1, 323), (180, 5), (224, 0)] + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("trained_freqs", [True, False]) +def test_rotary_emb(trained_freqs: bool) -> None: + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + + seq_params = [(332, 628), (1, 718), (1, 323), (180, 5), (224, 0)] + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params, trained_freqs=trained_freqs) diff --git a/tests/unit/inference/v2/modules/test_cuda_pre_ln_module.py b/tests/unit/inference/v2/modules/test_cuda_pre_ln_module.py new file mode 100644 index 000000000000..386f3b3ef0b3 --- /dev/null +++ b/tests/unit/inference/v2/modules/test_cuda_pre_ln_module.py @@ -0,0 +1,88 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional, Tuple + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.modules import ConfigBundle +from deepspeed.inference.v2.modules.configs import DSNormConfig +from deepspeed.inference.v2.modules.interfaces import DSPreNormRegistry +from ...v2.inference_test_utils import get_dtypes, allclose + + +def reference_implementation(residual: torch.Tensor, hidden_states: Optional[torch.Tensor], gamma: torch.Tensor, + beta: torch.Tensor, epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]: + dtype = residual.dtype + + residual = residual.to(torch.float32) + gamma = gamma.to(torch.float32) + beta = beta.to(torch.float32) + + if hidden_states is not None: + hidden_states = hidden_states.to(torch.float32) + residual = residual + hidden_states + hidden_states = torch.nn.functional.layer_norm(residual, (residual.size(-1), ), + weight=gamma, + bias=beta, + eps=epsilon) + return residual.to(dtype), hidden_states.to(dtype) + + +def _pre_ln_test_helper(n_tokens: int, n_channels: int, dtype: torch.dtype, res_add: bool = False): + config = DSNormConfig(max_tokens=2048, + type="layer_norm", + channels=n_channels, + residual_dtype=dtype, + input_dtype=dtype, + output_dtype=dtype, + eps=1e-5) + bundle = ConfigBundle(name='cuda_pre_ln', config=config) + + # Input vals + if res_add: + hidden_states = torch.randn((n_tokens, n_channels), + dtype=dtype, + device=get_accelerator().current_device_name()) + else: + hidden_states = None + + residual = torch.randn((n_tokens, n_channels), dtype=dtype, device=get_accelerator().current_device_name()) + gamma = torch.randn((n_channels), dtype=torch.float32, device=get_accelerator().current_device_name()) + beta = torch.rand((n_channels), dtype=torch.float32, device=get_accelerator().current_device_name()) + epsilon = 1e-5 + + # Reference output + ref_residual, ref_output = reference_implementation(residual, hidden_states, gamma, beta, epsilon) + + # New output + pre_ln_module = DSPreNormRegistry.instantiate_config(bundle) + gamma = pre_ln_module.transform_param(gamma) + beta = pre_ln_module.transform_param(beta) + + ds_residual, ds_output = pre_ln_module(residual, hidden_states, gamma, beta) + + # Check + assert allclose(ds_residual, ref_residual) + assert allclose(ds_output, ref_output) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens, channels", [(1, 2048), (37, 8192), (1280, 768), (2048, 5120)]) +def test_token_channels(tokens: int, channels: int) -> None: + _pre_ln_test_helper(tokens, channels, torch.float16) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("dtype", get_dtypes(include_float=False)) +def test_dtype(dtype: torch.dtype) -> None: + _pre_ln_test_helper(733, 2560, dtype) + + +@pytest.mark.inference_v2_ops +def test_no_res_add(): + _pre_ln_test_helper(733, 2560, torch.float16, res_add=False) diff --git a/tests/unit/inference/v2/modules/test_custom_module.py b/tests/unit/inference/v2/modules/test_custom_module.py new file mode 100644 index 000000000000..eb54b7a913f2 --- /dev/null +++ b/tests/unit/inference/v2/modules/test_custom_module.py @@ -0,0 +1,76 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.modules import ConfigBundle +from deepspeed.inference.v2.modules.interfaces import DSPostNormRegistry +from deepspeed.inference.v2.modules.configs import DSNormConfig +from deepspeed.inference.v2.modules.implementations import cuda_post_ln +from ...v2.inference_test_utils import allclose + + +def reference_implementation(residual: torch.Tensor, hidden_states: torch.Tensor, gamma: torch.Tensor, + beta: torch.Tensor, epsilon: float) -> torch.Tensor: + residual_f = residual.to(torch.float32) + hidden_states_f = hidden_states.to(torch.float32) + gamma_f = gamma.to(torch.float32) + beta_f = beta.to(torch.float32) + return torch.nn.functional.layer_norm(residual_f + hidden_states_f, (hidden_states_f.size(-1), ), + weight=gamma_f, + bias=beta_f, + eps=epsilon).to(hidden_states.dtype) + + +@DSPostNormRegistry.register_module +class CustomPostLNModule(cuda_post_ln.DSPostLNCUDAModule): + + @staticmethod + def name(): + return 'custom_post_ln' + + +""" +Here, we explicitly register an LN implementation outside the core deepspeed repo. This should +validate that the registry is working as expected and we can implement modules outside the core +repo. +""" + + +@pytest.mark.inference_v2_ops +def test_custom_registration(): + channels = 4096 + dtype = torch.float16 + tokens = 1024 + + config = DSNormConfig(max_tokens=2048, + type="layer_norm", + channels=channels, + residual_dtype=dtype, + input_dtype=dtype, + output_dtype=dtype, + eps=1e-5) + bundle = ConfigBundle(name='custom_post_ln', config=config) + + # Input vals + hidden_states = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name()) + residual = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name()) + gamma = torch.randn((channels), dtype=torch.float32, device=get_accelerator().current_device_name()) + beta = torch.rand((channels), dtype=torch.float32, device=get_accelerator().current_device_name()) + epsilon = 1e-5 + + # Reference output + ref_output = reference_implementation(residual, hidden_states, gamma, beta, epsilon) + + # New output + post_ln_module = DSPostNormRegistry.instantiate_config(bundle) + gamma = post_ln_module.transform_param(gamma) + beta = post_ln_module.transform_param(beta) + ds_output, _ = post_ln_module(residual, hidden_states, gamma, beta) + + # Check + assert allclose(ds_output, ref_output) diff --git a/tests/unit/inference/v2/modules/test_cutlass_moe.py b/tests/unit/inference/v2/modules/test_cutlass_moe.py new file mode 100644 index 000000000000..e21170c9ed8f --- /dev/null +++ b/tests/unit/inference/v2/modules/test_cutlass_moe.py @@ -0,0 +1,214 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Tuple + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import ActivationType, DtypeEnum +from deepspeed.inference.v2.modules import ConfigBundle +from deepspeed.inference.v2.modules.configs import DSMoEConfig +from deepspeed.inference.v2.modules.interfaces import DSMoERegistry + +from ..kernels.ragged_ops.ragged_testing_utils import build_simple_batch +from ...v2.inference_test_utils import allclose, get_dtypes + + +def _gating_reference(logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Reference gating code. + """ + logits = logits.float() + probs = torch.nn.functional.softmax(logits, dim=1) + + indices1_s = torch.argmax(probs, dim=-1) + mask1 = torch.nn.functional.one_hot(indices1_s, num_classes=logits.shape[-1]) + indices_mask = mask1.sum(dim=1) * logits.shape[-1] - 1 + indices1_s = torch.min(indices1_s, indices_mask) + + gates1_s = (probs * mask1).sum(dim=1) + + sorted_indices = indices1_s.sort()[1] + original_indices = sorted_indices.sort()[1] + + exp_count = torch.bincount(indices1_s, minlength=logits.shape[-1]).long() + exp_count_cumsum = exp_count.cumsum(dim=0) + + return sorted_indices, original_indices, exp_count_cumsum, gates1_s + + +def _reference_impl(hidden_states: torch.Tensor, gate_weight: torch.Tensor, mlp_1_w: torch.Tensor, + mlp_2_w: torch.Tensor, mlp_1_b: torch.Tensor, mlp_2_b: torch.Tensor, + act_fn: ActivationType) -> torch.Tensor: + """ + Reference implementation of the MoE module. + """ + + act_fn_dict = { + ActivationType.GELU: torch.nn.functional.gelu, + ActivationType.RELU: torch.nn.functional.relu, + ActivationType.SILU: torch.nn.functional.silu, + ActivationType.IDENTITY: lambda x: x, + } + + logits = torch.matmul(hidden_states, gate_weight.t()) + sorted_indices, original_indices, exp_count_cumsum, gate_scales = _gating_reference(logits) + + moe_input = hidden_states[sorted_indices] + + output_unordered = torch.empty_like(hidden_states) + + for expert_idx in range(mlp_1_w.shape[0]): + min_bound = 0 if expert_idx == 0 else exp_count_cumsum[expert_idx - 1] + max_bound = exp_count_cumsum[expert_idx] + + input_slice = moe_input[min_bound:max_bound] + intermediate = torch.nn.functional.linear(input_slice, mlp_1_w[expert_idx], mlp_1_b[expert_idx]) + + intermediate = act_fn_dict[act_fn](intermediate) + output_slice = torch.nn.functional.linear(intermediate, mlp_2_w[expert_idx], mlp_2_b[expert_idx]) + + output_unordered[min_bound:max_bound] = output_slice + + output = output_unordered[original_indices] + + output.mul_(gate_scales.unsqueeze(-1)).reshape(hidden_states.shape) + return output + + +def _cutlass_moe_testing_helper(tokens: int, + in_channels: int, + intermediate_dim: int, + experts: int, + dtype: int, + activation_type: ActivationType = ActivationType.GELU, + use_bias: bool = True, + iters: int = 1) -> None: + + config = DSMoEConfig(max_tokens=4096, + model_dim=in_channels, + intermediate_features=intermediate_dim, + n_experts=experts, + activation=activation_type, + input_dtype=dtype, + output_dtype=dtype) + + implementation_config = {"weight_dtype": DtypeEnum(dtype)} + + bundle = ConfigBundle(name='cutlass_multi_gemm_moe', config=config, implementation_config=implementation_config) + moe_module = DSMoERegistry.instantiate_config(bundle) + + batch = build_simple_batch([tokens]) + + # Parameters + gate_weight = torch.randn( + (experts, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + + mlp_1_w = torch.randn( + (experts, intermediate_dim, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + mlp_2_w = torch.randn( + (experts, in_channels, intermediate_dim), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + + if use_bias: + mlp_1_b = torch.randn( + (experts, intermediate_dim), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + mlp_2_b = torch.randn( + (experts, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + else: + mlp_1_b = None + mlp_2_b = None + + gate_ds = moe_module.transform_gate_param(gate_weight) + mlp_1_w_ds = moe_module.transform_moe_mlp_1_param(mlp_1_w) + mlp_1_b_ds = moe_module.transform_moe_mlp_1_param(mlp_1_b) + mlp_2_w_ds = moe_module.transform_moe_mlp_2_param(mlp_2_w) + mlp_2_b_ds = moe_module.transform_moe_mlp_2_param(mlp_2_b) + + for _ in range(iters): + # Input vals + hidden_states = torch.randn( + (tokens, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + + # Reference implementation + ref_output = _reference_impl(hidden_states, gate_weight, mlp_1_w, mlp_2_w, mlp_1_b, mlp_2_b, activation_type) + + output = moe_module(hidden_states, + batch, + gate_ds, + mlp_1_w_ds, + mlp_2_w_ds, + mlp_1_b=mlp_1_b_ds, + mlp_2_b=mlp_2_b_ds) + + # Increase the tolerance for larger meta ops since the error is additive + assert allclose(output, ref_output, tolerances=(1e-2, 1e-2)) + + get_accelerator().synchronize() + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("experts", [2, 32, 64]) +def test_expert_variance(experts: int) -> None: + _cutlass_moe_testing_helper(tokens=876, + in_channels=4096, + intermediate_dim=2048, + experts=experts, + dtype=DtypeEnum.fp16, + activation_type=ActivationType.IDENTITY, + use_bias=True) + + +@pytest.mark.inference_v2_ops +def test_successive_inputs(): + """ + The CUTLASS MoE uses persistent state (expert counts) that is assumed to be cleared + on each forward pass. This ensures that the module is clearing that metadata. + """ + _cutlass_moe_testing_helper(tokens=876, + in_channels=4096, + intermediate_dim=2048, + experts=64, + dtype=DtypeEnum.fp16, + activation_type=ActivationType.IDENTITY, + use_bias=True, + iters=10) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("dtype", get_dtypes(include_float=False)) +def test_dtypes(dtype: torch.dtype) -> None: + _cutlass_moe_testing_helper(tokens=876, + in_channels=4096, + intermediate_dim=2048, + experts=64, + dtype=DtypeEnum(dtype), + activation_type=ActivationType.IDENTITY, + use_bias=True) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("activation_type", [ActivationType.GELU, ActivationType.RELU, ActivationType.SILU]) +def test_activation_types(activation_type: ActivationType) -> None: + _cutlass_moe_testing_helper(tokens=876, + in_channels=4096, + intermediate_dim=2048, + experts=64, + dtype=DtypeEnum.fp16, + activation_type=activation_type, + use_bias=True) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("in_channels, out_channels", [(4096, 2048), (2048, 8192), (6144, 3072)]) +def test_in_out_channels(in_channels: int, out_channels: int) -> None: + _cutlass_moe_testing_helper(tokens=876, + in_channels=in_channels, + intermediate_dim=out_channels, + experts=64, + dtype=DtypeEnum.fp16, + activation_type=ActivationType.IDENTITY, + use_bias=True) diff --git a/tests/unit/inference/v2/modules/test_post_ln_module.py b/tests/unit/inference/v2/modules/test_post_ln_module.py new file mode 100644 index 000000000000..f9dcfd272170 --- /dev/null +++ b/tests/unit/inference/v2/modules/test_post_ln_module.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.modules import ConfigBundle +from deepspeed.inference.v2.modules.configs import DSNormConfig +from deepspeed.inference.v2.modules.interfaces import DSPostNormRegistry +from ...v2.inference_test_utils import get_dtypes, allclose + + +def reference_implementation(residual: torch.Tensor, hidden_states: torch.Tensor, gamma: torch.Tensor, + beta: torch.Tensor, epsilon: float) -> torch.Tensor: + residual_f = residual.to(torch.float32) + hidden_states_f = hidden_states.to(torch.float32) + gamma_f = gamma.to(torch.float32) + beta_f = beta.to(torch.float32) + return torch.nn.functional.layer_norm(residual_f + hidden_states_f, (hidden_states_f.size(-1), ), + weight=gamma_f, + bias=beta_f, + eps=epsilon).to(hidden_states.dtype) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens, channels", [(1, 2048), (37, 8192), (1280, 768), (2048, 5120)]) +@pytest.mark.parametrize("dtype", get_dtypes()) +def test_cuda_post_ln_module(tokens: int, channels: int, dtype: torch.dtype) -> None: + config = DSNormConfig(max_tokens=2048, + type="layer_norm", + channels=channels, + residual_dtype=dtype, + input_dtype=dtype, + output_dtype=dtype, + eps=1e-5) + bundle = ConfigBundle(name='cuda_post_ln', config=config) + + # Input vals + hidden_states = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name()) + residual = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name()) + gamma = torch.randn((channels), dtype=torch.float32, device=get_accelerator().current_device_name()) + beta = torch.rand((channels), dtype=torch.float32, device=get_accelerator().current_device_name()) + epsilon = 1e-5 + + # Reference output + ref_output = reference_implementation(residual, hidden_states, gamma, beta, epsilon) + + # New output + post_ln_module = DSPostNormRegistry.instantiate_config(bundle) + gamma = post_ln_module.transform_param(gamma) + beta = post_ln_module.transform_param(beta) + ds_output, _ = post_ln_module(residual, hidden_states, gamma, beta) + + # Check + assert allclose(ds_output, ref_output) diff --git a/tests/unit/inference/v2/modules/test_pre_rms_module.py b/tests/unit/inference/v2/modules/test_pre_rms_module.py new file mode 100644 index 000000000000..bbd108a35a5a --- /dev/null +++ b/tests/unit/inference/v2/modules/test_pre_rms_module.py @@ -0,0 +1,88 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional, Tuple + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.modules import ConfigBundle +from deepspeed.inference.v2.modules.configs import DSNormConfig +from deepspeed.inference.v2.modules.interfaces import DSPreNormRegistry +from ...v2.inference_test_utils import get_dtypes, allclose + + +def reference_implementation(residual: torch.Tensor, hidden_states: Optional[torch.Tensor], gamma: torch.Tensor, + epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]: + dtype = residual.dtype + + if hidden_states is not None: + hidden_states = hidden_states + residual = residual + hidden_states + + rms_vals = residual.to(torch.float32) + variance = rms_vals.pow(2).mean(-1, keepdim=True) + rms_vals = rms_vals * torch.rsqrt(variance + epsilon) + + if gamma.dtype in [torch.float16, torch.bfloat16]: + rms_vals = rms_vals.to(gamma.dtype) + + hidden_states = gamma * rms_vals + + return residual.to(dtype), hidden_states.to(dtype) + + +def _pre_rms_test_helper(n_tokens: int, n_channels: int, dtype: torch.dtype, res_add: bool = False): + config = DSNormConfig(max_tokens=2048, + type="rms_norm", + channels=n_channels, + residual_dtype=dtype, + input_dtype=dtype, + output_dtype=dtype, + eps=1e-5) + bundle = ConfigBundle(name='cuda_pre_rms', config=config) + + # Input vals + if res_add: + hidden_states = torch.randn((n_tokens, n_channels), + dtype=dtype, + device=get_accelerator().current_device_name()) + else: + hidden_states = None + + residual = torch.randn((n_tokens, n_channels), dtype=dtype, device=get_accelerator().current_device_name()) + gamma = torch.randn((n_channels), dtype=torch.float32, device=get_accelerator().current_device_name()) + epsilon = 1e-5 + + # Reference output + ref_residual, ref_output = reference_implementation(residual, hidden_states, gamma, epsilon) + + # New output + pre_ln_module = DSPreNormRegistry.instantiate_config(bundle) + gamma = pre_ln_module.transform_param(gamma) + + ds_residual, ds_output = pre_ln_module(residual, hidden_states, gamma) + + # Check + assert allclose(ds_residual, ref_residual) + assert allclose(ds_output, ref_output) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens, channels", [(1, 2048), (37, 8192), (1280, 768), (2048, 5120)]) +def test_token_channels(tokens: int, channels: int) -> None: + _pre_rms_test_helper(tokens, channels, torch.float16) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("dtype", get_dtypes(include_float=False)) +def test_dtype(dtype: torch.dtype) -> None: + _pre_rms_test_helper(733, 2560, dtype) + + +@pytest.mark.inference_v2_ops +def test_no_res_add(): + _pre_rms_test_helper(733, 2560, torch.float16, res_add=False) diff --git a/tests/unit/inference/v2/ragged/__init__.py b/tests/unit/inference/v2/ragged/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/v2/ragged/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/v2/ragged/test_blocked_allocator.py b/tests/unit/inference/v2/ragged/test_blocked_allocator.py new file mode 100644 index 000000000000..4596e81c5652 --- /dev/null +++ b/tests/unit/inference/v2/ragged/test_blocked_allocator.py @@ -0,0 +1,166 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import random +from typing import List + +import pytest +import torch + +from deepspeed.inference.v2.ragged.blocked_allocator import BlockedAllocator + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize('bad_size', [0, -1]) +def test_bad_initialization(bad_size: int) -> None: + with pytest.raises(ValueError): + BlockedAllocator(bad_size) + + +@pytest.mark.inference_v2 +def test_allocation() -> None: + + allocator = BlockedAllocator(16) + + a1 = allocator.allocate(4) + assert a1.numel() == 4 + assert allocator.free_blocks == 12 + + a2_allocs = [] + for i in range(3): + a2_allocs.append(allocator.allocate(2)) + assert allocator.free_blocks == 12 - (i + 1) * 2 + + a3 = allocator.allocate(6) + assert a3.numel() == 6 + + assert allocator.free_blocks == 0 + + # Test that we can't allocate more blocks than we have. + with pytest.raises(ValueError): + allocator.allocate(1) + + all_vals = torch.cat([a1, *a2_allocs, a3], dim=0) + unique_vals = torch.unique(all_vals, sorted=False) + assert unique_vals.numel() == all_vals.numel() + + +@pytest.mark.inference_v2 +def test_too_large_allocation(): + allocator = BlockedAllocator(16) + + with pytest.raises(ValueError): + allocator.allocate(17) + + +@pytest.mark.inference_v2 +def test_deallocation() -> None: + allocator = BlockedAllocator(16) + + # Allocate + all_blocks = allocator.allocate(16) + assert allocator.free_blocks == 0 + + # Deallocate all blocks + allocator.free(all_blocks) + assert allocator.free_blocks == 16 + + # Get all the blocks again + all_blocks = allocator.allocate(16) + + # Deallocate in chunks + c1 = all_blocks[:4] + c2 = all_blocks[4:8] + + allocator.free(c1) + assert allocator.free_blocks == 4 + + allocator.free(c2) + assert allocator.free_blocks == 8 + + with pytest.raises(ValueError): + allocator.free(c1) + + with pytest.raises(ValueError): + allocator.free(c2) + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize('index', [-1, 2]) +def test_invalid_dealloc_indices(index: int): + allocator = BlockedAllocator(1) + + with pytest.raises(ValueError): + allocator.free(torch.tensor([index])) + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize('index', [-1, 2]) +def test_invalid_alloc_indices(index: int): + allocator = BlockedAllocator(1) + allocator.allocate(1) + + to_free = [0, index] + + with pytest.raises(ValueError): + allocator.free(torch.tensor(to_free)) + + # Block 0 should not be freed if passed with an invalid index. + assert allocator.free_blocks == 0 + + allocator.free(torch.tensor([0])) + assert allocator.free_blocks == 1 + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize('test_iters', [8192]) +def test_long_running_allocation(test_iters: int) -> None: + """ + Evaluate the stability of the allocator over a longer sequence of allocations/deallocations. + """ + TOTAL_BLOCKS = 128 + + allocator = BlockedAllocator(TOTAL_BLOCKS) + + def validate_uniqueness(all_blocks: List[torch.Tensor]) -> None: + all_vals = torch.cat(all_blocks, dim=0) + assert all_vals.numel() <= TOTAL_BLOCKS + + unique_vals = torch.unique(all_vals, sorted=False) + assert unique_vals.numel() == all_vals.numel() + + all_allocs: List[torch.Tensor] = [] + num_allocs = 0 + num_frees = 0 + num_blocks_allocated = 0 + num_blocks_freed = 0 + + for _ in range(test_iters): + decision = random.randint(0, 1) + + if decision == 0: + blocks_to_allocate = random.randint(1, 24) + if blocks_to_allocate > allocator.free_blocks: + with pytest.raises(ValueError): + allocator.allocate(blocks_to_allocate) + else: + all_allocs.append(allocator.allocate(blocks_to_allocate)) + num_allocs += 1 + num_blocks_allocated += blocks_to_allocate + else: + if len(all_allocs) > 0: + idx = random.randint(0, len(all_allocs) - 1) + allocator.free(all_allocs[idx]) + + num_frees += 1 + num_blocks_freed += all_allocs[idx].numel() + + del all_allocs[idx] + + if len(all_allocs) > 0: + validate_uniqueness(all_allocs) + + assert num_allocs == num_frees + len(all_allocs) + assert num_blocks_allocated == num_blocks_freed + (TOTAL_BLOCKS - allocator.free_blocks) diff --git a/tests/unit/inference/v2/ragged/test_manager_configs.py b/tests/unit/inference/v2/ragged/test_manager_configs.py new file mode 100644 index 000000000000..bdd513445ddb --- /dev/null +++ b/tests/unit/inference/v2/ragged/test_manager_configs.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest + +from pydantic import ValidationError + +from deepspeed.inference.v2.ragged import DSStateManagerConfig + + +@pytest.mark.inference_v2 +def test_negative_max_tracked_sequences() -> None: + with pytest.raises(ValidationError): + DSStateManagerConfig(max_tracked_sequences=-1) + + +@pytest.mark.inference_v2 +def test_zero_max_tracked_sequences() -> None: + with pytest.raises(ValidationError): + DSStateManagerConfig(max_tracked_sequences=0) + + +@pytest.mark.inference_v2 +def test_negative_max_ragged_batch_size() -> None: + with pytest.raises(ValidationError): + DSStateManagerConfig(max_ragged_batch_size=-1) + + +@pytest.mark.inference_v2 +def test_zero_max_ragged_batch_size() -> None: + with pytest.raises(ValidationError): + DSStateManagerConfig(max_ragged_batch_size=0) + + +@pytest.mark.inference_v2 +def test_negative_max_ragged_sequence_count() -> None: + with pytest.raises(ValidationError): + DSStateManagerConfig(max_ragged_sequence_count=-1) + + +@pytest.mark.inference_v2 +def test_zero_max_ragged_sequence_count() -> None: + with pytest.raises(ValidationError): + DSStateManagerConfig(max_ragged_sequence_count=0) + + +@pytest.mark.inference_v2 +def test_too_small_max_ragged_batch_size() -> None: + with pytest.raises(ValidationError): + DSStateManagerConfig(max_ragged_batch_size=512, max_ragged_sequence_count=1024) + + +@pytest.mark.inference_v2 +def test_too_small_max_tracked_sequences() -> None: + with pytest.raises(ValidationError): + DSStateManagerConfig(max_tracked_sequences=512, max_ragged_sequence_count=1024) diff --git a/tests/unit/inference/v2/ragged/test_ragged_wrapper.py b/tests/unit/inference/v2/ragged/test_ragged_wrapper.py new file mode 100644 index 000000000000..3cb74f4c49d2 --- /dev/null +++ b/tests/unit/inference/v2/ragged/test_ragged_wrapper.py @@ -0,0 +1,112 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.ragged import ( + PlaceholderSequenceDescriptor, + RaggedBatchWrapper, + DSStateManagerConfig, +) + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize('max_ragged_sequence_count, max_ragged_batch_size', [(128, 512), (128, 1024)]) +def test_wrapper_initialization(max_ragged_sequence_count: int, max_ragged_batch_size: int) -> None: + config = DSStateManagerConfig(max_tracked_sequences=max_ragged_sequence_count, + max_ragged_batch_size=max_ragged_batch_size, + max_ragged_sequence_count=max_ragged_sequence_count) + + batch = RaggedBatchWrapper(config) + + assert batch.current_tokens == 0 + assert batch.current_sequences == 0 + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize('seq_len', [1, 37, 128, 512]) +def test_single_sequence_batch(seq_len: int) -> None: + """ + Test we successfully construct single sequence batches and the on device metadata is accurate. + """ + + config = DSStateManagerConfig() + batch = RaggedBatchWrapper(config) + + batch.clear() + + assert batch.current_tokens == 0 + assert batch.current_sequences == 0 + + seq_desc = PlaceholderSequenceDescriptor() + tokens = torch.randint(0, 100, (seq_len, )) + batch.insert_sequence(seq_desc, tokens) + + batch.finalize() + + assert batch.current_tokens == seq_len + assert batch.current_sequences == 1 + assert torch.equal(batch.input_ids(), tokens.to(get_accelerator().current_device())) + assert torch.equal(batch.tokens_to_seq(), torch.zeros_like(tokens, device=get_accelerator().current_device())) + assert torch.equal(batch.batch_metadata_buffer(), + torch.tensor([seq_len, 1], device=get_accelerator().current_device())) + + batch.clear() + + assert batch.current_tokens == 0 + assert batch.current_sequences == 0 + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize('seq_lens', [[128, 128], [1, 32, 243], [64, 1, 1, 1, 1, 393, 27, 2]]) +def test_multi_sequence_batch(seq_lens: List[int]) -> None: + """ + Test sequentially adding new tokens to a batch and validate device data structures hold + the appropriate data. + """ + config = DSStateManagerConfig() + batch = RaggedBatchWrapper(config) + + batch.clear() + + assert batch.current_tokens == 0 + assert batch.current_sequences == 0 + + all_toks = [torch.randint(0, 100, (seq_len, )) for seq_len in seq_lens] + + for i, toks in enumerate(all_toks): + seq_desc = PlaceholderSequenceDescriptor() + batch.insert_sequence(seq_desc, toks) + + assert batch.current_tokens == sum(seq_lens[:i + 1]) + assert batch.current_sequences == i + 1 + + batch.finalize() + + assert batch.current_tokens == sum(seq_lens) + assert batch.current_sequences == len(seq_lens) + + assert torch.equal(batch.input_ids(), torch.cat(all_toks, dim=0).to(get_accelerator().current_device())) + assert torch.equal( + batch.tokens_to_seq(), + torch.cat([torch.full((seq_len, ), i, dtype=torch.int32) for i, seq_len in enumerate(seq_lens)], + dim=0).to(get_accelerator().current_device())) + + for i, seq_len in enumerate(seq_lens): + assert batch.inflight_seq_descriptors()[i][0] == sum(seq_lens[:i]) + assert batch.inflight_seq_descriptors()[i][1] == seq_len + assert batch.inflight_seq_descriptors()[i][2] == 0 + + assert torch.equal(batch.batch_metadata_buffer(), + torch.tensor([sum(seq_lens), len(seq_lens)], device=get_accelerator().current_device())) + + batch.clear() + + assert batch.current_tokens == 0 + assert batch.current_sequences == 0