Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimized vecmat ukernel tile functions for i16 x u4 -> i32 on AVX-512-VNNI #15525

Merged
merged 1 commit into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,143 @@ IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8_16(
iree_uk_mmt4d_tile_s16s16s32_4x16x2_x86_64_avx512_vnni,
iree_uk_mmt4d_tile_s16s16s32_8x16x2_x86_64_avx512_vnni,
iree_uk_mmt4d_tile_s16s16s32_16x16x2_x86_64_avx512_vnni)

// This kernel is parametrized in N0, allowing N0==16 and N0==32.
// Performance on AMD Ryzen 9 7950X3D:
// - with N0=16: 180 Gop/s
// - with N0=32: 240 Gop/s
// So there's a nice reward for going extra large, but that's also a liability
// for vecmat shapes whose N dimension isn't a multiple of 32. Maybe we can
// keep both for now.
//
// The idea of this kernel is to split the LHS s16 values into high and low
// 8-bit components to be able to use _mm512_dpbusd_epi32.
//
// In itself, that doesn't reduce the number of arithmetic instructions: while
// each now computes a 4D dot-product instead of a 2D one as in
// _mm512_dpwssd_epi32, we now need twice more of them to do separately the
// high and low 8bit parts of the LHS s16 values.
//
// The real benefit is that this removes the need to extend RHS u4 values to
// s16. Since this is a vecmat kernel, the LHS is small and the RHS is big,
// so it matters to avoid RHS-processing work.
//
// It's not trivial how to use _mm512_dpbusd_epi32, with its quirky
// unsigned * signed semantics. We take advantage of the fact that our u4
// RHS values, when extended to u8, do not use the top bit -- so they are
// also interpretable as s8 values in place. So this is specific to RHS
// being less-than-8-bit values (it's not specific beyond that to 4bit).
// Meanwhile, when we split the LHS s16 values into high and low 8bit components
// the high 8bits are signed s8 and the low 8bit are unsigned u8. So, for each
// of the combinations of operands that we have to feed _mm512_dpbusd_epi32,
// we manage to find an operand order that accomodates the instruction's
// requirements on signednesses.
static inline void
iree_uk_mmt4d_tile_s16u4s32_1x16x8_to_1x32x8_x86_64_avx512_vnni(
void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
const void* IREE_UK_RESTRICT rhs_panel,
const iree_uk_mmt4d_params_t* params, int N0) {
IREE_UK_ASSERT(N0 >= 1 && N0 <= 16 && iree_uk_is_po2_u32(N0));
iree_uk_int32_t* IREE_UK_RESTRICT out_ptr = out_tile;
const iree_uk_int16_t* IREE_UK_RESTRICT lhs_ptr = lhs_panel;
const iree_uk_uint8_t* IREE_UK_RESTRICT rhs_ptr = rhs_panel;
// acc[4 * i] is the actual accumulator.
// The other acc[4 * i + j] are only used internally in the accumulation loop.
__m512i acc[8];
if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
for (int i = 0; i < N0 / 16; ++i) {
acc[4 * i] = _mm512_loadu_si512((const __m512i*)(out_ptr + 16 * i));
}
} else {
for (int i = 0; i < N0 / 16; ++i) {
acc[4 * i] = _mm512_setzero_si512();
}
}
for (int i = 0; i < N0 / 16; ++i) {
for (int j = 1; j < 4; ++j) {
acc[4 * i + j] = _mm512_setzero_si512();
}
}

