From b9c169f40c4726320c54024f16023a9ede74f5f5 Mon Sep 17 00:00:00 2001 From: Marat Dukhan Date: Sun, 7 Aug 2022 17:51:09 -0700 Subject: [PATCH] BF16 GEMM microkernels for NEON & NEON-BF16 PiperOrigin-RevId: 465928283 --- BUILD.bazel | 180 + CMakeLists.txt | 46 + WORKSPACE | 14 +- bench/bf16-gemm.cc | 246 + bench/utils.cc | 8 + bench/utils.h | 4 + cmake/DownloadCLog.cmake | 4 +- cmake/DownloadCpuinfo.cmake | 4 +- scripts/build-android-armv7.sh | 3 + scripts/generate-bf16-gemm.sh | 41 + .../c2-neonbf16-bfdot-lane-ld128.c.in | 197 + src/bf16-gemm/c8-neon.c.in | 229 + src/bf16-gemm/c8-neonbf16.c.in | 177 + .../gen/1x4c8-minmax-neonbf16-bfdot.c | 128 + .../gen/1x4c8-minmax-neonbf16-bfmlal.c | 137 + .../gen/1x4c8-minmax-neonfma-shland.c | 174 + src/bf16-gemm/gen/1x4c8-minmax-neonfma-zip.c | 174 + .../1x8c2-minmax-neonbf16-bfdot-lane-ld128.c | 171 + .../gen/2x4c8-minmax-neonbf16-bfdot.c | 169 + .../gen/2x4c8-minmax-neonbf16-bfmlal.c | 186 + .../gen/2x4c8-minmax-neonfma-shland.c | 233 + src/bf16-gemm/gen/2x4c8-minmax-neonfma-zip.c | 233 + .../gen/3x4c8-minmax-neonbf16-bfdot.c | 210 + .../gen/3x4c8-minmax-neonbf16-bfmlal.c | 235 + .../gen/3x4c8-minmax-neonfma-shland.c | 292 + src/bf16-gemm/gen/3x4c8-minmax-neonfma-zip.c | 292 + .../gen/4x4c8-minmax-neonbf16-bfdot.c | 251 + .../gen/4x4c8-minmax-neonbf16-bfmlal.c | 284 + .../gen/4x4c8-minmax-neonfma-shland.c | 351 + src/bf16-gemm/gen/4x4c8-minmax-neonfma-zip.c | 351 + .../4x8c2-minmax-neonbf16-bfdot-lane-ld128.c | 330 + .../gen/5x4c8-minmax-neonbf16-bfdot.c | 292 + .../gen/5x4c8-minmax-neonbf16-bfmlal.c | 333 + .../gen/5x4c8-minmax-neonfma-shland.c | 410 + src/bf16-gemm/gen/5x4c8-minmax-neonfma-zip.c | 410 + .../5x8c2-minmax-neonbf16-bfdot-lane-ld128.c | 383 + .../6x8c2-minmax-neonbf16-bfdot-lane-ld128.c | 436 + src/microparams-init.c | 10 + src/xnnpack/gemm.h | 135 +- src/xnnpack/isa-checks.h | 7 + src/xnnpack/microparams-init.h | 9 + src/xnnpack/microparams.h | 7 + src/xnnpack/params.h | 17 + test/bf16-gemm-minmax.cc | 10967 ++++++++++++++++ test/bf16-gemm-minmax.yaml | 92 + test/gemm-microkernel-tester.cc | 76 + test/gemm-microkernel-tester.h | 2 + third_party/cpuinfo.patch | 595 - tools/xnncommon.py | 3 + 49 files changed, 18885 insertions(+), 653 deletions(-) create mode 100644 bench/bf16-gemm.cc create mode 100755 scripts/generate-bf16-gemm.sh create mode 100644 src/bf16-gemm/c2-neonbf16-bfdot-lane-ld128.c.in create mode 100644 src/bf16-gemm/c8-neon.c.in create mode 100644 src/bf16-gemm/c8-neonbf16.c.in create mode 100644 src/bf16-gemm/gen/1x4c8-minmax-neonbf16-bfdot.c create mode 100644 src/bf16-gemm/gen/1x4c8-minmax-neonbf16-bfmlal.c create mode 100644 src/bf16-gemm/gen/1x4c8-minmax-neonfma-shland.c create mode 100644 src/bf16-gemm/gen/1x4c8-minmax-neonfma-zip.c create mode 100644 src/bf16-gemm/gen/1x8c2-minmax-neonbf16-bfdot-lane-ld128.c create mode 100644 src/bf16-gemm/gen/2x4c8-minmax-neonbf16-bfdot.c create mode 100644 src/bf16-gemm/gen/2x4c8-minmax-neonbf16-bfmlal.c create mode 100644 src/bf16-gemm/gen/2x4c8-minmax-neonfma-shland.c create mode 100644 src/bf16-gemm/gen/2x4c8-minmax-neonfma-zip.c create mode 100644 src/bf16-gemm/gen/3x4c8-minmax-neonbf16-bfdot.c create mode 100644 src/bf16-gemm/gen/3x4c8-minmax-neonbf16-bfmlal.c create mode 100644 src/bf16-gemm/gen/3x4c8-minmax-neonfma-shland.c create mode 100644 src/bf16-gemm/gen/3x4c8-minmax-neonfma-zip.c create mode 100644 src/bf16-gemm/gen/4x4c8-minmax-neonbf16-bfdot.c create mode 100644 src/bf16-gemm/gen/4x4c8-minmax-neonbf16-bfmlal.c create mode 100644 src/bf16-gemm/gen/4x4c8-minmax-neonfma-shland.c create mode 100644 src/bf16-gemm/gen/4x4c8-minmax-neonfma-zip.c create mode 100644 src/bf16-gemm/gen/4x8c2-minmax-neonbf16-bfdot-lane-ld128.c create mode 100644 src/bf16-gemm/gen/5x4c8-minmax-neonbf16-bfdot.c create mode 100644 src/bf16-gemm/gen/5x4c8-minmax-neonbf16-bfmlal.c create mode 100644 src/bf16-gemm/gen/5x4c8-minmax-neonfma-shland.c create mode 100644 src/bf16-gemm/gen/5x4c8-minmax-neonfma-zip.c create mode 100644 src/bf16-gemm/gen/5x8c2-minmax-neonbf16-bfdot-lane-ld128.c create mode 100644 src/bf16-gemm/gen/6x8c2-minmax-neonbf16-bfdot-lane-ld128.c create mode 100644 test/bf16-gemm-minmax.cc create mode 100644 test/bf16-gemm-minmax.yaml delete mode 100644 third_party/cpuinfo.patch diff --git a/BUILD.bazel b/BUILD.bazel index 6142bdaa17e3..7098bae0221d 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -4287,6 +4287,11 @@ PROD_NEONFMA_MICROKERNEL_SRCS = [ ] ALL_NEONFMA_MICROKERNEL_SRCS = [ + "src/bf16-gemm/gen/1x4c8-minmax-neonfma-shland.c", + "src/bf16-gemm/gen/2x4c8-minmax-neonfma-shland.c", + "src/bf16-gemm/gen/3x4c8-minmax-neonfma-shland.c", + "src/bf16-gemm/gen/4x4c8-minmax-neonfma-shland.c", + "src/bf16-gemm/gen/5x4c8-minmax-neonfma-shland.c", "src/f32-dwconv/gen/up4x3-minmax-neonfma-acc2.c", "src/f32-dwconv/gen/up4x3-minmax-neonfma.c", "src/f32-dwconv/gen/up4x4-minmax-neonfma-acc2.c", @@ -4533,6 +4538,11 @@ PROD_AARCH64_NEON_MICROKERNEL_SRCS = [ ] ALL_AARCH64_NEON_MICROKERNEL_SRCS = [ + "src/bf16-gemm/gen/1x4c8-minmax-neonfma-zip.c", + "src/bf16-gemm/gen/2x4c8-minmax-neonfma-zip.c", + "src/bf16-gemm/gen/3x4c8-minmax-neonfma-zip.c", + "src/bf16-gemm/gen/4x4c8-minmax-neonfma-zip.c", + "src/bf16-gemm/gen/5x4c8-minmax-neonfma-zip.c", "src/f32-conv-hwc/gen/3x3s2p0p1c3x4-neonfma-2x1.c", "src/f32-conv-hwc/gen/3x3s2p0p1c3x4-neonfma-2x2.c", "src/f32-conv-hwc/gen/3x3s2p0p1c3x8-neonfma-2x1.c", @@ -5214,6 +5224,32 @@ ALL_AARCH64_NEONFP16ARITH_MICROKERNEL_SRCS = [ "src/math/sigmoid-f16-neonfp16arith-rr2-p3-div.c", ] +PROD_NEONBF16_MICROKERNEL_SRCS = [ +] + +ALL_NEONBF16_MICROKERNEL_SRCS = [ + "src/bf16-gemm/gen/1x8c2-minmax-neonbf16-bfdot-lane-ld128.c", + "src/bf16-gemm/gen/4x8c2-minmax-neonbf16-bfdot-lane-ld128.c", + "src/bf16-gemm/gen/5x8c2-minmax-neonbf16-bfdot-lane-ld128.c", + "src/bf16-gemm/gen/6x8c2-minmax-neonbf16-bfdot-lane-ld128.c", + "src/bf16-gemm/gen/1x4c8-minmax-neonbf16-bfdot.c", + "src/bf16-gemm/gen/2x4c8-minmax-neonbf16-bfdot.c", + "src/bf16-gemm/gen/3x4c8-minmax-neonbf16-bfdot.c", + "src/bf16-gemm/gen/4x4c8-minmax-neonbf16-bfdot.c", + "src/bf16-gemm/gen/5x4c8-minmax-neonbf16-bfdot.c", + "src/bf16-gemm/gen/1x4c8-minmax-neonbf16-bfmlal.c", + "src/bf16-gemm/gen/2x4c8-minmax-neonbf16-bfmlal.c", + "src/bf16-gemm/gen/3x4c8-minmax-neonbf16-bfmlal.c", + "src/bf16-gemm/gen/4x4c8-minmax-neonbf16-bfmlal.c", + "src/bf16-gemm/gen/5x4c8-minmax-neonbf16-bfmlal.c", +] + +PROD_AARCH64_NEONBF16_MICROKERNEL_SRCS = [ +] + +ALL_AARCH64_NEONBF16_MICROKERNEL_SRCS = [ +] + PROD_NEONDOT_MICROKERNEL_SRCS = [ "src/qc8-gemm/gen/1x8c4-minmax-fp32-neondot.c", "src/qc8-gemm/gen/1x16c4-minmax-fp32-neondot.c", @@ -9319,6 +9355,76 @@ xnnpack_cc_library( ], ) +xnnpack_cc_library( + name = "neonbf16_bench_microkernels", + aarch32_copts = [ + "-marm", + "-march=armv8.2-a+bf16", + "-mfpu=neon-fp-armv8", + ], + aarch32_srcs = ALL_NEONBF16_MICROKERNEL_SRCS, + aarch64_copts = ["-march=armv8.2-a+bf16"], + aarch64_srcs = ALL_NEONBF16_MICROKERNEL_SRCS + ALL_AARCH64_NEONBF16_MICROKERNEL_SRCS, + gcc_copts = xnnpack_gcc_std_copts(), + msvc_copts = xnnpack_msvc_std_copts(), + deps = [ + ":common", + ":math", + ":microkernels_h", + ":params", + ":tables", + ":unaligned", + ], +) + +xnnpack_cc_library( + name = "neonbf16_prod_microkernels", + aarch32_copts = [ + "-marm", + "-march=armv8.2-a+bf16", + "-mfpu=neon-fp-armv8", + ], + aarch32_srcs = PROD_NEONBF16_MICROKERNEL_SRCS, + aarch64_copts = ["-march=armv8.2-a+bf16"], + aarch64_srcs = PROD_NEONBF16_MICROKERNEL_SRCS + PROD_AARCH64_NEONBF16_MICROKERNEL_SRCS, + gcc_copts = xnnpack_gcc_std_copts(), + msvc_copts = xnnpack_msvc_std_copts(), + deps = [ + ":common", + ":math", + ":microkernels_h", + ":params", + ":tables", + ":unaligned", + ], +) + +xnnpack_cc_library( + name = "neonbf16_test_microkernels", + aarch32_copts = [ + "-marm", + "-march=armv8.2-a+bf16", + "-mfpu=neon-fp-armv8", + ], + aarch32_srcs = ALL_NEONBF16_MICROKERNEL_SRCS, + aarch64_copts = ["-march=armv8.2-a+bf16"], + aarch64_srcs = ALL_NEONBF16_MICROKERNEL_SRCS + ALL_AARCH64_NEONBF16_MICROKERNEL_SRCS, + copts = [ + "-UNDEBUG", + "-DXNN_TEST_MODE=1", + ], + gcc_copts = xnnpack_gcc_std_copts(), + msvc_copts = xnnpack_msvc_std_copts(), + deps = [ + ":common", + ":math", + ":microkernels_h", + ":params", + ":tables", + ":unaligned", + ], +) + xnnpack_cc_library( name = "neondot_bench_microkernels", aarch32_copts = [ @@ -10280,6 +10386,9 @@ xnnpack_aggregate_library( defines = select({ ":arm_fp16_enabled": ["XNN_ENABLE_ARM_FP16=1"], "//conditions:default": ["XNN_ENABLE_ARM_FP16=0"], + }) + select({ + ":arm_bf16_enabled": ["XNN_ENABLE_ARM_BF16=1"], + "//conditions:default": ["XNN_ENABLE_ARM_BF16=0"], }) + select({ ":arm_dotprod_enabled": ["XNN_ENABLE_ARM_DOTPROD=1"], "//conditions:default": ["XNN_ENABLE_ARM_DOTPROD=0"], @@ -10289,6 +10398,9 @@ xnnpack_aggregate_library( ] + select({ ":arm_fp16_enabled": [":neonfp16arith_prod_microkernels"], "//conditions:default": [], + }) + select({ + ":arm_bf16_enabled": [":neonbf16_prod_microkernels"], + "//conditions:default": [], }) + select({ ":arm_dotprod_enabled": [":neondot_prod_microkernels"], "//conditions:default": [], @@ -10339,6 +10451,9 @@ xnnpack_aggregate_library( defines = select({ ":arm_fp16_enabled": ["XNN_ENABLE_ARM_FP16=1"], "//conditions:default": ["XNN_ENABLE_ARM_FP16=0"], + }) + select({ + ":arm_bf16_enabled": ["XNN_ENABLE_ARM_BF16=1"], + "//conditions:default": ["XNN_ENABLE_ARM_BF16=0"], }) + select({ ":arm_dotprod_enabled": ["XNN_ENABLE_ARM_DOTPROD=1"], "//conditions:default": ["XNN_ENABLE_ARM_DOTPROD=0"], @@ -10348,6 +10463,9 @@ xnnpack_aggregate_library( ] + select({ ":arm_fp16_enabled": [":neonfp16arith_bench_microkernels"], "//conditions:default": [], + }) + select({ + ":arm_bf16_enabled": [":neonbf16_bench_microkernels"], + "//conditions:default": [], }) + select({ ":arm_dotprod_enabled": [":neondot_bench_microkernels"], "//conditions:default": [], @@ -10398,6 +10516,9 @@ xnnpack_aggregate_library( defines = select({ ":arm_fp16_enabled": ["XNN_ENABLE_ARM_FP16=1"], "//conditions:default": ["XNN_ENABLE_ARM_FP16=0"], + }) + select({ + ":arm_bf16_enabled": ["XNN_ENABLE_ARM_BF16=1"], + "//conditions:default": ["XNN_ENABLE_ARM_BF16=0"], }) + select({ ":arm_dotprod_enabled": ["XNN_ENABLE_ARM_DOTPROD=1"], "//conditions:default": ["XNN_ENABLE_ARM_DOTPROD=0"], @@ -10407,6 +10528,9 @@ xnnpack_aggregate_library( ] + select({ ":arm_fp16_enabled": [":neonfp16arith_prod_microkernels"], "//conditions:default": [], + }) + select({ + ":arm_bf16_enabled": [":neonbf16_prod_microkernels"], + "//conditions:default": [], }) + select({ ":arm_dotprod_enabled": [":neondot_prod_microkernels"], "//conditions:default": [], @@ -10457,6 +10581,9 @@ xnnpack_aggregate_library( defines = select({ ":arm_fp16_enabled": ["XNN_ENABLE_ARM_FP16=1"], "//conditions:default": ["XNN_ENABLE_ARM_FP16=0"], + }) + select({ + ":arm_bf16_enabled": ["XNN_ENABLE_ARM_BF16=1"], + "//conditions:default": ["XNN_ENABLE_ARM_BF16=0"], }) + select({ ":arm_dotprod_enabled": ["XNN_ENABLE_ARM_DOTPROD=1"], "//conditions:default": ["XNN_ENABLE_ARM_DOTPROD=0"], @@ -10466,6 +10593,9 @@ xnnpack_aggregate_library( ] + select({ ":arm_fp16_enabled": [":neonfp16arith_test_microkernels"], "//conditions:default": [], + }) + select({ + ":arm_bf16_enabled": [":neonbf16_test_microkernels"], + "//conditions:default": [], }) + select({ ":arm_dotprod_enabled": [":neondot_test_microkernels"], "//conditions:default": [], @@ -11306,6 +11436,18 @@ xnnpack_benchmark( deps = MICROKERNEL_BENCHMARK_DEPS, ) +xnnpack_benchmark( + name = "bf16_gemm_bench", + srcs = [ + "bench/bf16-gemm.cc", + "bench/gemm.h", + ], + deps = MICROKERNEL_BENCHMARK_DEPS + [ + ":math", + ":packing", + ], +) + xnnpack_benchmark( name = "f16_igemm_bench", srcs = [ @@ -12464,6 +12606,16 @@ xnnpack_cc_library( ], ) +xnnpack_unit_test( + name = "bf16_gemm_minmax_test", + srcs = [ + "test/bf16-gemm-minmax.cc", + ], + deps = MICROKERNEL_TEST_DEPS + [ + ":gemm_microkernel_tester", + ], +) + xnnpack_unit_test( name = "f16_f32_vcvt_test", srcs = [ @@ -15528,6 +15680,18 @@ config_setting( define_values = {"xnn_enable_arm_fp16": "false"}, ) +# Enables usage of ARM BF16 (BF16 arithmetics) kernels. +config_setting( + name = "xnn_enable_arm_bf16_explicit_true", + define_values = {"xnn_enable_arm_bf16": "true"}, +) + +# Disables usage of ARM BF16 (BF16 arithmetics) kernels. +config_setting( + name = "xnn_enable_arm_bf16_explicit_false", + define_values = {"xnn_enable_arm_bf16": "false"}, +) + # Enables usage of ARM DotProd (integer dot product) kernels. config_setting( name = "xnn_enable_arm_dotprod_explicit_true", @@ -15965,6 +16129,22 @@ alias( }), ) +selects.config_setting_group( + name = "arm_bf16_enabled_by_default", + match_any = [ + ":aarch64", + ], +) + +alias( + name = "arm_bf16_enabled", + actual = select({ + ":xnn_enable_arm_bf16_explicit_true": ":xnn_enable_arm_bf16_explicit_true", + ":xnn_enable_arm_bf16_explicit_false": ":xnn_enable_arm_bf16_explicit_true", + "//conditions:default": ":arm_bf16_enabled_by_default", + }), +) + selects.config_setting_group( name = "arm_dotprod_enabled_by_default", match_any = [ diff --git a/CMakeLists.txt b/CMakeLists.txt index 2721c0170227..07b3fe6a6a1f 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,6 +40,7 @@ ENDIF() # --- [ Processor-specific options OPTION(XNNPACK_ENABLE_ARM_FP16 "Build XNNPACK with ARM FP16 (FP16 data processing) micro-kernels" ON) +OPTION(XNNPACK_ENABLE_ARM_BF16 "Build XNNPACK with ARM BF16 (BFLOAT16) micro-kernels" ON) OPTION(XNNPACK_ENABLE_ARM_DOTPROD "Build XNNPACK with ARM DotProd (integer dot product) micro-kernels" ON) # ---[ CMake options @@ -50,6 +51,7 @@ IF(XNNPACK_BUILD_TESTS) ENDIF() ADD_COMPILE_DEFINITIONS("XNN_ENABLE_ARM_FP16=$") +ADD_COMPILE_DEFINITIONS("XNN_ENABLE_ARM_BF16=$") ADD_COMPILE_DEFINITIONS("XNN_ENABLE_ARM_DOTPROD=$") ADD_COMPILE_DEFINITIONS("XNN_ENABLE_ASSEMBLY=$") ADD_COMPILE_DEFINITIONS("XNN_ENABLE_JIT=$") @@ -2777,6 +2779,11 @@ SET(PROD_NEONFMA_MICROKERNEL_SRCS src/f32-vsigmoid/gen/vsigmoid-neonfma-rr1-lut64-p2-nr2recps-x16.c) SET(ALL_NEONFMA_MICROKERNEL_SRCS + src/bf16-gemm/gen/1x4c8-minmax-neonfma-shland.c + src/bf16-gemm/gen/2x4c8-minmax-neonfma-shland.c + src/bf16-gemm/gen/3x4c8-minmax-neonfma-shland.c + src/bf16-gemm/gen/4x4c8-minmax-neonfma-shland.c + src/bf16-gemm/gen/5x4c8-minmax-neonfma-shland.c src/f32-dwconv/gen/up4x3-minmax-neonfma-acc2.c src/f32-dwconv/gen/up4x3-minmax-neonfma.c src/f32-dwconv/gen/up4x4-minmax-neonfma-acc2.c @@ -3021,6 +3028,11 @@ SET(PROD_AARCH64_NEON_MICROKERNEL_SRCS src/x32-transposec/4x4-aarch64-tbl.c) SET(ALL_AARCH64_NEON_MICROKERNEL_SRCS + src/bf16-gemm/gen/1x4c8-minmax-neonfma-zip.c + src/bf16-gemm/gen/2x4c8-minmax-neonfma-zip.c + src/bf16-gemm/gen/3x4c8-minmax-neonfma-zip.c + src/bf16-gemm/gen/4x4c8-minmax-neonfma-zip.c + src/bf16-gemm/gen/5x4c8-minmax-neonfma-zip.c src/f32-conv-hwc/gen/3x3s2p0p1c3x4-neonfma-2x1.c src/f32-conv-hwc/gen/3x3s2p0p1c3x4-neonfma-2x2.c src/f32-conv-hwc/gen/3x3s2p0p1c3x8-neonfma-2x1.c @@ -3695,6 +3707,22 @@ SET(ALL_AARCH64_NEONFP16ARITH_MICROKERNEL_SRCS src/math/sigmoid-f16-neonfp16arith-rr2-p2-div.c src/math/sigmoid-f16-neonfp16arith-rr2-p3-div.c) +SET(ALL_NEONBF16_MICROKERNEL_SRCS + src/bf16-gemm/gen/1x8c2-minmax-neonbf16-bfdot-lane-ld128.c + src/bf16-gemm/gen/4x8c2-minmax-neonbf16-bfdot-lane-ld128.c + src/bf16-gemm/gen/5x8c2-minmax-neonbf16-bfdot-lane-ld128.c + src/bf16-gemm/gen/6x8c2-minmax-neonbf16-bfdot-lane-ld128.c + src/bf16-gemm/gen/1x4c8-minmax-neonbf16-bfdot.c + src/bf16-gemm/gen/2x4c8-minmax-neonbf16-bfdot.c + src/bf16-gemm/gen/3x4c8-minmax-neonbf16-bfdot.c + src/bf16-gemm/gen/4x4c8-minmax-neonbf16-bfdot.c + src/bf16-gemm/gen/5x4c8-minmax-neonbf16-bfdot.c + src/bf16-gemm/gen/1x4c8-minmax-neonbf16-bfmlal.c + src/bf16-gemm/gen/2x4c8-minmax-neonbf16-bfmlal.c + src/bf16-gemm/gen/3x4c8-minmax-neonbf16-bfmlal.c + src/bf16-gemm/gen/4x4c8-minmax-neonbf16-bfmlal.c + src/bf16-gemm/gen/5x4c8-minmax-neonbf16-bfmlal.c) + SET(PROD_NEONDOT_MICROKERNEL_SRCS src/qc8-gemm/gen/1x8c4-minmax-fp32-neondot.c src/qc8-gemm/gen/1x16c4-minmax-fp32-neondot.c @@ -6975,6 +7003,9 @@ IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv[5-8]" OR IOS_ARCH MATCHES "^armv7") IF(XNNPACK_ENABLE_ARM_FP16) LIST(APPEND ALL_MICROKERNEL_SRCS ${ALL_NEONFP16ARITH_MICROKERNEL_SRCS}) ENDIF() + IF(XNNPACK_ENABLE_ARM_BF16) + LIST(APPEND ALL_MICROKERNEL_SRCS ${ALL_NEONBF16_MICROKERNEL_SRCS}) + ENDIF() IF(XNNPACK_ENABLE_ARM_DOTPROD) LIST(APPEND ALL_MICROKERNEL_SRCS ${ALL_NEONDOT_MICROKERNEL_SRCS}) ENDIF() @@ -7006,6 +7037,9 @@ IF(XNNPACK_TARGET_PROCESSOR MATCHES "^(aarch64|arm64)$" OR IOS_ARCH MATCHES "^ar LIST(APPEND ALL_MICROKERNEL_SRCS ${ALL_NEONFP16ARITH_MICROKERNEL_SRCS}) LIST(APPEND ALL_MICROKERNEL_SRCS ${ALL_AARCH64_NEONFP16ARITH_MICROKERNEL_SRCS}) ENDIF() + IF(XNNPACK_ENABLE_ARM_BF16) + LIST(APPEND ALL_MICROKERNEL_SRCS ${ALL_NEONBF16_MICROKERNEL_SRCS}) + ENDIF() IF(XNNPACK_ENABLE_ARM_DOTPROD) LIST(APPEND ALL_MICROKERNEL_SRCS ${ALL_NEONDOT_MICROKERNEL_SRCS}) ENDIF() @@ -7085,6 +7119,7 @@ IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv[5-8]" OR IOS_ARCH MATCHES "^armv7") SET_PROPERTY(SOURCE ${ALL_NEONFMA_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=neon-vfpv4 ") SET_PROPERTY(SOURCE ${ALL_NEONV8_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv8-a -mfpu=neon-fp-armv8 ") SET_PROPERTY(SOURCE ${ALL_NEONFP16ARITH_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv8.2-a+fp16 -mfpu=neon-fp-armv8 ") + SET_PROPERTY(SOURCE ${ALL_NEONBF16_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv8.2-a+bf16 -mfpu=neon-fp-armv8 ") SET_PROPERTY(SOURCE ${ALL_NEONDOT_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv8.2-a+dotprod -mfpu=neon-fp-armv8 ") SET_PROPERTY(SOURCE ${AARCH32_ASM_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv8.2-a+dotprod -mfpu=neon-fp-armv8 ") # Workground the neon detection bug in ARM v8 @@ -7094,6 +7129,7 @@ IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv[5-8]" OR IOS_ARCH MATCHES "^armv7") IF(ANDROID_NDK_MAJOR AND ANDROID_NDK_MAJOR LESS 21) SET_PROPERTY(SOURCE ${ALL_NEONV8_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -mfloat-abi=softfp ") SET_PROPERTY(SOURCE ${ALL_NEONFP16ARITH_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -mfloat-abi=softfp ") + SET_PROPERTY(SOURCE ${ALL_NEONBF16_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -mfloat-abi=softfp ") SET_PROPERTY(SOURCE ${ALL_NEONDOT_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -mfloat-abi=softfp ") SET_PROPERTY(SOURCE ${AARCH32_ASM_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -mfloat-abi=softfp ") ENDIF() @@ -7101,6 +7137,7 @@ ENDIF() IF(XNNPACK_TARGET_PROCESSOR MATCHES "^(aarch64|arm64)$" OR IOS_ARCH MATCHES "^arm64.*") SET_PROPERTY(SOURCE ${ALL_NEONFP16ARITH_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv8.2-a+fp16 ") SET_PROPERTY(SOURCE ${ALL_AARCH64_NEONFP16ARITH_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + SET_PROPERTY(SOURCE ${ALL_NEONBF16_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv8.2-a+bf16 ") SET_PROPERTY(SOURCE ${ALL_NEONDOT_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv8.2-a+dotprod ") SET_PROPERTY(SOURCE ${AARCH64_ASM_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv8.2-a+fp16+dotprod ") IF(IOS) @@ -7829,6 +7866,11 @@ IF(XNNPACK_BUILD_TESTS) ADD_TEST(NAME fusion-test COMMAND fusion-test) # ---[ Build microkernel-level unit tests + ADD_EXECUTABLE(bf16-gemm-minmax-test test/bf16-gemm-minmax.cc $ $) + TARGET_INCLUDE_DIRECTORIES(bf16-gemm-minmax-test PRIVATE include src test) + TARGET_LINK_LIBRARIES(bf16-gemm-minmax-test PRIVATE XNNPACK cpuinfo fp16 pthreadpool gtest gtest_main jit gemm-microkernel-tester microparams_init allocator) + ADD_TEST(NAME bf16-gemm-minmax-test COMMAND bf16-gemm-minmax-test) + ADD_EXECUTABLE(f16-f32-vcvt-test test/f16-f32-vcvt.cc $) TARGET_INCLUDE_DIRECTORIES(f16-f32-vcvt-test PRIVATE include src test) TARGET_LINK_LIBRARIES(f16-f32-vcvt-test PRIVATE cpuinfo fp16 pthreadpool gtest gtest_main microparams_init) @@ -9079,6 +9121,10 @@ IF(XNNPACK_BUILD_BENCHMARKS) TARGET_LINK_LIBRARIES(truncation-bench PRIVATE XNNPACK benchmark bench-utils microparams_init logging operators) # ---[ Build microkernel-level microbenchmarks + ADD_EXECUTABLE(bf16-gemm-bench bench/bf16-gemm.cc $ $) + TARGET_INCLUDE_DIRECTORIES(bf16-gemm-bench PRIVATE . include src) + TARGET_LINK_LIBRARIES(bf16-gemm-bench PRIVATE benchmark bench-utils cpuinfo fp16 pthreadpool microparams_init) + ADD_EXECUTABLE(f16-dwconv-bench bench/f16-dwconv.cc $ $ $) TARGET_INCLUDE_DIRECTORIES(f16-dwconv-bench PRIVATE . include src) TARGET_LINK_LIBRARIES(f16-dwconv-bench PRIVATE benchmark bench-utils cpuinfo fp16 pthreadpool microparams_init) diff --git a/WORKSPACE b/WORKSPACE index 1179b0c8ab35..2bcbbc6e0a02 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -63,25 +63,23 @@ http_archive( # clog library, used for logging http_archive( name = "clog", - strip_prefix = "cpuinfo-d5e37adf1406cf899d7d9ec1d317c47506ccb970", - sha256 = "3f2dc1970f397a0e59db72f9fca6ff144b216895c1d606f6c94a507c1e53a025", + strip_prefix = "cpuinfo-49610f89b8b1eb52d75d1eda7a2c40c1e86a78e7", + sha256 = "25843b5f21c32cba89f9b921c0500ab5cd0c2cb8fb0f345e5b5e4678329386c7", urls = [ - "https://github.com/pytorch/cpuinfo/archive/d5e37adf1406cf899d7d9ec1d317c47506ccb970.tar.gz", + "https://github.com/pytorch/cpuinfo/archive/49610f89b8b1eb52d75d1eda7a2c40c1e86a78e7.tar.gz", ], build_file = "@//third_party:clog.BUILD", ) - # cpuinfo library, used for detecting processor characteristics http_archive( name = "cpuinfo", - strip_prefix = "cpuinfo-ed8b86a253800bafdb7b25c5c399f91bff9cb1f3", - sha256 = "a7f9a188148a1660149878f737f42783e72f33a4f842f3e362fee2c981613e53", + strip_prefix = "cpuinfo-49610f89b8b1eb52d75d1eda7a2c40c1e86a78e7", + sha256 = "25843b5f21c32cba89f9b921c0500ab5cd0c2cb8fb0f345e5b5e4678329386c7", urls = [ - "https://github.com/pytorch/cpuinfo/archive/ed8b86a253800bafdb7b25c5c399f91bff9cb1f3.zip", + "https://github.com/pytorch/cpuinfo/archive/49610f89b8b1eb52d75d1eda7a2c40c1e86a78e7.zip", ], build_file = "@//third_party:cpuinfo.BUILD", - patches = ["@//third_party:cpuinfo.patch"], ) # Ruy library, used to benchmark against diff --git a/bench/bf16-gemm.cc b/bench/bf16-gemm.cc new file mode 100644 index 000000000000..ea330e85703c --- /dev/null +++ b/bench/bf16-gemm.cc @@ -0,0 +1,246 @@ +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include "bench/gemm.h" +#include "bench/utils.h" +#include +#include +#include +#include +#include +#include +#include + + +static void GEMMBenchmark(benchmark::State& state, + xnn_bf16_gemm_minmax_ukernel_function gemm, + size_t mr, size_t nr, size_t kr, size_t sr, + xnn_init_bf16_minmax_params_fn init_params, + benchmark::utils::IsaCheckFunction isa_check = nullptr) +{ + if (isa_check && !isa_check(state)) { + return; + } + + const size_t mc = state.range(0); + const size_t nc = state.range(1); + const size_t kc = state.range(2); + + const size_t nc_stride = benchmark::utils::RoundUp(nc, nr); + const size_t kc_stride = benchmark::utils::RoundUp(kc, kr * sr); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto f32rng = std::bind(std::uniform_real_distribution(), std::ref(rng)); + + std::vector a(mc * kc + XNN_EXTRA_BYTES / sizeof(uint16_t)); + std::generate(a.begin(), a.end(), [&] { return fp32_to_bits(f32rng(rng)) >> 16; }); + std::vector k(nc * kc); + std::generate(k.begin(), k.end(), [&] { return fp32_to_bits(f32rng(rng)) >> 16; }); + std::vector b(nc); + std::generate(b.begin(), b.end(), [&] { return fp32_to_bits(f32rng(rng)) >> 16; }); + + const size_t w_elements = nc_stride * kc_stride + nc_stride; + const size_t c_elements = mc * nc; + const size_t num_buffers = 1 + + benchmark::utils::DivideRoundUp(benchmark::utils::GetMaxCacheSize(), + sizeof(uint16_t) * (w_elements + c_elements)); + + std::vector> w(w_elements * num_buffers); + std::fill(w.begin(), w.end(), 0); + xnn_pack_f16_gemm_goi_w(1 /* groups */, nc, kc, nr, kr, sr, k.data(), b.data(), w.data(), 0, nullptr); + std::vector c(c_elements * num_buffers); + std::fill(c.begin(), c.end(), UINT16_C(0x7FC0) /* NaN */); + + // Prepare minmax parameters. + xnn_bf16_minmax_params params; + init_params(¶ms, + UINT16_C(0xFF80) /* -inf */, UINT16_C(0x7F80) /* inf */); + + size_t buffer_index = 0; + for (auto _ : state) { + // Use circular buffers (exceeding cache size) and prefetch to control cache state: + // - A is always in L1 cache (if fits, otherwise L2, L3, etc) + // - W is not in cache (for any cache level) + // - C is not in cache (for any cache level) + state.PauseTiming(); + benchmark::utils::PrefetchToL1(a.data(), a.size() * sizeof(uint16_t)); + buffer_index = (buffer_index + 1) % num_buffers; + state.ResumeTiming(); + + for (uint32_t m = 0; m < mc; m += mr) { + const uint32_t mb = min(mc - m, mr); + for (uint32_t n = 0; n < nc; n += nr) { + const uint32_t nb = min(nc - n, nr); + gemm( + mb, nb, kc * sizeof(uint16_t), + a.data() + m * kc, kc * sizeof(uint16_t), + w.data() + (nc_stride * buffer_index + n) * (kc_stride + 1), + c.data() + (mc * buffer_index + m) * nc + n, nc * sizeof(uint16_t), nr * sizeof(uint16_t), + ¶ms); + } + } + } + + const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency(); + if (cpu_frequency != 0) { + state.counters["cpufreq"] = cpu_frequency; + } + + state.counters["FLOPS"] = benchmark::Counter( + uint64_t(state.iterations()) * 2 * mc * nc * kc, benchmark::Counter::kIsRate); +} + + +#if XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + static void bf16_gemm_1x8c2__neonbf16_bfdot_lane_ld128(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, 1, 8, 2, 1, + xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16); + } + static void bf16_gemm_4x8c2__neonbf16_bfdot_lane_ld128(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, 4, 8, 2, 1, + xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16); + } + static void bf16_gemm_5x8c2__neonbf16_bfdot_lane_ld128(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, 5, 8, 2, 1, + xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16); + } + static void bf16_gemm_6x8c2__neonbf16_bfdot_lane_ld128(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, 6, 8, 2, 1, + xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16); + } + + static void bf16_gemm_1x4c8__neonbf16_bfdot(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, 1, 4, 8, 1, + xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16); + } + static void bf16_gemm_2x4c8__neonbf16_bfdot(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, 2, 4, 8, 1, + xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16); + } + static void bf16_gemm_3x4c8__neonbf16_bfdot(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, 3, 4, 8, 1, + xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16); + } + static void bf16_gemm_4x4c8__neonbf16_bfdot(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, 4, 4, 8, 1, + xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16); + } + static void bf16_gemm_5x4c8__neonbf16_bfdot(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, 5, 4, 8, 1, + xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16); + } + + static void bf16_gemm_1x4c8__neonbf16_bfmlal(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, 1, 4, 8, 1, + xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16); + } + static void bf16_gemm_2x4c8__neonbf16_bfmlal(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, 2, 4, 8, 1, + xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16); + } + static void bf16_gemm_3x4c8__neonbf16_bfmlal(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, 3, 4, 8, 1, + xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16); + } + static void bf16_gemm_4x4c8__neonbf16_bfmlal(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, 4, 4, 8, 1, + xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16); + } + static void bf16_gemm_5x4c8__neonbf16_bfmlal(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, 5, 4, 8, 1, + xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16); + } + + BENCHMARK_GEMM(bf16_gemm_1x8c2__neonbf16_bfdot_lane_ld128) + BENCHMARK_GEMM(bf16_gemm_4x8c2__neonbf16_bfdot_lane_ld128) + BENCHMARK_GEMM(bf16_gemm_5x8c2__neonbf16_bfdot_lane_ld128) + BENCHMARK_GEMM(bf16_gemm_6x8c2__neonbf16_bfdot_lane_ld128) + + BENCHMARK_GEMM(bf16_gemm_1x4c8__neonbf16_bfdot) + BENCHMARK_GEMM(bf16_gemm_2x4c8__neonbf16_bfdot) + BENCHMARK_GEMM(bf16_gemm_3x4c8__neonbf16_bfdot) + BENCHMARK_GEMM(bf16_gemm_4x4c8__neonbf16_bfdot) + BENCHMARK_GEMM(bf16_gemm_5x4c8__neonbf16_bfdot) + + BENCHMARK_GEMM(bf16_gemm_1x4c8__neonbf16_bfmlal) + BENCHMARK_GEMM(bf16_gemm_2x4c8__neonbf16_bfmlal) + BENCHMARK_GEMM(bf16_gemm_3x4c8__neonbf16_bfmlal) + BENCHMARK_GEMM(bf16_gemm_4x4c8__neonbf16_bfmlal) + BENCHMARK_GEMM(bf16_gemm_5x4c8__neonbf16_bfmlal) +#endif // XNN_ENABLE_ARM_FP16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + +#if XNN_ARCH_ARM64 + static void bf16_gemm_1x4c8__neonfma_zip(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, 1, 4, 8, 1, + xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONFMA); + } + static void bf16_gemm_2x4c8__neonfma_zip(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, 2, 4, 8, 1, + xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONFMA); + } + static void bf16_gemm_3x4c8__neonfma_zip(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, 3, 4, 8, 1, + xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONFMA); + } + static void bf16_gemm_4x4c8__neonfma_zip(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, 4, 4, 8, 1, + xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONFMA); + } + static void bf16_gemm_5x4c8__neonfma_zip(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, 5, 4, 8, 1, + xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONFMA); + } + + BENCHMARK_GEMM(bf16_gemm_1x4c8__neonfma_zip) + BENCHMARK_GEMM(bf16_gemm_2x4c8__neonfma_zip) + BENCHMARK_GEMM(bf16_gemm_3x4c8__neonfma_zip) + BENCHMARK_GEMM(bf16_gemm_4x4c8__neonfma_zip) + BENCHMARK_GEMM(bf16_gemm_5x4c8__neonfma_zip) +#endif // XNN_ARCH_ARM64 + +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 + static void bf16_gemm_1x4c8__neonfma_shland(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, 1, 4, 8, 1, + xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONFMA); + } + static void bf16_gemm_2x4c8__neonfma_shland(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, 2, 4, 8, 1, + xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONFMA); + } + static void bf16_gemm_3x4c8__neonfma_shland(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, 3, 4, 8, 1, + xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONFMA); + } + static void bf16_gemm_4x4c8__neonfma_shland(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, 4, 4, 8, 1, + xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONFMA); + } + static void bf16_gemm_5x4c8__neonfma_shland(benchmark::State& state, const char* net) { + GEMMBenchmark(state, xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, 5, 4, 8, 1, + xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONFMA); + } + + BENCHMARK_GEMM(bf16_gemm_1x4c8__neonfma_shland) + BENCHMARK_GEMM(bf16_gemm_2x4c8__neonfma_shland) + BENCHMARK_GEMM(bf16_gemm_3x4c8__neonfma_shland) + BENCHMARK_GEMM(bf16_gemm_4x4c8__neonfma_shland) + BENCHMARK_GEMM(bf16_gemm_5x4c8__neonfma_shland) +#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 + +#ifndef XNNPACK_BENCHMARK_NO_MAIN +BENCHMARK_MAIN(); +#endif diff --git a/bench/utils.cc b/bench/utils.cc index f43b762b3e9f..9797b328e58f 100644 --- a/bench/utils.cc +++ b/bench/utils.cc @@ -223,6 +223,14 @@ bool CheckNEONFP16ARITH(benchmark::State& state) { return true; } +bool CheckNEONBF16(benchmark::State& state) { + if (!cpuinfo_initialize() || !cpuinfo_has_arm_neon_bf16()) { + state.SkipWithError("no NEON-BF16 extension"); + return false; + } + return true; +} + bool CheckNEONDOT(benchmark::State& state) { if (!cpuinfo_initialize() || !cpuinfo_has_arm_neon_dot()) { state.SkipWithError("no NEON-DOT extension"); diff --git a/bench/utils.h b/bench/utils.h index 47ec5a22b78a..6c0a08c39f34 100644 --- a/bench/utils.h +++ b/bench/utils.h @@ -104,6 +104,10 @@ bool CheckNEONV8(benchmark::State& state); // If NEON-FP16-ARITH is unsupported, report error in benchmark state, and return false. bool CheckNEONFP16ARITH(benchmark::State& state); +// Check if ARM NEON-BF16 extension is supported. +// If NEON-BF16 is unsupported, report error in benchmark state, and return false. +bool CheckNEONBF16(benchmark::State& state); + // Check if ARM DOT extension is supported. // If DOT is unsupported, report error in benchmark state, and return false. bool CheckNEONDOT(benchmark::State& state); diff --git a/cmake/DownloadCLog.cmake b/cmake/DownloadCLog.cmake index 446f655b2c05..72a9d7393be3 100644 --- a/cmake/DownloadCLog.cmake +++ b/cmake/DownloadCLog.cmake @@ -12,8 +12,8 @@ PROJECT(clog-download NONE) INCLUDE(ExternalProject) ExternalProject_Add(clog - URL https://github.com/pytorch/cpuinfo/archive/d5e37adf1406cf899d7d9ec1d317c47506ccb970.tar.gz - URL_HASH SHA256=3f2dc1970f397a0e59db72f9fca6ff144b216895c1d606f6c94a507c1e53a025 + URL https://github.com/pytorch/cpuinfo/archive/49610f89b8b1eb52d75d1eda7a2c40c1e86a78e7.zip + URL_HASH SHA256=25843b5f21c32cba89f9b921c0500ab5cd0c2cb8fb0f345e5b5e4678329386c7 SOURCE_DIR "${CMAKE_BINARY_DIR}/clog-source" BINARY_DIR "${CMAKE_BINARY_DIR}/clog" CONFIGURE_COMMAND "" diff --git a/cmake/DownloadCpuinfo.cmake b/cmake/DownloadCpuinfo.cmake index a274c14194d9..92abfd906b08 100644 --- a/cmake/DownloadCpuinfo.cmake +++ b/cmake/DownloadCpuinfo.cmake @@ -12,8 +12,8 @@ PROJECT(cpuinfo-download NONE) INCLUDE(ExternalProject) ExternalProject_Add(cpuinfo - URL https://github.com/pytorch/cpuinfo/archive/5916273f79a21551890fd3d56fc5375a78d1598d.zip - URL_HASH SHA256=2a160c527d3c58085ce260f34f9e2b161adc009b34186a2baf24e74376e89e6d + URL https://github.com/pytorch/cpuinfo/archive/49610f89b8b1eb52d75d1eda7a2c40c1e86a78e7.zip + URL_HASH SHA256=25843b5f21c32cba89f9b921c0500ab5cd0c2cb8fb0f345e5b5e4678329386c7 SOURCE_DIR "${CMAKE_BINARY_DIR}/cpuinfo-source" BINARY_DIR "${CMAKE_BINARY_DIR}/cpuinfo" CONFIGURE_COMMAND "" diff --git a/scripts/build-android-armv7.sh b/scripts/build-android-armv7.sh index faedf781c78f..5942920cd43e 100755 --- a/scripts/build-android-armv7.sh +++ b/scripts/build-android-armv7.sh @@ -55,6 +55,9 @@ CMAKE_ARGS+=("-DANDROID_PIE=ON") CMAKE_ARGS+=("-DANDROID_STL=c++_static") CMAKE_ARGS+=("-DANDROID_CPP_FEATURES=exceptions") +# BF16 instructions cause ICE in Android NDK compiler +CMAKE_ARGS+=("-DXNNPACK_ENABLE_ARM_BF16=OFF") + # Use-specified CMake arguments go last to allow overridding defaults CMAKE_ARGS+=($@) diff --git a/scripts/generate-bf16-gemm.sh b/scripts/generate-bf16-gemm.sh new file mode 100755 index 000000000000..f221062714c2 --- /dev/null +++ b/scripts/generate-bf16-gemm.sh @@ -0,0 +1,41 @@ +#!/bin/sh +# Copyright 2022 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +################################### ARM NEON ################################## +### LD128 micro-kernels +tools/xngen src/bf16-gemm/c2-neonbf16-bfdot-lane-ld128.c.in -D MR=1 -D NR=8 -o src/bf16-gemm/gen/1x8c2-minmax-neonbf16-bfdot-lane-ld128.c & +tools/xngen src/bf16-gemm/c2-neonbf16-bfdot-lane-ld128.c.in -D MR=4 -D NR=8 -o src/bf16-gemm/gen/4x8c2-minmax-neonbf16-bfdot-lane-ld128.c & +tools/xngen src/bf16-gemm/c2-neonbf16-bfdot-lane-ld128.c.in -D MR=5 -D NR=8 -o src/bf16-gemm/gen/5x8c2-minmax-neonbf16-bfdot-lane-ld128.c & +tools/xngen src/bf16-gemm/c2-neonbf16-bfdot-lane-ld128.c.in -D MR=6 -D NR=8 -o src/bf16-gemm/gen/6x8c2-minmax-neonbf16-bfdot-lane-ld128.c & + +tools/xngen src/bf16-gemm/c8-neon.c.in -D MR=1 -D NR=4 -D EXTOPT=SHLAND -o src/bf16-gemm/gen/1x4c8-minmax-neonfma-shland.c & +tools/xngen src/bf16-gemm/c8-neon.c.in -D MR=2 -D NR=4 -D EXTOPT=SHLAND -o src/bf16-gemm/gen/2x4c8-minmax-neonfma-shland.c & +tools/xngen src/bf16-gemm/c8-neon.c.in -D MR=3 -D NR=4 -D EXTOPT=SHLAND -o src/bf16-gemm/gen/3x4c8-minmax-neonfma-shland.c & +tools/xngen src/bf16-gemm/c8-neon.c.in -D MR=4 -D NR=4 -D EXTOPT=SHLAND -o src/bf16-gemm/gen/4x4c8-minmax-neonfma-shland.c & +tools/xngen src/bf16-gemm/c8-neon.c.in -D MR=5 -D NR=4 -D EXTOPT=SHLAND -o src/bf16-gemm/gen/5x4c8-minmax-neonfma-shland.c & + +tools/xngen src/bf16-gemm/c8-neon.c.in -D MR=1 -D NR=4 -D EXTOPT=ZIP -o src/bf16-gemm/gen/1x4c8-minmax-neonfma-zip.c & +tools/xngen src/bf16-gemm/c8-neon.c.in -D MR=2 -D NR=4 -D EXTOPT=ZIP -o src/bf16-gemm/gen/2x4c8-minmax-neonfma-zip.c & +tools/xngen src/bf16-gemm/c8-neon.c.in -D MR=3 -D NR=4 -D EXTOPT=ZIP -o src/bf16-gemm/gen/3x4c8-minmax-neonfma-zip.c & +tools/xngen src/bf16-gemm/c8-neon.c.in -D MR=4 -D NR=4 -D EXTOPT=ZIP -o src/bf16-gemm/gen/4x4c8-minmax-neonfma-zip.c & +tools/xngen src/bf16-gemm/c8-neon.c.in -D MR=5 -D NR=4 -D EXTOPT=ZIP -o src/bf16-gemm/gen/5x4c8-minmax-neonfma-zip.c & + +tools/xngen src/bf16-gemm/c8-neonbf16.c.in -D MR=1 -D NR=4 -D BFOPT=BFDOT -o src/bf16-gemm/gen/1x4c8-minmax-neonbf16-bfdot.c & +tools/xngen src/bf16-gemm/c8-neonbf16.c.in -D MR=2 -D NR=4 -D BFOPT=BFDOT -o src/bf16-gemm/gen/2x4c8-minmax-neonbf16-bfdot.c & +tools/xngen src/bf16-gemm/c8-neonbf16.c.in -D MR=3 -D NR=4 -D BFOPT=BFDOT -o src/bf16-gemm/gen/3x4c8-minmax-neonbf16-bfdot.c & +tools/xngen src/bf16-gemm/c8-neonbf16.c.in -D MR=4 -D NR=4 -D BFOPT=BFDOT -o src/bf16-gemm/gen/4x4c8-minmax-neonbf16-bfdot.c & +tools/xngen src/bf16-gemm/c8-neonbf16.c.in -D MR=5 -D NR=4 -D BFOPT=BFDOT -o src/bf16-gemm/gen/5x4c8-minmax-neonbf16-bfdot.c & + +tools/xngen src/bf16-gemm/c8-neonbf16.c.in -D MR=1 -D NR=4 -D BFOPT=BFMLAL -o src/bf16-gemm/gen/1x4c8-minmax-neonbf16-bfmlal.c & +tools/xngen src/bf16-gemm/c8-neonbf16.c.in -D MR=2 -D NR=4 -D BFOPT=BFMLAL -o src/bf16-gemm/gen/2x4c8-minmax-neonbf16-bfmlal.c & +tools/xngen src/bf16-gemm/c8-neonbf16.c.in -D MR=3 -D NR=4 -D BFOPT=BFMLAL -o src/bf16-gemm/gen/3x4c8-minmax-neonbf16-bfmlal.c & +tools/xngen src/bf16-gemm/c8-neonbf16.c.in -D MR=4 -D NR=4 -D BFOPT=BFMLAL -o src/bf16-gemm/gen/4x4c8-minmax-neonbf16-bfmlal.c & +tools/xngen src/bf16-gemm/c8-neonbf16.c.in -D MR=5 -D NR=4 -D BFOPT=BFMLAL -o src/bf16-gemm/gen/5x4c8-minmax-neonbf16-bfmlal.c & + +################################## Unit tests ################################# +tools/generate-gemm-test.py --spec test/bf16-gemm-minmax.yaml --output test/bf16-gemm-minmax.cc & + +wait diff --git a/src/bf16-gemm/c2-neonbf16-bfdot-lane-ld128.c.in b/src/bf16-gemm/c2-neonbf16-bfdot-lane-ld128.c.in new file mode 100644 index 000000000000..9bd0196e3e02 --- /dev/null +++ b/src/bf16-gemm/c2-neonbf16-bfdot-lane-ld128.c.in @@ -0,0 +1,197 @@ +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +$assert NR % 4 == 0 +$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_${MR}x${NR}c2__neonbf16_bfdot_lane_ld128( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= ${MR}); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(bfloat16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const bfloat16_t* a0 = (const bfloat16_t*) a; + bfloat16_t* c0 = (bfloat16_t*) c; + $for M in range(1, MR): + const bfloat16_t* a${M} = (const bfloat16_t*) ((uintptr_t) a${M-1} + a_stride); + bfloat16_t* c${M} = (bfloat16_t*) ((uintptr_t) c${M-1} + cm_stride); + $if M % 2 == 0: + if XNN_UNPREDICTABLE(mr <= ${M}) { + a${M} = a${M-1}; + c${M} = c${M-1}; + } + $elif M + 1 == MR: + if XNN_UNPREDICTABLE(mr != ${M+1}) { + a${M} = a${M-1}; + c${M} = c${M-1}; + } + $else: + if XNN_UNPREDICTABLE(mr < ${M+1}) { + a${M} = a${M-1}; + c${M} = c${M-1}; + } + + const bfloat16_t* w = (const bfloat16_t*) w_ptr; + do { + $for N in range(0, NR, 4): + float32x4_t vacc0x${ABC[N:N+4]} = vcvt_f32_bf16(vld1_bf16(w)); w += 4; + $for M in range(1, MR): + $for N in range(0, NR, 4): + float32x4_t vacc${M}x${ABC[N:N+4]} = vacc0x${ABC[N:N+4]}; + + size_t k = kc; + for (; k >= 8 * sizeof(bfloat16_t); k -= 8 * sizeof(bfloat16_t)) { + $for M in range(MR): + const bfloat16x8_t va${M} = vld1q_bf16(a${M}); a${M} += 8; + + $for K in range(4): + $for N in range(0, NR, 4): + const bfloat16x8_t vb${ABC[N:N+4]}c${ABC[2*K:2*K+2]} = vld1q_bf16(w); w += 8; + + $for N in range(0, NR, 4): + $for M in range(MR): + vacc${M}x${ABC[N:N+4]} = vbfdotq_laneq_f32(vacc${M}x${ABC[N:N+4]}, vb${ABC[N:N+4]}c${ABC[2*K:2*K+2]}, va${M}, ${K}); + } + if XNN_UNLIKELY(k != 0) { + $for M in range(MR): + const bfloat16x8_t va${M} = vld1q_bf16(a${M}); a${M} = (const bfloat16_t*) ((uintptr_t) a${M} + k); + + $for N in range(0, NR, 4): + const bfloat16x8_t vb${ABC[N:N+4]}c${ABC[0:2]} = vld1q_bf16(w); w += 8; + + $for M in range(MR): + const uint32x4_t va${M}c${ABC[0:2]} = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va${M})), 0); + + $for N in range(0, NR, 4): + const uint32x4_t vm${ABC[N:N+4]}c${ABC[0:2]} = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb${ABC[N:N+4]}c${ABC[0:2]}), vmovq_n_u16(0))); + + $for N in range(0, NR, 4): + $for M in range(MR): + const uint32x4_t va${M}x${ABC[N:N+4]}c${ABC[0:2]} = vbicq_u32(va${M}c${ABC[0:2]}, vm${ABC[N:N+4]}c${ABC[0:2]}); + vacc${M}x${ABC[N:N+4]} = vbfdotq_f32(vacc${M}x${ABC[N:N+4]}, vb${ABC[N:N+4]}c${ABC[0:2]}, vreinterpretq_bf16_u32(va${M}x${ABC[N:N+4]}c${ABC[0:2]})); + + if (k > 2 * sizeof(bfloat16_t)) { + $for N in range(0, NR, 4): + const bfloat16x8_t vb${ABC[N:N+4]}c${ABC[2:4]} = vld1q_bf16(w); w += 8; + + $for M in range(MR): + const uint32x4_t va${M}c${ABC[2:4]} = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va${M})), 1); + + $for N in range(0, NR, 4): + const uint32x4_t vm${ABC[N:N+4]}c${ABC[2:4]} = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb${ABC[N:N+4]}c${ABC[2:4]}), vmovq_n_u16(0))); + + $for N in range(0, NR, 4): + $for M in range(MR): + const uint32x4_t va${M}x${ABC[N:N+4]}c${ABC[2:4]} = vbicq_u32(va${M}c${ABC[2:4]}, vm${ABC[N:N+4]}c${ABC[2:4]}); + vacc${M}x${ABC[N:N+4]} = vbfdotq_f32(vacc${M}x${ABC[N:N+4]}, vb${ABC[N:N+4]}c${ABC[2:4]}, vreinterpretq_bf16_u32(va${M}x${ABC[N:N+4]}c${ABC[2:4]})); + + if (k > 4 * sizeof(bfloat16_t)) { + $for N in range(0, NR, 4): + const bfloat16x8_t vb${ABC[N:N+4]}c${ABC[4:6]} = vld1q_bf16(w); w += 8; + + $for M in range(MR): + const uint32x4_t va${M}c${ABC[4:6]} = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va${M})), 0); + + $for N in range(0, NR, 4): + const uint32x4_t vm${ABC[N:N+4]}c${ABC[4:6]} = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb${ABC[N:N+4]}c${ABC[4:6]}), vmovq_n_u16(0))); + + $for N in range(0, NR, 4): + $for M in range(MR): + const uint32x4_t va${M}x${ABC[N:N+4]}c${ABC[4:6]} = vbicq_u32(va${M}c${ABC[4:6]}, vm${ABC[N:N+4]}c${ABC[4:6]}); + vacc${M}x${ABC[N:N+4]} = vbfdotq_f32(vacc${M}x${ABC[N:N+4]}, vb${ABC[N:N+4]}c${ABC[4:6]}, vreinterpretq_bf16_u32(va${M}x${ABC[N:N+4]}c${ABC[4:6]})); + + if (k > 6 * sizeof(bfloat16_t)) { + $for N in range(0, NR, 4): + const bfloat16x8_t vb${ABC[N:N+4]}c${ABC[6:8]} = vld1q_bf16(w); w += 8; + + $for M in range(MR): + const uint32x4_t va${M}c${ABC[6:8]} = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va${M})), 1); + + $for N in range(0, NR, 4): + const uint32x4_t vm${ABC[N:N+4]}c${ABC[6:8]} = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb${ABC[N:N+4]}c${ABC[6:8]}), vmovq_n_u16(0))); + + $for N in range(0, NR, 4): + $for M in range(MR): + const uint32x4_t va${M}x${ABC[N:N+4]}c${ABC[6:8]} = vbicq_u32(va${M}c${ABC[6:8]}, vm${ABC[N:N+4]}c${ABC[6:8]}); + vacc${M}x${ABC[N:N+4]} = vbfdotq_f32(vacc${M}x${ABC[N:N+4]}, vb${ABC[N:N+4]}c${ABC[6:8]}, vreinterpretq_bf16_u32(va${M}x${ABC[N:N+4]}c${ABC[6:8]})); + } + } + } + } + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + $for N in range(0, NR, 4): + $for M in range(MR): + vacc${M}x${ABC[N:N+4]} = vminq_f32(vacc${M}x${ABC[N:N+4]}, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + $for N in range(0, NR, 4): + $for M in range(MR): + vacc${M}x${ABC[N:N+4]} = vmaxq_f32(vacc${M}x${ABC[N:N+4]}, vmin); + + $for N in range(0, NR, 4): + $for M in range(MR): + bfloat16x4_t vout${M}x${ABC[N:N+4]} = vcvt_bf16_f32(vacc${M}x${ABC[N:N+4]}); + + if XNN_LIKELY(nc >= ${NR}) { + $for M in range(MR): + vst1_bf16(c${M}, vout${M}x${ABC[0:4]}); + $for N in range(4, NR, 4): + vst1_bf16(c${M} + ${N}, vout${M}x${ABC[N:N+4]}); + c${M} = (bfloat16_t*) ((uintptr_t) c${M} + cn_stride); + + $for M in range(MR): + a${M} = (const bfloat16_t*) ((uintptr_t) a${M} - kc); + + nc -= ${NR}; + } else { + $for LOG2N in reversed(range(NR.bit_length())): + $if NR != 1 << LOG2N: + if (nc & ${1 << LOG2N}) { + $if LOG2N >= 2: + $for N in range(0, 1 << LOG2N, 4): + $for M in range(MR): + vst1_bf16(c${M}, vout${M}x${ABC[N:N+4]}); c${M} += 4; + + $for M in range(MR): + $for N in range(0, 1 << (LOG2N - 1), 4): + vout${M}x${ABC[N:N+4]} = vout${M}x${ABC[N + (1 << LOG2N):N + (1 << LOG2N)+4]}; + $elif LOG2N == 1: + $for M in range(MR): + vst1_lane_u32((void*) c${M}, vreinterpret_u32_bf16(vout${M}x${ABC[0:4]}), 0); c${M} += 2; + + $for M in range(MR): + vout${M}x${ABC[0:4]} = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout${M}x${ABC[0:4]}), vreinterpret_u16_bf16(vout${M}x${ABC[0:4]}), 2)); + $elif LOG2N == 0: + $for M in range(MR): + vst1_lane_bf16(c${M}, vout${M}x${ABC[0:4]}, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/bf16-gemm/c8-neon.c.in b/src/bf16-gemm/c8-neon.c.in new file mode 100644 index 000000000000..6fbf68645863 --- /dev/null +++ b/src/bf16-gemm/c8-neon.c.in @@ -0,0 +1,229 @@ +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +$assert NR % 4 == 0 +$assert EXTOPT in ["SHLAND", "ZIP", "MOVL"] +$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_${MR}x${NR}c8__neonfma_${EXTOPT.lower()}( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= ${MR}); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(uint16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const uint16_t* a0 = (const uint16_t*) a; + uint16_t* c0 = (uint16_t*) c; + $for M in range(1, MR): + const uint16_t* a${M} = (const uint16_t*) ((uintptr_t) a${M-1} + a_stride); + uint16_t* c${M} = (uint16_t*) ((uintptr_t) c${M-1} + cm_stride); + $if M % 2 == 0: + if XNN_UNPREDICTABLE(mr <= ${M}) { + a${M} = a${M-1}; + c${M} = c${M-1}; + } + $elif M + 1 == MR: + if XNN_UNPREDICTABLE(mr != ${M+1}) { + a${M} = a${M-1}; + c${M} = c${M-1}; + } + $else: + if XNN_UNPREDICTABLE(mr < ${M+1}) { + a${M} = a${M-1}; + c${M} = c${M-1}; + } + + const uint16_t* w = (const uint16_t*) w_ptr; + $if EXTOPT == "SHLAND": + const uint16x8_t vmask = vreinterpretq_u16_u32(vmovq_n_u32(UINT32_C(0xFFFF0000))); + $elif EXTOPT == "ZIP": + const uint16x8_t vzero = vmovq_n_u16(0); + do { + $for N in range(NR): + float32x4_t vacc0x${ABC[N]} = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + $for M in range(1, MR): + $for N in range(NR): + float32x4_t vacc${M}x${ABC[N]} = vacc0x${ABC[N]}; + + size_t k = kc; + for (; k >= 8 * sizeof(uint16_t); k -= 8 * sizeof(uint16_t)) { + $for M in range(MR): + const uint16x8_t va${M} = vld1q_u16(a${M}); a${M} += 8; + + $for N in range(NR): + const uint16x8_t vb${ABC[N]} = vld1q_u16(w); w += 8; + + $for M in range(MR): + $if EXTOPT == "SHLAND": + const float32x4_t va${M}e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va${M}), 16)); + $elif EXTOPT == "ZIP": + const float32x4_t va${M}e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va${M})); + + $for N in range(NR): + $if EXTOPT == "SHLAND": + const float32x4_t vb${ABC[N]}e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb${ABC[N]}), 16)); + $elif EXTOPT == "ZIP": + const float32x4_t vb${ABC[N]}e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb${ABC[N]})); + + $for N in range(NR): + $for M in range(MR): + vacc${M}x${ABC[N]} = vfmaq_f32(vacc${M}x${ABC[N]}, va${M}e, vb${ABC[N]}e); + + $for M in range(MR): + $if EXTOPT == "SHLAND": + const float32x4_t va${M}o = vreinterpretq_f32_u16(vandq_u16(va${M}, vmask)); + $elif EXTOPT == "ZIP": + const float32x4_t va${M}o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va${M})); + + $for N in range(NR): + $if EXTOPT == "SHLAND": + const float32x4_t vb${ABC[N]}o = vreinterpretq_f32_u16(vandq_u16(vb${ABC[N]}, vmask)); + $elif EXTOPT == "ZIP": + const float32x4_t vb${ABC[N]}o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb${ABC[N]})); + + $for N in range(NR): + $for M in range(MR): + vacc${M}x${ABC[N]} = vfmaq_f32(vacc${M}x${ABC[N]}, va${M}o, vb${ABC[N]}o); + } + if XNN_UNLIKELY(k != 0) { + $for M in range(MR): + const uint16x8_t va${M} = vld1q_u16(a${M}); a${M} = (const uint16_t*) ((uintptr_t) a${M} + k); + + $for N in range(NR): + const uint16x8_t vb${ABC[N]} = vld1q_u16(w); w += 8; + + $for N in range(NR): + const uint16x8_t vm${ABC[N]} = vceqq_u16(vb${ABC[N]}, vmovq_n_u16(0)); + + $for N in range(NR): + $if EXTOPT == "SHLAND": + const float32x4_t vb${ABC[N]}e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb${ABC[N]}), 16)); + $elif EXTOPT == "ZIP": + const float32x4_t vb${ABC[N]}e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb${ABC[N]})); + + $for N in range(NR): + $for M in range(MR): + const uint16x8_t va${M}x${ABC[N]} = vbicq_u16(va${M}, vm${ABC[N]}); + + $for N in range(NR): + $for M in range(MR): + $if EXTOPT == "SHLAND": + const float32x4_t va${M}x${ABC[N]}e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va${M}x${ABC[N]}), 16)); + $elif EXTOPT == "ZIP": + const float32x4_t va${M}x${ABC[N]}e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va${M}x${ABC[N]})); + + $for N in range(NR): + $for M in range(MR): + vacc${M}x${ABC[N]} = vfmaq_f32(vacc${M}x${ABC[N]}, va${M}x${ABC[N]}e, vb${ABC[N]}e); + + $for N in range(NR): + $if EXTOPT == "SHLAND": + const float32x4_t vb${ABC[N]}o = vreinterpretq_f32_u16(vandq_u16(vb${ABC[N]}, vmask)); + $elif EXTOPT == "ZIP": + const float32x4_t vb${ABC[N]}o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb${ABC[N]})); + + $for N in range(NR): + $for M in range(MR): + $if EXTOPT == "SHLAND": + const float32x4_t va${M}x${ABC[N]}o = vreinterpretq_f32_u16(vandq_u16(va${M}x${ABC[N]}, vmask)); + $elif EXTOPT == "ZIP": + const float32x4_t va${M}x${ABC[N]}o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va${M}x${ABC[N]})); + + $for N in range(NR): + $for M in range(MR): + vacc${M}x${ABC[N]} = vfmaq_f32(vacc${M}x${ABC[N]}, va${M}x${ABC[N]}o, vb${ABC[N]}o); + } + +#if XNN_ARCH_ARM64 + $for N in range(0, NR, 2): + $for M in range(MR): + const float32x4_t vacc${M}x${ABC[N:N+2]} = vpaddq_f32(vacc${M}x${ABC[N]}, vacc${M}x${ABC[N+1]}); + + $for N in range(0, NR, 4): + $for M in range(MR): + float32x4_t vacc${M}x${ABC[N:N+4]} = vpaddq_f32(vacc${M}x${ABC[N:N+2]}, vacc${M}x${ABC[N+2:N+4]}); +#else + $for N in range(NR): + $for M in range(MR): + const float32x2_t vsum${M}x${ABC[N]} = vadd_f32(vget_low_f32(vacc${M}x${ABC[N]}), vget_high_f32(vacc${M}x${ABC[N]})); + + $for N in range(0, NR, 4): + $for M in range(MR): + float32x4_t vacc${M}x${ABC[N:N+4]} = vcombine_f32(vpadd_f32(vsum${M}x${ABC[N]}, vsum${M}x${ABC[N+1]}), vpadd_f32(vsum${M}x${ABC[N+2]}, vsum${M}x${ABC[N+3]})); +#endif + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + $for N in range(0, NR, 4): + $for M in range(MR): + vacc${M}x${ABC[N:N+4]} = vminq_f32(vacc${M}x${ABC[N:N+4]}, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + $for N in range(0, NR, 4): + $for M in range(MR): + vacc${M}x${ABC[N:N+4]} = vmaxq_f32(vacc${M}x${ABC[N:N+4]}, vmin); + + $for N in range(0, NR, 4): + $for M in range(MR): + uint16x4_t vout${M}x${ABC[N:N+4]} = vshrn_n_u32(vreinterpretq_u32_f32(vacc${M}x${ABC[N:N+4]}), 16); + + if XNN_LIKELY(nc >= ${NR}) { + $for M in range(MR): + vst1_u16(c${M}, vout${M}x${ABC[0:4]}); + $for N in range(4, NR, 4): + vst1_u16(c${M} + ${N}, vout${M}x${ABC[N:N+4]}); + c${M} = (uint16_t*) ((uintptr_t) c${M} + cn_stride); + + $for M in range(MR): + a${M} = (const uint16_t*) ((uintptr_t) a${M} - kc); + + nc -= ${NR}; + } else { + $for LOG2N in reversed(range(NR.bit_length())): + $if NR != 1 << LOG2N: + if (nc & ${1 << LOG2N}) { + $if LOG2N >= 2: + $for N in range(0, 1 << LOG2N, 4): + $for M in range(MR): + vst1_u16(c${M}, vout${M}x${ABC[N:N+4]}); c${M} += 4; + + $for M in range(MR): + $for N in range(0, 1 << (LOG2N - 1), 4): + vout${M}x${ABC[N:N+4]} = vout${M}x${ABC[N + (1 << LOG2N):N + (1 << LOG2N)+4]}; + $elif LOG2N == 1: + $for M in range(MR): + vst1_lane_u32((void*) c${M}, vreinterpret_u32_u16(vout${M}x${ABC[0:4]}), 0); c${M} += 2; + + $for M in range(MR): + vout${M}x${ABC[0:4]} = vext_u16(vout${M}x${ABC[0:4]}, vout${M}x${ABC[0:4]}, 2); + $elif LOG2N == 0: + $for M in range(MR): + vst1_lane_u16(c${M}, vout${M}x${ABC[0:4]}, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/bf16-gemm/c8-neonbf16.c.in b/src/bf16-gemm/c8-neonbf16.c.in new file mode 100644 index 000000000000..19a881a2e4d6 --- /dev/null +++ b/src/bf16-gemm/c8-neonbf16.c.in @@ -0,0 +1,177 @@ +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +$assert NR % 4 == 0 +$assert BFOPT in ["BFDOT", "BFMLAL"] +$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_${MR}x${NR}c8__neonbf16_${BFOPT.lower()}( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= ${MR}); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(bfloat16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const bfloat16_t* a0 = (const bfloat16_t*) a; + bfloat16_t* c0 = (bfloat16_t*) c; + $for M in range(1, MR): + const bfloat16_t* a${M} = (const bfloat16_t*) ((uintptr_t) a${M-1} + a_stride); + bfloat16_t* c${M} = (bfloat16_t*) ((uintptr_t) c${M-1} + cm_stride); + $if M % 2 == 0: + if XNN_UNPREDICTABLE(mr <= ${M}) { + a${M} = a${M-1}; + c${M} = c${M-1}; + } + $elif M + 1 == MR: + if XNN_UNPREDICTABLE(mr != ${M+1}) { + a${M} = a${M-1}; + c${M} = c${M-1}; + } + $else: + if XNN_UNPREDICTABLE(mr < ${M+1}) { + a${M} = a${M-1}; + c${M} = c${M-1}; + } + + const bfloat16_t* w = (const bfloat16_t*) w_ptr; + do { + $for N in range(NR): + float32x4_t vacc0x${ABC[N]} = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + $for M in range(1, MR): + $for N in range(NR): + float32x4_t vacc${M}x${ABC[N]} = vacc0x${ABC[N]}; + + size_t k = kc; + for (; k >= 8 * sizeof(bfloat16_t); k -= 8 * sizeof(bfloat16_t)) { + $for M in range(MR): + const bfloat16x8_t va${M} = vld1q_bf16(a${M}); a${M} += 8; + + $for N in range(NR): + const bfloat16x8_t vb${ABC[N]} = vld1q_bf16(w); w += 8; + + $if BFOPT == "BFDOT": + $for N in range(NR): + $for M in range(MR): + vacc${M}x${ABC[N]} = vbfdotq_f32(vacc${M}x${ABC[N]}, va${M}, vb${ABC[N]}); + $elif BFOPT == "BFMLAL": + $for N in range(NR): + $for M in range(MR): + vacc${M}x${ABC[N]} = vbfmlalbq_f32(vacc${M}x${ABC[N]}, va${M}, vb${ABC[N]}); + + $for N in range(NR): + $for M in range(MR): + vacc${M}x${ABC[N]} = vbfmlaltq_f32(vacc${M}x${ABC[N]}, va${M}, vb${ABC[N]}); + } + if XNN_UNLIKELY(k != 0) { + $for M in range(MR): + const bfloat16x8_t va${M} = vld1q_bf16(a${M}); a${M} = (const bfloat16_t*) ((uintptr_t) a${M} + k); + + $for N in range(NR): + const bfloat16x8_t vb${ABC[N]} = vld1q_bf16(w); w += 8; + + $for N in range(NR): + const uint16x8_t vm${ABC[N]} = vceqq_u16(vreinterpretq_u16_bf16(vb${ABC[N]}), vmovq_n_u16(0)); + + $for N in range(NR): + $for M in range(MR): + const bfloat16x8_t va${M}x${ABC[N]} = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va${M}), vm${ABC[N]})); + $if BFOPT == "BFDOT": + vacc${M}x${ABC[N]} = vbfdotq_f32(vacc${M}x${ABC[N]}, va${M}x${ABC[N]}, vb${ABC[N]}); + $elif BFOPT == "BFMLAL": + vacc${M}x${ABC[N]} = vbfmlalbq_f32(vacc${M}x${ABC[N]}, va${M}x${ABC[N]}, vb${ABC[N]}); + vacc${M}x${ABC[N]} = vbfmlaltq_f32(vacc${M}x${ABC[N]}, va${M}x${ABC[N]}, vb${ABC[N]}); + } + +#if XNN_ARCH_ARM64 + $for N in range(0, NR, 2): + $for M in range(MR): + const float32x4_t vacc${M}x${ABC[N:N+2]} = vpaddq_f32(vacc${M}x${ABC[N]}, vacc${M}x${ABC[N+1]}); + + $for N in range(0, NR, 4): + $for M in range(MR): + float32x4_t vacc${M}x${ABC[N:N+4]} = vpaddq_f32(vacc${M}x${ABC[N:N+2]}, vacc${M}x${ABC[N+2:N+4]}); +#else + $for N in range(NR): + $for M in range(MR): + const float32x2_t vsum${M}x${ABC[N]} = vadd_f32(vget_low_f32(vacc${M}x${ABC[N]}), vget_high_f32(vacc${M}x${ABC[N]})); + + $for N in range(0, NR, 4): + $for M in range(MR): + float32x4_t vacc${M}x${ABC[N:N+4]} = vcombine_f32(vpadd_f32(vsum${M}x${ABC[N]}, vsum${M}x${ABC[N+1]}), vpadd_f32(vsum${M}x${ABC[N+2]}, vsum${M}x${ABC[N+3]})); +#endif + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + $for N in range(0, NR, 4): + $for M in range(MR): + vacc${M}x${ABC[N:N+4]} = vminq_f32(vacc${M}x${ABC[N:N+4]}, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + $for N in range(0, NR, 4): + $for M in range(MR): + vacc${M}x${ABC[N:N+4]} = vmaxq_f32(vacc${M}x${ABC[N:N+4]}, vmin); + + $for N in range(0, NR, 4): + $for M in range(MR): + bfloat16x4_t vout${M}x${ABC[N:N+4]} = vcvt_bf16_f32(vacc${M}x${ABC[N:N+4]}); + + if XNN_LIKELY(nc >= ${NR}) { + $for M in range(MR): + vst1_bf16(c${M}, vout${M}x${ABC[0:4]}); + $for N in range(4, NR, 4): + vst1_bf16(c${M} + ${N}, vout${M}x${ABC[N:N+4]}); + c${M} = (bfloat16_t*) ((uintptr_t) c${M} + cn_stride); + + $for M in range(MR): + a${M} = (const bfloat16_t*) ((uintptr_t) a${M} - kc); + + nc -= ${NR}; + } else { + $for LOG2N in reversed(range(NR.bit_length())): + $if NR != 1 << LOG2N: + if (nc & ${1 << LOG2N}) { + $if LOG2N >= 2: + $for N in range(0, 1 << LOG2N, 4): + $for M in range(MR): + vst1_bf16(c${M}, vout${M}x${ABC[N:N+4]}); c${M} += 4; + + $for M in range(MR): + $for N in range(0, 1 << (LOG2N - 1), 4): + vout${M}x${ABC[N:N+4]} = vout${M}x${ABC[N + (1 << LOG2N):N + (1 << LOG2N)+4]}; + $elif LOG2N == 1: + $for M in range(MR): + vst1_lane_u32((void*) c${M}, vreinterpret_u32_bf16(vout${M}x${ABC[0:4]}), 0); c${M} += 2; + + $for M in range(MR): + vout${M}x${ABC[0:4]} = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout${M}x${ABC[0:4]}), vreinterpret_u16_bf16(vout${M}x${ABC[0:4]}), 2)); + $elif LOG2N == 0: + $for M in range(MR): + vst1_lane_bf16(c${M}, vout${M}x${ABC[0:4]}, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/bf16-gemm/gen/1x4c8-minmax-neonbf16-bfdot.c b/src/bf16-gemm/gen/1x4c8-minmax-neonbf16-bfdot.c new file mode 100644 index 000000000000..2701b69d5fbf --- /dev/null +++ b/src/bf16-gemm/gen/1x4c8-minmax-neonbf16-bfdot.c @@ -0,0 +1,128 @@ +// Auto-generated file. Do not edit! +// Template: src/bf16-gemm/c8-neonbf16.c.in +// Generator: tools/xngen +// +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 1); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(bfloat16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const bfloat16_t* a0 = (const bfloat16_t*) a; + bfloat16_t* c0 = (bfloat16_t*) c; + + const bfloat16_t* w = (const bfloat16_t*) w_ptr; + do { + float32x4_t vacc0x0 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x1 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x2 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x3 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + + size_t k = kc; + for (; k >= 8 * sizeof(bfloat16_t); k -= 8 * sizeof(bfloat16_t)) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 += 8; + + const bfloat16x8_t vb0 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb1 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb2 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb3 = vld1q_bf16(w); w += 8; + + vacc0x0 = vbfdotq_f32(vacc0x0, va0, vb0); + vacc0x1 = vbfdotq_f32(vacc0x1, va0, vb1); + vacc0x2 = vbfdotq_f32(vacc0x2, va0, vb2); + vacc0x3 = vbfdotq_f32(vacc0x3, va0, vb3); + } + if XNN_UNLIKELY(k != 0) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 = (const bfloat16_t*) ((uintptr_t) a0 + k); + + const bfloat16x8_t vb0 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb1 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb2 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb3 = vld1q_bf16(w); w += 8; + + const uint16x8_t vm0 = vceqq_u16(vreinterpretq_u16_bf16(vb0), vmovq_n_u16(0)); + const uint16x8_t vm1 = vceqq_u16(vreinterpretq_u16_bf16(vb1), vmovq_n_u16(0)); + const uint16x8_t vm2 = vceqq_u16(vreinterpretq_u16_bf16(vb2), vmovq_n_u16(0)); + const uint16x8_t vm3 = vceqq_u16(vreinterpretq_u16_bf16(vb3), vmovq_n_u16(0)); + + const bfloat16x8_t va0x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm0)); + vacc0x0 = vbfdotq_f32(vacc0x0, va0x0, vb0); + const bfloat16x8_t va0x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm1)); + vacc0x1 = vbfdotq_f32(vacc0x1, va0x1, vb1); + const bfloat16x8_t va0x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm2)); + vacc0x2 = vbfdotq_f32(vacc0x2, va0x2, vb2); + const bfloat16x8_t va0x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm3)); + vacc0x3 = vbfdotq_f32(vacc0x3, va0x3, vb3); + } + +#if XNN_ARCH_ARM64 + const float32x4_t vacc0x01 = vpaddq_f32(vacc0x0, vacc0x1); + const float32x4_t vacc0x23 = vpaddq_f32(vacc0x2, vacc0x3); + + float32x4_t vacc0x0123 = vpaddq_f32(vacc0x01, vacc0x23); +#else + const float32x2_t vsum0x0 = vadd_f32(vget_low_f32(vacc0x0), vget_high_f32(vacc0x0)); + const float32x2_t vsum0x1 = vadd_f32(vget_low_f32(vacc0x1), vget_high_f32(vacc0x1)); + const float32x2_t vsum0x2 = vadd_f32(vget_low_f32(vacc0x2), vget_high_f32(vacc0x2)); + const float32x2_t vsum0x3 = vadd_f32(vget_low_f32(vacc0x3), vget_high_f32(vacc0x3)); + + float32x4_t vacc0x0123 = vcombine_f32(vpadd_f32(vsum0x0, vsum0x1), vpadd_f32(vsum0x2, vsum0x3)); +#endif + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + vacc0x0123 = vminq_f32(vacc0x0123, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + vacc0x0123 = vmaxq_f32(vacc0x0123, vmin); + + bfloat16x4_t vout0x0123 = vcvt_bf16_f32(vacc0x0123); + + if XNN_LIKELY(nc >= 4) { + vst1_bf16(c0, vout0x0123); + c0 = (bfloat16_t*) ((uintptr_t) c0 + cn_stride); + + a0 = (const bfloat16_t*) ((uintptr_t) a0 - kc); + + nc -= 4; + } else { + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_bf16(vout0x0123), 0); c0 += 2; + + vout0x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout0x0123), vreinterpret_u16_bf16(vout0x0123), 2)); + } + if (nc & 1) { + vst1_lane_bf16(c0, vout0x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/bf16-gemm/gen/1x4c8-minmax-neonbf16-bfmlal.c b/src/bf16-gemm/gen/1x4c8-minmax-neonbf16-bfmlal.c new file mode 100644 index 000000000000..5a0b8b61720c --- /dev/null +++ b/src/bf16-gemm/gen/1x4c8-minmax-neonbf16-bfmlal.c @@ -0,0 +1,137 @@ +// Auto-generated file. Do not edit! +// Template: src/bf16-gemm/c8-neonbf16.c.in +// Generator: tools/xngen +// +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 1); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(bfloat16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const bfloat16_t* a0 = (const bfloat16_t*) a; + bfloat16_t* c0 = (bfloat16_t*) c; + + const bfloat16_t* w = (const bfloat16_t*) w_ptr; + do { + float32x4_t vacc0x0 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x1 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x2 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x3 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + + size_t k = kc; + for (; k >= 8 * sizeof(bfloat16_t); k -= 8 * sizeof(bfloat16_t)) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 += 8; + + const bfloat16x8_t vb0 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb1 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb2 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb3 = vld1q_bf16(w); w += 8; + + vacc0x0 = vbfmlalbq_f32(vacc0x0, va0, vb0); + vacc0x1 = vbfmlalbq_f32(vacc0x1, va0, vb1); + vacc0x2 = vbfmlalbq_f32(vacc0x2, va0, vb2); + vacc0x3 = vbfmlalbq_f32(vacc0x3, va0, vb3); + + vacc0x0 = vbfmlaltq_f32(vacc0x0, va0, vb0); + vacc0x1 = vbfmlaltq_f32(vacc0x1, va0, vb1); + vacc0x2 = vbfmlaltq_f32(vacc0x2, va0, vb2); + vacc0x3 = vbfmlaltq_f32(vacc0x3, va0, vb3); + } + if XNN_UNLIKELY(k != 0) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 = (const bfloat16_t*) ((uintptr_t) a0 + k); + + const bfloat16x8_t vb0 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb1 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb2 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb3 = vld1q_bf16(w); w += 8; + + const uint16x8_t vm0 = vceqq_u16(vreinterpretq_u16_bf16(vb0), vmovq_n_u16(0)); + const uint16x8_t vm1 = vceqq_u16(vreinterpretq_u16_bf16(vb1), vmovq_n_u16(0)); + const uint16x8_t vm2 = vceqq_u16(vreinterpretq_u16_bf16(vb2), vmovq_n_u16(0)); + const uint16x8_t vm3 = vceqq_u16(vreinterpretq_u16_bf16(vb3), vmovq_n_u16(0)); + + const bfloat16x8_t va0x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm0)); + vacc0x0 = vbfmlalbq_f32(vacc0x0, va0x0, vb0); + vacc0x0 = vbfmlaltq_f32(vacc0x0, va0x0, vb0); + const bfloat16x8_t va0x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm1)); + vacc0x1 = vbfmlalbq_f32(vacc0x1, va0x1, vb1); + vacc0x1 = vbfmlaltq_f32(vacc0x1, va0x1, vb1); + const bfloat16x8_t va0x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm2)); + vacc0x2 = vbfmlalbq_f32(vacc0x2, va0x2, vb2); + vacc0x2 = vbfmlaltq_f32(vacc0x2, va0x2, vb2); + const bfloat16x8_t va0x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm3)); + vacc0x3 = vbfmlalbq_f32(vacc0x3, va0x3, vb3); + vacc0x3 = vbfmlaltq_f32(vacc0x3, va0x3, vb3); + } + +#if XNN_ARCH_ARM64 + const float32x4_t vacc0x01 = vpaddq_f32(vacc0x0, vacc0x1); + const float32x4_t vacc0x23 = vpaddq_f32(vacc0x2, vacc0x3); + + float32x4_t vacc0x0123 = vpaddq_f32(vacc0x01, vacc0x23); +#else + const float32x2_t vsum0x0 = vadd_f32(vget_low_f32(vacc0x0), vget_high_f32(vacc0x0)); + const float32x2_t vsum0x1 = vadd_f32(vget_low_f32(vacc0x1), vget_high_f32(vacc0x1)); + const float32x2_t vsum0x2 = vadd_f32(vget_low_f32(vacc0x2), vget_high_f32(vacc0x2)); + const float32x2_t vsum0x3 = vadd_f32(vget_low_f32(vacc0x3), vget_high_f32(vacc0x3)); + + float32x4_t vacc0x0123 = vcombine_f32(vpadd_f32(vsum0x0, vsum0x1), vpadd_f32(vsum0x2, vsum0x3)); +#endif + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + vacc0x0123 = vminq_f32(vacc0x0123, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + vacc0x0123 = vmaxq_f32(vacc0x0123, vmin); + + bfloat16x4_t vout0x0123 = vcvt_bf16_f32(vacc0x0123); + + if XNN_LIKELY(nc >= 4) { + vst1_bf16(c0, vout0x0123); + c0 = (bfloat16_t*) ((uintptr_t) c0 + cn_stride); + + a0 = (const bfloat16_t*) ((uintptr_t) a0 - kc); + + nc -= 4; + } else { + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_bf16(vout0x0123), 0); c0 += 2; + + vout0x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout0x0123), vreinterpret_u16_bf16(vout0x0123), 2)); + } + if (nc & 1) { + vst1_lane_bf16(c0, vout0x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/bf16-gemm/gen/1x4c8-minmax-neonfma-shland.c b/src/bf16-gemm/gen/1x4c8-minmax-neonfma-shland.c new file mode 100644 index 000000000000..72bae24e4a3f --- /dev/null +++ b/src/bf16-gemm/gen/1x4c8-minmax-neonfma-shland.c @@ -0,0 +1,174 @@ +// Auto-generated file. Do not edit! +// Template: src/bf16-gemm/c8-neon.c.in +// Generator: tools/xngen +// +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 1); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(uint16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const uint16_t* a0 = (const uint16_t*) a; + uint16_t* c0 = (uint16_t*) c; + + const uint16_t* w = (const uint16_t*) w_ptr; + const uint16x8_t vmask = vreinterpretq_u16_u32(vmovq_n_u32(UINT32_C(0xFFFF0000))); + do { + float32x4_t vacc0x0 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x1 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x2 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x3 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + + size_t k = kc; + for (; k >= 8 * sizeof(uint16_t); k -= 8 * sizeof(uint16_t)) { + const uint16x8_t va0 = vld1q_u16(a0); a0 += 8; + + const uint16x8_t vb0 = vld1q_u16(w); w += 8; + const uint16x8_t vb1 = vld1q_u16(w); w += 8; + const uint16x8_t vb2 = vld1q_u16(w); w += 8; + const uint16x8_t vb3 = vld1q_u16(w); w += 8; + + const float32x4_t va0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va0), 16)); + + const float32x4_t vb0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb0), 16)); + const float32x4_t vb1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb1), 16)); + const float32x4_t vb2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb2), 16)); + const float32x4_t vb3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb3), 16)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0e, vb0e); + vacc0x1 = vfmaq_f32(vacc0x1, va0e, vb1e); + vacc0x2 = vfmaq_f32(vacc0x2, va0e, vb2e); + vacc0x3 = vfmaq_f32(vacc0x3, va0e, vb3e); + + const float32x4_t va0o = vreinterpretq_f32_u16(vandq_u16(va0, vmask)); + + const float32x4_t vb0o = vreinterpretq_f32_u16(vandq_u16(vb0, vmask)); + const float32x4_t vb1o = vreinterpretq_f32_u16(vandq_u16(vb1, vmask)); + const float32x4_t vb2o = vreinterpretq_f32_u16(vandq_u16(vb2, vmask)); + const float32x4_t vb3o = vreinterpretq_f32_u16(vandq_u16(vb3, vmask)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0o, vb0o); + vacc0x1 = vfmaq_f32(vacc0x1, va0o, vb1o); + vacc0x2 = vfmaq_f32(vacc0x2, va0o, vb2o); + vacc0x3 = vfmaq_f32(vacc0x3, va0o, vb3o); + } + if XNN_UNLIKELY(k != 0) { + const uint16x8_t va0 = vld1q_u16(a0); a0 = (const uint16_t*) ((uintptr_t) a0 + k); + + const uint16x8_t vb0 = vld1q_u16(w); w += 8; + const uint16x8_t vb1 = vld1q_u16(w); w += 8; + const uint16x8_t vb2 = vld1q_u16(w); w += 8; + const uint16x8_t vb3 = vld1q_u16(w); w += 8; + + const uint16x8_t vm0 = vceqq_u16(vb0, vmovq_n_u16(0)); + const uint16x8_t vm1 = vceqq_u16(vb1, vmovq_n_u16(0)); + const uint16x8_t vm2 = vceqq_u16(vb2, vmovq_n_u16(0)); + const uint16x8_t vm3 = vceqq_u16(vb3, vmovq_n_u16(0)); + + const float32x4_t vb0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb0), 16)); + const float32x4_t vb1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb1), 16)); + const float32x4_t vb2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb2), 16)); + const float32x4_t vb3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb3), 16)); + + const uint16x8_t va0x0 = vbicq_u16(va0, vm0); + const uint16x8_t va0x1 = vbicq_u16(va0, vm1); + const uint16x8_t va0x2 = vbicq_u16(va0, vm2); + const uint16x8_t va0x3 = vbicq_u16(va0, vm3); + + const float32x4_t va0x0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va0x0), 16)); + const float32x4_t va0x1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va0x1), 16)); + const float32x4_t va0x2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va0x2), 16)); + const float32x4_t va0x3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va0x3), 16)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0x0e, vb0e); + vacc0x1 = vfmaq_f32(vacc0x1, va0x1e, vb1e); + vacc0x2 = vfmaq_f32(vacc0x2, va0x2e, vb2e); + vacc0x3 = vfmaq_f32(vacc0x3, va0x3e, vb3e); + + const float32x4_t vb0o = vreinterpretq_f32_u16(vandq_u16(vb0, vmask)); + const float32x4_t vb1o = vreinterpretq_f32_u16(vandq_u16(vb1, vmask)); + const float32x4_t vb2o = vreinterpretq_f32_u16(vandq_u16(vb2, vmask)); + const float32x4_t vb3o = vreinterpretq_f32_u16(vandq_u16(vb3, vmask)); + + const float32x4_t va0x0o = vreinterpretq_f32_u16(vandq_u16(va0x0, vmask)); + const float32x4_t va0x1o = vreinterpretq_f32_u16(vandq_u16(va0x1, vmask)); + const float32x4_t va0x2o = vreinterpretq_f32_u16(vandq_u16(va0x2, vmask)); + const float32x4_t va0x3o = vreinterpretq_f32_u16(vandq_u16(va0x3, vmask)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0x0o, vb0o); + vacc0x1 = vfmaq_f32(vacc0x1, va0x1o, vb1o); + vacc0x2 = vfmaq_f32(vacc0x2, va0x2o, vb2o); + vacc0x3 = vfmaq_f32(vacc0x3, va0x3o, vb3o); + } + +#if XNN_ARCH_ARM64 + const float32x4_t vacc0x01 = vpaddq_f32(vacc0x0, vacc0x1); + const float32x4_t vacc0x23 = vpaddq_f32(vacc0x2, vacc0x3); + + float32x4_t vacc0x0123 = vpaddq_f32(vacc0x01, vacc0x23); +#else + const float32x2_t vsum0x0 = vadd_f32(vget_low_f32(vacc0x0), vget_high_f32(vacc0x0)); + const float32x2_t vsum0x1 = vadd_f32(vget_low_f32(vacc0x1), vget_high_f32(vacc0x1)); + const float32x2_t vsum0x2 = vadd_f32(vget_low_f32(vacc0x2), vget_high_f32(vacc0x2)); + const float32x2_t vsum0x3 = vadd_f32(vget_low_f32(vacc0x3), vget_high_f32(vacc0x3)); + + float32x4_t vacc0x0123 = vcombine_f32(vpadd_f32(vsum0x0, vsum0x1), vpadd_f32(vsum0x2, vsum0x3)); +#endif + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + vacc0x0123 = vminq_f32(vacc0x0123, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + vacc0x0123 = vmaxq_f32(vacc0x0123, vmin); + + uint16x4_t vout0x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc0x0123), 16); + + if XNN_LIKELY(nc >= 4) { + vst1_u16(c0, vout0x0123); + c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); + + a0 = (const uint16_t*) ((uintptr_t) a0 - kc); + + nc -= 4; + } else { + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_u16(vout0x0123), 0); c0 += 2; + + vout0x0123 = vext_u16(vout0x0123, vout0x0123, 2); + } + if (nc & 1) { + vst1_lane_u16(c0, vout0x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/bf16-gemm/gen/1x4c8-minmax-neonfma-zip.c b/src/bf16-gemm/gen/1x4c8-minmax-neonfma-zip.c new file mode 100644 index 000000000000..d73da776e601 --- /dev/null +++ b/src/bf16-gemm/gen/1x4c8-minmax-neonfma-zip.c @@ -0,0 +1,174 @@ +// Auto-generated file. Do not edit! +// Template: src/bf16-gemm/c8-neon.c.in +// Generator: tools/xngen +// +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 1); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(uint16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const uint16_t* a0 = (const uint16_t*) a; + uint16_t* c0 = (uint16_t*) c; + + const uint16_t* w = (const uint16_t*) w_ptr; + const uint16x8_t vzero = vmovq_n_u16(0); + do { + float32x4_t vacc0x0 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x1 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x2 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x3 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + + size_t k = kc; + for (; k >= 8 * sizeof(uint16_t); k -= 8 * sizeof(uint16_t)) { + const uint16x8_t va0 = vld1q_u16(a0); a0 += 8; + + const uint16x8_t vb0 = vld1q_u16(w); w += 8; + const uint16x8_t vb1 = vld1q_u16(w); w += 8; + const uint16x8_t vb2 = vld1q_u16(w); w += 8; + const uint16x8_t vb3 = vld1q_u16(w); w += 8; + + const float32x4_t va0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va0)); + + const float32x4_t vb0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb0)); + const float32x4_t vb1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb1)); + const float32x4_t vb2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb2)); + const float32x4_t vb3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb3)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0e, vb0e); + vacc0x1 = vfmaq_f32(vacc0x1, va0e, vb1e); + vacc0x2 = vfmaq_f32(vacc0x2, va0e, vb2e); + vacc0x3 = vfmaq_f32(vacc0x3, va0e, vb3e); + + const float32x4_t va0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va0)); + + const float32x4_t vb0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb0)); + const float32x4_t vb1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb1)); + const float32x4_t vb2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb2)); + const float32x4_t vb3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb3)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0o, vb0o); + vacc0x1 = vfmaq_f32(vacc0x1, va0o, vb1o); + vacc0x2 = vfmaq_f32(vacc0x2, va0o, vb2o); + vacc0x3 = vfmaq_f32(vacc0x3, va0o, vb3o); + } + if XNN_UNLIKELY(k != 0) { + const uint16x8_t va0 = vld1q_u16(a0); a0 = (const uint16_t*) ((uintptr_t) a0 + k); + + const uint16x8_t vb0 = vld1q_u16(w); w += 8; + const uint16x8_t vb1 = vld1q_u16(w); w += 8; + const uint16x8_t vb2 = vld1q_u16(w); w += 8; + const uint16x8_t vb3 = vld1q_u16(w); w += 8; + + const uint16x8_t vm0 = vceqq_u16(vb0, vmovq_n_u16(0)); + const uint16x8_t vm1 = vceqq_u16(vb1, vmovq_n_u16(0)); + const uint16x8_t vm2 = vceqq_u16(vb2, vmovq_n_u16(0)); + const uint16x8_t vm3 = vceqq_u16(vb3, vmovq_n_u16(0)); + + const float32x4_t vb0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb0)); + const float32x4_t vb1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb1)); + const float32x4_t vb2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb2)); + const float32x4_t vb3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb3)); + + const uint16x8_t va0x0 = vbicq_u16(va0, vm0); + const uint16x8_t va0x1 = vbicq_u16(va0, vm1); + const uint16x8_t va0x2 = vbicq_u16(va0, vm2); + const uint16x8_t va0x3 = vbicq_u16(va0, vm3); + + const float32x4_t va0x0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va0x0)); + const float32x4_t va0x1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va0x1)); + const float32x4_t va0x2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va0x2)); + const float32x4_t va0x3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va0x3)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0x0e, vb0e); + vacc0x1 = vfmaq_f32(vacc0x1, va0x1e, vb1e); + vacc0x2 = vfmaq_f32(vacc0x2, va0x2e, vb2e); + vacc0x3 = vfmaq_f32(vacc0x3, va0x3e, vb3e); + + const float32x4_t vb0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb0)); + const float32x4_t vb1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb1)); + const float32x4_t vb2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb2)); + const float32x4_t vb3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb3)); + + const float32x4_t va0x0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va0x0)); + const float32x4_t va0x1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va0x1)); + const float32x4_t va0x2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va0x2)); + const float32x4_t va0x3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va0x3)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0x0o, vb0o); + vacc0x1 = vfmaq_f32(vacc0x1, va0x1o, vb1o); + vacc0x2 = vfmaq_f32(vacc0x2, va0x2o, vb2o); + vacc0x3 = vfmaq_f32(vacc0x3, va0x3o, vb3o); + } + +#if XNN_ARCH_ARM64 + const float32x4_t vacc0x01 = vpaddq_f32(vacc0x0, vacc0x1); + const float32x4_t vacc0x23 = vpaddq_f32(vacc0x2, vacc0x3); + + float32x4_t vacc0x0123 = vpaddq_f32(vacc0x01, vacc0x23); +#else + const float32x2_t vsum0x0 = vadd_f32(vget_low_f32(vacc0x0), vget_high_f32(vacc0x0)); + const float32x2_t vsum0x1 = vadd_f32(vget_low_f32(vacc0x1), vget_high_f32(vacc0x1)); + const float32x2_t vsum0x2 = vadd_f32(vget_low_f32(vacc0x2), vget_high_f32(vacc0x2)); + const float32x2_t vsum0x3 = vadd_f32(vget_low_f32(vacc0x3), vget_high_f32(vacc0x3)); + + float32x4_t vacc0x0123 = vcombine_f32(vpadd_f32(vsum0x0, vsum0x1), vpadd_f32(vsum0x2, vsum0x3)); +#endif + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + vacc0x0123 = vminq_f32(vacc0x0123, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + vacc0x0123 = vmaxq_f32(vacc0x0123, vmin); + + uint16x4_t vout0x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc0x0123), 16); + + if XNN_LIKELY(nc >= 4) { + vst1_u16(c0, vout0x0123); + c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); + + a0 = (const uint16_t*) ((uintptr_t) a0 - kc); + + nc -= 4; + } else { + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_u16(vout0x0123), 0); c0 += 2; + + vout0x0123 = vext_u16(vout0x0123, vout0x0123, 2); + } + if (nc & 1) { + vst1_lane_u16(c0, vout0x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/bf16-gemm/gen/1x8c2-minmax-neonbf16-bfdot-lane-ld128.c b/src/bf16-gemm/gen/1x8c2-minmax-neonbf16-bfdot-lane-ld128.c new file mode 100644 index 000000000000..2b5ba0d23ae4 --- /dev/null +++ b/src/bf16-gemm/gen/1x8c2-minmax-neonbf16-bfdot-lane-ld128.c @@ -0,0 +1,171 @@ +// Auto-generated file. Do not edit! +// Template: src/bf16-gemm/c2-neonbf16-bfdot-lane-ld128.c.in +// Generator: tools/xngen +// +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 1); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(bfloat16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const bfloat16_t* a0 = (const bfloat16_t*) a; + bfloat16_t* c0 = (bfloat16_t*) c; + + const bfloat16_t* w = (const bfloat16_t*) w_ptr; + do { + float32x4_t vacc0x0123 = vcvt_f32_bf16(vld1_bf16(w)); w += 4; + float32x4_t vacc0x4567 = vcvt_f32_bf16(vld1_bf16(w)); w += 4; + + size_t k = kc; + for (; k >= 8 * sizeof(bfloat16_t); k -= 8 * sizeof(bfloat16_t)) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 += 8; + + const bfloat16x8_t vb0123c01 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c01 = vld1q_bf16(w); w += 8; + + vacc0x0123 = vbfdotq_laneq_f32(vacc0x0123, vb0123c01, va0, 0); + vacc0x4567 = vbfdotq_laneq_f32(vacc0x4567, vb4567c01, va0, 0); + const bfloat16x8_t vb0123c23 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c23 = vld1q_bf16(w); w += 8; + + vacc0x0123 = vbfdotq_laneq_f32(vacc0x0123, vb0123c23, va0, 1); + vacc0x4567 = vbfdotq_laneq_f32(vacc0x4567, vb4567c23, va0, 1); + const bfloat16x8_t vb0123c45 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c45 = vld1q_bf16(w); w += 8; + + vacc0x0123 = vbfdotq_laneq_f32(vacc0x0123, vb0123c45, va0, 2); + vacc0x4567 = vbfdotq_laneq_f32(vacc0x4567, vb4567c45, va0, 2); + const bfloat16x8_t vb0123c67 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c67 = vld1q_bf16(w); w += 8; + + vacc0x0123 = vbfdotq_laneq_f32(vacc0x0123, vb0123c67, va0, 3); + vacc0x4567 = vbfdotq_laneq_f32(vacc0x4567, vb4567c67, va0, 3); + } + if XNN_UNLIKELY(k != 0) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 = (const bfloat16_t*) ((uintptr_t) a0 + k); + + const bfloat16x8_t vb0123c01 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c01 = vld1q_bf16(w); w += 8; + + const uint32x4_t va0c01 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va0)), 0); + + const uint32x4_t vm0123c01 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb0123c01), vmovq_n_u16(0))); + const uint32x4_t vm4567c01 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb4567c01), vmovq_n_u16(0))); + + const uint32x4_t va0x0123c01 = vbicq_u32(va0c01, vm0123c01); + vacc0x0123 = vbfdotq_f32(vacc0x0123, vb0123c01, vreinterpretq_bf16_u32(va0x0123c01)); + const uint32x4_t va0x4567c01 = vbicq_u32(va0c01, vm4567c01); + vacc0x4567 = vbfdotq_f32(vacc0x4567, vb4567c01, vreinterpretq_bf16_u32(va0x4567c01)); + + if (k > 2 * sizeof(bfloat16_t)) { + const bfloat16x8_t vb0123c23 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c23 = vld1q_bf16(w); w += 8; + + const uint32x4_t va0c23 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va0)), 1); + + const uint32x4_t vm0123c23 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb0123c23), vmovq_n_u16(0))); + const uint32x4_t vm4567c23 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb4567c23), vmovq_n_u16(0))); + + const uint32x4_t va0x0123c23 = vbicq_u32(va0c23, vm0123c23); + vacc0x0123 = vbfdotq_f32(vacc0x0123, vb0123c23, vreinterpretq_bf16_u32(va0x0123c23)); + const uint32x4_t va0x4567c23 = vbicq_u32(va0c23, vm4567c23); + vacc0x4567 = vbfdotq_f32(vacc0x4567, vb4567c23, vreinterpretq_bf16_u32(va0x4567c23)); + + if (k > 4 * sizeof(bfloat16_t)) { + const bfloat16x8_t vb0123c45 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c45 = vld1q_bf16(w); w += 8; + + const uint32x4_t va0c45 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va0)), 0); + + const uint32x4_t vm0123c45 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb0123c45), vmovq_n_u16(0))); + const uint32x4_t vm4567c45 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb4567c45), vmovq_n_u16(0))); + + const uint32x4_t va0x0123c45 = vbicq_u32(va0c45, vm0123c45); + vacc0x0123 = vbfdotq_f32(vacc0x0123, vb0123c45, vreinterpretq_bf16_u32(va0x0123c45)); + const uint32x4_t va0x4567c45 = vbicq_u32(va0c45, vm4567c45); + vacc0x4567 = vbfdotq_f32(vacc0x4567, vb4567c45, vreinterpretq_bf16_u32(va0x4567c45)); + + if (k > 6 * sizeof(bfloat16_t)) { + const bfloat16x8_t vb0123c67 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c67 = vld1q_bf16(w); w += 8; + + const uint32x4_t va0c67 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va0)), 1); + + const uint32x4_t vm0123c67 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb0123c67), vmovq_n_u16(0))); + const uint32x4_t vm4567c67 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb4567c67), vmovq_n_u16(0))); + + const uint32x4_t va0x0123c67 = vbicq_u32(va0c67, vm0123c67); + vacc0x0123 = vbfdotq_f32(vacc0x0123, vb0123c67, vreinterpretq_bf16_u32(va0x0123c67)); + const uint32x4_t va0x4567c67 = vbicq_u32(va0c67, vm4567c67); + vacc0x4567 = vbfdotq_f32(vacc0x4567, vb4567c67, vreinterpretq_bf16_u32(va0x4567c67)); + } + } + } + } + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + vacc0x0123 = vminq_f32(vacc0x0123, vmax); + vacc0x4567 = vminq_f32(vacc0x4567, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + vacc0x0123 = vmaxq_f32(vacc0x0123, vmin); + vacc0x4567 = vmaxq_f32(vacc0x4567, vmin); + + bfloat16x4_t vout0x0123 = vcvt_bf16_f32(vacc0x0123); + bfloat16x4_t vout0x4567 = vcvt_bf16_f32(vacc0x4567); + + if XNN_LIKELY(nc >= 8) { + vst1_bf16(c0, vout0x0123); + vst1_bf16(c0 + 4, vout0x4567); + c0 = (bfloat16_t*) ((uintptr_t) c0 + cn_stride); + + a0 = (const bfloat16_t*) ((uintptr_t) a0 - kc); + + nc -= 8; + } else { + if (nc & 4) { + vst1_bf16(c0, vout0x0123); c0 += 4; + + vout0x0123 = vout0x4567; + } + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_bf16(vout0x0123), 0); c0 += 2; + + vout0x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout0x0123), vreinterpret_u16_bf16(vout0x0123), 2)); + } + if (nc & 1) { + vst1_lane_bf16(c0, vout0x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/bf16-gemm/gen/2x4c8-minmax-neonbf16-bfdot.c b/src/bf16-gemm/gen/2x4c8-minmax-neonbf16-bfdot.c new file mode 100644 index 000000000000..8165d8277859 --- /dev/null +++ b/src/bf16-gemm/gen/2x4c8-minmax-neonbf16-bfdot.c @@ -0,0 +1,169 @@ +// Auto-generated file. Do not edit! +// Template: src/bf16-gemm/c8-neonbf16.c.in +// Generator: tools/xngen +// +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 2); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(bfloat16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const bfloat16_t* a0 = (const bfloat16_t*) a; + bfloat16_t* c0 = (bfloat16_t*) c; + const bfloat16_t* a1 = (const bfloat16_t*) ((uintptr_t) a0 + a_stride); + bfloat16_t* c1 = (bfloat16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr != 2) { + a1 = a0; + c1 = c0; + } + + const bfloat16_t* w = (const bfloat16_t*) w_ptr; + do { + float32x4_t vacc0x0 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x1 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x2 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x3 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc1x0 = vacc0x0; + float32x4_t vacc1x1 = vacc0x1; + float32x4_t vacc1x2 = vacc0x2; + float32x4_t vacc1x3 = vacc0x3; + + size_t k = kc; + for (; k >= 8 * sizeof(bfloat16_t); k -= 8 * sizeof(bfloat16_t)) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 += 8; + const bfloat16x8_t va1 = vld1q_bf16(a1); a1 += 8; + + const bfloat16x8_t vb0 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb1 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb2 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb3 = vld1q_bf16(w); w += 8; + + vacc0x0 = vbfdotq_f32(vacc0x0, va0, vb0); + vacc1x0 = vbfdotq_f32(vacc1x0, va1, vb0); + vacc0x1 = vbfdotq_f32(vacc0x1, va0, vb1); + vacc1x1 = vbfdotq_f32(vacc1x1, va1, vb1); + vacc0x2 = vbfdotq_f32(vacc0x2, va0, vb2); + vacc1x2 = vbfdotq_f32(vacc1x2, va1, vb2); + vacc0x3 = vbfdotq_f32(vacc0x3, va0, vb3); + vacc1x3 = vbfdotq_f32(vacc1x3, va1, vb3); + } + if XNN_UNLIKELY(k != 0) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 = (const bfloat16_t*) ((uintptr_t) a0 + k); + const bfloat16x8_t va1 = vld1q_bf16(a1); a1 = (const bfloat16_t*) ((uintptr_t) a1 + k); + + const bfloat16x8_t vb0 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb1 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb2 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb3 = vld1q_bf16(w); w += 8; + + const uint16x8_t vm0 = vceqq_u16(vreinterpretq_u16_bf16(vb0), vmovq_n_u16(0)); + const uint16x8_t vm1 = vceqq_u16(vreinterpretq_u16_bf16(vb1), vmovq_n_u16(0)); + const uint16x8_t vm2 = vceqq_u16(vreinterpretq_u16_bf16(vb2), vmovq_n_u16(0)); + const uint16x8_t vm3 = vceqq_u16(vreinterpretq_u16_bf16(vb3), vmovq_n_u16(0)); + + const bfloat16x8_t va0x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm0)); + vacc0x0 = vbfdotq_f32(vacc0x0, va0x0, vb0); + const bfloat16x8_t va1x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm0)); + vacc1x0 = vbfdotq_f32(vacc1x0, va1x0, vb0); + const bfloat16x8_t va0x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm1)); + vacc0x1 = vbfdotq_f32(vacc0x1, va0x1, vb1); + const bfloat16x8_t va1x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm1)); + vacc1x1 = vbfdotq_f32(vacc1x1, va1x1, vb1); + const bfloat16x8_t va0x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm2)); + vacc0x2 = vbfdotq_f32(vacc0x2, va0x2, vb2); + const bfloat16x8_t va1x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm2)); + vacc1x2 = vbfdotq_f32(vacc1x2, va1x2, vb2); + const bfloat16x8_t va0x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm3)); + vacc0x3 = vbfdotq_f32(vacc0x3, va0x3, vb3); + const bfloat16x8_t va1x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm3)); + vacc1x3 = vbfdotq_f32(vacc1x3, va1x3, vb3); + } + +#if XNN_ARCH_ARM64 + const float32x4_t vacc0x01 = vpaddq_f32(vacc0x0, vacc0x1); + const float32x4_t vacc1x01 = vpaddq_f32(vacc1x0, vacc1x1); + const float32x4_t vacc0x23 = vpaddq_f32(vacc0x2, vacc0x3); + const float32x4_t vacc1x23 = vpaddq_f32(vacc1x2, vacc1x3); + + float32x4_t vacc0x0123 = vpaddq_f32(vacc0x01, vacc0x23); + float32x4_t vacc1x0123 = vpaddq_f32(vacc1x01, vacc1x23); +#else + const float32x2_t vsum0x0 = vadd_f32(vget_low_f32(vacc0x0), vget_high_f32(vacc0x0)); + const float32x2_t vsum1x0 = vadd_f32(vget_low_f32(vacc1x0), vget_high_f32(vacc1x0)); + const float32x2_t vsum0x1 = vadd_f32(vget_low_f32(vacc0x1), vget_high_f32(vacc0x1)); + const float32x2_t vsum1x1 = vadd_f32(vget_low_f32(vacc1x1), vget_high_f32(vacc1x1)); + const float32x2_t vsum0x2 = vadd_f32(vget_low_f32(vacc0x2), vget_high_f32(vacc0x2)); + const float32x2_t vsum1x2 = vadd_f32(vget_low_f32(vacc1x2), vget_high_f32(vacc1x2)); + const float32x2_t vsum0x3 = vadd_f32(vget_low_f32(vacc0x3), vget_high_f32(vacc0x3)); + const float32x2_t vsum1x3 = vadd_f32(vget_low_f32(vacc1x3), vget_high_f32(vacc1x3)); + + float32x4_t vacc0x0123 = vcombine_f32(vpadd_f32(vsum0x0, vsum0x1), vpadd_f32(vsum0x2, vsum0x3)); + float32x4_t vacc1x0123 = vcombine_f32(vpadd_f32(vsum1x0, vsum1x1), vpadd_f32(vsum1x2, vsum1x3)); +#endif + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + vacc0x0123 = vminq_f32(vacc0x0123, vmax); + vacc1x0123 = vminq_f32(vacc1x0123, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + vacc0x0123 = vmaxq_f32(vacc0x0123, vmin); + vacc1x0123 = vmaxq_f32(vacc1x0123, vmin); + + bfloat16x4_t vout0x0123 = vcvt_bf16_f32(vacc0x0123); + bfloat16x4_t vout1x0123 = vcvt_bf16_f32(vacc1x0123); + + if XNN_LIKELY(nc >= 4) { + vst1_bf16(c0, vout0x0123); + c0 = (bfloat16_t*) ((uintptr_t) c0 + cn_stride); + vst1_bf16(c1, vout1x0123); + c1 = (bfloat16_t*) ((uintptr_t) c1 + cn_stride); + + a0 = (const bfloat16_t*) ((uintptr_t) a0 - kc); + a1 = (const bfloat16_t*) ((uintptr_t) a1 - kc); + + nc -= 4; + } else { + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_bf16(vout0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_bf16(vout1x0123), 0); c1 += 2; + + vout0x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout0x0123), vreinterpret_u16_bf16(vout0x0123), 2)); + vout1x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout1x0123), vreinterpret_u16_bf16(vout1x0123), 2)); + } + if (nc & 1) { + vst1_lane_bf16(c0, vout0x0123, 0); + vst1_lane_bf16(c1, vout1x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/bf16-gemm/gen/2x4c8-minmax-neonbf16-bfmlal.c b/src/bf16-gemm/gen/2x4c8-minmax-neonbf16-bfmlal.c new file mode 100644 index 000000000000..283c3be197c3 --- /dev/null +++ b/src/bf16-gemm/gen/2x4c8-minmax-neonbf16-bfmlal.c @@ -0,0 +1,186 @@ +// Auto-generated file. Do not edit! +// Template: src/bf16-gemm/c8-neonbf16.c.in +// Generator: tools/xngen +// +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 2); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(bfloat16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const bfloat16_t* a0 = (const bfloat16_t*) a; + bfloat16_t* c0 = (bfloat16_t*) c; + const bfloat16_t* a1 = (const bfloat16_t*) ((uintptr_t) a0 + a_stride); + bfloat16_t* c1 = (bfloat16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr != 2) { + a1 = a0; + c1 = c0; + } + + const bfloat16_t* w = (const bfloat16_t*) w_ptr; + do { + float32x4_t vacc0x0 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x1 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x2 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x3 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc1x0 = vacc0x0; + float32x4_t vacc1x1 = vacc0x1; + float32x4_t vacc1x2 = vacc0x2; + float32x4_t vacc1x3 = vacc0x3; + + size_t k = kc; + for (; k >= 8 * sizeof(bfloat16_t); k -= 8 * sizeof(bfloat16_t)) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 += 8; + const bfloat16x8_t va1 = vld1q_bf16(a1); a1 += 8; + + const bfloat16x8_t vb0 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb1 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb2 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb3 = vld1q_bf16(w); w += 8; + + vacc0x0 = vbfmlalbq_f32(vacc0x0, va0, vb0); + vacc1x0 = vbfmlalbq_f32(vacc1x0, va1, vb0); + vacc0x1 = vbfmlalbq_f32(vacc0x1, va0, vb1); + vacc1x1 = vbfmlalbq_f32(vacc1x1, va1, vb1); + vacc0x2 = vbfmlalbq_f32(vacc0x2, va0, vb2); + vacc1x2 = vbfmlalbq_f32(vacc1x2, va1, vb2); + vacc0x3 = vbfmlalbq_f32(vacc0x3, va0, vb3); + vacc1x3 = vbfmlalbq_f32(vacc1x3, va1, vb3); + + vacc0x0 = vbfmlaltq_f32(vacc0x0, va0, vb0); + vacc1x0 = vbfmlaltq_f32(vacc1x0, va1, vb0); + vacc0x1 = vbfmlaltq_f32(vacc0x1, va0, vb1); + vacc1x1 = vbfmlaltq_f32(vacc1x1, va1, vb1); + vacc0x2 = vbfmlaltq_f32(vacc0x2, va0, vb2); + vacc1x2 = vbfmlaltq_f32(vacc1x2, va1, vb2); + vacc0x3 = vbfmlaltq_f32(vacc0x3, va0, vb3); + vacc1x3 = vbfmlaltq_f32(vacc1x3, va1, vb3); + } + if XNN_UNLIKELY(k != 0) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 = (const bfloat16_t*) ((uintptr_t) a0 + k); + const bfloat16x8_t va1 = vld1q_bf16(a1); a1 = (const bfloat16_t*) ((uintptr_t) a1 + k); + + const bfloat16x8_t vb0 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb1 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb2 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb3 = vld1q_bf16(w); w += 8; + + const uint16x8_t vm0 = vceqq_u16(vreinterpretq_u16_bf16(vb0), vmovq_n_u16(0)); + const uint16x8_t vm1 = vceqq_u16(vreinterpretq_u16_bf16(vb1), vmovq_n_u16(0)); + const uint16x8_t vm2 = vceqq_u16(vreinterpretq_u16_bf16(vb2), vmovq_n_u16(0)); + const uint16x8_t vm3 = vceqq_u16(vreinterpretq_u16_bf16(vb3), vmovq_n_u16(0)); + + const bfloat16x8_t va0x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm0)); + vacc0x0 = vbfmlalbq_f32(vacc0x0, va0x0, vb0); + vacc0x0 = vbfmlaltq_f32(vacc0x0, va0x0, vb0); + const bfloat16x8_t va1x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm0)); + vacc1x0 = vbfmlalbq_f32(vacc1x0, va1x0, vb0); + vacc1x0 = vbfmlaltq_f32(vacc1x0, va1x0, vb0); + const bfloat16x8_t va0x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm1)); + vacc0x1 = vbfmlalbq_f32(vacc0x1, va0x1, vb1); + vacc0x1 = vbfmlaltq_f32(vacc0x1, va0x1, vb1); + const bfloat16x8_t va1x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm1)); + vacc1x1 = vbfmlalbq_f32(vacc1x1, va1x1, vb1); + vacc1x1 = vbfmlaltq_f32(vacc1x1, va1x1, vb1); + const bfloat16x8_t va0x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm2)); + vacc0x2 = vbfmlalbq_f32(vacc0x2, va0x2, vb2); + vacc0x2 = vbfmlaltq_f32(vacc0x2, va0x2, vb2); + const bfloat16x8_t va1x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm2)); + vacc1x2 = vbfmlalbq_f32(vacc1x2, va1x2, vb2); + vacc1x2 = vbfmlaltq_f32(vacc1x2, va1x2, vb2); + const bfloat16x8_t va0x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm3)); + vacc0x3 = vbfmlalbq_f32(vacc0x3, va0x3, vb3); + vacc0x3 = vbfmlaltq_f32(vacc0x3, va0x3, vb3); + const bfloat16x8_t va1x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm3)); + vacc1x3 = vbfmlalbq_f32(vacc1x3, va1x3, vb3); + vacc1x3 = vbfmlaltq_f32(vacc1x3, va1x3, vb3); + } + +#if XNN_ARCH_ARM64 + const float32x4_t vacc0x01 = vpaddq_f32(vacc0x0, vacc0x1); + const float32x4_t vacc1x01 = vpaddq_f32(vacc1x0, vacc1x1); + const float32x4_t vacc0x23 = vpaddq_f32(vacc0x2, vacc0x3); + const float32x4_t vacc1x23 = vpaddq_f32(vacc1x2, vacc1x3); + + float32x4_t vacc0x0123 = vpaddq_f32(vacc0x01, vacc0x23); + float32x4_t vacc1x0123 = vpaddq_f32(vacc1x01, vacc1x23); +#else + const float32x2_t vsum0x0 = vadd_f32(vget_low_f32(vacc0x0), vget_high_f32(vacc0x0)); + const float32x2_t vsum1x0 = vadd_f32(vget_low_f32(vacc1x0), vget_high_f32(vacc1x0)); + const float32x2_t vsum0x1 = vadd_f32(vget_low_f32(vacc0x1), vget_high_f32(vacc0x1)); + const float32x2_t vsum1x1 = vadd_f32(vget_low_f32(vacc1x1), vget_high_f32(vacc1x1)); + const float32x2_t vsum0x2 = vadd_f32(vget_low_f32(vacc0x2), vget_high_f32(vacc0x2)); + const float32x2_t vsum1x2 = vadd_f32(vget_low_f32(vacc1x2), vget_high_f32(vacc1x2)); + const float32x2_t vsum0x3 = vadd_f32(vget_low_f32(vacc0x3), vget_high_f32(vacc0x3)); + const float32x2_t vsum1x3 = vadd_f32(vget_low_f32(vacc1x3), vget_high_f32(vacc1x3)); + + float32x4_t vacc0x0123 = vcombine_f32(vpadd_f32(vsum0x0, vsum0x1), vpadd_f32(vsum0x2, vsum0x3)); + float32x4_t vacc1x0123 = vcombine_f32(vpadd_f32(vsum1x0, vsum1x1), vpadd_f32(vsum1x2, vsum1x3)); +#endif + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + vacc0x0123 = vminq_f32(vacc0x0123, vmax); + vacc1x0123 = vminq_f32(vacc1x0123, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + vacc0x0123 = vmaxq_f32(vacc0x0123, vmin); + vacc1x0123 = vmaxq_f32(vacc1x0123, vmin); + + bfloat16x4_t vout0x0123 = vcvt_bf16_f32(vacc0x0123); + bfloat16x4_t vout1x0123 = vcvt_bf16_f32(vacc1x0123); + + if XNN_LIKELY(nc >= 4) { + vst1_bf16(c0, vout0x0123); + c0 = (bfloat16_t*) ((uintptr_t) c0 + cn_stride); + vst1_bf16(c1, vout1x0123); + c1 = (bfloat16_t*) ((uintptr_t) c1 + cn_stride); + + a0 = (const bfloat16_t*) ((uintptr_t) a0 - kc); + a1 = (const bfloat16_t*) ((uintptr_t) a1 - kc); + + nc -= 4; + } else { + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_bf16(vout0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_bf16(vout1x0123), 0); c1 += 2; + + vout0x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout0x0123), vreinterpret_u16_bf16(vout0x0123), 2)); + vout1x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout1x0123), vreinterpret_u16_bf16(vout1x0123), 2)); + } + if (nc & 1) { + vst1_lane_bf16(c0, vout0x0123, 0); + vst1_lane_bf16(c1, vout1x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/bf16-gemm/gen/2x4c8-minmax-neonfma-shland.c b/src/bf16-gemm/gen/2x4c8-minmax-neonfma-shland.c new file mode 100644 index 000000000000..4a40c548dc2c --- /dev/null +++ b/src/bf16-gemm/gen/2x4c8-minmax-neonfma-shland.c @@ -0,0 +1,233 @@ +// Auto-generated file. Do not edit! +// Template: src/bf16-gemm/c8-neon.c.in +// Generator: tools/xngen +// +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 2); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(uint16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const uint16_t* a0 = (const uint16_t*) a; + uint16_t* c0 = (uint16_t*) c; + const uint16_t* a1 = (const uint16_t*) ((uintptr_t) a0 + a_stride); + uint16_t* c1 = (uint16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr != 2) { + a1 = a0; + c1 = c0; + } + + const uint16_t* w = (const uint16_t*) w_ptr; + const uint16x8_t vmask = vreinterpretq_u16_u32(vmovq_n_u32(UINT32_C(0xFFFF0000))); + do { + float32x4_t vacc0x0 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x1 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x2 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x3 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc1x0 = vacc0x0; + float32x4_t vacc1x1 = vacc0x1; + float32x4_t vacc1x2 = vacc0x2; + float32x4_t vacc1x3 = vacc0x3; + + size_t k = kc; + for (; k >= 8 * sizeof(uint16_t); k -= 8 * sizeof(uint16_t)) { + const uint16x8_t va0 = vld1q_u16(a0); a0 += 8; + const uint16x8_t va1 = vld1q_u16(a1); a1 += 8; + + const uint16x8_t vb0 = vld1q_u16(w); w += 8; + const uint16x8_t vb1 = vld1q_u16(w); w += 8; + const uint16x8_t vb2 = vld1q_u16(w); w += 8; + const uint16x8_t vb3 = vld1q_u16(w); w += 8; + + const float32x4_t va0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va0), 16)); + const float32x4_t va1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va1), 16)); + + const float32x4_t vb0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb0), 16)); + const float32x4_t vb1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb1), 16)); + const float32x4_t vb2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb2), 16)); + const float32x4_t vb3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb3), 16)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0e, vb0e); + vacc1x0 = vfmaq_f32(vacc1x0, va1e, vb0e); + vacc0x1 = vfmaq_f32(vacc0x1, va0e, vb1e); + vacc1x1 = vfmaq_f32(vacc1x1, va1e, vb1e); + vacc0x2 = vfmaq_f32(vacc0x2, va0e, vb2e); + vacc1x2 = vfmaq_f32(vacc1x2, va1e, vb2e); + vacc0x3 = vfmaq_f32(vacc0x3, va0e, vb3e); + vacc1x3 = vfmaq_f32(vacc1x3, va1e, vb3e); + + const float32x4_t va0o = vreinterpretq_f32_u16(vandq_u16(va0, vmask)); + const float32x4_t va1o = vreinterpretq_f32_u16(vandq_u16(va1, vmask)); + + const float32x4_t vb0o = vreinterpretq_f32_u16(vandq_u16(vb0, vmask)); + const float32x4_t vb1o = vreinterpretq_f32_u16(vandq_u16(vb1, vmask)); + const float32x4_t vb2o = vreinterpretq_f32_u16(vandq_u16(vb2, vmask)); + const float32x4_t vb3o = vreinterpretq_f32_u16(vandq_u16(vb3, vmask)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0o, vb0o); + vacc1x0 = vfmaq_f32(vacc1x0, va1o, vb0o); + vacc0x1 = vfmaq_f32(vacc0x1, va0o, vb1o); + vacc1x1 = vfmaq_f32(vacc1x1, va1o, vb1o); + vacc0x2 = vfmaq_f32(vacc0x2, va0o, vb2o); + vacc1x2 = vfmaq_f32(vacc1x2, va1o, vb2o); + vacc0x3 = vfmaq_f32(vacc0x3, va0o, vb3o); + vacc1x3 = vfmaq_f32(vacc1x3, va1o, vb3o); + } + if XNN_UNLIKELY(k != 0) { + const uint16x8_t va0 = vld1q_u16(a0); a0 = (const uint16_t*) ((uintptr_t) a0 + k); + const uint16x8_t va1 = vld1q_u16(a1); a1 = (const uint16_t*) ((uintptr_t) a1 + k); + + const uint16x8_t vb0 = vld1q_u16(w); w += 8; + const uint16x8_t vb1 = vld1q_u16(w); w += 8; + const uint16x8_t vb2 = vld1q_u16(w); w += 8; + const uint16x8_t vb3 = vld1q_u16(w); w += 8; + + const uint16x8_t vm0 = vceqq_u16(vb0, vmovq_n_u16(0)); + const uint16x8_t vm1 = vceqq_u16(vb1, vmovq_n_u16(0)); + const uint16x8_t vm2 = vceqq_u16(vb2, vmovq_n_u16(0)); + const uint16x8_t vm3 = vceqq_u16(vb3, vmovq_n_u16(0)); + + const float32x4_t vb0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb0), 16)); + const float32x4_t vb1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb1), 16)); + const float32x4_t vb2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb2), 16)); + const float32x4_t vb3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb3), 16)); + + const uint16x8_t va0x0 = vbicq_u16(va0, vm0); + const uint16x8_t va1x0 = vbicq_u16(va1, vm0); + const uint16x8_t va0x1 = vbicq_u16(va0, vm1); + const uint16x8_t va1x1 = vbicq_u16(va1, vm1); + const uint16x8_t va0x2 = vbicq_u16(va0, vm2); + const uint16x8_t va1x2 = vbicq_u16(va1, vm2); + const uint16x8_t va0x3 = vbicq_u16(va0, vm3); + const uint16x8_t va1x3 = vbicq_u16(va1, vm3); + + const float32x4_t va0x0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va0x0), 16)); + const float32x4_t va1x0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va1x0), 16)); + const float32x4_t va0x1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va0x1), 16)); + const float32x4_t va1x1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va1x1), 16)); + const float32x4_t va0x2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va0x2), 16)); + const float32x4_t va1x2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va1x2), 16)); + const float32x4_t va0x3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va0x3), 16)); + const float32x4_t va1x3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va1x3), 16)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0x0e, vb0e); + vacc1x0 = vfmaq_f32(vacc1x0, va1x0e, vb0e); + vacc0x1 = vfmaq_f32(vacc0x1, va0x1e, vb1e); + vacc1x1 = vfmaq_f32(vacc1x1, va1x1e, vb1e); + vacc0x2 = vfmaq_f32(vacc0x2, va0x2e, vb2e); + vacc1x2 = vfmaq_f32(vacc1x2, va1x2e, vb2e); + vacc0x3 = vfmaq_f32(vacc0x3, va0x3e, vb3e); + vacc1x3 = vfmaq_f32(vacc1x3, va1x3e, vb3e); + + const float32x4_t vb0o = vreinterpretq_f32_u16(vandq_u16(vb0, vmask)); + const float32x4_t vb1o = vreinterpretq_f32_u16(vandq_u16(vb1, vmask)); + const float32x4_t vb2o = vreinterpretq_f32_u16(vandq_u16(vb2, vmask)); + const float32x4_t vb3o = vreinterpretq_f32_u16(vandq_u16(vb3, vmask)); + + const float32x4_t va0x0o = vreinterpretq_f32_u16(vandq_u16(va0x0, vmask)); + const float32x4_t va1x0o = vreinterpretq_f32_u16(vandq_u16(va1x0, vmask)); + const float32x4_t va0x1o = vreinterpretq_f32_u16(vandq_u16(va0x1, vmask)); + const float32x4_t va1x1o = vreinterpretq_f32_u16(vandq_u16(va1x1, vmask)); + const float32x4_t va0x2o = vreinterpretq_f32_u16(vandq_u16(va0x2, vmask)); + const float32x4_t va1x2o = vreinterpretq_f32_u16(vandq_u16(va1x2, vmask)); + const float32x4_t va0x3o = vreinterpretq_f32_u16(vandq_u16(va0x3, vmask)); + const float32x4_t va1x3o = vreinterpretq_f32_u16(vandq_u16(va1x3, vmask)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0x0o, vb0o); + vacc1x0 = vfmaq_f32(vacc1x0, va1x0o, vb0o); + vacc0x1 = vfmaq_f32(vacc0x1, va0x1o, vb1o); + vacc1x1 = vfmaq_f32(vacc1x1, va1x1o, vb1o); + vacc0x2 = vfmaq_f32(vacc0x2, va0x2o, vb2o); + vacc1x2 = vfmaq_f32(vacc1x2, va1x2o, vb2o); + vacc0x3 = vfmaq_f32(vacc0x3, va0x3o, vb3o); + vacc1x3 = vfmaq_f32(vacc1x3, va1x3o, vb3o); + } + +#if XNN_ARCH_ARM64 + const float32x4_t vacc0x01 = vpaddq_f32(vacc0x0, vacc0x1); + const float32x4_t vacc1x01 = vpaddq_f32(vacc1x0, vacc1x1); + const float32x4_t vacc0x23 = vpaddq_f32(vacc0x2, vacc0x3); + const float32x4_t vacc1x23 = vpaddq_f32(vacc1x2, vacc1x3); + + float32x4_t vacc0x0123 = vpaddq_f32(vacc0x01, vacc0x23); + float32x4_t vacc1x0123 = vpaddq_f32(vacc1x01, vacc1x23); +#else + const float32x2_t vsum0x0 = vadd_f32(vget_low_f32(vacc0x0), vget_high_f32(vacc0x0)); + const float32x2_t vsum1x0 = vadd_f32(vget_low_f32(vacc1x0), vget_high_f32(vacc1x0)); + const float32x2_t vsum0x1 = vadd_f32(vget_low_f32(vacc0x1), vget_high_f32(vacc0x1)); + const float32x2_t vsum1x1 = vadd_f32(vget_low_f32(vacc1x1), vget_high_f32(vacc1x1)); + const float32x2_t vsum0x2 = vadd_f32(vget_low_f32(vacc0x2), vget_high_f32(vacc0x2)); + const float32x2_t vsum1x2 = vadd_f32(vget_low_f32(vacc1x2), vget_high_f32(vacc1x2)); + const float32x2_t vsum0x3 = vadd_f32(vget_low_f32(vacc0x3), vget_high_f32(vacc0x3)); + const float32x2_t vsum1x3 = vadd_f32(vget_low_f32(vacc1x3), vget_high_f32(vacc1x3)); + + float32x4_t vacc0x0123 = vcombine_f32(vpadd_f32(vsum0x0, vsum0x1), vpadd_f32(vsum0x2, vsum0x3)); + float32x4_t vacc1x0123 = vcombine_f32(vpadd_f32(vsum1x0, vsum1x1), vpadd_f32(vsum1x2, vsum1x3)); +#endif + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + vacc0x0123 = vminq_f32(vacc0x0123, vmax); + vacc1x0123 = vminq_f32(vacc1x0123, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + vacc0x0123 = vmaxq_f32(vacc0x0123, vmin); + vacc1x0123 = vmaxq_f32(vacc1x0123, vmin); + + uint16x4_t vout0x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc0x0123), 16); + uint16x4_t vout1x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc1x0123), 16); + + if XNN_LIKELY(nc >= 4) { + vst1_u16(c0, vout0x0123); + c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); + vst1_u16(c1, vout1x0123); + c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride); + + a0 = (const uint16_t*) ((uintptr_t) a0 - kc); + a1 = (const uint16_t*) ((uintptr_t) a1 - kc); + + nc -= 4; + } else { + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_u16(vout0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_u16(vout1x0123), 0); c1 += 2; + + vout0x0123 = vext_u16(vout0x0123, vout0x0123, 2); + vout1x0123 = vext_u16(vout1x0123, vout1x0123, 2); + } + if (nc & 1) { + vst1_lane_u16(c0, vout0x0123, 0); + vst1_lane_u16(c1, vout1x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/bf16-gemm/gen/2x4c8-minmax-neonfma-zip.c b/src/bf16-gemm/gen/2x4c8-minmax-neonfma-zip.c new file mode 100644 index 000000000000..ffd9821b5291 --- /dev/null +++ b/src/bf16-gemm/gen/2x4c8-minmax-neonfma-zip.c @@ -0,0 +1,233 @@ +// Auto-generated file. Do not edit! +// Template: src/bf16-gemm/c8-neon.c.in +// Generator: tools/xngen +// +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 2); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(uint16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const uint16_t* a0 = (const uint16_t*) a; + uint16_t* c0 = (uint16_t*) c; + const uint16_t* a1 = (const uint16_t*) ((uintptr_t) a0 + a_stride); + uint16_t* c1 = (uint16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr != 2) { + a1 = a0; + c1 = c0; + } + + const uint16_t* w = (const uint16_t*) w_ptr; + const uint16x8_t vzero = vmovq_n_u16(0); + do { + float32x4_t vacc0x0 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x1 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x2 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x3 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc1x0 = vacc0x0; + float32x4_t vacc1x1 = vacc0x1; + float32x4_t vacc1x2 = vacc0x2; + float32x4_t vacc1x3 = vacc0x3; + + size_t k = kc; + for (; k >= 8 * sizeof(uint16_t); k -= 8 * sizeof(uint16_t)) { + const uint16x8_t va0 = vld1q_u16(a0); a0 += 8; + const uint16x8_t va1 = vld1q_u16(a1); a1 += 8; + + const uint16x8_t vb0 = vld1q_u16(w); w += 8; + const uint16x8_t vb1 = vld1q_u16(w); w += 8; + const uint16x8_t vb2 = vld1q_u16(w); w += 8; + const uint16x8_t vb3 = vld1q_u16(w); w += 8; + + const float32x4_t va0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va0)); + const float32x4_t va1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va1)); + + const float32x4_t vb0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb0)); + const float32x4_t vb1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb1)); + const float32x4_t vb2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb2)); + const float32x4_t vb3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb3)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0e, vb0e); + vacc1x0 = vfmaq_f32(vacc1x0, va1e, vb0e); + vacc0x1 = vfmaq_f32(vacc0x1, va0e, vb1e); + vacc1x1 = vfmaq_f32(vacc1x1, va1e, vb1e); + vacc0x2 = vfmaq_f32(vacc0x2, va0e, vb2e); + vacc1x2 = vfmaq_f32(vacc1x2, va1e, vb2e); + vacc0x3 = vfmaq_f32(vacc0x3, va0e, vb3e); + vacc1x3 = vfmaq_f32(vacc1x3, va1e, vb3e); + + const float32x4_t va0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va0)); + const float32x4_t va1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va1)); + + const float32x4_t vb0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb0)); + const float32x4_t vb1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb1)); + const float32x4_t vb2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb2)); + const float32x4_t vb3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb3)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0o, vb0o); + vacc1x0 = vfmaq_f32(vacc1x0, va1o, vb0o); + vacc0x1 = vfmaq_f32(vacc0x1, va0o, vb1o); + vacc1x1 = vfmaq_f32(vacc1x1, va1o, vb1o); + vacc0x2 = vfmaq_f32(vacc0x2, va0o, vb2o); + vacc1x2 = vfmaq_f32(vacc1x2, va1o, vb2o); + vacc0x3 = vfmaq_f32(vacc0x3, va0o, vb3o); + vacc1x3 = vfmaq_f32(vacc1x3, va1o, vb3o); + } + if XNN_UNLIKELY(k != 0) { + const uint16x8_t va0 = vld1q_u16(a0); a0 = (const uint16_t*) ((uintptr_t) a0 + k); + const uint16x8_t va1 = vld1q_u16(a1); a1 = (const uint16_t*) ((uintptr_t) a1 + k); + + const uint16x8_t vb0 = vld1q_u16(w); w += 8; + const uint16x8_t vb1 = vld1q_u16(w); w += 8; + const uint16x8_t vb2 = vld1q_u16(w); w += 8; + const uint16x8_t vb3 = vld1q_u16(w); w += 8; + + const uint16x8_t vm0 = vceqq_u16(vb0, vmovq_n_u16(0)); + const uint16x8_t vm1 = vceqq_u16(vb1, vmovq_n_u16(0)); + const uint16x8_t vm2 = vceqq_u16(vb2, vmovq_n_u16(0)); + const uint16x8_t vm3 = vceqq_u16(vb3, vmovq_n_u16(0)); + + const float32x4_t vb0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb0)); + const float32x4_t vb1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb1)); + const float32x4_t vb2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb2)); + const float32x4_t vb3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb3)); + + const uint16x8_t va0x0 = vbicq_u16(va0, vm0); + const uint16x8_t va1x0 = vbicq_u16(va1, vm0); + const uint16x8_t va0x1 = vbicq_u16(va0, vm1); + const uint16x8_t va1x1 = vbicq_u16(va1, vm1); + const uint16x8_t va0x2 = vbicq_u16(va0, vm2); + const uint16x8_t va1x2 = vbicq_u16(va1, vm2); + const uint16x8_t va0x3 = vbicq_u16(va0, vm3); + const uint16x8_t va1x3 = vbicq_u16(va1, vm3); + + const float32x4_t va0x0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va0x0)); + const float32x4_t va1x0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va1x0)); + const float32x4_t va0x1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va0x1)); + const float32x4_t va1x1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va1x1)); + const float32x4_t va0x2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va0x2)); + const float32x4_t va1x2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va1x2)); + const float32x4_t va0x3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va0x3)); + const float32x4_t va1x3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va1x3)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0x0e, vb0e); + vacc1x0 = vfmaq_f32(vacc1x0, va1x0e, vb0e); + vacc0x1 = vfmaq_f32(vacc0x1, va0x1e, vb1e); + vacc1x1 = vfmaq_f32(vacc1x1, va1x1e, vb1e); + vacc0x2 = vfmaq_f32(vacc0x2, va0x2e, vb2e); + vacc1x2 = vfmaq_f32(vacc1x2, va1x2e, vb2e); + vacc0x3 = vfmaq_f32(vacc0x3, va0x3e, vb3e); + vacc1x3 = vfmaq_f32(vacc1x3, va1x3e, vb3e); + + const float32x4_t vb0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb0)); + const float32x4_t vb1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb1)); + const float32x4_t vb2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb2)); + const float32x4_t vb3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb3)); + + const float32x4_t va0x0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va0x0)); + const float32x4_t va1x0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va1x0)); + const float32x4_t va0x1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va0x1)); + const float32x4_t va1x1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va1x1)); + const float32x4_t va0x2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va0x2)); + const float32x4_t va1x2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va1x2)); + const float32x4_t va0x3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va0x3)); + const float32x4_t va1x3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va1x3)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0x0o, vb0o); + vacc1x0 = vfmaq_f32(vacc1x0, va1x0o, vb0o); + vacc0x1 = vfmaq_f32(vacc0x1, va0x1o, vb1o); + vacc1x1 = vfmaq_f32(vacc1x1, va1x1o, vb1o); + vacc0x2 = vfmaq_f32(vacc0x2, va0x2o, vb2o); + vacc1x2 = vfmaq_f32(vacc1x2, va1x2o, vb2o); + vacc0x3 = vfmaq_f32(vacc0x3, va0x3o, vb3o); + vacc1x3 = vfmaq_f32(vacc1x3, va1x3o, vb3o); + } + +#if XNN_ARCH_ARM64 + const float32x4_t vacc0x01 = vpaddq_f32(vacc0x0, vacc0x1); + const float32x4_t vacc1x01 = vpaddq_f32(vacc1x0, vacc1x1); + const float32x4_t vacc0x23 = vpaddq_f32(vacc0x2, vacc0x3); + const float32x4_t vacc1x23 = vpaddq_f32(vacc1x2, vacc1x3); + + float32x4_t vacc0x0123 = vpaddq_f32(vacc0x01, vacc0x23); + float32x4_t vacc1x0123 = vpaddq_f32(vacc1x01, vacc1x23); +#else + const float32x2_t vsum0x0 = vadd_f32(vget_low_f32(vacc0x0), vget_high_f32(vacc0x0)); + const float32x2_t vsum1x0 = vadd_f32(vget_low_f32(vacc1x0), vget_high_f32(vacc1x0)); + const float32x2_t vsum0x1 = vadd_f32(vget_low_f32(vacc0x1), vget_high_f32(vacc0x1)); + const float32x2_t vsum1x1 = vadd_f32(vget_low_f32(vacc1x1), vget_high_f32(vacc1x1)); + const float32x2_t vsum0x2 = vadd_f32(vget_low_f32(vacc0x2), vget_high_f32(vacc0x2)); + const float32x2_t vsum1x2 = vadd_f32(vget_low_f32(vacc1x2), vget_high_f32(vacc1x2)); + const float32x2_t vsum0x3 = vadd_f32(vget_low_f32(vacc0x3), vget_high_f32(vacc0x3)); + const float32x2_t vsum1x3 = vadd_f32(vget_low_f32(vacc1x3), vget_high_f32(vacc1x3)); + + float32x4_t vacc0x0123 = vcombine_f32(vpadd_f32(vsum0x0, vsum0x1), vpadd_f32(vsum0x2, vsum0x3)); + float32x4_t vacc1x0123 = vcombine_f32(vpadd_f32(vsum1x0, vsum1x1), vpadd_f32(vsum1x2, vsum1x3)); +#endif + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + vacc0x0123 = vminq_f32(vacc0x0123, vmax); + vacc1x0123 = vminq_f32(vacc1x0123, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + vacc0x0123 = vmaxq_f32(vacc0x0123, vmin); + vacc1x0123 = vmaxq_f32(vacc1x0123, vmin); + + uint16x4_t vout0x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc0x0123), 16); + uint16x4_t vout1x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc1x0123), 16); + + if XNN_LIKELY(nc >= 4) { + vst1_u16(c0, vout0x0123); + c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); + vst1_u16(c1, vout1x0123); + c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride); + + a0 = (const uint16_t*) ((uintptr_t) a0 - kc); + a1 = (const uint16_t*) ((uintptr_t) a1 - kc); + + nc -= 4; + } else { + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_u16(vout0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_u16(vout1x0123), 0); c1 += 2; + + vout0x0123 = vext_u16(vout0x0123, vout0x0123, 2); + vout1x0123 = vext_u16(vout1x0123, vout1x0123, 2); + } + if (nc & 1) { + vst1_lane_u16(c0, vout0x0123, 0); + vst1_lane_u16(c1, vout1x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/bf16-gemm/gen/3x4c8-minmax-neonbf16-bfdot.c b/src/bf16-gemm/gen/3x4c8-minmax-neonbf16-bfdot.c new file mode 100644 index 000000000000..6c2bd774ccb0 --- /dev/null +++ b/src/bf16-gemm/gen/3x4c8-minmax-neonbf16-bfdot.c @@ -0,0 +1,210 @@ +// Auto-generated file. Do not edit! +// Template: src/bf16-gemm/c8-neonbf16.c.in +// Generator: tools/xngen +// +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 3); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(bfloat16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const bfloat16_t* a0 = (const bfloat16_t*) a; + bfloat16_t* c0 = (bfloat16_t*) c; + const bfloat16_t* a1 = (const bfloat16_t*) ((uintptr_t) a0 + a_stride); + bfloat16_t* c1 = (bfloat16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const bfloat16_t* a2 = (const bfloat16_t*) ((uintptr_t) a1 + a_stride); + bfloat16_t* c2 = (bfloat16_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + + const bfloat16_t* w = (const bfloat16_t*) w_ptr; + do { + float32x4_t vacc0x0 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x1 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x2 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x3 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc1x0 = vacc0x0; + float32x4_t vacc1x1 = vacc0x1; + float32x4_t vacc1x2 = vacc0x2; + float32x4_t vacc1x3 = vacc0x3; + float32x4_t vacc2x0 = vacc0x0; + float32x4_t vacc2x1 = vacc0x1; + float32x4_t vacc2x2 = vacc0x2; + float32x4_t vacc2x3 = vacc0x3; + + size_t k = kc; + for (; k >= 8 * sizeof(bfloat16_t); k -= 8 * sizeof(bfloat16_t)) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 += 8; + const bfloat16x8_t va1 = vld1q_bf16(a1); a1 += 8; + const bfloat16x8_t va2 = vld1q_bf16(a2); a2 += 8; + + const bfloat16x8_t vb0 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb1 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb2 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb3 = vld1q_bf16(w); w += 8; + + vacc0x0 = vbfdotq_f32(vacc0x0, va0, vb0); + vacc1x0 = vbfdotq_f32(vacc1x0, va1, vb0); + vacc2x0 = vbfdotq_f32(vacc2x0, va2, vb0); + vacc0x1 = vbfdotq_f32(vacc0x1, va0, vb1); + vacc1x1 = vbfdotq_f32(vacc1x1, va1, vb1); + vacc2x1 = vbfdotq_f32(vacc2x1, va2, vb1); + vacc0x2 = vbfdotq_f32(vacc0x2, va0, vb2); + vacc1x2 = vbfdotq_f32(vacc1x2, va1, vb2); + vacc2x2 = vbfdotq_f32(vacc2x2, va2, vb2); + vacc0x3 = vbfdotq_f32(vacc0x3, va0, vb3); + vacc1x3 = vbfdotq_f32(vacc1x3, va1, vb3); + vacc2x3 = vbfdotq_f32(vacc2x3, va2, vb3); + } + if XNN_UNLIKELY(k != 0) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 = (const bfloat16_t*) ((uintptr_t) a0 + k); + const bfloat16x8_t va1 = vld1q_bf16(a1); a1 = (const bfloat16_t*) ((uintptr_t) a1 + k); + const bfloat16x8_t va2 = vld1q_bf16(a2); a2 = (const bfloat16_t*) ((uintptr_t) a2 + k); + + const bfloat16x8_t vb0 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb1 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb2 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb3 = vld1q_bf16(w); w += 8; + + const uint16x8_t vm0 = vceqq_u16(vreinterpretq_u16_bf16(vb0), vmovq_n_u16(0)); + const uint16x8_t vm1 = vceqq_u16(vreinterpretq_u16_bf16(vb1), vmovq_n_u16(0)); + const uint16x8_t vm2 = vceqq_u16(vreinterpretq_u16_bf16(vb2), vmovq_n_u16(0)); + const uint16x8_t vm3 = vceqq_u16(vreinterpretq_u16_bf16(vb3), vmovq_n_u16(0)); + + const bfloat16x8_t va0x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm0)); + vacc0x0 = vbfdotq_f32(vacc0x0, va0x0, vb0); + const bfloat16x8_t va1x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm0)); + vacc1x0 = vbfdotq_f32(vacc1x0, va1x0, vb0); + const bfloat16x8_t va2x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va2), vm0)); + vacc2x0 = vbfdotq_f32(vacc2x0, va2x0, vb0); + const bfloat16x8_t va0x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm1)); + vacc0x1 = vbfdotq_f32(vacc0x1, va0x1, vb1); + const bfloat16x8_t va1x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm1)); + vacc1x1 = vbfdotq_f32(vacc1x1, va1x1, vb1); + const bfloat16x8_t va2x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va2), vm1)); + vacc2x1 = vbfdotq_f32(vacc2x1, va2x1, vb1); + const bfloat16x8_t va0x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm2)); + vacc0x2 = vbfdotq_f32(vacc0x2, va0x2, vb2); + const bfloat16x8_t va1x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm2)); + vacc1x2 = vbfdotq_f32(vacc1x2, va1x2, vb2); + const bfloat16x8_t va2x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va2), vm2)); + vacc2x2 = vbfdotq_f32(vacc2x2, va2x2, vb2); + const bfloat16x8_t va0x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm3)); + vacc0x3 = vbfdotq_f32(vacc0x3, va0x3, vb3); + const bfloat16x8_t va1x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm3)); + vacc1x3 = vbfdotq_f32(vacc1x3, va1x3, vb3); + const bfloat16x8_t va2x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va2), vm3)); + vacc2x3 = vbfdotq_f32(vacc2x3, va2x3, vb3); + } + +#if XNN_ARCH_ARM64 + const float32x4_t vacc0x01 = vpaddq_f32(vacc0x0, vacc0x1); + const float32x4_t vacc1x01 = vpaddq_f32(vacc1x0, vacc1x1); + const float32x4_t vacc2x01 = vpaddq_f32(vacc2x0, vacc2x1); + const float32x4_t vacc0x23 = vpaddq_f32(vacc0x2, vacc0x3); + const float32x4_t vacc1x23 = vpaddq_f32(vacc1x2, vacc1x3); + const float32x4_t vacc2x23 = vpaddq_f32(vacc2x2, vacc2x3); + + float32x4_t vacc0x0123 = vpaddq_f32(vacc0x01, vacc0x23); + float32x4_t vacc1x0123 = vpaddq_f32(vacc1x01, vacc1x23); + float32x4_t vacc2x0123 = vpaddq_f32(vacc2x01, vacc2x23); +#else + const float32x2_t vsum0x0 = vadd_f32(vget_low_f32(vacc0x0), vget_high_f32(vacc0x0)); + const float32x2_t vsum1x0 = vadd_f32(vget_low_f32(vacc1x0), vget_high_f32(vacc1x0)); + const float32x2_t vsum2x0 = vadd_f32(vget_low_f32(vacc2x0), vget_high_f32(vacc2x0)); + const float32x2_t vsum0x1 = vadd_f32(vget_low_f32(vacc0x1), vget_high_f32(vacc0x1)); + const float32x2_t vsum1x1 = vadd_f32(vget_low_f32(vacc1x1), vget_high_f32(vacc1x1)); + const float32x2_t vsum2x1 = vadd_f32(vget_low_f32(vacc2x1), vget_high_f32(vacc2x1)); + const float32x2_t vsum0x2 = vadd_f32(vget_low_f32(vacc0x2), vget_high_f32(vacc0x2)); + const float32x2_t vsum1x2 = vadd_f32(vget_low_f32(vacc1x2), vget_high_f32(vacc1x2)); + const float32x2_t vsum2x2 = vadd_f32(vget_low_f32(vacc2x2), vget_high_f32(vacc2x2)); + const float32x2_t vsum0x3 = vadd_f32(vget_low_f32(vacc0x3), vget_high_f32(vacc0x3)); + const float32x2_t vsum1x3 = vadd_f32(vget_low_f32(vacc1x3), vget_high_f32(vacc1x3)); + const float32x2_t vsum2x3 = vadd_f32(vget_low_f32(vacc2x3), vget_high_f32(vacc2x3)); + + float32x4_t vacc0x0123 = vcombine_f32(vpadd_f32(vsum0x0, vsum0x1), vpadd_f32(vsum0x2, vsum0x3)); + float32x4_t vacc1x0123 = vcombine_f32(vpadd_f32(vsum1x0, vsum1x1), vpadd_f32(vsum1x2, vsum1x3)); + float32x4_t vacc2x0123 = vcombine_f32(vpadd_f32(vsum2x0, vsum2x1), vpadd_f32(vsum2x2, vsum2x3)); +#endif + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + vacc0x0123 = vminq_f32(vacc0x0123, vmax); + vacc1x0123 = vminq_f32(vacc1x0123, vmax); + vacc2x0123 = vminq_f32(vacc2x0123, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + vacc0x0123 = vmaxq_f32(vacc0x0123, vmin); + vacc1x0123 = vmaxq_f32(vacc1x0123, vmin); + vacc2x0123 = vmaxq_f32(vacc2x0123, vmin); + + bfloat16x4_t vout0x0123 = vcvt_bf16_f32(vacc0x0123); + bfloat16x4_t vout1x0123 = vcvt_bf16_f32(vacc1x0123); + bfloat16x4_t vout2x0123 = vcvt_bf16_f32(vacc2x0123); + + if XNN_LIKELY(nc >= 4) { + vst1_bf16(c0, vout0x0123); + c0 = (bfloat16_t*) ((uintptr_t) c0 + cn_stride); + vst1_bf16(c1, vout1x0123); + c1 = (bfloat16_t*) ((uintptr_t) c1 + cn_stride); + vst1_bf16(c2, vout2x0123); + c2 = (bfloat16_t*) ((uintptr_t) c2 + cn_stride); + + a0 = (const bfloat16_t*) ((uintptr_t) a0 - kc); + a1 = (const bfloat16_t*) ((uintptr_t) a1 - kc); + a2 = (const bfloat16_t*) ((uintptr_t) a2 - kc); + + nc -= 4; + } else { + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_bf16(vout0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_bf16(vout1x0123), 0); c1 += 2; + vst1_lane_u32((void*) c2, vreinterpret_u32_bf16(vout2x0123), 0); c2 += 2; + + vout0x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout0x0123), vreinterpret_u16_bf16(vout0x0123), 2)); + vout1x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout1x0123), vreinterpret_u16_bf16(vout1x0123), 2)); + vout2x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout2x0123), vreinterpret_u16_bf16(vout2x0123), 2)); + } + if (nc & 1) { + vst1_lane_bf16(c0, vout0x0123, 0); + vst1_lane_bf16(c1, vout1x0123, 0); + vst1_lane_bf16(c2, vout2x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/bf16-gemm/gen/3x4c8-minmax-neonbf16-bfmlal.c b/src/bf16-gemm/gen/3x4c8-minmax-neonbf16-bfmlal.c new file mode 100644 index 000000000000..097b07e75e13 --- /dev/null +++ b/src/bf16-gemm/gen/3x4c8-minmax-neonbf16-bfmlal.c @@ -0,0 +1,235 @@ +// Auto-generated file. Do not edit! +// Template: src/bf16-gemm/c8-neonbf16.c.in +// Generator: tools/xngen +// +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 3); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(bfloat16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const bfloat16_t* a0 = (const bfloat16_t*) a; + bfloat16_t* c0 = (bfloat16_t*) c; + const bfloat16_t* a1 = (const bfloat16_t*) ((uintptr_t) a0 + a_stride); + bfloat16_t* c1 = (bfloat16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const bfloat16_t* a2 = (const bfloat16_t*) ((uintptr_t) a1 + a_stride); + bfloat16_t* c2 = (bfloat16_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + + const bfloat16_t* w = (const bfloat16_t*) w_ptr; + do { + float32x4_t vacc0x0 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x1 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x2 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x3 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc1x0 = vacc0x0; + float32x4_t vacc1x1 = vacc0x1; + float32x4_t vacc1x2 = vacc0x2; + float32x4_t vacc1x3 = vacc0x3; + float32x4_t vacc2x0 = vacc0x0; + float32x4_t vacc2x1 = vacc0x1; + float32x4_t vacc2x2 = vacc0x2; + float32x4_t vacc2x3 = vacc0x3; + + size_t k = kc; + for (; k >= 8 * sizeof(bfloat16_t); k -= 8 * sizeof(bfloat16_t)) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 += 8; + const bfloat16x8_t va1 = vld1q_bf16(a1); a1 += 8; + const bfloat16x8_t va2 = vld1q_bf16(a2); a2 += 8; + + const bfloat16x8_t vb0 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb1 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb2 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb3 = vld1q_bf16(w); w += 8; + + vacc0x0 = vbfmlalbq_f32(vacc0x0, va0, vb0); + vacc1x0 = vbfmlalbq_f32(vacc1x0, va1, vb0); + vacc2x0 = vbfmlalbq_f32(vacc2x0, va2, vb0); + vacc0x1 = vbfmlalbq_f32(vacc0x1, va0, vb1); + vacc1x1 = vbfmlalbq_f32(vacc1x1, va1, vb1); + vacc2x1 = vbfmlalbq_f32(vacc2x1, va2, vb1); + vacc0x2 = vbfmlalbq_f32(vacc0x2, va0, vb2); + vacc1x2 = vbfmlalbq_f32(vacc1x2, va1, vb2); + vacc2x2 = vbfmlalbq_f32(vacc2x2, va2, vb2); + vacc0x3 = vbfmlalbq_f32(vacc0x3, va0, vb3); + vacc1x3 = vbfmlalbq_f32(vacc1x3, va1, vb3); + vacc2x3 = vbfmlalbq_f32(vacc2x3, va2, vb3); + + vacc0x0 = vbfmlaltq_f32(vacc0x0, va0, vb0); + vacc1x0 = vbfmlaltq_f32(vacc1x0, va1, vb0); + vacc2x0 = vbfmlaltq_f32(vacc2x0, va2, vb0); + vacc0x1 = vbfmlaltq_f32(vacc0x1, va0, vb1); + vacc1x1 = vbfmlaltq_f32(vacc1x1, va1, vb1); + vacc2x1 = vbfmlaltq_f32(vacc2x1, va2, vb1); + vacc0x2 = vbfmlaltq_f32(vacc0x2, va0, vb2); + vacc1x2 = vbfmlaltq_f32(vacc1x2, va1, vb2); + vacc2x2 = vbfmlaltq_f32(vacc2x2, va2, vb2); + vacc0x3 = vbfmlaltq_f32(vacc0x3, va0, vb3); + vacc1x3 = vbfmlaltq_f32(vacc1x3, va1, vb3); + vacc2x3 = vbfmlaltq_f32(vacc2x3, va2, vb3); + } + if XNN_UNLIKELY(k != 0) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 = (const bfloat16_t*) ((uintptr_t) a0 + k); + const bfloat16x8_t va1 = vld1q_bf16(a1); a1 = (const bfloat16_t*) ((uintptr_t) a1 + k); + const bfloat16x8_t va2 = vld1q_bf16(a2); a2 = (const bfloat16_t*) ((uintptr_t) a2 + k); + + const bfloat16x8_t vb0 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb1 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb2 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb3 = vld1q_bf16(w); w += 8; + + const uint16x8_t vm0 = vceqq_u16(vreinterpretq_u16_bf16(vb0), vmovq_n_u16(0)); + const uint16x8_t vm1 = vceqq_u16(vreinterpretq_u16_bf16(vb1), vmovq_n_u16(0)); + const uint16x8_t vm2 = vceqq_u16(vreinterpretq_u16_bf16(vb2), vmovq_n_u16(0)); + const uint16x8_t vm3 = vceqq_u16(vreinterpretq_u16_bf16(vb3), vmovq_n_u16(0)); + + const bfloat16x8_t va0x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm0)); + vacc0x0 = vbfmlalbq_f32(vacc0x0, va0x0, vb0); + vacc0x0 = vbfmlaltq_f32(vacc0x0, va0x0, vb0); + const bfloat16x8_t va1x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm0)); + vacc1x0 = vbfmlalbq_f32(vacc1x0, va1x0, vb0); + vacc1x0 = vbfmlaltq_f32(vacc1x0, va1x0, vb0); + const bfloat16x8_t va2x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va2), vm0)); + vacc2x0 = vbfmlalbq_f32(vacc2x0, va2x0, vb0); + vacc2x0 = vbfmlaltq_f32(vacc2x0, va2x0, vb0); + const bfloat16x8_t va0x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm1)); + vacc0x1 = vbfmlalbq_f32(vacc0x1, va0x1, vb1); + vacc0x1 = vbfmlaltq_f32(vacc0x1, va0x1, vb1); + const bfloat16x8_t va1x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm1)); + vacc1x1 = vbfmlalbq_f32(vacc1x1, va1x1, vb1); + vacc1x1 = vbfmlaltq_f32(vacc1x1, va1x1, vb1); + const bfloat16x8_t va2x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va2), vm1)); + vacc2x1 = vbfmlalbq_f32(vacc2x1, va2x1, vb1); + vacc2x1 = vbfmlaltq_f32(vacc2x1, va2x1, vb1); + const bfloat16x8_t va0x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm2)); + vacc0x2 = vbfmlalbq_f32(vacc0x2, va0x2, vb2); + vacc0x2 = vbfmlaltq_f32(vacc0x2, va0x2, vb2); + const bfloat16x8_t va1x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm2)); + vacc1x2 = vbfmlalbq_f32(vacc1x2, va1x2, vb2); + vacc1x2 = vbfmlaltq_f32(vacc1x2, va1x2, vb2); + const bfloat16x8_t va2x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va2), vm2)); + vacc2x2 = vbfmlalbq_f32(vacc2x2, va2x2, vb2); + vacc2x2 = vbfmlaltq_f32(vacc2x2, va2x2, vb2); + const bfloat16x8_t va0x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm3)); + vacc0x3 = vbfmlalbq_f32(vacc0x3, va0x3, vb3); + vacc0x3 = vbfmlaltq_f32(vacc0x3, va0x3, vb3); + const bfloat16x8_t va1x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm3)); + vacc1x3 = vbfmlalbq_f32(vacc1x3, va1x3, vb3); + vacc1x3 = vbfmlaltq_f32(vacc1x3, va1x3, vb3); + const bfloat16x8_t va2x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va2), vm3)); + vacc2x3 = vbfmlalbq_f32(vacc2x3, va2x3, vb3); + vacc2x3 = vbfmlaltq_f32(vacc2x3, va2x3, vb3); + } + +#if XNN_ARCH_ARM64 + const float32x4_t vacc0x01 = vpaddq_f32(vacc0x0, vacc0x1); + const float32x4_t vacc1x01 = vpaddq_f32(vacc1x0, vacc1x1); + const float32x4_t vacc2x01 = vpaddq_f32(vacc2x0, vacc2x1); + const float32x4_t vacc0x23 = vpaddq_f32(vacc0x2, vacc0x3); + const float32x4_t vacc1x23 = vpaddq_f32(vacc1x2, vacc1x3); + const float32x4_t vacc2x23 = vpaddq_f32(vacc2x2, vacc2x3); + + float32x4_t vacc0x0123 = vpaddq_f32(vacc0x01, vacc0x23); + float32x4_t vacc1x0123 = vpaddq_f32(vacc1x01, vacc1x23); + float32x4_t vacc2x0123 = vpaddq_f32(vacc2x01, vacc2x23); +#else + const float32x2_t vsum0x0 = vadd_f32(vget_low_f32(vacc0x0), vget_high_f32(vacc0x0)); + const float32x2_t vsum1x0 = vadd_f32(vget_low_f32(vacc1x0), vget_high_f32(vacc1x0)); + const float32x2_t vsum2x0 = vadd_f32(vget_low_f32(vacc2x0), vget_high_f32(vacc2x0)); + const float32x2_t vsum0x1 = vadd_f32(vget_low_f32(vacc0x1), vget_high_f32(vacc0x1)); + const float32x2_t vsum1x1 = vadd_f32(vget_low_f32(vacc1x1), vget_high_f32(vacc1x1)); + const float32x2_t vsum2x1 = vadd_f32(vget_low_f32(vacc2x1), vget_high_f32(vacc2x1)); + const float32x2_t vsum0x2 = vadd_f32(vget_low_f32(vacc0x2), vget_high_f32(vacc0x2)); + const float32x2_t vsum1x2 = vadd_f32(vget_low_f32(vacc1x2), vget_high_f32(vacc1x2)); + const float32x2_t vsum2x2 = vadd_f32(vget_low_f32(vacc2x2), vget_high_f32(vacc2x2)); + const float32x2_t vsum0x3 = vadd_f32(vget_low_f32(vacc0x3), vget_high_f32(vacc0x3)); + const float32x2_t vsum1x3 = vadd_f32(vget_low_f32(vacc1x3), vget_high_f32(vacc1x3)); + const float32x2_t vsum2x3 = vadd_f32(vget_low_f32(vacc2x3), vget_high_f32(vacc2x3)); + + float32x4_t vacc0x0123 = vcombine_f32(vpadd_f32(vsum0x0, vsum0x1), vpadd_f32(vsum0x2, vsum0x3)); + float32x4_t vacc1x0123 = vcombine_f32(vpadd_f32(vsum1x0, vsum1x1), vpadd_f32(vsum1x2, vsum1x3)); + float32x4_t vacc2x0123 = vcombine_f32(vpadd_f32(vsum2x0, vsum2x1), vpadd_f32(vsum2x2, vsum2x3)); +#endif + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + vacc0x0123 = vminq_f32(vacc0x0123, vmax); + vacc1x0123 = vminq_f32(vacc1x0123, vmax); + vacc2x0123 = vminq_f32(vacc2x0123, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + vacc0x0123 = vmaxq_f32(vacc0x0123, vmin); + vacc1x0123 = vmaxq_f32(vacc1x0123, vmin); + vacc2x0123 = vmaxq_f32(vacc2x0123, vmin); + + bfloat16x4_t vout0x0123 = vcvt_bf16_f32(vacc0x0123); + bfloat16x4_t vout1x0123 = vcvt_bf16_f32(vacc1x0123); + bfloat16x4_t vout2x0123 = vcvt_bf16_f32(vacc2x0123); + + if XNN_LIKELY(nc >= 4) { + vst1_bf16(c0, vout0x0123); + c0 = (bfloat16_t*) ((uintptr_t) c0 + cn_stride); + vst1_bf16(c1, vout1x0123); + c1 = (bfloat16_t*) ((uintptr_t) c1 + cn_stride); + vst1_bf16(c2, vout2x0123); + c2 = (bfloat16_t*) ((uintptr_t) c2 + cn_stride); + + a0 = (const bfloat16_t*) ((uintptr_t) a0 - kc); + a1 = (const bfloat16_t*) ((uintptr_t) a1 - kc); + a2 = (const bfloat16_t*) ((uintptr_t) a2 - kc); + + nc -= 4; + } else { + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_bf16(vout0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_bf16(vout1x0123), 0); c1 += 2; + vst1_lane_u32((void*) c2, vreinterpret_u32_bf16(vout2x0123), 0); c2 += 2; + + vout0x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout0x0123), vreinterpret_u16_bf16(vout0x0123), 2)); + vout1x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout1x0123), vreinterpret_u16_bf16(vout1x0123), 2)); + vout2x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout2x0123), vreinterpret_u16_bf16(vout2x0123), 2)); + } + if (nc & 1) { + vst1_lane_bf16(c0, vout0x0123, 0); + vst1_lane_bf16(c1, vout1x0123, 0); + vst1_lane_bf16(c2, vout2x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/bf16-gemm/gen/3x4c8-minmax-neonfma-shland.c b/src/bf16-gemm/gen/3x4c8-minmax-neonfma-shland.c new file mode 100644 index 000000000000..370ac8ab1a2c --- /dev/null +++ b/src/bf16-gemm/gen/3x4c8-minmax-neonfma-shland.c @@ -0,0 +1,292 @@ +// Auto-generated file. Do not edit! +// Template: src/bf16-gemm/c8-neon.c.in +// Generator: tools/xngen +// +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 3); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(uint16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const uint16_t* a0 = (const uint16_t*) a; + uint16_t* c0 = (uint16_t*) c; + const uint16_t* a1 = (const uint16_t*) ((uintptr_t) a0 + a_stride); + uint16_t* c1 = (uint16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const uint16_t* a2 = (const uint16_t*) ((uintptr_t) a1 + a_stride); + uint16_t* c2 = (uint16_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + + const uint16_t* w = (const uint16_t*) w_ptr; + const uint16x8_t vmask = vreinterpretq_u16_u32(vmovq_n_u32(UINT32_C(0xFFFF0000))); + do { + float32x4_t vacc0x0 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x1 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x2 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x3 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc1x0 = vacc0x0; + float32x4_t vacc1x1 = vacc0x1; + float32x4_t vacc1x2 = vacc0x2; + float32x4_t vacc1x3 = vacc0x3; + float32x4_t vacc2x0 = vacc0x0; + float32x4_t vacc2x1 = vacc0x1; + float32x4_t vacc2x2 = vacc0x2; + float32x4_t vacc2x3 = vacc0x3; + + size_t k = kc; + for (; k >= 8 * sizeof(uint16_t); k -= 8 * sizeof(uint16_t)) { + const uint16x8_t va0 = vld1q_u16(a0); a0 += 8; + const uint16x8_t va1 = vld1q_u16(a1); a1 += 8; + const uint16x8_t va2 = vld1q_u16(a2); a2 += 8; + + const uint16x8_t vb0 = vld1q_u16(w); w += 8; + const uint16x8_t vb1 = vld1q_u16(w); w += 8; + const uint16x8_t vb2 = vld1q_u16(w); w += 8; + const uint16x8_t vb3 = vld1q_u16(w); w += 8; + + const float32x4_t va0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va0), 16)); + const float32x4_t va1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va1), 16)); + const float32x4_t va2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va2), 16)); + + const float32x4_t vb0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb0), 16)); + const float32x4_t vb1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb1), 16)); + const float32x4_t vb2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb2), 16)); + const float32x4_t vb3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb3), 16)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0e, vb0e); + vacc1x0 = vfmaq_f32(vacc1x0, va1e, vb0e); + vacc2x0 = vfmaq_f32(vacc2x0, va2e, vb0e); + vacc0x1 = vfmaq_f32(vacc0x1, va0e, vb1e); + vacc1x1 = vfmaq_f32(vacc1x1, va1e, vb1e); + vacc2x1 = vfmaq_f32(vacc2x1, va2e, vb1e); + vacc0x2 = vfmaq_f32(vacc0x2, va0e, vb2e); + vacc1x2 = vfmaq_f32(vacc1x2, va1e, vb2e); + vacc2x2 = vfmaq_f32(vacc2x2, va2e, vb2e); + vacc0x3 = vfmaq_f32(vacc0x3, va0e, vb3e); + vacc1x3 = vfmaq_f32(vacc1x3, va1e, vb3e); + vacc2x3 = vfmaq_f32(vacc2x3, va2e, vb3e); + + const float32x4_t va0o = vreinterpretq_f32_u16(vandq_u16(va0, vmask)); + const float32x4_t va1o = vreinterpretq_f32_u16(vandq_u16(va1, vmask)); + const float32x4_t va2o = vreinterpretq_f32_u16(vandq_u16(va2, vmask)); + + const float32x4_t vb0o = vreinterpretq_f32_u16(vandq_u16(vb0, vmask)); + const float32x4_t vb1o = vreinterpretq_f32_u16(vandq_u16(vb1, vmask)); + const float32x4_t vb2o = vreinterpretq_f32_u16(vandq_u16(vb2, vmask)); + const float32x4_t vb3o = vreinterpretq_f32_u16(vandq_u16(vb3, vmask)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0o, vb0o); + vacc1x0 = vfmaq_f32(vacc1x0, va1o, vb0o); + vacc2x0 = vfmaq_f32(vacc2x0, va2o, vb0o); + vacc0x1 = vfmaq_f32(vacc0x1, va0o, vb1o); + vacc1x1 = vfmaq_f32(vacc1x1, va1o, vb1o); + vacc2x1 = vfmaq_f32(vacc2x1, va2o, vb1o); + vacc0x2 = vfmaq_f32(vacc0x2, va0o, vb2o); + vacc1x2 = vfmaq_f32(vacc1x2, va1o, vb2o); + vacc2x2 = vfmaq_f32(vacc2x2, va2o, vb2o); + vacc0x3 = vfmaq_f32(vacc0x3, va0o, vb3o); + vacc1x3 = vfmaq_f32(vacc1x3, va1o, vb3o); + vacc2x3 = vfmaq_f32(vacc2x3, va2o, vb3o); + } + if XNN_UNLIKELY(k != 0) { + const uint16x8_t va0 = vld1q_u16(a0); a0 = (const uint16_t*) ((uintptr_t) a0 + k); + const uint16x8_t va1 = vld1q_u16(a1); a1 = (const uint16_t*) ((uintptr_t) a1 + k); + const uint16x8_t va2 = vld1q_u16(a2); a2 = (const uint16_t*) ((uintptr_t) a2 + k); + + const uint16x8_t vb0 = vld1q_u16(w); w += 8; + const uint16x8_t vb1 = vld1q_u16(w); w += 8; + const uint16x8_t vb2 = vld1q_u16(w); w += 8; + const uint16x8_t vb3 = vld1q_u16(w); w += 8; + + const uint16x8_t vm0 = vceqq_u16(vb0, vmovq_n_u16(0)); + const uint16x8_t vm1 = vceqq_u16(vb1, vmovq_n_u16(0)); + const uint16x8_t vm2 = vceqq_u16(vb2, vmovq_n_u16(0)); + const uint16x8_t vm3 = vceqq_u16(vb3, vmovq_n_u16(0)); + + const float32x4_t vb0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb0), 16)); + const float32x4_t vb1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb1), 16)); + const float32x4_t vb2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb2), 16)); + const float32x4_t vb3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb3), 16)); + + const uint16x8_t va0x0 = vbicq_u16(va0, vm0); + const uint16x8_t va1x0 = vbicq_u16(va1, vm0); + const uint16x8_t va2x0 = vbicq_u16(va2, vm0); + const uint16x8_t va0x1 = vbicq_u16(va0, vm1); + const uint16x8_t va1x1 = vbicq_u16(va1, vm1); + const uint16x8_t va2x1 = vbicq_u16(va2, vm1); + const uint16x8_t va0x2 = vbicq_u16(va0, vm2); + const uint16x8_t va1x2 = vbicq_u16(va1, vm2); + const uint16x8_t va2x2 = vbicq_u16(va2, vm2); + const uint16x8_t va0x3 = vbicq_u16(va0, vm3); + const uint16x8_t va1x3 = vbicq_u16(va1, vm3); + const uint16x8_t va2x3 = vbicq_u16(va2, vm3); + + const float32x4_t va0x0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va0x0), 16)); + const float32x4_t va1x0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va1x0), 16)); + const float32x4_t va2x0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va2x0), 16)); + const float32x4_t va0x1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va0x1), 16)); + const float32x4_t va1x1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va1x1), 16)); + const float32x4_t va2x1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va2x1), 16)); + const float32x4_t va0x2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va0x2), 16)); + const float32x4_t va1x2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va1x2), 16)); + const float32x4_t va2x2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va2x2), 16)); + const float32x4_t va0x3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va0x3), 16)); + const float32x4_t va1x3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va1x3), 16)); + const float32x4_t va2x3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va2x3), 16)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0x0e, vb0e); + vacc1x0 = vfmaq_f32(vacc1x0, va1x0e, vb0e); + vacc2x0 = vfmaq_f32(vacc2x0, va2x0e, vb0e); + vacc0x1 = vfmaq_f32(vacc0x1, va0x1e, vb1e); + vacc1x1 = vfmaq_f32(vacc1x1, va1x1e, vb1e); + vacc2x1 = vfmaq_f32(vacc2x1, va2x1e, vb1e); + vacc0x2 = vfmaq_f32(vacc0x2, va0x2e, vb2e); + vacc1x2 = vfmaq_f32(vacc1x2, va1x2e, vb2e); + vacc2x2 = vfmaq_f32(vacc2x2, va2x2e, vb2e); + vacc0x3 = vfmaq_f32(vacc0x3, va0x3e, vb3e); + vacc1x3 = vfmaq_f32(vacc1x3, va1x3e, vb3e); + vacc2x3 = vfmaq_f32(vacc2x3, va2x3e, vb3e); + + const float32x4_t vb0o = vreinterpretq_f32_u16(vandq_u16(vb0, vmask)); + const float32x4_t vb1o = vreinterpretq_f32_u16(vandq_u16(vb1, vmask)); + const float32x4_t vb2o = vreinterpretq_f32_u16(vandq_u16(vb2, vmask)); + const float32x4_t vb3o = vreinterpretq_f32_u16(vandq_u16(vb3, vmask)); + + const float32x4_t va0x0o = vreinterpretq_f32_u16(vandq_u16(va0x0, vmask)); + const float32x4_t va1x0o = vreinterpretq_f32_u16(vandq_u16(va1x0, vmask)); + const float32x4_t va2x0o = vreinterpretq_f32_u16(vandq_u16(va2x0, vmask)); + const float32x4_t va0x1o = vreinterpretq_f32_u16(vandq_u16(va0x1, vmask)); + const float32x4_t va1x1o = vreinterpretq_f32_u16(vandq_u16(va1x1, vmask)); + const float32x4_t va2x1o = vreinterpretq_f32_u16(vandq_u16(va2x1, vmask)); + const float32x4_t va0x2o = vreinterpretq_f32_u16(vandq_u16(va0x2, vmask)); + const float32x4_t va1x2o = vreinterpretq_f32_u16(vandq_u16(va1x2, vmask)); + const float32x4_t va2x2o = vreinterpretq_f32_u16(vandq_u16(va2x2, vmask)); + const float32x4_t va0x3o = vreinterpretq_f32_u16(vandq_u16(va0x3, vmask)); + const float32x4_t va1x3o = vreinterpretq_f32_u16(vandq_u16(va1x3, vmask)); + const float32x4_t va2x3o = vreinterpretq_f32_u16(vandq_u16(va2x3, vmask)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0x0o, vb0o); + vacc1x0 = vfmaq_f32(vacc1x0, va1x0o, vb0o); + vacc2x0 = vfmaq_f32(vacc2x0, va2x0o, vb0o); + vacc0x1 = vfmaq_f32(vacc0x1, va0x1o, vb1o); + vacc1x1 = vfmaq_f32(vacc1x1, va1x1o, vb1o); + vacc2x1 = vfmaq_f32(vacc2x1, va2x1o, vb1o); + vacc0x2 = vfmaq_f32(vacc0x2, va0x2o, vb2o); + vacc1x2 = vfmaq_f32(vacc1x2, va1x2o, vb2o); + vacc2x2 = vfmaq_f32(vacc2x2, va2x2o, vb2o); + vacc0x3 = vfmaq_f32(vacc0x3, va0x3o, vb3o); + vacc1x3 = vfmaq_f32(vacc1x3, va1x3o, vb3o); + vacc2x3 = vfmaq_f32(vacc2x3, va2x3o, vb3o); + } + +#if XNN_ARCH_ARM64 + const float32x4_t vacc0x01 = vpaddq_f32(vacc0x0, vacc0x1); + const float32x4_t vacc1x01 = vpaddq_f32(vacc1x0, vacc1x1); + const float32x4_t vacc2x01 = vpaddq_f32(vacc2x0, vacc2x1); + const float32x4_t vacc0x23 = vpaddq_f32(vacc0x2, vacc0x3); + const float32x4_t vacc1x23 = vpaddq_f32(vacc1x2, vacc1x3); + const float32x4_t vacc2x23 = vpaddq_f32(vacc2x2, vacc2x3); + + float32x4_t vacc0x0123 = vpaddq_f32(vacc0x01, vacc0x23); + float32x4_t vacc1x0123 = vpaddq_f32(vacc1x01, vacc1x23); + float32x4_t vacc2x0123 = vpaddq_f32(vacc2x01, vacc2x23); +#else + const float32x2_t vsum0x0 = vadd_f32(vget_low_f32(vacc0x0), vget_high_f32(vacc0x0)); + const float32x2_t vsum1x0 = vadd_f32(vget_low_f32(vacc1x0), vget_high_f32(vacc1x0)); + const float32x2_t vsum2x0 = vadd_f32(vget_low_f32(vacc2x0), vget_high_f32(vacc2x0)); + const float32x2_t vsum0x1 = vadd_f32(vget_low_f32(vacc0x1), vget_high_f32(vacc0x1)); + const float32x2_t vsum1x1 = vadd_f32(vget_low_f32(vacc1x1), vget_high_f32(vacc1x1)); + const float32x2_t vsum2x1 = vadd_f32(vget_low_f32(vacc2x1), vget_high_f32(vacc2x1)); + const float32x2_t vsum0x2 = vadd_f32(vget_low_f32(vacc0x2), vget_high_f32(vacc0x2)); + const float32x2_t vsum1x2 = vadd_f32(vget_low_f32(vacc1x2), vget_high_f32(vacc1x2)); + const float32x2_t vsum2x2 = vadd_f32(vget_low_f32(vacc2x2), vget_high_f32(vacc2x2)); + const float32x2_t vsum0x3 = vadd_f32(vget_low_f32(vacc0x3), vget_high_f32(vacc0x3)); + const float32x2_t vsum1x3 = vadd_f32(vget_low_f32(vacc1x3), vget_high_f32(vacc1x3)); + const float32x2_t vsum2x3 = vadd_f32(vget_low_f32(vacc2x3), vget_high_f32(vacc2x3)); + + float32x4_t vacc0x0123 = vcombine_f32(vpadd_f32(vsum0x0, vsum0x1), vpadd_f32(vsum0x2, vsum0x3)); + float32x4_t vacc1x0123 = vcombine_f32(vpadd_f32(vsum1x0, vsum1x1), vpadd_f32(vsum1x2, vsum1x3)); + float32x4_t vacc2x0123 = vcombine_f32(vpadd_f32(vsum2x0, vsum2x1), vpadd_f32(vsum2x2, vsum2x3)); +#endif + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + vacc0x0123 = vminq_f32(vacc0x0123, vmax); + vacc1x0123 = vminq_f32(vacc1x0123, vmax); + vacc2x0123 = vminq_f32(vacc2x0123, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + vacc0x0123 = vmaxq_f32(vacc0x0123, vmin); + vacc1x0123 = vmaxq_f32(vacc1x0123, vmin); + vacc2x0123 = vmaxq_f32(vacc2x0123, vmin); + + uint16x4_t vout0x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc0x0123), 16); + uint16x4_t vout1x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc1x0123), 16); + uint16x4_t vout2x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc2x0123), 16); + + if XNN_LIKELY(nc >= 4) { + vst1_u16(c0, vout0x0123); + c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); + vst1_u16(c1, vout1x0123); + c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride); + vst1_u16(c2, vout2x0123); + c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride); + + a0 = (const uint16_t*) ((uintptr_t) a0 - kc); + a1 = (const uint16_t*) ((uintptr_t) a1 - kc); + a2 = (const uint16_t*) ((uintptr_t) a2 - kc); + + nc -= 4; + } else { + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_u16(vout0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_u16(vout1x0123), 0); c1 += 2; + vst1_lane_u32((void*) c2, vreinterpret_u32_u16(vout2x0123), 0); c2 += 2; + + vout0x0123 = vext_u16(vout0x0123, vout0x0123, 2); + vout1x0123 = vext_u16(vout1x0123, vout1x0123, 2); + vout2x0123 = vext_u16(vout2x0123, vout2x0123, 2); + } + if (nc & 1) { + vst1_lane_u16(c0, vout0x0123, 0); + vst1_lane_u16(c1, vout1x0123, 0); + vst1_lane_u16(c2, vout2x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/bf16-gemm/gen/3x4c8-minmax-neonfma-zip.c b/src/bf16-gemm/gen/3x4c8-minmax-neonfma-zip.c new file mode 100644 index 000000000000..e13b541fe3bb --- /dev/null +++ b/src/bf16-gemm/gen/3x4c8-minmax-neonfma-zip.c @@ -0,0 +1,292 @@ +// Auto-generated file. Do not edit! +// Template: src/bf16-gemm/c8-neon.c.in +// Generator: tools/xngen +// +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 3); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(uint16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const uint16_t* a0 = (const uint16_t*) a; + uint16_t* c0 = (uint16_t*) c; + const uint16_t* a1 = (const uint16_t*) ((uintptr_t) a0 + a_stride); + uint16_t* c1 = (uint16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const uint16_t* a2 = (const uint16_t*) ((uintptr_t) a1 + a_stride); + uint16_t* c2 = (uint16_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + + const uint16_t* w = (const uint16_t*) w_ptr; + const uint16x8_t vzero = vmovq_n_u16(0); + do { + float32x4_t vacc0x0 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x1 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x2 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x3 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc1x0 = vacc0x0; + float32x4_t vacc1x1 = vacc0x1; + float32x4_t vacc1x2 = vacc0x2; + float32x4_t vacc1x3 = vacc0x3; + float32x4_t vacc2x0 = vacc0x0; + float32x4_t vacc2x1 = vacc0x1; + float32x4_t vacc2x2 = vacc0x2; + float32x4_t vacc2x3 = vacc0x3; + + size_t k = kc; + for (; k >= 8 * sizeof(uint16_t); k -= 8 * sizeof(uint16_t)) { + const uint16x8_t va0 = vld1q_u16(a0); a0 += 8; + const uint16x8_t va1 = vld1q_u16(a1); a1 += 8; + const uint16x8_t va2 = vld1q_u16(a2); a2 += 8; + + const uint16x8_t vb0 = vld1q_u16(w); w += 8; + const uint16x8_t vb1 = vld1q_u16(w); w += 8; + const uint16x8_t vb2 = vld1q_u16(w); w += 8; + const uint16x8_t vb3 = vld1q_u16(w); w += 8; + + const float32x4_t va0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va0)); + const float32x4_t va1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va1)); + const float32x4_t va2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va2)); + + const float32x4_t vb0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb0)); + const float32x4_t vb1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb1)); + const float32x4_t vb2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb2)); + const float32x4_t vb3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb3)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0e, vb0e); + vacc1x0 = vfmaq_f32(vacc1x0, va1e, vb0e); + vacc2x0 = vfmaq_f32(vacc2x0, va2e, vb0e); + vacc0x1 = vfmaq_f32(vacc0x1, va0e, vb1e); + vacc1x1 = vfmaq_f32(vacc1x1, va1e, vb1e); + vacc2x1 = vfmaq_f32(vacc2x1, va2e, vb1e); + vacc0x2 = vfmaq_f32(vacc0x2, va0e, vb2e); + vacc1x2 = vfmaq_f32(vacc1x2, va1e, vb2e); + vacc2x2 = vfmaq_f32(vacc2x2, va2e, vb2e); + vacc0x3 = vfmaq_f32(vacc0x3, va0e, vb3e); + vacc1x3 = vfmaq_f32(vacc1x3, va1e, vb3e); + vacc2x3 = vfmaq_f32(vacc2x3, va2e, vb3e); + + const float32x4_t va0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va0)); + const float32x4_t va1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va1)); + const float32x4_t va2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va2)); + + const float32x4_t vb0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb0)); + const float32x4_t vb1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb1)); + const float32x4_t vb2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb2)); + const float32x4_t vb3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb3)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0o, vb0o); + vacc1x0 = vfmaq_f32(vacc1x0, va1o, vb0o); + vacc2x0 = vfmaq_f32(vacc2x0, va2o, vb0o); + vacc0x1 = vfmaq_f32(vacc0x1, va0o, vb1o); + vacc1x1 = vfmaq_f32(vacc1x1, va1o, vb1o); + vacc2x1 = vfmaq_f32(vacc2x1, va2o, vb1o); + vacc0x2 = vfmaq_f32(vacc0x2, va0o, vb2o); + vacc1x2 = vfmaq_f32(vacc1x2, va1o, vb2o); + vacc2x2 = vfmaq_f32(vacc2x2, va2o, vb2o); + vacc0x3 = vfmaq_f32(vacc0x3, va0o, vb3o); + vacc1x3 = vfmaq_f32(vacc1x3, va1o, vb3o); + vacc2x3 = vfmaq_f32(vacc2x3, va2o, vb3o); + } + if XNN_UNLIKELY(k != 0) { + const uint16x8_t va0 = vld1q_u16(a0); a0 = (const uint16_t*) ((uintptr_t) a0 + k); + const uint16x8_t va1 = vld1q_u16(a1); a1 = (const uint16_t*) ((uintptr_t) a1 + k); + const uint16x8_t va2 = vld1q_u16(a2); a2 = (const uint16_t*) ((uintptr_t) a2 + k); + + const uint16x8_t vb0 = vld1q_u16(w); w += 8; + const uint16x8_t vb1 = vld1q_u16(w); w += 8; + const uint16x8_t vb2 = vld1q_u16(w); w += 8; + const uint16x8_t vb3 = vld1q_u16(w); w += 8; + + const uint16x8_t vm0 = vceqq_u16(vb0, vmovq_n_u16(0)); + const uint16x8_t vm1 = vceqq_u16(vb1, vmovq_n_u16(0)); + const uint16x8_t vm2 = vceqq_u16(vb2, vmovq_n_u16(0)); + const uint16x8_t vm3 = vceqq_u16(vb3, vmovq_n_u16(0)); + + const float32x4_t vb0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb0)); + const float32x4_t vb1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb1)); + const float32x4_t vb2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb2)); + const float32x4_t vb3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb3)); + + const uint16x8_t va0x0 = vbicq_u16(va0, vm0); + const uint16x8_t va1x0 = vbicq_u16(va1, vm0); + const uint16x8_t va2x0 = vbicq_u16(va2, vm0); + const uint16x8_t va0x1 = vbicq_u16(va0, vm1); + const uint16x8_t va1x1 = vbicq_u16(va1, vm1); + const uint16x8_t va2x1 = vbicq_u16(va2, vm1); + const uint16x8_t va0x2 = vbicq_u16(va0, vm2); + const uint16x8_t va1x2 = vbicq_u16(va1, vm2); + const uint16x8_t va2x2 = vbicq_u16(va2, vm2); + const uint16x8_t va0x3 = vbicq_u16(va0, vm3); + const uint16x8_t va1x3 = vbicq_u16(va1, vm3); + const uint16x8_t va2x3 = vbicq_u16(va2, vm3); + + const float32x4_t va0x0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va0x0)); + const float32x4_t va1x0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va1x0)); + const float32x4_t va2x0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va2x0)); + const float32x4_t va0x1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va0x1)); + const float32x4_t va1x1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va1x1)); + const float32x4_t va2x1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va2x1)); + const float32x4_t va0x2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va0x2)); + const float32x4_t va1x2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va1x2)); + const float32x4_t va2x2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va2x2)); + const float32x4_t va0x3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va0x3)); + const float32x4_t va1x3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va1x3)); + const float32x4_t va2x3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va2x3)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0x0e, vb0e); + vacc1x0 = vfmaq_f32(vacc1x0, va1x0e, vb0e); + vacc2x0 = vfmaq_f32(vacc2x0, va2x0e, vb0e); + vacc0x1 = vfmaq_f32(vacc0x1, va0x1e, vb1e); + vacc1x1 = vfmaq_f32(vacc1x1, va1x1e, vb1e); + vacc2x1 = vfmaq_f32(vacc2x1, va2x1e, vb1e); + vacc0x2 = vfmaq_f32(vacc0x2, va0x2e, vb2e); + vacc1x2 = vfmaq_f32(vacc1x2, va1x2e, vb2e); + vacc2x2 = vfmaq_f32(vacc2x2, va2x2e, vb2e); + vacc0x3 = vfmaq_f32(vacc0x3, va0x3e, vb3e); + vacc1x3 = vfmaq_f32(vacc1x3, va1x3e, vb3e); + vacc2x3 = vfmaq_f32(vacc2x3, va2x3e, vb3e); + + const float32x4_t vb0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb0)); + const float32x4_t vb1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb1)); + const float32x4_t vb2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb2)); + const float32x4_t vb3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb3)); + + const float32x4_t va0x0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va0x0)); + const float32x4_t va1x0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va1x0)); + const float32x4_t va2x0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va2x0)); + const float32x4_t va0x1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va0x1)); + const float32x4_t va1x1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va1x1)); + const float32x4_t va2x1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va2x1)); + const float32x4_t va0x2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va0x2)); + const float32x4_t va1x2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va1x2)); + const float32x4_t va2x2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va2x2)); + const float32x4_t va0x3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va0x3)); + const float32x4_t va1x3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va1x3)); + const float32x4_t va2x3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va2x3)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0x0o, vb0o); + vacc1x0 = vfmaq_f32(vacc1x0, va1x0o, vb0o); + vacc2x0 = vfmaq_f32(vacc2x0, va2x0o, vb0o); + vacc0x1 = vfmaq_f32(vacc0x1, va0x1o, vb1o); + vacc1x1 = vfmaq_f32(vacc1x1, va1x1o, vb1o); + vacc2x1 = vfmaq_f32(vacc2x1, va2x1o, vb1o); + vacc0x2 = vfmaq_f32(vacc0x2, va0x2o, vb2o); + vacc1x2 = vfmaq_f32(vacc1x2, va1x2o, vb2o); + vacc2x2 = vfmaq_f32(vacc2x2, va2x2o, vb2o); + vacc0x3 = vfmaq_f32(vacc0x3, va0x3o, vb3o); + vacc1x3 = vfmaq_f32(vacc1x3, va1x3o, vb3o); + vacc2x3 = vfmaq_f32(vacc2x3, va2x3o, vb3o); + } + +#if XNN_ARCH_ARM64 + const float32x4_t vacc0x01 = vpaddq_f32(vacc0x0, vacc0x1); + const float32x4_t vacc1x01 = vpaddq_f32(vacc1x0, vacc1x1); + const float32x4_t vacc2x01 = vpaddq_f32(vacc2x0, vacc2x1); + const float32x4_t vacc0x23 = vpaddq_f32(vacc0x2, vacc0x3); + const float32x4_t vacc1x23 = vpaddq_f32(vacc1x2, vacc1x3); + const float32x4_t vacc2x23 = vpaddq_f32(vacc2x2, vacc2x3); + + float32x4_t vacc0x0123 = vpaddq_f32(vacc0x01, vacc0x23); + float32x4_t vacc1x0123 = vpaddq_f32(vacc1x01, vacc1x23); + float32x4_t vacc2x0123 = vpaddq_f32(vacc2x01, vacc2x23); +#else + const float32x2_t vsum0x0 = vadd_f32(vget_low_f32(vacc0x0), vget_high_f32(vacc0x0)); + const float32x2_t vsum1x0 = vadd_f32(vget_low_f32(vacc1x0), vget_high_f32(vacc1x0)); + const float32x2_t vsum2x0 = vadd_f32(vget_low_f32(vacc2x0), vget_high_f32(vacc2x0)); + const float32x2_t vsum0x1 = vadd_f32(vget_low_f32(vacc0x1), vget_high_f32(vacc0x1)); + const float32x2_t vsum1x1 = vadd_f32(vget_low_f32(vacc1x1), vget_high_f32(vacc1x1)); + const float32x2_t vsum2x1 = vadd_f32(vget_low_f32(vacc2x1), vget_high_f32(vacc2x1)); + const float32x2_t vsum0x2 = vadd_f32(vget_low_f32(vacc0x2), vget_high_f32(vacc0x2)); + const float32x2_t vsum1x2 = vadd_f32(vget_low_f32(vacc1x2), vget_high_f32(vacc1x2)); + const float32x2_t vsum2x2 = vadd_f32(vget_low_f32(vacc2x2), vget_high_f32(vacc2x2)); + const float32x2_t vsum0x3 = vadd_f32(vget_low_f32(vacc0x3), vget_high_f32(vacc0x3)); + const float32x2_t vsum1x3 = vadd_f32(vget_low_f32(vacc1x3), vget_high_f32(vacc1x3)); + const float32x2_t vsum2x3 = vadd_f32(vget_low_f32(vacc2x3), vget_high_f32(vacc2x3)); + + float32x4_t vacc0x0123 = vcombine_f32(vpadd_f32(vsum0x0, vsum0x1), vpadd_f32(vsum0x2, vsum0x3)); + float32x4_t vacc1x0123 = vcombine_f32(vpadd_f32(vsum1x0, vsum1x1), vpadd_f32(vsum1x2, vsum1x3)); + float32x4_t vacc2x0123 = vcombine_f32(vpadd_f32(vsum2x0, vsum2x1), vpadd_f32(vsum2x2, vsum2x3)); +#endif + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + vacc0x0123 = vminq_f32(vacc0x0123, vmax); + vacc1x0123 = vminq_f32(vacc1x0123, vmax); + vacc2x0123 = vminq_f32(vacc2x0123, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + vacc0x0123 = vmaxq_f32(vacc0x0123, vmin); + vacc1x0123 = vmaxq_f32(vacc1x0123, vmin); + vacc2x0123 = vmaxq_f32(vacc2x0123, vmin); + + uint16x4_t vout0x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc0x0123), 16); + uint16x4_t vout1x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc1x0123), 16); + uint16x4_t vout2x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc2x0123), 16); + + if XNN_LIKELY(nc >= 4) { + vst1_u16(c0, vout0x0123); + c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); + vst1_u16(c1, vout1x0123); + c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride); + vst1_u16(c2, vout2x0123); + c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride); + + a0 = (const uint16_t*) ((uintptr_t) a0 - kc); + a1 = (const uint16_t*) ((uintptr_t) a1 - kc); + a2 = (const uint16_t*) ((uintptr_t) a2 - kc); + + nc -= 4; + } else { + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_u16(vout0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_u16(vout1x0123), 0); c1 += 2; + vst1_lane_u32((void*) c2, vreinterpret_u32_u16(vout2x0123), 0); c2 += 2; + + vout0x0123 = vext_u16(vout0x0123, vout0x0123, 2); + vout1x0123 = vext_u16(vout1x0123, vout1x0123, 2); + vout2x0123 = vext_u16(vout2x0123, vout2x0123, 2); + } + if (nc & 1) { + vst1_lane_u16(c0, vout0x0123, 0); + vst1_lane_u16(c1, vout1x0123, 0); + vst1_lane_u16(c2, vout2x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/bf16-gemm/gen/4x4c8-minmax-neonbf16-bfdot.c b/src/bf16-gemm/gen/4x4c8-minmax-neonbf16-bfdot.c new file mode 100644 index 000000000000..8e62c2b4b07d --- /dev/null +++ b/src/bf16-gemm/gen/4x4c8-minmax-neonbf16-bfdot.c @@ -0,0 +1,251 @@ +// Auto-generated file. Do not edit! +// Template: src/bf16-gemm/c8-neonbf16.c.in +// Generator: tools/xngen +// +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 4); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(bfloat16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const bfloat16_t* a0 = (const bfloat16_t*) a; + bfloat16_t* c0 = (bfloat16_t*) c; + const bfloat16_t* a1 = (const bfloat16_t*) ((uintptr_t) a0 + a_stride); + bfloat16_t* c1 = (bfloat16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const bfloat16_t* a2 = (const bfloat16_t*) ((uintptr_t) a1 + a_stride); + bfloat16_t* c2 = (bfloat16_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + const bfloat16_t* a3 = (const bfloat16_t*) ((uintptr_t) a2 + a_stride); + bfloat16_t* c3 = (bfloat16_t*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr != 4) { + a3 = a2; + c3 = c2; + } + + const bfloat16_t* w = (const bfloat16_t*) w_ptr; + do { + float32x4_t vacc0x0 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x1 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x2 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x3 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc1x0 = vacc0x0; + float32x4_t vacc1x1 = vacc0x1; + float32x4_t vacc1x2 = vacc0x2; + float32x4_t vacc1x3 = vacc0x3; + float32x4_t vacc2x0 = vacc0x0; + float32x4_t vacc2x1 = vacc0x1; + float32x4_t vacc2x2 = vacc0x2; + float32x4_t vacc2x3 = vacc0x3; + float32x4_t vacc3x0 = vacc0x0; + float32x4_t vacc3x1 = vacc0x1; + float32x4_t vacc3x2 = vacc0x2; + float32x4_t vacc3x3 = vacc0x3; + + size_t k = kc; + for (; k >= 8 * sizeof(bfloat16_t); k -= 8 * sizeof(bfloat16_t)) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 += 8; + const bfloat16x8_t va1 = vld1q_bf16(a1); a1 += 8; + const bfloat16x8_t va2 = vld1q_bf16(a2); a2 += 8; + const bfloat16x8_t va3 = vld1q_bf16(a3); a3 += 8; + + const bfloat16x8_t vb0 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb1 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb2 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb3 = vld1q_bf16(w); w += 8; + + vacc0x0 = vbfdotq_f32(vacc0x0, va0, vb0); + vacc1x0 = vbfdotq_f32(vacc1x0, va1, vb0); + vacc2x0 = vbfdotq_f32(vacc2x0, va2, vb0); + vacc3x0 = vbfdotq_f32(vacc3x0, va3, vb0); + vacc0x1 = vbfdotq_f32(vacc0x1, va0, vb1); + vacc1x1 = vbfdotq_f32(vacc1x1, va1, vb1); + vacc2x1 = vbfdotq_f32(vacc2x1, va2, vb1); + vacc3x1 = vbfdotq_f32(vacc3x1, va3, vb1); + vacc0x2 = vbfdotq_f32(vacc0x2, va0, vb2); + vacc1x2 = vbfdotq_f32(vacc1x2, va1, vb2); + vacc2x2 = vbfdotq_f32(vacc2x2, va2, vb2); + vacc3x2 = vbfdotq_f32(vacc3x2, va3, vb2); + vacc0x3 = vbfdotq_f32(vacc0x3, va0, vb3); + vacc1x3 = vbfdotq_f32(vacc1x3, va1, vb3); + vacc2x3 = vbfdotq_f32(vacc2x3, va2, vb3); + vacc3x3 = vbfdotq_f32(vacc3x3, va3, vb3); + } + if XNN_UNLIKELY(k != 0) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 = (const bfloat16_t*) ((uintptr_t) a0 + k); + const bfloat16x8_t va1 = vld1q_bf16(a1); a1 = (const bfloat16_t*) ((uintptr_t) a1 + k); + const bfloat16x8_t va2 = vld1q_bf16(a2); a2 = (const bfloat16_t*) ((uintptr_t) a2 + k); + const bfloat16x8_t va3 = vld1q_bf16(a3); a3 = (const bfloat16_t*) ((uintptr_t) a3 + k); + + const bfloat16x8_t vb0 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb1 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb2 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb3 = vld1q_bf16(w); w += 8; + + const uint16x8_t vm0 = vceqq_u16(vreinterpretq_u16_bf16(vb0), vmovq_n_u16(0)); + const uint16x8_t vm1 = vceqq_u16(vreinterpretq_u16_bf16(vb1), vmovq_n_u16(0)); + const uint16x8_t vm2 = vceqq_u16(vreinterpretq_u16_bf16(vb2), vmovq_n_u16(0)); + const uint16x8_t vm3 = vceqq_u16(vreinterpretq_u16_bf16(vb3), vmovq_n_u16(0)); + + const bfloat16x8_t va0x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm0)); + vacc0x0 = vbfdotq_f32(vacc0x0, va0x0, vb0); + const bfloat16x8_t va1x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm0)); + vacc1x0 = vbfdotq_f32(vacc1x0, va1x0, vb0); + const bfloat16x8_t va2x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va2), vm0)); + vacc2x0 = vbfdotq_f32(vacc2x0, va2x0, vb0); + const bfloat16x8_t va3x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va3), vm0)); + vacc3x0 = vbfdotq_f32(vacc3x0, va3x0, vb0); + const bfloat16x8_t va0x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm1)); + vacc0x1 = vbfdotq_f32(vacc0x1, va0x1, vb1); + const bfloat16x8_t va1x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm1)); + vacc1x1 = vbfdotq_f32(vacc1x1, va1x1, vb1); + const bfloat16x8_t va2x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va2), vm1)); + vacc2x1 = vbfdotq_f32(vacc2x1, va2x1, vb1); + const bfloat16x8_t va3x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va3), vm1)); + vacc3x1 = vbfdotq_f32(vacc3x1, va3x1, vb1); + const bfloat16x8_t va0x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm2)); + vacc0x2 = vbfdotq_f32(vacc0x2, va0x2, vb2); + const bfloat16x8_t va1x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm2)); + vacc1x2 = vbfdotq_f32(vacc1x2, va1x2, vb2); + const bfloat16x8_t va2x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va2), vm2)); + vacc2x2 = vbfdotq_f32(vacc2x2, va2x2, vb2); + const bfloat16x8_t va3x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va3), vm2)); + vacc3x2 = vbfdotq_f32(vacc3x2, va3x2, vb2); + const bfloat16x8_t va0x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm3)); + vacc0x3 = vbfdotq_f32(vacc0x3, va0x3, vb3); + const bfloat16x8_t va1x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm3)); + vacc1x3 = vbfdotq_f32(vacc1x3, va1x3, vb3); + const bfloat16x8_t va2x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va2), vm3)); + vacc2x3 = vbfdotq_f32(vacc2x3, va2x3, vb3); + const bfloat16x8_t va3x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va3), vm3)); + vacc3x3 = vbfdotq_f32(vacc3x3, va3x3, vb3); + } + +#if XNN_ARCH_ARM64 + const float32x4_t vacc0x01 = vpaddq_f32(vacc0x0, vacc0x1); + const float32x4_t vacc1x01 = vpaddq_f32(vacc1x0, vacc1x1); + const float32x4_t vacc2x01 = vpaddq_f32(vacc2x0, vacc2x1); + const float32x4_t vacc3x01 = vpaddq_f32(vacc3x0, vacc3x1); + const float32x4_t vacc0x23 = vpaddq_f32(vacc0x2, vacc0x3); + const float32x4_t vacc1x23 = vpaddq_f32(vacc1x2, vacc1x3); + const float32x4_t vacc2x23 = vpaddq_f32(vacc2x2, vacc2x3); + const float32x4_t vacc3x23 = vpaddq_f32(vacc3x2, vacc3x3); + + float32x4_t vacc0x0123 = vpaddq_f32(vacc0x01, vacc0x23); + float32x4_t vacc1x0123 = vpaddq_f32(vacc1x01, vacc1x23); + float32x4_t vacc2x0123 = vpaddq_f32(vacc2x01, vacc2x23); + float32x4_t vacc3x0123 = vpaddq_f32(vacc3x01, vacc3x23); +#else + const float32x2_t vsum0x0 = vadd_f32(vget_low_f32(vacc0x0), vget_high_f32(vacc0x0)); + const float32x2_t vsum1x0 = vadd_f32(vget_low_f32(vacc1x0), vget_high_f32(vacc1x0)); + const float32x2_t vsum2x0 = vadd_f32(vget_low_f32(vacc2x0), vget_high_f32(vacc2x0)); + const float32x2_t vsum3x0 = vadd_f32(vget_low_f32(vacc3x0), vget_high_f32(vacc3x0)); + const float32x2_t vsum0x1 = vadd_f32(vget_low_f32(vacc0x1), vget_high_f32(vacc0x1)); + const float32x2_t vsum1x1 = vadd_f32(vget_low_f32(vacc1x1), vget_high_f32(vacc1x1)); + const float32x2_t vsum2x1 = vadd_f32(vget_low_f32(vacc2x1), vget_high_f32(vacc2x1)); + const float32x2_t vsum3x1 = vadd_f32(vget_low_f32(vacc3x1), vget_high_f32(vacc3x1)); + const float32x2_t vsum0x2 = vadd_f32(vget_low_f32(vacc0x2), vget_high_f32(vacc0x2)); + const float32x2_t vsum1x2 = vadd_f32(vget_low_f32(vacc1x2), vget_high_f32(vacc1x2)); + const float32x2_t vsum2x2 = vadd_f32(vget_low_f32(vacc2x2), vget_high_f32(vacc2x2)); + const float32x2_t vsum3x2 = vadd_f32(vget_low_f32(vacc3x2), vget_high_f32(vacc3x2)); + const float32x2_t vsum0x3 = vadd_f32(vget_low_f32(vacc0x3), vget_high_f32(vacc0x3)); + const float32x2_t vsum1x3 = vadd_f32(vget_low_f32(vacc1x3), vget_high_f32(vacc1x3)); + const float32x2_t vsum2x3 = vadd_f32(vget_low_f32(vacc2x3), vget_high_f32(vacc2x3)); + const float32x2_t vsum3x3 = vadd_f32(vget_low_f32(vacc3x3), vget_high_f32(vacc3x3)); + + float32x4_t vacc0x0123 = vcombine_f32(vpadd_f32(vsum0x0, vsum0x1), vpadd_f32(vsum0x2, vsum0x3)); + float32x4_t vacc1x0123 = vcombine_f32(vpadd_f32(vsum1x0, vsum1x1), vpadd_f32(vsum1x2, vsum1x3)); + float32x4_t vacc2x0123 = vcombine_f32(vpadd_f32(vsum2x0, vsum2x1), vpadd_f32(vsum2x2, vsum2x3)); + float32x4_t vacc3x0123 = vcombine_f32(vpadd_f32(vsum3x0, vsum3x1), vpadd_f32(vsum3x2, vsum3x3)); +#endif + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + vacc0x0123 = vminq_f32(vacc0x0123, vmax); + vacc1x0123 = vminq_f32(vacc1x0123, vmax); + vacc2x0123 = vminq_f32(vacc2x0123, vmax); + vacc3x0123 = vminq_f32(vacc3x0123, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + vacc0x0123 = vmaxq_f32(vacc0x0123, vmin); + vacc1x0123 = vmaxq_f32(vacc1x0123, vmin); + vacc2x0123 = vmaxq_f32(vacc2x0123, vmin); + vacc3x0123 = vmaxq_f32(vacc3x0123, vmin); + + bfloat16x4_t vout0x0123 = vcvt_bf16_f32(vacc0x0123); + bfloat16x4_t vout1x0123 = vcvt_bf16_f32(vacc1x0123); + bfloat16x4_t vout2x0123 = vcvt_bf16_f32(vacc2x0123); + bfloat16x4_t vout3x0123 = vcvt_bf16_f32(vacc3x0123); + + if XNN_LIKELY(nc >= 4) { + vst1_bf16(c0, vout0x0123); + c0 = (bfloat16_t*) ((uintptr_t) c0 + cn_stride); + vst1_bf16(c1, vout1x0123); + c1 = (bfloat16_t*) ((uintptr_t) c1 + cn_stride); + vst1_bf16(c2, vout2x0123); + c2 = (bfloat16_t*) ((uintptr_t) c2 + cn_stride); + vst1_bf16(c3, vout3x0123); + c3 = (bfloat16_t*) ((uintptr_t) c3 + cn_stride); + + a0 = (const bfloat16_t*) ((uintptr_t) a0 - kc); + a1 = (const bfloat16_t*) ((uintptr_t) a1 - kc); + a2 = (const bfloat16_t*) ((uintptr_t) a2 - kc); + a3 = (const bfloat16_t*) ((uintptr_t) a3 - kc); + + nc -= 4; + } else { + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_bf16(vout0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_bf16(vout1x0123), 0); c1 += 2; + vst1_lane_u32((void*) c2, vreinterpret_u32_bf16(vout2x0123), 0); c2 += 2; + vst1_lane_u32((void*) c3, vreinterpret_u32_bf16(vout3x0123), 0); c3 += 2; + + vout0x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout0x0123), vreinterpret_u16_bf16(vout0x0123), 2)); + vout1x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout1x0123), vreinterpret_u16_bf16(vout1x0123), 2)); + vout2x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout2x0123), vreinterpret_u16_bf16(vout2x0123), 2)); + vout3x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout3x0123), vreinterpret_u16_bf16(vout3x0123), 2)); + } + if (nc & 1) { + vst1_lane_bf16(c0, vout0x0123, 0); + vst1_lane_bf16(c1, vout1x0123, 0); + vst1_lane_bf16(c2, vout2x0123, 0); + vst1_lane_bf16(c3, vout3x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/bf16-gemm/gen/4x4c8-minmax-neonbf16-bfmlal.c b/src/bf16-gemm/gen/4x4c8-minmax-neonbf16-bfmlal.c new file mode 100644 index 000000000000..22b8999b6d5f --- /dev/null +++ b/src/bf16-gemm/gen/4x4c8-minmax-neonbf16-bfmlal.c @@ -0,0 +1,284 @@ +// Auto-generated file. Do not edit! +// Template: src/bf16-gemm/c8-neonbf16.c.in +// Generator: tools/xngen +// +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 4); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(bfloat16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const bfloat16_t* a0 = (const bfloat16_t*) a; + bfloat16_t* c0 = (bfloat16_t*) c; + const bfloat16_t* a1 = (const bfloat16_t*) ((uintptr_t) a0 + a_stride); + bfloat16_t* c1 = (bfloat16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const bfloat16_t* a2 = (const bfloat16_t*) ((uintptr_t) a1 + a_stride); + bfloat16_t* c2 = (bfloat16_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + const bfloat16_t* a3 = (const bfloat16_t*) ((uintptr_t) a2 + a_stride); + bfloat16_t* c3 = (bfloat16_t*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr != 4) { + a3 = a2; + c3 = c2; + } + + const bfloat16_t* w = (const bfloat16_t*) w_ptr; + do { + float32x4_t vacc0x0 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x1 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x2 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x3 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc1x0 = vacc0x0; + float32x4_t vacc1x1 = vacc0x1; + float32x4_t vacc1x2 = vacc0x2; + float32x4_t vacc1x3 = vacc0x3; + float32x4_t vacc2x0 = vacc0x0; + float32x4_t vacc2x1 = vacc0x1; + float32x4_t vacc2x2 = vacc0x2; + float32x4_t vacc2x3 = vacc0x3; + float32x4_t vacc3x0 = vacc0x0; + float32x4_t vacc3x1 = vacc0x1; + float32x4_t vacc3x2 = vacc0x2; + float32x4_t vacc3x3 = vacc0x3; + + size_t k = kc; + for (; k >= 8 * sizeof(bfloat16_t); k -= 8 * sizeof(bfloat16_t)) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 += 8; + const bfloat16x8_t va1 = vld1q_bf16(a1); a1 += 8; + const bfloat16x8_t va2 = vld1q_bf16(a2); a2 += 8; + const bfloat16x8_t va3 = vld1q_bf16(a3); a3 += 8; + + const bfloat16x8_t vb0 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb1 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb2 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb3 = vld1q_bf16(w); w += 8; + + vacc0x0 = vbfmlalbq_f32(vacc0x0, va0, vb0); + vacc1x0 = vbfmlalbq_f32(vacc1x0, va1, vb0); + vacc2x0 = vbfmlalbq_f32(vacc2x0, va2, vb0); + vacc3x0 = vbfmlalbq_f32(vacc3x0, va3, vb0); + vacc0x1 = vbfmlalbq_f32(vacc0x1, va0, vb1); + vacc1x1 = vbfmlalbq_f32(vacc1x1, va1, vb1); + vacc2x1 = vbfmlalbq_f32(vacc2x1, va2, vb1); + vacc3x1 = vbfmlalbq_f32(vacc3x1, va3, vb1); + vacc0x2 = vbfmlalbq_f32(vacc0x2, va0, vb2); + vacc1x2 = vbfmlalbq_f32(vacc1x2, va1, vb2); + vacc2x2 = vbfmlalbq_f32(vacc2x2, va2, vb2); + vacc3x2 = vbfmlalbq_f32(vacc3x2, va3, vb2); + vacc0x3 = vbfmlalbq_f32(vacc0x3, va0, vb3); + vacc1x3 = vbfmlalbq_f32(vacc1x3, va1, vb3); + vacc2x3 = vbfmlalbq_f32(vacc2x3, va2, vb3); + vacc3x3 = vbfmlalbq_f32(vacc3x3, va3, vb3); + + vacc0x0 = vbfmlaltq_f32(vacc0x0, va0, vb0); + vacc1x0 = vbfmlaltq_f32(vacc1x0, va1, vb0); + vacc2x0 = vbfmlaltq_f32(vacc2x0, va2, vb0); + vacc3x0 = vbfmlaltq_f32(vacc3x0, va3, vb0); + vacc0x1 = vbfmlaltq_f32(vacc0x1, va0, vb1); + vacc1x1 = vbfmlaltq_f32(vacc1x1, va1, vb1); + vacc2x1 = vbfmlaltq_f32(vacc2x1, va2, vb1); + vacc3x1 = vbfmlaltq_f32(vacc3x1, va3, vb1); + vacc0x2 = vbfmlaltq_f32(vacc0x2, va0, vb2); + vacc1x2 = vbfmlaltq_f32(vacc1x2, va1, vb2); + vacc2x2 = vbfmlaltq_f32(vacc2x2, va2, vb2); + vacc3x2 = vbfmlaltq_f32(vacc3x2, va3, vb2); + vacc0x3 = vbfmlaltq_f32(vacc0x3, va0, vb3); + vacc1x3 = vbfmlaltq_f32(vacc1x3, va1, vb3); + vacc2x3 = vbfmlaltq_f32(vacc2x3, va2, vb3); + vacc3x3 = vbfmlaltq_f32(vacc3x3, va3, vb3); + } + if XNN_UNLIKELY(k != 0) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 = (const bfloat16_t*) ((uintptr_t) a0 + k); + const bfloat16x8_t va1 = vld1q_bf16(a1); a1 = (const bfloat16_t*) ((uintptr_t) a1 + k); + const bfloat16x8_t va2 = vld1q_bf16(a2); a2 = (const bfloat16_t*) ((uintptr_t) a2 + k); + const bfloat16x8_t va3 = vld1q_bf16(a3); a3 = (const bfloat16_t*) ((uintptr_t) a3 + k); + + const bfloat16x8_t vb0 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb1 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb2 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb3 = vld1q_bf16(w); w += 8; + + const uint16x8_t vm0 = vceqq_u16(vreinterpretq_u16_bf16(vb0), vmovq_n_u16(0)); + const uint16x8_t vm1 = vceqq_u16(vreinterpretq_u16_bf16(vb1), vmovq_n_u16(0)); + const uint16x8_t vm2 = vceqq_u16(vreinterpretq_u16_bf16(vb2), vmovq_n_u16(0)); + const uint16x8_t vm3 = vceqq_u16(vreinterpretq_u16_bf16(vb3), vmovq_n_u16(0)); + + const bfloat16x8_t va0x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm0)); + vacc0x0 = vbfmlalbq_f32(vacc0x0, va0x0, vb0); + vacc0x0 = vbfmlaltq_f32(vacc0x0, va0x0, vb0); + const bfloat16x8_t va1x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm0)); + vacc1x0 = vbfmlalbq_f32(vacc1x0, va1x0, vb0); + vacc1x0 = vbfmlaltq_f32(vacc1x0, va1x0, vb0); + const bfloat16x8_t va2x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va2), vm0)); + vacc2x0 = vbfmlalbq_f32(vacc2x0, va2x0, vb0); + vacc2x0 = vbfmlaltq_f32(vacc2x0, va2x0, vb0); + const bfloat16x8_t va3x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va3), vm0)); + vacc3x0 = vbfmlalbq_f32(vacc3x0, va3x0, vb0); + vacc3x0 = vbfmlaltq_f32(vacc3x0, va3x0, vb0); + const bfloat16x8_t va0x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm1)); + vacc0x1 = vbfmlalbq_f32(vacc0x1, va0x1, vb1); + vacc0x1 = vbfmlaltq_f32(vacc0x1, va0x1, vb1); + const bfloat16x8_t va1x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm1)); + vacc1x1 = vbfmlalbq_f32(vacc1x1, va1x1, vb1); + vacc1x1 = vbfmlaltq_f32(vacc1x1, va1x1, vb1); + const bfloat16x8_t va2x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va2), vm1)); + vacc2x1 = vbfmlalbq_f32(vacc2x1, va2x1, vb1); + vacc2x1 = vbfmlaltq_f32(vacc2x1, va2x1, vb1); + const bfloat16x8_t va3x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va3), vm1)); + vacc3x1 = vbfmlalbq_f32(vacc3x1, va3x1, vb1); + vacc3x1 = vbfmlaltq_f32(vacc3x1, va3x1, vb1); + const bfloat16x8_t va0x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm2)); + vacc0x2 = vbfmlalbq_f32(vacc0x2, va0x2, vb2); + vacc0x2 = vbfmlaltq_f32(vacc0x2, va0x2, vb2); + const bfloat16x8_t va1x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm2)); + vacc1x2 = vbfmlalbq_f32(vacc1x2, va1x2, vb2); + vacc1x2 = vbfmlaltq_f32(vacc1x2, va1x2, vb2); + const bfloat16x8_t va2x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va2), vm2)); + vacc2x2 = vbfmlalbq_f32(vacc2x2, va2x2, vb2); + vacc2x2 = vbfmlaltq_f32(vacc2x2, va2x2, vb2); + const bfloat16x8_t va3x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va3), vm2)); + vacc3x2 = vbfmlalbq_f32(vacc3x2, va3x2, vb2); + vacc3x2 = vbfmlaltq_f32(vacc3x2, va3x2, vb2); + const bfloat16x8_t va0x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm3)); + vacc0x3 = vbfmlalbq_f32(vacc0x3, va0x3, vb3); + vacc0x3 = vbfmlaltq_f32(vacc0x3, va0x3, vb3); + const bfloat16x8_t va1x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm3)); + vacc1x3 = vbfmlalbq_f32(vacc1x3, va1x3, vb3); + vacc1x3 = vbfmlaltq_f32(vacc1x3, va1x3, vb3); + const bfloat16x8_t va2x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va2), vm3)); + vacc2x3 = vbfmlalbq_f32(vacc2x3, va2x3, vb3); + vacc2x3 = vbfmlaltq_f32(vacc2x3, va2x3, vb3); + const bfloat16x8_t va3x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va3), vm3)); + vacc3x3 = vbfmlalbq_f32(vacc3x3, va3x3, vb3); + vacc3x3 = vbfmlaltq_f32(vacc3x3, va3x3, vb3); + } + +#if XNN_ARCH_ARM64 + const float32x4_t vacc0x01 = vpaddq_f32(vacc0x0, vacc0x1); + const float32x4_t vacc1x01 = vpaddq_f32(vacc1x0, vacc1x1); + const float32x4_t vacc2x01 = vpaddq_f32(vacc2x0, vacc2x1); + const float32x4_t vacc3x01 = vpaddq_f32(vacc3x0, vacc3x1); + const float32x4_t vacc0x23 = vpaddq_f32(vacc0x2, vacc0x3); + const float32x4_t vacc1x23 = vpaddq_f32(vacc1x2, vacc1x3); + const float32x4_t vacc2x23 = vpaddq_f32(vacc2x2, vacc2x3); + const float32x4_t vacc3x23 = vpaddq_f32(vacc3x2, vacc3x3); + + float32x4_t vacc0x0123 = vpaddq_f32(vacc0x01, vacc0x23); + float32x4_t vacc1x0123 = vpaddq_f32(vacc1x01, vacc1x23); + float32x4_t vacc2x0123 = vpaddq_f32(vacc2x01, vacc2x23); + float32x4_t vacc3x0123 = vpaddq_f32(vacc3x01, vacc3x23); +#else + const float32x2_t vsum0x0 = vadd_f32(vget_low_f32(vacc0x0), vget_high_f32(vacc0x0)); + const float32x2_t vsum1x0 = vadd_f32(vget_low_f32(vacc1x0), vget_high_f32(vacc1x0)); + const float32x2_t vsum2x0 = vadd_f32(vget_low_f32(vacc2x0), vget_high_f32(vacc2x0)); + const float32x2_t vsum3x0 = vadd_f32(vget_low_f32(vacc3x0), vget_high_f32(vacc3x0)); + const float32x2_t vsum0x1 = vadd_f32(vget_low_f32(vacc0x1), vget_high_f32(vacc0x1)); + const float32x2_t vsum1x1 = vadd_f32(vget_low_f32(vacc1x1), vget_high_f32(vacc1x1)); + const float32x2_t vsum2x1 = vadd_f32(vget_low_f32(vacc2x1), vget_high_f32(vacc2x1)); + const float32x2_t vsum3x1 = vadd_f32(vget_low_f32(vacc3x1), vget_high_f32(vacc3x1)); + const float32x2_t vsum0x2 = vadd_f32(vget_low_f32(vacc0x2), vget_high_f32(vacc0x2)); + const float32x2_t vsum1x2 = vadd_f32(vget_low_f32(vacc1x2), vget_high_f32(vacc1x2)); + const float32x2_t vsum2x2 = vadd_f32(vget_low_f32(vacc2x2), vget_high_f32(vacc2x2)); + const float32x2_t vsum3x2 = vadd_f32(vget_low_f32(vacc3x2), vget_high_f32(vacc3x2)); + const float32x2_t vsum0x3 = vadd_f32(vget_low_f32(vacc0x3), vget_high_f32(vacc0x3)); + const float32x2_t vsum1x3 = vadd_f32(vget_low_f32(vacc1x3), vget_high_f32(vacc1x3)); + const float32x2_t vsum2x3 = vadd_f32(vget_low_f32(vacc2x3), vget_high_f32(vacc2x3)); + const float32x2_t vsum3x3 = vadd_f32(vget_low_f32(vacc3x3), vget_high_f32(vacc3x3)); + + float32x4_t vacc0x0123 = vcombine_f32(vpadd_f32(vsum0x0, vsum0x1), vpadd_f32(vsum0x2, vsum0x3)); + float32x4_t vacc1x0123 = vcombine_f32(vpadd_f32(vsum1x0, vsum1x1), vpadd_f32(vsum1x2, vsum1x3)); + float32x4_t vacc2x0123 = vcombine_f32(vpadd_f32(vsum2x0, vsum2x1), vpadd_f32(vsum2x2, vsum2x3)); + float32x4_t vacc3x0123 = vcombine_f32(vpadd_f32(vsum3x0, vsum3x1), vpadd_f32(vsum3x2, vsum3x3)); +#endif + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + vacc0x0123 = vminq_f32(vacc0x0123, vmax); + vacc1x0123 = vminq_f32(vacc1x0123, vmax); + vacc2x0123 = vminq_f32(vacc2x0123, vmax); + vacc3x0123 = vminq_f32(vacc3x0123, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + vacc0x0123 = vmaxq_f32(vacc0x0123, vmin); + vacc1x0123 = vmaxq_f32(vacc1x0123, vmin); + vacc2x0123 = vmaxq_f32(vacc2x0123, vmin); + vacc3x0123 = vmaxq_f32(vacc3x0123, vmin); + + bfloat16x4_t vout0x0123 = vcvt_bf16_f32(vacc0x0123); + bfloat16x4_t vout1x0123 = vcvt_bf16_f32(vacc1x0123); + bfloat16x4_t vout2x0123 = vcvt_bf16_f32(vacc2x0123); + bfloat16x4_t vout3x0123 = vcvt_bf16_f32(vacc3x0123); + + if XNN_LIKELY(nc >= 4) { + vst1_bf16(c0, vout0x0123); + c0 = (bfloat16_t*) ((uintptr_t) c0 + cn_stride); + vst1_bf16(c1, vout1x0123); + c1 = (bfloat16_t*) ((uintptr_t) c1 + cn_stride); + vst1_bf16(c2, vout2x0123); + c2 = (bfloat16_t*) ((uintptr_t) c2 + cn_stride); + vst1_bf16(c3, vout3x0123); + c3 = (bfloat16_t*) ((uintptr_t) c3 + cn_stride); + + a0 = (const bfloat16_t*) ((uintptr_t) a0 - kc); + a1 = (const bfloat16_t*) ((uintptr_t) a1 - kc); + a2 = (const bfloat16_t*) ((uintptr_t) a2 - kc); + a3 = (const bfloat16_t*) ((uintptr_t) a3 - kc); + + nc -= 4; + } else { + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_bf16(vout0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_bf16(vout1x0123), 0); c1 += 2; + vst1_lane_u32((void*) c2, vreinterpret_u32_bf16(vout2x0123), 0); c2 += 2; + vst1_lane_u32((void*) c3, vreinterpret_u32_bf16(vout3x0123), 0); c3 += 2; + + vout0x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout0x0123), vreinterpret_u16_bf16(vout0x0123), 2)); + vout1x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout1x0123), vreinterpret_u16_bf16(vout1x0123), 2)); + vout2x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout2x0123), vreinterpret_u16_bf16(vout2x0123), 2)); + vout3x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout3x0123), vreinterpret_u16_bf16(vout3x0123), 2)); + } + if (nc & 1) { + vst1_lane_bf16(c0, vout0x0123, 0); + vst1_lane_bf16(c1, vout1x0123, 0); + vst1_lane_bf16(c2, vout2x0123, 0); + vst1_lane_bf16(c3, vout3x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/bf16-gemm/gen/4x4c8-minmax-neonfma-shland.c b/src/bf16-gemm/gen/4x4c8-minmax-neonfma-shland.c new file mode 100644 index 000000000000..32aa26dce177 --- /dev/null +++ b/src/bf16-gemm/gen/4x4c8-minmax-neonfma-shland.c @@ -0,0 +1,351 @@ +// Auto-generated file. Do not edit! +// Template: src/bf16-gemm/c8-neon.c.in +// Generator: tools/xngen +// +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 4); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(uint16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const uint16_t* a0 = (const uint16_t*) a; + uint16_t* c0 = (uint16_t*) c; + const uint16_t* a1 = (const uint16_t*) ((uintptr_t) a0 + a_stride); + uint16_t* c1 = (uint16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const uint16_t* a2 = (const uint16_t*) ((uintptr_t) a1 + a_stride); + uint16_t* c2 = (uint16_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + const uint16_t* a3 = (const uint16_t*) ((uintptr_t) a2 + a_stride); + uint16_t* c3 = (uint16_t*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr != 4) { + a3 = a2; + c3 = c2; + } + + const uint16_t* w = (const uint16_t*) w_ptr; + const uint16x8_t vmask = vreinterpretq_u16_u32(vmovq_n_u32(UINT32_C(0xFFFF0000))); + do { + float32x4_t vacc0x0 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x1 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x2 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x3 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc1x0 = vacc0x0; + float32x4_t vacc1x1 = vacc0x1; + float32x4_t vacc1x2 = vacc0x2; + float32x4_t vacc1x3 = vacc0x3; + float32x4_t vacc2x0 = vacc0x0; + float32x4_t vacc2x1 = vacc0x1; + float32x4_t vacc2x2 = vacc0x2; + float32x4_t vacc2x3 = vacc0x3; + float32x4_t vacc3x0 = vacc0x0; + float32x4_t vacc3x1 = vacc0x1; + float32x4_t vacc3x2 = vacc0x2; + float32x4_t vacc3x3 = vacc0x3; + + size_t k = kc; + for (; k >= 8 * sizeof(uint16_t); k -= 8 * sizeof(uint16_t)) { + const uint16x8_t va0 = vld1q_u16(a0); a0 += 8; + const uint16x8_t va1 = vld1q_u16(a1); a1 += 8; + const uint16x8_t va2 = vld1q_u16(a2); a2 += 8; + const uint16x8_t va3 = vld1q_u16(a3); a3 += 8; + + const uint16x8_t vb0 = vld1q_u16(w); w += 8; + const uint16x8_t vb1 = vld1q_u16(w); w += 8; + const uint16x8_t vb2 = vld1q_u16(w); w += 8; + const uint16x8_t vb3 = vld1q_u16(w); w += 8; + + const float32x4_t va0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va0), 16)); + const float32x4_t va1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va1), 16)); + const float32x4_t va2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va2), 16)); + const float32x4_t va3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va3), 16)); + + const float32x4_t vb0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb0), 16)); + const float32x4_t vb1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb1), 16)); + const float32x4_t vb2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb2), 16)); + const float32x4_t vb3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb3), 16)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0e, vb0e); + vacc1x0 = vfmaq_f32(vacc1x0, va1e, vb0e); + vacc2x0 = vfmaq_f32(vacc2x0, va2e, vb0e); + vacc3x0 = vfmaq_f32(vacc3x0, va3e, vb0e); + vacc0x1 = vfmaq_f32(vacc0x1, va0e, vb1e); + vacc1x1 = vfmaq_f32(vacc1x1, va1e, vb1e); + vacc2x1 = vfmaq_f32(vacc2x1, va2e, vb1e); + vacc3x1 = vfmaq_f32(vacc3x1, va3e, vb1e); + vacc0x2 = vfmaq_f32(vacc0x2, va0e, vb2e); + vacc1x2 = vfmaq_f32(vacc1x2, va1e, vb2e); + vacc2x2 = vfmaq_f32(vacc2x2, va2e, vb2e); + vacc3x2 = vfmaq_f32(vacc3x2, va3e, vb2e); + vacc0x3 = vfmaq_f32(vacc0x3, va0e, vb3e); + vacc1x3 = vfmaq_f32(vacc1x3, va1e, vb3e); + vacc2x3 = vfmaq_f32(vacc2x3, va2e, vb3e); + vacc3x3 = vfmaq_f32(vacc3x3, va3e, vb3e); + + const float32x4_t va0o = vreinterpretq_f32_u16(vandq_u16(va0, vmask)); + const float32x4_t va1o = vreinterpretq_f32_u16(vandq_u16(va1, vmask)); + const float32x4_t va2o = vreinterpretq_f32_u16(vandq_u16(va2, vmask)); + const float32x4_t va3o = vreinterpretq_f32_u16(vandq_u16(va3, vmask)); + + const float32x4_t vb0o = vreinterpretq_f32_u16(vandq_u16(vb0, vmask)); + const float32x4_t vb1o = vreinterpretq_f32_u16(vandq_u16(vb1, vmask)); + const float32x4_t vb2o = vreinterpretq_f32_u16(vandq_u16(vb2, vmask)); + const float32x4_t vb3o = vreinterpretq_f32_u16(vandq_u16(vb3, vmask)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0o, vb0o); + vacc1x0 = vfmaq_f32(vacc1x0, va1o, vb0o); + vacc2x0 = vfmaq_f32(vacc2x0, va2o, vb0o); + vacc3x0 = vfmaq_f32(vacc3x0, va3o, vb0o); + vacc0x1 = vfmaq_f32(vacc0x1, va0o, vb1o); + vacc1x1 = vfmaq_f32(vacc1x1, va1o, vb1o); + vacc2x1 = vfmaq_f32(vacc2x1, va2o, vb1o); + vacc3x1 = vfmaq_f32(vacc3x1, va3o, vb1o); + vacc0x2 = vfmaq_f32(vacc0x2, va0o, vb2o); + vacc1x2 = vfmaq_f32(vacc1x2, va1o, vb2o); + vacc2x2 = vfmaq_f32(vacc2x2, va2o, vb2o); + vacc3x2 = vfmaq_f32(vacc3x2, va3o, vb2o); + vacc0x3 = vfmaq_f32(vacc0x3, va0o, vb3o); + vacc1x3 = vfmaq_f32(vacc1x3, va1o, vb3o); + vacc2x3 = vfmaq_f32(vacc2x3, va2o, vb3o); + vacc3x3 = vfmaq_f32(vacc3x3, va3o, vb3o); + } + if XNN_UNLIKELY(k != 0) { + const uint16x8_t va0 = vld1q_u16(a0); a0 = (const uint16_t*) ((uintptr_t) a0 + k); + const uint16x8_t va1 = vld1q_u16(a1); a1 = (const uint16_t*) ((uintptr_t) a1 + k); + const uint16x8_t va2 = vld1q_u16(a2); a2 = (const uint16_t*) ((uintptr_t) a2 + k); + const uint16x8_t va3 = vld1q_u16(a3); a3 = (const uint16_t*) ((uintptr_t) a3 + k); + + const uint16x8_t vb0 = vld1q_u16(w); w += 8; + const uint16x8_t vb1 = vld1q_u16(w); w += 8; + const uint16x8_t vb2 = vld1q_u16(w); w += 8; + const uint16x8_t vb3 = vld1q_u16(w); w += 8; + + const uint16x8_t vm0 = vceqq_u16(vb0, vmovq_n_u16(0)); + const uint16x8_t vm1 = vceqq_u16(vb1, vmovq_n_u16(0)); + const uint16x8_t vm2 = vceqq_u16(vb2, vmovq_n_u16(0)); + const uint16x8_t vm3 = vceqq_u16(vb3, vmovq_n_u16(0)); + + const float32x4_t vb0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb0), 16)); + const float32x4_t vb1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb1), 16)); + const float32x4_t vb2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb2), 16)); + const float32x4_t vb3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb3), 16)); + + const uint16x8_t va0x0 = vbicq_u16(va0, vm0); + const uint16x8_t va1x0 = vbicq_u16(va1, vm0); + const uint16x8_t va2x0 = vbicq_u16(va2, vm0); + const uint16x8_t va3x0 = vbicq_u16(va3, vm0); + const uint16x8_t va0x1 = vbicq_u16(va0, vm1); + const uint16x8_t va1x1 = vbicq_u16(va1, vm1); + const uint16x8_t va2x1 = vbicq_u16(va2, vm1); + const uint16x8_t va3x1 = vbicq_u16(va3, vm1); + const uint16x8_t va0x2 = vbicq_u16(va0, vm2); + const uint16x8_t va1x2 = vbicq_u16(va1, vm2); + const uint16x8_t va2x2 = vbicq_u16(va2, vm2); + const uint16x8_t va3x2 = vbicq_u16(va3, vm2); + const uint16x8_t va0x3 = vbicq_u16(va0, vm3); + const uint16x8_t va1x3 = vbicq_u16(va1, vm3); + const uint16x8_t va2x3 = vbicq_u16(va2, vm3); + const uint16x8_t va3x3 = vbicq_u16(va3, vm3); + + const float32x4_t va0x0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va0x0), 16)); + const float32x4_t va1x0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va1x0), 16)); + const float32x4_t va2x0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va2x0), 16)); + const float32x4_t va3x0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va3x0), 16)); + const float32x4_t va0x1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va0x1), 16)); + const float32x4_t va1x1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va1x1), 16)); + const float32x4_t va2x1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va2x1), 16)); + const float32x4_t va3x1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va3x1), 16)); + const float32x4_t va0x2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va0x2), 16)); + const float32x4_t va1x2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va1x2), 16)); + const float32x4_t va2x2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va2x2), 16)); + const float32x4_t va3x2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va3x2), 16)); + const float32x4_t va0x3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va0x3), 16)); + const float32x4_t va1x3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va1x3), 16)); + const float32x4_t va2x3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va2x3), 16)); + const float32x4_t va3x3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va3x3), 16)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0x0e, vb0e); + vacc1x0 = vfmaq_f32(vacc1x0, va1x0e, vb0e); + vacc2x0 = vfmaq_f32(vacc2x0, va2x0e, vb0e); + vacc3x0 = vfmaq_f32(vacc3x0, va3x0e, vb0e); + vacc0x1 = vfmaq_f32(vacc0x1, va0x1e, vb1e); + vacc1x1 = vfmaq_f32(vacc1x1, va1x1e, vb1e); + vacc2x1 = vfmaq_f32(vacc2x1, va2x1e, vb1e); + vacc3x1 = vfmaq_f32(vacc3x1, va3x1e, vb1e); + vacc0x2 = vfmaq_f32(vacc0x2, va0x2e, vb2e); + vacc1x2 = vfmaq_f32(vacc1x2, va1x2e, vb2e); + vacc2x2 = vfmaq_f32(vacc2x2, va2x2e, vb2e); + vacc3x2 = vfmaq_f32(vacc3x2, va3x2e, vb2e); + vacc0x3 = vfmaq_f32(vacc0x3, va0x3e, vb3e); + vacc1x3 = vfmaq_f32(vacc1x3, va1x3e, vb3e); + vacc2x3 = vfmaq_f32(vacc2x3, va2x3e, vb3e); + vacc3x3 = vfmaq_f32(vacc3x3, va3x3e, vb3e); + + const float32x4_t vb0o = vreinterpretq_f32_u16(vandq_u16(vb0, vmask)); + const float32x4_t vb1o = vreinterpretq_f32_u16(vandq_u16(vb1, vmask)); + const float32x4_t vb2o = vreinterpretq_f32_u16(vandq_u16(vb2, vmask)); + const float32x4_t vb3o = vreinterpretq_f32_u16(vandq_u16(vb3, vmask)); + + const float32x4_t va0x0o = vreinterpretq_f32_u16(vandq_u16(va0x0, vmask)); + const float32x4_t va1x0o = vreinterpretq_f32_u16(vandq_u16(va1x0, vmask)); + const float32x4_t va2x0o = vreinterpretq_f32_u16(vandq_u16(va2x0, vmask)); + const float32x4_t va3x0o = vreinterpretq_f32_u16(vandq_u16(va3x0, vmask)); + const float32x4_t va0x1o = vreinterpretq_f32_u16(vandq_u16(va0x1, vmask)); + const float32x4_t va1x1o = vreinterpretq_f32_u16(vandq_u16(va1x1, vmask)); + const float32x4_t va2x1o = vreinterpretq_f32_u16(vandq_u16(va2x1, vmask)); + const float32x4_t va3x1o = vreinterpretq_f32_u16(vandq_u16(va3x1, vmask)); + const float32x4_t va0x2o = vreinterpretq_f32_u16(vandq_u16(va0x2, vmask)); + const float32x4_t va1x2o = vreinterpretq_f32_u16(vandq_u16(va1x2, vmask)); + const float32x4_t va2x2o = vreinterpretq_f32_u16(vandq_u16(va2x2, vmask)); + const float32x4_t va3x2o = vreinterpretq_f32_u16(vandq_u16(va3x2, vmask)); + const float32x4_t va0x3o = vreinterpretq_f32_u16(vandq_u16(va0x3, vmask)); + const float32x4_t va1x3o = vreinterpretq_f32_u16(vandq_u16(va1x3, vmask)); + const float32x4_t va2x3o = vreinterpretq_f32_u16(vandq_u16(va2x3, vmask)); + const float32x4_t va3x3o = vreinterpretq_f32_u16(vandq_u16(va3x3, vmask)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0x0o, vb0o); + vacc1x0 = vfmaq_f32(vacc1x0, va1x0o, vb0o); + vacc2x0 = vfmaq_f32(vacc2x0, va2x0o, vb0o); + vacc3x0 = vfmaq_f32(vacc3x0, va3x0o, vb0o); + vacc0x1 = vfmaq_f32(vacc0x1, va0x1o, vb1o); + vacc1x1 = vfmaq_f32(vacc1x1, va1x1o, vb1o); + vacc2x1 = vfmaq_f32(vacc2x1, va2x1o, vb1o); + vacc3x1 = vfmaq_f32(vacc3x1, va3x1o, vb1o); + vacc0x2 = vfmaq_f32(vacc0x2, va0x2o, vb2o); + vacc1x2 = vfmaq_f32(vacc1x2, va1x2o, vb2o); + vacc2x2 = vfmaq_f32(vacc2x2, va2x2o, vb2o); + vacc3x2 = vfmaq_f32(vacc3x2, va3x2o, vb2o); + vacc0x3 = vfmaq_f32(vacc0x3, va0x3o, vb3o); + vacc1x3 = vfmaq_f32(vacc1x3, va1x3o, vb3o); + vacc2x3 = vfmaq_f32(vacc2x3, va2x3o, vb3o); + vacc3x3 = vfmaq_f32(vacc3x3, va3x3o, vb3o); + } + +#if XNN_ARCH_ARM64 + const float32x4_t vacc0x01 = vpaddq_f32(vacc0x0, vacc0x1); + const float32x4_t vacc1x01 = vpaddq_f32(vacc1x0, vacc1x1); + const float32x4_t vacc2x01 = vpaddq_f32(vacc2x0, vacc2x1); + const float32x4_t vacc3x01 = vpaddq_f32(vacc3x0, vacc3x1); + const float32x4_t vacc0x23 = vpaddq_f32(vacc0x2, vacc0x3); + const float32x4_t vacc1x23 = vpaddq_f32(vacc1x2, vacc1x3); + const float32x4_t vacc2x23 = vpaddq_f32(vacc2x2, vacc2x3); + const float32x4_t vacc3x23 = vpaddq_f32(vacc3x2, vacc3x3); + + float32x4_t vacc0x0123 = vpaddq_f32(vacc0x01, vacc0x23); + float32x4_t vacc1x0123 = vpaddq_f32(vacc1x01, vacc1x23); + float32x4_t vacc2x0123 = vpaddq_f32(vacc2x01, vacc2x23); + float32x4_t vacc3x0123 = vpaddq_f32(vacc3x01, vacc3x23); +#else + const float32x2_t vsum0x0 = vadd_f32(vget_low_f32(vacc0x0), vget_high_f32(vacc0x0)); + const float32x2_t vsum1x0 = vadd_f32(vget_low_f32(vacc1x0), vget_high_f32(vacc1x0)); + const float32x2_t vsum2x0 = vadd_f32(vget_low_f32(vacc2x0), vget_high_f32(vacc2x0)); + const float32x2_t vsum3x0 = vadd_f32(vget_low_f32(vacc3x0), vget_high_f32(vacc3x0)); + const float32x2_t vsum0x1 = vadd_f32(vget_low_f32(vacc0x1), vget_high_f32(vacc0x1)); + const float32x2_t vsum1x1 = vadd_f32(vget_low_f32(vacc1x1), vget_high_f32(vacc1x1)); + const float32x2_t vsum2x1 = vadd_f32(vget_low_f32(vacc2x1), vget_high_f32(vacc2x1)); + const float32x2_t vsum3x1 = vadd_f32(vget_low_f32(vacc3x1), vget_high_f32(vacc3x1)); + const float32x2_t vsum0x2 = vadd_f32(vget_low_f32(vacc0x2), vget_high_f32(vacc0x2)); + const float32x2_t vsum1x2 = vadd_f32(vget_low_f32(vacc1x2), vget_high_f32(vacc1x2)); + const float32x2_t vsum2x2 = vadd_f32(vget_low_f32(vacc2x2), vget_high_f32(vacc2x2)); + const float32x2_t vsum3x2 = vadd_f32(vget_low_f32(vacc3x2), vget_high_f32(vacc3x2)); + const float32x2_t vsum0x3 = vadd_f32(vget_low_f32(vacc0x3), vget_high_f32(vacc0x3)); + const float32x2_t vsum1x3 = vadd_f32(vget_low_f32(vacc1x3), vget_high_f32(vacc1x3)); + const float32x2_t vsum2x3 = vadd_f32(vget_low_f32(vacc2x3), vget_high_f32(vacc2x3)); + const float32x2_t vsum3x3 = vadd_f32(vget_low_f32(vacc3x3), vget_high_f32(vacc3x3)); + + float32x4_t vacc0x0123 = vcombine_f32(vpadd_f32(vsum0x0, vsum0x1), vpadd_f32(vsum0x2, vsum0x3)); + float32x4_t vacc1x0123 = vcombine_f32(vpadd_f32(vsum1x0, vsum1x1), vpadd_f32(vsum1x2, vsum1x3)); + float32x4_t vacc2x0123 = vcombine_f32(vpadd_f32(vsum2x0, vsum2x1), vpadd_f32(vsum2x2, vsum2x3)); + float32x4_t vacc3x0123 = vcombine_f32(vpadd_f32(vsum3x0, vsum3x1), vpadd_f32(vsum3x2, vsum3x3)); +#endif + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + vacc0x0123 = vminq_f32(vacc0x0123, vmax); + vacc1x0123 = vminq_f32(vacc1x0123, vmax); + vacc2x0123 = vminq_f32(vacc2x0123, vmax); + vacc3x0123 = vminq_f32(vacc3x0123, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + vacc0x0123 = vmaxq_f32(vacc0x0123, vmin); + vacc1x0123 = vmaxq_f32(vacc1x0123, vmin); + vacc2x0123 = vmaxq_f32(vacc2x0123, vmin); + vacc3x0123 = vmaxq_f32(vacc3x0123, vmin); + + uint16x4_t vout0x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc0x0123), 16); + uint16x4_t vout1x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc1x0123), 16); + uint16x4_t vout2x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc2x0123), 16); + uint16x4_t vout3x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc3x0123), 16); + + if XNN_LIKELY(nc >= 4) { + vst1_u16(c0, vout0x0123); + c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); + vst1_u16(c1, vout1x0123); + c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride); + vst1_u16(c2, vout2x0123); + c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride); + vst1_u16(c3, vout3x0123); + c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride); + + a0 = (const uint16_t*) ((uintptr_t) a0 - kc); + a1 = (const uint16_t*) ((uintptr_t) a1 - kc); + a2 = (const uint16_t*) ((uintptr_t) a2 - kc); + a3 = (const uint16_t*) ((uintptr_t) a3 - kc); + + nc -= 4; + } else { + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_u16(vout0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_u16(vout1x0123), 0); c1 += 2; + vst1_lane_u32((void*) c2, vreinterpret_u32_u16(vout2x0123), 0); c2 += 2; + vst1_lane_u32((void*) c3, vreinterpret_u32_u16(vout3x0123), 0); c3 += 2; + + vout0x0123 = vext_u16(vout0x0123, vout0x0123, 2); + vout1x0123 = vext_u16(vout1x0123, vout1x0123, 2); + vout2x0123 = vext_u16(vout2x0123, vout2x0123, 2); + vout3x0123 = vext_u16(vout3x0123, vout3x0123, 2); + } + if (nc & 1) { + vst1_lane_u16(c0, vout0x0123, 0); + vst1_lane_u16(c1, vout1x0123, 0); + vst1_lane_u16(c2, vout2x0123, 0); + vst1_lane_u16(c3, vout3x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/bf16-gemm/gen/4x4c8-minmax-neonfma-zip.c b/src/bf16-gemm/gen/4x4c8-minmax-neonfma-zip.c new file mode 100644 index 000000000000..573427708184 --- /dev/null +++ b/src/bf16-gemm/gen/4x4c8-minmax-neonfma-zip.c @@ -0,0 +1,351 @@ +// Auto-generated file. Do not edit! +// Template: src/bf16-gemm/c8-neon.c.in +// Generator: tools/xngen +// +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 4); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(uint16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const uint16_t* a0 = (const uint16_t*) a; + uint16_t* c0 = (uint16_t*) c; + const uint16_t* a1 = (const uint16_t*) ((uintptr_t) a0 + a_stride); + uint16_t* c1 = (uint16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const uint16_t* a2 = (const uint16_t*) ((uintptr_t) a1 + a_stride); + uint16_t* c2 = (uint16_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + const uint16_t* a3 = (const uint16_t*) ((uintptr_t) a2 + a_stride); + uint16_t* c3 = (uint16_t*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr != 4) { + a3 = a2; + c3 = c2; + } + + const uint16_t* w = (const uint16_t*) w_ptr; + const uint16x8_t vzero = vmovq_n_u16(0); + do { + float32x4_t vacc0x0 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x1 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x2 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x3 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc1x0 = vacc0x0; + float32x4_t vacc1x1 = vacc0x1; + float32x4_t vacc1x2 = vacc0x2; + float32x4_t vacc1x3 = vacc0x3; + float32x4_t vacc2x0 = vacc0x0; + float32x4_t vacc2x1 = vacc0x1; + float32x4_t vacc2x2 = vacc0x2; + float32x4_t vacc2x3 = vacc0x3; + float32x4_t vacc3x0 = vacc0x0; + float32x4_t vacc3x1 = vacc0x1; + float32x4_t vacc3x2 = vacc0x2; + float32x4_t vacc3x3 = vacc0x3; + + size_t k = kc; + for (; k >= 8 * sizeof(uint16_t); k -= 8 * sizeof(uint16_t)) { + const uint16x8_t va0 = vld1q_u16(a0); a0 += 8; + const uint16x8_t va1 = vld1q_u16(a1); a1 += 8; + const uint16x8_t va2 = vld1q_u16(a2); a2 += 8; + const uint16x8_t va3 = vld1q_u16(a3); a3 += 8; + + const uint16x8_t vb0 = vld1q_u16(w); w += 8; + const uint16x8_t vb1 = vld1q_u16(w); w += 8; + const uint16x8_t vb2 = vld1q_u16(w); w += 8; + const uint16x8_t vb3 = vld1q_u16(w); w += 8; + + const float32x4_t va0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va0)); + const float32x4_t va1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va1)); + const float32x4_t va2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va2)); + const float32x4_t va3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va3)); + + const float32x4_t vb0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb0)); + const float32x4_t vb1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb1)); + const float32x4_t vb2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb2)); + const float32x4_t vb3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb3)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0e, vb0e); + vacc1x0 = vfmaq_f32(vacc1x0, va1e, vb0e); + vacc2x0 = vfmaq_f32(vacc2x0, va2e, vb0e); + vacc3x0 = vfmaq_f32(vacc3x0, va3e, vb0e); + vacc0x1 = vfmaq_f32(vacc0x1, va0e, vb1e); + vacc1x1 = vfmaq_f32(vacc1x1, va1e, vb1e); + vacc2x1 = vfmaq_f32(vacc2x1, va2e, vb1e); + vacc3x1 = vfmaq_f32(vacc3x1, va3e, vb1e); + vacc0x2 = vfmaq_f32(vacc0x2, va0e, vb2e); + vacc1x2 = vfmaq_f32(vacc1x2, va1e, vb2e); + vacc2x2 = vfmaq_f32(vacc2x2, va2e, vb2e); + vacc3x2 = vfmaq_f32(vacc3x2, va3e, vb2e); + vacc0x3 = vfmaq_f32(vacc0x3, va0e, vb3e); + vacc1x3 = vfmaq_f32(vacc1x3, va1e, vb3e); + vacc2x3 = vfmaq_f32(vacc2x3, va2e, vb3e); + vacc3x3 = vfmaq_f32(vacc3x3, va3e, vb3e); + + const float32x4_t va0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va0)); + const float32x4_t va1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va1)); + const float32x4_t va2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va2)); + const float32x4_t va3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va3)); + + const float32x4_t vb0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb0)); + const float32x4_t vb1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb1)); + const float32x4_t vb2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb2)); + const float32x4_t vb3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb3)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0o, vb0o); + vacc1x0 = vfmaq_f32(vacc1x0, va1o, vb0o); + vacc2x0 = vfmaq_f32(vacc2x0, va2o, vb0o); + vacc3x0 = vfmaq_f32(vacc3x0, va3o, vb0o); + vacc0x1 = vfmaq_f32(vacc0x1, va0o, vb1o); + vacc1x1 = vfmaq_f32(vacc1x1, va1o, vb1o); + vacc2x1 = vfmaq_f32(vacc2x1, va2o, vb1o); + vacc3x1 = vfmaq_f32(vacc3x1, va3o, vb1o); + vacc0x2 = vfmaq_f32(vacc0x2, va0o, vb2o); + vacc1x2 = vfmaq_f32(vacc1x2, va1o, vb2o); + vacc2x2 = vfmaq_f32(vacc2x2, va2o, vb2o); + vacc3x2 = vfmaq_f32(vacc3x2, va3o, vb2o); + vacc0x3 = vfmaq_f32(vacc0x3, va0o, vb3o); + vacc1x3 = vfmaq_f32(vacc1x3, va1o, vb3o); + vacc2x3 = vfmaq_f32(vacc2x3, va2o, vb3o); + vacc3x3 = vfmaq_f32(vacc3x3, va3o, vb3o); + } + if XNN_UNLIKELY(k != 0) { + const uint16x8_t va0 = vld1q_u16(a0); a0 = (const uint16_t*) ((uintptr_t) a0 + k); + const uint16x8_t va1 = vld1q_u16(a1); a1 = (const uint16_t*) ((uintptr_t) a1 + k); + const uint16x8_t va2 = vld1q_u16(a2); a2 = (const uint16_t*) ((uintptr_t) a2 + k); + const uint16x8_t va3 = vld1q_u16(a3); a3 = (const uint16_t*) ((uintptr_t) a3 + k); + + const uint16x8_t vb0 = vld1q_u16(w); w += 8; + const uint16x8_t vb1 = vld1q_u16(w); w += 8; + const uint16x8_t vb2 = vld1q_u16(w); w += 8; + const uint16x8_t vb3 = vld1q_u16(w); w += 8; + + const uint16x8_t vm0 = vceqq_u16(vb0, vmovq_n_u16(0)); + const uint16x8_t vm1 = vceqq_u16(vb1, vmovq_n_u16(0)); + const uint16x8_t vm2 = vceqq_u16(vb2, vmovq_n_u16(0)); + const uint16x8_t vm3 = vceqq_u16(vb3, vmovq_n_u16(0)); + + const float32x4_t vb0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb0)); + const float32x4_t vb1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb1)); + const float32x4_t vb2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb2)); + const float32x4_t vb3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb3)); + + const uint16x8_t va0x0 = vbicq_u16(va0, vm0); + const uint16x8_t va1x0 = vbicq_u16(va1, vm0); + const uint16x8_t va2x0 = vbicq_u16(va2, vm0); + const uint16x8_t va3x0 = vbicq_u16(va3, vm0); + const uint16x8_t va0x1 = vbicq_u16(va0, vm1); + const uint16x8_t va1x1 = vbicq_u16(va1, vm1); + const uint16x8_t va2x1 = vbicq_u16(va2, vm1); + const uint16x8_t va3x1 = vbicq_u16(va3, vm1); + const uint16x8_t va0x2 = vbicq_u16(va0, vm2); + const uint16x8_t va1x2 = vbicq_u16(va1, vm2); + const uint16x8_t va2x2 = vbicq_u16(va2, vm2); + const uint16x8_t va3x2 = vbicq_u16(va3, vm2); + const uint16x8_t va0x3 = vbicq_u16(va0, vm3); + const uint16x8_t va1x3 = vbicq_u16(va1, vm3); + const uint16x8_t va2x3 = vbicq_u16(va2, vm3); + const uint16x8_t va3x3 = vbicq_u16(va3, vm3); + + const float32x4_t va0x0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va0x0)); + const float32x4_t va1x0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va1x0)); + const float32x4_t va2x0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va2x0)); + const float32x4_t va3x0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va3x0)); + const float32x4_t va0x1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va0x1)); + const float32x4_t va1x1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va1x1)); + const float32x4_t va2x1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va2x1)); + const float32x4_t va3x1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va3x1)); + const float32x4_t va0x2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va0x2)); + const float32x4_t va1x2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va1x2)); + const float32x4_t va2x2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va2x2)); + const float32x4_t va3x2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va3x2)); + const float32x4_t va0x3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va0x3)); + const float32x4_t va1x3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va1x3)); + const float32x4_t va2x3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va2x3)); + const float32x4_t va3x3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va3x3)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0x0e, vb0e); + vacc1x0 = vfmaq_f32(vacc1x0, va1x0e, vb0e); + vacc2x0 = vfmaq_f32(vacc2x0, va2x0e, vb0e); + vacc3x0 = vfmaq_f32(vacc3x0, va3x0e, vb0e); + vacc0x1 = vfmaq_f32(vacc0x1, va0x1e, vb1e); + vacc1x1 = vfmaq_f32(vacc1x1, va1x1e, vb1e); + vacc2x1 = vfmaq_f32(vacc2x1, va2x1e, vb1e); + vacc3x1 = vfmaq_f32(vacc3x1, va3x1e, vb1e); + vacc0x2 = vfmaq_f32(vacc0x2, va0x2e, vb2e); + vacc1x2 = vfmaq_f32(vacc1x2, va1x2e, vb2e); + vacc2x2 = vfmaq_f32(vacc2x2, va2x2e, vb2e); + vacc3x2 = vfmaq_f32(vacc3x2, va3x2e, vb2e); + vacc0x3 = vfmaq_f32(vacc0x3, va0x3e, vb3e); + vacc1x3 = vfmaq_f32(vacc1x3, va1x3e, vb3e); + vacc2x3 = vfmaq_f32(vacc2x3, va2x3e, vb3e); + vacc3x3 = vfmaq_f32(vacc3x3, va3x3e, vb3e); + + const float32x4_t vb0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb0)); + const float32x4_t vb1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb1)); + const float32x4_t vb2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb2)); + const float32x4_t vb3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb3)); + + const float32x4_t va0x0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va0x0)); + const float32x4_t va1x0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va1x0)); + const float32x4_t va2x0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va2x0)); + const float32x4_t va3x0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va3x0)); + const float32x4_t va0x1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va0x1)); + const float32x4_t va1x1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va1x1)); + const float32x4_t va2x1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va2x1)); + const float32x4_t va3x1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va3x1)); + const float32x4_t va0x2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va0x2)); + const float32x4_t va1x2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va1x2)); + const float32x4_t va2x2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va2x2)); + const float32x4_t va3x2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va3x2)); + const float32x4_t va0x3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va0x3)); + const float32x4_t va1x3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va1x3)); + const float32x4_t va2x3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va2x3)); + const float32x4_t va3x3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va3x3)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0x0o, vb0o); + vacc1x0 = vfmaq_f32(vacc1x0, va1x0o, vb0o); + vacc2x0 = vfmaq_f32(vacc2x0, va2x0o, vb0o); + vacc3x0 = vfmaq_f32(vacc3x0, va3x0o, vb0o); + vacc0x1 = vfmaq_f32(vacc0x1, va0x1o, vb1o); + vacc1x1 = vfmaq_f32(vacc1x1, va1x1o, vb1o); + vacc2x1 = vfmaq_f32(vacc2x1, va2x1o, vb1o); + vacc3x1 = vfmaq_f32(vacc3x1, va3x1o, vb1o); + vacc0x2 = vfmaq_f32(vacc0x2, va0x2o, vb2o); + vacc1x2 = vfmaq_f32(vacc1x2, va1x2o, vb2o); + vacc2x2 = vfmaq_f32(vacc2x2, va2x2o, vb2o); + vacc3x2 = vfmaq_f32(vacc3x2, va3x2o, vb2o); + vacc0x3 = vfmaq_f32(vacc0x3, va0x3o, vb3o); + vacc1x3 = vfmaq_f32(vacc1x3, va1x3o, vb3o); + vacc2x3 = vfmaq_f32(vacc2x3, va2x3o, vb3o); + vacc3x3 = vfmaq_f32(vacc3x3, va3x3o, vb3o); + } + +#if XNN_ARCH_ARM64 + const float32x4_t vacc0x01 = vpaddq_f32(vacc0x0, vacc0x1); + const float32x4_t vacc1x01 = vpaddq_f32(vacc1x0, vacc1x1); + const float32x4_t vacc2x01 = vpaddq_f32(vacc2x0, vacc2x1); + const float32x4_t vacc3x01 = vpaddq_f32(vacc3x0, vacc3x1); + const float32x4_t vacc0x23 = vpaddq_f32(vacc0x2, vacc0x3); + const float32x4_t vacc1x23 = vpaddq_f32(vacc1x2, vacc1x3); + const float32x4_t vacc2x23 = vpaddq_f32(vacc2x2, vacc2x3); + const float32x4_t vacc3x23 = vpaddq_f32(vacc3x2, vacc3x3); + + float32x4_t vacc0x0123 = vpaddq_f32(vacc0x01, vacc0x23); + float32x4_t vacc1x0123 = vpaddq_f32(vacc1x01, vacc1x23); + float32x4_t vacc2x0123 = vpaddq_f32(vacc2x01, vacc2x23); + float32x4_t vacc3x0123 = vpaddq_f32(vacc3x01, vacc3x23); +#else + const float32x2_t vsum0x0 = vadd_f32(vget_low_f32(vacc0x0), vget_high_f32(vacc0x0)); + const float32x2_t vsum1x0 = vadd_f32(vget_low_f32(vacc1x0), vget_high_f32(vacc1x0)); + const float32x2_t vsum2x0 = vadd_f32(vget_low_f32(vacc2x0), vget_high_f32(vacc2x0)); + const float32x2_t vsum3x0 = vadd_f32(vget_low_f32(vacc3x0), vget_high_f32(vacc3x0)); + const float32x2_t vsum0x1 = vadd_f32(vget_low_f32(vacc0x1), vget_high_f32(vacc0x1)); + const float32x2_t vsum1x1 = vadd_f32(vget_low_f32(vacc1x1), vget_high_f32(vacc1x1)); + const float32x2_t vsum2x1 = vadd_f32(vget_low_f32(vacc2x1), vget_high_f32(vacc2x1)); + const float32x2_t vsum3x1 = vadd_f32(vget_low_f32(vacc3x1), vget_high_f32(vacc3x1)); + const float32x2_t vsum0x2 = vadd_f32(vget_low_f32(vacc0x2), vget_high_f32(vacc0x2)); + const float32x2_t vsum1x2 = vadd_f32(vget_low_f32(vacc1x2), vget_high_f32(vacc1x2)); + const float32x2_t vsum2x2 = vadd_f32(vget_low_f32(vacc2x2), vget_high_f32(vacc2x2)); + const float32x2_t vsum3x2 = vadd_f32(vget_low_f32(vacc3x2), vget_high_f32(vacc3x2)); + const float32x2_t vsum0x3 = vadd_f32(vget_low_f32(vacc0x3), vget_high_f32(vacc0x3)); + const float32x2_t vsum1x3 = vadd_f32(vget_low_f32(vacc1x3), vget_high_f32(vacc1x3)); + const float32x2_t vsum2x3 = vadd_f32(vget_low_f32(vacc2x3), vget_high_f32(vacc2x3)); + const float32x2_t vsum3x3 = vadd_f32(vget_low_f32(vacc3x3), vget_high_f32(vacc3x3)); + + float32x4_t vacc0x0123 = vcombine_f32(vpadd_f32(vsum0x0, vsum0x1), vpadd_f32(vsum0x2, vsum0x3)); + float32x4_t vacc1x0123 = vcombine_f32(vpadd_f32(vsum1x0, vsum1x1), vpadd_f32(vsum1x2, vsum1x3)); + float32x4_t vacc2x0123 = vcombine_f32(vpadd_f32(vsum2x0, vsum2x1), vpadd_f32(vsum2x2, vsum2x3)); + float32x4_t vacc3x0123 = vcombine_f32(vpadd_f32(vsum3x0, vsum3x1), vpadd_f32(vsum3x2, vsum3x3)); +#endif + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + vacc0x0123 = vminq_f32(vacc0x0123, vmax); + vacc1x0123 = vminq_f32(vacc1x0123, vmax); + vacc2x0123 = vminq_f32(vacc2x0123, vmax); + vacc3x0123 = vminq_f32(vacc3x0123, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + vacc0x0123 = vmaxq_f32(vacc0x0123, vmin); + vacc1x0123 = vmaxq_f32(vacc1x0123, vmin); + vacc2x0123 = vmaxq_f32(vacc2x0123, vmin); + vacc3x0123 = vmaxq_f32(vacc3x0123, vmin); + + uint16x4_t vout0x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc0x0123), 16); + uint16x4_t vout1x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc1x0123), 16); + uint16x4_t vout2x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc2x0123), 16); + uint16x4_t vout3x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc3x0123), 16); + + if XNN_LIKELY(nc >= 4) { + vst1_u16(c0, vout0x0123); + c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); + vst1_u16(c1, vout1x0123); + c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride); + vst1_u16(c2, vout2x0123); + c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride); + vst1_u16(c3, vout3x0123); + c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride); + + a0 = (const uint16_t*) ((uintptr_t) a0 - kc); + a1 = (const uint16_t*) ((uintptr_t) a1 - kc); + a2 = (const uint16_t*) ((uintptr_t) a2 - kc); + a3 = (const uint16_t*) ((uintptr_t) a3 - kc); + + nc -= 4; + } else { + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_u16(vout0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_u16(vout1x0123), 0); c1 += 2; + vst1_lane_u32((void*) c2, vreinterpret_u32_u16(vout2x0123), 0); c2 += 2; + vst1_lane_u32((void*) c3, vreinterpret_u32_u16(vout3x0123), 0); c3 += 2; + + vout0x0123 = vext_u16(vout0x0123, vout0x0123, 2); + vout1x0123 = vext_u16(vout1x0123, vout1x0123, 2); + vout2x0123 = vext_u16(vout2x0123, vout2x0123, 2); + vout3x0123 = vext_u16(vout3x0123, vout3x0123, 2); + } + if (nc & 1) { + vst1_lane_u16(c0, vout0x0123, 0); + vst1_lane_u16(c1, vout1x0123, 0); + vst1_lane_u16(c2, vout2x0123, 0); + vst1_lane_u16(c3, vout3x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/bf16-gemm/gen/4x8c2-minmax-neonbf16-bfdot-lane-ld128.c b/src/bf16-gemm/gen/4x8c2-minmax-neonbf16-bfdot-lane-ld128.c new file mode 100644 index 000000000000..b8e954515c14 --- /dev/null +++ b/src/bf16-gemm/gen/4x8c2-minmax-neonbf16-bfdot-lane-ld128.c @@ -0,0 +1,330 @@ +// Auto-generated file. Do not edit! +// Template: src/bf16-gemm/c2-neonbf16-bfdot-lane-ld128.c.in +// Generator: tools/xngen +// +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 4); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(bfloat16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const bfloat16_t* a0 = (const bfloat16_t*) a; + bfloat16_t* c0 = (bfloat16_t*) c; + const bfloat16_t* a1 = (const bfloat16_t*) ((uintptr_t) a0 + a_stride); + bfloat16_t* c1 = (bfloat16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const bfloat16_t* a2 = (const bfloat16_t*) ((uintptr_t) a1 + a_stride); + bfloat16_t* c2 = (bfloat16_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + const bfloat16_t* a3 = (const bfloat16_t*) ((uintptr_t) a2 + a_stride); + bfloat16_t* c3 = (bfloat16_t*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr != 4) { + a3 = a2; + c3 = c2; + } + + const bfloat16_t* w = (const bfloat16_t*) w_ptr; + do { + float32x4_t vacc0x0123 = vcvt_f32_bf16(vld1_bf16(w)); w += 4; + float32x4_t vacc0x4567 = vcvt_f32_bf16(vld1_bf16(w)); w += 4; + float32x4_t vacc1x0123 = vacc0x0123; + float32x4_t vacc1x4567 = vacc0x4567; + float32x4_t vacc2x0123 = vacc0x0123; + float32x4_t vacc2x4567 = vacc0x4567; + float32x4_t vacc3x0123 = vacc0x0123; + float32x4_t vacc3x4567 = vacc0x4567; + + size_t k = kc; + for (; k >= 8 * sizeof(bfloat16_t); k -= 8 * sizeof(bfloat16_t)) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 += 8; + const bfloat16x8_t va1 = vld1q_bf16(a1); a1 += 8; + const bfloat16x8_t va2 = vld1q_bf16(a2); a2 += 8; + const bfloat16x8_t va3 = vld1q_bf16(a3); a3 += 8; + + const bfloat16x8_t vb0123c01 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c01 = vld1q_bf16(w); w += 8; + + vacc0x0123 = vbfdotq_laneq_f32(vacc0x0123, vb0123c01, va0, 0); + vacc1x0123 = vbfdotq_laneq_f32(vacc1x0123, vb0123c01, va1, 0); + vacc2x0123 = vbfdotq_laneq_f32(vacc2x0123, vb0123c01, va2, 0); + vacc3x0123 = vbfdotq_laneq_f32(vacc3x0123, vb0123c01, va3, 0); + vacc0x4567 = vbfdotq_laneq_f32(vacc0x4567, vb4567c01, va0, 0); + vacc1x4567 = vbfdotq_laneq_f32(vacc1x4567, vb4567c01, va1, 0); + vacc2x4567 = vbfdotq_laneq_f32(vacc2x4567, vb4567c01, va2, 0); + vacc3x4567 = vbfdotq_laneq_f32(vacc3x4567, vb4567c01, va3, 0); + const bfloat16x8_t vb0123c23 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c23 = vld1q_bf16(w); w += 8; + + vacc0x0123 = vbfdotq_laneq_f32(vacc0x0123, vb0123c23, va0, 1); + vacc1x0123 = vbfdotq_laneq_f32(vacc1x0123, vb0123c23, va1, 1); + vacc2x0123 = vbfdotq_laneq_f32(vacc2x0123, vb0123c23, va2, 1); + vacc3x0123 = vbfdotq_laneq_f32(vacc3x0123, vb0123c23, va3, 1); + vacc0x4567 = vbfdotq_laneq_f32(vacc0x4567, vb4567c23, va0, 1); + vacc1x4567 = vbfdotq_laneq_f32(vacc1x4567, vb4567c23, va1, 1); + vacc2x4567 = vbfdotq_laneq_f32(vacc2x4567, vb4567c23, va2, 1); + vacc3x4567 = vbfdotq_laneq_f32(vacc3x4567, vb4567c23, va3, 1); + const bfloat16x8_t vb0123c45 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c45 = vld1q_bf16(w); w += 8; + + vacc0x0123 = vbfdotq_laneq_f32(vacc0x0123, vb0123c45, va0, 2); + vacc1x0123 = vbfdotq_laneq_f32(vacc1x0123, vb0123c45, va1, 2); + vacc2x0123 = vbfdotq_laneq_f32(vacc2x0123, vb0123c45, va2, 2); + vacc3x0123 = vbfdotq_laneq_f32(vacc3x0123, vb0123c45, va3, 2); + vacc0x4567 = vbfdotq_laneq_f32(vacc0x4567, vb4567c45, va0, 2); + vacc1x4567 = vbfdotq_laneq_f32(vacc1x4567, vb4567c45, va1, 2); + vacc2x4567 = vbfdotq_laneq_f32(vacc2x4567, vb4567c45, va2, 2); + vacc3x4567 = vbfdotq_laneq_f32(vacc3x4567, vb4567c45, va3, 2); + const bfloat16x8_t vb0123c67 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c67 = vld1q_bf16(w); w += 8; + + vacc0x0123 = vbfdotq_laneq_f32(vacc0x0123, vb0123c67, va0, 3); + vacc1x0123 = vbfdotq_laneq_f32(vacc1x0123, vb0123c67, va1, 3); + vacc2x0123 = vbfdotq_laneq_f32(vacc2x0123, vb0123c67, va2, 3); + vacc3x0123 = vbfdotq_laneq_f32(vacc3x0123, vb0123c67, va3, 3); + vacc0x4567 = vbfdotq_laneq_f32(vacc0x4567, vb4567c67, va0, 3); + vacc1x4567 = vbfdotq_laneq_f32(vacc1x4567, vb4567c67, va1, 3); + vacc2x4567 = vbfdotq_laneq_f32(vacc2x4567, vb4567c67, va2, 3); + vacc3x4567 = vbfdotq_laneq_f32(vacc3x4567, vb4567c67, va3, 3); + } + if XNN_UNLIKELY(k != 0) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 = (const bfloat16_t*) ((uintptr_t) a0 + k); + const bfloat16x8_t va1 = vld1q_bf16(a1); a1 = (const bfloat16_t*) ((uintptr_t) a1 + k); + const bfloat16x8_t va2 = vld1q_bf16(a2); a2 = (const bfloat16_t*) ((uintptr_t) a2 + k); + const bfloat16x8_t va3 = vld1q_bf16(a3); a3 = (const bfloat16_t*) ((uintptr_t) a3 + k); + + const bfloat16x8_t vb0123c01 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c01 = vld1q_bf16(w); w += 8; + + const uint32x4_t va0c01 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va0)), 0); + const uint32x4_t va1c01 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va1)), 0); + const uint32x4_t va2c01 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va2)), 0); + const uint32x4_t va3c01 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va3)), 0); + + const uint32x4_t vm0123c01 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb0123c01), vmovq_n_u16(0))); + const uint32x4_t vm4567c01 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb4567c01), vmovq_n_u16(0))); + + const uint32x4_t va0x0123c01 = vbicq_u32(va0c01, vm0123c01); + vacc0x0123 = vbfdotq_f32(vacc0x0123, vb0123c01, vreinterpretq_bf16_u32(va0x0123c01)); + const uint32x4_t va1x0123c01 = vbicq_u32(va1c01, vm0123c01); + vacc1x0123 = vbfdotq_f32(vacc1x0123, vb0123c01, vreinterpretq_bf16_u32(va1x0123c01)); + const uint32x4_t va2x0123c01 = vbicq_u32(va2c01, vm0123c01); + vacc2x0123 = vbfdotq_f32(vacc2x0123, vb0123c01, vreinterpretq_bf16_u32(va2x0123c01)); + const uint32x4_t va3x0123c01 = vbicq_u32(va3c01, vm0123c01); + vacc3x0123 = vbfdotq_f32(vacc3x0123, vb0123c01, vreinterpretq_bf16_u32(va3x0123c01)); + const uint32x4_t va0x4567c01 = vbicq_u32(va0c01, vm4567c01); + vacc0x4567 = vbfdotq_f32(vacc0x4567, vb4567c01, vreinterpretq_bf16_u32(va0x4567c01)); + const uint32x4_t va1x4567c01 = vbicq_u32(va1c01, vm4567c01); + vacc1x4567 = vbfdotq_f32(vacc1x4567, vb4567c01, vreinterpretq_bf16_u32(va1x4567c01)); + const uint32x4_t va2x4567c01 = vbicq_u32(va2c01, vm4567c01); + vacc2x4567 = vbfdotq_f32(vacc2x4567, vb4567c01, vreinterpretq_bf16_u32(va2x4567c01)); + const uint32x4_t va3x4567c01 = vbicq_u32(va3c01, vm4567c01); + vacc3x4567 = vbfdotq_f32(vacc3x4567, vb4567c01, vreinterpretq_bf16_u32(va3x4567c01)); + + if (k > 2 * sizeof(bfloat16_t)) { + const bfloat16x8_t vb0123c23 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c23 = vld1q_bf16(w); w += 8; + + const uint32x4_t va0c23 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va0)), 1); + const uint32x4_t va1c23 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va1)), 1); + const uint32x4_t va2c23 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va2)), 1); + const uint32x4_t va3c23 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va3)), 1); + + const uint32x4_t vm0123c23 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb0123c23), vmovq_n_u16(0))); + const uint32x4_t vm4567c23 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb4567c23), vmovq_n_u16(0))); + + const uint32x4_t va0x0123c23 = vbicq_u32(va0c23, vm0123c23); + vacc0x0123 = vbfdotq_f32(vacc0x0123, vb0123c23, vreinterpretq_bf16_u32(va0x0123c23)); + const uint32x4_t va1x0123c23 = vbicq_u32(va1c23, vm0123c23); + vacc1x0123 = vbfdotq_f32(vacc1x0123, vb0123c23, vreinterpretq_bf16_u32(va1x0123c23)); + const uint32x4_t va2x0123c23 = vbicq_u32(va2c23, vm0123c23); + vacc2x0123 = vbfdotq_f32(vacc2x0123, vb0123c23, vreinterpretq_bf16_u32(va2x0123c23)); + const uint32x4_t va3x0123c23 = vbicq_u32(va3c23, vm0123c23); + vacc3x0123 = vbfdotq_f32(vacc3x0123, vb0123c23, vreinterpretq_bf16_u32(va3x0123c23)); + const uint32x4_t va0x4567c23 = vbicq_u32(va0c23, vm4567c23); + vacc0x4567 = vbfdotq_f32(vacc0x4567, vb4567c23, vreinterpretq_bf16_u32(va0x4567c23)); + const uint32x4_t va1x4567c23 = vbicq_u32(va1c23, vm4567c23); + vacc1x4567 = vbfdotq_f32(vacc1x4567, vb4567c23, vreinterpretq_bf16_u32(va1x4567c23)); + const uint32x4_t va2x4567c23 = vbicq_u32(va2c23, vm4567c23); + vacc2x4567 = vbfdotq_f32(vacc2x4567, vb4567c23, vreinterpretq_bf16_u32(va2x4567c23)); + const uint32x4_t va3x4567c23 = vbicq_u32(va3c23, vm4567c23); + vacc3x4567 = vbfdotq_f32(vacc3x4567, vb4567c23, vreinterpretq_bf16_u32(va3x4567c23)); + + if (k > 4 * sizeof(bfloat16_t)) { + const bfloat16x8_t vb0123c45 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c45 = vld1q_bf16(w); w += 8; + + const uint32x4_t va0c45 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va0)), 0); + const uint32x4_t va1c45 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va1)), 0); + const uint32x4_t va2c45 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va2)), 0); + const uint32x4_t va3c45 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va3)), 0); + + const uint32x4_t vm0123c45 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb0123c45), vmovq_n_u16(0))); + const uint32x4_t vm4567c45 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb4567c45), vmovq_n_u16(0))); + + const uint32x4_t va0x0123c45 = vbicq_u32(va0c45, vm0123c45); + vacc0x0123 = vbfdotq_f32(vacc0x0123, vb0123c45, vreinterpretq_bf16_u32(va0x0123c45)); + const uint32x4_t va1x0123c45 = vbicq_u32(va1c45, vm0123c45); + vacc1x0123 = vbfdotq_f32(vacc1x0123, vb0123c45, vreinterpretq_bf16_u32(va1x0123c45)); + const uint32x4_t va2x0123c45 = vbicq_u32(va2c45, vm0123c45); + vacc2x0123 = vbfdotq_f32(vacc2x0123, vb0123c45, vreinterpretq_bf16_u32(va2x0123c45)); + const uint32x4_t va3x0123c45 = vbicq_u32(va3c45, vm0123c45); + vacc3x0123 = vbfdotq_f32(vacc3x0123, vb0123c45, vreinterpretq_bf16_u32(va3x0123c45)); + const uint32x4_t va0x4567c45 = vbicq_u32(va0c45, vm4567c45); + vacc0x4567 = vbfdotq_f32(vacc0x4567, vb4567c45, vreinterpretq_bf16_u32(va0x4567c45)); + const uint32x4_t va1x4567c45 = vbicq_u32(va1c45, vm4567c45); + vacc1x4567 = vbfdotq_f32(vacc1x4567, vb4567c45, vreinterpretq_bf16_u32(va1x4567c45)); + const uint32x4_t va2x4567c45 = vbicq_u32(va2c45, vm4567c45); + vacc2x4567 = vbfdotq_f32(vacc2x4567, vb4567c45, vreinterpretq_bf16_u32(va2x4567c45)); + const uint32x4_t va3x4567c45 = vbicq_u32(va3c45, vm4567c45); + vacc3x4567 = vbfdotq_f32(vacc3x4567, vb4567c45, vreinterpretq_bf16_u32(va3x4567c45)); + + if (k > 6 * sizeof(bfloat16_t)) { + const bfloat16x8_t vb0123c67 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c67 = vld1q_bf16(w); w += 8; + + const uint32x4_t va0c67 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va0)), 1); + const uint32x4_t va1c67 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va1)), 1); + const uint32x4_t va2c67 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va2)), 1); + const uint32x4_t va3c67 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va3)), 1); + + const uint32x4_t vm0123c67 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb0123c67), vmovq_n_u16(0))); + const uint32x4_t vm4567c67 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb4567c67), vmovq_n_u16(0))); + + const uint32x4_t va0x0123c67 = vbicq_u32(va0c67, vm0123c67); + vacc0x0123 = vbfdotq_f32(vacc0x0123, vb0123c67, vreinterpretq_bf16_u32(va0x0123c67)); + const uint32x4_t va1x0123c67 = vbicq_u32(va1c67, vm0123c67); + vacc1x0123 = vbfdotq_f32(vacc1x0123, vb0123c67, vreinterpretq_bf16_u32(va1x0123c67)); + const uint32x4_t va2x0123c67 = vbicq_u32(va2c67, vm0123c67); + vacc2x0123 = vbfdotq_f32(vacc2x0123, vb0123c67, vreinterpretq_bf16_u32(va2x0123c67)); + const uint32x4_t va3x0123c67 = vbicq_u32(va3c67, vm0123c67); + vacc3x0123 = vbfdotq_f32(vacc3x0123, vb0123c67, vreinterpretq_bf16_u32(va3x0123c67)); + const uint32x4_t va0x4567c67 = vbicq_u32(va0c67, vm4567c67); + vacc0x4567 = vbfdotq_f32(vacc0x4567, vb4567c67, vreinterpretq_bf16_u32(va0x4567c67)); + const uint32x4_t va1x4567c67 = vbicq_u32(va1c67, vm4567c67); + vacc1x4567 = vbfdotq_f32(vacc1x4567, vb4567c67, vreinterpretq_bf16_u32(va1x4567c67)); + const uint32x4_t va2x4567c67 = vbicq_u32(va2c67, vm4567c67); + vacc2x4567 = vbfdotq_f32(vacc2x4567, vb4567c67, vreinterpretq_bf16_u32(va2x4567c67)); + const uint32x4_t va3x4567c67 = vbicq_u32(va3c67, vm4567c67); + vacc3x4567 = vbfdotq_f32(vacc3x4567, vb4567c67, vreinterpretq_bf16_u32(va3x4567c67)); + } + } + } + } + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + vacc0x0123 = vminq_f32(vacc0x0123, vmax); + vacc1x0123 = vminq_f32(vacc1x0123, vmax); + vacc2x0123 = vminq_f32(vacc2x0123, vmax); + vacc3x0123 = vminq_f32(vacc3x0123, vmax); + vacc0x4567 = vminq_f32(vacc0x4567, vmax); + vacc1x4567 = vminq_f32(vacc1x4567, vmax); + vacc2x4567 = vminq_f32(vacc2x4567, vmax); + vacc3x4567 = vminq_f32(vacc3x4567, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + vacc0x0123 = vmaxq_f32(vacc0x0123, vmin); + vacc1x0123 = vmaxq_f32(vacc1x0123, vmin); + vacc2x0123 = vmaxq_f32(vacc2x0123, vmin); + vacc3x0123 = vmaxq_f32(vacc3x0123, vmin); + vacc0x4567 = vmaxq_f32(vacc0x4567, vmin); + vacc1x4567 = vmaxq_f32(vacc1x4567, vmin); + vacc2x4567 = vmaxq_f32(vacc2x4567, vmin); + vacc3x4567 = vmaxq_f32(vacc3x4567, vmin); + + bfloat16x4_t vout0x0123 = vcvt_bf16_f32(vacc0x0123); + bfloat16x4_t vout1x0123 = vcvt_bf16_f32(vacc1x0123); + bfloat16x4_t vout2x0123 = vcvt_bf16_f32(vacc2x0123); + bfloat16x4_t vout3x0123 = vcvt_bf16_f32(vacc3x0123); + bfloat16x4_t vout0x4567 = vcvt_bf16_f32(vacc0x4567); + bfloat16x4_t vout1x4567 = vcvt_bf16_f32(vacc1x4567); + bfloat16x4_t vout2x4567 = vcvt_bf16_f32(vacc2x4567); + bfloat16x4_t vout3x4567 = vcvt_bf16_f32(vacc3x4567); + + if XNN_LIKELY(nc >= 8) { + vst1_bf16(c0, vout0x0123); + vst1_bf16(c0 + 4, vout0x4567); + c0 = (bfloat16_t*) ((uintptr_t) c0 + cn_stride); + vst1_bf16(c1, vout1x0123); + vst1_bf16(c1 + 4, vout1x4567); + c1 = (bfloat16_t*) ((uintptr_t) c1 + cn_stride); + vst1_bf16(c2, vout2x0123); + vst1_bf16(c2 + 4, vout2x4567); + c2 = (bfloat16_t*) ((uintptr_t) c2 + cn_stride); + vst1_bf16(c3, vout3x0123); + vst1_bf16(c3 + 4, vout3x4567); + c3 = (bfloat16_t*) ((uintptr_t) c3 + cn_stride); + + a0 = (const bfloat16_t*) ((uintptr_t) a0 - kc); + a1 = (const bfloat16_t*) ((uintptr_t) a1 - kc); + a2 = (const bfloat16_t*) ((uintptr_t) a2 - kc); + a3 = (const bfloat16_t*) ((uintptr_t) a3 - kc); + + nc -= 8; + } else { + if (nc & 4) { + vst1_bf16(c0, vout0x0123); c0 += 4; + vst1_bf16(c1, vout1x0123); c1 += 4; + vst1_bf16(c2, vout2x0123); c2 += 4; + vst1_bf16(c3, vout3x0123); c3 += 4; + + vout0x0123 = vout0x4567; + vout1x0123 = vout1x4567; + vout2x0123 = vout2x4567; + vout3x0123 = vout3x4567; + } + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_bf16(vout0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_bf16(vout1x0123), 0); c1 += 2; + vst1_lane_u32((void*) c2, vreinterpret_u32_bf16(vout2x0123), 0); c2 += 2; + vst1_lane_u32((void*) c3, vreinterpret_u32_bf16(vout3x0123), 0); c3 += 2; + + vout0x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout0x0123), vreinterpret_u16_bf16(vout0x0123), 2)); + vout1x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout1x0123), vreinterpret_u16_bf16(vout1x0123), 2)); + vout2x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout2x0123), vreinterpret_u16_bf16(vout2x0123), 2)); + vout3x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout3x0123), vreinterpret_u16_bf16(vout3x0123), 2)); + } + if (nc & 1) { + vst1_lane_bf16(c0, vout0x0123, 0); + vst1_lane_bf16(c1, vout1x0123, 0); + vst1_lane_bf16(c2, vout2x0123, 0); + vst1_lane_bf16(c3, vout3x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/bf16-gemm/gen/5x4c8-minmax-neonbf16-bfdot.c b/src/bf16-gemm/gen/5x4c8-minmax-neonbf16-bfdot.c new file mode 100644 index 000000000000..49e84cf543a5 --- /dev/null +++ b/src/bf16-gemm/gen/5x4c8-minmax-neonbf16-bfdot.c @@ -0,0 +1,292 @@ +// Auto-generated file. Do not edit! +// Template: src/bf16-gemm/c8-neonbf16.c.in +// Generator: tools/xngen +// +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 5); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(bfloat16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const bfloat16_t* a0 = (const bfloat16_t*) a; + bfloat16_t* c0 = (bfloat16_t*) c; + const bfloat16_t* a1 = (const bfloat16_t*) ((uintptr_t) a0 + a_stride); + bfloat16_t* c1 = (bfloat16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const bfloat16_t* a2 = (const bfloat16_t*) ((uintptr_t) a1 + a_stride); + bfloat16_t* c2 = (bfloat16_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + const bfloat16_t* a3 = (const bfloat16_t*) ((uintptr_t) a2 + a_stride); + bfloat16_t* c3 = (bfloat16_t*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr < 4) { + a3 = a2; + c3 = c2; + } + const bfloat16_t* a4 = (const bfloat16_t*) ((uintptr_t) a3 + a_stride); + bfloat16_t* c4 = (bfloat16_t*) ((uintptr_t) c3 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 4) { + a4 = a3; + c4 = c3; + } + + const bfloat16_t* w = (const bfloat16_t*) w_ptr; + do { + float32x4_t vacc0x0 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x1 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x2 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x3 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc1x0 = vacc0x0; + float32x4_t vacc1x1 = vacc0x1; + float32x4_t vacc1x2 = vacc0x2; + float32x4_t vacc1x3 = vacc0x3; + float32x4_t vacc2x0 = vacc0x0; + float32x4_t vacc2x1 = vacc0x1; + float32x4_t vacc2x2 = vacc0x2; + float32x4_t vacc2x3 = vacc0x3; + float32x4_t vacc3x0 = vacc0x0; + float32x4_t vacc3x1 = vacc0x1; + float32x4_t vacc3x2 = vacc0x2; + float32x4_t vacc3x3 = vacc0x3; + float32x4_t vacc4x0 = vacc0x0; + float32x4_t vacc4x1 = vacc0x1; + float32x4_t vacc4x2 = vacc0x2; + float32x4_t vacc4x3 = vacc0x3; + + size_t k = kc; + for (; k >= 8 * sizeof(bfloat16_t); k -= 8 * sizeof(bfloat16_t)) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 += 8; + const bfloat16x8_t va1 = vld1q_bf16(a1); a1 += 8; + const bfloat16x8_t va2 = vld1q_bf16(a2); a2 += 8; + const bfloat16x8_t va3 = vld1q_bf16(a3); a3 += 8; + const bfloat16x8_t va4 = vld1q_bf16(a4); a4 += 8; + + const bfloat16x8_t vb0 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb1 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb2 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb3 = vld1q_bf16(w); w += 8; + + vacc0x0 = vbfdotq_f32(vacc0x0, va0, vb0); + vacc1x0 = vbfdotq_f32(vacc1x0, va1, vb0); + vacc2x0 = vbfdotq_f32(vacc2x0, va2, vb0); + vacc3x0 = vbfdotq_f32(vacc3x0, va3, vb0); + vacc4x0 = vbfdotq_f32(vacc4x0, va4, vb0); + vacc0x1 = vbfdotq_f32(vacc0x1, va0, vb1); + vacc1x1 = vbfdotq_f32(vacc1x1, va1, vb1); + vacc2x1 = vbfdotq_f32(vacc2x1, va2, vb1); + vacc3x1 = vbfdotq_f32(vacc3x1, va3, vb1); + vacc4x1 = vbfdotq_f32(vacc4x1, va4, vb1); + vacc0x2 = vbfdotq_f32(vacc0x2, va0, vb2); + vacc1x2 = vbfdotq_f32(vacc1x2, va1, vb2); + vacc2x2 = vbfdotq_f32(vacc2x2, va2, vb2); + vacc3x2 = vbfdotq_f32(vacc3x2, va3, vb2); + vacc4x2 = vbfdotq_f32(vacc4x2, va4, vb2); + vacc0x3 = vbfdotq_f32(vacc0x3, va0, vb3); + vacc1x3 = vbfdotq_f32(vacc1x3, va1, vb3); + vacc2x3 = vbfdotq_f32(vacc2x3, va2, vb3); + vacc3x3 = vbfdotq_f32(vacc3x3, va3, vb3); + vacc4x3 = vbfdotq_f32(vacc4x3, va4, vb3); + } + if XNN_UNLIKELY(k != 0) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 = (const bfloat16_t*) ((uintptr_t) a0 + k); + const bfloat16x8_t va1 = vld1q_bf16(a1); a1 = (const bfloat16_t*) ((uintptr_t) a1 + k); + const bfloat16x8_t va2 = vld1q_bf16(a2); a2 = (const bfloat16_t*) ((uintptr_t) a2 + k); + const bfloat16x8_t va3 = vld1q_bf16(a3); a3 = (const bfloat16_t*) ((uintptr_t) a3 + k); + const bfloat16x8_t va4 = vld1q_bf16(a4); a4 = (const bfloat16_t*) ((uintptr_t) a4 + k); + + const bfloat16x8_t vb0 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb1 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb2 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb3 = vld1q_bf16(w); w += 8; + + const uint16x8_t vm0 = vceqq_u16(vreinterpretq_u16_bf16(vb0), vmovq_n_u16(0)); + const uint16x8_t vm1 = vceqq_u16(vreinterpretq_u16_bf16(vb1), vmovq_n_u16(0)); + const uint16x8_t vm2 = vceqq_u16(vreinterpretq_u16_bf16(vb2), vmovq_n_u16(0)); + const uint16x8_t vm3 = vceqq_u16(vreinterpretq_u16_bf16(vb3), vmovq_n_u16(0)); + + const bfloat16x8_t va0x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm0)); + vacc0x0 = vbfdotq_f32(vacc0x0, va0x0, vb0); + const bfloat16x8_t va1x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm0)); + vacc1x0 = vbfdotq_f32(vacc1x0, va1x0, vb0); + const bfloat16x8_t va2x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va2), vm0)); + vacc2x0 = vbfdotq_f32(vacc2x0, va2x0, vb0); + const bfloat16x8_t va3x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va3), vm0)); + vacc3x0 = vbfdotq_f32(vacc3x0, va3x0, vb0); + const bfloat16x8_t va4x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va4), vm0)); + vacc4x0 = vbfdotq_f32(vacc4x0, va4x0, vb0); + const bfloat16x8_t va0x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm1)); + vacc0x1 = vbfdotq_f32(vacc0x1, va0x1, vb1); + const bfloat16x8_t va1x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm1)); + vacc1x1 = vbfdotq_f32(vacc1x1, va1x1, vb1); + const bfloat16x8_t va2x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va2), vm1)); + vacc2x1 = vbfdotq_f32(vacc2x1, va2x1, vb1); + const bfloat16x8_t va3x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va3), vm1)); + vacc3x1 = vbfdotq_f32(vacc3x1, va3x1, vb1); + const bfloat16x8_t va4x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va4), vm1)); + vacc4x1 = vbfdotq_f32(vacc4x1, va4x1, vb1); + const bfloat16x8_t va0x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm2)); + vacc0x2 = vbfdotq_f32(vacc0x2, va0x2, vb2); + const bfloat16x8_t va1x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm2)); + vacc1x2 = vbfdotq_f32(vacc1x2, va1x2, vb2); + const bfloat16x8_t va2x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va2), vm2)); + vacc2x2 = vbfdotq_f32(vacc2x2, va2x2, vb2); + const bfloat16x8_t va3x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va3), vm2)); + vacc3x2 = vbfdotq_f32(vacc3x2, va3x2, vb2); + const bfloat16x8_t va4x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va4), vm2)); + vacc4x2 = vbfdotq_f32(vacc4x2, va4x2, vb2); + const bfloat16x8_t va0x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm3)); + vacc0x3 = vbfdotq_f32(vacc0x3, va0x3, vb3); + const bfloat16x8_t va1x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm3)); + vacc1x3 = vbfdotq_f32(vacc1x3, va1x3, vb3); + const bfloat16x8_t va2x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va2), vm3)); + vacc2x3 = vbfdotq_f32(vacc2x3, va2x3, vb3); + const bfloat16x8_t va3x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va3), vm3)); + vacc3x3 = vbfdotq_f32(vacc3x3, va3x3, vb3); + const bfloat16x8_t va4x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va4), vm3)); + vacc4x3 = vbfdotq_f32(vacc4x3, va4x3, vb3); + } + +#if XNN_ARCH_ARM64 + const float32x4_t vacc0x01 = vpaddq_f32(vacc0x0, vacc0x1); + const float32x4_t vacc1x01 = vpaddq_f32(vacc1x0, vacc1x1); + const float32x4_t vacc2x01 = vpaddq_f32(vacc2x0, vacc2x1); + const float32x4_t vacc3x01 = vpaddq_f32(vacc3x0, vacc3x1); + const float32x4_t vacc4x01 = vpaddq_f32(vacc4x0, vacc4x1); + const float32x4_t vacc0x23 = vpaddq_f32(vacc0x2, vacc0x3); + const float32x4_t vacc1x23 = vpaddq_f32(vacc1x2, vacc1x3); + const float32x4_t vacc2x23 = vpaddq_f32(vacc2x2, vacc2x3); + const float32x4_t vacc3x23 = vpaddq_f32(vacc3x2, vacc3x3); + const float32x4_t vacc4x23 = vpaddq_f32(vacc4x2, vacc4x3); + + float32x4_t vacc0x0123 = vpaddq_f32(vacc0x01, vacc0x23); + float32x4_t vacc1x0123 = vpaddq_f32(vacc1x01, vacc1x23); + float32x4_t vacc2x0123 = vpaddq_f32(vacc2x01, vacc2x23); + float32x4_t vacc3x0123 = vpaddq_f32(vacc3x01, vacc3x23); + float32x4_t vacc4x0123 = vpaddq_f32(vacc4x01, vacc4x23); +#else + const float32x2_t vsum0x0 = vadd_f32(vget_low_f32(vacc0x0), vget_high_f32(vacc0x0)); + const float32x2_t vsum1x0 = vadd_f32(vget_low_f32(vacc1x0), vget_high_f32(vacc1x0)); + const float32x2_t vsum2x0 = vadd_f32(vget_low_f32(vacc2x0), vget_high_f32(vacc2x0)); + const float32x2_t vsum3x0 = vadd_f32(vget_low_f32(vacc3x0), vget_high_f32(vacc3x0)); + const float32x2_t vsum4x0 = vadd_f32(vget_low_f32(vacc4x0), vget_high_f32(vacc4x0)); + const float32x2_t vsum0x1 = vadd_f32(vget_low_f32(vacc0x1), vget_high_f32(vacc0x1)); + const float32x2_t vsum1x1 = vadd_f32(vget_low_f32(vacc1x1), vget_high_f32(vacc1x1)); + const float32x2_t vsum2x1 = vadd_f32(vget_low_f32(vacc2x1), vget_high_f32(vacc2x1)); + const float32x2_t vsum3x1 = vadd_f32(vget_low_f32(vacc3x1), vget_high_f32(vacc3x1)); + const float32x2_t vsum4x1 = vadd_f32(vget_low_f32(vacc4x1), vget_high_f32(vacc4x1)); + const float32x2_t vsum0x2 = vadd_f32(vget_low_f32(vacc0x2), vget_high_f32(vacc0x2)); + const float32x2_t vsum1x2 = vadd_f32(vget_low_f32(vacc1x2), vget_high_f32(vacc1x2)); + const float32x2_t vsum2x2 = vadd_f32(vget_low_f32(vacc2x2), vget_high_f32(vacc2x2)); + const float32x2_t vsum3x2 = vadd_f32(vget_low_f32(vacc3x2), vget_high_f32(vacc3x2)); + const float32x2_t vsum4x2 = vadd_f32(vget_low_f32(vacc4x2), vget_high_f32(vacc4x2)); + const float32x2_t vsum0x3 = vadd_f32(vget_low_f32(vacc0x3), vget_high_f32(vacc0x3)); + const float32x2_t vsum1x3 = vadd_f32(vget_low_f32(vacc1x3), vget_high_f32(vacc1x3)); + const float32x2_t vsum2x3 = vadd_f32(vget_low_f32(vacc2x3), vget_high_f32(vacc2x3)); + const float32x2_t vsum3x3 = vadd_f32(vget_low_f32(vacc3x3), vget_high_f32(vacc3x3)); + const float32x2_t vsum4x3 = vadd_f32(vget_low_f32(vacc4x3), vget_high_f32(vacc4x3)); + + float32x4_t vacc0x0123 = vcombine_f32(vpadd_f32(vsum0x0, vsum0x1), vpadd_f32(vsum0x2, vsum0x3)); + float32x4_t vacc1x0123 = vcombine_f32(vpadd_f32(vsum1x0, vsum1x1), vpadd_f32(vsum1x2, vsum1x3)); + float32x4_t vacc2x0123 = vcombine_f32(vpadd_f32(vsum2x0, vsum2x1), vpadd_f32(vsum2x2, vsum2x3)); + float32x4_t vacc3x0123 = vcombine_f32(vpadd_f32(vsum3x0, vsum3x1), vpadd_f32(vsum3x2, vsum3x3)); + float32x4_t vacc4x0123 = vcombine_f32(vpadd_f32(vsum4x0, vsum4x1), vpadd_f32(vsum4x2, vsum4x3)); +#endif + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + vacc0x0123 = vminq_f32(vacc0x0123, vmax); + vacc1x0123 = vminq_f32(vacc1x0123, vmax); + vacc2x0123 = vminq_f32(vacc2x0123, vmax); + vacc3x0123 = vminq_f32(vacc3x0123, vmax); + vacc4x0123 = vminq_f32(vacc4x0123, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + vacc0x0123 = vmaxq_f32(vacc0x0123, vmin); + vacc1x0123 = vmaxq_f32(vacc1x0123, vmin); + vacc2x0123 = vmaxq_f32(vacc2x0123, vmin); + vacc3x0123 = vmaxq_f32(vacc3x0123, vmin); + vacc4x0123 = vmaxq_f32(vacc4x0123, vmin); + + bfloat16x4_t vout0x0123 = vcvt_bf16_f32(vacc0x0123); + bfloat16x4_t vout1x0123 = vcvt_bf16_f32(vacc1x0123); + bfloat16x4_t vout2x0123 = vcvt_bf16_f32(vacc2x0123); + bfloat16x4_t vout3x0123 = vcvt_bf16_f32(vacc3x0123); + bfloat16x4_t vout4x0123 = vcvt_bf16_f32(vacc4x0123); + + if XNN_LIKELY(nc >= 4) { + vst1_bf16(c0, vout0x0123); + c0 = (bfloat16_t*) ((uintptr_t) c0 + cn_stride); + vst1_bf16(c1, vout1x0123); + c1 = (bfloat16_t*) ((uintptr_t) c1 + cn_stride); + vst1_bf16(c2, vout2x0123); + c2 = (bfloat16_t*) ((uintptr_t) c2 + cn_stride); + vst1_bf16(c3, vout3x0123); + c3 = (bfloat16_t*) ((uintptr_t) c3 + cn_stride); + vst1_bf16(c4, vout4x0123); + c4 = (bfloat16_t*) ((uintptr_t) c4 + cn_stride); + + a0 = (const bfloat16_t*) ((uintptr_t) a0 - kc); + a1 = (const bfloat16_t*) ((uintptr_t) a1 - kc); + a2 = (const bfloat16_t*) ((uintptr_t) a2 - kc); + a3 = (const bfloat16_t*) ((uintptr_t) a3 - kc); + a4 = (const bfloat16_t*) ((uintptr_t) a4 - kc); + + nc -= 4; + } else { + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_bf16(vout0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_bf16(vout1x0123), 0); c1 += 2; + vst1_lane_u32((void*) c2, vreinterpret_u32_bf16(vout2x0123), 0); c2 += 2; + vst1_lane_u32((void*) c3, vreinterpret_u32_bf16(vout3x0123), 0); c3 += 2; + vst1_lane_u32((void*) c4, vreinterpret_u32_bf16(vout4x0123), 0); c4 += 2; + + vout0x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout0x0123), vreinterpret_u16_bf16(vout0x0123), 2)); + vout1x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout1x0123), vreinterpret_u16_bf16(vout1x0123), 2)); + vout2x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout2x0123), vreinterpret_u16_bf16(vout2x0123), 2)); + vout3x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout3x0123), vreinterpret_u16_bf16(vout3x0123), 2)); + vout4x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout4x0123), vreinterpret_u16_bf16(vout4x0123), 2)); + } + if (nc & 1) { + vst1_lane_bf16(c0, vout0x0123, 0); + vst1_lane_bf16(c1, vout1x0123, 0); + vst1_lane_bf16(c2, vout2x0123, 0); + vst1_lane_bf16(c3, vout3x0123, 0); + vst1_lane_bf16(c4, vout4x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/bf16-gemm/gen/5x4c8-minmax-neonbf16-bfmlal.c b/src/bf16-gemm/gen/5x4c8-minmax-neonbf16-bfmlal.c new file mode 100644 index 000000000000..04d626e0746f --- /dev/null +++ b/src/bf16-gemm/gen/5x4c8-minmax-neonbf16-bfmlal.c @@ -0,0 +1,333 @@ +// Auto-generated file. Do not edit! +// Template: src/bf16-gemm/c8-neonbf16.c.in +// Generator: tools/xngen +// +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 5); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(bfloat16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const bfloat16_t* a0 = (const bfloat16_t*) a; + bfloat16_t* c0 = (bfloat16_t*) c; + const bfloat16_t* a1 = (const bfloat16_t*) ((uintptr_t) a0 + a_stride); + bfloat16_t* c1 = (bfloat16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const bfloat16_t* a2 = (const bfloat16_t*) ((uintptr_t) a1 + a_stride); + bfloat16_t* c2 = (bfloat16_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + const bfloat16_t* a3 = (const bfloat16_t*) ((uintptr_t) a2 + a_stride); + bfloat16_t* c3 = (bfloat16_t*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr < 4) { + a3 = a2; + c3 = c2; + } + const bfloat16_t* a4 = (const bfloat16_t*) ((uintptr_t) a3 + a_stride); + bfloat16_t* c4 = (bfloat16_t*) ((uintptr_t) c3 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 4) { + a4 = a3; + c4 = c3; + } + + const bfloat16_t* w = (const bfloat16_t*) w_ptr; + do { + float32x4_t vacc0x0 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x1 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x2 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc0x3 = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; + float32x4_t vacc1x0 = vacc0x0; + float32x4_t vacc1x1 = vacc0x1; + float32x4_t vacc1x2 = vacc0x2; + float32x4_t vacc1x3 = vacc0x3; + float32x4_t vacc2x0 = vacc0x0; + float32x4_t vacc2x1 = vacc0x1; + float32x4_t vacc2x2 = vacc0x2; + float32x4_t vacc2x3 = vacc0x3; + float32x4_t vacc3x0 = vacc0x0; + float32x4_t vacc3x1 = vacc0x1; + float32x4_t vacc3x2 = vacc0x2; + float32x4_t vacc3x3 = vacc0x3; + float32x4_t vacc4x0 = vacc0x0; + float32x4_t vacc4x1 = vacc0x1; + float32x4_t vacc4x2 = vacc0x2; + float32x4_t vacc4x3 = vacc0x3; + + size_t k = kc; + for (; k >= 8 * sizeof(bfloat16_t); k -= 8 * sizeof(bfloat16_t)) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 += 8; + const bfloat16x8_t va1 = vld1q_bf16(a1); a1 += 8; + const bfloat16x8_t va2 = vld1q_bf16(a2); a2 += 8; + const bfloat16x8_t va3 = vld1q_bf16(a3); a3 += 8; + const bfloat16x8_t va4 = vld1q_bf16(a4); a4 += 8; + + const bfloat16x8_t vb0 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb1 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb2 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb3 = vld1q_bf16(w); w += 8; + + vacc0x0 = vbfmlalbq_f32(vacc0x0, va0, vb0); + vacc1x0 = vbfmlalbq_f32(vacc1x0, va1, vb0); + vacc2x0 = vbfmlalbq_f32(vacc2x0, va2, vb0); + vacc3x0 = vbfmlalbq_f32(vacc3x0, va3, vb0); + vacc4x0 = vbfmlalbq_f32(vacc4x0, va4, vb0); + vacc0x1 = vbfmlalbq_f32(vacc0x1, va0, vb1); + vacc1x1 = vbfmlalbq_f32(vacc1x1, va1, vb1); + vacc2x1 = vbfmlalbq_f32(vacc2x1, va2, vb1); + vacc3x1 = vbfmlalbq_f32(vacc3x1, va3, vb1); + vacc4x1 = vbfmlalbq_f32(vacc4x1, va4, vb1); + vacc0x2 = vbfmlalbq_f32(vacc0x2, va0, vb2); + vacc1x2 = vbfmlalbq_f32(vacc1x2, va1, vb2); + vacc2x2 = vbfmlalbq_f32(vacc2x2, va2, vb2); + vacc3x2 = vbfmlalbq_f32(vacc3x2, va3, vb2); + vacc4x2 = vbfmlalbq_f32(vacc4x2, va4, vb2); + vacc0x3 = vbfmlalbq_f32(vacc0x3, va0, vb3); + vacc1x3 = vbfmlalbq_f32(vacc1x3, va1, vb3); + vacc2x3 = vbfmlalbq_f32(vacc2x3, va2, vb3); + vacc3x3 = vbfmlalbq_f32(vacc3x3, va3, vb3); + vacc4x3 = vbfmlalbq_f32(vacc4x3, va4, vb3); + + vacc0x0 = vbfmlaltq_f32(vacc0x0, va0, vb0); + vacc1x0 = vbfmlaltq_f32(vacc1x0, va1, vb0); + vacc2x0 = vbfmlaltq_f32(vacc2x0, va2, vb0); + vacc3x0 = vbfmlaltq_f32(vacc3x0, va3, vb0); + vacc4x0 = vbfmlaltq_f32(vacc4x0, va4, vb0); + vacc0x1 = vbfmlaltq_f32(vacc0x1, va0, vb1); + vacc1x1 = vbfmlaltq_f32(vacc1x1, va1, vb1); + vacc2x1 = vbfmlaltq_f32(vacc2x1, va2, vb1); + vacc3x1 = vbfmlaltq_f32(vacc3x1, va3, vb1); + vacc4x1 = vbfmlaltq_f32(vacc4x1, va4, vb1); + vacc0x2 = vbfmlaltq_f32(vacc0x2, va0, vb2); + vacc1x2 = vbfmlaltq_f32(vacc1x2, va1, vb2); + vacc2x2 = vbfmlaltq_f32(vacc2x2, va2, vb2); + vacc3x2 = vbfmlaltq_f32(vacc3x2, va3, vb2); + vacc4x2 = vbfmlaltq_f32(vacc4x2, va4, vb2); + vacc0x3 = vbfmlaltq_f32(vacc0x3, va0, vb3); + vacc1x3 = vbfmlaltq_f32(vacc1x3, va1, vb3); + vacc2x3 = vbfmlaltq_f32(vacc2x3, va2, vb3); + vacc3x3 = vbfmlaltq_f32(vacc3x3, va3, vb3); + vacc4x3 = vbfmlaltq_f32(vacc4x3, va4, vb3); + } + if XNN_UNLIKELY(k != 0) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 = (const bfloat16_t*) ((uintptr_t) a0 + k); + const bfloat16x8_t va1 = vld1q_bf16(a1); a1 = (const bfloat16_t*) ((uintptr_t) a1 + k); + const bfloat16x8_t va2 = vld1q_bf16(a2); a2 = (const bfloat16_t*) ((uintptr_t) a2 + k); + const bfloat16x8_t va3 = vld1q_bf16(a3); a3 = (const bfloat16_t*) ((uintptr_t) a3 + k); + const bfloat16x8_t va4 = vld1q_bf16(a4); a4 = (const bfloat16_t*) ((uintptr_t) a4 + k); + + const bfloat16x8_t vb0 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb1 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb2 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb3 = vld1q_bf16(w); w += 8; + + const uint16x8_t vm0 = vceqq_u16(vreinterpretq_u16_bf16(vb0), vmovq_n_u16(0)); + const uint16x8_t vm1 = vceqq_u16(vreinterpretq_u16_bf16(vb1), vmovq_n_u16(0)); + const uint16x8_t vm2 = vceqq_u16(vreinterpretq_u16_bf16(vb2), vmovq_n_u16(0)); + const uint16x8_t vm3 = vceqq_u16(vreinterpretq_u16_bf16(vb3), vmovq_n_u16(0)); + + const bfloat16x8_t va0x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm0)); + vacc0x0 = vbfmlalbq_f32(vacc0x0, va0x0, vb0); + vacc0x0 = vbfmlaltq_f32(vacc0x0, va0x0, vb0); + const bfloat16x8_t va1x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm0)); + vacc1x0 = vbfmlalbq_f32(vacc1x0, va1x0, vb0); + vacc1x0 = vbfmlaltq_f32(vacc1x0, va1x0, vb0); + const bfloat16x8_t va2x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va2), vm0)); + vacc2x0 = vbfmlalbq_f32(vacc2x0, va2x0, vb0); + vacc2x0 = vbfmlaltq_f32(vacc2x0, va2x0, vb0); + const bfloat16x8_t va3x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va3), vm0)); + vacc3x0 = vbfmlalbq_f32(vacc3x0, va3x0, vb0); + vacc3x0 = vbfmlaltq_f32(vacc3x0, va3x0, vb0); + const bfloat16x8_t va4x0 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va4), vm0)); + vacc4x0 = vbfmlalbq_f32(vacc4x0, va4x0, vb0); + vacc4x0 = vbfmlaltq_f32(vacc4x0, va4x0, vb0); + const bfloat16x8_t va0x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm1)); + vacc0x1 = vbfmlalbq_f32(vacc0x1, va0x1, vb1); + vacc0x1 = vbfmlaltq_f32(vacc0x1, va0x1, vb1); + const bfloat16x8_t va1x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm1)); + vacc1x1 = vbfmlalbq_f32(vacc1x1, va1x1, vb1); + vacc1x1 = vbfmlaltq_f32(vacc1x1, va1x1, vb1); + const bfloat16x8_t va2x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va2), vm1)); + vacc2x1 = vbfmlalbq_f32(vacc2x1, va2x1, vb1); + vacc2x1 = vbfmlaltq_f32(vacc2x1, va2x1, vb1); + const bfloat16x8_t va3x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va3), vm1)); + vacc3x1 = vbfmlalbq_f32(vacc3x1, va3x1, vb1); + vacc3x1 = vbfmlaltq_f32(vacc3x1, va3x1, vb1); + const bfloat16x8_t va4x1 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va4), vm1)); + vacc4x1 = vbfmlalbq_f32(vacc4x1, va4x1, vb1); + vacc4x1 = vbfmlaltq_f32(vacc4x1, va4x1, vb1); + const bfloat16x8_t va0x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm2)); + vacc0x2 = vbfmlalbq_f32(vacc0x2, va0x2, vb2); + vacc0x2 = vbfmlaltq_f32(vacc0x2, va0x2, vb2); + const bfloat16x8_t va1x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm2)); + vacc1x2 = vbfmlalbq_f32(vacc1x2, va1x2, vb2); + vacc1x2 = vbfmlaltq_f32(vacc1x2, va1x2, vb2); + const bfloat16x8_t va2x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va2), vm2)); + vacc2x2 = vbfmlalbq_f32(vacc2x2, va2x2, vb2); + vacc2x2 = vbfmlaltq_f32(vacc2x2, va2x2, vb2); + const bfloat16x8_t va3x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va3), vm2)); + vacc3x2 = vbfmlalbq_f32(vacc3x2, va3x2, vb2); + vacc3x2 = vbfmlaltq_f32(vacc3x2, va3x2, vb2); + const bfloat16x8_t va4x2 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va4), vm2)); + vacc4x2 = vbfmlalbq_f32(vacc4x2, va4x2, vb2); + vacc4x2 = vbfmlaltq_f32(vacc4x2, va4x2, vb2); + const bfloat16x8_t va0x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va0), vm3)); + vacc0x3 = vbfmlalbq_f32(vacc0x3, va0x3, vb3); + vacc0x3 = vbfmlaltq_f32(vacc0x3, va0x3, vb3); + const bfloat16x8_t va1x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va1), vm3)); + vacc1x3 = vbfmlalbq_f32(vacc1x3, va1x3, vb3); + vacc1x3 = vbfmlaltq_f32(vacc1x3, va1x3, vb3); + const bfloat16x8_t va2x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va2), vm3)); + vacc2x3 = vbfmlalbq_f32(vacc2x3, va2x3, vb3); + vacc2x3 = vbfmlaltq_f32(vacc2x3, va2x3, vb3); + const bfloat16x8_t va3x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va3), vm3)); + vacc3x3 = vbfmlalbq_f32(vacc3x3, va3x3, vb3); + vacc3x3 = vbfmlaltq_f32(vacc3x3, va3x3, vb3); + const bfloat16x8_t va4x3 = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va4), vm3)); + vacc4x3 = vbfmlalbq_f32(vacc4x3, va4x3, vb3); + vacc4x3 = vbfmlaltq_f32(vacc4x3, va4x3, vb3); + } + +#if XNN_ARCH_ARM64 + const float32x4_t vacc0x01 = vpaddq_f32(vacc0x0, vacc0x1); + const float32x4_t vacc1x01 = vpaddq_f32(vacc1x0, vacc1x1); + const float32x4_t vacc2x01 = vpaddq_f32(vacc2x0, vacc2x1); + const float32x4_t vacc3x01 = vpaddq_f32(vacc3x0, vacc3x1); + const float32x4_t vacc4x01 = vpaddq_f32(vacc4x0, vacc4x1); + const float32x4_t vacc0x23 = vpaddq_f32(vacc0x2, vacc0x3); + const float32x4_t vacc1x23 = vpaddq_f32(vacc1x2, vacc1x3); + const float32x4_t vacc2x23 = vpaddq_f32(vacc2x2, vacc2x3); + const float32x4_t vacc3x23 = vpaddq_f32(vacc3x2, vacc3x3); + const float32x4_t vacc4x23 = vpaddq_f32(vacc4x2, vacc4x3); + + float32x4_t vacc0x0123 = vpaddq_f32(vacc0x01, vacc0x23); + float32x4_t vacc1x0123 = vpaddq_f32(vacc1x01, vacc1x23); + float32x4_t vacc2x0123 = vpaddq_f32(vacc2x01, vacc2x23); + float32x4_t vacc3x0123 = vpaddq_f32(vacc3x01, vacc3x23); + float32x4_t vacc4x0123 = vpaddq_f32(vacc4x01, vacc4x23); +#else + const float32x2_t vsum0x0 = vadd_f32(vget_low_f32(vacc0x0), vget_high_f32(vacc0x0)); + const float32x2_t vsum1x0 = vadd_f32(vget_low_f32(vacc1x0), vget_high_f32(vacc1x0)); + const float32x2_t vsum2x0 = vadd_f32(vget_low_f32(vacc2x0), vget_high_f32(vacc2x0)); + const float32x2_t vsum3x0 = vadd_f32(vget_low_f32(vacc3x0), vget_high_f32(vacc3x0)); + const float32x2_t vsum4x0 = vadd_f32(vget_low_f32(vacc4x0), vget_high_f32(vacc4x0)); + const float32x2_t vsum0x1 = vadd_f32(vget_low_f32(vacc0x1), vget_high_f32(vacc0x1)); + const float32x2_t vsum1x1 = vadd_f32(vget_low_f32(vacc1x1), vget_high_f32(vacc1x1)); + const float32x2_t vsum2x1 = vadd_f32(vget_low_f32(vacc2x1), vget_high_f32(vacc2x1)); + const float32x2_t vsum3x1 = vadd_f32(vget_low_f32(vacc3x1), vget_high_f32(vacc3x1)); + const float32x2_t vsum4x1 = vadd_f32(vget_low_f32(vacc4x1), vget_high_f32(vacc4x1)); + const float32x2_t vsum0x2 = vadd_f32(vget_low_f32(vacc0x2), vget_high_f32(vacc0x2)); + const float32x2_t vsum1x2 = vadd_f32(vget_low_f32(vacc1x2), vget_high_f32(vacc1x2)); + const float32x2_t vsum2x2 = vadd_f32(vget_low_f32(vacc2x2), vget_high_f32(vacc2x2)); + const float32x2_t vsum3x2 = vadd_f32(vget_low_f32(vacc3x2), vget_high_f32(vacc3x2)); + const float32x2_t vsum4x2 = vadd_f32(vget_low_f32(vacc4x2), vget_high_f32(vacc4x2)); + const float32x2_t vsum0x3 = vadd_f32(vget_low_f32(vacc0x3), vget_high_f32(vacc0x3)); + const float32x2_t vsum1x3 = vadd_f32(vget_low_f32(vacc1x3), vget_high_f32(vacc1x3)); + const float32x2_t vsum2x3 = vadd_f32(vget_low_f32(vacc2x3), vget_high_f32(vacc2x3)); + const float32x2_t vsum3x3 = vadd_f32(vget_low_f32(vacc3x3), vget_high_f32(vacc3x3)); + const float32x2_t vsum4x3 = vadd_f32(vget_low_f32(vacc4x3), vget_high_f32(vacc4x3)); + + float32x4_t vacc0x0123 = vcombine_f32(vpadd_f32(vsum0x0, vsum0x1), vpadd_f32(vsum0x2, vsum0x3)); + float32x4_t vacc1x0123 = vcombine_f32(vpadd_f32(vsum1x0, vsum1x1), vpadd_f32(vsum1x2, vsum1x3)); + float32x4_t vacc2x0123 = vcombine_f32(vpadd_f32(vsum2x0, vsum2x1), vpadd_f32(vsum2x2, vsum2x3)); + float32x4_t vacc3x0123 = vcombine_f32(vpadd_f32(vsum3x0, vsum3x1), vpadd_f32(vsum3x2, vsum3x3)); + float32x4_t vacc4x0123 = vcombine_f32(vpadd_f32(vsum4x0, vsum4x1), vpadd_f32(vsum4x2, vsum4x3)); +#endif + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + vacc0x0123 = vminq_f32(vacc0x0123, vmax); + vacc1x0123 = vminq_f32(vacc1x0123, vmax); + vacc2x0123 = vminq_f32(vacc2x0123, vmax); + vacc3x0123 = vminq_f32(vacc3x0123, vmax); + vacc4x0123 = vminq_f32(vacc4x0123, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + vacc0x0123 = vmaxq_f32(vacc0x0123, vmin); + vacc1x0123 = vmaxq_f32(vacc1x0123, vmin); + vacc2x0123 = vmaxq_f32(vacc2x0123, vmin); + vacc3x0123 = vmaxq_f32(vacc3x0123, vmin); + vacc4x0123 = vmaxq_f32(vacc4x0123, vmin); + + bfloat16x4_t vout0x0123 = vcvt_bf16_f32(vacc0x0123); + bfloat16x4_t vout1x0123 = vcvt_bf16_f32(vacc1x0123); + bfloat16x4_t vout2x0123 = vcvt_bf16_f32(vacc2x0123); + bfloat16x4_t vout3x0123 = vcvt_bf16_f32(vacc3x0123); + bfloat16x4_t vout4x0123 = vcvt_bf16_f32(vacc4x0123); + + if XNN_LIKELY(nc >= 4) { + vst1_bf16(c0, vout0x0123); + c0 = (bfloat16_t*) ((uintptr_t) c0 + cn_stride); + vst1_bf16(c1, vout1x0123); + c1 = (bfloat16_t*) ((uintptr_t) c1 + cn_stride); + vst1_bf16(c2, vout2x0123); + c2 = (bfloat16_t*) ((uintptr_t) c2 + cn_stride); + vst1_bf16(c3, vout3x0123); + c3 = (bfloat16_t*) ((uintptr_t) c3 + cn_stride); + vst1_bf16(c4, vout4x0123); + c4 = (bfloat16_t*) ((uintptr_t) c4 + cn_stride); + + a0 = (const bfloat16_t*) ((uintptr_t) a0 - kc); + a1 = (const bfloat16_t*) ((uintptr_t) a1 - kc); + a2 = (const bfloat16_t*) ((uintptr_t) a2 - kc); + a3 = (const bfloat16_t*) ((uintptr_t) a3 - kc); + a4 = (const bfloat16_t*) ((uintptr_t) a4 - kc); + + nc -= 4; + } else { + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_bf16(vout0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_bf16(vout1x0123), 0); c1 += 2; + vst1_lane_u32((void*) c2, vreinterpret_u32_bf16(vout2x0123), 0); c2 += 2; + vst1_lane_u32((void*) c3, vreinterpret_u32_bf16(vout3x0123), 0); c3 += 2; + vst1_lane_u32((void*) c4, vreinterpret_u32_bf16(vout4x0123), 0); c4 += 2; + + vout0x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout0x0123), vreinterpret_u16_bf16(vout0x0123), 2)); + vout1x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout1x0123), vreinterpret_u16_bf16(vout1x0123), 2)); + vout2x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout2x0123), vreinterpret_u16_bf16(vout2x0123), 2)); + vout3x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout3x0123), vreinterpret_u16_bf16(vout3x0123), 2)); + vout4x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout4x0123), vreinterpret_u16_bf16(vout4x0123), 2)); + } + if (nc & 1) { + vst1_lane_bf16(c0, vout0x0123, 0); + vst1_lane_bf16(c1, vout1x0123, 0); + vst1_lane_bf16(c2, vout2x0123, 0); + vst1_lane_bf16(c3, vout3x0123, 0); + vst1_lane_bf16(c4, vout4x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/bf16-gemm/gen/5x4c8-minmax-neonfma-shland.c b/src/bf16-gemm/gen/5x4c8-minmax-neonfma-shland.c new file mode 100644 index 000000000000..42e7f2814668 --- /dev/null +++ b/src/bf16-gemm/gen/5x4c8-minmax-neonfma-shland.c @@ -0,0 +1,410 @@ +// Auto-generated file. Do not edit! +// Template: src/bf16-gemm/c8-neon.c.in +// Generator: tools/xngen +// +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 5); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(uint16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const uint16_t* a0 = (const uint16_t*) a; + uint16_t* c0 = (uint16_t*) c; + const uint16_t* a1 = (const uint16_t*) ((uintptr_t) a0 + a_stride); + uint16_t* c1 = (uint16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const uint16_t* a2 = (const uint16_t*) ((uintptr_t) a1 + a_stride); + uint16_t* c2 = (uint16_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + const uint16_t* a3 = (const uint16_t*) ((uintptr_t) a2 + a_stride); + uint16_t* c3 = (uint16_t*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr < 4) { + a3 = a2; + c3 = c2; + } + const uint16_t* a4 = (const uint16_t*) ((uintptr_t) a3 + a_stride); + uint16_t* c4 = (uint16_t*) ((uintptr_t) c3 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 4) { + a4 = a3; + c4 = c3; + } + + const uint16_t* w = (const uint16_t*) w_ptr; + const uint16x8_t vmask = vreinterpretq_u16_u32(vmovq_n_u32(UINT32_C(0xFFFF0000))); + do { + float32x4_t vacc0x0 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x1 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x2 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x3 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc1x0 = vacc0x0; + float32x4_t vacc1x1 = vacc0x1; + float32x4_t vacc1x2 = vacc0x2; + float32x4_t vacc1x3 = vacc0x3; + float32x4_t vacc2x0 = vacc0x0; + float32x4_t vacc2x1 = vacc0x1; + float32x4_t vacc2x2 = vacc0x2; + float32x4_t vacc2x3 = vacc0x3; + float32x4_t vacc3x0 = vacc0x0; + float32x4_t vacc3x1 = vacc0x1; + float32x4_t vacc3x2 = vacc0x2; + float32x4_t vacc3x3 = vacc0x3; + float32x4_t vacc4x0 = vacc0x0; + float32x4_t vacc4x1 = vacc0x1; + float32x4_t vacc4x2 = vacc0x2; + float32x4_t vacc4x3 = vacc0x3; + + size_t k = kc; + for (; k >= 8 * sizeof(uint16_t); k -= 8 * sizeof(uint16_t)) { + const uint16x8_t va0 = vld1q_u16(a0); a0 += 8; + const uint16x8_t va1 = vld1q_u16(a1); a1 += 8; + const uint16x8_t va2 = vld1q_u16(a2); a2 += 8; + const uint16x8_t va3 = vld1q_u16(a3); a3 += 8; + const uint16x8_t va4 = vld1q_u16(a4); a4 += 8; + + const uint16x8_t vb0 = vld1q_u16(w); w += 8; + const uint16x8_t vb1 = vld1q_u16(w); w += 8; + const uint16x8_t vb2 = vld1q_u16(w); w += 8; + const uint16x8_t vb3 = vld1q_u16(w); w += 8; + + const float32x4_t va0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va0), 16)); + const float32x4_t va1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va1), 16)); + const float32x4_t va2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va2), 16)); + const float32x4_t va3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va3), 16)); + const float32x4_t va4e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va4), 16)); + + const float32x4_t vb0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb0), 16)); + const float32x4_t vb1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb1), 16)); + const float32x4_t vb2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb2), 16)); + const float32x4_t vb3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb3), 16)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0e, vb0e); + vacc1x0 = vfmaq_f32(vacc1x0, va1e, vb0e); + vacc2x0 = vfmaq_f32(vacc2x0, va2e, vb0e); + vacc3x0 = vfmaq_f32(vacc3x0, va3e, vb0e); + vacc4x0 = vfmaq_f32(vacc4x0, va4e, vb0e); + vacc0x1 = vfmaq_f32(vacc0x1, va0e, vb1e); + vacc1x1 = vfmaq_f32(vacc1x1, va1e, vb1e); + vacc2x1 = vfmaq_f32(vacc2x1, va2e, vb1e); + vacc3x1 = vfmaq_f32(vacc3x1, va3e, vb1e); + vacc4x1 = vfmaq_f32(vacc4x1, va4e, vb1e); + vacc0x2 = vfmaq_f32(vacc0x2, va0e, vb2e); + vacc1x2 = vfmaq_f32(vacc1x2, va1e, vb2e); + vacc2x2 = vfmaq_f32(vacc2x2, va2e, vb2e); + vacc3x2 = vfmaq_f32(vacc3x2, va3e, vb2e); + vacc4x2 = vfmaq_f32(vacc4x2, va4e, vb2e); + vacc0x3 = vfmaq_f32(vacc0x3, va0e, vb3e); + vacc1x3 = vfmaq_f32(vacc1x3, va1e, vb3e); + vacc2x3 = vfmaq_f32(vacc2x3, va2e, vb3e); + vacc3x3 = vfmaq_f32(vacc3x3, va3e, vb3e); + vacc4x3 = vfmaq_f32(vacc4x3, va4e, vb3e); + + const float32x4_t va0o = vreinterpretq_f32_u16(vandq_u16(va0, vmask)); + const float32x4_t va1o = vreinterpretq_f32_u16(vandq_u16(va1, vmask)); + const float32x4_t va2o = vreinterpretq_f32_u16(vandq_u16(va2, vmask)); + const float32x4_t va3o = vreinterpretq_f32_u16(vandq_u16(va3, vmask)); + const float32x4_t va4o = vreinterpretq_f32_u16(vandq_u16(va4, vmask)); + + const float32x4_t vb0o = vreinterpretq_f32_u16(vandq_u16(vb0, vmask)); + const float32x4_t vb1o = vreinterpretq_f32_u16(vandq_u16(vb1, vmask)); + const float32x4_t vb2o = vreinterpretq_f32_u16(vandq_u16(vb2, vmask)); + const float32x4_t vb3o = vreinterpretq_f32_u16(vandq_u16(vb3, vmask)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0o, vb0o); + vacc1x0 = vfmaq_f32(vacc1x0, va1o, vb0o); + vacc2x0 = vfmaq_f32(vacc2x0, va2o, vb0o); + vacc3x0 = vfmaq_f32(vacc3x0, va3o, vb0o); + vacc4x0 = vfmaq_f32(vacc4x0, va4o, vb0o); + vacc0x1 = vfmaq_f32(vacc0x1, va0o, vb1o); + vacc1x1 = vfmaq_f32(vacc1x1, va1o, vb1o); + vacc2x1 = vfmaq_f32(vacc2x1, va2o, vb1o); + vacc3x1 = vfmaq_f32(vacc3x1, va3o, vb1o); + vacc4x1 = vfmaq_f32(vacc4x1, va4o, vb1o); + vacc0x2 = vfmaq_f32(vacc0x2, va0o, vb2o); + vacc1x2 = vfmaq_f32(vacc1x2, va1o, vb2o); + vacc2x2 = vfmaq_f32(vacc2x2, va2o, vb2o); + vacc3x2 = vfmaq_f32(vacc3x2, va3o, vb2o); + vacc4x2 = vfmaq_f32(vacc4x2, va4o, vb2o); + vacc0x3 = vfmaq_f32(vacc0x3, va0o, vb3o); + vacc1x3 = vfmaq_f32(vacc1x3, va1o, vb3o); + vacc2x3 = vfmaq_f32(vacc2x3, va2o, vb3o); + vacc3x3 = vfmaq_f32(vacc3x3, va3o, vb3o); + vacc4x3 = vfmaq_f32(vacc4x3, va4o, vb3o); + } + if XNN_UNLIKELY(k != 0) { + const uint16x8_t va0 = vld1q_u16(a0); a0 = (const uint16_t*) ((uintptr_t) a0 + k); + const uint16x8_t va1 = vld1q_u16(a1); a1 = (const uint16_t*) ((uintptr_t) a1 + k); + const uint16x8_t va2 = vld1q_u16(a2); a2 = (const uint16_t*) ((uintptr_t) a2 + k); + const uint16x8_t va3 = vld1q_u16(a3); a3 = (const uint16_t*) ((uintptr_t) a3 + k); + const uint16x8_t va4 = vld1q_u16(a4); a4 = (const uint16_t*) ((uintptr_t) a4 + k); + + const uint16x8_t vb0 = vld1q_u16(w); w += 8; + const uint16x8_t vb1 = vld1q_u16(w); w += 8; + const uint16x8_t vb2 = vld1q_u16(w); w += 8; + const uint16x8_t vb3 = vld1q_u16(w); w += 8; + + const uint16x8_t vm0 = vceqq_u16(vb0, vmovq_n_u16(0)); + const uint16x8_t vm1 = vceqq_u16(vb1, vmovq_n_u16(0)); + const uint16x8_t vm2 = vceqq_u16(vb2, vmovq_n_u16(0)); + const uint16x8_t vm3 = vceqq_u16(vb3, vmovq_n_u16(0)); + + const float32x4_t vb0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb0), 16)); + const float32x4_t vb1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb1), 16)); + const float32x4_t vb2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb2), 16)); + const float32x4_t vb3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb3), 16)); + + const uint16x8_t va0x0 = vbicq_u16(va0, vm0); + const uint16x8_t va1x0 = vbicq_u16(va1, vm0); + const uint16x8_t va2x0 = vbicq_u16(va2, vm0); + const uint16x8_t va3x0 = vbicq_u16(va3, vm0); + const uint16x8_t va4x0 = vbicq_u16(va4, vm0); + const uint16x8_t va0x1 = vbicq_u16(va0, vm1); + const uint16x8_t va1x1 = vbicq_u16(va1, vm1); + const uint16x8_t va2x1 = vbicq_u16(va2, vm1); + const uint16x8_t va3x1 = vbicq_u16(va3, vm1); + const uint16x8_t va4x1 = vbicq_u16(va4, vm1); + const uint16x8_t va0x2 = vbicq_u16(va0, vm2); + const uint16x8_t va1x2 = vbicq_u16(va1, vm2); + const uint16x8_t va2x2 = vbicq_u16(va2, vm2); + const uint16x8_t va3x2 = vbicq_u16(va3, vm2); + const uint16x8_t va4x2 = vbicq_u16(va4, vm2); + const uint16x8_t va0x3 = vbicq_u16(va0, vm3); + const uint16x8_t va1x3 = vbicq_u16(va1, vm3); + const uint16x8_t va2x3 = vbicq_u16(va2, vm3); + const uint16x8_t va3x3 = vbicq_u16(va3, vm3); + const uint16x8_t va4x3 = vbicq_u16(va4, vm3); + + const float32x4_t va0x0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va0x0), 16)); + const float32x4_t va1x0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va1x0), 16)); + const float32x4_t va2x0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va2x0), 16)); + const float32x4_t va3x0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va3x0), 16)); + const float32x4_t va4x0e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va4x0), 16)); + const float32x4_t va0x1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va0x1), 16)); + const float32x4_t va1x1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va1x1), 16)); + const float32x4_t va2x1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va2x1), 16)); + const float32x4_t va3x1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va3x1), 16)); + const float32x4_t va4x1e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va4x1), 16)); + const float32x4_t va0x2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va0x2), 16)); + const float32x4_t va1x2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va1x2), 16)); + const float32x4_t va2x2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va2x2), 16)); + const float32x4_t va3x2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va3x2), 16)); + const float32x4_t va4x2e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va4x2), 16)); + const float32x4_t va0x3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va0x3), 16)); + const float32x4_t va1x3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va1x3), 16)); + const float32x4_t va2x3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va2x3), 16)); + const float32x4_t va3x3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va3x3), 16)); + const float32x4_t va4x3e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va4x3), 16)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0x0e, vb0e); + vacc1x0 = vfmaq_f32(vacc1x0, va1x0e, vb0e); + vacc2x0 = vfmaq_f32(vacc2x0, va2x0e, vb0e); + vacc3x0 = vfmaq_f32(vacc3x0, va3x0e, vb0e); + vacc4x0 = vfmaq_f32(vacc4x0, va4x0e, vb0e); + vacc0x1 = vfmaq_f32(vacc0x1, va0x1e, vb1e); + vacc1x1 = vfmaq_f32(vacc1x1, va1x1e, vb1e); + vacc2x1 = vfmaq_f32(vacc2x1, va2x1e, vb1e); + vacc3x1 = vfmaq_f32(vacc3x1, va3x1e, vb1e); + vacc4x1 = vfmaq_f32(vacc4x1, va4x1e, vb1e); + vacc0x2 = vfmaq_f32(vacc0x2, va0x2e, vb2e); + vacc1x2 = vfmaq_f32(vacc1x2, va1x2e, vb2e); + vacc2x2 = vfmaq_f32(vacc2x2, va2x2e, vb2e); + vacc3x2 = vfmaq_f32(vacc3x2, va3x2e, vb2e); + vacc4x2 = vfmaq_f32(vacc4x2, va4x2e, vb2e); + vacc0x3 = vfmaq_f32(vacc0x3, va0x3e, vb3e); + vacc1x3 = vfmaq_f32(vacc1x3, va1x3e, vb3e); + vacc2x3 = vfmaq_f32(vacc2x3, va2x3e, vb3e); + vacc3x3 = vfmaq_f32(vacc3x3, va3x3e, vb3e); + vacc4x3 = vfmaq_f32(vacc4x3, va4x3e, vb3e); + + const float32x4_t vb0o = vreinterpretq_f32_u16(vandq_u16(vb0, vmask)); + const float32x4_t vb1o = vreinterpretq_f32_u16(vandq_u16(vb1, vmask)); + const float32x4_t vb2o = vreinterpretq_f32_u16(vandq_u16(vb2, vmask)); + const float32x4_t vb3o = vreinterpretq_f32_u16(vandq_u16(vb3, vmask)); + + const float32x4_t va0x0o = vreinterpretq_f32_u16(vandq_u16(va0x0, vmask)); + const float32x4_t va1x0o = vreinterpretq_f32_u16(vandq_u16(va1x0, vmask)); + const float32x4_t va2x0o = vreinterpretq_f32_u16(vandq_u16(va2x0, vmask)); + const float32x4_t va3x0o = vreinterpretq_f32_u16(vandq_u16(va3x0, vmask)); + const float32x4_t va4x0o = vreinterpretq_f32_u16(vandq_u16(va4x0, vmask)); + const float32x4_t va0x1o = vreinterpretq_f32_u16(vandq_u16(va0x1, vmask)); + const float32x4_t va1x1o = vreinterpretq_f32_u16(vandq_u16(va1x1, vmask)); + const float32x4_t va2x1o = vreinterpretq_f32_u16(vandq_u16(va2x1, vmask)); + const float32x4_t va3x1o = vreinterpretq_f32_u16(vandq_u16(va3x1, vmask)); + const float32x4_t va4x1o = vreinterpretq_f32_u16(vandq_u16(va4x1, vmask)); + const float32x4_t va0x2o = vreinterpretq_f32_u16(vandq_u16(va0x2, vmask)); + const float32x4_t va1x2o = vreinterpretq_f32_u16(vandq_u16(va1x2, vmask)); + const float32x4_t va2x2o = vreinterpretq_f32_u16(vandq_u16(va2x2, vmask)); + const float32x4_t va3x2o = vreinterpretq_f32_u16(vandq_u16(va3x2, vmask)); + const float32x4_t va4x2o = vreinterpretq_f32_u16(vandq_u16(va4x2, vmask)); + const float32x4_t va0x3o = vreinterpretq_f32_u16(vandq_u16(va0x3, vmask)); + const float32x4_t va1x3o = vreinterpretq_f32_u16(vandq_u16(va1x3, vmask)); + const float32x4_t va2x3o = vreinterpretq_f32_u16(vandq_u16(va2x3, vmask)); + const float32x4_t va3x3o = vreinterpretq_f32_u16(vandq_u16(va3x3, vmask)); + const float32x4_t va4x3o = vreinterpretq_f32_u16(vandq_u16(va4x3, vmask)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0x0o, vb0o); + vacc1x0 = vfmaq_f32(vacc1x0, va1x0o, vb0o); + vacc2x0 = vfmaq_f32(vacc2x0, va2x0o, vb0o); + vacc3x0 = vfmaq_f32(vacc3x0, va3x0o, vb0o); + vacc4x0 = vfmaq_f32(vacc4x0, va4x0o, vb0o); + vacc0x1 = vfmaq_f32(vacc0x1, va0x1o, vb1o); + vacc1x1 = vfmaq_f32(vacc1x1, va1x1o, vb1o); + vacc2x1 = vfmaq_f32(vacc2x1, va2x1o, vb1o); + vacc3x1 = vfmaq_f32(vacc3x1, va3x1o, vb1o); + vacc4x1 = vfmaq_f32(vacc4x1, va4x1o, vb1o); + vacc0x2 = vfmaq_f32(vacc0x2, va0x2o, vb2o); + vacc1x2 = vfmaq_f32(vacc1x2, va1x2o, vb2o); + vacc2x2 = vfmaq_f32(vacc2x2, va2x2o, vb2o); + vacc3x2 = vfmaq_f32(vacc3x2, va3x2o, vb2o); + vacc4x2 = vfmaq_f32(vacc4x2, va4x2o, vb2o); + vacc0x3 = vfmaq_f32(vacc0x3, va0x3o, vb3o); + vacc1x3 = vfmaq_f32(vacc1x3, va1x3o, vb3o); + vacc2x3 = vfmaq_f32(vacc2x3, va2x3o, vb3o); + vacc3x3 = vfmaq_f32(vacc3x3, va3x3o, vb3o); + vacc4x3 = vfmaq_f32(vacc4x3, va4x3o, vb3o); + } + +#if XNN_ARCH_ARM64 + const float32x4_t vacc0x01 = vpaddq_f32(vacc0x0, vacc0x1); + const float32x4_t vacc1x01 = vpaddq_f32(vacc1x0, vacc1x1); + const float32x4_t vacc2x01 = vpaddq_f32(vacc2x0, vacc2x1); + const float32x4_t vacc3x01 = vpaddq_f32(vacc3x0, vacc3x1); + const float32x4_t vacc4x01 = vpaddq_f32(vacc4x0, vacc4x1); + const float32x4_t vacc0x23 = vpaddq_f32(vacc0x2, vacc0x3); + const float32x4_t vacc1x23 = vpaddq_f32(vacc1x2, vacc1x3); + const float32x4_t vacc2x23 = vpaddq_f32(vacc2x2, vacc2x3); + const float32x4_t vacc3x23 = vpaddq_f32(vacc3x2, vacc3x3); + const float32x4_t vacc4x23 = vpaddq_f32(vacc4x2, vacc4x3); + + float32x4_t vacc0x0123 = vpaddq_f32(vacc0x01, vacc0x23); + float32x4_t vacc1x0123 = vpaddq_f32(vacc1x01, vacc1x23); + float32x4_t vacc2x0123 = vpaddq_f32(vacc2x01, vacc2x23); + float32x4_t vacc3x0123 = vpaddq_f32(vacc3x01, vacc3x23); + float32x4_t vacc4x0123 = vpaddq_f32(vacc4x01, vacc4x23); +#else + const float32x2_t vsum0x0 = vadd_f32(vget_low_f32(vacc0x0), vget_high_f32(vacc0x0)); + const float32x2_t vsum1x0 = vadd_f32(vget_low_f32(vacc1x0), vget_high_f32(vacc1x0)); + const float32x2_t vsum2x0 = vadd_f32(vget_low_f32(vacc2x0), vget_high_f32(vacc2x0)); + const float32x2_t vsum3x0 = vadd_f32(vget_low_f32(vacc3x0), vget_high_f32(vacc3x0)); + const float32x2_t vsum4x0 = vadd_f32(vget_low_f32(vacc4x0), vget_high_f32(vacc4x0)); + const float32x2_t vsum0x1 = vadd_f32(vget_low_f32(vacc0x1), vget_high_f32(vacc0x1)); + const float32x2_t vsum1x1 = vadd_f32(vget_low_f32(vacc1x1), vget_high_f32(vacc1x1)); + const float32x2_t vsum2x1 = vadd_f32(vget_low_f32(vacc2x1), vget_high_f32(vacc2x1)); + const float32x2_t vsum3x1 = vadd_f32(vget_low_f32(vacc3x1), vget_high_f32(vacc3x1)); + const float32x2_t vsum4x1 = vadd_f32(vget_low_f32(vacc4x1), vget_high_f32(vacc4x1)); + const float32x2_t vsum0x2 = vadd_f32(vget_low_f32(vacc0x2), vget_high_f32(vacc0x2)); + const float32x2_t vsum1x2 = vadd_f32(vget_low_f32(vacc1x2), vget_high_f32(vacc1x2)); + const float32x2_t vsum2x2 = vadd_f32(vget_low_f32(vacc2x2), vget_high_f32(vacc2x2)); + const float32x2_t vsum3x2 = vadd_f32(vget_low_f32(vacc3x2), vget_high_f32(vacc3x2)); + const float32x2_t vsum4x2 = vadd_f32(vget_low_f32(vacc4x2), vget_high_f32(vacc4x2)); + const float32x2_t vsum0x3 = vadd_f32(vget_low_f32(vacc0x3), vget_high_f32(vacc0x3)); + const float32x2_t vsum1x3 = vadd_f32(vget_low_f32(vacc1x3), vget_high_f32(vacc1x3)); + const float32x2_t vsum2x3 = vadd_f32(vget_low_f32(vacc2x3), vget_high_f32(vacc2x3)); + const float32x2_t vsum3x3 = vadd_f32(vget_low_f32(vacc3x3), vget_high_f32(vacc3x3)); + const float32x2_t vsum4x3 = vadd_f32(vget_low_f32(vacc4x3), vget_high_f32(vacc4x3)); + + float32x4_t vacc0x0123 = vcombine_f32(vpadd_f32(vsum0x0, vsum0x1), vpadd_f32(vsum0x2, vsum0x3)); + float32x4_t vacc1x0123 = vcombine_f32(vpadd_f32(vsum1x0, vsum1x1), vpadd_f32(vsum1x2, vsum1x3)); + float32x4_t vacc2x0123 = vcombine_f32(vpadd_f32(vsum2x0, vsum2x1), vpadd_f32(vsum2x2, vsum2x3)); + float32x4_t vacc3x0123 = vcombine_f32(vpadd_f32(vsum3x0, vsum3x1), vpadd_f32(vsum3x2, vsum3x3)); + float32x4_t vacc4x0123 = vcombine_f32(vpadd_f32(vsum4x0, vsum4x1), vpadd_f32(vsum4x2, vsum4x3)); +#endif + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + vacc0x0123 = vminq_f32(vacc0x0123, vmax); + vacc1x0123 = vminq_f32(vacc1x0123, vmax); + vacc2x0123 = vminq_f32(vacc2x0123, vmax); + vacc3x0123 = vminq_f32(vacc3x0123, vmax); + vacc4x0123 = vminq_f32(vacc4x0123, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + vacc0x0123 = vmaxq_f32(vacc0x0123, vmin); + vacc1x0123 = vmaxq_f32(vacc1x0123, vmin); + vacc2x0123 = vmaxq_f32(vacc2x0123, vmin); + vacc3x0123 = vmaxq_f32(vacc3x0123, vmin); + vacc4x0123 = vmaxq_f32(vacc4x0123, vmin); + + uint16x4_t vout0x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc0x0123), 16); + uint16x4_t vout1x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc1x0123), 16); + uint16x4_t vout2x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc2x0123), 16); + uint16x4_t vout3x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc3x0123), 16); + uint16x4_t vout4x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc4x0123), 16); + + if XNN_LIKELY(nc >= 4) { + vst1_u16(c0, vout0x0123); + c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); + vst1_u16(c1, vout1x0123); + c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride); + vst1_u16(c2, vout2x0123); + c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride); + vst1_u16(c3, vout3x0123); + c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride); + vst1_u16(c4, vout4x0123); + c4 = (uint16_t*) ((uintptr_t) c4 + cn_stride); + + a0 = (const uint16_t*) ((uintptr_t) a0 - kc); + a1 = (const uint16_t*) ((uintptr_t) a1 - kc); + a2 = (const uint16_t*) ((uintptr_t) a2 - kc); + a3 = (const uint16_t*) ((uintptr_t) a3 - kc); + a4 = (const uint16_t*) ((uintptr_t) a4 - kc); + + nc -= 4; + } else { + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_u16(vout0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_u16(vout1x0123), 0); c1 += 2; + vst1_lane_u32((void*) c2, vreinterpret_u32_u16(vout2x0123), 0); c2 += 2; + vst1_lane_u32((void*) c3, vreinterpret_u32_u16(vout3x0123), 0); c3 += 2; + vst1_lane_u32((void*) c4, vreinterpret_u32_u16(vout4x0123), 0); c4 += 2; + + vout0x0123 = vext_u16(vout0x0123, vout0x0123, 2); + vout1x0123 = vext_u16(vout1x0123, vout1x0123, 2); + vout2x0123 = vext_u16(vout2x0123, vout2x0123, 2); + vout3x0123 = vext_u16(vout3x0123, vout3x0123, 2); + vout4x0123 = vext_u16(vout4x0123, vout4x0123, 2); + } + if (nc & 1) { + vst1_lane_u16(c0, vout0x0123, 0); + vst1_lane_u16(c1, vout1x0123, 0); + vst1_lane_u16(c2, vout2x0123, 0); + vst1_lane_u16(c3, vout3x0123, 0); + vst1_lane_u16(c4, vout4x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/bf16-gemm/gen/5x4c8-minmax-neonfma-zip.c b/src/bf16-gemm/gen/5x4c8-minmax-neonfma-zip.c new file mode 100644 index 000000000000..34581a829531 --- /dev/null +++ b/src/bf16-gemm/gen/5x4c8-minmax-neonfma-zip.c @@ -0,0 +1,410 @@ +// Auto-generated file. Do not edit! +// Template: src/bf16-gemm/c8-neon.c.in +// Generator: tools/xngen +// +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 5); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(uint16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const uint16_t* a0 = (const uint16_t*) a; + uint16_t* c0 = (uint16_t*) c; + const uint16_t* a1 = (const uint16_t*) ((uintptr_t) a0 + a_stride); + uint16_t* c1 = (uint16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const uint16_t* a2 = (const uint16_t*) ((uintptr_t) a1 + a_stride); + uint16_t* c2 = (uint16_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + const uint16_t* a3 = (const uint16_t*) ((uintptr_t) a2 + a_stride); + uint16_t* c3 = (uint16_t*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr < 4) { + a3 = a2; + c3 = c2; + } + const uint16_t* a4 = (const uint16_t*) ((uintptr_t) a3 + a_stride); + uint16_t* c4 = (uint16_t*) ((uintptr_t) c3 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 4) { + a4 = a3; + c4 = c3; + } + + const uint16_t* w = (const uint16_t*) w_ptr; + const uint16x8_t vzero = vmovq_n_u16(0); + do { + float32x4_t vacc0x0 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x1 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x2 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc0x3 = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; + float32x4_t vacc1x0 = vacc0x0; + float32x4_t vacc1x1 = vacc0x1; + float32x4_t vacc1x2 = vacc0x2; + float32x4_t vacc1x3 = vacc0x3; + float32x4_t vacc2x0 = vacc0x0; + float32x4_t vacc2x1 = vacc0x1; + float32x4_t vacc2x2 = vacc0x2; + float32x4_t vacc2x3 = vacc0x3; + float32x4_t vacc3x0 = vacc0x0; + float32x4_t vacc3x1 = vacc0x1; + float32x4_t vacc3x2 = vacc0x2; + float32x4_t vacc3x3 = vacc0x3; + float32x4_t vacc4x0 = vacc0x0; + float32x4_t vacc4x1 = vacc0x1; + float32x4_t vacc4x2 = vacc0x2; + float32x4_t vacc4x3 = vacc0x3; + + size_t k = kc; + for (; k >= 8 * sizeof(uint16_t); k -= 8 * sizeof(uint16_t)) { + const uint16x8_t va0 = vld1q_u16(a0); a0 += 8; + const uint16x8_t va1 = vld1q_u16(a1); a1 += 8; + const uint16x8_t va2 = vld1q_u16(a2); a2 += 8; + const uint16x8_t va3 = vld1q_u16(a3); a3 += 8; + const uint16x8_t va4 = vld1q_u16(a4); a4 += 8; + + const uint16x8_t vb0 = vld1q_u16(w); w += 8; + const uint16x8_t vb1 = vld1q_u16(w); w += 8; + const uint16x8_t vb2 = vld1q_u16(w); w += 8; + const uint16x8_t vb3 = vld1q_u16(w); w += 8; + + const float32x4_t va0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va0)); + const float32x4_t va1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va1)); + const float32x4_t va2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va2)); + const float32x4_t va3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va3)); + const float32x4_t va4e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va4)); + + const float32x4_t vb0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb0)); + const float32x4_t vb1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb1)); + const float32x4_t vb2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb2)); + const float32x4_t vb3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb3)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0e, vb0e); + vacc1x0 = vfmaq_f32(vacc1x0, va1e, vb0e); + vacc2x0 = vfmaq_f32(vacc2x0, va2e, vb0e); + vacc3x0 = vfmaq_f32(vacc3x0, va3e, vb0e); + vacc4x0 = vfmaq_f32(vacc4x0, va4e, vb0e); + vacc0x1 = vfmaq_f32(vacc0x1, va0e, vb1e); + vacc1x1 = vfmaq_f32(vacc1x1, va1e, vb1e); + vacc2x1 = vfmaq_f32(vacc2x1, va2e, vb1e); + vacc3x1 = vfmaq_f32(vacc3x1, va3e, vb1e); + vacc4x1 = vfmaq_f32(vacc4x1, va4e, vb1e); + vacc0x2 = vfmaq_f32(vacc0x2, va0e, vb2e); + vacc1x2 = vfmaq_f32(vacc1x2, va1e, vb2e); + vacc2x2 = vfmaq_f32(vacc2x2, va2e, vb2e); + vacc3x2 = vfmaq_f32(vacc3x2, va3e, vb2e); + vacc4x2 = vfmaq_f32(vacc4x2, va4e, vb2e); + vacc0x3 = vfmaq_f32(vacc0x3, va0e, vb3e); + vacc1x3 = vfmaq_f32(vacc1x3, va1e, vb3e); + vacc2x3 = vfmaq_f32(vacc2x3, va2e, vb3e); + vacc3x3 = vfmaq_f32(vacc3x3, va3e, vb3e); + vacc4x3 = vfmaq_f32(vacc4x3, va4e, vb3e); + + const float32x4_t va0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va0)); + const float32x4_t va1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va1)); + const float32x4_t va2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va2)); + const float32x4_t va3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va3)); + const float32x4_t va4o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va4)); + + const float32x4_t vb0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb0)); + const float32x4_t vb1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb1)); + const float32x4_t vb2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb2)); + const float32x4_t vb3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb3)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0o, vb0o); + vacc1x0 = vfmaq_f32(vacc1x0, va1o, vb0o); + vacc2x0 = vfmaq_f32(vacc2x0, va2o, vb0o); + vacc3x0 = vfmaq_f32(vacc3x0, va3o, vb0o); + vacc4x0 = vfmaq_f32(vacc4x0, va4o, vb0o); + vacc0x1 = vfmaq_f32(vacc0x1, va0o, vb1o); + vacc1x1 = vfmaq_f32(vacc1x1, va1o, vb1o); + vacc2x1 = vfmaq_f32(vacc2x1, va2o, vb1o); + vacc3x1 = vfmaq_f32(vacc3x1, va3o, vb1o); + vacc4x1 = vfmaq_f32(vacc4x1, va4o, vb1o); + vacc0x2 = vfmaq_f32(vacc0x2, va0o, vb2o); + vacc1x2 = vfmaq_f32(vacc1x2, va1o, vb2o); + vacc2x2 = vfmaq_f32(vacc2x2, va2o, vb2o); + vacc3x2 = vfmaq_f32(vacc3x2, va3o, vb2o); + vacc4x2 = vfmaq_f32(vacc4x2, va4o, vb2o); + vacc0x3 = vfmaq_f32(vacc0x3, va0o, vb3o); + vacc1x3 = vfmaq_f32(vacc1x3, va1o, vb3o); + vacc2x3 = vfmaq_f32(vacc2x3, va2o, vb3o); + vacc3x3 = vfmaq_f32(vacc3x3, va3o, vb3o); + vacc4x3 = vfmaq_f32(vacc4x3, va4o, vb3o); + } + if XNN_UNLIKELY(k != 0) { + const uint16x8_t va0 = vld1q_u16(a0); a0 = (const uint16_t*) ((uintptr_t) a0 + k); + const uint16x8_t va1 = vld1q_u16(a1); a1 = (const uint16_t*) ((uintptr_t) a1 + k); + const uint16x8_t va2 = vld1q_u16(a2); a2 = (const uint16_t*) ((uintptr_t) a2 + k); + const uint16x8_t va3 = vld1q_u16(a3); a3 = (const uint16_t*) ((uintptr_t) a3 + k); + const uint16x8_t va4 = vld1q_u16(a4); a4 = (const uint16_t*) ((uintptr_t) a4 + k); + + const uint16x8_t vb0 = vld1q_u16(w); w += 8; + const uint16x8_t vb1 = vld1q_u16(w); w += 8; + const uint16x8_t vb2 = vld1q_u16(w); w += 8; + const uint16x8_t vb3 = vld1q_u16(w); w += 8; + + const uint16x8_t vm0 = vceqq_u16(vb0, vmovq_n_u16(0)); + const uint16x8_t vm1 = vceqq_u16(vb1, vmovq_n_u16(0)); + const uint16x8_t vm2 = vceqq_u16(vb2, vmovq_n_u16(0)); + const uint16x8_t vm3 = vceqq_u16(vb3, vmovq_n_u16(0)); + + const float32x4_t vb0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb0)); + const float32x4_t vb1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb1)); + const float32x4_t vb2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb2)); + const float32x4_t vb3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb3)); + + const uint16x8_t va0x0 = vbicq_u16(va0, vm0); + const uint16x8_t va1x0 = vbicq_u16(va1, vm0); + const uint16x8_t va2x0 = vbicq_u16(va2, vm0); + const uint16x8_t va3x0 = vbicq_u16(va3, vm0); + const uint16x8_t va4x0 = vbicq_u16(va4, vm0); + const uint16x8_t va0x1 = vbicq_u16(va0, vm1); + const uint16x8_t va1x1 = vbicq_u16(va1, vm1); + const uint16x8_t va2x1 = vbicq_u16(va2, vm1); + const uint16x8_t va3x1 = vbicq_u16(va3, vm1); + const uint16x8_t va4x1 = vbicq_u16(va4, vm1); + const uint16x8_t va0x2 = vbicq_u16(va0, vm2); + const uint16x8_t va1x2 = vbicq_u16(va1, vm2); + const uint16x8_t va2x2 = vbicq_u16(va2, vm2); + const uint16x8_t va3x2 = vbicq_u16(va3, vm2); + const uint16x8_t va4x2 = vbicq_u16(va4, vm2); + const uint16x8_t va0x3 = vbicq_u16(va0, vm3); + const uint16x8_t va1x3 = vbicq_u16(va1, vm3); + const uint16x8_t va2x3 = vbicq_u16(va2, vm3); + const uint16x8_t va3x3 = vbicq_u16(va3, vm3); + const uint16x8_t va4x3 = vbicq_u16(va4, vm3); + + const float32x4_t va0x0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va0x0)); + const float32x4_t va1x0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va1x0)); + const float32x4_t va2x0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va2x0)); + const float32x4_t va3x0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va3x0)); + const float32x4_t va4x0e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va4x0)); + const float32x4_t va0x1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va0x1)); + const float32x4_t va1x1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va1x1)); + const float32x4_t va2x1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va2x1)); + const float32x4_t va3x1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va3x1)); + const float32x4_t va4x1e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va4x1)); + const float32x4_t va0x2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va0x2)); + const float32x4_t va1x2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va1x2)); + const float32x4_t va2x2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va2x2)); + const float32x4_t va3x2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va3x2)); + const float32x4_t va4x2e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va4x2)); + const float32x4_t va0x3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va0x3)); + const float32x4_t va1x3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va1x3)); + const float32x4_t va2x3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va2x3)); + const float32x4_t va3x3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va3x3)); + const float32x4_t va4x3e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va4x3)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0x0e, vb0e); + vacc1x0 = vfmaq_f32(vacc1x0, va1x0e, vb0e); + vacc2x0 = vfmaq_f32(vacc2x0, va2x0e, vb0e); + vacc3x0 = vfmaq_f32(vacc3x0, va3x0e, vb0e); + vacc4x0 = vfmaq_f32(vacc4x0, va4x0e, vb0e); + vacc0x1 = vfmaq_f32(vacc0x1, va0x1e, vb1e); + vacc1x1 = vfmaq_f32(vacc1x1, va1x1e, vb1e); + vacc2x1 = vfmaq_f32(vacc2x1, va2x1e, vb1e); + vacc3x1 = vfmaq_f32(vacc3x1, va3x1e, vb1e); + vacc4x1 = vfmaq_f32(vacc4x1, va4x1e, vb1e); + vacc0x2 = vfmaq_f32(vacc0x2, va0x2e, vb2e); + vacc1x2 = vfmaq_f32(vacc1x2, va1x2e, vb2e); + vacc2x2 = vfmaq_f32(vacc2x2, va2x2e, vb2e); + vacc3x2 = vfmaq_f32(vacc3x2, va3x2e, vb2e); + vacc4x2 = vfmaq_f32(vacc4x2, va4x2e, vb2e); + vacc0x3 = vfmaq_f32(vacc0x3, va0x3e, vb3e); + vacc1x3 = vfmaq_f32(vacc1x3, va1x3e, vb3e); + vacc2x3 = vfmaq_f32(vacc2x3, va2x3e, vb3e); + vacc3x3 = vfmaq_f32(vacc3x3, va3x3e, vb3e); + vacc4x3 = vfmaq_f32(vacc4x3, va4x3e, vb3e); + + const float32x4_t vb0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb0)); + const float32x4_t vb1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb1)); + const float32x4_t vb2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb2)); + const float32x4_t vb3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb3)); + + const float32x4_t va0x0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va0x0)); + const float32x4_t va1x0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va1x0)); + const float32x4_t va2x0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va2x0)); + const float32x4_t va3x0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va3x0)); + const float32x4_t va4x0o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va4x0)); + const float32x4_t va0x1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va0x1)); + const float32x4_t va1x1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va1x1)); + const float32x4_t va2x1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va2x1)); + const float32x4_t va3x1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va3x1)); + const float32x4_t va4x1o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va4x1)); + const float32x4_t va0x2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va0x2)); + const float32x4_t va1x2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va1x2)); + const float32x4_t va2x2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va2x2)); + const float32x4_t va3x2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va3x2)); + const float32x4_t va4x2o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va4x2)); + const float32x4_t va0x3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va0x3)); + const float32x4_t va1x3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va1x3)); + const float32x4_t va2x3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va2x3)); + const float32x4_t va3x3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va3x3)); + const float32x4_t va4x3o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va4x3)); + + vacc0x0 = vfmaq_f32(vacc0x0, va0x0o, vb0o); + vacc1x0 = vfmaq_f32(vacc1x0, va1x0o, vb0o); + vacc2x0 = vfmaq_f32(vacc2x0, va2x0o, vb0o); + vacc3x0 = vfmaq_f32(vacc3x0, va3x0o, vb0o); + vacc4x0 = vfmaq_f32(vacc4x0, va4x0o, vb0o); + vacc0x1 = vfmaq_f32(vacc0x1, va0x1o, vb1o); + vacc1x1 = vfmaq_f32(vacc1x1, va1x1o, vb1o); + vacc2x1 = vfmaq_f32(vacc2x1, va2x1o, vb1o); + vacc3x1 = vfmaq_f32(vacc3x1, va3x1o, vb1o); + vacc4x1 = vfmaq_f32(vacc4x1, va4x1o, vb1o); + vacc0x2 = vfmaq_f32(vacc0x2, va0x2o, vb2o); + vacc1x2 = vfmaq_f32(vacc1x2, va1x2o, vb2o); + vacc2x2 = vfmaq_f32(vacc2x2, va2x2o, vb2o); + vacc3x2 = vfmaq_f32(vacc3x2, va3x2o, vb2o); + vacc4x2 = vfmaq_f32(vacc4x2, va4x2o, vb2o); + vacc0x3 = vfmaq_f32(vacc0x3, va0x3o, vb3o); + vacc1x3 = vfmaq_f32(vacc1x3, va1x3o, vb3o); + vacc2x3 = vfmaq_f32(vacc2x3, va2x3o, vb3o); + vacc3x3 = vfmaq_f32(vacc3x3, va3x3o, vb3o); + vacc4x3 = vfmaq_f32(vacc4x3, va4x3o, vb3o); + } + +#if XNN_ARCH_ARM64 + const float32x4_t vacc0x01 = vpaddq_f32(vacc0x0, vacc0x1); + const float32x4_t vacc1x01 = vpaddq_f32(vacc1x0, vacc1x1); + const float32x4_t vacc2x01 = vpaddq_f32(vacc2x0, vacc2x1); + const float32x4_t vacc3x01 = vpaddq_f32(vacc3x0, vacc3x1); + const float32x4_t vacc4x01 = vpaddq_f32(vacc4x0, vacc4x1); + const float32x4_t vacc0x23 = vpaddq_f32(vacc0x2, vacc0x3); + const float32x4_t vacc1x23 = vpaddq_f32(vacc1x2, vacc1x3); + const float32x4_t vacc2x23 = vpaddq_f32(vacc2x2, vacc2x3); + const float32x4_t vacc3x23 = vpaddq_f32(vacc3x2, vacc3x3); + const float32x4_t vacc4x23 = vpaddq_f32(vacc4x2, vacc4x3); + + float32x4_t vacc0x0123 = vpaddq_f32(vacc0x01, vacc0x23); + float32x4_t vacc1x0123 = vpaddq_f32(vacc1x01, vacc1x23); + float32x4_t vacc2x0123 = vpaddq_f32(vacc2x01, vacc2x23); + float32x4_t vacc3x0123 = vpaddq_f32(vacc3x01, vacc3x23); + float32x4_t vacc4x0123 = vpaddq_f32(vacc4x01, vacc4x23); +#else + const float32x2_t vsum0x0 = vadd_f32(vget_low_f32(vacc0x0), vget_high_f32(vacc0x0)); + const float32x2_t vsum1x0 = vadd_f32(vget_low_f32(vacc1x0), vget_high_f32(vacc1x0)); + const float32x2_t vsum2x0 = vadd_f32(vget_low_f32(vacc2x0), vget_high_f32(vacc2x0)); + const float32x2_t vsum3x0 = vadd_f32(vget_low_f32(vacc3x0), vget_high_f32(vacc3x0)); + const float32x2_t vsum4x0 = vadd_f32(vget_low_f32(vacc4x0), vget_high_f32(vacc4x0)); + const float32x2_t vsum0x1 = vadd_f32(vget_low_f32(vacc0x1), vget_high_f32(vacc0x1)); + const float32x2_t vsum1x1 = vadd_f32(vget_low_f32(vacc1x1), vget_high_f32(vacc1x1)); + const float32x2_t vsum2x1 = vadd_f32(vget_low_f32(vacc2x1), vget_high_f32(vacc2x1)); + const float32x2_t vsum3x1 = vadd_f32(vget_low_f32(vacc3x1), vget_high_f32(vacc3x1)); + const float32x2_t vsum4x1 = vadd_f32(vget_low_f32(vacc4x1), vget_high_f32(vacc4x1)); + const float32x2_t vsum0x2 = vadd_f32(vget_low_f32(vacc0x2), vget_high_f32(vacc0x2)); + const float32x2_t vsum1x2 = vadd_f32(vget_low_f32(vacc1x2), vget_high_f32(vacc1x2)); + const float32x2_t vsum2x2 = vadd_f32(vget_low_f32(vacc2x2), vget_high_f32(vacc2x2)); + const float32x2_t vsum3x2 = vadd_f32(vget_low_f32(vacc3x2), vget_high_f32(vacc3x2)); + const float32x2_t vsum4x2 = vadd_f32(vget_low_f32(vacc4x2), vget_high_f32(vacc4x2)); + const float32x2_t vsum0x3 = vadd_f32(vget_low_f32(vacc0x3), vget_high_f32(vacc0x3)); + const float32x2_t vsum1x3 = vadd_f32(vget_low_f32(vacc1x3), vget_high_f32(vacc1x3)); + const float32x2_t vsum2x3 = vadd_f32(vget_low_f32(vacc2x3), vget_high_f32(vacc2x3)); + const float32x2_t vsum3x3 = vadd_f32(vget_low_f32(vacc3x3), vget_high_f32(vacc3x3)); + const float32x2_t vsum4x3 = vadd_f32(vget_low_f32(vacc4x3), vget_high_f32(vacc4x3)); + + float32x4_t vacc0x0123 = vcombine_f32(vpadd_f32(vsum0x0, vsum0x1), vpadd_f32(vsum0x2, vsum0x3)); + float32x4_t vacc1x0123 = vcombine_f32(vpadd_f32(vsum1x0, vsum1x1), vpadd_f32(vsum1x2, vsum1x3)); + float32x4_t vacc2x0123 = vcombine_f32(vpadd_f32(vsum2x0, vsum2x1), vpadd_f32(vsum2x2, vsum2x3)); + float32x4_t vacc3x0123 = vcombine_f32(vpadd_f32(vsum3x0, vsum3x1), vpadd_f32(vsum3x2, vsum3x3)); + float32x4_t vacc4x0123 = vcombine_f32(vpadd_f32(vsum4x0, vsum4x1), vpadd_f32(vsum4x2, vsum4x3)); +#endif + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + vacc0x0123 = vminq_f32(vacc0x0123, vmax); + vacc1x0123 = vminq_f32(vacc1x0123, vmax); + vacc2x0123 = vminq_f32(vacc2x0123, vmax); + vacc3x0123 = vminq_f32(vacc3x0123, vmax); + vacc4x0123 = vminq_f32(vacc4x0123, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + vacc0x0123 = vmaxq_f32(vacc0x0123, vmin); + vacc1x0123 = vmaxq_f32(vacc1x0123, vmin); + vacc2x0123 = vmaxq_f32(vacc2x0123, vmin); + vacc3x0123 = vmaxq_f32(vacc3x0123, vmin); + vacc4x0123 = vmaxq_f32(vacc4x0123, vmin); + + uint16x4_t vout0x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc0x0123), 16); + uint16x4_t vout1x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc1x0123), 16); + uint16x4_t vout2x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc2x0123), 16); + uint16x4_t vout3x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc3x0123), 16); + uint16x4_t vout4x0123 = vshrn_n_u32(vreinterpretq_u32_f32(vacc4x0123), 16); + + if XNN_LIKELY(nc >= 4) { + vst1_u16(c0, vout0x0123); + c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); + vst1_u16(c1, vout1x0123); + c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride); + vst1_u16(c2, vout2x0123); + c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride); + vst1_u16(c3, vout3x0123); + c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride); + vst1_u16(c4, vout4x0123); + c4 = (uint16_t*) ((uintptr_t) c4 + cn_stride); + + a0 = (const uint16_t*) ((uintptr_t) a0 - kc); + a1 = (const uint16_t*) ((uintptr_t) a1 - kc); + a2 = (const uint16_t*) ((uintptr_t) a2 - kc); + a3 = (const uint16_t*) ((uintptr_t) a3 - kc); + a4 = (const uint16_t*) ((uintptr_t) a4 - kc); + + nc -= 4; + } else { + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_u16(vout0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_u16(vout1x0123), 0); c1 += 2; + vst1_lane_u32((void*) c2, vreinterpret_u32_u16(vout2x0123), 0); c2 += 2; + vst1_lane_u32((void*) c3, vreinterpret_u32_u16(vout3x0123), 0); c3 += 2; + vst1_lane_u32((void*) c4, vreinterpret_u32_u16(vout4x0123), 0); c4 += 2; + + vout0x0123 = vext_u16(vout0x0123, vout0x0123, 2); + vout1x0123 = vext_u16(vout1x0123, vout1x0123, 2); + vout2x0123 = vext_u16(vout2x0123, vout2x0123, 2); + vout3x0123 = vext_u16(vout3x0123, vout3x0123, 2); + vout4x0123 = vext_u16(vout4x0123, vout4x0123, 2); + } + if (nc & 1) { + vst1_lane_u16(c0, vout0x0123, 0); + vst1_lane_u16(c1, vout1x0123, 0); + vst1_lane_u16(c2, vout2x0123, 0); + vst1_lane_u16(c3, vout3x0123, 0); + vst1_lane_u16(c4, vout4x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/bf16-gemm/gen/5x8c2-minmax-neonbf16-bfdot-lane-ld128.c b/src/bf16-gemm/gen/5x8c2-minmax-neonbf16-bfdot-lane-ld128.c new file mode 100644 index 000000000000..b70a92126918 --- /dev/null +++ b/src/bf16-gemm/gen/5x8c2-minmax-neonbf16-bfdot-lane-ld128.c @@ -0,0 +1,383 @@ +// Auto-generated file. Do not edit! +// Template: src/bf16-gemm/c2-neonbf16-bfdot-lane-ld128.c.in +// Generator: tools/xngen +// +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 5); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(bfloat16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const bfloat16_t* a0 = (const bfloat16_t*) a; + bfloat16_t* c0 = (bfloat16_t*) c; + const bfloat16_t* a1 = (const bfloat16_t*) ((uintptr_t) a0 + a_stride); + bfloat16_t* c1 = (bfloat16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const bfloat16_t* a2 = (const bfloat16_t*) ((uintptr_t) a1 + a_stride); + bfloat16_t* c2 = (bfloat16_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + const bfloat16_t* a3 = (const bfloat16_t*) ((uintptr_t) a2 + a_stride); + bfloat16_t* c3 = (bfloat16_t*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr < 4) { + a3 = a2; + c3 = c2; + } + const bfloat16_t* a4 = (const bfloat16_t*) ((uintptr_t) a3 + a_stride); + bfloat16_t* c4 = (bfloat16_t*) ((uintptr_t) c3 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 4) { + a4 = a3; + c4 = c3; + } + + const bfloat16_t* w = (const bfloat16_t*) w_ptr; + do { + float32x4_t vacc0x0123 = vcvt_f32_bf16(vld1_bf16(w)); w += 4; + float32x4_t vacc0x4567 = vcvt_f32_bf16(vld1_bf16(w)); w += 4; + float32x4_t vacc1x0123 = vacc0x0123; + float32x4_t vacc1x4567 = vacc0x4567; + float32x4_t vacc2x0123 = vacc0x0123; + float32x4_t vacc2x4567 = vacc0x4567; + float32x4_t vacc3x0123 = vacc0x0123; + float32x4_t vacc3x4567 = vacc0x4567; + float32x4_t vacc4x0123 = vacc0x0123; + float32x4_t vacc4x4567 = vacc0x4567; + + size_t k = kc; + for (; k >= 8 * sizeof(bfloat16_t); k -= 8 * sizeof(bfloat16_t)) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 += 8; + const bfloat16x8_t va1 = vld1q_bf16(a1); a1 += 8; + const bfloat16x8_t va2 = vld1q_bf16(a2); a2 += 8; + const bfloat16x8_t va3 = vld1q_bf16(a3); a3 += 8; + const bfloat16x8_t va4 = vld1q_bf16(a4); a4 += 8; + + const bfloat16x8_t vb0123c01 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c01 = vld1q_bf16(w); w += 8; + + vacc0x0123 = vbfdotq_laneq_f32(vacc0x0123, vb0123c01, va0, 0); + vacc1x0123 = vbfdotq_laneq_f32(vacc1x0123, vb0123c01, va1, 0); + vacc2x0123 = vbfdotq_laneq_f32(vacc2x0123, vb0123c01, va2, 0); + vacc3x0123 = vbfdotq_laneq_f32(vacc3x0123, vb0123c01, va3, 0); + vacc4x0123 = vbfdotq_laneq_f32(vacc4x0123, vb0123c01, va4, 0); + vacc0x4567 = vbfdotq_laneq_f32(vacc0x4567, vb4567c01, va0, 0); + vacc1x4567 = vbfdotq_laneq_f32(vacc1x4567, vb4567c01, va1, 0); + vacc2x4567 = vbfdotq_laneq_f32(vacc2x4567, vb4567c01, va2, 0); + vacc3x4567 = vbfdotq_laneq_f32(vacc3x4567, vb4567c01, va3, 0); + vacc4x4567 = vbfdotq_laneq_f32(vacc4x4567, vb4567c01, va4, 0); + const bfloat16x8_t vb0123c23 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c23 = vld1q_bf16(w); w += 8; + + vacc0x0123 = vbfdotq_laneq_f32(vacc0x0123, vb0123c23, va0, 1); + vacc1x0123 = vbfdotq_laneq_f32(vacc1x0123, vb0123c23, va1, 1); + vacc2x0123 = vbfdotq_laneq_f32(vacc2x0123, vb0123c23, va2, 1); + vacc3x0123 = vbfdotq_laneq_f32(vacc3x0123, vb0123c23, va3, 1); + vacc4x0123 = vbfdotq_laneq_f32(vacc4x0123, vb0123c23, va4, 1); + vacc0x4567 = vbfdotq_laneq_f32(vacc0x4567, vb4567c23, va0, 1); + vacc1x4567 = vbfdotq_laneq_f32(vacc1x4567, vb4567c23, va1, 1); + vacc2x4567 = vbfdotq_laneq_f32(vacc2x4567, vb4567c23, va2, 1); + vacc3x4567 = vbfdotq_laneq_f32(vacc3x4567, vb4567c23, va3, 1); + vacc4x4567 = vbfdotq_laneq_f32(vacc4x4567, vb4567c23, va4, 1); + const bfloat16x8_t vb0123c45 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c45 = vld1q_bf16(w); w += 8; + + vacc0x0123 = vbfdotq_laneq_f32(vacc0x0123, vb0123c45, va0, 2); + vacc1x0123 = vbfdotq_laneq_f32(vacc1x0123, vb0123c45, va1, 2); + vacc2x0123 = vbfdotq_laneq_f32(vacc2x0123, vb0123c45, va2, 2); + vacc3x0123 = vbfdotq_laneq_f32(vacc3x0123, vb0123c45, va3, 2); + vacc4x0123 = vbfdotq_laneq_f32(vacc4x0123, vb0123c45, va4, 2); + vacc0x4567 = vbfdotq_laneq_f32(vacc0x4567, vb4567c45, va0, 2); + vacc1x4567 = vbfdotq_laneq_f32(vacc1x4567, vb4567c45, va1, 2); + vacc2x4567 = vbfdotq_laneq_f32(vacc2x4567, vb4567c45, va2, 2); + vacc3x4567 = vbfdotq_laneq_f32(vacc3x4567, vb4567c45, va3, 2); + vacc4x4567 = vbfdotq_laneq_f32(vacc4x4567, vb4567c45, va4, 2); + const bfloat16x8_t vb0123c67 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c67 = vld1q_bf16(w); w += 8; + + vacc0x0123 = vbfdotq_laneq_f32(vacc0x0123, vb0123c67, va0, 3); + vacc1x0123 = vbfdotq_laneq_f32(vacc1x0123, vb0123c67, va1, 3); + vacc2x0123 = vbfdotq_laneq_f32(vacc2x0123, vb0123c67, va2, 3); + vacc3x0123 = vbfdotq_laneq_f32(vacc3x0123, vb0123c67, va3, 3); + vacc4x0123 = vbfdotq_laneq_f32(vacc4x0123, vb0123c67, va4, 3); + vacc0x4567 = vbfdotq_laneq_f32(vacc0x4567, vb4567c67, va0, 3); + vacc1x4567 = vbfdotq_laneq_f32(vacc1x4567, vb4567c67, va1, 3); + vacc2x4567 = vbfdotq_laneq_f32(vacc2x4567, vb4567c67, va2, 3); + vacc3x4567 = vbfdotq_laneq_f32(vacc3x4567, vb4567c67, va3, 3); + vacc4x4567 = vbfdotq_laneq_f32(vacc4x4567, vb4567c67, va4, 3); + } + if XNN_UNLIKELY(k != 0) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 = (const bfloat16_t*) ((uintptr_t) a0 + k); + const bfloat16x8_t va1 = vld1q_bf16(a1); a1 = (const bfloat16_t*) ((uintptr_t) a1 + k); + const bfloat16x8_t va2 = vld1q_bf16(a2); a2 = (const bfloat16_t*) ((uintptr_t) a2 + k); + const bfloat16x8_t va3 = vld1q_bf16(a3); a3 = (const bfloat16_t*) ((uintptr_t) a3 + k); + const bfloat16x8_t va4 = vld1q_bf16(a4); a4 = (const bfloat16_t*) ((uintptr_t) a4 + k); + + const bfloat16x8_t vb0123c01 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c01 = vld1q_bf16(w); w += 8; + + const uint32x4_t va0c01 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va0)), 0); + const uint32x4_t va1c01 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va1)), 0); + const uint32x4_t va2c01 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va2)), 0); + const uint32x4_t va3c01 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va3)), 0); + const uint32x4_t va4c01 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va4)), 0); + + const uint32x4_t vm0123c01 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb0123c01), vmovq_n_u16(0))); + const uint32x4_t vm4567c01 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb4567c01), vmovq_n_u16(0))); + + const uint32x4_t va0x0123c01 = vbicq_u32(va0c01, vm0123c01); + vacc0x0123 = vbfdotq_f32(vacc0x0123, vb0123c01, vreinterpretq_bf16_u32(va0x0123c01)); + const uint32x4_t va1x0123c01 = vbicq_u32(va1c01, vm0123c01); + vacc1x0123 = vbfdotq_f32(vacc1x0123, vb0123c01, vreinterpretq_bf16_u32(va1x0123c01)); + const uint32x4_t va2x0123c01 = vbicq_u32(va2c01, vm0123c01); + vacc2x0123 = vbfdotq_f32(vacc2x0123, vb0123c01, vreinterpretq_bf16_u32(va2x0123c01)); + const uint32x4_t va3x0123c01 = vbicq_u32(va3c01, vm0123c01); + vacc3x0123 = vbfdotq_f32(vacc3x0123, vb0123c01, vreinterpretq_bf16_u32(va3x0123c01)); + const uint32x4_t va4x0123c01 = vbicq_u32(va4c01, vm0123c01); + vacc4x0123 = vbfdotq_f32(vacc4x0123, vb0123c01, vreinterpretq_bf16_u32(va4x0123c01)); + const uint32x4_t va0x4567c01 = vbicq_u32(va0c01, vm4567c01); + vacc0x4567 = vbfdotq_f32(vacc0x4567, vb4567c01, vreinterpretq_bf16_u32(va0x4567c01)); + const uint32x4_t va1x4567c01 = vbicq_u32(va1c01, vm4567c01); + vacc1x4567 = vbfdotq_f32(vacc1x4567, vb4567c01, vreinterpretq_bf16_u32(va1x4567c01)); + const uint32x4_t va2x4567c01 = vbicq_u32(va2c01, vm4567c01); + vacc2x4567 = vbfdotq_f32(vacc2x4567, vb4567c01, vreinterpretq_bf16_u32(va2x4567c01)); + const uint32x4_t va3x4567c01 = vbicq_u32(va3c01, vm4567c01); + vacc3x4567 = vbfdotq_f32(vacc3x4567, vb4567c01, vreinterpretq_bf16_u32(va3x4567c01)); + const uint32x4_t va4x4567c01 = vbicq_u32(va4c01, vm4567c01); + vacc4x4567 = vbfdotq_f32(vacc4x4567, vb4567c01, vreinterpretq_bf16_u32(va4x4567c01)); + + if (k > 2 * sizeof(bfloat16_t)) { + const bfloat16x8_t vb0123c23 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c23 = vld1q_bf16(w); w += 8; + + const uint32x4_t va0c23 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va0)), 1); + const uint32x4_t va1c23 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va1)), 1); + const uint32x4_t va2c23 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va2)), 1); + const uint32x4_t va3c23 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va3)), 1); + const uint32x4_t va4c23 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va4)), 1); + + const uint32x4_t vm0123c23 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb0123c23), vmovq_n_u16(0))); + const uint32x4_t vm4567c23 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb4567c23), vmovq_n_u16(0))); + + const uint32x4_t va0x0123c23 = vbicq_u32(va0c23, vm0123c23); + vacc0x0123 = vbfdotq_f32(vacc0x0123, vb0123c23, vreinterpretq_bf16_u32(va0x0123c23)); + const uint32x4_t va1x0123c23 = vbicq_u32(va1c23, vm0123c23); + vacc1x0123 = vbfdotq_f32(vacc1x0123, vb0123c23, vreinterpretq_bf16_u32(va1x0123c23)); + const uint32x4_t va2x0123c23 = vbicq_u32(va2c23, vm0123c23); + vacc2x0123 = vbfdotq_f32(vacc2x0123, vb0123c23, vreinterpretq_bf16_u32(va2x0123c23)); + const uint32x4_t va3x0123c23 = vbicq_u32(va3c23, vm0123c23); + vacc3x0123 = vbfdotq_f32(vacc3x0123, vb0123c23, vreinterpretq_bf16_u32(va3x0123c23)); + const uint32x4_t va4x0123c23 = vbicq_u32(va4c23, vm0123c23); + vacc4x0123 = vbfdotq_f32(vacc4x0123, vb0123c23, vreinterpretq_bf16_u32(va4x0123c23)); + const uint32x4_t va0x4567c23 = vbicq_u32(va0c23, vm4567c23); + vacc0x4567 = vbfdotq_f32(vacc0x4567, vb4567c23, vreinterpretq_bf16_u32(va0x4567c23)); + const uint32x4_t va1x4567c23 = vbicq_u32(va1c23, vm4567c23); + vacc1x4567 = vbfdotq_f32(vacc1x4567, vb4567c23, vreinterpretq_bf16_u32(va1x4567c23)); + const uint32x4_t va2x4567c23 = vbicq_u32(va2c23, vm4567c23); + vacc2x4567 = vbfdotq_f32(vacc2x4567, vb4567c23, vreinterpretq_bf16_u32(va2x4567c23)); + const uint32x4_t va3x4567c23 = vbicq_u32(va3c23, vm4567c23); + vacc3x4567 = vbfdotq_f32(vacc3x4567, vb4567c23, vreinterpretq_bf16_u32(va3x4567c23)); + const uint32x4_t va4x4567c23 = vbicq_u32(va4c23, vm4567c23); + vacc4x4567 = vbfdotq_f32(vacc4x4567, vb4567c23, vreinterpretq_bf16_u32(va4x4567c23)); + + if (k > 4 * sizeof(bfloat16_t)) { + const bfloat16x8_t vb0123c45 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c45 = vld1q_bf16(w); w += 8; + + const uint32x4_t va0c45 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va0)), 0); + const uint32x4_t va1c45 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va1)), 0); + const uint32x4_t va2c45 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va2)), 0); + const uint32x4_t va3c45 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va3)), 0); + const uint32x4_t va4c45 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va4)), 0); + + const uint32x4_t vm0123c45 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb0123c45), vmovq_n_u16(0))); + const uint32x4_t vm4567c45 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb4567c45), vmovq_n_u16(0))); + + const uint32x4_t va0x0123c45 = vbicq_u32(va0c45, vm0123c45); + vacc0x0123 = vbfdotq_f32(vacc0x0123, vb0123c45, vreinterpretq_bf16_u32(va0x0123c45)); + const uint32x4_t va1x0123c45 = vbicq_u32(va1c45, vm0123c45); + vacc1x0123 = vbfdotq_f32(vacc1x0123, vb0123c45, vreinterpretq_bf16_u32(va1x0123c45)); + const uint32x4_t va2x0123c45 = vbicq_u32(va2c45, vm0123c45); + vacc2x0123 = vbfdotq_f32(vacc2x0123, vb0123c45, vreinterpretq_bf16_u32(va2x0123c45)); + const uint32x4_t va3x0123c45 = vbicq_u32(va3c45, vm0123c45); + vacc3x0123 = vbfdotq_f32(vacc3x0123, vb0123c45, vreinterpretq_bf16_u32(va3x0123c45)); + const uint32x4_t va4x0123c45 = vbicq_u32(va4c45, vm0123c45); + vacc4x0123 = vbfdotq_f32(vacc4x0123, vb0123c45, vreinterpretq_bf16_u32(va4x0123c45)); + const uint32x4_t va0x4567c45 = vbicq_u32(va0c45, vm4567c45); + vacc0x4567 = vbfdotq_f32(vacc0x4567, vb4567c45, vreinterpretq_bf16_u32(va0x4567c45)); + const uint32x4_t va1x4567c45 = vbicq_u32(va1c45, vm4567c45); + vacc1x4567 = vbfdotq_f32(vacc1x4567, vb4567c45, vreinterpretq_bf16_u32(va1x4567c45)); + const uint32x4_t va2x4567c45 = vbicq_u32(va2c45, vm4567c45); + vacc2x4567 = vbfdotq_f32(vacc2x4567, vb4567c45, vreinterpretq_bf16_u32(va2x4567c45)); + const uint32x4_t va3x4567c45 = vbicq_u32(va3c45, vm4567c45); + vacc3x4567 = vbfdotq_f32(vacc3x4567, vb4567c45, vreinterpretq_bf16_u32(va3x4567c45)); + const uint32x4_t va4x4567c45 = vbicq_u32(va4c45, vm4567c45); + vacc4x4567 = vbfdotq_f32(vacc4x4567, vb4567c45, vreinterpretq_bf16_u32(va4x4567c45)); + + if (k > 6 * sizeof(bfloat16_t)) { + const bfloat16x8_t vb0123c67 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c67 = vld1q_bf16(w); w += 8; + + const uint32x4_t va0c67 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va0)), 1); + const uint32x4_t va1c67 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va1)), 1); + const uint32x4_t va2c67 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va2)), 1); + const uint32x4_t va3c67 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va3)), 1); + const uint32x4_t va4c67 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va4)), 1); + + const uint32x4_t vm0123c67 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb0123c67), vmovq_n_u16(0))); + const uint32x4_t vm4567c67 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb4567c67), vmovq_n_u16(0))); + + const uint32x4_t va0x0123c67 = vbicq_u32(va0c67, vm0123c67); + vacc0x0123 = vbfdotq_f32(vacc0x0123, vb0123c67, vreinterpretq_bf16_u32(va0x0123c67)); + const uint32x4_t va1x0123c67 = vbicq_u32(va1c67, vm0123c67); + vacc1x0123 = vbfdotq_f32(vacc1x0123, vb0123c67, vreinterpretq_bf16_u32(va1x0123c67)); + const uint32x4_t va2x0123c67 = vbicq_u32(va2c67, vm0123c67); + vacc2x0123 = vbfdotq_f32(vacc2x0123, vb0123c67, vreinterpretq_bf16_u32(va2x0123c67)); + const uint32x4_t va3x0123c67 = vbicq_u32(va3c67, vm0123c67); + vacc3x0123 = vbfdotq_f32(vacc3x0123, vb0123c67, vreinterpretq_bf16_u32(va3x0123c67)); + const uint32x4_t va4x0123c67 = vbicq_u32(va4c67, vm0123c67); + vacc4x0123 = vbfdotq_f32(vacc4x0123, vb0123c67, vreinterpretq_bf16_u32(va4x0123c67)); + const uint32x4_t va0x4567c67 = vbicq_u32(va0c67, vm4567c67); + vacc0x4567 = vbfdotq_f32(vacc0x4567, vb4567c67, vreinterpretq_bf16_u32(va0x4567c67)); + const uint32x4_t va1x4567c67 = vbicq_u32(va1c67, vm4567c67); + vacc1x4567 = vbfdotq_f32(vacc1x4567, vb4567c67, vreinterpretq_bf16_u32(va1x4567c67)); + const uint32x4_t va2x4567c67 = vbicq_u32(va2c67, vm4567c67); + vacc2x4567 = vbfdotq_f32(vacc2x4567, vb4567c67, vreinterpretq_bf16_u32(va2x4567c67)); + const uint32x4_t va3x4567c67 = vbicq_u32(va3c67, vm4567c67); + vacc3x4567 = vbfdotq_f32(vacc3x4567, vb4567c67, vreinterpretq_bf16_u32(va3x4567c67)); + const uint32x4_t va4x4567c67 = vbicq_u32(va4c67, vm4567c67); + vacc4x4567 = vbfdotq_f32(vacc4x4567, vb4567c67, vreinterpretq_bf16_u32(va4x4567c67)); + } + } + } + } + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + vacc0x0123 = vminq_f32(vacc0x0123, vmax); + vacc1x0123 = vminq_f32(vacc1x0123, vmax); + vacc2x0123 = vminq_f32(vacc2x0123, vmax); + vacc3x0123 = vminq_f32(vacc3x0123, vmax); + vacc4x0123 = vminq_f32(vacc4x0123, vmax); + vacc0x4567 = vminq_f32(vacc0x4567, vmax); + vacc1x4567 = vminq_f32(vacc1x4567, vmax); + vacc2x4567 = vminq_f32(vacc2x4567, vmax); + vacc3x4567 = vminq_f32(vacc3x4567, vmax); + vacc4x4567 = vminq_f32(vacc4x4567, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + vacc0x0123 = vmaxq_f32(vacc0x0123, vmin); + vacc1x0123 = vmaxq_f32(vacc1x0123, vmin); + vacc2x0123 = vmaxq_f32(vacc2x0123, vmin); + vacc3x0123 = vmaxq_f32(vacc3x0123, vmin); + vacc4x0123 = vmaxq_f32(vacc4x0123, vmin); + vacc0x4567 = vmaxq_f32(vacc0x4567, vmin); + vacc1x4567 = vmaxq_f32(vacc1x4567, vmin); + vacc2x4567 = vmaxq_f32(vacc2x4567, vmin); + vacc3x4567 = vmaxq_f32(vacc3x4567, vmin); + vacc4x4567 = vmaxq_f32(vacc4x4567, vmin); + + bfloat16x4_t vout0x0123 = vcvt_bf16_f32(vacc0x0123); + bfloat16x4_t vout1x0123 = vcvt_bf16_f32(vacc1x0123); + bfloat16x4_t vout2x0123 = vcvt_bf16_f32(vacc2x0123); + bfloat16x4_t vout3x0123 = vcvt_bf16_f32(vacc3x0123); + bfloat16x4_t vout4x0123 = vcvt_bf16_f32(vacc4x0123); + bfloat16x4_t vout0x4567 = vcvt_bf16_f32(vacc0x4567); + bfloat16x4_t vout1x4567 = vcvt_bf16_f32(vacc1x4567); + bfloat16x4_t vout2x4567 = vcvt_bf16_f32(vacc2x4567); + bfloat16x4_t vout3x4567 = vcvt_bf16_f32(vacc3x4567); + bfloat16x4_t vout4x4567 = vcvt_bf16_f32(vacc4x4567); + + if XNN_LIKELY(nc >= 8) { + vst1_bf16(c0, vout0x0123); + vst1_bf16(c0 + 4, vout0x4567); + c0 = (bfloat16_t*) ((uintptr_t) c0 + cn_stride); + vst1_bf16(c1, vout1x0123); + vst1_bf16(c1 + 4, vout1x4567); + c1 = (bfloat16_t*) ((uintptr_t) c1 + cn_stride); + vst1_bf16(c2, vout2x0123); + vst1_bf16(c2 + 4, vout2x4567); + c2 = (bfloat16_t*) ((uintptr_t) c2 + cn_stride); + vst1_bf16(c3, vout3x0123); + vst1_bf16(c3 + 4, vout3x4567); + c3 = (bfloat16_t*) ((uintptr_t) c3 + cn_stride); + vst1_bf16(c4, vout4x0123); + vst1_bf16(c4 + 4, vout4x4567); + c4 = (bfloat16_t*) ((uintptr_t) c4 + cn_stride); + + a0 = (const bfloat16_t*) ((uintptr_t) a0 - kc); + a1 = (const bfloat16_t*) ((uintptr_t) a1 - kc); + a2 = (const bfloat16_t*) ((uintptr_t) a2 - kc); + a3 = (const bfloat16_t*) ((uintptr_t) a3 - kc); + a4 = (const bfloat16_t*) ((uintptr_t) a4 - kc); + + nc -= 8; + } else { + if (nc & 4) { + vst1_bf16(c0, vout0x0123); c0 += 4; + vst1_bf16(c1, vout1x0123); c1 += 4; + vst1_bf16(c2, vout2x0123); c2 += 4; + vst1_bf16(c3, vout3x0123); c3 += 4; + vst1_bf16(c4, vout4x0123); c4 += 4; + + vout0x0123 = vout0x4567; + vout1x0123 = vout1x4567; + vout2x0123 = vout2x4567; + vout3x0123 = vout3x4567; + vout4x0123 = vout4x4567; + } + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_bf16(vout0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_bf16(vout1x0123), 0); c1 += 2; + vst1_lane_u32((void*) c2, vreinterpret_u32_bf16(vout2x0123), 0); c2 += 2; + vst1_lane_u32((void*) c3, vreinterpret_u32_bf16(vout3x0123), 0); c3 += 2; + vst1_lane_u32((void*) c4, vreinterpret_u32_bf16(vout4x0123), 0); c4 += 2; + + vout0x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout0x0123), vreinterpret_u16_bf16(vout0x0123), 2)); + vout1x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout1x0123), vreinterpret_u16_bf16(vout1x0123), 2)); + vout2x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout2x0123), vreinterpret_u16_bf16(vout2x0123), 2)); + vout3x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout3x0123), vreinterpret_u16_bf16(vout3x0123), 2)); + vout4x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout4x0123), vreinterpret_u16_bf16(vout4x0123), 2)); + } + if (nc & 1) { + vst1_lane_bf16(c0, vout0x0123, 0); + vst1_lane_bf16(c1, vout1x0123, 0); + vst1_lane_bf16(c2, vout2x0123, 0); + vst1_lane_bf16(c3, vout3x0123, 0); + vst1_lane_bf16(c4, vout4x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/bf16-gemm/gen/6x8c2-minmax-neonbf16-bfdot-lane-ld128.c b/src/bf16-gemm/gen/6x8c2-minmax-neonbf16-bfdot-lane-ld128.c new file mode 100644 index 000000000000..e862fc71b87e --- /dev/null +++ b/src/bf16-gemm/gen/6x8c2-minmax-neonbf16-bfdot-lane-ld128.c @@ -0,0 +1,436 @@ +// Auto-generated file. Do not edit! +// Template: src/bf16-gemm/c2-neonbf16-bfdot-lane-ld128.c.in +// Generator: tools/xngen +// +// Copyright 2022 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include + +#include + +#include + + +void xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128( + size_t mr, + size_t nc, + size_t kc, + const void* restrict a, + size_t a_stride, + const void* restrict w_ptr, + void* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) +{ + assert(mr != 0); + assert(mr <= 6); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(bfloat16_t) == 0); + assert(a != NULL); + assert(w_ptr != NULL); + assert(c != NULL); + + const bfloat16_t* a0 = (const bfloat16_t*) a; + bfloat16_t* c0 = (bfloat16_t*) c; + const bfloat16_t* a1 = (const bfloat16_t*) ((uintptr_t) a0 + a_stride); + bfloat16_t* c1 = (bfloat16_t*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const bfloat16_t* a2 = (const bfloat16_t*) ((uintptr_t) a1 + a_stride); + bfloat16_t* c2 = (bfloat16_t*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + const bfloat16_t* a3 = (const bfloat16_t*) ((uintptr_t) a2 + a_stride); + bfloat16_t* c3 = (bfloat16_t*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr < 4) { + a3 = a2; + c3 = c2; + } + const bfloat16_t* a4 = (const bfloat16_t*) ((uintptr_t) a3 + a_stride); + bfloat16_t* c4 = (bfloat16_t*) ((uintptr_t) c3 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 4) { + a4 = a3; + c4 = c3; + } + const bfloat16_t* a5 = (const bfloat16_t*) ((uintptr_t) a4 + a_stride); + bfloat16_t* c5 = (bfloat16_t*) ((uintptr_t) c4 + cm_stride); + if XNN_UNPREDICTABLE(mr != 6) { + a5 = a4; + c5 = c4; + } + + const bfloat16_t* w = (const bfloat16_t*) w_ptr; + do { + float32x4_t vacc0x0123 = vcvt_f32_bf16(vld1_bf16(w)); w += 4; + float32x4_t vacc0x4567 = vcvt_f32_bf16(vld1_bf16(w)); w += 4; + float32x4_t vacc1x0123 = vacc0x0123; + float32x4_t vacc1x4567 = vacc0x4567; + float32x4_t vacc2x0123 = vacc0x0123; + float32x4_t vacc2x4567 = vacc0x4567; + float32x4_t vacc3x0123 = vacc0x0123; + float32x4_t vacc3x4567 = vacc0x4567; + float32x4_t vacc4x0123 = vacc0x0123; + float32x4_t vacc4x4567 = vacc0x4567; + float32x4_t vacc5x0123 = vacc0x0123; + float32x4_t vacc5x4567 = vacc0x4567; + + size_t k = kc; + for (; k >= 8 * sizeof(bfloat16_t); k -= 8 * sizeof(bfloat16_t)) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 += 8; + const bfloat16x8_t va1 = vld1q_bf16(a1); a1 += 8; + const bfloat16x8_t va2 = vld1q_bf16(a2); a2 += 8; + const bfloat16x8_t va3 = vld1q_bf16(a3); a3 += 8; + const bfloat16x8_t va4 = vld1q_bf16(a4); a4 += 8; + const bfloat16x8_t va5 = vld1q_bf16(a5); a5 += 8; + + const bfloat16x8_t vb0123c01 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c01 = vld1q_bf16(w); w += 8; + + vacc0x0123 = vbfdotq_laneq_f32(vacc0x0123, vb0123c01, va0, 0); + vacc1x0123 = vbfdotq_laneq_f32(vacc1x0123, vb0123c01, va1, 0); + vacc2x0123 = vbfdotq_laneq_f32(vacc2x0123, vb0123c01, va2, 0); + vacc3x0123 = vbfdotq_laneq_f32(vacc3x0123, vb0123c01, va3, 0); + vacc4x0123 = vbfdotq_laneq_f32(vacc4x0123, vb0123c01, va4, 0); + vacc5x0123 = vbfdotq_laneq_f32(vacc5x0123, vb0123c01, va5, 0); + vacc0x4567 = vbfdotq_laneq_f32(vacc0x4567, vb4567c01, va0, 0); + vacc1x4567 = vbfdotq_laneq_f32(vacc1x4567, vb4567c01, va1, 0); + vacc2x4567 = vbfdotq_laneq_f32(vacc2x4567, vb4567c01, va2, 0); + vacc3x4567 = vbfdotq_laneq_f32(vacc3x4567, vb4567c01, va3, 0); + vacc4x4567 = vbfdotq_laneq_f32(vacc4x4567, vb4567c01, va4, 0); + vacc5x4567 = vbfdotq_laneq_f32(vacc5x4567, vb4567c01, va5, 0); + const bfloat16x8_t vb0123c23 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c23 = vld1q_bf16(w); w += 8; + + vacc0x0123 = vbfdotq_laneq_f32(vacc0x0123, vb0123c23, va0, 1); + vacc1x0123 = vbfdotq_laneq_f32(vacc1x0123, vb0123c23, va1, 1); + vacc2x0123 = vbfdotq_laneq_f32(vacc2x0123, vb0123c23, va2, 1); + vacc3x0123 = vbfdotq_laneq_f32(vacc3x0123, vb0123c23, va3, 1); + vacc4x0123 = vbfdotq_laneq_f32(vacc4x0123, vb0123c23, va4, 1); + vacc5x0123 = vbfdotq_laneq_f32(vacc5x0123, vb0123c23, va5, 1); + vacc0x4567 = vbfdotq_laneq_f32(vacc0x4567, vb4567c23, va0, 1); + vacc1x4567 = vbfdotq_laneq_f32(vacc1x4567, vb4567c23, va1, 1); + vacc2x4567 = vbfdotq_laneq_f32(vacc2x4567, vb4567c23, va2, 1); + vacc3x4567 = vbfdotq_laneq_f32(vacc3x4567, vb4567c23, va3, 1); + vacc4x4567 = vbfdotq_laneq_f32(vacc4x4567, vb4567c23, va4, 1); + vacc5x4567 = vbfdotq_laneq_f32(vacc5x4567, vb4567c23, va5, 1); + const bfloat16x8_t vb0123c45 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c45 = vld1q_bf16(w); w += 8; + + vacc0x0123 = vbfdotq_laneq_f32(vacc0x0123, vb0123c45, va0, 2); + vacc1x0123 = vbfdotq_laneq_f32(vacc1x0123, vb0123c45, va1, 2); + vacc2x0123 = vbfdotq_laneq_f32(vacc2x0123, vb0123c45, va2, 2); + vacc3x0123 = vbfdotq_laneq_f32(vacc3x0123, vb0123c45, va3, 2); + vacc4x0123 = vbfdotq_laneq_f32(vacc4x0123, vb0123c45, va4, 2); + vacc5x0123 = vbfdotq_laneq_f32(vacc5x0123, vb0123c45, va5, 2); + vacc0x4567 = vbfdotq_laneq_f32(vacc0x4567, vb4567c45, va0, 2); + vacc1x4567 = vbfdotq_laneq_f32(vacc1x4567, vb4567c45, va1, 2); + vacc2x4567 = vbfdotq_laneq_f32(vacc2x4567, vb4567c45, va2, 2); + vacc3x4567 = vbfdotq_laneq_f32(vacc3x4567, vb4567c45, va3, 2); + vacc4x4567 = vbfdotq_laneq_f32(vacc4x4567, vb4567c45, va4, 2); + vacc5x4567 = vbfdotq_laneq_f32(vacc5x4567, vb4567c45, va5, 2); + const bfloat16x8_t vb0123c67 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c67 = vld1q_bf16(w); w += 8; + + vacc0x0123 = vbfdotq_laneq_f32(vacc0x0123, vb0123c67, va0, 3); + vacc1x0123 = vbfdotq_laneq_f32(vacc1x0123, vb0123c67, va1, 3); + vacc2x0123 = vbfdotq_laneq_f32(vacc2x0123, vb0123c67, va2, 3); + vacc3x0123 = vbfdotq_laneq_f32(vacc3x0123, vb0123c67, va3, 3); + vacc4x0123 = vbfdotq_laneq_f32(vacc4x0123, vb0123c67, va4, 3); + vacc5x0123 = vbfdotq_laneq_f32(vacc5x0123, vb0123c67, va5, 3); + vacc0x4567 = vbfdotq_laneq_f32(vacc0x4567, vb4567c67, va0, 3); + vacc1x4567 = vbfdotq_laneq_f32(vacc1x4567, vb4567c67, va1, 3); + vacc2x4567 = vbfdotq_laneq_f32(vacc2x4567, vb4567c67, va2, 3); + vacc3x4567 = vbfdotq_laneq_f32(vacc3x4567, vb4567c67, va3, 3); + vacc4x4567 = vbfdotq_laneq_f32(vacc4x4567, vb4567c67, va4, 3); + vacc5x4567 = vbfdotq_laneq_f32(vacc5x4567, vb4567c67, va5, 3); + } + if XNN_UNLIKELY(k != 0) { + const bfloat16x8_t va0 = vld1q_bf16(a0); a0 = (const bfloat16_t*) ((uintptr_t) a0 + k); + const bfloat16x8_t va1 = vld1q_bf16(a1); a1 = (const bfloat16_t*) ((uintptr_t) a1 + k); + const bfloat16x8_t va2 = vld1q_bf16(a2); a2 = (const bfloat16_t*) ((uintptr_t) a2 + k); + const bfloat16x8_t va3 = vld1q_bf16(a3); a3 = (const bfloat16_t*) ((uintptr_t) a3 + k); + const bfloat16x8_t va4 = vld1q_bf16(a4); a4 = (const bfloat16_t*) ((uintptr_t) a4 + k); + const bfloat16x8_t va5 = vld1q_bf16(a5); a5 = (const bfloat16_t*) ((uintptr_t) a5 + k); + + const bfloat16x8_t vb0123c01 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c01 = vld1q_bf16(w); w += 8; + + const uint32x4_t va0c01 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va0)), 0); + const uint32x4_t va1c01 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va1)), 0); + const uint32x4_t va2c01 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va2)), 0); + const uint32x4_t va3c01 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va3)), 0); + const uint32x4_t va4c01 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va4)), 0); + const uint32x4_t va5c01 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va5)), 0); + + const uint32x4_t vm0123c01 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb0123c01), vmovq_n_u16(0))); + const uint32x4_t vm4567c01 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb4567c01), vmovq_n_u16(0))); + + const uint32x4_t va0x0123c01 = vbicq_u32(va0c01, vm0123c01); + vacc0x0123 = vbfdotq_f32(vacc0x0123, vb0123c01, vreinterpretq_bf16_u32(va0x0123c01)); + const uint32x4_t va1x0123c01 = vbicq_u32(va1c01, vm0123c01); + vacc1x0123 = vbfdotq_f32(vacc1x0123, vb0123c01, vreinterpretq_bf16_u32(va1x0123c01)); + const uint32x4_t va2x0123c01 = vbicq_u32(va2c01, vm0123c01); + vacc2x0123 = vbfdotq_f32(vacc2x0123, vb0123c01, vreinterpretq_bf16_u32(va2x0123c01)); + const uint32x4_t va3x0123c01 = vbicq_u32(va3c01, vm0123c01); + vacc3x0123 = vbfdotq_f32(vacc3x0123, vb0123c01, vreinterpretq_bf16_u32(va3x0123c01)); + const uint32x4_t va4x0123c01 = vbicq_u32(va4c01, vm0123c01); + vacc4x0123 = vbfdotq_f32(vacc4x0123, vb0123c01, vreinterpretq_bf16_u32(va4x0123c01)); + const uint32x4_t va5x0123c01 = vbicq_u32(va5c01, vm0123c01); + vacc5x0123 = vbfdotq_f32(vacc5x0123, vb0123c01, vreinterpretq_bf16_u32(va5x0123c01)); + const uint32x4_t va0x4567c01 = vbicq_u32(va0c01, vm4567c01); + vacc0x4567 = vbfdotq_f32(vacc0x4567, vb4567c01, vreinterpretq_bf16_u32(va0x4567c01)); + const uint32x4_t va1x4567c01 = vbicq_u32(va1c01, vm4567c01); + vacc1x4567 = vbfdotq_f32(vacc1x4567, vb4567c01, vreinterpretq_bf16_u32(va1x4567c01)); + const uint32x4_t va2x4567c01 = vbicq_u32(va2c01, vm4567c01); + vacc2x4567 = vbfdotq_f32(vacc2x4567, vb4567c01, vreinterpretq_bf16_u32(va2x4567c01)); + const uint32x4_t va3x4567c01 = vbicq_u32(va3c01, vm4567c01); + vacc3x4567 = vbfdotq_f32(vacc3x4567, vb4567c01, vreinterpretq_bf16_u32(va3x4567c01)); + const uint32x4_t va4x4567c01 = vbicq_u32(va4c01, vm4567c01); + vacc4x4567 = vbfdotq_f32(vacc4x4567, vb4567c01, vreinterpretq_bf16_u32(va4x4567c01)); + const uint32x4_t va5x4567c01 = vbicq_u32(va5c01, vm4567c01); + vacc5x4567 = vbfdotq_f32(vacc5x4567, vb4567c01, vreinterpretq_bf16_u32(va5x4567c01)); + + if (k > 2 * sizeof(bfloat16_t)) { + const bfloat16x8_t vb0123c23 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c23 = vld1q_bf16(w); w += 8; + + const uint32x4_t va0c23 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va0)), 1); + const uint32x4_t va1c23 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va1)), 1); + const uint32x4_t va2c23 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va2)), 1); + const uint32x4_t va3c23 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va3)), 1); + const uint32x4_t va4c23 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va4)), 1); + const uint32x4_t va5c23 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va5)), 1); + + const uint32x4_t vm0123c23 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb0123c23), vmovq_n_u16(0))); + const uint32x4_t vm4567c23 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb4567c23), vmovq_n_u16(0))); + + const uint32x4_t va0x0123c23 = vbicq_u32(va0c23, vm0123c23); + vacc0x0123 = vbfdotq_f32(vacc0x0123, vb0123c23, vreinterpretq_bf16_u32(va0x0123c23)); + const uint32x4_t va1x0123c23 = vbicq_u32(va1c23, vm0123c23); + vacc1x0123 = vbfdotq_f32(vacc1x0123, vb0123c23, vreinterpretq_bf16_u32(va1x0123c23)); + const uint32x4_t va2x0123c23 = vbicq_u32(va2c23, vm0123c23); + vacc2x0123 = vbfdotq_f32(vacc2x0123, vb0123c23, vreinterpretq_bf16_u32(va2x0123c23)); + const uint32x4_t va3x0123c23 = vbicq_u32(va3c23, vm0123c23); + vacc3x0123 = vbfdotq_f32(vacc3x0123, vb0123c23, vreinterpretq_bf16_u32(va3x0123c23)); + const uint32x4_t va4x0123c23 = vbicq_u32(va4c23, vm0123c23); + vacc4x0123 = vbfdotq_f32(vacc4x0123, vb0123c23, vreinterpretq_bf16_u32(va4x0123c23)); + const uint32x4_t va5x0123c23 = vbicq_u32(va5c23, vm0123c23); + vacc5x0123 = vbfdotq_f32(vacc5x0123, vb0123c23, vreinterpretq_bf16_u32(va5x0123c23)); + const uint32x4_t va0x4567c23 = vbicq_u32(va0c23, vm4567c23); + vacc0x4567 = vbfdotq_f32(vacc0x4567, vb4567c23, vreinterpretq_bf16_u32(va0x4567c23)); + const uint32x4_t va1x4567c23 = vbicq_u32(va1c23, vm4567c23); + vacc1x4567 = vbfdotq_f32(vacc1x4567, vb4567c23, vreinterpretq_bf16_u32(va1x4567c23)); + const uint32x4_t va2x4567c23 = vbicq_u32(va2c23, vm4567c23); + vacc2x4567 = vbfdotq_f32(vacc2x4567, vb4567c23, vreinterpretq_bf16_u32(va2x4567c23)); + const uint32x4_t va3x4567c23 = vbicq_u32(va3c23, vm4567c23); + vacc3x4567 = vbfdotq_f32(vacc3x4567, vb4567c23, vreinterpretq_bf16_u32(va3x4567c23)); + const uint32x4_t va4x4567c23 = vbicq_u32(va4c23, vm4567c23); + vacc4x4567 = vbfdotq_f32(vacc4x4567, vb4567c23, vreinterpretq_bf16_u32(va4x4567c23)); + const uint32x4_t va5x4567c23 = vbicq_u32(va5c23, vm4567c23); + vacc5x4567 = vbfdotq_f32(vacc5x4567, vb4567c23, vreinterpretq_bf16_u32(va5x4567c23)); + + if (k > 4 * sizeof(bfloat16_t)) { + const bfloat16x8_t vb0123c45 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c45 = vld1q_bf16(w); w += 8; + + const uint32x4_t va0c45 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va0)), 0); + const uint32x4_t va1c45 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va1)), 0); + const uint32x4_t va2c45 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va2)), 0); + const uint32x4_t va3c45 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va3)), 0); + const uint32x4_t va4c45 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va4)), 0); + const uint32x4_t va5c45 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va5)), 0); + + const uint32x4_t vm0123c45 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb0123c45), vmovq_n_u16(0))); + const uint32x4_t vm4567c45 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb4567c45), vmovq_n_u16(0))); + + const uint32x4_t va0x0123c45 = vbicq_u32(va0c45, vm0123c45); + vacc0x0123 = vbfdotq_f32(vacc0x0123, vb0123c45, vreinterpretq_bf16_u32(va0x0123c45)); + const uint32x4_t va1x0123c45 = vbicq_u32(va1c45, vm0123c45); + vacc1x0123 = vbfdotq_f32(vacc1x0123, vb0123c45, vreinterpretq_bf16_u32(va1x0123c45)); + const uint32x4_t va2x0123c45 = vbicq_u32(va2c45, vm0123c45); + vacc2x0123 = vbfdotq_f32(vacc2x0123, vb0123c45, vreinterpretq_bf16_u32(va2x0123c45)); + const uint32x4_t va3x0123c45 = vbicq_u32(va3c45, vm0123c45); + vacc3x0123 = vbfdotq_f32(vacc3x0123, vb0123c45, vreinterpretq_bf16_u32(va3x0123c45)); + const uint32x4_t va4x0123c45 = vbicq_u32(va4c45, vm0123c45); + vacc4x0123 = vbfdotq_f32(vacc4x0123, vb0123c45, vreinterpretq_bf16_u32(va4x0123c45)); + const uint32x4_t va5x0123c45 = vbicq_u32(va5c45, vm0123c45); + vacc5x0123 = vbfdotq_f32(vacc5x0123, vb0123c45, vreinterpretq_bf16_u32(va5x0123c45)); + const uint32x4_t va0x4567c45 = vbicq_u32(va0c45, vm4567c45); + vacc0x4567 = vbfdotq_f32(vacc0x4567, vb4567c45, vreinterpretq_bf16_u32(va0x4567c45)); + const uint32x4_t va1x4567c45 = vbicq_u32(va1c45, vm4567c45); + vacc1x4567 = vbfdotq_f32(vacc1x4567, vb4567c45, vreinterpretq_bf16_u32(va1x4567c45)); + const uint32x4_t va2x4567c45 = vbicq_u32(va2c45, vm4567c45); + vacc2x4567 = vbfdotq_f32(vacc2x4567, vb4567c45, vreinterpretq_bf16_u32(va2x4567c45)); + const uint32x4_t va3x4567c45 = vbicq_u32(va3c45, vm4567c45); + vacc3x4567 = vbfdotq_f32(vacc3x4567, vb4567c45, vreinterpretq_bf16_u32(va3x4567c45)); + const uint32x4_t va4x4567c45 = vbicq_u32(va4c45, vm4567c45); + vacc4x4567 = vbfdotq_f32(vacc4x4567, vb4567c45, vreinterpretq_bf16_u32(va4x4567c45)); + const uint32x4_t va5x4567c45 = vbicq_u32(va5c45, vm4567c45); + vacc5x4567 = vbfdotq_f32(vacc5x4567, vb4567c45, vreinterpretq_bf16_u32(va5x4567c45)); + + if (k > 6 * sizeof(bfloat16_t)) { + const bfloat16x8_t vb0123c67 = vld1q_bf16(w); w += 8; + const bfloat16x8_t vb4567c67 = vld1q_bf16(w); w += 8; + + const uint32x4_t va0c67 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va0)), 1); + const uint32x4_t va1c67 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va1)), 1); + const uint32x4_t va2c67 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va2)), 1); + const uint32x4_t va3c67 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va3)), 1); + const uint32x4_t va4c67 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va4)), 1); + const uint32x4_t va5c67 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va5)), 1); + + const uint32x4_t vm0123c67 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb0123c67), vmovq_n_u16(0))); + const uint32x4_t vm4567c67 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb4567c67), vmovq_n_u16(0))); + + const uint32x4_t va0x0123c67 = vbicq_u32(va0c67, vm0123c67); + vacc0x0123 = vbfdotq_f32(vacc0x0123, vb0123c67, vreinterpretq_bf16_u32(va0x0123c67)); + const uint32x4_t va1x0123c67 = vbicq_u32(va1c67, vm0123c67); + vacc1x0123 = vbfdotq_f32(vacc1x0123, vb0123c67, vreinterpretq_bf16_u32(va1x0123c67)); + const uint32x4_t va2x0123c67 = vbicq_u32(va2c67, vm0123c67); + vacc2x0123 = vbfdotq_f32(vacc2x0123, vb0123c67, vreinterpretq_bf16_u32(va2x0123c67)); + const uint32x4_t va3x0123c67 = vbicq_u32(va3c67, vm0123c67); + vacc3x0123 = vbfdotq_f32(vacc3x0123, vb0123c67, vreinterpretq_bf16_u32(va3x0123c67)); + const uint32x4_t va4x0123c67 = vbicq_u32(va4c67, vm0123c67); + vacc4x0123 = vbfdotq_f32(vacc4x0123, vb0123c67, vreinterpretq_bf16_u32(va4x0123c67)); + const uint32x4_t va5x0123c67 = vbicq_u32(va5c67, vm0123c67); + vacc5x0123 = vbfdotq_f32(vacc5x0123, vb0123c67, vreinterpretq_bf16_u32(va5x0123c67)); + const uint32x4_t va0x4567c67 = vbicq_u32(va0c67, vm4567c67); + vacc0x4567 = vbfdotq_f32(vacc0x4567, vb4567c67, vreinterpretq_bf16_u32(va0x4567c67)); + const uint32x4_t va1x4567c67 = vbicq_u32(va1c67, vm4567c67); + vacc1x4567 = vbfdotq_f32(vacc1x4567, vb4567c67, vreinterpretq_bf16_u32(va1x4567c67)); + const uint32x4_t va2x4567c67 = vbicq_u32(va2c67, vm4567c67); + vacc2x4567 = vbfdotq_f32(vacc2x4567, vb4567c67, vreinterpretq_bf16_u32(va2x4567c67)); + const uint32x4_t va3x4567c67 = vbicq_u32(va3c67, vm4567c67); + vacc3x4567 = vbfdotq_f32(vacc3x4567, vb4567c67, vreinterpretq_bf16_u32(va3x4567c67)); + const uint32x4_t va4x4567c67 = vbicq_u32(va4c67, vm4567c67); + vacc4x4567 = vbfdotq_f32(vacc4x4567, vb4567c67, vreinterpretq_bf16_u32(va4x4567c67)); + const uint32x4_t va5x4567c67 = vbicq_u32(va5c67, vm4567c67); + vacc5x4567 = vbfdotq_f32(vacc5x4567, vb4567c67, vreinterpretq_bf16_u32(va5x4567c67)); + } + } + } + } + + const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); + vacc0x0123 = vminq_f32(vacc0x0123, vmax); + vacc1x0123 = vminq_f32(vacc1x0123, vmax); + vacc2x0123 = vminq_f32(vacc2x0123, vmax); + vacc3x0123 = vminq_f32(vacc3x0123, vmax); + vacc4x0123 = vminq_f32(vacc4x0123, vmax); + vacc5x0123 = vminq_f32(vacc5x0123, vmax); + vacc0x4567 = vminq_f32(vacc0x4567, vmax); + vacc1x4567 = vminq_f32(vacc1x4567, vmax); + vacc2x4567 = vminq_f32(vacc2x4567, vmax); + vacc3x4567 = vminq_f32(vacc3x4567, vmax); + vacc4x4567 = vminq_f32(vacc4x4567, vmax); + vacc5x4567 = vminq_f32(vacc5x4567, vmax); + + const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); + vacc0x0123 = vmaxq_f32(vacc0x0123, vmin); + vacc1x0123 = vmaxq_f32(vacc1x0123, vmin); + vacc2x0123 = vmaxq_f32(vacc2x0123, vmin); + vacc3x0123 = vmaxq_f32(vacc3x0123, vmin); + vacc4x0123 = vmaxq_f32(vacc4x0123, vmin); + vacc5x0123 = vmaxq_f32(vacc5x0123, vmin); + vacc0x4567 = vmaxq_f32(vacc0x4567, vmin); + vacc1x4567 = vmaxq_f32(vacc1x4567, vmin); + vacc2x4567 = vmaxq_f32(vacc2x4567, vmin); + vacc3x4567 = vmaxq_f32(vacc3x4567, vmin); + vacc4x4567 = vmaxq_f32(vacc4x4567, vmin); + vacc5x4567 = vmaxq_f32(vacc5x4567, vmin); + + bfloat16x4_t vout0x0123 = vcvt_bf16_f32(vacc0x0123); + bfloat16x4_t vout1x0123 = vcvt_bf16_f32(vacc1x0123); + bfloat16x4_t vout2x0123 = vcvt_bf16_f32(vacc2x0123); + bfloat16x4_t vout3x0123 = vcvt_bf16_f32(vacc3x0123); + bfloat16x4_t vout4x0123 = vcvt_bf16_f32(vacc4x0123); + bfloat16x4_t vout5x0123 = vcvt_bf16_f32(vacc5x0123); + bfloat16x4_t vout0x4567 = vcvt_bf16_f32(vacc0x4567); + bfloat16x4_t vout1x4567 = vcvt_bf16_f32(vacc1x4567); + bfloat16x4_t vout2x4567 = vcvt_bf16_f32(vacc2x4567); + bfloat16x4_t vout3x4567 = vcvt_bf16_f32(vacc3x4567); + bfloat16x4_t vout4x4567 = vcvt_bf16_f32(vacc4x4567); + bfloat16x4_t vout5x4567 = vcvt_bf16_f32(vacc5x4567); + + if XNN_LIKELY(nc >= 8) { + vst1_bf16(c0, vout0x0123); + vst1_bf16(c0 + 4, vout0x4567); + c0 = (bfloat16_t*) ((uintptr_t) c0 + cn_stride); + vst1_bf16(c1, vout1x0123); + vst1_bf16(c1 + 4, vout1x4567); + c1 = (bfloat16_t*) ((uintptr_t) c1 + cn_stride); + vst1_bf16(c2, vout2x0123); + vst1_bf16(c2 + 4, vout2x4567); + c2 = (bfloat16_t*) ((uintptr_t) c2 + cn_stride); + vst1_bf16(c3, vout3x0123); + vst1_bf16(c3 + 4, vout3x4567); + c3 = (bfloat16_t*) ((uintptr_t) c3 + cn_stride); + vst1_bf16(c4, vout4x0123); + vst1_bf16(c4 + 4, vout4x4567); + c4 = (bfloat16_t*) ((uintptr_t) c4 + cn_stride); + vst1_bf16(c5, vout5x0123); + vst1_bf16(c5 + 4, vout5x4567); + c5 = (bfloat16_t*) ((uintptr_t) c5 + cn_stride); + + a0 = (const bfloat16_t*) ((uintptr_t) a0 - kc); + a1 = (const bfloat16_t*) ((uintptr_t) a1 - kc); + a2 = (const bfloat16_t*) ((uintptr_t) a2 - kc); + a3 = (const bfloat16_t*) ((uintptr_t) a3 - kc); + a4 = (const bfloat16_t*) ((uintptr_t) a4 - kc); + a5 = (const bfloat16_t*) ((uintptr_t) a5 - kc); + + nc -= 8; + } else { + if (nc & 4) { + vst1_bf16(c0, vout0x0123); c0 += 4; + vst1_bf16(c1, vout1x0123); c1 += 4; + vst1_bf16(c2, vout2x0123); c2 += 4; + vst1_bf16(c3, vout3x0123); c3 += 4; + vst1_bf16(c4, vout4x0123); c4 += 4; + vst1_bf16(c5, vout5x0123); c5 += 4; + + vout0x0123 = vout0x4567; + vout1x0123 = vout1x4567; + vout2x0123 = vout2x4567; + vout3x0123 = vout3x4567; + vout4x0123 = vout4x4567; + vout5x0123 = vout5x4567; + } + if (nc & 2) { + vst1_lane_u32((void*) c0, vreinterpret_u32_bf16(vout0x0123), 0); c0 += 2; + vst1_lane_u32((void*) c1, vreinterpret_u32_bf16(vout1x0123), 0); c1 += 2; + vst1_lane_u32((void*) c2, vreinterpret_u32_bf16(vout2x0123), 0); c2 += 2; + vst1_lane_u32((void*) c3, vreinterpret_u32_bf16(vout3x0123), 0); c3 += 2; + vst1_lane_u32((void*) c4, vreinterpret_u32_bf16(vout4x0123), 0); c4 += 2; + vst1_lane_u32((void*) c5, vreinterpret_u32_bf16(vout5x0123), 0); c5 += 2; + + vout0x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout0x0123), vreinterpret_u16_bf16(vout0x0123), 2)); + vout1x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout1x0123), vreinterpret_u16_bf16(vout1x0123), 2)); + vout2x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout2x0123), vreinterpret_u16_bf16(vout2x0123), 2)); + vout3x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout3x0123), vreinterpret_u16_bf16(vout3x0123), 2)); + vout4x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout4x0123), vreinterpret_u16_bf16(vout4x0123), 2)); + vout5x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout5x0123), vreinterpret_u16_bf16(vout5x0123), 2)); + } + if (nc & 1) { + vst1_lane_bf16(c0, vout0x0123, 0); + vst1_lane_bf16(c1, vout1x0123, 0); + vst1_lane_bf16(c2, vout2x0123, 0); + vst1_lane_bf16(c3, vout3x0123, 0); + vst1_lane_bf16(c4, vout4x0123, 0); + vst1_lane_bf16(c5, vout5x0123, 0); + } + + nc = 0; + } + } while (nc != 0); +} diff --git a/src/microparams-init.c b/src/microparams-init.c index 6f0d6d2bf5db..442ad87c7042 100644 --- a/src/microparams-init.c +++ b/src/microparams-init.c @@ -1934,6 +1934,16 @@ size_t xnn_init_scalar_f32_gavgpool_params( return sizeof(params->scalar); } +size_t xnn_init_bf16_minmax_scalar_params( + union xnn_bf16_minmax_params params[XNN_MIN_ELEMENTS(1)], + uint16_t output_min, + uint16_t output_max) +{ + params->scalar.min = uint32_as_float((uint32_t) output_min << 16); + params->scalar.max = uint32_as_float((uint32_t) output_max << 16); + return sizeof(params->scalar); +} + #if XNN_ARCH_ARM || XNN_ARCH_ARM64 size_t xnn_init_f16_minmax_neon_params( union xnn_f16_minmax_params params[XNN_MIN_ELEMENTS(1)], diff --git a/src/xnnpack/gemm.h b/src/xnnpack/gemm.h index 06a713eff6c4..b2ac9a1ee816 100644 --- a/src/xnnpack/gemm.h +++ b/src/xnnpack/gemm.h @@ -19,6 +19,95 @@ extern "C" { #endif +#define DECLARE_BF16_GEMM_MINMAX_UKERNEL_FUNCTION(fn_name) \ + void fn_name( \ + size_t mr, \ + size_t nr, \ + size_t k, \ + const void* a, \ + size_t a_stride, \ + const void* w, \ + void* c, \ + size_t cm_stride, \ + size_t cn_stride, \ + const union xnn_bf16_minmax_params* params); + +DECLARE_BF16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland) +DECLARE_BF16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland) +DECLARE_BF16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland) +DECLARE_BF16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland) +DECLARE_BF16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland) + +DECLARE_BF16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip) +DECLARE_BF16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip) +DECLARE_BF16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip) +DECLARE_BF16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip) +DECLARE_BF16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip) + +DECLARE_BF16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128) +DECLARE_BF16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128) +DECLARE_BF16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128) +DECLARE_BF16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128) + +DECLARE_BF16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot) +DECLARE_BF16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot) +DECLARE_BF16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot) +DECLARE_BF16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot) +DECLARE_BF16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot) + +DECLARE_BF16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal) +DECLARE_BF16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal) +DECLARE_BF16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal) +DECLARE_BF16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal) +DECLARE_BF16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal) + + +#define DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(fn_name) \ + void fn_name( \ + size_t mr, \ + size_t nr, \ + size_t k, \ + const void* a, \ + size_t a_stride, \ + const void* w, \ + void* c, \ + size_t cm_stride, \ + size_t cn_stride, \ + const union xnn_f16_minmax_params* params); + +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_1x8__aarch64_neonfp16arith_ld64) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_1x8__neonfp16arith_ld64) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_1x16__aarch64_neonfp16arith_ld32) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_1x16__aarch64_neonfp16arith_ld64) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_4x8__aarch64_neonfp16arith_ld64) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_4x8__neonfp16arith_ld64) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_4x16__aarch64_neonfp16arith_ld32) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_4x16__aarch64_neonfp16arith_ld64) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_6x8__aarch64_neonfp16arith_ld64) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_6x8__neonfp16arith_ld64) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_6x16__aarch64_neonfp16arith_cortex_a55) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_6x16__aarch64_neonfp16arith_cortex_a55r0) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_6x16__aarch64_neonfp16arith_cortex_a75) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_6x16__aarch64_neonfp16arith_ld32) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_6x16__aarch64_neonfp16arith_ld64) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_8x8__aarch64_neonfp16arith_ld64) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_8x8__neonfp16arith_ld64) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64) + +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_1x8__avx2_broadcast) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_1x16__avx2_broadcast) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_3x16__avx2_broadcast) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_4x8__avx2_broadcast) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_4x16__avx2_broadcast) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_5x8__avx2_broadcast) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_5x16__avx2_broadcast) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_6x8__avx2_broadcast) +DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_7x8__avx2_broadcast) + + #define DECLARE_F32_GEMM_UKERNEL_FUNCTION(fn_name) \ XNN_INTERNAL void fn_name( \ size_t mr, \ @@ -580,52 +669,6 @@ DECLARE_F32_GEMMINC_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemminc_minmax_ukernel_2x4__ DECLARE_F32_GEMMINC_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemminc_minmax_ukernel_4x4__scalar) -#define DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(fn_name) \ - void fn_name( \ - size_t mr, \ - size_t nr, \ - size_t k, \ - const void* a, \ - size_t a_stride, \ - const void* w, \ - void* c, \ - size_t cm_stride, \ - size_t cn_stride, \ - const union xnn_f16_minmax_params* params); - -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_1x8__aarch64_neonfp16arith_ld64) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_1x8__neonfp16arith_ld64) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_1x16__aarch64_neonfp16arith_ld32) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_1x16__aarch64_neonfp16arith_ld64) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_4x8__aarch64_neonfp16arith_ld64) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_4x8__neonfp16arith_ld64) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_4x16__aarch64_neonfp16arith_ld32) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_4x16__aarch64_neonfp16arith_ld64) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_6x8__aarch64_neonfp16arith_ld64) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_6x8__neonfp16arith_ld64) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_6x16__aarch64_neonfp16arith_cortex_a55) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_6x16__aarch64_neonfp16arith_cortex_a55r0) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_6x16__aarch64_neonfp16arith_cortex_a75) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_6x16__aarch64_neonfp16arith_ld32) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_6x16__aarch64_neonfp16arith_ld64) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_8x8__aarch64_neonfp16arith_ld64) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_8x8__neonfp16arith_ld64) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64) - -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_1x8__avx2_broadcast) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_1x16__avx2_broadcast) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_3x16__avx2_broadcast) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_4x8__avx2_broadcast) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_4x16__avx2_broadcast) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_5x8__avx2_broadcast) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_5x16__avx2_broadcast) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_6x8__avx2_broadcast) -DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_7x8__avx2_broadcast) - - #define DECLARE_QU8_GEMM_MINMAX_UKERNEL_FUNCTION(fn_name) \ XNN_INTERNAL void fn_name( \ size_t mr, \ diff --git a/src/xnnpack/isa-checks.h b/src/xnnpack/isa-checks.h index 51ab0b7d8bfd..a18c2d7065e7 100644 --- a/src/xnnpack/isa-checks.h +++ b/src/xnnpack/isa-checks.h @@ -132,6 +132,13 @@ } \ } while (0) +#define TEST_REQUIRES_ARM_NEON_BF16 \ + do { \ + if (!cpuinfo_initialize() || !cpuinfo_has_arm_neon_bf16()) { \ + GTEST_SKIP(); \ + } \ + } while (0) + #define TEST_REQUIRES_ARM_NEON_DOT \ do { \ if (!cpuinfo_initialize() || !cpuinfo_has_arm_neon_dot()) { \ diff --git a/src/xnnpack/microparams-init.h b/src/xnnpack/microparams-init.h index 1af2d5906c0d..f5b965ad97a7 100644 --- a/src/xnnpack/microparams-init.h +++ b/src/xnnpack/microparams-init.h @@ -325,6 +325,15 @@ XNN_INTERNAL size_t xnn_init_scalar_f32_gavgpool_params( #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 +#define DECLARE_INIT_BF16_MINMAX_PARAMS_FUNCTION(fn_name) \ + XNN_INTERNAL size_t fn_name( \ + union xnn_bf16_minmax_params params[XNN_MIN_ELEMENTS(1)], \ + uint16_t output_min, \ + uint16_t output_max); + +DECLARE_INIT_BF16_MINMAX_PARAMS_FUNCTION(xnn_init_bf16_minmax_scalar_params) + + #define DECLARE_INIT_F16_MINMAX_PARAMS_FUNCTION(fn_name) \ XNN_INTERNAL size_t fn_name( \ union xnn_f16_minmax_params params[XNN_MIN_ELEMENTS(1)], \ diff --git a/src/xnnpack/microparams.h b/src/xnnpack/microparams.h index f25b788978b0..3a6ae2e27be8 100644 --- a/src/xnnpack/microparams.h +++ b/src/xnnpack/microparams.h @@ -72,6 +72,13 @@ union xnn_f32_scaleminmax_params { // Min+Max: used by VCLAMP and GEMM/IGEMM/DWCONV/MAXPOOL/etc with MINMAX activation. +union xnn_bf16_minmax_params { + struct { + float min; + float max; + } scalar; +}; + union xnn_f16_minmax_params { char _; // Dummy member variable to comply with the C standard #if XNN_ARCH_ARM || XNN_ARCH_ARM64 diff --git a/src/xnnpack/params.h b/src/xnnpack/params.h index 45508decf35a..4e167a0bb785 100644 --- a/src/xnnpack/params.h +++ b/src/xnnpack/params.h @@ -168,6 +168,18 @@ typedef void (*xnn_f32_gemminc_minmax_ukernel_function)( const float* acc, const union xnn_f32_minmax_params* params); +typedef void (*xnn_bf16_gemm_minmax_ukernel_function)( + size_t mr, + size_t nr, + size_t k, + const void* a, + size_t a_stride, + const void* w, + void* c, + size_t cm_stride, + size_t cn_stride, + const union xnn_bf16_minmax_params* params); + typedef void (*xnn_f16_gemm_minmax_ukernel_function)( size_t mr, size_t nr, @@ -1771,6 +1783,11 @@ typedef size_t (*xnn_init_f16_chw_params_fn)( typedef size_t (*xnn_init_f16_hswish_params_fn)( union xnn_f16_hswish_params params[XNN_MIN_ELEMENTS(1)]); +typedef size_t (*xnn_init_bf16_minmax_params_fn)( + union xnn_bf16_minmax_params params[XNN_MIN_ELEMENTS(1)], + uint16_t min, + uint16_t max); + typedef size_t (*xnn_init_f16_minmax_params_fn)( union xnn_f16_minmax_params params[XNN_MIN_ELEMENTS(1)], uint16_t min, diff --git a/test/bf16-gemm-minmax.cc b/test/bf16-gemm-minmax.cc new file mode 100644 index 000000000000..e0c6abc5f3be --- /dev/null +++ b/test/bf16-gemm-minmax.cc @@ -0,0 +1,10967 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// Copyright 2019 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +// +// Auto-generated file. Do not edit! +// Specification: test/bf16-gemm-minmax.yaml +// Generator: tools/generate-gemm-test.py + + +#include + +#include +#include +#include + +#include +#include +#include +#include "gemm-microkernel-tester.h" + + +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, k_eq_8) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(8) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(8) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(8) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, k_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, k_eq_8_subtile_m) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(4) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, k_eq_8_subtile_n) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, k_lt_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, k_lt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(k) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, k_lt_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, k_gt_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(k) + .a_stride(19) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, k_div_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(k) + .a_stride(83) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, n_gt_4) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, n_gt_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, n_gt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, n_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, n_div_4) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, n_div_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, n_div_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, n_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(7) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, qmin) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(8) + .qmin(128) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, qmax) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(8) + .qmax(128) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_SHLAND, strided_cm) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(8) + .cm_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } +#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, k_eq_8) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(8) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(8) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(8) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, k_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, k_eq_8_subtile_m) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(4) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, k_eq_8_subtile_n) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, k_lt_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, k_lt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(k) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, k_lt_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, k_gt_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(k) + .a_stride(19) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, k_div_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(k) + .a_stride(83) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, n_gt_4) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, n_gt_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, n_gt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, n_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, n_div_4) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, n_div_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, n_div_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, n_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(7) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, qmin) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(8) + .qmin(128) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, qmax) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(8) + .qmax(128) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_SHLAND, strided_cm) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(8) + .cm_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } +#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, k_eq_8) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(8) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(8) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(8) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, k_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, k_eq_8_subtile_m) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(4) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, k_eq_8_subtile_n) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, k_lt_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, k_lt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(k) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, k_lt_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, k_gt_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(k) + .a_stride(19) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, k_div_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(k) + .a_stride(83) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, n_gt_4) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, n_gt_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, n_gt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, n_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, n_div_4) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, n_div_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, n_div_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, n_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(7) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, qmin) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(8) + .qmin(128) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, qmax) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(8) + .qmax(128) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_SHLAND, strided_cm) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(8) + .cm_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } +#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, k_eq_8) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(8) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(8) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(8) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, k_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, k_eq_8_subtile_m) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(4) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, k_eq_8_subtile_n) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, k_lt_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, k_lt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(k) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, k_lt_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, k_gt_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(k) + .a_stride(19) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, k_div_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(k) + .a_stride(83) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, n_gt_4) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, n_gt_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, n_gt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, n_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, n_div_4) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, n_div_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, n_div_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, n_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(7) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, qmin) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(8) + .qmin(128) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, qmax) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(8) + .qmax(128) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_SHLAND, strided_cm) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(8) + .cm_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } +#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, k_eq_8) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(8) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(8) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(8) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, k_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, k_eq_8_subtile_m) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(4) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, k_eq_8_subtile_n) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, k_lt_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, k_lt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(k) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, k_lt_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, k_gt_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(k) + .a_stride(19) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, k_div_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(k) + .a_stride(83) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, n_gt_4) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, n_gt_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, n_gt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, n_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, n_div_4) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, n_div_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, n_div_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, n_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(7) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, qmin) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(8) + .qmin(128) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, qmax) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(8) + .qmax(128) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_SHLAND, strided_cm) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(8) + .cm_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params); + } +#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM64 + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, k_eq_8) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(8) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(8) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(8) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, k_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, k_eq_8_subtile_m) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(4) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, k_eq_8_subtile_n) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, k_lt_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, k_lt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(k) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, k_lt_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, k_gt_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(k) + .a_stride(19) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, k_div_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(k) + .a_stride(83) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, n_gt_4) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, n_gt_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, n_gt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, n_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, n_div_4) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, n_div_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, n_div_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, n_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(7) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, qmin) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(8) + .qmin(128) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, qmax) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(8) + .qmax(128) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONFMA_ZIP, strided_cm) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(8) + .cm_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } +#endif // XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM64 + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, k_eq_8) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(8) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(8) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(8) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, k_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, k_eq_8_subtile_m) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(4) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, k_eq_8_subtile_n) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, k_lt_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, k_lt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(k) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, k_lt_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, k_gt_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(k) + .a_stride(19) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, k_div_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(k) + .a_stride(83) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, n_gt_4) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, n_gt_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, n_gt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, n_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, n_div_4) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, n_div_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, n_div_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, n_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(7) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, qmin) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(8) + .qmin(128) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, qmax) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(8) + .qmax(128) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONFMA_ZIP, strided_cm) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(8) + .cm_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } +#endif // XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM64 + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, k_eq_8) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(8) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(8) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(8) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, k_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, k_eq_8_subtile_m) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(4) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, k_eq_8_subtile_n) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, k_lt_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, k_lt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(k) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, k_lt_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, k_gt_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(k) + .a_stride(19) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, k_div_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(k) + .a_stride(83) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, n_gt_4) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, n_gt_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, n_gt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, n_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, n_div_4) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, n_div_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, n_div_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, n_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(7) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, qmin) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(8) + .qmin(128) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, qmax) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(8) + .qmax(128) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONFMA_ZIP, strided_cm) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(8) + .cm_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } +#endif // XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM64 + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, k_eq_8) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(8) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(8) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(8) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, k_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, k_eq_8_subtile_m) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(4) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, k_eq_8_subtile_n) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, k_lt_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, k_lt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(k) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, k_lt_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, k_gt_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(k) + .a_stride(19) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, k_div_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(k) + .a_stride(83) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, n_gt_4) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, n_gt_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, n_gt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, n_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, n_div_4) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, n_div_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, n_div_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, n_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(7) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, qmin) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(8) + .qmin(128) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, qmax) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(8) + .qmax(128) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONFMA_ZIP, strided_cm) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(8) + .cm_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } +#endif // XNN_ARCH_ARM64 + + +#if XNN_ARCH_ARM64 + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, k_eq_8) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(8) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(8) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(8) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, k_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, k_eq_8_subtile_m) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(4) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, k_eq_8_subtile_n) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, k_lt_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, k_lt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(k) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, k_lt_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k < 8; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, k_gt_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(k) + .a_stride(19) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 9; k < 16; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, k_div_8) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(k) + .a_stride(83) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 16; k <= 80; k += 8) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, n_gt_4) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, n_gt_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, n_gt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, n_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, n_div_4) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, n_div_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, n_div_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, n_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_FMA; + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(7) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, qmin) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(8) + .qmin(128) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, qmax) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(8) + .qmax(128) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONFMA_ZIP, strided_cm) { + TEST_REQUIRES_ARM_NEON_FMA; + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(8) + .cm_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params); + } +#endif // XNN_ARCH_ARM64 + + +#if XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, k_eq_8) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(1) + .n(8) + .k(8) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(1) + .n(8) + .k(8) + .cn_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(1) + .n(8) + .k(8) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, k_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 8; n++) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, k_eq_8_subtile_m) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(8) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, k_eq_8_subtile_n) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(1) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, k_lt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(1) + .n(8) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, k_lt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(1) + .n(8) + .k(k) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, k_lt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + for (uint32_t n = 1; n <= 8; n++) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, k_gt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(1) + .n(8) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(1) + .n(8) + .k(k) + .a_stride(19) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + for (uint32_t n = 1; n <= 8; n++) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, k_div_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(1) + .n(8) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(1) + .n(8) + .k(k) + .a_stride(83) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + for (uint32_t n = 1; n <= 8; n++) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, n_gt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(1) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, n_gt_8_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(1) + .n(n) + .k(k) + .cn_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, n_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(1) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, n_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, n_div_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(1) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, n_div_8_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(1) + .n(n) + .k(k) + .cn_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, n_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(1) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, n_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t n = 1; n <= 8; n++) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(11) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, qmin) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(1) + .n(8) + .k(8) + .qmin(128) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, qmax) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(1) + .n(8) + .k(8) + .qmax(128) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_1X8C2__NEONBF16_BFDOT_LANE_LD128, strided_cm) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(1) + .nr(8) + .kr(2) + .sr(1) + .m(1) + .n(8) + .k(8) + .cm_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } +#endif // XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + + +#if XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, k_eq_8) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(4) + .n(8) + .k(8) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(4) + .n(8) + .k(8) + .cn_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(4) + .n(8) + .k(8) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, k_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 8; n++) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, k_eq_8_subtile_m) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(8) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, k_eq_8_subtile_n) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(4) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, k_lt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, k_lt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(4) + .n(8) + .k(k) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, k_lt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + for (uint32_t n = 1; n <= 8; n++) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, k_gt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(4) + .n(8) + .k(k) + .a_stride(19) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + for (uint32_t n = 1; n <= 8; n++) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, k_div_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(4) + .n(8) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(4) + .n(8) + .k(k) + .a_stride(83) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + for (uint32_t n = 1; n <= 8; n++) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, n_gt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(4) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, n_gt_8_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(4) + .n(n) + .k(k) + .cn_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, n_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(4) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, n_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, n_div_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(4) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, n_div_8_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(4) + .n(n) + .k(k) + .cn_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, n_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(4) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, n_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t n = 1; n <= 8; n++) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(11) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, qmin) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(4) + .n(8) + .k(8) + .qmin(128) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, qmax) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(4) + .n(8) + .k(8) + .qmax(128) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_4X8C2__NEONBF16_BFDOT_LANE_LD128, strided_cm) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .kr(2) + .sr(1) + .m(4) + .n(8) + .k(8) + .cm_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } +#endif // XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + + +#if XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, k_eq_8) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(5) + .n(8) + .k(8) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(5) + .n(8) + .k(8) + .cn_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(5) + .n(8) + .k(8) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, k_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 8; n++) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, k_eq_8_subtile_m) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(8) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, k_eq_8_subtile_n) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(5) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, k_lt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(5) + .n(8) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, k_lt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(5) + .n(8) + .k(k) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, k_lt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + for (uint32_t n = 1; n <= 8; n++) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, k_gt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(5) + .n(8) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(5) + .n(8) + .k(k) + .a_stride(19) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + for (uint32_t n = 1; n <= 8; n++) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, k_div_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(5) + .n(8) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(5) + .n(8) + .k(k) + .a_stride(83) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + for (uint32_t n = 1; n <= 8; n++) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, n_gt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(5) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, n_gt_8_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(5) + .n(n) + .k(k) + .cn_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, n_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(5) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, n_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, n_div_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(5) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, n_div_8_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(5) + .n(n) + .k(k) + .cn_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, n_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(5) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, n_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t n = 1; n <= 8; n++) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(11) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, qmin) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(5) + .n(8) + .k(8) + .qmin(128) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, qmax) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(5) + .n(8) + .k(8) + .qmax(128) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_5X8C2__NEONBF16_BFDOT_LANE_LD128, strided_cm) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(5) + .nr(8) + .kr(2) + .sr(1) + .m(5) + .n(8) + .k(8) + .cm_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } +#endif // XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + + +#if XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, k_eq_8) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(6) + .n(8) + .k(8) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(6) + .n(8) + .k(8) + .cn_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(6) + .n(8) + .k(8) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, k_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 8; n++) { + for (uint32_t m = 1; m <= 6; m++) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, k_eq_8_subtile_m) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t m = 1; m <= 6; m++) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(8) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, k_eq_8_subtile_n) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(6) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, k_lt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(6) + .n(8) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, k_lt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(6) + .n(8) + .k(k) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, k_lt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + for (uint32_t n = 1; n <= 8; n++) { + for (uint32_t m = 1; m <= 6; m++) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, k_gt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(6) + .n(8) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(6) + .n(8) + .k(k) + .a_stride(19) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + for (uint32_t n = 1; n <= 8; n++) { + for (uint32_t m = 1; m <= 6; m++) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, k_div_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(6) + .n(8) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(6) + .n(8) + .k(k) + .a_stride(83) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + for (uint32_t n = 1; n <= 8; n++) { + for (uint32_t m = 1; m <= 6; m++) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, n_gt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(6) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, n_gt_8_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(6) + .n(n) + .k(k) + .cn_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, n_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(6) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, n_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 9; n < 16; n++) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 6; m++) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, n_div_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(6) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, n_div_8_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(6) + .n(n) + .k(k) + .cn_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, n_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(6) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, n_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 16; n <= 24; n += 8) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 6; m++) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t n = 1; n <= 8; n++) { + for (uint32_t m = 1; m <= 6; m++) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(11) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, qmin) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(6) + .n(8) + .k(8) + .qmin(128) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, qmax) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(6) + .n(8) + .k(8) + .qmax(128) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_6X8C2__NEONBF16_BFDOT_LANE_LD128, strided_cm) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(6) + .nr(8) + .kr(2) + .sr(1) + .m(6) + .n(8) + .k(8) + .cm_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params); + } +#endif // XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + + +#if XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, k_eq_8) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(8) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(8) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(8) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, k_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, k_eq_8_subtile_m) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(4) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, k_eq_8_subtile_n) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, k_lt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, k_lt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(k) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, k_lt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, k_gt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(k) + .a_stride(19) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, k_div_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(k) + .a_stride(83) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, n_gt_4) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, n_gt_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, n_gt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, n_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, n_div_4) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, n_div_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, n_div_4_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, n_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(7) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, qmin) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(8) + .qmin(128) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, qmax) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(8) + .qmax(128) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFDOT, strided_cm) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(8) + .cm_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } +#endif // XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + + +#if XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, k_eq_8) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(8) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(8) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(8) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, k_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, k_eq_8_subtile_m) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(4) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, k_eq_8_subtile_n) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, k_lt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, k_lt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(k) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, k_lt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, k_gt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(k) + .a_stride(19) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, k_div_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(k) + .a_stride(83) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, n_gt_4) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, n_gt_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, n_gt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, n_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, n_div_4) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, n_div_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, n_div_4_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, n_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(7) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, qmin) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(8) + .qmin(128) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, qmax) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(8) + .qmax(128) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFDOT, strided_cm) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(8) + .cm_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } +#endif // XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + + +#if XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, k_eq_8) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(8) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(8) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(8) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, k_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, k_eq_8_subtile_m) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(4) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, k_eq_8_subtile_n) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, k_lt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, k_lt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(k) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, k_lt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, k_gt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(k) + .a_stride(19) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, k_div_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(k) + .a_stride(83) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, n_gt_4) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, n_gt_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, n_gt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, n_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, n_div_4) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, n_div_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, n_div_4_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, n_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(7) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, qmin) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(8) + .qmin(128) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, qmax) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(8) + .qmax(128) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFDOT, strided_cm) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(8) + .cm_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } +#endif // XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + + +#if XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, k_eq_8) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(8) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(8) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(8) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, k_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, k_eq_8_subtile_m) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(4) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, k_eq_8_subtile_n) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, k_lt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, k_lt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(k) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, k_lt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, k_gt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(k) + .a_stride(19) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, k_div_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(k) + .a_stride(83) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, n_gt_4) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, n_gt_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, n_gt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, n_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, n_div_4) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, n_div_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, n_div_4_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, n_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(7) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, qmin) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(8) + .qmin(128) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, qmax) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(8) + .qmax(128) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFDOT, strided_cm) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(8) + .cm_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } +#endif // XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + + +#if XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, k_eq_8) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(8) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(8) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(8) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, k_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, k_eq_8_subtile_m) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(4) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, k_eq_8_subtile_n) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, k_lt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, k_lt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(k) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, k_lt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, k_gt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(k) + .a_stride(19) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, k_div_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(k) + .a_stride(83) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, n_gt_4) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, n_gt_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, n_gt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, n_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, n_div_4) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, n_div_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, n_div_4_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, n_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(7) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, qmin) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(8) + .qmin(128) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, qmax) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(8) + .qmax(128) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFDOT, strided_cm) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(8) + .cm_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params); + } +#endif // XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + + +#if XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, k_eq_8) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(8) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(8) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(8) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, k_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, k_eq_8_subtile_m) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(4) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, k_eq_8_subtile_n) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, k_lt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, k_lt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(k) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, k_lt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, k_gt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(k) + .a_stride(19) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, k_div_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(k) + .a_stride(83) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, n_gt_4) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, n_gt_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, n_gt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, n_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, n_div_4) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, n_div_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, n_div_4_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, n_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 1; m++) { + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(7) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, qmin) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(8) + .qmin(128) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, qmax) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(8) + .qmax(128) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_1X4C8__NEONBF16_BFMLAL, strided_cm) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(1) + .nr(4) + .kr(8) + .sr(1) + .m(1) + .n(4) + .k(8) + .cm_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } +#endif // XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + + +#if XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, k_eq_8) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(8) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(8) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(8) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, k_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, k_eq_8_subtile_m) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(4) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, k_eq_8_subtile_n) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, k_lt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, k_lt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(k) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, k_lt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, k_gt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(k) + .a_stride(19) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, k_div_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(k) + .a_stride(83) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, n_gt_4) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, n_gt_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, n_gt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, n_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, n_div_4) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, n_div_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, n_div_4_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, n_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 2; m++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(7) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, qmin) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(8) + .qmin(128) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, qmax) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(8) + .qmax(128) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_2X4C8__NEONBF16_BFMLAL, strided_cm) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .kr(8) + .sr(1) + .m(2) + .n(4) + .k(8) + .cm_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } +#endif // XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + + +#if XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, k_eq_8) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(8) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(8) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(8) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, k_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, k_eq_8_subtile_m) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(4) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, k_eq_8_subtile_n) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, k_lt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, k_lt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(k) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, k_lt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, k_gt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(k) + .a_stride(19) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, k_div_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(k) + .a_stride(83) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, n_gt_4) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, n_gt_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, n_gt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, n_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, n_div_4) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, n_div_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, n_div_4_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, n_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 3; m++) { + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(7) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, qmin) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(8) + .qmin(128) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, qmax) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(8) + .qmax(128) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_3X4C8__NEONBF16_BFMLAL, strided_cm) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(3) + .nr(4) + .kr(8) + .sr(1) + .m(3) + .n(4) + .k(8) + .cm_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } +#endif // XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + + +#if XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, k_eq_8) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(8) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(8) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(8) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, k_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, k_eq_8_subtile_m) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(4) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, k_eq_8_subtile_n) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, k_lt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, k_lt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(k) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, k_lt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, k_gt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(k) + .a_stride(19) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, k_div_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(k) + .a_stride(83) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, n_gt_4) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, n_gt_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, n_gt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, n_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, n_div_4) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, n_div_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, n_div_4_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, n_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 4; m++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(7) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, qmin) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(8) + .qmin(128) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, qmax) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(8) + .qmax(128) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_4X4C8__NEONBF16_BFMLAL, strided_cm) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .kr(8) + .sr(1) + .m(4) + .n(4) + .k(8) + .cm_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } +#endif // XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + + +#if XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, k_eq_8) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(8) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(8) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(8) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, k_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, k_eq_8_subtile_m) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(4) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, k_eq_8_subtile_n) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(8) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, k_lt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, k_lt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(k) + .a_stride(11) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, k_lt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k < 8; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, k_gt_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(k) + .a_stride(19) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 9; k < 16; k++) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, k_div_8) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(k) + .a_stride(83) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 16; k <= 80; k += 8) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, n_gt_4) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, n_gt_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, n_gt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, n_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 5; n < 8; n++) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, n_div_4) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(k) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, n_div_4_strided_cn) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(k) + .cn_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, n_div_4_strided_a) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(n) + .k(k) + .a_stride(43) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, n_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (uint32_t n = 8; n <= 12; n += 4) { + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, strided_cm_subtile) { + TEST_REQUIRES_ARM_NEON_BF16; + for (size_t k = 1; k <= 40; k += 9) { + for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 5; m++) { + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(m) + .n(n) + .k(k) + .cm_stride(7) + .iterations(1) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + } + } + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, qmin) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(8) + .qmin(128) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, qmax) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(8) + .qmax(128) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } + + TEST(BF16_GEMM_MINMAX_5X4C8__NEONBF16_BFMLAL, strided_cm) { + TEST_REQUIRES_ARM_NEON_BF16; + GemmMicrokernelTester() + .mr(5) + .nr(4) + .kr(8) + .sr(1) + .m(5) + .n(4) + .k(8) + .cm_stride(7) + .Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params); + } +#endif // XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64) diff --git a/test/bf16-gemm-minmax.yaml b/test/bf16-gemm-minmax.yaml new file mode 100644 index 000000000000..5f515eebf653 --- /dev/null +++ b/test/bf16-gemm-minmax.yaml @@ -0,0 +1,92 @@ +# Copyright 2022 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# ARM NEON +- name: xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland + init: xnn_init_bf16_minmax_scalar_params + k-block: 8 +- name: xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland + init: xnn_init_bf16_minmax_scalar_params + k-block: 8 +- name: xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland + init: xnn_init_bf16_minmax_scalar_params + k-block: 8 +- name: xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland + init: xnn_init_bf16_minmax_scalar_params + k-block: 8 +- name: xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland + init: xnn_init_bf16_minmax_scalar_params + k-block: 8 + +- name: xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip + init: xnn_init_bf16_minmax_scalar_params + k-block: 8 + arch: + - aarch64 +- name: xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip + init: xnn_init_bf16_minmax_scalar_params + k-block: 8 + arch: + - aarch64 +- name: xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip + init: xnn_init_bf16_minmax_scalar_params + k-block: 8 + arch: + - aarch64 +- name: xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip + init: xnn_init_bf16_minmax_scalar_params + k-block: 8 + arch: + - aarch64 +- name: xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip + init: xnn_init_bf16_minmax_scalar_params + k-block: 8 + arch: + - aarch64 + +- name: xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128 + init: xnn_init_bf16_minmax_scalar_params + k-block: 8 +- name: xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128 + init: xnn_init_bf16_minmax_scalar_params + k-block: 8 +- name: xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128 + init: xnn_init_bf16_minmax_scalar_params + k-block: 8 +- name: xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128 + init: xnn_init_bf16_minmax_scalar_params + k-block: 8 + +- name: xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot + init: xnn_init_bf16_minmax_scalar_params + k-block: 8 +- name: xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot + init: xnn_init_bf16_minmax_scalar_params + k-block: 8 +- name: xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot + init: xnn_init_bf16_minmax_scalar_params + k-block: 8 +- name: xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot + init: xnn_init_bf16_minmax_scalar_params + k-block: 8 +- name: xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot + init: xnn_init_bf16_minmax_scalar_params + k-block: 8 + +- name: xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal + init: xnn_init_bf16_minmax_scalar_params + k-block: 8 +- name: xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal + init: xnn_init_bf16_minmax_scalar_params + k-block: 8 +- name: xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal + init: xnn_init_bf16_minmax_scalar_params + k-block: 8 +- name: xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal + init: xnn_init_bf16_minmax_scalar_params + k-block: 8 +- name: xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal + init: xnn_init_bf16_minmax_scalar_params + k-block: 8 diff --git a/test/gemm-microkernel-tester.cc b/test/gemm-microkernel-tester.cc index 5ce72a926e8d..87077d6b784c 100644 --- a/test/gemm-microkernel-tester.cc +++ b/test/gemm-microkernel-tester.cc @@ -714,6 +714,82 @@ void GemmMicrokernelTester::Test( } } +void GemmMicrokernelTester::Test(xnn_bf16_gemm_minmax_ukernel_function gemm_minmax, xnn_init_bf16_minmax_params_fn init_params) const +{ + ASSERT_LE(m(), mr()); + ASSERT_GE(a_stride(), k()); + ASSERT_GE(cm_stride(), n()); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto f32rng = std::bind(std::uniform_real_distribution(0.5f, 1.0f), std::ref(rng)); + + std::vector a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint16_t)); + std::vector b(n() * k()); + std::vector> packed_w(packed_n() * packed_k() + packed_n()); + std::vector bias(n()); + std::vector c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + std::vector c_ref(m() * n()); + + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(a.begin(), a.end(), [&] { return fp32_to_bits(f32rng(rng)) >> 16; }); + std::generate(b.begin(), b.end(), [&] { return fp32_to_bits(f32rng(rng)) >> 16; }); + std::generate(bias.begin(), bias.end(), [&] { return fp32_to_bits(f32rng(rng)) >> 16; }); + std::fill(c.begin(), c.end(), UINT32_C(0x7FC0) /* NaN */); + std::fill(c_ref.begin(), c_ref.end(), 0.0f); + + std::fill(packed_w.begin(), packed_w.end(), 0); + xnn_pack_f16_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data(), 0, nullptr); + + for (size_t m_index = 0; m_index < m(); m_index++) { + for (size_t n_index = 0; n_index < n(); n_index++) { + c_ref[m_index * n() + n_index] = fp32_from_bits(uint32_t(bias[n_index]) << 16); + for (size_t k_index = 0; k_index < k(); k_index++) { + ASSERT_LE(n(), packed_n()); + ASSERT_LT(m_index * n() + n_index, c_ref.size()); + ASSERT_LT(m_index * k() + k_index, a.size()); + c_ref[m_index * n() + n_index] += + fp32_from_bits(uint32_t(a[m_index * a_stride() + k_index]) << 16) * + fp32_from_bits(uint32_t(b[n_index * k() + k_index]) << 16); + } + } + } + + const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend()); + const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend()); + const float c_min = fp32_from_bits(fp32_to_bits(accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin())) & UINT32_C(0xFFFF0000)); + const float c_max = fp32_from_bits(fp32_to_bits(accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax())) & UINT32_C(0xFFFF0000)); + + // Prepare parameters. + xnn_bf16_minmax_params params; + init_params(¶ms, + fp32_to_bits(c_min) >> 16, + fp32_to_bits(c_max) >> 16); + + for (float& c_value : c_ref) { + c_value = std::max(std::min(c_value, c_max), c_min); + } + + gemm_minmax(m(), n(), k() * sizeof(uint16_t), + a.data(), a_stride() * sizeof(uint16_t), + packed_w.data(), + c.data(), cm_stride() * sizeof(uint16_t), cn_stride() * sizeof(uint16_t), + ¶ms); + + // Validate micro-kernel outputs. + for (size_t i = 0; i < m(); i++) { + for (size_t j = 0; j < n(); j++) { + ASSERT_NEAR( + fp32_from_bits(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << 16), + c_ref[i * n() + j], + std::max(1.0e-4f, std::abs(c_ref[i * n() + j]) * 3.0e-2f)) + << "at " << i << ", " << j << ": Mr x Nr x Kr = " << mr() << " x " << nr() + << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); + } + } + } +} + void GemmMicrokernelTester::Test(xnn_f16_gemm_minmax_ukernel_function gemm_minmax, xnn_init_f16_minmax_params_fn init_params) const { ASSERT_LE(m(), mr()); diff --git a/test/gemm-microkernel-tester.h b/test/gemm-microkernel-tester.h index d81eb889e116..8b0c207ab8e4 100644 --- a/test/gemm-microkernel-tester.h +++ b/test/gemm-microkernel-tester.h @@ -228,6 +228,8 @@ class GemmMicrokernelTester { xnn_init_qs8_conv_minmax_params_fn init_params, xnn_qs8_requantize_fn requantize) const; + void Test(xnn_bf16_gemm_minmax_ukernel_function gemm_minmax, xnn_init_bf16_minmax_params_fn init_params) const; + void Test(xnn_f16_gemm_minmax_ukernel_function gemm_minmax, xnn_init_f16_minmax_params_fn init_params) const; void Test(xnn_f16_igemm_minmax_ukernel_function igemm_minmax, xnn_init_f16_minmax_params_fn init_params) const; diff --git a/third_party/cpuinfo.patch b/third_party/cpuinfo.patch deleted file mode 100644 index 6b671d2aab4b..000000000000 --- a/third_party/cpuinfo.patch +++ /dev/null @@ -1,595 +0,0 @@ -diff --git CMakeLists.txt CMakeLists.txt -index 06aee4d..6e42ab9 100644 ---- CMakeLists.txt -+++ CMakeLists.txt -@@ -1,6 +1,4 @@ --CMAKE_MINIMUM_REQUIRED(VERSION 2.8.12 FATAL_ERROR) -- --INCLUDE(GNUInstallDirs) -+CMAKE_MINIMUM_REQUIRED(VERSION 3.5 FATAL_ERROR) - - # ---[ Project and semantic versioning. - PROJECT(cpuinfo C CXX) -@@ -18,32 +16,22 @@ OPTION(CPUINFO_BUILD_MOCK_TESTS "Build cpuinfo mock tests" ON) - OPTION(CPUINFO_BUILD_BENCHMARKS "Build cpuinfo micro-benchmarks" ON) - - # ---[ CMake options -+INCLUDE(GNUInstallDirs) -+ - IF(CPUINFO_BUILD_UNIT_TESTS OR CPUINFO_BUILD_MOCK_TESTS) - ENABLE_TESTING() - ENDIF() - - MACRO(CPUINFO_TARGET_ENABLE_C99 target) -- IF(${CMAKE_VERSION} VERSION_LESS "3.1") -- IF(NOT MSVC) -- TARGET_COMPILE_OPTIONS(${target} PRIVATE -std=c99) -- ENDIF() -- ELSE() -- SET_TARGET_PROPERTIES(${target} PROPERTIES -- C_STANDARD 99 -- C_EXTENSIONS NO) -- ENDIF() -+ SET_TARGET_PROPERTIES(${target} PROPERTIES -+ C_STANDARD 99 -+ C_EXTENSIONS NO) - ENDMACRO() - - MACRO(CPUINFO_TARGET_ENABLE_CXX11 target) -- IF(${CMAKE_VERSION} VERSION_LESS "3.1") -- IF(NOT MSVC) -- TARGET_COMPILE_OPTIONS(${target} PRIVATE -std=c++11) -- ENDIF() -- ELSE() -- SET_TARGET_PROPERTIES(${target} PROPERTIES -- CXX_STANDARD 11 -- CXX_EXTENSIONS NO) -- ENDIF() -+ SET_TARGET_PROPERTIES(${target} PROPERTIES -+ CXX_STANDARD 11 -+ CXX_EXTENSIONS NO) - ENDMACRO() - - MACRO(CPUINFO_TARGET_RUNTIME_LIBRARY target) -diff --git include/cpuinfo.h include/cpuinfo.h -index e2e6564..cffa299 100644 ---- include/cpuinfo.h -+++ include/cpuinfo.h -@@ -361,6 +361,8 @@ enum cpuinfo_uarch { - cpuinfo_uarch_zen = 0x00200109, - /** AMD Zen 2 microarchitecture (7 nm Ryzen and EPYC CPUs). */ - cpuinfo_uarch_zen2 = 0x0020010A, -+ /** AMD Zen 3 microarchitecture. */ -+ cpuinfo_uarch_zen3 = 0x0020010B, - - /** NSC Geode and AMD Geode GX and LX. */ - cpuinfo_uarch_geode = 0x00200200, -@@ -425,6 +427,9 @@ enum cpuinfo_uarch { - /** ARM Neoverse E1. */ - cpuinfo_uarch_neoverse_e1 = 0x00300401, - -+ /** ARM Cortex-X1. */ -+ cpuinfo_uarch_cortex_x1 = 0x00300500, -+ - /** Qualcomm Scorpion. */ - cpuinfo_uarch_scorpion = 0x00400100, - /** Qualcomm Krait. */ -@@ -1455,6 +1460,8 @@ static inline bool cpuinfo_has_x86_sha(void) { - #endif - #if CPUINFO_ARCH_ARM64 - bool atomics; -+ bool sve; -+ bool sve2; - #endif - bool rdm; - bool fp16arith; -@@ -1770,6 +1777,22 @@ static inline bool cpuinfo_has_arm_crc32(void) { - #endif - } - -+static inline bool cpuinfo_has_arm_sve(void) { -+ #if CPUINFO_ARCH_ARM64 -+ return cpuinfo_isa.sve; -+ #else -+ return false; -+ #endif -+} -+ -+static inline bool cpuinfo_has_arm_sve2(void) { -+ #if CPUINFO_ARCH_ARM64 -+ return cpuinfo_isa.sve2; -+ #else -+ return false; -+ #endif -+} -+ - const struct cpuinfo_processor* CPUINFO_ABI cpuinfo_get_processors(void); - const struct cpuinfo_core* CPUINFO_ABI cpuinfo_get_cores(void); - const struct cpuinfo_cluster* CPUINFO_ABI cpuinfo_get_clusters(void); -diff --git src/arm/linux/aarch32-isa.c src/arm/linux/aarch32-isa.c -index 41f9972..df68aa1 100644 ---- src/arm/linux/aarch32-isa.c -+++ src/arm/linux/aarch32-isa.c -@@ -56,24 +56,37 @@ void cpuinfo_arm_linux_decode_isa_from_proc_cpuinfo( - /* - * NEON FP16 compute extension and VQRDMLAH/VQRDMLSH instructions are not indicated in /proc/cpuinfo. - * Use a MIDR-based heuristic to whitelist processors known to support it: -- * - Processors with Qualcomm-modified Cortex-A55 cores -- * - Processors with Qualcomm-modified Cortex-A75 cores -- * - Processors with Qualcomm-modified Cortex-A76 cores -- * - Kirin 980 processor -+ * - Processors with Cortex-A55 cores -+ * - Processors with Cortex-A65 cores -+ * - Processors with Cortex-A75 cores -+ * - Processors with Cortex-A76 cores -+ * - Processors with Cortex-A77 cores -+ * - Processors with Exynos M4 cores -+ * - Processors with Exynos M5 cores -+ * - Neoverse N1 cores - */ -- switch (midr & (CPUINFO_ARM_MIDR_IMPLEMENTER_MASK | CPUINFO_ARM_MIDR_PART_MASK)) { -- case UINT32_C(0x51008020): /* Kryo 385 Gold (Cortex-A75) */ -- case UINT32_C(0x51008030): /* Kryo 385 Silver (Cortex-A55) */ -- case UINT32_C(0x51008040): /* Kryo 485 Gold (Cortex-A76) */ -- isa->fp16arith = true; -- isa->rdm = true; -- break; -- default: -- if (chipset->series == cpuinfo_arm_chipset_series_hisilicon_kirin && chipset->model == 980) { -+ if (chipset->series == cpuinfo_arm_chipset_series_samsung_exynos && chipset->model == 9810) { -+ /* Only little cores of Exynos 9810 support FP16 & RDM */ -+ cpuinfo_log_warning("FP16 arithmetics and RDM disabled: only little cores in Exynos 9810 support these extensions"); -+ } else { -+ switch (midr & (CPUINFO_ARM_MIDR_IMPLEMENTER_MASK | CPUINFO_ARM_MIDR_PART_MASK)) { -+ case UINT32_C(0x4100D050): /* Cortex-A55 */ -+ case UINT32_C(0x4100D060): /* Cortex-A65 */ -+ case UINT32_C(0x4100D0B0): /* Cortex-A76 */ -+ case UINT32_C(0x4100D0C0): /* Neoverse N1 */ -+ case UINT32_C(0x4100D0D0): /* Cortex-A77 */ -+ case UINT32_C(0x4100D0E0): /* Cortex-A76AE */ -+ case UINT32_C(0x4800D400): /* Cortex-A76 (HiSilicon) */ -+ case UINT32_C(0x51008020): /* Kryo 385 Gold (Cortex-A75) */ -+ case UINT32_C(0x51008030): /* Kryo 385 Silver (Cortex-A55) */ -+ case UINT32_C(0x51008040): /* Kryo 485 Gold (Cortex-A76) */ -+ case UINT32_C(0x51008050): /* Kryo 485 Silver (Cortex-A55) */ -+ case UINT32_C(0x53000030): /* Exynos M4 */ -+ case UINT32_C(0x53000040): /* Exynos M5 */ - isa->fp16arith = true; - isa->rdm = true; -- } -- break; -+ break; -+ } - } - - /* -diff --git src/arm/linux/aarch64-isa.c src/arm/linux/aarch64-isa.c -index 619cda5..2000e1a 100644 ---- src/arm/linux/aarch64-isa.c -+++ src/arm/linux/aarch64-isa.c -@@ -6,6 +6,7 @@ - - void cpuinfo_arm64_linux_decode_isa_from_proc_cpuinfo( - uint32_t features, -+ uint32_t features2, - uint32_t midr, - const struct cpuinfo_arm_chipset chipset[restrict static 1], - struct cpuinfo_arm_isa isa[restrict static 1]) -@@ -28,43 +29,56 @@ void cpuinfo_arm64_linux_decode_isa_from_proc_cpuinfo( - if (features & CPUINFO_ARM_LINUX_FEATURE_ATOMICS) { - isa->atomics = true; - } -- const uint32_t fp16arith_mask = CPUINFO_ARM_LINUX_FEATURE_FPHP | CPUINFO_ARM_LINUX_FEATURE_ASIMDHP; -- if ((features & fp16arith_mask) == fp16arith_mask) { -- if (chipset->series == cpuinfo_arm_chipset_series_samsung_exynos && chipset->model == 9810) { -- /* Exynos 9810 reports that it supports FP16 compute, but in fact only little cores do */ -- cpuinfo_log_warning("FP16 arithmetics disabled: only little cores of Exynos 9810 support FP16 compute"); -- } else { -- isa->fp16arith = true; -- } -- } else if (features & CPUINFO_ARM_LINUX_FEATURE_FPHP) { -- cpuinfo_log_warning("FP16 arithmetics disabled: detected support only for scalar operations"); -- } else if (features & CPUINFO_ARM_LINUX_FEATURE_ASIMDHP) { -- cpuinfo_log_warning("FP16 arithmetics disabled: detected support only for SIMD operations"); -- } -+ - /* -- * Many phones ship with an old kernel configuration that doesn't report -- * SQRDMLAH/SQRDMLSH/UQRDMLAH/UQRDMLSH instructions. -+ * Some phones ship with an old kernel configuration that doesn't report NEON FP16 compute extension and SQRDMLAH/SQRDMLSH/UQRDMLAH/UQRDMLSH instructions. - * Use a MIDR-based heuristic to whitelist processors known to support it: -- * - Processors with Qualcomm-modified Cortex-A55 cores -- * - Processors with Qualcomm-modified Cortex-A75 cores -- * - Processors with Qualcomm-modified Cortex-A76 cores -- * - Kirin 980 processor -+ * - Processors with Cortex-A55 cores -+ * - Processors with Cortex-A65 cores -+ * - Processors with Cortex-A75 cores -+ * - Processors with Cortex-A76 cores -+ * - Processors with Cortex-A77 cores -+ * - Processors with Exynos M4 cores -+ * - Processors with Exynos M5 cores -+ * - Neoverse N1 cores - */ -- switch (midr & (CPUINFO_ARM_MIDR_IMPLEMENTER_MASK | CPUINFO_ARM_MIDR_PART_MASK)) { -- case UINT32_C(0x51008020): /* Kryo 385 Gold (Cortex-A75) */ -- case UINT32_C(0x51008030): /* Kryo 385 Silver (Cortex-A55) */ -- case UINT32_C(0x51008040): /* Kryo 485 Gold (Cortex-A76) */ -- isa->rdm = true; -- break; -- default: -- if (features & CPUINFO_ARM_LINUX_FEATURE_ASIMDRDM) { -- isa->rdm = true; -- } -- if (chipset->series == cpuinfo_arm_chipset_series_hisilicon_kirin && chipset->model == 980) { -+ if (chipset->series == cpuinfo_arm_chipset_series_samsung_exynos && chipset->model == 9810) { -+ /* Exynos 9810 reports that it supports FP16 compute, but in fact only little cores do */ -+ cpuinfo_log_warning("FP16 arithmetics and RDM disabled: only little cores in Exynos 9810 support these extensions"); -+ } else { -+ const uint32_t fp16arith_mask = CPUINFO_ARM_LINUX_FEATURE_FPHP | CPUINFO_ARM_LINUX_FEATURE_ASIMDHP; -+ switch (midr & (CPUINFO_ARM_MIDR_IMPLEMENTER_MASK | CPUINFO_ARM_MIDR_PART_MASK)) { -+ case UINT32_C(0x4100D050): /* Cortex-A55 */ -+ case UINT32_C(0x4100D060): /* Cortex-A65 */ -+ case UINT32_C(0x4100D0B0): /* Cortex-A76 */ -+ case UINT32_C(0x4100D0C0): /* Neoverse N1 */ -+ case UINT32_C(0x4100D0D0): /* Cortex-A77 */ -+ case UINT32_C(0x4100D0E0): /* Cortex-A76AE */ -+ case UINT32_C(0x4800D400): /* Cortex-A76 (HiSilicon) */ -+ case UINT32_C(0x51008020): /* Kryo 385 Gold (Cortex-A75) */ -+ case UINT32_C(0x51008030): /* Kryo 385 Silver (Cortex-A55) */ -+ case UINT32_C(0x51008040): /* Kryo 485 Gold (Cortex-A76) */ -+ case UINT32_C(0x51008050): /* Kryo 485 Silver (Cortex-A55) */ -+ case UINT32_C(0x53000030): /* Exynos M4 */ -+ case UINT32_C(0x53000040): /* Exynos M5 */ -+ isa->fp16arith = true; - isa->rdm = true; -- } -- break; -+ break; -+ default: -+ if ((features & fp16arith_mask) == fp16arith_mask) { -+ isa->fp16arith = true; -+ } else if (features & CPUINFO_ARM_LINUX_FEATURE_FPHP) { -+ cpuinfo_log_warning("FP16 arithmetics disabled: detected support only for scalar operations"); -+ } else if (features & CPUINFO_ARM_LINUX_FEATURE_ASIMDHP) { -+ cpuinfo_log_warning("FP16 arithmetics disabled: detected support only for SIMD operations"); -+ } -+ if (features & CPUINFO_ARM_LINUX_FEATURE_ASIMDRDM) { -+ isa->rdm = true; -+ } -+ break; -+ } - } -+ - /* - * Many phones ship with an old kernel configuration that doesn't report UDOT/SDOT instructions. - * Use a MIDR-based heuristic to whitelist processors known to support it. -@@ -98,13 +112,16 @@ void cpuinfo_arm64_linux_decode_isa_from_proc_cpuinfo( - if (features & CPUINFO_ARM_LINUX_FEATURE_JSCVT) { - isa->jscvt = true; - } -- if (features & CPUINFO_ARM_LINUX_FEATURE_ASIMDRDM) { -- isa->rdm = true; -- } - if (features & CPUINFO_ARM_LINUX_FEATURE_JSCVT) { - isa->jscvt = true; - } - if (features & CPUINFO_ARM_LINUX_FEATURE_FCMA) { - isa->fcma = true; - } -+ if (features & CPUINFO_ARM_LINUX_FEATURE_SVE) { -+ isa->sve = true; -+ } -+ if (features2 & CPUINFO_ARM_LINUX_FEATURE2_SVE2) { -+ isa->sve2 = true; -+ } - } -diff --git src/arm/linux/api.h src/arm/linux/api.h -index 2597e49..1c09f82 100644 ---- src/arm/linux/api.h -+++ src/arm/linux/api.h -@@ -111,6 +111,28 @@ struct cpuinfo_arm_linux_proc_cpuinfo_cache { - #define CPUINFO_ARM_LINUX_FEATURE_ILRCPC UINT32_C(0x04000000) - #define CPUINFO_ARM_LINUX_FEATURE_FLAGM UINT32_C(0x08000000) - #define CPUINFO_ARM_LINUX_FEATURE_SSBS UINT32_C(0x10000000) -+ #define CPUINFO_ARM_LINUX_FEATURE_SB UINT32_C(0x20000000) -+ #define CPUINFO_ARM_LINUX_FEATURE_PACA UINT32_C(0x40000000) -+ #define CPUINFO_ARM_LINUX_FEATURE_PACG UINT32_C(0x80000000) -+ -+ #define CPUINFO_ARM_LINUX_FEATURE2_DCPODP UINT32_C(0x00000001) -+ #define CPUINFO_ARM_LINUX_FEATURE2_SVE2 UINT32_C(0x00000002) -+ #define CPUINFO_ARM_LINUX_FEATURE2_SVEAES UINT32_C(0x00000004) -+ #define CPUINFO_ARM_LINUX_FEATURE2_SVEPMULL UINT32_C(0x00000008) -+ #define CPUINFO_ARM_LINUX_FEATURE2_SVEBITPERM UINT32_C(0x00000010) -+ #define CPUINFO_ARM_LINUX_FEATURE2_SVESHA3 UINT32_C(0x00000020) -+ #define CPUINFO_ARM_LINUX_FEATURE2_SVESM4 UINT32_C(0x00000040) -+ #define CPUINFO_ARM_LINUX_FEATURE2_FLAGM2 UINT32_C(0x00000080) -+ #define CPUINFO_ARM_LINUX_FEATURE2_FRINT UINT32_C(0x00000100) -+ #define CPUINFO_ARM_LINUX_FEATURE2_SVEI8MM UINT32_C(0x00000200) -+ #define CPUINFO_ARM_LINUX_FEATURE2_SVEF32MM UINT32_C(0x00000400) -+ #define CPUINFO_ARM_LINUX_FEATURE2_SVEF64MM UINT32_C(0x00000800) -+ #define CPUINFO_ARM_LINUX_FEATURE2_SVEBF16 UINT32_C(0x00001000) -+ #define CPUINFO_ARM_LINUX_FEATURE2_I8MM UINT32_C(0x00002000) -+ #define CPUINFO_ARM_LINUX_FEATURE2_BF16 UINT32_C(0x00004000) -+ #define CPUINFO_ARM_LINUX_FEATURE2_DGH UINT32_C(0x00008000) -+ #define CPUINFO_ARM_LINUX_FEATURE2_RNG UINT32_C(0x00010000) -+ #define CPUINFO_ARM_LINUX_FEATURE2_BTI UINT32_C(0x00020000) - #endif - - #define CPUINFO_ARM_LINUX_VALID_ARCHITECTURE UINT32_C(0x00010000) -@@ -146,9 +168,7 @@ struct cpuinfo_arm_linux_processor { - struct cpuinfo_arm_linux_proc_cpuinfo_cache proc_cpuinfo_cache; - #endif - uint32_t features; --#if CPUINFO_ARCH_ARM - uint32_t features2; --#endif - /** - * Main ID Register value. - */ -@@ -282,9 +302,13 @@ CPUINFO_INTERNAL bool cpuinfo_arm_linux_parse_proc_cpuinfo( - const struct cpuinfo_arm_chipset chipset[restrict static 1], - struct cpuinfo_arm_isa isa[restrict static 1]); - #elif CPUINFO_ARCH_ARM64 -- CPUINFO_INTERNAL uint32_t cpuinfo_arm_linux_hwcap_from_getauxval(void); -+ CPUINFO_INTERNAL void cpuinfo_arm_linux_hwcap_from_getauxval( -+ uint32_t hwcap[restrict static 1], -+ uint32_t hwcap2[restrict static 1]); -+ - CPUINFO_INTERNAL void cpuinfo_arm64_linux_decode_isa_from_proc_cpuinfo( - uint32_t features, -+ uint32_t features2, - uint32_t midr, - const struct cpuinfo_arm_chipset chipset[restrict static 1], - struct cpuinfo_arm_isa isa[restrict static 1]); -diff --git src/arm/linux/hwcap.c src/arm/linux/hwcap.c -index 36d0d91..35e9994 100644 ---- src/arm/linux/hwcap.c -+++ src/arm/linux/hwcap.c -@@ -29,12 +29,10 @@ - mock_hwcap = hwcap; - } - -- #if CPUINFO_ARCH_ARM -- static uint32_t mock_hwcap2 = 0; -- void cpuinfo_set_hwcap2(uint32_t hwcap2) { -- mock_hwcap2 = hwcap2; -- } -- #endif -+ static uint32_t mock_hwcap2 = 0; -+ void cpuinfo_set_hwcap2(uint32_t hwcap2) { -+ mock_hwcap2 = hwcap2; -+ } - #endif - - -@@ -145,11 +143,17 @@ - } - #endif /* __ANDROID__ */ - #elif CPUINFO_ARCH_ARM64 -- uint32_t cpuinfo_arm_linux_hwcap_from_getauxval(void) { -+ void cpuinfo_arm_linux_hwcap_from_getauxval( -+ uint32_t hwcap[restrict static 1], -+ uint32_t hwcap2[restrict static 1]) -+ { - #if CPUINFO_MOCK -- return mock_hwcap; -+ *hwcap = mock_hwcap; -+ *hwcap2 = mock_hwcap2; - #else -- return (uint32_t) getauxval(AT_HWCAP); -+ *hwcap = (uint32_t) getauxval(AT_HWCAP); -+ *hwcap2 = (uint32_t) getauxval(AT_HWCAP2); -+ return ; - #endif - } - #endif -diff --git src/arm/linux/init.c src/arm/linux/init.c -index 89d957e..23d8439 100644 ---- src/arm/linux/init.c -+++ src/arm/linux/init.c -@@ -277,10 +277,11 @@ void cpuinfo_arm_linux_init(void) { - last_midr, last_architecture_version, last_architecture_flags, - &chipset, &cpuinfo_isa); - #elif CPUINFO_ARCH_ARM64 -+ uint32_t isa_features = 0, isa_features2 = 0; - /* getauxval is always available on ARM64 Android */ -- const uint32_t isa_features = cpuinfo_arm_linux_hwcap_from_getauxval(); -+ cpuinfo_arm_linux_hwcap_from_getauxval(&isa_features, &isa_features2); - cpuinfo_arm64_linux_decode_isa_from_proc_cpuinfo( -- isa_features, last_midr, &chipset, &cpuinfo_isa); -+ isa_features, isa_features2, last_midr, &chipset, &cpuinfo_isa); - #endif - - /* Detect min/max frequency and package ID */ -diff --git src/arm/mach/init.c src/arm/mach/init.c -index d820744..dbea578 100644 ---- src/arm/mach/init.c -+++ src/arm/mach/init.c -@@ -24,7 +24,6 @@ - #ifndef CPUFAMILY_ARM_LIGHTNING_THUNDER - #define CPUFAMILY_ARM_LIGHTNING_THUNDER 0x462504D2 - #endif -- - #ifndef CPUFAMILY_ARM_FIRESTORM_ICESTORM - #define CPUFAMILY_ARM_FIRESTORM_ICESTORM 0x1B588BB3 - #endif -@@ -349,6 +348,7 @@ void cpuinfo_arm_mach_init(void) { - case CPUFAMILY_ARM_MONSOON_MISTRAL: - case CPUFAMILY_ARM_VORTEX_TEMPEST: - case CPUFAMILY_ARM_LIGHTNING_THUNDER: -+ case CPUFAMILY_ARM_FIRESTORM_ICESTORM: - #if CPUINFO_ARCH_ARM64 - cpuinfo_isa.atomics = true; - #endif -@@ -360,8 +360,10 @@ void cpuinfo_arm_mach_init(void) { - * ARMv8.2 optional dot-product instructions, so we currently whitelist CPUs - * known to support these instruction. - */ -- if (cpu_family == CPUFAMILY_ARM_LIGHTNING_THUNDER) { -- cpuinfo_isa.dot = true; -+ switch (cpu_family) { -+ case CPUFAMILY_ARM_LIGHTNING_THUNDER: -+ case CPUFAMILY_ARM_FIRESTORM_ICESTORM: -+ cpuinfo_isa.dot = true; - } - - uint32_t num_clusters = 1; -diff --git src/arm/midr.h src/arm/midr.h -index 2638517..739dc19 100644 ---- src/arm/midr.h -+++ src/arm/midr.h -@@ -171,9 +171,10 @@ inline static bool midr_is_kryo_gold(uint32_t midr) { - inline static uint32_t midr_score_core(uint32_t midr) { - const uint32_t core_mask = CPUINFO_ARM_MIDR_IMPLEMENTER_MASK | CPUINFO_ARM_MIDR_PART_MASK; - switch (midr & core_mask) { -- case UINT32_C(0x53000040): /* Exynos M5 */ - case UINT32_C(0x53000030): /* Exynos M4 */ -- /* These cores are in big role w.r.t Cortex-A75 or Cortex-A76 */ -+ case UINT32_C(0x53000040): /* Exynos M5 */ -+ case UINT32_C(0x4100D440): /* Cortex-X1 */ -+ /* These cores are in big role w.r.t Cortex-A75/-A76/-A77/-A78 */ - return 6; - case UINT32_C(0x4E000030): /* Denver 2 */ - case UINT32_C(0x53000010): /* Exynos M1 and Exynos M2 */ -diff --git src/arm/uarch.c src/arm/uarch.c -index 0d7a7d7..8b5362b 100644 ---- src/arm/uarch.c -+++ src/arm/uarch.c -@@ -94,6 +94,9 @@ void cpuinfo_arm_decode_vendor_uarch( - case 0xD41: /* Cortex-A78 */ - *uarch = cpuinfo_uarch_cortex_a78; - break; -+ case 0xD44: /* Cortex-X1 */ -+ *uarch = cpuinfo_uarch_cortex_x1; -+ break; - #if CPUINFO_ARCH_ARM64 && !defined(__ANDROID__) - case 0xD4A: - *uarch = cpuinfo_uarch_neoverse_e1; -diff --git src/init.c src/init.c -index f703e8e..d61e7be 100644 ---- src/init.c -+++ src/init.c -@@ -35,8 +35,6 @@ bool CPUINFO_ABI cpuinfo_initialize(void) { - #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 - #if defined(__linux__) - pthread_once(&init_guard, &cpuinfo_arm_linux_init); -- #elif defined(TARGET_OS_IPHONE) && TARGET_OS_IPHONE -- pthread_once(&init_guard, &cpuinfo_arm_mach_init); - #elif defined(__MACH__) && defined(__APPLE__) - pthread_once(&init_guard, &cpuinfo_arm_mach_init); - #else -diff --git src/x86/uarch.c src/x86/uarch.c -index ecaa762..3705499 100644 ---- src/x86/uarch.c -+++ src/x86/uarch.c -@@ -209,9 +209,23 @@ enum cpuinfo_uarch cpuinfo_x86_decode_uarch( - return cpuinfo_uarch_zen; - case 0x31: // Rome, Castle Peak - case 0x60: // Renoir -+ case 0x68: // Lucienne - case 0x71: // Matisse -+ case 0x90: // Van Gogh -+ case 0x98: // Mero - return cpuinfo_uarch_zen2; - } -+ break; -+ case 0x19: -+ switch (model_info->model) { -+ case 0x01: // Genesis -+ case 0x21: // Vermeer -+ case 0x30: // Badami, Trento -+ case 0x40: // Rembrandt -+ case 0x50: // Cezanne -+ return cpuinfo_uarch_zen3; -+ } -+ break; - } - break; - case cpuinfo_vendor_hygon: -diff --git src/x86/windows/init.c src/x86/windows/init.c -index 9a23bd7..274075c 100644 ---- src/x86/windows/init.c -+++ src/x86/windows/init.c -@@ -95,6 +95,15 @@ static void cpuinfo_x86_count_caches( - *l4_count_ptr = l4_count; - } - -+static bool cpuinfo_x86_windows_is_wine(void) { -+ HMODULE ntdll = GetModuleHandleW(L"ntdll.dll"); -+ if (ntdll == NULL) { -+ return false; -+ } -+ -+ return GetProcAddress(ntdll, "wine_get_version") != NULL; -+} -+ - BOOL CALLBACK cpuinfo_x86_windows_init(PINIT_ONCE init_once, PVOID parameter, PVOID* context) { - struct cpuinfo_processor* processors = NULL; - struct cpuinfo_core* cores = NULL; -@@ -108,6 +117,7 @@ BOOL CALLBACK cpuinfo_x86_windows_init(PINIT_ONCE init_once, PVOID parameter, PV - PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX processor_infos = NULL; - - HANDLE heap = GetProcessHeap(); -+ const bool is_wine = cpuinfo_x86_windows_is_wine(); - - struct cpuinfo_x86_processor x86_processor; - ZeroMemory(&x86_processor, sizeof(x86_processor)); -@@ -121,7 +131,8 @@ BOOL CALLBACK cpuinfo_x86_windows_init(PINIT_ONCE init_once, PVOID parameter, PV - x86_processor.topology.thread_bits_offset + x86_processor.topology.thread_bits_length, - x86_processor.topology.core_bits_offset + x86_processor.topology.core_bits_length); - -- const uint32_t max_group_count = (uint32_t) GetMaximumProcessorGroupCount(); -+ /* WINE doesn't implement GetMaximumProcessorGroupCount and aborts when calling it */ -+ const uint32_t max_group_count = is_wine ? 1 : (uint32_t) GetMaximumProcessorGroupCount(); - cpuinfo_log_debug("detected %"PRIu32" processor groups", max_group_count); - - uint32_t processors_count = 0; -diff --git test/mock/galaxy-s9-us.cc test/mock/galaxy-s9-us.cc -index ceea969..91c4868 100644 ---- test/mock/galaxy-s9-us.cc -+++ test/mock/galaxy-s9-us.cc -@@ -817,4 +817,4 @@ int main(int argc, char* argv[]) { - cpuinfo_initialize(); - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); --} -+} -\ No newline at end of file -diff --git tools/cpu-info.c tools/cpu-info.c -index 55d654f..30ec633 100644 ---- tools/cpu-info.c -+++ tools/cpu-info.c -@@ -129,6 +129,8 @@ static const char* uarch_to_string(enum cpuinfo_uarch uarch) { - return "Zen"; - case cpuinfo_uarch_zen2: - return "Zen 2"; -+ case cpuinfo_uarch_zen3: -+ return "Zen 3"; - case cpuinfo_uarch_geode: - return "Geode"; - case cpuinfo_uarch_bobcat: -@@ -185,6 +187,8 @@ static const char* uarch_to_string(enum cpuinfo_uarch uarch) { - return "Cortex-A77"; - case cpuinfo_uarch_cortex_a78: - return "Cortex-A78"; -+ case cpuinfo_uarch_cortex_x1: -+ return "Cortex-X1"; - case cpuinfo_uarch_scorpion: - return "Scorpion"; - case cpuinfo_uarch_krait: -diff --git tools/isa-info.c tools/isa-info.c -index 8365846..92abb57 100644 ---- tools/isa-info.c -+++ tools/isa-info.c -@@ -161,6 +161,10 @@ int main(int argc, char** argv) { - printf("\tARM v8.3 JS conversion: %s\n", cpuinfo_has_arm_jscvt() ? "yes" : "no"); - printf("\tARM v8.3 complex: %s\n", cpuinfo_has_arm_fcma() ? "yes" : "no"); - -+ printf("SIMD extensions:\n"); -+ printf("\tARM SVE: %s\n", cpuinfo_has_arm_sve() ? "yes" : "no"); -+ printf("\tARM SVE 2: %s\n", cpuinfo_has_arm_sve2() ? "yes" : "no"); -+ - printf("Cryptography extensions:\n"); - printf("\tAES: %s\n", cpuinfo_has_arm_aes() ? "yes" : "no"); - printf("\tSHA1: %s\n", cpuinfo_has_arm_sha1() ? "yes" : "no"); diff --git a/tools/xnncommon.py b/tools/xnncommon.py index 99a6c2b48380..d5ea7239d605 100644 --- a/tools/xnncommon.py +++ b/tools/xnncommon.py @@ -34,6 +34,7 @@ def _remove_duplicate_newlines(text): # status for the ISA. Only ISAs that can be enabled/disabled have an entry. _ISA_TO_MACRO_MAP = { "neonfp16arith": "XNN_ENABLE_ARM_FP16", + "neonbf16": "XNN_ENABLE_ARM_BF16", "neondot": "XNN_ENABLE_ARM_DOTPROD", } @@ -44,6 +45,7 @@ def _remove_duplicate_newlines(text): "neonfma": ["aarch32", "aarch64"], "neonv8": ["aarch32", "aarch64"], "neonfp16arith": ["aarch32", "aarch64"], + "neonbf16": ["aarch32", "aarch64"], "neondot": ["aarch32", "aarch64"], "sse": ["x86-32", "x86-64"], "sse2": ["x86-32", "x86-64"], @@ -69,6 +71,7 @@ def _remove_duplicate_newlines(text): "neonfma": "TEST_REQUIRES_ARM_NEON_FMA", "neonv8": "TEST_REQUIRES_ARM_NEON_V8", "neonfp16arith": "TEST_REQUIRES_ARM_NEON_FP16_ARITH", + "neonbf16": "TEST_REQUIRES_ARM_NEON_BF16", "neondot": "TEST_REQUIRES_ARM_NEON_DOT", "sse": "TEST_REQUIRES_X86_SSE", "sse2": "TEST_REQUIRES_X86_SSE2",