Skip to content

Commit

Permalink
BF16 GEMM microkernels
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 465928283
  • Loading branch information
Maratyszcza authored and xnnpack-bot committed Aug 8, 2022
1 parent 22fa3a3 commit fcedd82
Show file tree
Hide file tree
Showing 43 changed files with 18,825 additions and 46 deletions.
180 changes: 180 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -4253,6 +4253,11 @@ PROD_NEONFMA_MICROKERNEL_SRCS = [
]

ALL_NEONFMA_MICROKERNEL_SRCS = [
"src/bf16-gemm/gen/1x4c8-minmax-neonfma-shland.c",
"src/bf16-gemm/gen/2x4c8-minmax-neonfma-shland.c",
"src/bf16-gemm/gen/3x4c8-minmax-neonfma-shland.c",
"src/bf16-gemm/gen/4x4c8-minmax-neonfma-shland.c",
"src/bf16-gemm/gen/5x4c8-minmax-neonfma-shland.c",
"src/f32-dwconv/gen/up4x3-minmax-neonfma-acc2.c",
"src/f32-dwconv/gen/up4x3-minmax-neonfma.c",
"src/f32-dwconv/gen/up4x4-minmax-neonfma-acc2.c",
Expand Down Expand Up @@ -4499,6 +4504,11 @@ PROD_AARCH64_NEON_MICROKERNEL_SRCS = [
]

ALL_AARCH64_NEON_MICROKERNEL_SRCS = [
"src/bf16-gemm/gen/1x4c8-minmax-neonfma-zip.c",
"src/bf16-gemm/gen/2x4c8-minmax-neonfma-zip.c",
"src/bf16-gemm/gen/3x4c8-minmax-neonfma-zip.c",
"src/bf16-gemm/gen/4x4c8-minmax-neonfma-zip.c",
"src/bf16-gemm/gen/5x4c8-minmax-neonfma-zip.c",
"src/f32-conv-hwc/gen/3x3s2p0p1c3x4-neonfma-2x1.c",
"src/f32-conv-hwc/gen/3x3s2p0p1c3x4-neonfma-2x2.c",
"src/f32-conv-hwc/gen/3x3s2p0p1c3x8-neonfma-2x1.c",
Expand Down Expand Up @@ -5180,6 +5190,32 @@ ALL_AARCH64_NEONFP16ARITH_MICROKERNEL_SRCS = [
"src/math/sigmoid-f16-neonfp16arith-rr2-p3-div.c",
]

PROD_NEONBF16_MICROKERNEL_SRCS = [
]

ALL_NEONBF16_MICROKERNEL_SRCS = [
"src/bf16-gemm/gen/1x8c2-minmax-neonbf16-bfdot-lane-ld128.c",
"src/bf16-gemm/gen/4x8c2-minmax-neonbf16-bfdot-lane-ld128.c",
"src/bf16-gemm/gen/5x8c2-minmax-neonbf16-bfdot-lane-ld128.c",
"src/bf16-gemm/gen/6x8c2-minmax-neonbf16-bfdot-lane-ld128.c",
"src/bf16-gemm/gen/1x4c8-minmax-neonbf16-bfdot.c",
"src/bf16-gemm/gen/2x4c8-minmax-neonbf16-bfdot.c",
"src/bf16-gemm/gen/3x4c8-minmax-neonbf16-bfdot.c",
"src/bf16-gemm/gen/4x4c8-minmax-neonbf16-bfdot.c",
"src/bf16-gemm/gen/5x4c8-minmax-neonbf16-bfdot.c",
"src/bf16-gemm/gen/1x4c8-minmax-neonbf16-bfmlal.c",
"src/bf16-gemm/gen/2x4c8-minmax-neonbf16-bfmlal.c",
"src/bf16-gemm/gen/3x4c8-minmax-neonbf16-bfmlal.c",
"src/bf16-gemm/gen/4x4c8-minmax-neonbf16-bfmlal.c",
"src/bf16-gemm/gen/5x4c8-minmax-neonbf16-bfmlal.c",
]

PROD_AARCH64_NEONBF16_MICROKERNEL_SRCS = [
]

ALL_AARCH64_NEONBF16_MICROKERNEL_SRCS = [
]

PROD_NEONDOT_MICROKERNEL_SRCS = [
"src/qc8-gemm/gen/1x8c4-minmax-fp32-neondot.c",
"src/qc8-gemm/gen/1x16c4-minmax-fp32-neondot.c",
Expand Down Expand Up @@ -9336,6 +9372,76 @@ xnnpack_cc_library(
],
)

xnnpack_cc_library(
name = "neonbf16_bench_microkernels",
aarch32_copts = [
"-marm",
"-march=armv8.2-a+bf16",
"-mfpu=neon-fp-armv8",
],
aarch32_srcs = ALL_NEONBF16_MICROKERNEL_SRCS,
aarch64_copts = ["-march=armv8.2-a+bf16"],
aarch64_srcs = ALL_NEONBF16_MICROKERNEL_SRCS + ALL_AARCH64_NEONBF16_MICROKERNEL_SRCS,
gcc_copts = xnnpack_gcc_std_copts(),
msvc_copts = xnnpack_msvc_std_copts(),
deps = [
":common",
":math",
":microkernels_h",
":params",
":tables",
":unaligned",
],
)

xnnpack_cc_library(
name = "neonbf16_prod_microkernels",
aarch32_copts = [
"-marm",
"-march=armv8.2-a+bf16",
"-mfpu=neon-fp-armv8",
],
aarch32_srcs = PROD_NEONBF16_MICROKERNEL_SRCS,
aarch64_copts = ["-march=armv8.2-a+bf16"],
aarch64_srcs = PROD_NEONBF16_MICROKERNEL_SRCS + PROD_AARCH64_NEONBF16_MICROKERNEL_SRCS,
gcc_copts = xnnpack_gcc_std_copts(),
msvc_copts = xnnpack_msvc_std_copts(),
deps = [
":common",
":math",
":microkernels_h",
":params",
":tables",
":unaligned",
],
)

xnnpack_cc_library(
name = "neonbf16_test_microkernels",
aarch32_copts = [
"-marm",
"-march=armv8.2-a+bf16",
"-mfpu=neon-fp-armv8",
],
aarch32_srcs = ALL_NEONBF16_MICROKERNEL_SRCS,
aarch64_copts = ["-march=armv8.2-a+bf16"],
aarch64_srcs = ALL_NEONBF16_MICROKERNEL_SRCS + ALL_AARCH64_NEONBF16_MICROKERNEL_SRCS,
copts = [
"-UNDEBUG",
"-DXNN_TEST_MODE=1",
],
gcc_copts = xnnpack_gcc_std_copts(),
msvc_copts = xnnpack_msvc_std_copts(),
deps = [
":common",
":math",
":microkernels_h",
":params",
":tables",
":unaligned",
],
)

xnnpack_cc_library(
name = "neondot_bench_microkernels",
aarch32_copts = [
Expand Down Expand Up @@ -10381,6 +10487,9 @@ xnnpack_aggregate_library(
defines = select({
":arm_fp16_enabled": ["XNN_ENABLE_ARM_FP16=1"],
"//conditions:default": ["XNN_ENABLE_ARM_FP16=0"],
}) + select({
":arm_bf16_enabled": ["XNN_ENABLE_ARM_BF16=1"],
"//conditions:default": ["XNN_ENABLE_ARM_BF16=0"],
}) + select({
":arm_dotprod_enabled": ["XNN_ENABLE_ARM_DOTPROD=1"],
"//conditions:default": ["XNN_ENABLE_ARM_DOTPROD=0"],
Expand All @@ -10390,6 +10499,9 @@ xnnpack_aggregate_library(
] + select({
":arm_fp16_enabled": [":neonfp16arith_prod_microkernels"],
"//conditions:default": [],
}) + select({
":arm_bf16_enabled": [":neonbf16_prod_microkernels"],
"//conditions:default": [],
}) + select({
":arm_dotprod_enabled": [":neondot_prod_microkernels"],
"//conditions:default": [],
Expand Down Expand Up @@ -10440,6 +10552,9 @@ xnnpack_aggregate_library(
defines = select({
":arm_fp16_enabled": ["XNN_ENABLE_ARM_FP16=1"],
"//conditions:default": ["XNN_ENABLE_ARM_FP16=0"],
}) + select({
":arm_bf16_enabled": ["XNN_ENABLE_ARM_BF16=1"],
"//conditions:default": ["XNN_ENABLE_ARM_BF16=0"],
}) + select({
":arm_dotprod_enabled": ["XNN_ENABLE_ARM_DOTPROD=1"],
"//conditions:default": ["XNN_ENABLE_ARM_DOTPROD=0"],
Expand All @@ -10449,6 +10564,9 @@ xnnpack_aggregate_library(
] + select({
":arm_fp16_enabled": [":neonfp16arith_bench_microkernels"],
"//conditions:default": [],
}) + select({
":arm_bf16_enabled": [":neonbf16_bench_microkernels"],
"//conditions:default": [],
}) + select({
":arm_dotprod_enabled": [":neondot_bench_microkernels"],
"//conditions:default": [],
Expand Down Expand Up @@ -10499,6 +10617,9 @@ xnnpack_aggregate_library(
defines = select({
":arm_fp16_enabled": ["XNN_ENABLE_ARM_FP16=1"],
"//conditions:default": ["XNN_ENABLE_ARM_FP16=0"],
}) + select({
":arm_bf16_enabled": ["XNN_ENABLE_ARM_BF16=1"],
"//conditions:default": ["XNN_ENABLE_ARM_BF16=0"],
}) + select({
":arm_dotprod_enabled": ["XNN_ENABLE_ARM_DOTPROD=1"],
"//conditions:default": ["XNN_ENABLE_ARM_DOTPROD=0"],
Expand All @@ -10508,6 +10629,9 @@ xnnpack_aggregate_library(
] + select({
":arm_fp16_enabled": [":neonfp16arith_prod_microkernels"],
"//conditions:default": [],
}) + select({
":arm_bf16_enabled": [":neonbf16_prod_microkernels"],
"//conditions:default": [],
}) + select({
":arm_dotprod_enabled": [":neondot_prod_microkernels"],
"//conditions:default": [],
Expand Down Expand Up @@ -10558,6 +10682,9 @@ xnnpack_aggregate_library(
defines = select({
":arm_fp16_enabled": ["XNN_ENABLE_ARM_FP16=1"],
"//conditions:default": ["XNN_ENABLE_ARM_FP16=0"],
}) + select({
":arm_bf16_enabled": ["XNN_ENABLE_ARM_BF16=1"],
"//conditions:default": ["XNN_ENABLE_ARM_BF16=0"],
}) + select({
":arm_dotprod_enabled": ["XNN_ENABLE_ARM_DOTPROD=1"],
"//conditions:default": ["XNN_ENABLE_ARM_DOTPROD=0"],
Expand All @@ -10567,6 +10694,9 @@ xnnpack_aggregate_library(
] + select({
":arm_fp16_enabled": [":neonfp16arith_test_microkernels"],
"//conditions:default": [],
}) + select({
":arm_bf16_enabled": [":neonbf16_test_microkernels"],
"//conditions:default": [],
}) + select({
":arm_dotprod_enabled": [":neondot_test_microkernels"],
"//conditions:default": [],
Expand Down Expand Up @@ -11378,6 +11508,18 @@ xnnpack_benchmark(
deps = MICROKERNEL_BENCHMARK_DEPS,
)

xnnpack_benchmark(
name = "bf16_gemm_bench",
srcs = [
"bench/bf16-gemm.cc",
"bench/gemm.h",
],
deps = MICROKERNEL_BENCHMARK_DEPS + [
":math",
":packing",
],
)

xnnpack_benchmark(
name = "f16_igemm_bench",
srcs = [
Expand Down Expand Up @@ -12520,6 +12662,16 @@ xnnpack_cc_library(
],
)

xnnpack_unit_test(
name = "bf16_gemm_minmax_test",
srcs = [
"test/bf16-gemm-minmax.cc",
],
deps = MICROKERNEL_TEST_DEPS + [
":gemm_microkernel_tester",
],
)

xnnpack_unit_test(
name = "f16_f32_vcvt_test",
srcs = [
Expand Down Expand Up @@ -15574,6 +15726,18 @@ config_setting(
define_values = {"xnn_enable_arm_fp16": "false"},
)

# Enables usage of ARM BF16 (BF16 arithmetics) kernels.
config_setting(
name = "xnn_enable_arm_bf16_explicit_true",
define_values = {"xnn_enable_arm_bf16": "true"},
)

# Disables usage of ARM BF16 (BF16 arithmetics) kernels.
config_setting(
name = "xnn_enable_arm_bf16_explicit_false",
define_values = {"xnn_enable_arm_bf16": "false"},
)

# Enables usage of ARM DotProd (integer dot product) kernels.
config_setting(
name = "xnn_enable_arm_dotprod_explicit_true",
Expand Down Expand Up @@ -16011,6 +16175,22 @@ alias(
}),
)

selects.config_setting_group(
name = "arm_bf16_enabled_by_default",
match_any = [
":aarch64",
],
)

alias(
name = "arm_bf16_enabled",
actual = select({
":xnn_enable_arm_bf16_explicit_true": ":xnn_enable_arm_bf16_explicit_true",
":xnn_enable_arm_bf16_explicit_false": ":xnn_enable_arm_bf16_explicit_true",
"//conditions:default": ":arm_bf16_enabled_by_default",
}),
)

selects.config_setting_group(
name = "arm_dotprod_enabled_by_default",
match_any = [
Expand Down
Loading

0 comments on commit fcedd82

Please sign in to comment.