diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_vnni.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_vnni.c index 943d6074ffc5c..db47a07516320 100644 --- a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_vnni.c +++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_vnni.c @@ -254,3 +254,143 @@ IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8_16( iree_uk_mmt4d_tile_s16s16s32_4x16x2_x86_64_avx512_vnni, iree_uk_mmt4d_tile_s16s16s32_8x16x2_x86_64_avx512_vnni, iree_uk_mmt4d_tile_s16s16s32_16x16x2_x86_64_avx512_vnni) + +// This kernel is parametrized in N0, allowing N0==16 and N0==32. +// Performance on AMD Ryzen 9 7950X3D: +// - with N0=16: 180 Gop/s +// - with N0=32: 240 Gop/s +// So there's a nice reward for going extra large, but that's also a liability +// for vecmat shapes whose N dimension isn't a multiple of 32. Maybe we can +// keep both for now. +// +// The idea of this kernel is to split the LHS s16 values into high and low +// 8-bit components to be able to use _mm512_dpbusd_epi32. +// +// In itself, that doesn't reduce the number of arithmetic instructions: while +// each now computes a 4D dot-product instead of a 2D one as in +// _mm512_dpwssd_epi32, we now need twice more of them to do separately the +// high and low 8bit parts of the LHS s16 values. +// +// The real benefit is that this removes the need to extend RHS u4 values to +// s16. Since this is a vecmat kernel, the LHS is small and the RHS is big, +// so it matters to avoid RHS-processing work. +// +// It's not trivial how to use _mm512_dpbusd_epi32, with its quirky +// unsigned * signed semantics. We take advantage of the fact that our u4 +// RHS values, when extended to u8, do not use the top bit -- so they are +// also interpretable as s8 values in place. So this is specific to RHS +// being less-than-8-bit values (it's not specific beyond that to 4bit). +// Meanwhile, when we split the LHS s16 values into high and low 8bit components +// the high 8bits are signed s8 and the low 8bit are unsigned u8. So, for each +// of the combinations of operands that we have to feed _mm512_dpbusd_epi32, +// we manage to find an operand order that accomodates the instruction's +// requirements on signednesses. +static inline void +iree_uk_mmt4d_tile_s16u4s32_1x16x8_to_1x32x8_x86_64_avx512_vnni( + void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel, + const void* IREE_UK_RESTRICT rhs_panel, + const iree_uk_mmt4d_params_t* params, int N0) { + IREE_UK_ASSERT(N0 >= 1 && N0 <= 16 && iree_uk_is_po2_u32(N0)); + iree_uk_int32_t* IREE_UK_RESTRICT out_ptr = out_tile; + const iree_uk_int16_t* IREE_UK_RESTRICT lhs_ptr = lhs_panel; + const iree_uk_uint8_t* IREE_UK_RESTRICT rhs_ptr = rhs_panel; + // acc[4 * i] is the actual accumulator. + // The other acc[4 * i + j] are only used internally in the accumulation loop. + __m512i acc[8]; + if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) { + for (int i = 0; i < N0 / 16; ++i) { + acc[4 * i] = _mm512_loadu_si512((const __m512i*)(out_ptr + 16 * i)); + } + } else { + for (int i = 0; i < N0 / 16; ++i) { + acc[4 * i] = _mm512_setzero_si512(); + } + } + for (int i = 0; i < N0 / 16; ++i) { + for (int j = 1; j < 4; ++j) { + acc[4 * i + j] = _mm512_setzero_si512(); + } + } + + const __m128i idx_0_mod_4 = _mm_set1_epi32(0x0c080400); + const __m128i idx_1_mod_4 = _mm_set1_epi32(0x0d090501); + const __m128i idx_2_mod_4 = _mm_set1_epi32(0x0e0a0602); + const __m128i idx_3_mod_4 = _mm_set1_epi32(0x0f0b0703); + const __m512i mask_0f = _mm512_set1_epi8(0x0f); + IREE_UK_ASSUME(params->K >= 1); + for (iree_uk_int32_t k = 0; k < params->K; ++k) { + // Load 8xs16 LHS data. + __m128i lhs = _mm_loadu_si128((const __m128i*)lhs_ptr); + lhs_ptr += 8; + // Extract the even/odd s16 lanes and within them, the low/high 8bit parts, + // and broadcast into 512bit registers to multiply against RHS data. + __m512i lhs_even_s16_low_u8 = + _mm512_broadcastq_epi64(_mm_shuffle_epi8(lhs, idx_0_mod_4)); + __m512i lhs_even_s16_high_s8 = + _mm512_broadcastq_epi64(_mm_shuffle_epi8(lhs, idx_1_mod_4)); + __m512i lhs_odd_s16_low_u8 = + _mm512_broadcastq_epi64(_mm_shuffle_epi8(lhs, idx_2_mod_4)); + __m512i lhs_odd_s16_high_s8 = + _mm512_broadcastq_epi64(_mm_shuffle_epi8(lhs, idx_3_mod_4)); + // Load 8x16xu4 RHS data. + __m512i rhs[2]; + for (int i = 0; i < N0 / 16; ++i) { + rhs[i] = _mm512_loadu_si512((const __m512i*)(rhs_ptr + 64 * i)); + } + rhs_ptr += N0 * 4; + // Extract the even/odd u4 lanes. + __m512i rhs_even_u4[2]; + __m512i rhs_odd_u4[2]; + for (int i = 0; i < N0 / 16; ++i) { + rhs_even_u4[i] = _mm512_and_si512(mask_0f, rhs[i]); + rhs_odd_u4[i] = _mm512_and_si512(mask_0f, _mm512_srli_epi16(rhs[i], 4)); + } + // Arithmetic. See the comment at the top of this kernel for an explanation. + // _mm512_dpbusd_epi32 takes an unsigned LHS and a signed RHS. The parameter + // order in each call is adapted to that constraint. + for (int i = 0; i < N0 / 16; ++i) { + acc[4 * i + 0] = _mm512_dpbusd_epi32(acc[4 * i + 0], lhs_even_s16_low_u8, + rhs_even_u4[i]); + acc[4 * i + 1] = _mm512_dpbusd_epi32(acc[4 * i + 1], rhs_even_u4[i], + lhs_even_s16_high_s8); + acc[4 * i + 2] = _mm512_dpbusd_epi32(acc[4 * i + 2], lhs_odd_s16_low_u8, + rhs_odd_u4[i]); + acc[4 * i + 3] = _mm512_dpbusd_epi32(acc[4 * i + 3], rhs_odd_u4[i], + lhs_odd_s16_high_s8); + } + } + + // The accumulators that contain products against high 8bit parts of s16 LHS + // values need to be left-shifted by 8 bits to account for that. + for (int i = 0; i < N0 / 16; ++i) { + acc[4 * i + 1] = _mm512_slli_epi32(acc[4 * i + 1], 8); + acc[4 * i + 3] = _mm512_slli_epi32(acc[4 * i + 3], 8); + } + + // Add accumulators together. + for (int i = 0; i < N0 / 16; ++i) { + for (int j = 1; j <= 3; ++j) { + acc[4 * i + 0] = _mm512_add_epi32(acc[4 * i + 0], acc[4 * i + j]); + } + } + + for (int i = 0; i < N0 / 16; ++i) { + _mm512_storeu_si512((__m512i*)(out_ptr + 16 * i), acc[4 * i]); + } +} + +void iree_uk_mmt4d_tile_s16u4s32_1x16x8_x86_64_avx512_vnni( + void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel, + const void* IREE_UK_RESTRICT rhs_panel, + const iree_uk_mmt4d_params_t* params) { + iree_uk_mmt4d_tile_s16u4s32_1x16x8_to_1x32x8_x86_64_avx512_vnni( + out_tile, lhs_panel, rhs_panel, params, 16); +} + +void iree_uk_mmt4d_tile_s16u4s32_1x32x8_x86_64_avx512_vnni( + void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel, + const void* IREE_UK_RESTRICT rhs_panel, + const iree_uk_mmt4d_params_t* params) { + iree_uk_mmt4d_tile_s16u4s32_1x16x8_to_1x32x8_x86_64_avx512_vnni( + out_tile, lhs_panel, rhs_panel, params, 32); +} diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c index b957957e13c05..b82451adf1e3a 100644 --- a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c +++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c @@ -338,17 +338,39 @@ iree_uk_mmt4d_select_tile_func_x86_64_s16s16s32_M0x8x2( static iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_x86_64_s16s16s32( const iree_uk_mmt4d_params_t* params) { -#if 1 if (params->N0 == 16 && params->K0 == 2) { return iree_uk_mmt4d_select_tile_func_x86_64_s16s16s32_M0x16x2(params); } -#endif if (params->N0 == 8 && params->K0 == 2) { return iree_uk_mmt4d_select_tile_func_x86_64_s16s16s32_M0x8x2(params); } return 0; } +static iree_uk_mmt4d_tile_func_t +iree_uk_mmt4d_select_tile_func_x86_64_s16u4s32_1xN0x8( + const iree_uk_mmt4d_params_t* params) { +#if defined(IREE_UK_BUILD_X86_64_AVX512_VNNI) + if (params->cpu_data[0] & IREE_CPU_DATA0_X86_64_AVX512VNNI) { + switch (params->N0) { + case 16: + return iree_uk_mmt4d_tile_s16u4s32_1x16x8_x86_64_avx512_vnni; + case 32: + return iree_uk_mmt4d_tile_s16u4s32_1x32x8_x86_64_avx512_vnni; + } + } +#endif + return 0; +} + +static iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_x86_64_s16u4s32( + const iree_uk_mmt4d_params_t* params) { + if (params->M0 == 1 && params->K0 == 8) { + return iree_uk_mmt4d_select_tile_func_x86_64_s16u4s32_1xN0x8(params); + } + return 0; +} + iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_arch( const iree_uk_mmt4d_params_t* params) { switch (iree_uk_mmt4d_type(params->flags)) { @@ -367,7 +389,7 @@ iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_arch( case iree_uk_mmt4d_type_s16s16s32: return iree_uk_mmt4d_select_tile_func_x86_64_s16s16s32(params); case iree_uk_mmt4d_type_s16u4s32: - return 0; + return iree_uk_mmt4d_select_tile_func_x86_64_s16u4s32(params); default: IREE_UK_ASSUME_UNREACHABLE; return 0; diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_internal.h b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_internal.h index f82ff0c889b79..c295edfaf7536 100644 --- a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_internal.h +++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_internal.h @@ -109,5 +109,9 @@ IREE_UK_MMT4D_TILE_FUNC_DECL( iree_uk_mmt4d_tile_s16s16s32_8x16x2_x86_64_avx512_vnni) IREE_UK_MMT4D_TILE_FUNC_DECL( iree_uk_mmt4d_tile_s16s16s32_16x16x2_x86_64_avx512_vnni) +IREE_UK_MMT4D_TILE_FUNC_DECL( + iree_uk_mmt4d_tile_s16u4s32_1x16x8_x86_64_avx512_vnni) +IREE_UK_MMT4D_TILE_FUNC_DECL( + iree_uk_mmt4d_tile_s16u4s32_1x32x8_x86_64_avx512_vnni) #endif // foIREE_BUILTINS_UKERNEL_ARCH_X86_64_MMT4D_X86_64_INTERNAL_H_ diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c index ffa40a775a876..c81ab9364443b 100644 --- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c +++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c @@ -167,6 +167,10 @@ int main(int argc, char** argv) { "avx512_base"); iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16S16S32, 16, 16, 2, "avx512_vnni"); + iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16U4S32, 1, 16, 8, + "avx512_vnni"); + iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16U4S32, 1, 32, 8, + "avx512_vnni"); #else // defined(IREE_ARCH_ARM_64) // Architectures on which we do not have any optimized ukernel code. // Benchmark some arbitrary tile shape. diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c index 995309fe374a9..440a028bd724d 100644 --- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c +++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c @@ -498,7 +498,8 @@ int main(int argc, char** argv) { "avx512_base"); iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16S16S32, 16, 16, 2, "avx512_vnni"); - + iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16U4S32, 1, 16, 8, "avx512_vnni"); + iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16U4S32, 1, 32, 8, "avx512_vnni"); #endif // defined(IREE_ARCH_ARM_64) return iree_uk_test_exit_status();