diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToUKernels.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToUKernels.cpp index b1aaf6636fb5..d2da94597e32 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToUKernels.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToUKernels.cpp @@ -120,7 +120,7 @@ matchDAGForUKernel(RewriterBase &rewriter, linalg::Mmt4DOp op, uint32_t flags = 0; if (lhsElemType.isSignlessInteger(8) && rhsElemType.isSignlessInteger(8) && outElemType.isSignlessInteger(32)) { - flags = IREE_UK_FLAG_MMT4D_TYPE_I8I8I32; + flags = IREE_UK_FLAG_MMT4D_TYPE_S8S8S32; } else if (lhsElemType.isF32() && rhsElemType.isF32() && outElemType.isF32()) { flags = IREE_UK_FLAG_MMT4D_TYPE_F32F32F32; diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64.c index ed2d24dc648f..2c5222516e9a 100644 --- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64.c +++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64.c @@ -200,7 +200,7 @@ IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8( iree_uk_mmt4d_tile_f16f16f16_4x8x1_arm_64, iree_uk_mmt4d_tile_f16f16f16_8x8x1_arm_64) -static inline void iree_uk_mmt4d_tile_i8i8i32_1x8x1_to_8x8x1_arm_64( +static inline void iree_uk_mmt4d_tile_s8s8s32_1x8x1_to_8x8x1_arm_64( void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel, const void* IREE_UK_RESTRICT rhs_panel, const iree_uk_mmt4d_params_t* params, int M0) { @@ -262,8 +262,8 @@ static inline void iree_uk_mmt4d_tile_i8i8i32_1x8x1_to_8x8x1_arm_64( } IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8( - iree_uk_mmt4d_tile_i8i8i32_1x8x1_to_8x8x1_arm_64, - iree_uk_mmt4d_tile_i8i8i32_1x8x1_arm_64, - iree_uk_mmt4d_tile_i8i8i32_2x8x1_arm_64, - iree_uk_mmt4d_tile_i8i8i32_4x8x1_arm_64, - iree_uk_mmt4d_tile_i8i8i32_8x8x1_arm_64) + iree_uk_mmt4d_tile_s8s8s32_1x8x1_to_8x8x1_arm_64, + iree_uk_mmt4d_tile_s8s8s32_1x8x1_arm_64, + iree_uk_mmt4d_tile_s8s8s32_2x8x1_arm_64, + iree_uk_mmt4d_tile_s8s8s32_4x8x1_arm_64, + iree_uk_mmt4d_tile_s8s8s32_8x8x1_arm_64) diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_dotprod.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_dotprod.c index 7a708bee7a94..dce52005f502 100644 --- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_dotprod.c +++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_dotprod.c @@ -7,7 +7,7 @@ #include "iree/builtins/ukernel/arch/arm_64/common_arm_64.h" #include "iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h" -static inline void iree_uk_mmt4d_tile_i8i8i32_1x8x4_to_8x8x4_arm_64_dotprod( +static inline void iree_uk_mmt4d_tile_s8s8s32_1x8x4_to_8x8x4_arm_64_dotprod( void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel, const void* IREE_UK_RESTRICT rhs_panel, const iree_uk_mmt4d_params_t* params, int M0) { @@ -71,8 +71,8 @@ static inline void iree_uk_mmt4d_tile_i8i8i32_1x8x4_to_8x8x4_arm_64_dotprod( } IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8( - iree_uk_mmt4d_tile_i8i8i32_1x8x4_to_8x8x4_arm_64_dotprod, - iree_uk_mmt4d_tile_i8i8i32_1x8x4_arm_64_dotprod, - iree_uk_mmt4d_tile_i8i8i32_2x8x4_arm_64_dotprod, - iree_uk_mmt4d_tile_i8i8i32_4x8x4_arm_64_dotprod, - iree_uk_mmt4d_tile_i8i8i32_8x8x4_arm_64_dotprod) + iree_uk_mmt4d_tile_s8s8s32_1x8x4_to_8x8x4_arm_64_dotprod, + iree_uk_mmt4d_tile_s8s8s32_1x8x4_arm_64_dotprod, + iree_uk_mmt4d_tile_s8s8s32_2x8x4_arm_64_dotprod, + iree_uk_mmt4d_tile_s8s8s32_4x8x4_arm_64_dotprod, + iree_uk_mmt4d_tile_s8s8s32_8x8x4_arm_64_dotprod) diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c index 0be1c66c6954..6ec64ca94810 100644 --- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c +++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_entry_point.c @@ -75,32 +75,32 @@ static iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_arm_64_f16f16f16_M0x8x1( const iree_uk_mmt4d_params_t* params) { #ifdef IREE_UK_BUILD_ARM_64_FP16 - if (iree_uk_cpu_supports_fp16(params->cpu_data)) { - switch (params->M0) { - case 1: - return iree_uk_mmt4d_tile_f16f16f16_1x8x1_arm_64_fp16; - case 2: - return iree_uk_mmt4d_tile_f16f16f16_2x8x1_arm_64_fp16; - case 4: - return iree_uk_mmt4d_tile_f16f16f16_4x8x1_arm_64_fp16; - case 8: - return iree_uk_mmt4d_tile_f16f16f16_8x8x1_arm_64_fp16; - } + if (iree_uk_cpu_supports_fp16(params->cpu_data)) { + switch (params->M0) { + case 1: + return iree_uk_mmt4d_tile_f16f16f16_1x8x1_arm_64_fp16; + case 2: + return iree_uk_mmt4d_tile_f16f16f16_2x8x1_arm_64_fp16; + case 4: + return iree_uk_mmt4d_tile_f16f16f16_4x8x1_arm_64_fp16; + case 8: + return iree_uk_mmt4d_tile_f16f16f16_8x8x1_arm_64_fp16; } + } #endif - if (params->flags & IREE_UK_FLAG_MMT4D_SKIP_INTERMEDIATE_ROUNDINGS) { - switch (params->M0) { - case 1: - return iree_uk_mmt4d_tile_f16f16f16_1x8x1_arm_64; - case 2: - return iree_uk_mmt4d_tile_f16f16f16_2x8x1_arm_64; - case 4: - return iree_uk_mmt4d_tile_f16f16f16_4x8x1_arm_64; - case 8: - return iree_uk_mmt4d_tile_f16f16f16_8x8x1_arm_64; - } + if (params->flags & IREE_UK_FLAG_MMT4D_SKIP_INTERMEDIATE_ROUNDINGS) { + switch (params->M0) { + case 1: + return iree_uk_mmt4d_tile_f16f16f16_1x8x1_arm_64; + case 2: + return iree_uk_mmt4d_tile_f16f16f16_2x8x1_arm_64; + case 4: + return iree_uk_mmt4d_tile_f16f16f16_4x8x1_arm_64; + case 8: + return iree_uk_mmt4d_tile_f16f16f16_8x8x1_arm_64; } - return 0; + } + return 0; } static iree_uk_mmt4d_tile_func_t @@ -116,20 +116,20 @@ static iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_arm_64_bf16bf16f32_M0x8x4( const iree_uk_mmt4d_params_t* params) { #ifdef IREE_UK_BUILD_ARM_64_BF16 - if (iree_uk_cpu_supports_bf16(params->cpu_data)) { - switch (params->M0) { - case 1: - return iree_uk_mmt4d_tile_bf16bf16f32_1x8x4_arm_64_bf16; - case 2: - return iree_uk_mmt4d_tile_bf16bf16f32_2x8x4_arm_64_bf16; - case 4: - return iree_uk_mmt4d_tile_bf16bf16f32_4x8x4_arm_64_bf16; - case 8: - return iree_uk_mmt4d_tile_bf16bf16f32_8x8x4_arm_64_bf16; - } + if (iree_uk_cpu_supports_bf16(params->cpu_data)) { + switch (params->M0) { + case 1: + return iree_uk_mmt4d_tile_bf16bf16f32_1x8x4_arm_64_bf16; + case 2: + return iree_uk_mmt4d_tile_bf16bf16f32_2x8x4_arm_64_bf16; + case 4: + return iree_uk_mmt4d_tile_bf16bf16f32_4x8x4_arm_64_bf16; + case 8: + return iree_uk_mmt4d_tile_bf16bf16f32_8x8x4_arm_64_bf16; } + } #endif - return 0; + return 0; } static iree_uk_mmt4d_tile_func_t @@ -152,13 +152,13 @@ iree_uk_mmt4d_select_tile_func_arm_64_i8i8i32_M0x8x1( const iree_uk_mmt4d_params_t* params) { switch (params->M0) { case 1: - return iree_uk_mmt4d_tile_i8i8i32_1x8x1_arm_64; + return iree_uk_mmt4d_tile_s8s8s32_1x8x1_arm_64; case 2: - return iree_uk_mmt4d_tile_i8i8i32_2x8x1_arm_64; + return iree_uk_mmt4d_tile_s8s8s32_2x8x1_arm_64; case 4: - return iree_uk_mmt4d_tile_i8i8i32_4x8x1_arm_64; + return iree_uk_mmt4d_tile_s8s8s32_4x8x1_arm_64; case 8: - return iree_uk_mmt4d_tile_i8i8i32_8x8x1_arm_64; + return iree_uk_mmt4d_tile_s8s8s32_8x8x1_arm_64; } return 0; } @@ -170,13 +170,13 @@ iree_uk_mmt4d_select_tile_func_arm_64_i8i8i32_M0x8x4( if (iree_uk_cpu_supports_dotprod(params->cpu_data)) { switch (params->M0) { case 1: - return iree_uk_mmt4d_tile_i8i8i32_1x8x4_arm_64_dotprod; + return iree_uk_mmt4d_tile_s8s8s32_1x8x4_arm_64_dotprod; case 2: - return iree_uk_mmt4d_tile_i8i8i32_2x8x4_arm_64_dotprod; + return iree_uk_mmt4d_tile_s8s8s32_2x8x4_arm_64_dotprod; case 4: - return iree_uk_mmt4d_tile_i8i8i32_4x8x4_arm_64_dotprod; + return iree_uk_mmt4d_tile_s8s8s32_4x8x4_arm_64_dotprod; case 8: - return iree_uk_mmt4d_tile_i8i8i32_8x8x4_arm_64_dotprod; + return iree_uk_mmt4d_tile_s8s8s32_8x8x4_arm_64_dotprod; } } #endif @@ -190,13 +190,13 @@ iree_uk_mmt4d_select_tile_func_arm_64_i8i8i32_M0x8x8( if (iree_uk_cpu_supports_i8mm(params->cpu_data)) { switch (params->M0) { case 1: - return iree_uk_mmt4d_tile_i8i8i32_1x8x8_arm_64_i8mm; + return iree_uk_mmt4d_tile_s8s8s32_1x8x8_arm_64_i8mm; case 2: - return iree_uk_mmt4d_tile_i8i8i32_2x8x8_arm_64_i8mm; + return iree_uk_mmt4d_tile_s8s8s32_2x8x8_arm_64_i8mm; case 4: - return iree_uk_mmt4d_tile_i8i8i32_4x8x8_arm_64_i8mm; + return iree_uk_mmt4d_tile_s8s8s32_4x8x8_arm_64_i8mm; case 8: - return iree_uk_mmt4d_tile_i8i8i32_8x8x8_arm_64_i8mm; + return iree_uk_mmt4d_tile_s8s8s32_8x8x8_arm_64_i8mm; } } #endif @@ -230,7 +230,7 @@ iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_arch( return iree_uk_mmt4d_select_tile_func_arm_64_bf16bf16f32(params); case iree_uk_mmt4d_type_bf16bf16bf16: return iree_uk_mmt4d_select_tile_func_arm_64_bf16bf16bf16(params); - case iree_uk_mmt4d_type_i8i8i32: + case iree_uk_mmt4d_type_s8s8s32: return iree_uk_mmt4d_select_tile_func_arm_64_i8i8i32(params); default: IREE_UK_ASSUME_UNREACHABLE; diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_i8mm.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_i8mm.c index e079e922e50c..97c6a5253980 100644 --- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_i8mm.c +++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_i8mm.c @@ -27,7 +27,7 @@ static inline int32x4_t iree_uk_neon_uzp2_s32_as_s64(int32x4_t a, int32x4_t b) { vuzp2q_s64(vreinterpretq_s64_s32(a), vreinterpretq_s64_s32(b))); } -void iree_uk_mmt4d_tile_i8i8i32_1x8x8_to_8x8x8_arm_64_i8mm( +void iree_uk_mmt4d_tile_s8s8s32_1x8x8_to_8x8x8_arm_64_i8mm( void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel, const void* IREE_UK_RESTRICT rhs_panel, const iree_uk_mmt4d_params_t* params, int M0) { @@ -99,8 +99,8 @@ void iree_uk_mmt4d_tile_i8i8i32_1x8x8_to_8x8x8_arm_64_i8mm( } IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8( - iree_uk_mmt4d_tile_i8i8i32_1x8x8_to_8x8x8_arm_64_i8mm, - iree_uk_mmt4d_tile_i8i8i32_1x8x8_arm_64_i8mm, - iree_uk_mmt4d_tile_i8i8i32_2x8x8_arm_64_i8mm, - iree_uk_mmt4d_tile_i8i8i32_4x8x8_arm_64_i8mm, - iree_uk_mmt4d_tile_i8i8i32_8x8x8_arm_64_i8mm) + iree_uk_mmt4d_tile_s8s8s32_1x8x8_to_8x8x8_arm_64_i8mm, + iree_uk_mmt4d_tile_s8s8s32_1x8x8_arm_64_i8mm, + iree_uk_mmt4d_tile_s8s8s32_2x8x8_arm_64_i8mm, + iree_uk_mmt4d_tile_s8s8s32_4x8x8_arm_64_i8mm, + iree_uk_mmt4d_tile_s8s8s32_8x8x8_arm_64_i8mm) diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h index 9bd33f01b164..3aa278ab2287 100644 --- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h +++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_internal.h @@ -33,17 +33,17 @@ IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_bf16bf16f32_1x8x4_arm_64_bf16) IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_bf16bf16f32_2x8x4_arm_64_bf16) IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_bf16bf16f32_4x8x4_arm_64_bf16) IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_bf16bf16f32_8x8x4_arm_64_bf16) -IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_i8i8i32_1x8x1_arm_64) -IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_i8i8i32_2x8x1_arm_64) -IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_i8i8i32_4x8x1_arm_64) -IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_i8i8i32_8x8x1_arm_64) -IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_i8i8i32_1x8x4_arm_64_dotprod) -IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_i8i8i32_2x8x4_arm_64_dotprod) -IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_i8i8i32_4x8x4_arm_64_dotprod) -IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_i8i8i32_8x8x4_arm_64_dotprod) -IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_i8i8i32_1x8x8_arm_64_i8mm) -IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_i8i8i32_2x8x8_arm_64_i8mm) -IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_i8i8i32_4x8x8_arm_64_i8mm) -IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_i8i8i32_8x8x8_arm_64_i8mm) +IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s8s32_1x8x1_arm_64) +IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s8s32_2x8x1_arm_64) +IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s8s32_4x8x1_arm_64) +IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s8s32_8x8x1_arm_64) +IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s8s32_1x8x4_arm_64_dotprod) +IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s8s32_2x8x4_arm_64_dotprod) +IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s8s32_4x8x4_arm_64_dotprod) +IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s8s32_8x8x4_arm_64_dotprod) +IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s8s32_1x8x8_arm_64_i8mm) +IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s8s32_2x8x8_arm_64_i8mm) +IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s8s32_4x8x8_arm_64_i8mm) +IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s8s32_8x8x8_arm_64_i8mm) #endif // IREE_BUILTINS_UKERNEL_ARCH_ARM_64_MMT4D_ARM_64_INTERNAL_H_ diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx2_fma.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx2_fma.c index 1bdf93a07aa4..bdf2edbbc59b 100644 --- a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx2_fma.c +++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx2_fma.c @@ -127,7 +127,7 @@ IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8( iree_uk_mmt4d_tile_f16f16f16_4x8x1_x86_64_avx2_fma, iree_uk_mmt4d_tile_f16f16f16_8x8x1_x86_64_avx2_fma) -static inline void iree_uk_mmt4d_tile_i8i8i32_1x8x2_to_8x8x2_x86_64_avx2_fma( +static inline void iree_uk_mmt4d_tile_s8s8s32_1x8x2_to_8x8x2_x86_64_avx2_fma( void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel, const void* IREE_UK_RESTRICT rhs_panel, const iree_uk_mmt4d_params_t* params, int M0) { @@ -216,8 +216,8 @@ static inline void iree_uk_mmt4d_tile_i8i8i32_1x8x2_to_8x8x2_x86_64_avx2_fma( } IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8( - iree_uk_mmt4d_tile_i8i8i32_1x8x2_to_8x8x2_x86_64_avx2_fma, - iree_uk_mmt4d_tile_i8i8i32_1x8x2_x86_64_avx2_fma, - iree_uk_mmt4d_tile_i8i8i32_2x8x2_x86_64_avx2_fma, - iree_uk_mmt4d_tile_i8i8i32_4x8x2_x86_64_avx2_fma, - iree_uk_mmt4d_tile_i8i8i32_8x8x2_x86_64_avx2_fma) + iree_uk_mmt4d_tile_s8s8s32_1x8x2_to_8x8x2_x86_64_avx2_fma, + iree_uk_mmt4d_tile_s8s8s32_1x8x2_x86_64_avx2_fma, + iree_uk_mmt4d_tile_s8s8s32_2x8x2_x86_64_avx2_fma, + iree_uk_mmt4d_tile_s8s8s32_4x8x2_x86_64_avx2_fma, + iree_uk_mmt4d_tile_s8s8s32_8x8x2_x86_64_avx2_fma) diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_base.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_base.c index 996ecfbce9ce..d133ae19adc8 100644 --- a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_base.c +++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_base.c @@ -179,7 +179,7 @@ IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8_16( iree_uk_mmt4d_tile_f16f16f16_16x16x1_x86_64_avx512_base) static inline void -iree_uk_mmt4d_tile_i8i8i32_1x16x2_to_16x16x2_x86_64_avx512_base( +iree_uk_mmt4d_tile_s8s8s32_1x16x2_to_16x16x2_x86_64_avx512_base( void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel, const void* IREE_UK_RESTRICT rhs_panel, const iree_uk_mmt4d_params_t* params, int M0) { @@ -300,9 +300,9 @@ iree_uk_mmt4d_tile_i8i8i32_1x16x2_to_16x16x2_x86_64_avx512_base( } IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8_16( - iree_uk_mmt4d_tile_i8i8i32_1x16x2_to_16x16x2_x86_64_avx512_base, - iree_uk_mmt4d_tile_i8i8i32_1x16x2_x86_64_avx512_base, - iree_uk_mmt4d_tile_i8i8i32_2x16x2_x86_64_avx512_base, - iree_uk_mmt4d_tile_i8i8i32_4x16x2_x86_64_avx512_base, - iree_uk_mmt4d_tile_i8i8i32_8x16x2_x86_64_avx512_base, - iree_uk_mmt4d_tile_i8i8i32_16x16x2_x86_64_avx512_base) \ No newline at end of file + iree_uk_mmt4d_tile_s8s8s32_1x16x2_to_16x16x2_x86_64_avx512_base, + iree_uk_mmt4d_tile_s8s8s32_1x16x2_x86_64_avx512_base, + iree_uk_mmt4d_tile_s8s8s32_2x16x2_x86_64_avx512_base, + iree_uk_mmt4d_tile_s8s8s32_4x16x2_x86_64_avx512_base, + iree_uk_mmt4d_tile_s8s8s32_8x16x2_x86_64_avx512_base, + iree_uk_mmt4d_tile_s8s8s32_16x16x2_x86_64_avx512_base) \ No newline at end of file diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_vnni.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_vnni.c index 386d22564068..44d8fa9affb5 100644 --- a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_vnni.c +++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_vnni.c @@ -8,7 +8,7 @@ #include "iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_internal.h" static inline void -iree_uk_mmt4d_tile_i8i8i32_1x16x2_to_16x16x2_x86_64_avx512_vnni( +iree_uk_mmt4d_tile_s8s8s32_1x16x2_to_16x16x2_x86_64_avx512_vnni( void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel, const void* IREE_UK_RESTRICT rhs_panel, const iree_uk_mmt4d_params_t* params, int M0) { @@ -129,9 +129,9 @@ iree_uk_mmt4d_tile_i8i8i32_1x16x2_to_16x16x2_x86_64_avx512_vnni( } IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0_1_2_4_8_16( - iree_uk_mmt4d_tile_i8i8i32_1x16x2_to_16x16x2_x86_64_avx512_vnni, - iree_uk_mmt4d_tile_i8i8i32_1x16x2_x86_64_avx512_vnni, - iree_uk_mmt4d_tile_i8i8i32_2x16x2_x86_64_avx512_vnni, - iree_uk_mmt4d_tile_i8i8i32_4x16x2_x86_64_avx512_vnni, - iree_uk_mmt4d_tile_i8i8i32_8x16x2_x86_64_avx512_vnni, - iree_uk_mmt4d_tile_i8i8i32_16x16x2_x86_64_avx512_vnni) + iree_uk_mmt4d_tile_s8s8s32_1x16x2_to_16x16x2_x86_64_avx512_vnni, + iree_uk_mmt4d_tile_s8s8s32_1x16x2_x86_64_avx512_vnni, + iree_uk_mmt4d_tile_s8s8s32_2x16x2_x86_64_avx512_vnni, + iree_uk_mmt4d_tile_s8s8s32_4x16x2_x86_64_avx512_vnni, + iree_uk_mmt4d_tile_s8s8s32_8x16x2_x86_64_avx512_vnni, + iree_uk_mmt4d_tile_s8s8s32_16x16x2_x86_64_avx512_vnni) diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c index cbc6b780066b..f31923e9611e 100644 --- a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c +++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c @@ -215,15 +215,15 @@ iree_uk_mmt4d_select_tile_func_x86_64_i8i8i32_M0x16x2( if (params->cpu_data[0] & (IREE_CPU_DATA0_X86_64_AVX512VNNI)) { switch (params->M0) { case 1: - return iree_uk_mmt4d_tile_i8i8i32_1x16x2_x86_64_avx512_vnni; + return iree_uk_mmt4d_tile_s8s8s32_1x16x2_x86_64_avx512_vnni; case 2: - return iree_uk_mmt4d_tile_i8i8i32_2x16x2_x86_64_avx512_vnni; + return iree_uk_mmt4d_tile_s8s8s32_2x16x2_x86_64_avx512_vnni; case 4: - return iree_uk_mmt4d_tile_i8i8i32_4x16x2_x86_64_avx512_vnni; + return iree_uk_mmt4d_tile_s8s8s32_4x16x2_x86_64_avx512_vnni; case 8: - return iree_uk_mmt4d_tile_i8i8i32_8x16x2_x86_64_avx512_vnni; + return iree_uk_mmt4d_tile_s8s8s32_8x16x2_x86_64_avx512_vnni; case 16: - return iree_uk_mmt4d_tile_i8i8i32_16x16x2_x86_64_avx512_vnni; + return iree_uk_mmt4d_tile_s8s8s32_16x16x2_x86_64_avx512_vnni; } } #endif @@ -231,15 +231,15 @@ iree_uk_mmt4d_select_tile_func_x86_64_i8i8i32_M0x16x2( if (params->cpu_data[0] & (IREE_CPU_DATA0_X86_64_AVX512BW)) { switch (params->M0) { case 1: - return iree_uk_mmt4d_tile_i8i8i32_1x16x2_x86_64_avx512_base; + return iree_uk_mmt4d_tile_s8s8s32_1x16x2_x86_64_avx512_base; case 2: - return iree_uk_mmt4d_tile_i8i8i32_2x16x2_x86_64_avx512_base; + return iree_uk_mmt4d_tile_s8s8s32_2x16x2_x86_64_avx512_base; case 4: - return iree_uk_mmt4d_tile_i8i8i32_4x16x2_x86_64_avx512_base; + return iree_uk_mmt4d_tile_s8s8s32_4x16x2_x86_64_avx512_base; case 8: - return iree_uk_mmt4d_tile_i8i8i32_8x16x2_x86_64_avx512_base; + return iree_uk_mmt4d_tile_s8s8s32_8x16x2_x86_64_avx512_base; case 16: - return iree_uk_mmt4d_tile_i8i8i32_16x16x2_x86_64_avx512_base; + return iree_uk_mmt4d_tile_s8s8s32_16x16x2_x86_64_avx512_base; } } #endif @@ -253,13 +253,13 @@ iree_uk_mmt4d_select_tile_func_x86_64_i8i8i32_M0x8x2( if (iree_uk_cpu_supports_avx2_fma(params->cpu_data)) { switch (params->M0) { case 1: - return iree_uk_mmt4d_tile_i8i8i32_1x8x2_x86_64_avx2_fma; + return iree_uk_mmt4d_tile_s8s8s32_1x8x2_x86_64_avx2_fma; case 2: - return iree_uk_mmt4d_tile_i8i8i32_2x8x2_x86_64_avx2_fma; + return iree_uk_mmt4d_tile_s8s8s32_2x8x2_x86_64_avx2_fma; case 4: - return iree_uk_mmt4d_tile_i8i8i32_4x8x2_x86_64_avx2_fma; + return iree_uk_mmt4d_tile_s8s8s32_4x8x2_x86_64_avx2_fma; case 8: - return iree_uk_mmt4d_tile_i8i8i32_8x8x2_x86_64_avx2_fma; + return iree_uk_mmt4d_tile_s8s8s32_8x8x2_x86_64_avx2_fma; } } #endif @@ -290,7 +290,7 @@ iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_arch( return iree_uk_mmt4d_select_tile_func_x86_64_bf16bf16f32(params); case iree_uk_mmt4d_type_bf16bf16bf16: return iree_uk_mmt4d_select_tile_func_x86_64_bf16bf16bf16(params); - case iree_uk_mmt4d_type_i8i8i32: + case iree_uk_mmt4d_type_s8s8s32: return iree_uk_mmt4d_select_tile_func_x86_64_i8i8i32(params); default: IREE_UK_ASSUME_UNREACHABLE; diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_internal.h b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_internal.h index b98de8c58543..7be4f7578258 100644 --- a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_internal.h +++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_internal.h @@ -9,10 +9,10 @@ #include "iree/builtins/ukernel/mmt4d_internal.h" -IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_i8i8i32_1x8x2_x86_64_avx2_fma) -IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_i8i8i32_2x8x2_x86_64_avx2_fma) -IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_i8i8i32_4x8x2_x86_64_avx2_fma) -IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_i8i8i32_8x8x2_x86_64_avx2_fma) +IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s8s32_1x8x2_x86_64_avx2_fma) +IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s8s32_2x8x2_x86_64_avx2_fma) +IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s8s32_4x8x2_x86_64_avx2_fma) +IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s8s32_8x8x2_x86_64_avx2_fma) IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_f32f32f32_1x8x1_x86_64_avx2_fma) IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_f32f32f32_2x8x1_x86_64_avx2_fma) IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_f32f32f32_4x8x1_x86_64_avx2_fma) @@ -66,24 +66,24 @@ IREE_UK_MMT4D_TILE_FUNC_DECL( IREE_UK_MMT4D_TILE_FUNC_DECL( iree_uk_mmt4d_tile_f16f16f16_16x16x1_x86_64_avx512_base) IREE_UK_MMT4D_TILE_FUNC_DECL( - iree_uk_mmt4d_tile_i8i8i32_1x16x2_x86_64_avx512_base) + iree_uk_mmt4d_tile_s8s8s32_1x16x2_x86_64_avx512_base) IREE_UK_MMT4D_TILE_FUNC_DECL( - iree_uk_mmt4d_tile_i8i8i32_2x16x2_x86_64_avx512_base) + iree_uk_mmt4d_tile_s8s8s32_2x16x2_x86_64_avx512_base) IREE_UK_MMT4D_TILE_FUNC_DECL( - iree_uk_mmt4d_tile_i8i8i32_4x16x2_x86_64_avx512_base) + iree_uk_mmt4d_tile_s8s8s32_4x16x2_x86_64_avx512_base) IREE_UK_MMT4D_TILE_FUNC_DECL( - iree_uk_mmt4d_tile_i8i8i32_8x16x2_x86_64_avx512_base) + iree_uk_mmt4d_tile_s8s8s32_8x16x2_x86_64_avx512_base) IREE_UK_MMT4D_TILE_FUNC_DECL( - iree_uk_mmt4d_tile_i8i8i32_16x16x2_x86_64_avx512_base) + iree_uk_mmt4d_tile_s8s8s32_16x16x2_x86_64_avx512_base) IREE_UK_MMT4D_TILE_FUNC_DECL( - iree_uk_mmt4d_tile_i8i8i32_1x16x2_x86_64_avx512_vnni) + iree_uk_mmt4d_tile_s8s8s32_1x16x2_x86_64_avx512_vnni) IREE_UK_MMT4D_TILE_FUNC_DECL( - iree_uk_mmt4d_tile_i8i8i32_2x16x2_x86_64_avx512_vnni) + iree_uk_mmt4d_tile_s8s8s32_2x16x2_x86_64_avx512_vnni) IREE_UK_MMT4D_TILE_FUNC_DECL( - iree_uk_mmt4d_tile_i8i8i32_4x16x2_x86_64_avx512_vnni) + iree_uk_mmt4d_tile_s8s8s32_4x16x2_x86_64_avx512_vnni) IREE_UK_MMT4D_TILE_FUNC_DECL( - iree_uk_mmt4d_tile_i8i8i32_8x16x2_x86_64_avx512_vnni) + iree_uk_mmt4d_tile_s8s8s32_8x16x2_x86_64_avx512_vnni) IREE_UK_MMT4D_TILE_FUNC_DECL( - iree_uk_mmt4d_tile_i8i8i32_16x16x2_x86_64_avx512_vnni) + iree_uk_mmt4d_tile_s8s8s32_16x16x2_x86_64_avx512_vnni) #endif // foIREE_BUILTINS_UKERNEL_ARCH_X86_64_MMT4D_X86_64_INTERNAL_H_ diff --git a/runtime/src/iree/builtins/ukernel/common.h b/runtime/src/iree/builtins/ukernel/common.h index bd4b45128ace..b196d856776a 100644 --- a/runtime/src/iree/builtins/ukernel/common.h +++ b/runtime/src/iree/builtins/ukernel/common.h @@ -433,7 +433,7 @@ typedef iree_uk_uint8_t iree_uk_type_t; // Signless integers. Use in microkernels that perform same-bit-width integer // arithmetic that is insensitive to signedness. For example, same-bit-width // element-wise integer add and mul ops. -#define IREE_UK_TYPE_CATEGORY_INTEGER 0x20u +#define IREE_UK_TYPE_CATEGORY_INTEGER_SIGNLESS 0x20u // Signed integers. Use in microkernels that are specifically performing signed // integer arithmetic. For example, any mixed-bit-width op that involves a // sign-extension (as in arith.extsi). @@ -461,10 +461,10 @@ enum { IREE_UK_TYPE_OPAQUE_16 = IREE_UK_TYPE_CATEGORY_OPAQUE | 4, IREE_UK_TYPE_OPAQUE_32 = IREE_UK_TYPE_CATEGORY_OPAQUE | 5, IREE_UK_TYPE_OPAQUE_64 = IREE_UK_TYPE_CATEGORY_OPAQUE | 6, - IREE_UK_TYPE_INT_8 = IREE_UK_TYPE_CATEGORY_INTEGER | 3, - IREE_UK_TYPE_INT_16 = IREE_UK_TYPE_CATEGORY_INTEGER | 4, - IREE_UK_TYPE_INT_32 = IREE_UK_TYPE_CATEGORY_INTEGER | 5, - IREE_UK_TYPE_INT_64 = IREE_UK_TYPE_CATEGORY_INTEGER | 6, + IREE_UK_TYPE_INT_8 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNLESS | 3, + IREE_UK_TYPE_INT_16 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNLESS | 4, + IREE_UK_TYPE_INT_32 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNLESS | 5, + IREE_UK_TYPE_INT_64 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNLESS | 6, IREE_UK_TYPE_SINT_8 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNED | 3, IREE_UK_TYPE_SINT_16 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNED | 4, IREE_UK_TYPE_SINT_32 = IREE_UK_TYPE_CATEGORY_INTEGER_SIGNED | 5, @@ -490,6 +490,43 @@ static inline int iree_uk_type_bit_count_log2(iree_uk_type_t t) { return t & IREE_UK_TYPE_BIT_COUNT_LOG2_MASK; } +// Mutate type category while keeping the same bit-width +static inline iree_uk_uint8_t iree_uk_type_mutate_category( + iree_uk_type_t t, iree_uk_uint8_t new_category) { + return new_category | iree_uk_type_bit_count_log2(t); +} + +// Integer type helpers +static inline iree_uk_uint8_t iree_uk_type_is_integer(iree_uk_type_t t) { + switch (iree_uk_type_category(t)) { + case IREE_UK_TYPE_CATEGORY_INTEGER_SIGNLESS: + case IREE_UK_TYPE_CATEGORY_INTEGER_SIGNED: + case IREE_UK_TYPE_CATEGORY_INTEGER_UNSIGNED: + return true; + default: + return false; + } +} + +static inline iree_uk_uint8_t iree_uk_integer_type_as_signless( + iree_uk_type_t t) { + IREE_UK_ASSERT(iree_uk_type_is_integer(t)); + return iree_uk_type_mutate_category(t, + IREE_UK_TYPE_CATEGORY_INTEGER_SIGNLESS); +} + +static inline iree_uk_uint8_t iree_uk_integer_type_as_signed(iree_uk_type_t t) { + IREE_UK_ASSERT(iree_uk_type_is_integer(t)); + return iree_uk_type_mutate_category(t, IREE_UK_TYPE_CATEGORY_INTEGER_SIGNED); +} + +static inline iree_uk_uint8_t iree_uk_integer_type_as_unsigned( + iree_uk_type_t t) { + IREE_UK_ASSERT(iree_uk_type_is_integer(t)); + return iree_uk_type_mutate_category(t, + IREE_UK_TYPE_CATEGORY_INTEGER_UNSIGNED); +} + // Behavior is undefined if the bit-count is not a multiple of 8! // The current implementation might return a negative value, but don't rely on // that. diff --git a/runtime/src/iree/builtins/ukernel/exported_bits.h b/runtime/src/iree/builtins/ukernel/exported_bits.h index 68263790f43c..e2dfd6af1d33 100644 --- a/runtime/src/iree/builtins/ukernel/exported_bits.h +++ b/runtime/src/iree/builtins/ukernel/exported_bits.h @@ -44,7 +44,7 @@ #define IREE_UK_FLAG_MMT4D_TYPE_MASK 0xFF #define IREE_UK_FLAG_MMT4D_TYPE_NONE 0x00 #define IREE_UK_FLAG_MMT4D_TYPE_F32F32F32 0x01 -#define IREE_UK_FLAG_MMT4D_TYPE_I8I8I32 0x02 +#define IREE_UK_FLAG_MMT4D_TYPE_S8S8S32 0x02 #define IREE_UK_FLAG_MMT4D_TYPE_F16F16F32 0x03 #define IREE_UK_FLAG_MMT4D_TYPE_F16F16F16 0x04 #define IREE_UK_FLAG_MMT4D_TYPE_BF16BF16F32 0x05 diff --git a/runtime/src/iree/builtins/ukernel/mmt4d.c b/runtime/src/iree/builtins/ukernel/mmt4d.c index 7cc61391eebe..a626a8c7da08 100644 --- a/runtime/src/iree/builtins/ukernel/mmt4d.c +++ b/runtime/src/iree/builtins/ukernel/mmt4d.c @@ -16,7 +16,7 @@ static void iree_uk_mmt4d_validate(const iree_uk_mmt4d_params_t* params) { IREE_UK_ASSERT(!(params->flags & ~allflags)); iree_uk_uint32_t flags_type = params->flags & IREE_UK_FLAG_MMT4D_TYPE_MASK; IREE_UK_ASSERT(flags_type == IREE_UK_FLAG_MMT4D_TYPE_F32F32F32 || - flags_type == IREE_UK_FLAG_MMT4D_TYPE_I8I8I32 || + flags_type == IREE_UK_FLAG_MMT4D_TYPE_S8S8S32 || flags_type == IREE_UK_FLAG_MMT4D_TYPE_F16F16F32 || flags_type == IREE_UK_FLAG_MMT4D_TYPE_F16F16F16 || flags_type == IREE_UK_FLAG_MMT4D_TYPE_BF16BF16F32 || diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_internal.h b/runtime/src/iree/builtins/ukernel/mmt4d_internal.h index add9b7b876fc..ebdd673194fc 100644 --- a/runtime/src/iree/builtins/ukernel/mmt4d_internal.h +++ b/runtime/src/iree/builtins/ukernel/mmt4d_internal.h @@ -12,8 +12,8 @@ typedef enum iree_uk_mmt4d_type_t { iree_uk_mmt4d_type_f32f32f32 = IREE_UK_TIE_3_TYPES_LITERAL(FLOAT_32, FLOAT_32, FLOAT_32), - iree_uk_mmt4d_type_i8i8i32 = - IREE_UK_TIE_3_TYPES_LITERAL(INT_8, INT_8, INT_32), + iree_uk_mmt4d_type_s8s8s32 = + IREE_UK_TIE_3_TYPES_LITERAL(SINT_8, SINT_8, SINT_32), iree_uk_mmt4d_type_f16f16f32 = IREE_UK_TIE_3_TYPES_LITERAL(FLOAT_16, FLOAT_16, FLOAT_32), iree_uk_mmt4d_type_f16f16f16 = @@ -28,8 +28,8 @@ static inline iree_uk_mmt4d_type_t iree_uk_mmt4d_type(iree_uk_uint32_t flags) { switch (flags & IREE_UK_FLAG_MMT4D_TYPE_MASK) { case IREE_UK_FLAG_MMT4D_TYPE_F32F32F32: return iree_uk_mmt4d_type_f32f32f32; - case IREE_UK_FLAG_MMT4D_TYPE_I8I8I32: - return iree_uk_mmt4d_type_i8i8i32; + case IREE_UK_FLAG_MMT4D_TYPE_S8S8S32: + return iree_uk_mmt4d_type_s8s8s32; case IREE_UK_FLAG_MMT4D_TYPE_F16F16F32: return iree_uk_mmt4d_type_f16f16f32; case IREE_UK_FLAG_MMT4D_TYPE_F16F16F16: diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_tile.c b/runtime/src/iree/builtins/ukernel/mmt4d_tile.c index f121ebe4926b..622ae89863bd 100644 --- a/runtime/src/iree/builtins/ukernel/mmt4d_tile.c +++ b/runtime/src/iree/builtins/ukernel/mmt4d_tile.c @@ -7,7 +7,7 @@ #include "iree/builtins/ukernel/mmt4d_internal.h" // Generic implementation of matmul tile, i8*i8->i32 case. -static void iree_uk_mmt4d_tile_i8i8i32_generic( +static void iree_uk_mmt4d_tile_s8s8s32_generic( void* out_tile_untyped, const void* lhs_panel_untyped, const void* rhs_panel_untyped, const iree_uk_mmt4d_params_t* params) { iree_uk_int32_t* out_tile = out_tile_untyped; @@ -225,8 +225,8 @@ static iree_uk_mmt4d_tile_func_t iree_uk_mmt4d_select_tile_func_generic( switch (iree_uk_mmt4d_type(params->flags)) { case iree_uk_mmt4d_type_f32f32f32: return iree_uk_mmt4d_tile_f32f32f32_generic; - case iree_uk_mmt4d_type_i8i8i32: - return iree_uk_mmt4d_tile_i8i8i32_generic; + case iree_uk_mmt4d_type_s8s8s32: + return iree_uk_mmt4d_tile_s8s8s32_generic; case iree_uk_mmt4d_type_f16f16f32: return iree_uk_mmt4d_tile_f16f16f32_generic; case iree_uk_mmt4d_type_f16f16f16: diff --git a/runtime/src/iree/builtins/ukernel/tools/e2e_matmul_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/e2e_matmul_benchmark.c index b572d78d9b64..441322ebb7ac 100644 --- a/runtime/src/iree/builtins/ukernel/tools/e2e_matmul_benchmark.c +++ b/runtime/src/iree/builtins/ukernel/tools/e2e_matmul_benchmark.c @@ -43,7 +43,7 @@ typedef struct iree_uk_benchmark_e2e_matmul_params_t { static iree_uk_uint32_t iree_uk_qts_op_flag(iree_uk_mmt4d_type_t type) { if (type == iree_uk_mmt4d_type_f32f32f32) return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_F32F32F32; - if (type == iree_uk_mmt4d_type_i8i8i32) + if (type == iree_uk_mmt4d_type_s8s8s32) return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_I8I8I32; iree_abort(); return 0; @@ -132,7 +132,7 @@ static void iree_uk_reference_rowmajor_matmul( iree_uk_reference_rowmajor_matmul_f32f32f32( params, (const float*)lhs, (const float*)rhs, (float*)out); break; - case IREE_UK_FLAG_MMT4D_TYPE_I8I8I32: + case IREE_UK_FLAG_MMT4D_TYPE_S8S8S32: iree_uk_reference_rowmajor_matmul_i8i8i32( params, (const iree_uk_int8_t*)lhs, (const iree_uk_int8_t*)rhs, (iree_uk_int32_t*)out); @@ -376,7 +376,7 @@ iree_uk_uint32_t iree_uk_mmt4d_parse_type_into_flag(const char* type) { return IREE_UK_FLAG_MMT4D_TYPE_F32F32F32; } if (!strcmp(type, "i8i8i32")) { - return IREE_UK_FLAG_MMT4D_TYPE_I8I8I32; + return IREE_UK_FLAG_MMT4D_TYPE_S8S8S32; } fprintf(stderr, "Unhandled type: %s\n", type); iree_abort(); diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c index 46556bb61435..ddcedb6f5855 100644 --- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c +++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c @@ -134,11 +134,11 @@ int main(int argc, char** argv) { "fp16"); iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_BF16BF16F32, 8, 8, 4, "bf16"); - iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 8, 8, 1, + iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 1, ""); - iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 8, 8, 4, + iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 4, "dotprod"); - iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 8, 8, 8, + iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 8, "i8mm"); #elif defined(IREE_ARCH_X86_64) iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F32F32F32, 8, 8, 1, @@ -155,18 +155,18 @@ int main(int argc, char** argv) { "avx512_base"); iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_BF16BF16F32, 16, 16, 2, "avx512_bf16"); - iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 8, 8, 2, + iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 2, "avx2_fma"); - iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 16, 16, 2, + iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 16, 16, 2, "avx512_base"); - iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 16, 16, 2, + iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 16, 16, 2, "avx512_vnni"); #else // defined(IREE_ARCH_ARM_64) // Architectures on which we do not have any optimized ukernel code. // Benchmark some arbitrary tile shape. iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F32F32F32, 8, 8, 1, ""); - iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 8, 8, 1, + iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 1, ""); #endif // defined(IREE_ARCH_ARM_64) diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c index eb5c28ee874f..84ae68f9e07b 100644 --- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c +++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c @@ -123,7 +123,7 @@ static void iree_mmt4d_reference_innerloop_bf16bf16bf16( *out_ptr = acc; } -static void iree_mmt4d_reference_innerloop_i8i8i32( +static void iree_mmt4d_reference_innerloop_s8s8s32( int32_t* out_ptr, const int8_t* lhs_ptr, const int8_t* rhs_ptr, const iree_uk_mmt4d_params_t* params) { int32_t acc = params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE ? *out_ptr : 0; @@ -191,8 +191,8 @@ static void iree_mmt4d_reference(const iree_uk_mmt4d_params_t* params) { (uint16_t*)out_ptr, (const uint16_t*)lhs_ptr, (const uint16_t*)rhs_ptr, params); break; - case IREE_UK_FLAG_MMT4D_TYPE_I8I8I32: - iree_mmt4d_reference_innerloop_i8i8i32( + case IREE_UK_FLAG_MMT4D_TYPE_S8S8S32: + iree_mmt4d_reference_innerloop_s8s8s32( (int32_t*)out_ptr, (const int8_t*)lhs_ptr, (const int8_t*)rhs_ptr, params); break; @@ -375,7 +375,7 @@ int main(int argc, char** argv) { // to test weird M0, N0, K0 to ensure e.g. that we haven't unwittingly baked // in a power-of-two assumption iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F32F32F32, 3, 5, 7, ""); - iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 9, 6, 3, ""); + iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 9, 6, 3, ""); iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F32, 4, 6, 5, ""); iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F16, 3, 5, 8, ""); iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_BF16BF16F32, 11, 4, 1, ""); @@ -390,9 +390,9 @@ int main(int argc, char** argv) { iree_uk_test_mmt4d_default_and_skip_intermediate_roundings( IREE_UK_FLAG_MMT4D_TYPE_F16F16F16, 8, 8, 1, "fp16"); iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_BF16BF16F32, 8, 8, 4, "bf16"); - iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 8, 8, 1, ""); - iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 8, 8, 4, "dotprod"); - iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 8, 8, 8, "i8mm"); + iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 1, ""); + iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 4, "dotprod"); + iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 8, "i8mm"); #elif defined(IREE_ARCH_X86_64) iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F32F32F32, 8, 4, 1, ""); // SSE iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F32F32F32, 8, 8, 1, "avx2_fma"); @@ -407,10 +407,10 @@ int main(int argc, char** argv) { IREE_UK_FLAG_MMT4D_TYPE_F16F16F16, 16, 16, 1, "avx512_base"); iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_BF16BF16F32, 16, 16, 2, "avx512_bf16"); - iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 8, 4, 2, ""); // SSE2 - iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 8, 8, 2, "avx2_fma"); - iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 16, 16, 2, "avx512_base"); - iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 16, 16, 2, "avx512_vnni"); + iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 4, 2, ""); // SSE2 + iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 2, "avx2_fma"); + iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 16, 16, 2, "avx512_base"); + iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 16, 16, 2, "avx512_vnni"); #endif // defined(IREE_ARCH_ARM_64) return iree_uk_test_exit_status(); diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c.orig b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c.orig index d674709a2b36..b3f9e2cb0400 100644 --- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c.orig +++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c.orig @@ -91,7 +91,7 @@ static void iree_mmt4d_reference_innerloop_bf16bf16bf16( *out_ptr = acc; } -static void iree_mmt4d_reference_innerloop_i8i8i32( +static void iree_mmt4d_reference_innerloop_s8s8s32( int32_t* out_ptr, const int8_t* lhs_ptr, const int8_t* rhs_ptr, const iree_uk_mmt4d_params_t* params) { int32_t acc = params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE ? *out_ptr : 0; @@ -159,8 +159,8 @@ static void iree_mmt4d_reference(const iree_uk_mmt4d_params_t* params) { (uint16_t*)out_ptr, (const uint16_t*)lhs_ptr, (const uint16_t*)rhs_ptr, params); break; - case IREE_UK_FLAG_MMT4D_TYPE_I8I8I32: - iree_mmt4d_reference_innerloop_i8i8i32( + case IREE_UK_FLAG_MMT4D_TYPE_S8S8S32: + iree_mmt4d_reference_innerloop_s8s8s32( (int32_t*)out_ptr, (const int8_t*)lhs_ptr, (const int8_t*)rhs_ptr, params); break; @@ -320,7 +320,7 @@ int main(int argc, char** argv) { // to test weird M0, N0, K0 to ensure e.g. that we haven't unwittingly baked // in a power-of-two assumption iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F32F32F32, 3, 5, 7, ""); - iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 9, 6, 3, ""); + iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 9, 6, 3, ""); iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F32, 4, 6, 5, ""); iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F16, 3, 5, 8, ""); iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_BF16BF16F32, 11, 4, 1, ""); @@ -332,9 +332,9 @@ int main(int argc, char** argv) { iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F32F32F32, 8, 8, 1, ""); iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F32, 8, 8, 1, ""); iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F16, 8, 8, 1, ""); - iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 8, 8, 1, ""); - iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 8, 8, 4, "dotprod"); - iree_uk_test_mmt4d_default_and_intrinsics(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 8, + iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 1, ""); + iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 4, "dotprod"); + iree_uk_test_mmt4d_default_and_intrinsics(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 8, "i8mm"); #elif defined(IREE_ARCH_X86_64) iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F32F32F32, 8, 4, 1, ""); // SSE @@ -347,10 +347,10 @@ int main(int argc, char** argv) { iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F16, 8, 8, 1, "avx2_fma"); iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F16F16F16, 16, 16, 1, "avx512_base"); - iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 8, 4, 2, ""); // SSE2 - iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 8, 8, 2, "avx2_fma"); - iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 16, 16, 2, "avx512_base"); - iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_I8I8I32, 16, 16, 2, "avx512_vnni"); + iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 4, 2, ""); // SSE2 + iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 2, "avx2_fma"); + iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 16, 16, 2, "avx512_base"); + iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 16, 16, 2, "avx512_vnni"); #endif // defined(IREE_ARCH_ARM_64) return iree_uk_test_exit_status(); diff --git a/runtime/src/iree/builtins/ukernel/tools/util.c b/runtime/src/iree/builtins/ukernel/tools/util.c index e5fd58a90eee..2f94964d7bb8 100644 --- a/runtime/src/iree/builtins/ukernel/tools/util.c +++ b/runtime/src/iree/builtins/ukernel/tools/util.c @@ -86,6 +86,17 @@ int iree_uk_random_engine_get_minus16_plus15(iree_uk_random_engine_t* e) { void iree_uk_write_random_buffer(void* buffer, iree_uk_index_t size_in_bytes, iree_uk_type_t type, iree_uk_random_engine_t* engine) { + if (iree_uk_type_category(type) == IREE_UK_TYPE_CATEGORY_INTEGER_SIGNLESS) { + // Signless integers mean that the operation that will consume this buffer + // should not care if the data is signed or unsigned integers, so let's + // randomly exercise both and recurse so that the rest of this function + // doesn't have to deal with signless again. + iree_uk_type_t resolved_type = iree_uk_random_engine_get_0_1(engine) + ? iree_uk_integer_type_as_signed(type) + : iree_uk_integer_type_as_unsigned(type); + iree_uk_write_random_buffer(buffer, size_in_bytes, resolved_type, engine); + return; + } iree_uk_index_t elem_size = iree_uk_type_size(type); iree_uk_index_t size_in_elems = size_in_bytes / elem_size; for (iree_uk_index_t i = 0; i < size_in_elems; ++i) { @@ -103,12 +114,24 @@ void iree_uk_write_random_buffer(void* buffer, iree_uk_index_t size_in_bytes, case IREE_UK_TYPE_BFLOAT_16: ((uint16_t*)buffer)[i] = iree_math_f32_to_bf16((float)random_val); break; - case IREE_UK_TYPE_INT_32: + case IREE_UK_TYPE_SINT_32: ((int32_t*)buffer)[i] = random_val; break; - case IREE_UK_TYPE_INT_8: + case IREE_UK_TYPE_UINT_32: + ((uint32_t*)buffer)[i] = random_val; + break; + case IREE_UK_TYPE_SINT_16: + ((int16_t*)buffer)[i] = random_val; + break; + case IREE_UK_TYPE_UINT_16: + ((uint16_t*)buffer)[i] = random_val; + break; + case IREE_UK_TYPE_SINT_8: ((int8_t*)buffer)[i] = random_val; break; + case IREE_UK_TYPE_UINT_8: + ((uint8_t*)buffer)[i] = random_val; + break; default: IREE_UK_ASSERT(false && "unknown type"); } @@ -119,12 +142,12 @@ static const char* iree_uk_type_category_str(const iree_uk_type_t type) { switch (type & IREE_UK_TYPE_CATEGORY_MASK) { case IREE_UK_TYPE_CATEGORY_OPAQUE: return "x"; - case IREE_UK_TYPE_CATEGORY_INTEGER: + case IREE_UK_TYPE_CATEGORY_INTEGER_SIGNLESS: return "i"; case IREE_UK_TYPE_CATEGORY_INTEGER_SIGNED: - return "si"; + return "s"; case IREE_UK_TYPE_CATEGORY_INTEGER_UNSIGNED: - return "ui"; + return "u"; case IREE_UK_TYPE_CATEGORY_FLOAT_IEEE: return "f"; case IREE_UK_TYPE_CATEGORY_FLOAT_BRAIN: diff --git a/runtime/src/iree/modules/vmvx/module.c b/runtime/src/iree/modules/vmvx/module.c index a0b2ae604fa8..3a571618d1ea 100644 --- a/runtime/src/iree/modules/vmvx/module.c +++ b/runtime/src/iree/modules/vmvx/module.c @@ -542,7 +542,7 @@ IREE_VMVX_ABI_EXPORT(iree_vmvx_mmt4d, mmt4d, v) { in_elem_size = 4; out_elem_size = 4; break; - case IREE_UK_FLAG_MMT4D_TYPE_I8I8I32: + case IREE_UK_FLAG_MMT4D_TYPE_S8S8S32: in_elem_size = 1; out_elem_size = 4; break;