diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c index 6ec64ca94810f..dbff2c795f30b 100644 --- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c +++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c @@ -232,6 +232,10 @@ iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_arch( return iree_uk_mmt4d_select_tile_func_arm_64_bf16bf16bf16(params); case iree_uk_mmt4d_type_s8s8s32: return iree_uk_mmt4d_select_tile_func_arm_64_i8i8i32(params); + case iree_uk_mmt4d_type_s16s16s32: + return 0; + case iree_uk_mmt4d_type_s16u4s32: + return 0; default: IREE_UK_ASSUME_UNREACHABLE; return 0; diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c index f31923e9611ef..09317c6de926f 100644 --- a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c +++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c @@ -292,6 +292,10 @@ iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_arch( return iree_uk_mmt4d_select_tile_func_x86_64_bf16bf16bf16(params); case iree_uk_mmt4d_type_s8s8s32: return iree_uk_mmt4d_select_tile_func_x86_64_i8i8i32(params); + case iree_uk_mmt4d_type_s16s16s32: + return 0; + case iree_uk_mmt4d_type_s16u4s32: + return 0; default: IREE_UK_ASSUME_UNREACHABLE; return 0; diff --git a/runtime/src/iree/builtins/ukernel/common.h b/runtime/src/iree/builtins/ukernel/common.h index b196d856776ab..8283cda572848 100644 --- a/runtime/src/iree/builtins/ukernel/common.h +++ b/runtime/src/iree/builtins/ukernel/common.h @@ -461,14 +461,17 @@ enum { IREE_UK_TYPE_OPAQUE_16 = IREE_UK_TYPE_CATEGORY_OPAQUE | 4, IREE_UK_TYPE_OPAQUE_32 = IREE_UK_TYPE_CATEGORY_OPAQUE | 5, IREE_UK_TYPE_OPAQUE_64 = IREE_UK_TYPE_CATEGORY_OPAQUE | 6, + IREE_UK_TYPE_INT_4 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNLESS | 2, IREE_UK_TYPE_INT_8 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNLESS | 3, IREE_UK_TYPE_INT_16 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNLESS | 4, IREE_UK_TYPE_INT_32 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNLESS | 5, IREE_UK_TYPE_INT_64 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNLESS | 6, + IREE_UK_TYPE_SINT_4 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNED | 2, IREE_UK_TYPE_SINT_8 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNED | 3, IREE_UK_TYPE_SINT_16 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNED | 4, IREE_UK_TYPE_SINT_32 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNED | 5, IREE_UK_TYPE_SINT_64 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNED | 6, + IREE_UK_TYPE_UINT_4 = IREE_UK_TYPE_CATEGORY_INTEGER_UNSIGNED | 2, IREE_UK_TYPE_UINT_8 = IREE_UK_TYPE_CATEGORY_INTEGER_UNSIGNED | 3, IREE_UK_TYPE_UINT_16 = IREE_UK_TYPE_CATEGORY_INTEGER_UNSIGNED | 4, IREE_UK_TYPE_UINT_32 = IREE_UK_TYPE_CATEGORY_INTEGER_UNSIGNED | 5, @@ -531,7 +534,9 @@ static inline iree_uk_uint8_t iree_uk_integer_type_as_unsigned( // The current implementation might return a negative value, but don't rely on // that. static inline int iree_uk_type_size_log2(iree_uk_type_t t) { - return iree_uk_type_bit_count_log2(t) - 3; + int bit_count_log2 = iree_uk_type_bit_count_log2(t); + IREE_UK_ASSERT(bit_count_log2 >= 3); + return bit_count_log2 - 3; } static inline int iree_uk_type_bit_count(iree_uk_type_t t) { @@ -545,6 +550,21 @@ static inline int iree_uk_type_size(iree_uk_type_t t) { return 1 << iree_uk_type_size_log2(t); } +// Helper to correctly convert a bit-size to a byte-size, rounding up if the +// bit-size is not a multiple of 8. +static inline iree_uk_index_t iree_uk_bits_to_bytes_rounding_up( + iree_uk_index_t bits) { + return (bits + 7) / 8; +} + +// Helper to correctly convert a bit-size to a byte-size, asserting that the +// bit-size is a multiple of 8. +static inline iree_uk_index_t iree_uk_bits_to_bytes_exact( + iree_uk_index_t bits) { + IREE_UK_ASSERT(!(bits % 8)); + return bits / 8; +} + //===----------------------------------------------------------------------===// // Tuples of types, packed ("tied") into a word. //===----------------------------------------------------------------------===// diff --git a/runtime/src/iree/builtins/ukernel/exported_bits.h b/runtime/src/iree/builtins/ukernel/exported_bits.h index e2dfd6af1d33b..c027c5b3b4e10 100644 --- a/runtime/src/iree/builtins/ukernel/exported_bits.h +++ b/runtime/src/iree/builtins/ukernel/exported_bits.h @@ -49,6 +49,9 @@ #define IREE_UK_FLAG_MMT4D_TYPE_F16F16F16 0x04 #define IREE_UK_FLAG_MMT4D_TYPE_BF16BF16F32 0x05 #define IREE_UK_FLAG_MMT4D_TYPE_BF16BF16BF16 0x06 +#define IREE_UK_FLAG_MMT4D_TYPE_S16S16S32 0x07 +#define IREE_UK_FLAG_MMT4D_TYPE_S16U4S32 0x08 +#define IREE_UK_FLAG_MMT4D_TYPE_END 0x09 // bit flags #define IREE_UK_FLAG_MMT4D_ACCUMULATE 0x100 diff --git a/runtime/src/iree/builtins/ukernel/mmt4d.c b/runtime/src/iree/builtins/ukernel/mmt4d.c index a626a8c7da08e..e775ca1265145 100644 --- a/runtime/src/iree/builtins/ukernel/mmt4d.c +++ b/runtime/src/iree/builtins/ukernel/mmt4d.c @@ -15,12 +15,7 @@ static void iree_uk_mmt4d_validate(const iree_uk_mmt4d_params_t* params) { IREE_UK_FLAG_MMT4D_SKIP_INTERMEDIATE_ROUNDINGS; IREE_UK_ASSERT(!(params->flags & ~allflags)); iree_uk_uint32_t flags_type = params->flags & IREE_UK_FLAG_MMT4D_TYPE_MASK; - IREE_UK_ASSERT(flags_type == IREE_UK_FLAG_MMT4D_TYPE_F32F32F32 || - flags_type == IREE_UK_FLAG_MMT4D_TYPE_S8S8S32 || - flags_type == IREE_UK_FLAG_MMT4D_TYPE_F16F16F32 || - flags_type == IREE_UK_FLAG_MMT4D_TYPE_F16F16F16 || - flags_type == IREE_UK_FLAG_MMT4D_TYPE_BF16BF16F32 || - flags_type == IREE_UK_FLAG_MMT4D_TYPE_BF16BF16BF16); + IREE_UK_ASSERT(flags_type < IREE_UK_FLAG_MMT4D_TYPE_END); // Some implementations may wish to avoid supporting absurdly wide types. For // instance, K is the innermost (i.e. hottest) loop bound, so some 32bit // targets may benefit from K being int32, not int64. We still let K be of @@ -33,8 +28,22 @@ static void iree_uk_mmt4d_validate(const iree_uk_mmt4d_params_t* params) { IREE_UK_ASSERT(IREE_UK_VALUE_IN_UNSIGNED_INT_RANGE(params->M0, 15)); IREE_UK_ASSERT(IREE_UK_VALUE_IN_UNSIGNED_INT_RANGE(params->N0, 15)); IREE_UK_ASSERT(IREE_UK_VALUE_IN_UNSIGNED_INT_RANGE(params->K0, 15)); - // Ensure iree_uk_mmt4d_tile_generic_max_bytes large enough for this tile. + + // Requirements on sub-byte element type cases + // - Ensure that the output type is not sub-byte. iree_uk_mmt4d_type_t mmt4d_type = iree_uk_mmt4d_type(params->flags); + IREE_UK_ASSERT(iree_uk_type_bit_count(iree_uk_mmt4d_out_type(mmt4d_type)) >= + 8); + // - Ensure that (K0 * {LHS,RHS} element bits) is a multiple of 8 bits. + int lhs_bits = iree_uk_type_bit_count(iree_uk_mmt4d_lhs_type(mmt4d_type)); + int rhs_bits = iree_uk_type_bit_count(iree_uk_mmt4d_lhs_type(mmt4d_type)); + IREE_UK_ASSERT(!((params->K0 * lhs_bits) % 8)); + IREE_UK_ASSERT(!((params->K0 * rhs_bits) % 8)); + // - Ensure that {LHS,RHS} strides are multiples of 8 bits. + IREE_UK_ASSERT(!((params->lhs_stride0 * lhs_bits) % 8)); + IREE_UK_ASSERT(!((params->rhs_stride0 * rhs_bits) % 8)); + + // Ensure iree_uk_mmt4d_tile_generic_max_bytes large enough for this tile. IREE_UK_ASSERT(params->M0 * params->N0 * iree_uk_type_size(iree_uk_mmt4d_out_type(mmt4d_type)) <= iree_uk_mmt4d_tile_generic_max_bytes); @@ -56,18 +65,24 @@ static void iree_uk_mmt4d_using_tile_func(const iree_uk_mmt4d_params_t* params, const iree_uk_type_t lhs_type = iree_uk_mmt4d_lhs_type(mmt4d_type); const iree_uk_type_t rhs_type = iree_uk_mmt4d_rhs_type(mmt4d_type); const iree_uk_type_t out_type = iree_uk_mmt4d_out_type(mmt4d_type); - const iree_uk_int16_t lhs_elem_size_log2 = iree_uk_type_size_log2(lhs_type); - const iree_uk_int16_t rhs_elem_size_log2 = iree_uk_type_size_log2(rhs_type); + const iree_uk_int16_t lhs_elem_bits_log2 = + iree_uk_type_bit_count_log2(lhs_type); + const iree_uk_int16_t rhs_elem_bits_log2 = + iree_uk_type_bit_count_log2(rhs_type); const iree_uk_int16_t out_elem_size_log2 = iree_uk_type_size_log2(out_type); char* out_tile_row = (char*)params->out_buffer + (params->out_offset << out_elem_size_log2); - const char* lhs_panel = (const char*)params->lhs_buffer + - (params->lhs_offset << lhs_elem_size_log2); - const char* rhs_panel_start = (const char*)params->rhs_buffer + - (params->rhs_offset << rhs_elem_size_log2); + const char* lhs_panel = + (const char*)params->lhs_buffer + + iree_uk_bits_to_bytes_exact(params->lhs_offset << lhs_elem_bits_log2); + const char* rhs_panel_start = + (const char*)params->rhs_buffer + + iree_uk_bits_to_bytes_exact(params->rhs_offset << rhs_elem_bits_log2); iree_uk_int32_t out_tile_size = (M0 * N0) << out_elem_size_log2; - iree_uk_index_t lhs_panel_stride = params->lhs_stride0 << lhs_elem_size_log2; - iree_uk_index_t rhs_panel_stride = params->rhs_stride0 << rhs_elem_size_log2; + iree_uk_index_t lhs_panel_stride = + iree_uk_bits_to_bytes_exact(params->lhs_stride0 << lhs_elem_bits_log2); + iree_uk_index_t rhs_panel_stride = + iree_uk_bits_to_bytes_exact(params->rhs_stride0 << rhs_elem_bits_log2); iree_uk_index_t out_stride = params->out_stride0 << out_elem_size_log2; for (iree_uk_int32_t i = 0; i < M; ++i) { char* out_tile = out_tile_row; diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_internal.h b/runtime/src/iree/builtins/ukernel/mmt4d_internal.h index ebdd673194fc0..1b743c2593a1d 100644 --- a/runtime/src/iree/builtins/ukernel/mmt4d_internal.h +++ b/runtime/src/iree/builtins/ukernel/mmt4d_internal.h @@ -14,6 +14,10 @@ typedef enum iree_uk_mmt4d_type_t { IREE_UK_TIE_3_TYPES_LITERAL(FLOAT_32, FLOAT_32, FLOAT_32), iree_uk_mmt4d_type_s8s8s32 = IREE_UK_TIE_3_TYPES_LITERAL(SINT_8, SINT_8, SINT_32), + iree_uk_mmt4d_type_s16s16s32 = + IREE_UK_TIE_3_TYPES_LITERAL(SINT_16, SINT_16, SINT_32), + iree_uk_mmt4d_type_s16u4s32 = + IREE_UK_TIE_3_TYPES_LITERAL(SINT_16, UINT_4, SINT_32), iree_uk_mmt4d_type_f16f16f32 = IREE_UK_TIE_3_TYPES_LITERAL(FLOAT_16, FLOAT_16, FLOAT_32), iree_uk_mmt4d_type_f16f16f16 = @@ -30,6 +34,10 @@ static inline iree_uk_mmt4d_type_t iree_uk_mmt4d_type(iree_uk_uint32_t flags) { return iree_uk_mmt4d_type_f32f32f32; case IREE_UK_FLAG_MMT4D_TYPE_S8S8S32: return iree_uk_mmt4d_type_s8s8s32; + case IREE_UK_FLAG_MMT4D_TYPE_S16S16S32: + return iree_uk_mmt4d_type_s16s16s32; + case IREE_UK_FLAG_MMT4D_TYPE_S16U4S32: + return iree_uk_mmt4d_type_s16u4s32; case IREE_UK_FLAG_MMT4D_TYPE_F16F16F32: return iree_uk_mmt4d_type_f16f16f32; case IREE_UK_FLAG_MMT4D_TYPE_F16F16F16: diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_tile.c b/runtime/src/iree/builtins/ukernel/mmt4d_tile.c index 622ae89863bd4..6604d3baedb11 100644 --- a/runtime/src/iree/builtins/ukernel/mmt4d_tile.c +++ b/runtime/src/iree/builtins/ukernel/mmt4d_tile.c @@ -6,7 +6,7 @@ #include "iree/builtins/ukernel/mmt4d_internal.h" -// Generic implementation of matmul tile, i8*i8->i32 case. +// Generic implementation of matmul tile, s8*s8->s32 case. static void iree_uk_mmt4d_tile_s8s8s32_generic( void* out_tile_untyped, const void* lhs_panel_untyped, const void* rhs_panel_untyped, const iree_uk_mmt4d_params_t* params) { @@ -41,6 +41,84 @@ static void iree_uk_mmt4d_tile_s8s8s32_generic( for (int i = 0; i < M0 * N0; ++i) out_tile[i] = acc[i]; } +// Generic implementation of matmul tile, s16*s16->s32 case. +static void iree_uk_mmt4d_tile_s16s16s32_generic( + void* out_tile_untyped, const void* lhs_panel_untyped, + const void* rhs_panel_untyped, const iree_uk_mmt4d_params_t* params) { + iree_uk_int32_t* out_tile = out_tile_untyped; + const iree_uk_int16_t* lhs_panel = lhs_panel_untyped; + const iree_uk_int16_t* rhs_panel = rhs_panel_untyped; + iree_uk_int16_t M0 = params->M0; + iree_uk_int16_t N0 = params->N0; + iree_uk_int16_t K0 = params->K0; + // Initialize the local accumulator tile. + iree_uk_int32_t acc[iree_uk_mmt4d_tile_generic_max_bytes / sizeof(*out_tile)]; + if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) { + for (int i = 0; i < M0 * N0; ++i) acc[i] = out_tile[i]; + } else { + for (int i = 0; i < M0 * N0; ++i) acc[i] = 0; + } + // Accumulation loop. + for (iree_uk_index_t k = 0; k < params->K; ++k) { + for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) { + for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) { + for (iree_uk_index_t k0 = 0; k0 < K0; ++k0) { + iree_uk_int32_t lhs_i32 = lhs_panel[i0 * K0 + k0]; + iree_uk_int32_t rhs_i32 = rhs_panel[j0 * K0 + k0]; + acc[i0 * N0 + j0] += lhs_i32 * rhs_i32; + } + } + } + lhs_panel += M0 * K0; + rhs_panel += N0 * K0; + } + // Store the local accumulator tile to the destination. + for (int i = 0; i < M0 * N0; ++i) out_tile[i] = acc[i]; +} + +// Generic implementation of matmul tile, s16*u4->s32 case. +static void iree_uk_mmt4d_tile_s16u4s32_generic( + void* out_tile_untyped, const void* lhs_panel_untyped, + const void* rhs_panel_untyped, const iree_uk_mmt4d_params_t* params) { + iree_uk_int32_t* out_tile = out_tile_untyped; + const iree_uk_int16_t* lhs_panel = lhs_panel_untyped; + const iree_uk_uint8_t* rhs_panel = rhs_panel_untyped; + iree_uk_int16_t M0 = params->M0; + iree_uk_int16_t N0 = params->N0; + iree_uk_int16_t K0 = params->K0; + // K0 must be even. + IREE_UK_ASSERT(!(K0 % 2)); + iree_uk_int16_t K0half = K0 / 2; + // Initialize the local accumulator tile. + iree_uk_int32_t acc[iree_uk_mmt4d_tile_generic_max_bytes / sizeof(*out_tile)]; + if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) { + for (int i = 0; i < M0 * N0; ++i) acc[i] = out_tile[i]; + } else { + for (int i = 0; i < M0 * N0; ++i) acc[i] = 0; + } + // Accumulation loop. + for (iree_uk_index_t k = 0; k < params->K; ++k) { + for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) { + for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) { + // As K0 must be even, we 2x-unroll the K0 loop, writing a 2D dot + // product. + for (iree_uk_index_t k0h = 0; k0h < K0half; ++k0h) { + iree_uk_int32_t lhs_0 = lhs_panel[i0 * K0 + 2 * k0h]; + iree_uk_int32_t lhs_1 = lhs_panel[i0 * K0 + 2 * k0h + 1]; + iree_uk_uint8_t rhs_byte = rhs_panel[j0 * K0half + k0h]; + iree_uk_int32_t rhs_0 = rhs_byte & 0xf; + iree_uk_int32_t rhs_1 = rhs_byte >> 4; + acc[i0 * N0 + j0] += lhs_0 * rhs_0 + lhs_1 * rhs_1; + } + } + } + lhs_panel += M0 * K0; + rhs_panel += N0 * K0half; + } + // Store the local accumulator tile to the destination. + for (int i = 0; i < M0 * N0; ++i) out_tile[i] = acc[i]; +} + // Generic implementation of matmul tile, f32*f32->f32 case. static void iree_uk_mmt4d_tile_f32f32f32_generic( void* out_tile_untyped, const void* lhs_panel_untyped, @@ -227,6 +305,10 @@ static iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_generic( return iree_uk_mmt4d_tile_f32f32f32_generic; case iree_uk_mmt4d_type_s8s8s32: return iree_uk_mmt4d_tile_s8s8s32_generic; + case iree_uk_mmt4d_type_s16s16s32: + return iree_uk_mmt4d_tile_s16s16s32_generic; + case iree_uk_mmt4d_type_s16u4s32: + return iree_uk_mmt4d_tile_s16u4s32_generic; case iree_uk_mmt4d_type_f16f16f32: return iree_uk_mmt4d_tile_f16f16f32_generic; case iree_uk_mmt4d_type_f16f16f16: diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c index 84ae68f9e07bd..21d1ea7464f99 100644 --- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c +++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c @@ -137,12 +137,47 @@ static void iree_mmt4d_reference_innerloop_s8s8s32( *out_ptr = acc; } +static void iree_mmt4d_reference_innerloop_s16s16s32( + int32_t* out_ptr, const int16_t* lhs_ptr, const int16_t* rhs_ptr, + const iree_uk_mmt4d_params_t* params) { + int32_t acc = params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE ? *out_ptr : 0; + for (iree_uk_index_t k = 0; k < params->K; ++k) { + for (iree_uk_index_t k0 = 0; k0 < params->K0; ++k0) { + int32_t lhs_i32 = lhs_ptr[k * params->M0 * params->K0 + k0]; + int32_t rhs_i32 = rhs_ptr[k * params->N0 * params->K0 + k0]; + acc += lhs_i32 * rhs_i32; + } + } + *out_ptr = acc; +} + +static void iree_mmt4d_reference_innerloop_s16u4s32( + int32_t* out_ptr, const int16_t* lhs_ptr, const uint8_t* rhs_ptr, + const iree_uk_mmt4d_params_t* params) { + // K0 must be even. + IREE_UK_ASSERT(!(params->K0 % 2)); + iree_uk_int16_t K0half = params->K0 / 2; + int32_t acc = params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE ? *out_ptr : 0; + for (iree_uk_index_t k = 0; k < params->K; ++k) { + // As K0 must be even, we 2x-unroll the K0 loop, writing a 2D dot product. + for (iree_uk_index_t k0h = 0; k0h < K0half; ++k0h) { + int32_t lhs_0 = lhs_ptr[k * params->M0 * params->K0 + 2 * k0h]; + int32_t lhs_1 = lhs_ptr[k * params->M0 * params->K0 + 2 * k0h + 1]; + uint8_t rhs_byte = rhs_ptr[k * params->N0 * K0half + k0h]; + int32_t rhs_0 = rhs_byte & 0xf; + int32_t rhs_1 = rhs_byte >> 4; + acc += lhs_0 * rhs_0 + lhs_1 * rhs_1; + } + } + *out_ptr = acc; +} + static void iree_mmt4d_reference(const iree_uk_mmt4d_params_t* params) { iree_uk_mmt4d_type_t mmt4d_type = iree_uk_mmt4d_type(params->flags); - iree_uk_index_t lhs_elem_size = - iree_uk_type_size(iree_uk_mmt4d_lhs_type(mmt4d_type)); - iree_uk_index_t rhs_elem_size = - iree_uk_type_size(iree_uk_mmt4d_rhs_type(mmt4d_type)); + iree_uk_index_t lhs_elem_bits = + iree_uk_type_bit_count(iree_uk_mmt4d_lhs_type(mmt4d_type)); + iree_uk_index_t rhs_elem_bits = + iree_uk_type_bit_count(iree_uk_mmt4d_rhs_type(mmt4d_type)); iree_uk_index_t out_elem_size = iree_uk_type_size(iree_uk_mmt4d_out_type(mmt4d_type)); for (iree_uk_index_t i = 0; i < params->M; ++i) { @@ -153,18 +188,22 @@ static void iree_mmt4d_reference(const iree_uk_mmt4d_params_t* params) { out_elem_size; const void* lhs_panel_ptr = ((const char*)params->lhs_buffer) + - (params->lhs_offset + i * params->lhs_stride0) * lhs_elem_size; + iree_uk_bits_to_bytes_exact( + (params->lhs_offset + i * params->lhs_stride0) * lhs_elem_bits); const void* rhs_panel_ptr = ((const char*)params->rhs_buffer) + - (params->rhs_offset + j * params->rhs_stride0) * rhs_elem_size; + iree_uk_bits_to_bytes_exact( + (params->rhs_offset + j * params->rhs_stride0) * rhs_elem_bits); for (iree_uk_index_t i0 = 0; i0 < params->M0; ++i0) { for (iree_uk_index_t j0 = 0; j0 < params->N0; ++j0) { void* out_ptr = ((char*)out_tile_ptr) + (i0 * params->N0 + j0) * out_elem_size; const void* lhs_ptr = - ((char*)lhs_panel_ptr) + i0 * params->K0 * lhs_elem_size; + ((char*)lhs_panel_ptr) + + iree_uk_bits_to_bytes_exact(i0 * params->K0 * lhs_elem_bits); const void* rhs_ptr = - ((char*)rhs_panel_ptr) + j0 * params->K0 * rhs_elem_size; + ((char*)rhs_panel_ptr) + + iree_uk_bits_to_bytes_exact(j0 * params->K0 * rhs_elem_bits); switch (params->flags & IREE_UK_FLAG_MMT4D_TYPE_MASK) { case IREE_UK_FLAG_MMT4D_TYPE_F32F32F32: iree_mmt4d_reference_innerloop_f32f32f32( @@ -196,6 +235,16 @@ static void iree_mmt4d_reference(const iree_uk_mmt4d_params_t* params) { (int32_t*)out_ptr, (const int8_t*)lhs_ptr, (const int8_t*)rhs_ptr, params); break; + case IREE_UK_FLAG_MMT4D_TYPE_S16S16S32: + iree_mmt4d_reference_innerloop_s16s16s32( + (int32_t*)out_ptr, (const int16_t*)lhs_ptr, + (const int16_t*)rhs_ptr, params); + break; + case IREE_UK_FLAG_MMT4D_TYPE_S16U4S32: + iree_mmt4d_reference_innerloop_s16u4s32( + (int32_t*)out_ptr, (const int16_t*)lhs_ptr, + (const uint8_t*)rhs_ptr, params); + break; default: IREE_UK_ASSERT(false && "unhandled type"); } @@ -206,23 +255,47 @@ static void iree_mmt4d_reference(const iree_uk_mmt4d_params_t* params) { } } +static iree_uk_index_t iree_uk_test_round_up_to_ensure_multiple_of_8_bits( + iree_uk_index_t index, iree_uk_type_t type) { + // Honor the requirement that strides should be multiples of 8 bits. + while ((index << iree_uk_type_bit_count_log2(type)) & 7) { + ++index; + } + return index; +} + +static iree_uk_index_t iree_uk_test_random_stride( + iree_uk_index_t min_stride, iree_uk_type_t type, + iree_uk_random_engine_t* engine) { + // Randomly make strides either tight or not to exercise all cases. + iree_uk_index_t stride = min_stride + iree_uk_random_engine_get_0_1(engine); + return iree_uk_test_round_up_to_ensure_multiple_of_8_bits(stride, type); +} + +static iree_uk_index_t iree_uk_test_random_offset( + iree_uk_type_t type, iree_uk_random_engine_t* engine) { + // Randomly make strides either tight or not to exercise all cases. + iree_uk_index_t stride = iree_uk_random_engine_get_0_1(engine); + return iree_uk_test_round_up_to_ensure_multiple_of_8_bits(stride, type); +} + static void iree_uk_test_mmt4d_for_shape_params( iree_uk_test_t* test, const iree_uk_mmt4d_params_t* src_params) { iree_uk_mmt4d_params_t params; memcpy(¶ms, src_params, sizeof params); - // Populate strides first - we need them below to compute buffer lengths. - // Randomly make strides either tight or not to exercise all cases. - iree_uk_random_engine_t* engine = iree_uk_test_random_engine(test); - params.lhs_stride0 = - params.K * params.M0 * params.K0 + iree_uk_random_engine_get_0_1(engine); - params.rhs_stride0 = - params.K * params.N0 * params.K0 + iree_uk_random_engine_get_0_1(engine); - params.out_stride0 = - params.N * params.M0 * params.N0 + iree_uk_random_engine_get_0_1(engine); iree_uk_mmt4d_type_t mmt4d_type = iree_uk_mmt4d_type(params.flags); iree_uk_type_t lhs_type = iree_uk_mmt4d_lhs_type(mmt4d_type); iree_uk_type_t rhs_type = iree_uk_mmt4d_rhs_type(mmt4d_type); iree_uk_type_t out_type = iree_uk_mmt4d_out_type(mmt4d_type); + // Populate strides first - we need them below to compute buffer lengths. + // Randomly make strides either tight or not to exercise all cases. + iree_uk_random_engine_t* engine = iree_uk_test_random_engine(test); + params.lhs_stride0 = iree_uk_test_random_stride( + params.K * params.M0 * params.K0, lhs_type, engine); + params.rhs_stride0 = iree_uk_test_random_stride( + params.K * params.N0 * params.K0, rhs_type, engine); + params.out_stride0 = iree_uk_test_random_stride( + params.N * params.M0 * params.N0, out_type, engine); iree_uk_index_t lhs_buffer_size = iree_uk_2d_buffer_length(lhs_type, params.M, params.lhs_stride0); iree_uk_index_t rhs_buffer_size = @@ -231,13 +304,17 @@ static void iree_uk_test_mmt4d_for_shape_params( void* rhs_buffer = malloc(rhs_buffer_size); iree_uk_write_random_buffer(lhs_buffer, lhs_buffer_size, lhs_type, engine); iree_uk_write_random_buffer(rhs_buffer, rhs_buffer_size, rhs_type, engine); - params.lhs_offset = iree_uk_random_engine_get_0_65535(engine); - params.rhs_offset = iree_uk_random_engine_get_0_65535(engine); - params.out_offset = iree_uk_random_engine_get_0_65535(engine); - params.lhs_buffer = (const char*)lhs_buffer - - (params.lhs_offset * iree_uk_type_size(lhs_type)); - params.rhs_buffer = (const char*)rhs_buffer - - (params.rhs_offset * iree_uk_type_size(rhs_type)); + params.lhs_offset = iree_uk_test_random_offset(lhs_type, engine); + params.rhs_offset = iree_uk_test_random_offset(rhs_type, engine); + params.out_offset = iree_uk_test_random_offset(out_type, engine); + params.lhs_buffer = + (const char*)lhs_buffer - + iree_uk_bits_to_bytes_exact(params.lhs_offset + << iree_uk_type_bit_count_log2(lhs_type)); + params.rhs_buffer = + (const char*)rhs_buffer - + iree_uk_bits_to_bytes_exact(params.rhs_offset + << iree_uk_type_bit_count_log2(rhs_type)); iree_uk_mmt4d_params_t reference_params; memcpy(&reference_params, ¶ms, sizeof params); @@ -250,14 +327,17 @@ static void iree_uk_test_mmt4d_for_shape_params( memcpy(reference_out_buffer, init_out_buffer, out_buffer_size); reference_params.out_buffer = (char*)reference_out_buffer - - (params.out_offset * iree_uk_type_size(out_type)); + iree_uk_bits_to_bytes_exact(params.out_offset + << iree_uk_type_bit_count_log2(out_type)); iree_uk_mmt4d_params_t actual_params; memcpy(&actual_params, ¶ms, sizeof params); void* actual_out_buffer = malloc(out_buffer_size); memcpy(actual_out_buffer, init_out_buffer, out_buffer_size); - actual_params.out_buffer = (char*)actual_out_buffer - - (params.out_offset * iree_uk_type_size(out_type)); + actual_params.out_buffer = + (char*)actual_out_buffer - + iree_uk_bits_to_bytes_exact(params.out_offset + << iree_uk_type_bit_count_log2(out_type)); iree_mmt4d_reference(&reference_params); iree_uk_mmt4d(&actual_params); @@ -376,6 +456,8 @@ int main(int argc, char** argv) { // in a power-of-two assumption iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F32F32F32, 3, 5, 7, ""); iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 9, 6, 3, ""); + iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16S16S32, 7, 3, 6, ""); + iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S16U4S32, 5, 3, 2, ""); iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F32, 4, 6, 5, ""); iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F16, 3, 5, 8, ""); iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_BF16BF16F32, 11, 4, 1, ""); diff --git a/runtime/src/iree/builtins/ukernel/tools/util.c b/runtime/src/iree/builtins/ukernel/tools/util.c index 2f94964d7bb86..8c8b545abf912 100644 --- a/runtime/src/iree/builtins/ukernel/tools/util.c +++ b/runtime/src/iree/builtins/ukernel/tools/util.c @@ -33,20 +33,27 @@ void iree_uk_assert_fail(const char* file, int line, const char* function, iree_uk_index_t iree_uk_2d_buffer_length(iree_uk_type_t type, iree_uk_index_t size0, iree_uk_index_t stride0) { - // Just for testing purposes, so it's OK to overestimate size. - return size0 * stride0 << iree_uk_type_size_log2(type); + // As we require strides to be multiples of 8 bits, the stride value in bytes + // is exact. + return size0 * iree_uk_bits_to_bytes_exact( + stride0 << iree_uk_type_bit_count_log2(type)); } bool iree_uk_2d_buffers_equal(const void* buf1, const void* buf2, iree_uk_type_t type, iree_uk_index_t size0, iree_uk_index_t size1, iree_uk_index_t stride0) { - iree_uk_index_t elem_size = iree_uk_type_size(type); + // Sizes don't have to be multiples of 8 bits. + iree_uk_index_t size1_bytes = iree_uk_bits_to_bytes_rounding_up( + size1 << iree_uk_type_bit_count_log2(type)); + // Strides are required to be multiples of 8 bits. + iree_uk_index_t stride0_bytes = + iree_uk_bits_to_bytes_exact(stride0 << iree_uk_type_bit_count_log2(type)); const char* buf1_ptr = buf1; const char* buf2_ptr = buf2; for (iree_uk_index_t i0 = 0; i0 < size0; ++i0) { - if (memcmp(buf1_ptr, buf2_ptr, elem_size * size1)) return false; - buf1_ptr += elem_size * stride0; - buf2_ptr += elem_size * stride0; + if (memcmp(buf1_ptr, buf2_ptr, size1_bytes)) return false; + buf1_ptr += stride0_bytes; + buf2_ptr += stride0_bytes; } return true; } @@ -73,6 +80,11 @@ int iree_uk_random_engine_get_0_65535(iree_uk_random_engine_t* e) { return (v >> 8) & 0xffff; } +int iree_uk_random_engine_get_0_255(iree_uk_random_engine_t* e) { + int v = iree_uk_random_engine_get_0_65535(e); + return v & 0xff; +} + int iree_uk_random_engine_get_0_1(iree_uk_random_engine_t* e) { int v = iree_uk_random_engine_get_0_65535(e); return v & 1; @@ -97,6 +109,16 @@ void iree_uk_write_random_buffer(void* buffer, iree_uk_index_t size_in_bytes, iree_uk_write_random_buffer(buffer, size_in_bytes, resolved_type, engine); return; } + // Special-case sub-byte-size integer types. Due to their narrow range, we + // want to generate values over their entire range, and then it's down to + // just generating random bytes. + if (iree_uk_type_is_integer(type) && iree_uk_type_bit_count(type) < 8) { + for (iree_uk_index_t i = 0; i < size_in_bytes; ++i) { + ((uint8_t*)buffer)[i] = iree_uk_random_engine_get_0_255(engine); + } + return; + } + // All other element types. iree_uk_index_t elem_size = iree_uk_type_size(type); iree_uk_index_t size_in_elems = size_in_bytes / elem_size; for (iree_uk_index_t i = 0; i < size_in_elems; ++i) {