diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 5ea08643d0f8..7529c31688f4 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -189,6 +189,14 @@ const x86Intrinsic intrinsic_defs[] = { {"llvm.x86.avx2.pmadd.ub.sw", Int(16, 16), "saturating_dot_product", {UInt(8, 32), Int(8, 32)}, Target::AVX2}, {"llvm.x86.ssse3.pmadd.ub.sw.128", Int(16, 8), "saturating_dot_product", {UInt(8, 16), Int(8, 16)}, Target::SSE41}, + // Horizontal widening adds using 2-way dot products. + {"hadd_pmadd_u8_sse3", UInt(16, 8), "horizontal_widening_add", {UInt(8, 16)}, Target::SSE41}, + {"hadd_pmadd_u8_sse3", Int(16, 8), "horizontal_widening_add", {UInt(8, 16)}, Target::SSE41}, + {"hadd_pmadd_i8_sse3", Int(16, 8), "horizontal_widening_add", {Int(8, 16)}, Target::SSE41}, + {"hadd_pmadd_u8_avx2", UInt(16, 16), "horizontal_widening_add", {UInt(8, 32)}, Target::AVX2}, + {"hadd_pmadd_u8_avx2", Int(16, 16), "horizontal_widening_add", {UInt(8, 32)}, Target::AVX2}, + {"hadd_pmadd_i8_avx2", Int(16, 16), "horizontal_widening_add", {Int(8, 32)}, Target::AVX2}, + {"llvm.x86.avx512.pmaddw.d.512", Int(32, 16), "dot_product", {Int(16, 32), Int(16, 32)}, Target::AVX512_Skylake}, {"llvm.x86.avx512.pmaddw.d.512", Int(32, 16), "dot_product", {Int(16, 32), Int(16, 32)}, Target::AVX512_Cannonlake}, {"llvm.x86.avx2.pmadd.wd", Int(32, 8), "dot_product", {Int(16, 16), Int(16, 16)}, Target::AVX2}, @@ -595,6 +603,7 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init enum { CombineInit = 1 << 0, SwapOperands = 1 << 1, + SingleArg = 1 << 2, }; }; // clang-format off @@ -624,8 +633,12 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init {VectorReduce::Add, 2, wild_f32x_ * wild_f32x_, "dot_product", BFloat(16), Pattern::CombineInit}, // One could do a horizontal widening addition with - // dot_product against a vector of ones. Currently disabled - // because I haven't found case where it's clearly better. + // other dot_products against a vector of ones. Currently disabled + // because I haven't found other cases where it's clearly better. + + {VectorReduce::Add, 2, u16(wild_u8x_), "horizontal_widening_add", {}, Pattern::SingleArg}, + {VectorReduce::Add, 2, i16(wild_u8x_), "horizontal_widening_add", {}, Pattern::SingleArg}, + {VectorReduce::Add, 2, i16(wild_i8x_), "horizontal_widening_add", {}, Pattern::SingleArg}, }; // clang-format on @@ -635,33 +648,61 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init continue; } if (expr_match(p.pattern, op->value, matches)) { - Expr a = matches[0]; - Expr b = matches[1]; - if (p.flags & Pattern::SwapOperands) { - std::swap(a, b); - } - if (p.narrow_type.bits() > 0) { - a = lossless_cast(p.narrow_type.with_lanes(a.type().lanes()), a); - b = lossless_cast(p.narrow_type.with_lanes(b.type().lanes()), b); - } - if (!a.defined() || !b.defined()) { - continue; - } + if (p.flags & Pattern::SingleArg) { + Expr a = matches[0]; - if (init.defined() && (p.flags & Pattern::CombineInit)) { - value = call_overloaded_intrin(op->type, p.intrin, {init, a, b}); - if (value) { - return; + if (p.narrow_type.bits() > 0) { + a = lossless_cast(p.narrow_type.with_lanes(a.type().lanes()), a); + } + if (!a.defined()) { + continue; + } + + if (init.defined() && (p.flags & Pattern::CombineInit)) { + value = call_overloaded_intrin(op->type, p.intrin, {init, a}); + if (value) { + return; + } + } else { + value = call_overloaded_intrin(op->type, p.intrin, {a}); + if (value) { + if (init.defined()) { + Value *x = value; + Value *y = codegen(init); + value = builder->CreateAdd(x, y); + } + return; + } } } else { - value = call_overloaded_intrin(op->type, p.intrin, {a, b}); - if (value) { - if (init.defined()) { - Value *x = value; - Value *y = codegen(init); - value = builder->CreateAdd(x, y); + Expr a = matches[0]; + Expr b = matches[1]; + if (p.flags & Pattern::SwapOperands) { + std::swap(a, b); + } + if (p.narrow_type.bits() > 0) { + a = lossless_cast(p.narrow_type.with_lanes(a.type().lanes()), a); + b = lossless_cast(p.narrow_type.with_lanes(b.type().lanes()), b); + } + if (!a.defined() || !b.defined()) { + continue; + } + + if (init.defined() && (p.flags & Pattern::CombineInit)) { + value = call_overloaded_intrin(op->type, p.intrin, {init, a, b}); + if (value) { + return; + } + } else { + value = call_overloaded_intrin(op->type, p.intrin, {a, b}); + if (value) { + if (init.defined()) { + Value *x = value; + Value *y = codegen(init); + value = builder->CreateAdd(x, y); + } + return; } - return; } } } diff --git a/src/runtime/x86_avx2.ll b/src/runtime/x86_avx2.ll index a73736860682..1a80f5b583d3 100644 --- a/src/runtime/x86_avx2.ll +++ b/src/runtime/x86_avx2.ll @@ -61,3 +61,14 @@ define weak_odr <16 x i16> @saturating_pmulhrswx16(<16 x i16> %a, <16 x i16> %b) ret <16 x i16> %5 } declare <16 x i16> @llvm.x86.avx2.pmul.hr.sw(<16 x i16>, <16 x i16>) nounwind readnone + +define weak_odr <16 x i16> @hadd_pmadd_u8_avx2(<32 x i8> %a) nounwind alwaysinline { + %1 = tail call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> %a, <32 x i8> ) + ret <16 x i16> %1 +} + +define weak_odr <16 x i16> @hadd_pmadd_i8_avx2(<32 x i8> %a) nounwind alwaysinline { + %1 = tail call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> , <32 x i8> %a) + ret <16 x i16> %1 +} +declare <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8>, <32 x i8>) nounwind readnone diff --git a/src/runtime/x86_sse41.ll b/src/runtime/x86_sse41.ll index 3ca654d0e874..f109ee37ec23 100644 --- a/src/runtime/x86_sse41.ll +++ b/src/runtime/x86_sse41.ll @@ -81,3 +81,14 @@ define weak_odr <8 x i16> @saturating_pmulhrswx8(<8 x i16> %a, <8 x i16> %b) nou ret <8 x i16> %5 } declare <8 x i16> @llvm.x86.ssse3.pmul.hr.sw.128(<8 x i16>, <8 x i16>) nounwind readnone + +define weak_odr <8 x i16> @hadd_pmadd_u8_sse3(<16 x i8> %a) nounwind alwaysinline { + %1 = tail call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> %a, <16 x i8> ) + ret <8 x i16> %1 +} + +define weak_odr <8 x i16> @hadd_pmadd_i8_sse3(<16 x i8> %a) nounwind alwaysinline { + %1 = tail call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> , <16 x i8> %a) + ret <8 x i16> %1 +} +declare <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8>, <16 x i8>) nounwind readnone diff --git a/test/correctness/simd_op_check.cpp b/test/correctness/simd_op_check.cpp index fe3dae63009b..5f2258b65f55 100644 --- a/test/correctness/simd_op_check.cpp +++ b/test/correctness/simd_op_check.cpp @@ -308,6 +308,11 @@ class SimdOpCheck : public SimdOpCheckTest { RDom r2(0, 2); check(check_pmaddubsw, 4 * w, saturating_sum(i16(in_u8(2 * x + r2)) * in_i8(2 * x + r2 + 32))); check(check_pmaddubsw, 4 * w, saturating_sum(i16(in_i8(2 * x + r2)) * in_u8(2 * x + r2 + 32))); + + // uint8 -> uint16 or int16 and int8 -> int16 horizontal widening adds should use pmaddubsw. + check(check_pmaddubsw, 4 * w, sum(u16(in_u8(2 * x + r2)))); + check(check_pmaddubsw, 4 * w, sum(i16(in_u8(2 * x + r2)))); + check(check_pmaddubsw, 4 * w, sum(i16(in_i8(2 * x + r2)))); } }