Skip to content

Commit

Permalink
Optimized vecmat ukernel tile functions for i16 x u4 -> i32 on AVX-…
Browse files Browse the repository at this point in the history
…512-VNNI (iree-org#15525)

This kernel is parametrized in N0, allowing N0==16 and N0==32.
Performance on AMD Ryzen 9 7950X3D:
  - With N0=16:  180 Gop/s.
  - With N0=32:  240 Gop/s.

These numbers show that there's a nice reward for going extra large, but
that's also a liability for vecmat shapes whose N dimension isn't a
multiple of 32. Maybe we can keep both for now.

This is currently by far our fastest vecmat tile function --- it's fast
even by general-matmul standards, while usually vecmat's low arithmetic
intensity relegates it to lower performance levels. It shows what's
possible now that we've decoupled vecmat tile shapes from general matmul
tile shapes in iree-org#15431 . That 32x8 is not a truncation of a general
matmul tile shape. Other element types and CPU architectures all need to
get the same treatment.

The idea of this kernel is to split the LHS s16 values into high and low
8-bit components to be able to use `_mm512_dpbusd_epi32`.

In itself, that doesn't reduce the number of arithmetic instructions:
while each now computes a 4D dot-product instead of a 2D one as in
`_mm512_dpwssd_epi32`, we now need twice more of them to do separately
the high and low 8bit parts of the LHS s16 values.

The real benefit is that this removes the need to extend RHS u4 values
to s16. Since this is a vecmat kernel, the LHS is small and the RHS is
big, so it matters to avoid RHS-processing work.

It's not trivial how to use `_mm512_dpbusd_epi32`, with its quirky
unsigned * signed semantics. We take advantage of the fact that our u4
RHS values, when extended to u8, do not use the top bit -- so they are
also interpretable as s8 values in place. So this is specific to RHS
being less-than-8-bit values (it's not specific beyond that to 4bit).
Meanwhile, when we split the LHS s16 values into high and low 8bit
components the high 8bits are signed s8 and the low 8bit are unsigned
u8. So, for each of the combinations of operands that we have to feed
`_mm512_dpbusd_epi32`, we manage to find an operand order that
accomodates the instruction's requirements on signednesses.
  • Loading branch information
bjacob authored and ramiro050 committed Dec 19, 2023
1 parent 3f8a13b commit ce4c630
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,143 @@ IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8_16(
iree_uk_mmt4d_tile_s16s16s32_4x16x2_x86_64_avx512_vnni,
iree_uk_mmt4d_tile_s16s16s32_8x16x2_x86_64_avx512_vnni,
iree_uk_mmt4d_tile_s16s16s32_16x16x2_x86_64_avx512_vnni)

// This kernel is parametrized in N0, allowing N0==16 and N0==32.
// Performance on AMD Ryzen 9 7950X3D:
// - with N0=16: 180 Gop/s
// - with N0=32: 240 Gop/s
// So there's a nice reward for going extra large, but that's also a liability
// for vecmat shapes whose N dimension isn't a multiple of 32. Maybe we can
// keep both for now.
//
// The idea of this kernel is to split the LHS s16 values into high and low
// 8-bit components to be able to use _mm512_dpbusd_epi32.
//
// In itself, that doesn't reduce the number of arithmetic instructions: while
// each now computes a 4D dot-product instead of a 2D one as in
// _mm512_dpwssd_epi32, we now need twice more of them to do separately the
// high and low 8bit parts of the LHS s16 values.
//
// The real benefit is that this removes the need to extend RHS u4 values to
// s16. Since this is a vecmat kernel, the LHS is small and the RHS is big,
// so it matters to avoid RHS-processing work.
//
// It's not trivial how to use _mm512_dpbusd_epi32, with its quirky
// unsigned * signed semantics. We take advantage of the fact that our u4
// RHS values, when extended to u8, do not use the top bit -- so they are
// also interpretable as s8 values in place. So this is specific to RHS
// being less-than-8-bit values (it's not specific beyond that to 4bit).
// Meanwhile, when we split the LHS s16 values into high and low 8bit components
// the high 8bits are signed s8 and the low 8bit are unsigned u8. So, for each
// of the combinations of operands that we have to feed _mm512_dpbusd_epi32,
// we manage to find an operand order that accomodates the instruction's
// requirements on signednesses.
static inline void
iree_uk_mmt4d_tile_s16u4s32_1x16x8_to_1x32x8_x86_64_avx512_vnni(
void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
const void* IREE_UK_RESTRICT rhs_panel,
const iree_uk_mmt4d_params_t* params, int N0) {
IREE_UK_ASSERT(N0 >= 1 && N0 <= 16 && iree_uk_is_po2_u32(N0));
iree_uk_int32_t* IREE_UK_RESTRICT out_ptr = out_tile;
const iree_uk_int16_t* IREE_UK_RESTRICT lhs_ptr = lhs_panel;
const iree_uk_uint8_t* IREE_UK_RESTRICT rhs_ptr = rhs_panel;
// acc[4 * i] is the actual accumulator.
// The other acc[4 * i + j] are only used internally in the accumulation loop.
__m512i acc[8];
if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
for (int i = 0; i < N0 / 16; ++i) {
acc[4 * i] = _mm512_loadu_si512((const __m512i*)(out_ptr + 16 * i));
}
} else {
for (int i = 0; i < N0 / 16; ++i) {
acc[4 * i] = _mm512_setzero_si512();
}
}
for (int i = 0; i < N0 / 16; ++i) {
for (int j = 1; j < 4; ++j) {
acc[4 * i + j] = _mm512_setzero_si512();
}
}

const __m128i idx_0_mod_4 = _mm_set1_epi32(0x0c080400);
const __m128i idx_1_mod_4 = _mm_set1_epi32(0x0d090501);
const __m128i idx_2_mod_4 = _mm_set1_epi32(0x0e0a0602);
const __m128i idx_3_mod_4 = _mm_set1_epi32(0x0f0b0703);
const __m512i mask_0f = _mm512_set1_epi8(0x0f);
IREE_UK_ASSUME(params->K >= 1);
for (iree_uk_int32_t k = 0; k < params->K; ++k) {
// Load 8xs16 LHS data.
__m128i lhs = _mm_loadu_si128((const __m128i*)lhs_ptr);
lhs_ptr += 8;
// Extract the even/odd s16 lanes and within them, the low/high 8bit parts,
// and broadcast into 512bit registers to multiply against RHS data.
__m512i lhs_even_s16_low_u8 =
_mm512_broadcastq_epi64(_mm_shuffle_epi8(lhs, idx_0_mod_4));
__m512i lhs_even_s16_high_s8 =
_mm512_broadcastq_epi64(_mm_shuffle_epi8(lhs, idx_1_mod_4));
__m512i lhs_odd_s16_low_u8 =
_mm512_broadcastq_epi64(_mm_shuffle_epi8(lhs, idx_2_mod_4));
__m512i lhs_odd_s16_high_s8 =
_mm512_broadcastq_epi64(_mm_shuffle_epi8(lhs, idx_3_mod_4));
// Load 8x16xu4 RHS data.
__m512i rhs[2];
for (int i = 0; i < N0 / 16; ++i) {
rhs[i] = _mm512_loadu_si512((const __m512i*)(rhs_ptr + 64 * i));
}
rhs_ptr += N0 * 4;
// Extract the even/odd u4 lanes.
__m512i rhs_even_u4[2];
__m512i rhs_odd_u4[2];
for (int i = 0; i < N0 / 16; ++i) {
rhs_even_u4[i] = _mm512_and_si512(mask_0f, rhs[i]);
rhs_odd_u4[i] = _mm512_and_si512(mask_0f, _mm512_srli_epi16(rhs[i], 4));
}
// Arithmetic. See the comment at the top of this kernel for an explanation.
// _mm512_dpbusd_epi32 takes an unsigned LHS and a signed RHS. The parameter
// order in each call is adapted to that constraint.
for (int i = 0; i < N0 / 16; ++i) {
acc[4 * i + 0] = _mm512_dpbusd_epi32(acc[4 * i + 0], lhs_even_s16_low_u8,
rhs_even_u4[i]);
acc[4 * i + 1] = _mm512_dpbusd_epi32(acc[4 * i + 1], rhs_even_u4[i],
lhs_even_s16_high_s8);
acc[4 * i + 2] = _mm512_dpbusd_epi32(acc[4 * i + 2], lhs_odd_s16_low_u8,
rhs_odd_u4[i]);
acc[4 * i + 3] = _mm512_dpbusd_epi32(acc[4 * i + 3], rhs_odd_u4[i],
lhs_odd_s16_high_s8);
}
}

// The accumulators that contain products against high 8bit parts of s16 LHS
// values need to be left-shifted by 8 bits to account for that.
for (int i = 0; i < N0 / 16; ++i) {
acc[4 * i + 1] = _mm512_slli_epi32(acc[4 * i + 1], 8);
acc[4 * i + 3] = _mm512_slli_epi32(acc[4 * i + 3], 8);
}

// Add accumulators together.
for (int i = 0; i < N0 / 16; ++i) {
for (int j = 1; j <= 3; ++j) {
acc[4 * i + 0] = _mm512_add_epi32(acc[4 * i + 0], acc[4 * i + j]);
}
}

for (int i = 0; i < N0 / 16; ++i) {
_mm512_storeu_si512((__m512i*)(out_ptr + 16 * i), acc[4 * i]);
}
}

void iree_uk_mmt4d_tile_s16u4s32_1x16x8_x86_64_avx512_vnni(
void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
const void* IREE_UK_RESTRICT rhs_panel,
const iree_uk_mmt4d_params_t* params) {
iree_uk_mmt4d_tile_s16u4s32_1x16x8_to_1x32x8_x86_64_avx512_vnni(
out_tile, lhs_panel, rhs_panel, params, 16);
}

void iree_uk_mmt4d_tile_s16u4s32_1x32x8_x86_64_avx512_vnni(
void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
const void* IREE_UK_RESTRICT rhs_panel,
const iree_uk_mmt4d_params_t* params) {
iree_uk_mmt4d_tile_s16u4s32_1x16x8_to_1x32x8_x86_64_avx512_vnni(
out_tile, lhs_panel, rhs_panel, params, 32);
}
Original file line number Diff line number Diff line change
Expand Up @@ -338,17 +338,39 @@ iree_uk_mmt4d_select_tile_func_x86_64_s16s16s32_M0x8x2(
static iree_uk_mmt4d_tile_func_t
iree_uk_mmt4d_select_tile_func_x86_64_s16s16s32(
const iree_uk_mmt4d_params_t* params) {
#if 1
if (params->N0 == 16 && params->K0 == 2) {
return iree_uk_mmt4d_select_tile_func_x86_64_s16s16s32_M0x16x2(params);
}
#endif
if (params->N0 == 8 && params->K0 == 2) {
return iree_uk_mmt4d_select_tile_func_x86_64_s16s16s32_M0x8x2(params);
}
return 0;
}

static iree_uk_mmt4d_tile_func_t
iree_uk_mmt4d_select_tile_func_x86_64_s16u4s32_1xN0x8(
const iree_uk_mmt4d_params_t* params) {
#if defined(IREE_UK_BUILD_X86_64_AVX512_VNNI)
if (params->cpu_data[0] & IREE_CPU_DATA0_X86_64_AVX512VNNI) {
switch (params->N0) {
case 16:
return iree_uk_mmt4d_tile_s16u4s32_1x16x8_x86_64_avx512_vnni;
case 32:
return iree_uk_mmt4d_tile_s16u4s32_1x32x8_x86_64_avx512_vnni;
}
}
#endif
return 0;
}

static iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_x86_64_s16u4s32(
const iree_uk_mmt4d_params_t* params) {
if (params->M0 == 1 && params->K0 == 8) {
return iree_uk_mmt4d_select_tile_func_x86_64_s16u4s32_1xN0x8(params);
}
return 0;
}

iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_arch(
const iree_uk_mmt4d_params_t* params) {
switch (iree_uk_mmt4d_type(params->flags)) {
Expand All @@ -367,7 +389,7 @@ iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_arch(
case iree_uk_mmt4d_type_s16s16s32:
return iree_uk_mmt4d_select_tile_func_x86_64_s16s16s32(params);
case iree_uk_mmt4d_type_s16u4s32:
return 0;
return iree_uk_mmt4d_select_tile_func_x86_64_s16u4s32(params);
default:
IREE_UK_ASSUME_UNREACHABLE;
return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,5 +109,9 @@ IREE_UK_MMT4D_TILE_FUNC_DECL(
iree_uk_mmt4d_tile_s16s16s32_8x16x2_x86_64_avx512_vnni)
IREE_UK_MMT4D_TILE_FUNC_DECL(
iree_uk_mmt4d_tile_s16s16s32_16x16x2_x86_64_avx512_vnni)
IREE_UK_MMT4D_TILE_FUNC_DECL(
iree_uk_mmt4d_tile_s16u4s32_1x16x8_x86_64_avx512_vnni)
IREE_UK_MMT4D_TILE_FUNC_DECL(
iree_uk_mmt4d_tile_s16u4s32_1x32x8_x86_64_avx512_vnni)

#endif // foIREE_BUILTINS_UKERNEL_ARCH_X86_64_MMT4D_X86_64_INTERNAL_H_
4 changes: 4 additions & 0 deletions runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ int main(int argc, char** argv) {
"avx512_base");
iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16S16S32, 16, 16, 2,
"avx512_vnni");
iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16U4S32, 1, 16, 8,
"avx512_vnni");
iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16U4S32, 1, 32, 8,
"avx512_vnni");
#else // defined(IREE_ARCH_ARM_64)
// Architectures on which we do not have any optimized ukernel code.
// Benchmark some arbitrary tile shape.
Expand Down
3 changes: 2 additions & 1 deletion runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,8 @@ int main(int argc, char** argv) {
"avx512_base");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16S16S32, 16, 16, 2,
"avx512_vnni");

iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16U4S32, 1, 16, 8, "avx512_vnni");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16U4S32, 1, 32, 8, "avx512_vnni");
#endif // defined(IREE_ARCH_ARM_64)

return iree_uk_test_exit_status();
Expand Down

0 comments on commit ce4c630

Please sign in to comment.