const __m128i idx_0_mod_4 = _mm_set1_epi32(0x0c080400);
const __m128i idx_1_mod_4 = _mm_set1_epi32(0x0d090501);
const __m128i idx_2_mod_4 = _mm_set1_epi32(0x0e0a0602);
const __m128i idx_3_mod_4 = _mm_set1_epi32(0x0f0b0703);
const __m512i mask_0f = _mm512_set1_epi8(0x0f);
IREE_UK_ASSUME(params->K >= 1);
for (iree_uk_int32_t k = 0; k < params->K; ++k) {
// Load 8xs16 LHS data.
__m128i lhs = _mm_loadu_si128((const __m128i*)lhs_ptr);
lhs_ptr += 8;
// Extract the even/odd s16 lanes and within them, the low/high 8bit parts,
// and broadcast into 512bit registers to multiply against RHS data.
__m512i lhs_even_s16_low_u8 =
_mm512_broadcastq_epi64(_mm_shuffle_epi8(lhs, idx_0_mod_4));
__m512i lhs_even_s16_high_s8 =
_mm512_broadcastq_epi64(_mm_shuffle_epi8(lhs, idx_1_mod_4));
__m512i lhs_odd_s16_low_u8 =
_mm512_broadcastq_epi64(_mm_shuffle_epi8(lhs, idx_2_mod_4));
__m512i lhs_odd_s16_high_s8 =
_mm512_broadcastq_epi64(_mm_shuffle_epi8(lhs, idx_3_mod_4));
// Load 8x16xu4 RHS data.
__m512i rhs[2];
for (int i = 0; i < N0 / 16; ++i) {
rhs[i] = _mm512_loadu_si512((const __m512i*)(rhs_ptr + 64 * i));
}
rhs_ptr += N0 * 4;
// Extract the even/odd u4 lanes.
__m512i rhs_even_u4[2];
__m512i rhs_odd_u4[2];
for (int i = 0; i < N0 / 16; ++i) {
rhs_even_u4[i] = _mm512_and_si512(mask_0f, rhs[i]);
rhs_odd_u4[i] = _mm512_and_si512(mask_0f, _mm512_srli_epi16(rhs[i], 4));
}
// Arithmetic. See the comment at the top of this kernel for an explanation.
// _mm512_dpbusd_epi32 takes an unsigned LHS and a signed RHS. The parameter
// order in each call is adapted to that constraint.
for (int i = 0; i < N0 / 16; ++i) {
acc[4 * i + 0] = _mm512_dpbusd_epi32(acc[4 * i + 0], lhs_even_s16_low_u8,
rhs_even_u4[i]);
acc[4 * i + 1] = _mm512_dpbusd_epi32(acc[4 * i + 1], rhs_even_u4[i],
lhs_even_s16_high_s8);
acc[4 * i + 2] = _mm512_dpbusd_epi32(acc[4 * i + 2], lhs_odd_s16_low_u8,
rhs_odd_u4[i]);
acc[4 * i + 3] = _mm512_dpbusd_epi32(acc[4 * i + 3], rhs_odd_u4[i],
lhs_odd_s16_high_s8);
}
}

// The accumulators that contain products against high 8bit parts of s16 LHS
// values need to be left-shifted by 8 bits to account for that.
for (int i = 0; i < N0 / 16; ++i) {
acc[4 * i + 1] = _mm512_slli_epi32(acc[4 * i + 1], 8);
acc[4 * i + 3] = _mm512_slli_epi32(acc[4 * i + 3], 8);
}

// Add accumulators together.
for (int i = 0; i < N0 / 16; ++i) {
for (int j = 1; j <= 3; ++j) {
acc[4 * i + 0] = _mm512_add_epi32(acc[4 * i + 0], acc[4 * i + j]);
}
}

for (int i = 0; i < N0 / 16; ++i) {
_mm512_storeu_si512((__m512i*)(out_ptr + 16 * i), acc[4 * i]);
}
}

void iree_uk_mmt4d_tile_s16u4s32_1x16x8_x86_64_avx512_vnni(
void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
const void* IREE_UK_RESTRICT rhs_panel,
const iree_uk_mmt4d_params_t* params) {
iree_uk_mmt4d_tile_s16u4s32_1x16x8_to_1x32x8_x86_64_avx512_vnni(
out_tile, lhs_panel, rhs_panel, params, 16);
}

