Skip to content

Commit

Permalink
ukernels: sub-8-bit support, types s16 * u4 -> s32 and `s16 * s16 -…
Browse files Browse the repository at this point in the history
…> s32` (iree-org#15343)

`s16 * u4 -> s32` is the main goal in iree-org#15158, and `s16 * s16 -> s32` is
seen as a less-optimized but generically useful path to run generic
"some non-8-bit quantized matmuls that we don't have super specialized
code for" on.
  • Loading branch information
bjacob authored and ramiro050 committed Dec 19, 2023
1 parent 3a8bf8e commit 29e65cd
Show file tree
Hide file tree
Showing 9 changed files with 290 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_arch(
return iree_uk_mmt4d_select_tile_func_arm_64_bf16bf16bf16(params);
case iree_uk_mmt4d_type_s8s8s32:
return iree_uk_mmt4d_select_tile_func_arm_64_i8i8i32(params);
case iree_uk_mmt4d_type_s16s16s32:
return 0;
case iree_uk_mmt4d_type_s16u4s32:
return 0;
default:
IREE_UK_ASSUME_UNREACHABLE;
return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,10 @@ iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_arch(
return iree_uk_mmt4d_select_tile_func_x86_64_bf16bf16bf16(params);
case iree_uk_mmt4d_type_s8s8s32:
return iree_uk_mmt4d_select_tile_func_x86_64_i8i8i32(params);
case iree_uk_mmt4d_type_s16s16s32:
return 0;
case iree_uk_mmt4d_type_s16u4s32:
return 0;
default:
IREE_UK_ASSUME_UNREACHABLE;
return 0;
Expand Down
22 changes: 21 additions & 1 deletion runtime/src/iree/builtins/ukernel/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -461,14 +461,17 @@ enum {
IREE_UK_TYPE_OPAQUE_16 = IREE_UK_TYPE_CATEGORY_OPAQUE | 4,
IREE_UK_TYPE_OPAQUE_32 = IREE_UK_TYPE_CATEGORY_OPAQUE | 5,
IREE_UK_TYPE_OPAQUE_64 = IREE_UK_TYPE_CATEGORY_OPAQUE | 6,
IREE_UK_TYPE_INT_4 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNLESS | 2,
IREE_UK_TYPE_INT_8 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNLESS | 3,
IREE_UK_TYPE_INT_16 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNLESS | 4,
IREE_UK_TYPE_INT_32 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNLESS | 5,
IREE_UK_TYPE_INT_64 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNLESS | 6,
IREE_UK_TYPE_SINT_4 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNED | 2,
IREE_UK_TYPE_SINT_8 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNED | 3,
IREE_UK_TYPE_SINT_16 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNED | 4,
IREE_UK_TYPE_SINT_32 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNED | 5,
IREE_UK_TYPE_SINT_64 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNED | 6,
IREE_UK_TYPE_UINT_4 = IREE_UK_TYPE_CATEGORY_INTEGER_UNSIGNED | 2,
IREE_UK_TYPE_UINT_8 = IREE_UK_TYPE_CATEGORY_INTEGER_UNSIGNED | 3,
IREE_UK_TYPE_UINT_16 = IREE_UK_TYPE_CATEGORY_INTEGER_UNSIGNED | 4,
IREE_UK_TYPE_UINT_32 = IREE_UK_TYPE_CATEGORY_INTEGER_UNSIGNED | 5,
Expand Down Expand Up @@ -531,7 +534,9 @@ static inline iree_uk_uint8_t iree_uk_integer_type_as_unsigned(
// The current implementation might return a negative value, but don't rely on
// that.
static inline int iree_uk_type_size_log2(iree_uk_type_t t) {
return iree_uk_type_bit_count_log2(t) - 3;
int bit_count_log2 = iree_uk_type_bit_count_log2(t);
IREE_UK_ASSERT(bit_count_log2 >= 3);
return bit_count_log2 - 3;
}

static inline int iree_uk_type_bit_count(iree_uk_type_t t) {
Expand All @@ -545,6 +550,21 @@ static inline int iree_uk_type_size(iree_uk_type_t t) {
return 1 << iree_uk_type_size_log2(t);
}

// Helper to correctly convert a bit-size to a byte-size, rounding up if the
// bit-size is not a multiple of 8.
static inline iree_uk_index_t iree_uk_bits_to_bytes_rounding_up(
iree_uk_index_t bits) {
return (bits + 7) / 8;
}

// Helper to correctly convert a bit-size to a byte-size, asserting that the
// bit-size is a multiple of 8.
static inline iree_uk_index_t iree_uk_bits_to_bytes_exact(
iree_uk_index_t bits) {
IREE_UK_ASSERT(!(bits % 8));
return bits / 8;
}

//===----------------------------------------------------------------------===//
// Tuples of types, packed ("tied") into a word.
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions runtime/src/iree/builtins/ukernel/exported_bits.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
#define IREE_UK_FLAG_MMT4D_TYPE_F16F16F16 0x04
#define IREE_UK_FLAG_MMT4D_TYPE_BF16BF16F32 0x05
#define IREE_UK_FLAG_MMT4D_TYPE_BF16BF16BF16 0x06
#define IREE_UK_FLAG_MMT4D_TYPE_S16S16S32 0x07
#define IREE_UK_FLAG_MMT4D_TYPE_S16U4S32 0x08
#define IREE_UK_FLAG_MMT4D_TYPE_END 0x09

// bit flags
#define IREE_UK_FLAG_MMT4D_ACCUMULATE 0x100
Expand Down
45 changes: 30 additions & 15 deletions runtime/src/iree/builtins/ukernel/mmt4d.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@ static void iree_uk_mmt4d_validate(const iree_uk_mmt4d_params_t* params) {
IREE_UK_FLAG_MMT4D_SKIP_INTERMEDIATE_ROUNDINGS;
IREE_UK_ASSERT(!(params->flags & ~allflags));
iree_uk_uint32_t flags_type = params->flags & IREE_UK_FLAG_MMT4D_TYPE_MASK;
IREE_UK_ASSERT(flags_type == IREE_UK_FLAG_MMT4D_TYPE_F32F32F32 ||
flags_type == IREE_UK_FLAG_MMT4D_TYPE_S8S8S32 ||
flags_type == IREE_UK_FLAG_MMT4D_TYPE_F16F16F32 ||
flags_type == IREE_UK_FLAG_MMT4D_TYPE_F16F16F16 ||
flags_type == IREE_UK_FLAG_MMT4D_TYPE_BF16BF16F32 ||
flags_type == IREE_UK_FLAG_MMT4D_TYPE_BF16BF16BF16);
IREE_UK_ASSERT(flags_type < IREE_UK_FLAG_MMT4D_TYPE_END);
// Some implementations may wish to avoid supporting absurdly wide types. For
// instance, K is the innermost (i.e. hottest) loop bound, so some 32bit
// targets may benefit from K being int32, not int64. We still let K be of
Expand All @@ -33,8 +28,22 @@ static void iree_uk_mmt4d_validate(const iree_uk_mmt4d_params_t* params) {
IREE_UK_ASSERT(IREE_UK_VALUE_IN_UNSIGNED_INT_RANGE(params->M0, 15));
IREE_UK_ASSERT(IREE_UK_VALUE_IN_UNSIGNED_INT_RANGE(params->N0, 15));
IREE_UK_ASSERT(IREE_UK_VALUE_IN_UNSIGNED_INT_RANGE(params->K0, 15));
// Ensure iree_uk_mmt4d_tile_generic_max_bytes large enough for this tile.

// Requirements on sub-byte element type cases
// - Ensure that the output type is not sub-byte.
iree_uk_mmt4d_type_t mmt4d_type = iree_uk_mmt4d_type(params->flags);
IREE_UK_ASSERT(iree_uk_type_bit_count(iree_uk_mmt4d_out_type(mmt4d_type)) >=
8);
// - Ensure that (K0 * {LHS,RHS} element bits) is a multiple of 8 bits.
int lhs_bits = iree_uk_type_bit_count(iree_uk_mmt4d_lhs_type(mmt4d_type));
int rhs_bits = iree_uk_type_bit_count(iree_uk_mmt4d_lhs_type(mmt4d_type));
IREE_UK_ASSERT(!((params->K0 * lhs_bits) % 8));
IREE_UK_ASSERT(!((params->K0 * rhs_bits) % 8));
// - Ensure that {LHS,RHS} strides are multiples of 8 bits.
IREE_UK_ASSERT(!((params->lhs_stride0 * lhs_bits) % 8));
IREE_UK_ASSERT(!((params->rhs_stride0 * rhs_bits) % 8));

// Ensure iree_uk_mmt4d_tile_generic_max_bytes large enough for this tile.
IREE_UK_ASSERT(params->M0 * params->N0 *
iree_uk_type_size(iree_uk_mmt4d_out_type(mmt4d_type)) <=
iree_uk_mmt4d_tile_generic_max_bytes);
Expand All @@ -56,18 +65,24 @@ static void iree_uk_mmt4d_using_tile_func(const iree_uk_mmt4d_params_t* params,
const iree_uk_type_t lhs_type = iree_uk_mmt4d_lhs_type(mmt4d_type);
const iree_uk_type_t rhs_type = iree_uk_mmt4d_rhs_type(mmt4d_type);
const iree_uk_type_t out_type = iree_uk_mmt4d_out_type(mmt4d_type);
const iree_uk_int16_t lhs_elem_size_log2 = iree_uk_type_size_log2(lhs_type);
const iree_uk_int16_t rhs_elem_size_log2 = iree_uk_type_size_log2(rhs_type);
const iree_uk_int16_t lhs_elem_bits_log2 =
iree_uk_type_bit_count_log2(lhs_type);
const iree_uk_int16_t rhs_elem_bits_log2 =
iree_uk_type_bit_count_log2(rhs_type);
const iree_uk_int16_t out_elem_size_log2 = iree_uk_type_size_log2(out_type);
char* out_tile_row =
(char*)params->out_buffer + (params->out_offset << out_elem_size_log2);
const char* lhs_panel = (const char*)params->lhs_buffer +
(params->lhs_offset << lhs_elem_size_log2);
const char* rhs_panel_start = (const char*)params->rhs_buffer +
(params->rhs_offset << rhs_elem_size_log2);
const char* lhs_panel =
(const char*)params->lhs_buffer +
iree_uk_bits_to_bytes_exact(params->lhs_offset << lhs_elem_bits_log2);
const char* rhs_panel_start =
(const char*)params->rhs_buffer +
iree_uk_bits_to_bytes_exact(params->rhs_offset << rhs_elem_bits_log2);
iree_uk_int32_t out_tile_size = (M0 * N0) << out_elem_size_log2;
iree_uk_index_t lhs_panel_stride = params->lhs_stride0 << lhs_elem_size_log2;
iree_uk_index_t rhs_panel_stride = params->rhs_stride0 << rhs_elem_size_log2;
iree_uk_index_t lhs_panel_stride =
iree_uk_bits_to_bytes_exact(params->lhs_stride0 << lhs_elem_bits_log2);
iree_uk_index_t rhs_panel_stride =
iree_uk_bits_to_bytes_exact(params->rhs_stride0 << rhs_elem_bits_log2);
iree_uk_index_t out_stride = params->out_stride0 << out_elem_size_log2;
for (iree_uk_int32_t i = 0; i < M; ++i) {
char* out_tile = out_tile_row;
Expand Down
8 changes: 8 additions & 0 deletions runtime/src/iree/builtins/ukernel/mmt4d_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ typedef enum iree_uk_mmt4d_type_t {
IREE_UK_TIE_3_TYPES_LITERAL(FLOAT_32, FLOAT_32, FLOAT_32),
iree_uk_mmt4d_type_s8s8s32 =
IREE_UK_TIE_3_TYPES_LITERAL(SINT_8, SINT_8, SINT_32),
iree_uk_mmt4d_type_s16s16s32 =
IREE_UK_TIE_3_TYPES_LITERAL(SINT_16, SINT_16, SINT_32),
iree_uk_mmt4d_type_s16u4s32 =
IREE_UK_TIE_3_TYPES_LITERAL(SINT_16, UINT_4, SINT_32),
iree_uk_mmt4d_type_f16f16f32 =
IREE_UK_TIE_3_TYPES_LITERAL(FLOAT_16, FLOAT_16, FLOAT_32),
iree_uk_mmt4d_type_f16f16f16 =
Expand All @@ -30,6 +34,10 @@ static inline iree_uk_mmt4d_type_t iree_uk_mmt4d_type(iree_uk_uint32_t flags) {
return iree_uk_mmt4d_type_f32f32f32;
case IREE_UK_FLAG_MMT4D_TYPE_S8S8S32:
return iree_uk_mmt4d_type_s8s8s32;
case IREE_UK_FLAG_MMT4D_TYPE_S16S16S32:
return iree_uk_mmt4d_type_s16s16s32;
case IREE_UK_FLAG_MMT4D_TYPE_S16U4S32:
return iree_uk_mmt4d_type_s16u4s32;
case IREE_UK_FLAG_MMT4D_TYPE_F16F16F32:
return iree_uk_mmt4d_type_f16f16f32;
case IREE_UK_FLAG_MMT4D_TYPE_F16F16F16:
Expand Down
84 changes: 83 additions & 1 deletion runtime/src/iree/builtins/ukernel/mmt4d_tile.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#include "iree/builtins/ukernel/mmt4d_internal.h"

// Generic implementation of matmul tile, i8*i8->i32 case.
// Generic implementation of matmul tile, s8*s8->s32 case.
static void iree_uk_mmt4d_tile_s8s8s32_generic(
void* out_tile_untyped, const void* lhs_panel_untyped,
const void* rhs_panel_untyped, const iree_uk_mmt4d_params_t* params) {
Expand Down Expand Up @@ -41,6 +41,84 @@ static void iree_uk_mmt4d_tile_s8s8s32_generic(
for (int i = 0; i < M0 * N0; ++i) out_tile[i] = acc[i];
}

// Generic implementation of matmul tile, s16*s16->s32 case.
static void iree_uk_mmt4d_tile_s16s16s32_generic(
void* out_tile_untyped, const void* lhs_panel_untyped,
const void* rhs_panel_untyped, const iree_uk_mmt4d_params_t* params) {
iree_uk_int32_t* out_tile = out_tile_untyped;
const iree_uk_int16_t* lhs_panel = lhs_panel_untyped;
const iree_uk_int16_t* rhs_panel = rhs_panel_untyped;
iree_uk_int16_t M0 = params->M0;
iree_uk_int16_t N0 = params->N0;
iree_uk_int16_t K0 = params->K0;
// Initialize the local accumulator tile.
iree_uk_int32_t acc[iree_uk_mmt4d_tile_generic_max_bytes / sizeof(*out_tile)];
if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
for (int i = 0; i < M0 * N0; ++i) acc[i] = out_tile[i];
} else {
for (int i = 0; i < M0 * N0; ++i) acc[i] = 0;
}
// Accumulation loop.
for (iree_uk_index_t k = 0; k < params->K; ++k) {
for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) {
for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) {
for (iree_uk_index_t k0 = 0; k0 < K0; ++k0) {
iree_uk_int32_t lhs_i32 = lhs_panel[i0 * K0 + k0];
iree_uk_int32_t rhs_i32 = rhs_panel[j0 * K0 + k0];
acc[i0 * N0 + j0] += lhs_i32 * rhs_i32;
}
}
}
lhs_panel += M0 * K0;
rhs_panel += N0 * K0;
}
// Store the local accumulator tile to the destination.
for (int i = 0; i < M0 * N0; ++i) out_tile[i] = acc[i];
}

// Generic implementation of matmul tile, s16*u4->s32 case.
static void iree_uk_mmt4d_tile_s16u4s32_generic(
void* out_tile_untyped, const void* lhs_panel_untyped,
const void* rhs_panel_untyped, const iree_uk_mmt4d_params_t* params) {
iree_uk_int32_t* out_tile = out_tile_untyped;
const iree_uk_int16_t* lhs_panel = lhs_panel_untyped;
const iree_uk_uint8_t* rhs_panel = rhs_panel_untyped;
iree_uk_int16_t M0 = params->M0;
iree_uk_int16_t N0 = params->N0;
iree_uk_int16_t K0 = params->K0;
// K0 must be even.
IREE_UK_ASSERT(!(K0 % 2));
iree_uk_int16_t K0half = K0 / 2;
// Initialize the local accumulator tile.
iree_uk_int32_t acc[iree_uk_mmt4d_tile_generic_max_bytes / sizeof(*out_tile)];
if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
for (int i = 0; i < M0 * N0; ++i) acc[i] = out_tile[i];
} else {
for (int i = 0; i < M0 * N0; ++i) acc[i] = 0;
}
// Accumulation loop.
for (iree_uk_index_t k = 0; k < params->K; ++k) {
for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) {
for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) {
// As K0 must be even, we 2x-unroll the K0 loop, writing a 2D dot
// product.
for (iree_uk_index_t k0h = 0; k0h < K0half; ++k0h) {
iree_uk_int32_t lhs_0 = lhs_panel[i0 * K0 + 2 * k0h];
iree_uk_int32_t lhs_1 = lhs_panel[i0 * K0 + 2 * k0h + 1];
iree_uk_uint8_t rhs_byte = rhs_panel[j0 * K0half + k0h];
iree_uk_int32_t rhs_0 = rhs_byte & 0xf;
iree_uk_int32_t rhs_1 = rhs_byte >> 4;
acc[i0 * N0 + j0] += lhs_0 * rhs_0 + lhs_1 * rhs_1;
}
}
}
lhs_panel += M0 * K0;
rhs_panel += N0 * K0half;
}
// Store the local accumulator tile to the destination.
for (int i = 0; i < M0 * N0; ++i) out_tile[i] = acc[i];
}

// Generic implementation of matmul tile, f32*f32->f32 case.
static void iree_uk_mmt4d_tile_f32f32f32_generic(
void* out_tile_untyped, const void* lhs_panel_untyped,
Expand Down Expand Up @@ -227,6 +305,10 @@ static iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_generic(
return iree_uk_mmt4d_tile_f32f32f32_generic;
case iree_uk_mmt4d_type_s8s8s32:
return iree_uk_mmt4d_tile_s8s8s32_generic;
case iree_uk_mmt4d_type_s16s16s32:
return iree_uk_mmt4d_tile_s16s16s32_generic;
case iree_uk_mmt4d_type_s16u4s32:
return iree_uk_mmt4d_tile_s16u4s32_generic;
case iree_uk_mmt4d_type_f16f16f32:
return iree_uk_mmt4d_tile_f16f16f32_generic;
case iree_uk_mmt4d_type_f16f16f16:
Expand Down
Loading

0 comments on commit 29e65cd

Please sign in to comment.