void iree_uk_mmt4d_tile_s16u4s32_1x32x8_x86_64_avx512_vnni(
void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
const void* IREE_UK_RESTRICT rhs_panel,
const iree_uk_mmt4d_params_t* params) {
iree_uk_mmt4d_tile_s16u4s32_1x16x8_to_1x32x8_x86_64_avx512_vnni(
out_tile, lhs_panel, rhs_panel, params, 32);
}
Original file line number Diff line number Diff line change
Expand Up @@ -338,17 +338,39 @@ iree_uk_mmt4d_select_tile_func_x86_64_s16s16s32_M0x8x2(
static iree_uk_mmt4d_tile_func_t
iree_uk_mmt4d_select_tile_func_x86_64_s16s16s32(
const iree_uk_mmt4d_params_t* params) {
#if 1
if (params->N0 == 16 && params->K0 == 2) {
return iree_uk_mmt4d_select_tile_func_x86_64_s16s16s32_M0x16x2(params);
}
#endif
if (params->N0 == 8 && params->K0 == 2) {
return iree_uk_mmt4d_select_tile_func_x86_64_s16s16s32_M0x8x2(params);
}
return 0;
}

static iree_uk_mmt4d_tile_func_t
iree_uk_mmt4d_select_tile_func_x86_64_s16u4s32_1xN0x8(
const iree_uk_mmt4d_params_t* params) {
#if defined(IREE_UK_BUILD_X86_64_AVX512_VNNI)
if (params->cpu_data[0] & IREE_CPU_DATA0_X86_64_AVX512VNNI) {
switch (params->N0) {
case 16:
return iree_uk_mmt4d_tile_s16u4s32_1x16x8_x86_64_avx512_vnni;
case 32:
return iree_uk_mmt4d_tile_s16u4s32_1x32x8_x86_64_avx512_vnni;
}
}
#endif
return 0;
}

static iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_x86_64_s16u4s32(
const iree_uk_mmt4d_params_t* params) {
if (params->M0 == 1 && params->K0 == 8) {
return iree_uk_mmt4d_select_tile_func_x86_64_s16u4s32_1xN0x8(params);
}
return 0;
}

iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_arch(
const iree_uk_mmt4d_params_t* params) {
switch (iree_uk_mmt4d_type(params->flags)) {
Expand All @@ -367,7 +389,7 @@ iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_arch(
case iree_uk_mmt4d_type_s16s16s32:
return iree_uk_mmt4d_select_tile_func_x86_64_s16s16s32(params);
case iree_uk_mmt4d_type_s16u4s32:
return 0;
return iree_uk_mmt4d_select_tile_func_x86_64_s16u4s32(params);
default:
IREE_UK_ASSUME_UNREACHABLE;
return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,5 +109,9 @@ IREE_UK_MMT4D_TILE_FUNC_DECL(
iree_uk_mmt4d_tile_s16s16s32_8x16x2_x86_64_avx512_vnni)
IREE_UK_MMT4D_TILE_FUNC_DECL(
iree_uk_mmt4d_tile_s16s16s32_16x16x2_x86_64_avx512_vnni)
IREE_UK_MMT4D_TILE_FUNC_DECL(
iree_uk_mmt4d_tile_s16u4s32_1x16x8_x86_64_avx512_vnni)
IREE_UK_MMT4D_TILE_FUNC_DECL(
iree_uk_mmt4d_tile_s16u4s32_1x32x8_x86_64_avx512_vnni)

#endif // foIREE_BUILTINS_UKERNEL_ARCH_X86_64_MMT4D_X86_64_INTERNAL_H_
4 changes: 4 additions & 0 deletions runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ int main(int argc, char** argv) {
"avx512_base");
iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16S16S32, 16, 16, 2,
"avx512_vnni");
iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16U4S32, 1, 16, 8,
"avx512_vnni");
iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16U4S32, 1, 32, 8,
"avx512_vnni");
#else // defined(IREE_ARCH_ARM_64)
// Architectures on which we do not have any optimized ukernel code.
// Benchmark some arbitrary tile shape.
Expand Down
3 changes: 2 additions & 1 deletion runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,8 @@ int main(int argc, char** argv) {
"avx512_base");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16S16S32, 16, 16, 2,
"avx512_vnni");

iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16U4S32, 1, 16, 8, "avx512_vnni");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16U4S32, 1, 32, 8, "avx512_vnni");
#endif // defined(IREE_ARCH_ARM_64)

return iree_uk_test_exit_status();
Expand Down