From d68360cebfb09d42d0df30d45d95d8719269f443 Mon Sep 17 00:00:00 2001 From: Mohammad Azim Khan Date: Sat, 16 Nov 2024 08:49:33 +0000 Subject: [PATCH 1/4] Various masked operations --- g3doc/quick_reference.md | 43 ++++++++++++ hwy/ops/arm_sve-inl.h | 119 +++++++++++++++++++++++++++++++ hwy/ops/generic_ops-inl.h | 67 ++++++++++++++++++ hwy/tests/logical_test.cc | 23 ++++++ hwy/tests/masked_minmax_test.cc | 25 +++++++ hwy/tests/reduction_test.cc | 120 ++++++++++++++++++++++++++++++++ hwy/tests/table_test.cc | 113 ++++++++++++++++++++++++++++++ 7 files changed, 510 insertions(+) diff --git a/g3doc/quick_reference.md b/g3doc/quick_reference.md index 8220e9b718..8b9cf41d7d 100644 --- a/g3doc/quick_reference.md +++ b/g3doc/quick_reference.md @@ -886,6 +886,9 @@ not a concern, these are equivalent to, and potentially more efficient than, V **MaskedSatSubOr**(V no, M m, V a, V b): returns `a[i] + b[i]` saturated to the minimum/maximum representable value, or `no[i]` if `m[i]` is false. +* `V`: `{i,f}` \ + V **MaskedAbsOr**(M m, V a, V b): returns the absolute value of + `a[i]` where m is active and returns `b[i]` otherwise. #### Shifts @@ -1050,6 +1053,9 @@ types, and on SVE/RVV. * V **AndNot**(V a, V b): returns `~a[i] & b[i]`. +* V **MaskedOrOrZero**(M m, V a, V b): returns `a[i] || b[i]` + or `zero` if `m[i]` is false. + The following three-argument functions may be more efficient than assembling them from 2-argument functions: @@ -1756,6 +1762,9 @@ All functions except `Stream` are defined in cache_control.h. `DemoteToNearestInt(d, v)` is more efficient on some targets, including x86 and RVV. +* Vec<D> **MaskedConvertToOrZero**(M m, D d, V v): returns `v[i]` + converted to `D` where m is active and returns zero otherwise. + #### Single vector demotion These functions demote a full vector (or parts thereof) into a vector of half @@ -2237,6 +2246,22 @@ The following `ReverseN` must not be called if `Lanes(D()) < N`: must be in the range `[0, 2 * Lanes(d))` but need not be unique. The index type `TI` must be an integer of the same size as `TFromD`. +* V **TableLookupLanesOr**(M m, V a, V b, unspecified) returns the + result of `TableLookupLanes(a, unspecified)` where `m[i]` is true, and returns + `b[i]` where `m[i]` is false. + +* V **TableLookupLanesOrZero**(M m, V a, unspecified) returns + the result of `TableLookupLanes(a, unspecified)` where `m[i]` is true, and + returns zero where `m[i]` is false. + +* V **TwoTablesLookupLanesOr**(D d, M m, V a, V b, unspecified) + returns the result of `TwoTablesLookupLanes(V a, V b, unspecified)` where + `m[i]` is true, and `a[i]` where `m[i]` is false. + +* V **TwoTablesLookupLanesOrZero**(D d, M m, V a, V b, unspecified) + returns the result of `TwoTablesLookupLanes(V a, V b, unspecified)` where + `m[i]` is true, and zero where `m[i]` is false. + * V **Per4LaneBlockShuffle**<size_t kIdx3, size_t kIdx2, size_t kIdx1, size_t kIdx0>(V v) does a per 4-lane block shuffle of `v` if `Lanes(DFromV())` is greater than or equal to 4 or a shuffle of the @@ -2377,6 +2402,24 @@ more efficient on some targets. * T **ReduceMin**(D, V v): returns the minimum of all lanes. * T **ReduceMax**(D, V v): returns the maximum of all lanes. +### Masked reductions + +**Note**: Horizontal operations (across lanes of the same vector) such as +reductions are slower than normal SIMD operations and are typically used outside +critical loops. + +All ops in this section ignore lanes where `mask=false`. These are equivalent +to, and potentially more efficient than, `GetLane(SumOfLanes(d, +IfThenElseZero(m, v)))` etc. The result is implementation-defined when all mask +elements are false. + +* T **MaskedReduceSum**(D, M m, V v): returns the sum of all lanes + where `m[i]` is `true`. +* T **MaskedReduceMin**(D, M m, V v): returns the minimum of all + lanes where `m[i]` is `true`. +* T **MaskedReduceMax**(D, M m, V v): returns the maximum of all + lanes where `m[i]` is `true`. + ### Crypto Ops in this section are only available if `HWY_TARGET != HWY_SCALAR`: diff --git a/hwy/ops/arm_sve-inl.h b/hwy/ops/arm_sve-inl.h index 2dde1479de..e744c376f3 100644 --- a/hwy/ops/arm_sve-inl.h +++ b/hwy/ops/arm_sve-inl.h @@ -219,6 +219,15 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _) HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ return sv##OP##_##CHAR##BITS(v); \ } +#define HWY_SVE_RETV_ARGMV_M(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_m(b, m, a); \ + } +#define HWY_SVE_RETV_ARGMV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a) { \ + return sv##OP##_##CHAR##BITS##_z(m, a); \ + } // vector = f(vector, scalar), e.g. detail::AddN #define HWY_SVE_RETV_ARGPVN(BASE, CHAR, BITS, HALF, NAME, OP) \ @@ -252,6 +261,17 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ return sv##OP##_##CHAR##BITS##_x(m, a, b); \ } +#define HWY_SVE_RETV_ARGMVV_M(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_m(m, a, b); \ + } +// User-specified mask. Mask=false value is zero. +#define HWY_SVE_RETV_ARGMVVZ(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_z(m, a, b); \ + } #define HWY_SVE_RETV_ARGVVV(BASE, CHAR, BITS, HALF, NAME, OP) \ HWY_API HWY_SVE_V(BASE, BITS) \ @@ -260,6 +280,13 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _) return sv##OP##_##CHAR##BITS(a, b, c); \ } +#define HWY_SVE_RETV_ARGMVVV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \ + HWY_SVE_V(BASE, BITS) c) { \ + return sv##OP##_##CHAR##BITS##_m(m, a, b, c); \ + } + // ------------------------------ Lanes namespace detail { @@ -727,6 +754,9 @@ HWY_API V Or(const V a, const V b) { return BitCast(df, Or(BitCast(du, a), BitCast(du, b))); } +// ------------------------------ MaskedOrOrZero +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVVZ, MaskedOrOrZero, orr) + // ------------------------------ Xor namespace detail { @@ -862,6 +892,12 @@ HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Abs, abs) HWY_SVE_FOREACH_I(HWY_SVE_RETV_ARGPV, SaturatedAbs, qabs) #endif // HWY_SVE_HAVE_2 +// ------------------------------ MaskedAbsOr +HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGMV_M, MaskedAbsOr, abs) + +// ------------------------------ MaskedAbsOrZero +HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGMV_Z, MaskedAbsOrZero, abs) + // ================================================== ARITHMETIC // Per-target flags to prevent generic_ops-inl.h defining Add etc. @@ -1272,6 +1308,11 @@ HWY_SVE_FOREACH_F(HWY_SVE_FMA, NegMulSub, nmad) #undef HWY_SVE_FMA +// ------------------------------ MaskedMulAdd +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVVV, MaskedMulAdd, mad) +} + // ------------------------------ Round etc. HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Round, rintn) @@ -1515,6 +1556,7 @@ HWY_API svbool_t LowerHalfOfMask(D /*d*/, svbool_t m) { namespace detail { HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMin, min) HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMax, max) +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVVZ, MaskedMaxOrZero, max) HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedAdd, add) HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedSub, sub) HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMul, mul) @@ -2849,6 +2891,41 @@ HWY_API svfloat32_t DemoteTo(Simd d, const svuint64_t v) { HWY_SVE_FOREACH_F(HWY_SVE_CONVERT, ConvertTo, cvt) #undef HWY_SVE_CONVERT +// ------------------------------ MaskedConvertToOrZero F + +#define HWY_SVE_MASKED_CONVERT_TO_OR_ZERO(BASE, CHAR, BITS, HALF, NAME, OP) \ + /* Float from signed */ \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_V(int, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_s##BITS##_z(m, v); \ + } \ + /* Float from unsigned */ \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_V(uint, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_u##BITS##_z(m, v); \ + } \ + /* Signed from float, rounding toward zero */ \ + template \ + HWY_API HWY_SVE_V(int, BITS) \ + NAME(svbool_t m, HWY_SVE_D(int, BITS, N, kPow2) /* d */, \ + HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_s##BITS##_##CHAR##BITS##_z(m, v); \ + } \ + /* Unsigned from float, rounding toward zero */ \ + template \ + HWY_API HWY_SVE_V(uint, BITS) \ + NAME(svbool_t m, HWY_SVE_D(uint, BITS, N, kPow2) /* d */, \ + HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_u##BITS##_##CHAR##BITS##_z(m, v); \ + } + +HWY_SVE_FOREACH_F(HWY_SVE_MASKED_CONVERT_TO_OR_ZERO, MaskedConvertToOrZero, cvt) +#undef HWY_SVE_MASKED_CONVERT_TO_OR_ZERO + // ------------------------------ NearestInt (Round, ConvertTo) template >> HWY_API VFromD NearestInt(VF v) { @@ -3288,6 +3365,25 @@ HWY_API TFromD ReduceMax(D d, VFromD v) { return detail::MaxOfLanesM(detail::MakeMask(d), v); } +#ifdef HWY_NATIVE_MASKED_REDUCE_SCALAR +#undef HWY_NATIVE_MASKED_REDUCE_SCALAR +#else +#define HWY_NATIVE_MASKED_REDUCE_SCALAR +#endif + +template +HWY_API TFromD MaskedReduceSum(D /*d*/, M m, VFromD v) { + return detail::SumOfLanesM(m, v); +} +template +HWY_API TFromD MaskedReduceMin(D /*d*/, M m, VFromD v) { + return detail::MinOfLanesM(m, v); +} +template +HWY_API TFromD MaskedReduceMax(D /*d*/, M m, VFromD v) { + return detail::MaxOfLanesM(m, v); +} + // ------------------------------ SumOfLanes template @@ -4755,6 +4851,23 @@ HWY_API V IfNegativeThenElse(V v, V yes, V no) { static_assert(IsSigned>(), "Only works for signed/float"); return IfThenElse(IsNegative(v), yes, no); } +// ------------------------------ IfNegativeThenNegOrUndefIfZero + +#ifdef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#undef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#else +#define HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#endif + +#define HWY_SVE_NEG_IF(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) mask, HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_m(v, IsNegative(mask), v); \ + } + +HWY_SVE_FOREACH_IF(HWY_SVE_NEG_IF, IfNegativeThenNegOrUndefIfZero, neg) + +#undef HWY_SVE_NEG_IF // ------------------------------ AverageRound (ShiftRight) @@ -6291,13 +6404,19 @@ HWY_API V HighestSetBitIndex(V v) { #undef HWY_SVE_IF_NOT_EMULATED_D #undef HWY_SVE_PTRUE #undef HWY_SVE_RETV_ARGMVV +#undef HWY_SVE_RETV_ARGMVVZ #undef HWY_SVE_RETV_ARGPV #undef HWY_SVE_RETV_ARGPVN #undef HWY_SVE_RETV_ARGPVV #undef HWY_SVE_RETV_ARGV #undef HWY_SVE_RETV_ARGVN +#undef HWY_SVE_RETV_ARGMV +#undef HWY_SVE_RETV_ARGMV_M +#undef HWY_SVE_RETV_ARGMV_Z #undef HWY_SVE_RETV_ARGVV +#undef HWY_SVE_RETV_ARGMVV_M #undef HWY_SVE_RETV_ARGVVV +#undef HWY_SVE_RETV_ARGMVVV #undef HWY_SVE_T #undef HWY_SVE_UNDEFINED #undef HWY_SVE_V diff --git a/hwy/ops/generic_ops-inl.h b/hwy/ops/generic_ops-inl.h index 99b518d99c..dbde1377f1 100644 --- a/hwy/ops/generic_ops-inl.h +++ b/hwy/ops/generic_ops-inl.h @@ -672,6 +672,18 @@ HWY_API V SaturatedAbs(V v) { #endif +// ------------------------------ MaskedAbsOr +template +HWY_API V MaskedAbsOr(M m, V v, V no) { + return IfThenElse(m, Abs(v), no); +} + +// ------------------------------ MaskedAbsOrZero +template +HWY_API V MaskedAbsOrZero(M m, V v) { + return IfThenElseZero(m, Abs(v)); +} + // ------------------------------ Reductions // Targets follow one of two strategies. If HWY_NATIVE_REDUCE_SCALAR is toggled, @@ -882,6 +894,28 @@ HWY_API TFromD ReduceMax(D d, VFromD v) { } #endif // HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#if (defined(HWY_NATIVE_MASKED_REDUCE_SCALAR) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_MASKED_REDUCE_SCALAR +#undef HWY_NATIVE_MASKED_REDUCE_SCALAR +#else +#define HWY_NATIVE_MASKED_REDUCE_SCALAR +#endif + +template +HWY_API TFromD MaskedReduceSum(D d, M m, VFromD v) { + return ReduceSum(d, IfThenElseZero(m, v)); +} +template +HWY_API TFromD MaskedReduceMin(D d, M m, VFromD v) { + return ReduceMin(d, IfThenElse(m, v, MaxOfLanes(d, v))); +} +template +HWY_API TFromD MaskedReduceMax(D d, M m, VFromD v) { + return ReduceMax(d, IfThenElseZero(m, v)); +} + +#endif // HWY_NATIVE_MASKED_REDUCE_SCALAR + // ------------------------------ IsEitherNaN #if (defined(HWY_NATIVE_IS_EITHER_NAN) == defined(HWY_TARGET_TOGGLE)) #ifdef HWY_NATIVE_IS_EITHER_NAN @@ -6444,6 +6478,30 @@ HWY_API V ReverseBits(V v) { } #endif // HWY_NATIVE_REVERSE_BITS_UI16_32_64 +// ------------------------------ TableLookupLanesOr +template +HWY_API V TableLookupLanesOr(M m, V a, V b, IndicesFromD> idx) { + return IfThenElse(m, TableLookupLanes(a, idx), b); +} + +// ------------------------------ TableLookupLanesOrZero +template +HWY_API V TableLookupLanesOrZero(M m, V a, IndicesFromD> idx) { + return IfThenElseZero(m, TableLookupLanes(a, idx)); +} + +// ------------------------------ TwoTablesLookupLanesOr +template +HWY_API V TwoTablesLookupLanesOr(D d, M m, V a, V b, IndicesFromD idx) { + return IfThenElse(m, TwoTablesLookupLanes(d, a, b, idx), a); +} + +// ------------------------------ TwoTablesLookupLanesOrZero +template +HWY_API V TwoTablesLookupLanesOrZero(D d, M m, V a, V b, IndicesFromD idx) { + return IfThenElse(m, TwoTablesLookupLanes(d, a, b, idx), Zero(d)); +} + // ------------------------------ Per4LaneBlockShuffle #if (defined(HWY_NATIVE_PER4LANEBLKSHUF_DUP32) == defined(HWY_TARGET_TOGGLE)) @@ -7299,6 +7357,15 @@ HWY_API V BitShuffle(V v, VI idx) { #endif // HWY_NATIVE_BITSHUFFLE +template +HWY_API V MaskedMaxOrZero(M m, V a, V b) { + return IfThenElseZero(m, (Max(a, b))); +} + +template +HWY_API V MaskedOrOrZero(M m, V a, V b) { + return IfThenElseZero(m, Or(a, b)); +} // ================================================== Operator wrapper // SVE* and RVV currently cannot define operators and have already defined diff --git a/hwy/tests/logical_test.cc b/hwy/tests/logical_test.cc index ecd7589c9e..5abc2277bc 100644 --- a/hwy/tests/logical_test.cc +++ b/hwy/tests/logical_test.cc @@ -146,6 +146,28 @@ HWY_NOINLINE void TestAllTestBit() { ForIntegerTypes(ForPartialVectors()); } +struct TestMaskedOrOrZero { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const MFromD all_true = MaskTrue(d); + const auto v1 = Iota(d, 1); + const auto v2 = Iota(d, 2); + + HWY_ASSERT_VEC_EQ(d, Or(v2, v1), MaskedOrOrZero(all_true, v1, v2)); + + const MFromD first_five = FirstN(d, 5); + const Vec v0 = Zero(d); + + const Vec v1_exp = IfThenElse(first_five, Or(v2, v1), v0); + + HWY_ASSERT_VEC_EQ(d, v1_exp, MaskedOrOrZero(first_five, v1, v2)); + } +}; + +HWY_NOINLINE void TestAllMaskedLogical() { + ForAllTypes(ForPartialVectors()); +} + } // namespace // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE @@ -159,6 +181,7 @@ HWY_BEFORE_TEST(HwyLogicalTest); HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllNot); HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllLogical); HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllTestBit); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllMaskedLogical); HWY_AFTER_TEST(); } // namespace } // namespace hwy diff --git a/hwy/tests/masked_minmax_test.cc b/hwy/tests/masked_minmax_test.cc index 0e071b14c1..0f726bfac6 100644 --- a/hwy/tests/masked_minmax_test.cc +++ b/hwy/tests/masked_minmax_test.cc @@ -137,6 +137,30 @@ HWY_NOINLINE void TestAllSignedMinMax() { ForFloatTypes(ForPartialVectors()); } +struct TestMaskedMaxOrZero { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const MFromD all_true = MaskTrue(d); + const auto v1 = Iota(d, 1); + const auto v2 = Iota(d, 2); + + HWY_ASSERT_VEC_EQ(d, v2, MaskedMaxOrZero(all_true, v1, v2)); + + const MFromD first_five = FirstN(d, 5); + const Vec v0 = Zero(d); + + const Vec v1_exp = IfThenElse(first_five, v2, v0); + + auto output = MaskedMaxOrZero(first_five, v1, v2); + + HWY_ASSERT_VEC_EQ(d, v1_exp, output); + } +}; + +HWY_NOINLINE void TestAllMaskedMaxOrZero() { + ForAllTypes(ForPartialVectors()); +} + } // namespace // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE @@ -149,6 +173,7 @@ namespace { HWY_BEFORE_TEST(HwyMaskedMinMaxTest); HWY_EXPORT_AND_TEST_P(HwyMaskedMinMaxTest, TestAllUnsignedMinMax); HWY_EXPORT_AND_TEST_P(HwyMaskedMinMaxTest, TestAllSignedMinMax); +HWY_EXPORT_AND_TEST_P(HwyMaskedMinMaxTest, TestAllMaskedMaxOrZero); HWY_AFTER_TEST(); } // namespace } // namespace hwy diff --git a/hwy/tests/reduction_test.cc b/hwy/tests/reduction_test.cc index fffc4a7873..fd35f645f6 100644 --- a/hwy/tests/reduction_test.cc +++ b/hwy/tests/reduction_test.cc @@ -352,6 +352,122 @@ HWY_NOINLINE void TestAllSumsOf8() { ForGEVectors<64, TestSumsOf8>()(uint8_t()); } +struct TestMaskedReduceSum { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + const Vec v2 = Iota(d, 2); + + const size_t N = Lanes(d); + auto bool_lanes = AllocateAligned(N); + HWY_ASSERT(bool_lanes); + + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + T expected = 0; + for (size_t i = 0; i < N; ++i) { + bool_lanes[i] = (Random32(&rng) & 1024) ? T(1) : T(0); + if (bool_lanes[i]) { + expected += ConvertScalarTo(i + 2); + } + } + + const Vec mask_i = Load(d, bool_lanes.get()); + const Mask mask = RebindMask(d, Gt(mask_i, Zero(d))); + + // If all elements are disabled the result is implementation defined + if (AllFalse(d, mask)) { + continue; + } + + HWY_ASSERT_EQ(expected, MaskedReduceSum(d, mask, v2)); + } + } +}; + +HWY_NOINLINE void TestAllMaskedReduceSum() { + ForAllTypes(ForPartialVectors()); +} + +struct TestMaskedReduceMin { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + const Vec v2 = Iota(d, 2); + + const size_t N = Lanes(d); + auto bool_lanes = AllocateAligned(N); + HWY_ASSERT(bool_lanes); + + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + T expected = + ConvertScalarTo(N + 3); // larger than any values in the vector + for (size_t i = 0; i < N; ++i) { + bool_lanes[i] = (Random32(&rng) & 1024) ? T(1) : T(0); + if (bool_lanes[i]) { + if (expected > ConvertScalarTo(i + 2)) { + expected = ConvertScalarTo(i + 2); + } + } + } + + const Vec mask_i = Load(d, bool_lanes.get()); + const Mask mask = RebindMask(d, Gt(mask_i, Zero(d))); + + // If all elements are disabled the result is implementation defined + if (AllFalse(d, mask)) { + continue; + } + + HWY_ASSERT_EQ(expected, MaskedReduceMin(d, mask, v2)); + } + } +}; + +HWY_NOINLINE void TestAllMaskedReduceMin() { + ForAllTypes(ForPartialVectors()); +} + +struct TestMaskedReduceMax { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + const Vec v2 = Iota(d, 2); + + const size_t N = Lanes(d); + auto bool_lanes = AllocateAligned(N); + HWY_ASSERT(bool_lanes); + + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + T expected = 0; + for (size_t i = 0; i < N; ++i) { + bool_lanes[i] = (Random32(&rng) & 1024) ? T(1) : T(0); + if (bool_lanes[i]) { + if (expected < ConvertScalarTo(i + 2)) { + expected = ConvertScalarTo(i + 2); + } + } + } + + const Vec mask_i = Load(d, bool_lanes.get()); + const Mask mask = RebindMask(d, Gt(mask_i, Zero(d))); + + // If all elements are disabled the result is implementation defined + if (AllFalse(d, mask)) { + continue; + } + + HWY_ASSERT_EQ(expected, MaskedReduceMax(d, mask, v2)); + } + } +}; + +HWY_NOINLINE void TestAllMaskedReduceMax() { + ForAllTypes(ForPartialVectors()); +} + } // namespace // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE @@ -367,6 +483,10 @@ HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllMinMaxOfLanes); HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllSumsOf2); HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllSumsOf4); HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllSumsOf8); + +HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllMaskedReduceSum); +HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllMaskedReduceMin); +HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllMaskedReduceMax); HWY_AFTER_TEST(); } // namespace } // namespace hwy diff --git a/hwy/tests/table_test.cc b/hwy/tests/table_test.cc index 09fdd7eaf6..eb5b1a8644 100644 --- a/hwy/tests/table_test.cc +++ b/hwy/tests/table_test.cc @@ -103,6 +103,59 @@ HWY_NOINLINE void TestAllTableLookupLanes() { ForAllTypes(ForPartialVectors()); } +struct TestTableLookupLanesOr { + template +#if HWY_TARGET != HWY_SCALARWE + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const RebindToSigned di; + using TI = TFromD; + + const size_t N = Lanes(d); + // Select indices from N-1 counting down + auto indices = IndicesFromVec( + d, Sub(Set(di, ConvertScalarTo(N - 1)), Iota(di, 0))); + + auto expected = AllocateAligned(N); + auto expected_zero = AllocateAligned(N); + auto bool_lanes = AllocateAligned(N); + HWY_ASSERT(expected && expected_zero && bool_lanes); + + const auto v1 = Iota(d, 5); + const auto v2 = Iota(d, 8); + + RandomState rng; + + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + for (size_t i = 0; i < N; ++i) { + bool_lanes[i] = (Random32(&rng) & 1024) ? T(1) : T(0); + + if (bool_lanes[i]) { + expected[i] = ConvertScalarTo(N - i + 5 - 1); // v1[N-1, N-2, ...] + expected_zero[i] = + ConvertScalarTo(N - i + 5 - 1); // v1[N-1, N-2, ...] + } else { + expected[i] = ConvertScalarTo(i + 8); // v2[i] + expected_zero[i] = ConvertScalarTo(0); + } + } + + const Vec mask_i = Load(d, bool_lanes.get()); + const Mask mask = RebindMask(d, Gt(mask_i, Zero(d))); + HWY_ASSERT_VEC_EQ(d, expected.get(), + TableLookupLanesOr(mask, v1, v2, indices)); + HWY_ASSERT_VEC_EQ(d, expected_zero.get(), + TableLookupLanesOrZero(mask, v1, indices)); +#else + (void) d; +#endif + } + } +}; + +HWY_NOINLINE void TestAllTableLookupLanesOr() { + ForAllTypes(ForPartialVectors()); +} + struct TestTwoTablesLookupLanes { template HWY_NOINLINE void operator()(T /*unused*/, D d) { @@ -194,6 +247,64 @@ HWY_NOINLINE void TestAllTwoTablesLookupLanes() { ForAllTypes(ForPartialVectors()); } +struct TestTwoTablesLookupLanesOr { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const RebindToSigned di; + using TI = TFromD; + + const size_t N = Lanes(d); + // Select indices from N-1 counting down + auto idx_lower = Sub(Set(di, ConvertScalarTo(N - 1)), Iota(di, 0)); + auto idx_upper = Add(idx_lower, Set(di, ConvertScalarTo(N))); + auto indices = IndicesFromVec(d, OddEven(idx_upper, idx_lower)); + + auto expected = AllocateAligned(N); + auto expected_zero = AllocateAligned(N); + auto bool_lanes = AllocateAligned(N); + HWY_ASSERT(expected && expected_zero && bool_lanes); + + const auto v1 = Iota(d, 5); + const auto v2 = Iota(d, 8); + + RandomState rng; + + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + for (size_t i = 0; i < N; ++i) { + bool_lanes[i] = (Random32(&rng) & 1024) ? T(1) : T(0); + + if (bool_lanes[i]) { + if (i % 2) { + expected[i] = + ConvertScalarTo(N - i + 8 - 1); // v2[N-1, N-2, ...] + expected_zero[i] = + ConvertScalarTo(N - i + 8 - 1); // v2[N-1, N-2, ...] + } else { + expected[i] = + ConvertScalarTo(N - i + 5 - 1); // v1[N-1, N-2, ...] + expected_zero[i] = + ConvertScalarTo(N - i + 5 - 1); // v1[N-1, N-2, ...] + } + } else { + expected[i] = ConvertScalarTo(i + 5); // v1[i] + expected_zero[i] = ConvertScalarTo(0); + } + } + + const Vec mask_i = Load(d, bool_lanes.get()); + const Mask mask = RebindMask(d, Gt(mask_i, Zero(d))); + HWY_ASSERT_VEC_EQ(d, expected.get(), + TwoTablesLookupLanesOr(d, mask, v1, v2, indices)); + HWY_ASSERT_VEC_EQ(d, expected_zero.get(), + TwoTablesLookupLanesOrZero(d, mask, v1, v2, indices)); + } + } +}; + +HWY_NOINLINE void TestAllTwoTablesLookupLanesOr() { + ForAllTypes(ForPartialVectors()); +} + } // namespace // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE @@ -205,7 +316,9 @@ namespace hwy { namespace { HWY_BEFORE_TEST(HwyTableTest); HWY_EXPORT_AND_TEST_P(HwyTableTest, TestAllTableLookupLanes); +HWY_EXPORT_AND_TEST_P(HwyTableTest, TestAllTableLookupLanesOr); HWY_EXPORT_AND_TEST_P(HwyTableTest, TestAllTwoTablesLookupLanes); +HWY_EXPORT_AND_TEST_P(HwyTableTest, TestAllTwoTablesLookupLanesOr); HWY_AFTER_TEST(); } // namespace } // namespace hwy From 49433a86814a4b02970a644fcc8320a32b5eee37 Mon Sep 17 00:00:00 2001 From: Mohammad Azim Khan Date: Fri, 22 Nov 2024 18:32:47 +0000 Subject: [PATCH 2/4] Remove MaskedAbsOr(Zero) covered with arithmetic tests --- g3doc/quick_reference.md | 3 --- hwy/ops/arm_sve-inl.h | 6 ------ hwy/ops/generic_ops-inl.h | 12 ------------ 3 files changed, 21 deletions(-) diff --git a/g3doc/quick_reference.md b/g3doc/quick_reference.md index 8b9cf41d7d..0e4d224609 100644 --- a/g3doc/quick_reference.md +++ b/g3doc/quick_reference.md @@ -886,9 +886,6 @@ not a concern, these are equivalent to, and potentially more efficient than, V **MaskedSatSubOr**(V no, M m, V a, V b): returns `a[i] + b[i]` saturated to the minimum/maximum representable value, or `no[i]` if `m[i]` is false. -* `V`: `{i,f}` \ - V **MaskedAbsOr**(M m, V a, V b): returns the absolute value of - `a[i]` where m is active and returns `b[i]` otherwise. #### Shifts diff --git a/hwy/ops/arm_sve-inl.h b/hwy/ops/arm_sve-inl.h index e744c376f3..442020c487 100644 --- a/hwy/ops/arm_sve-inl.h +++ b/hwy/ops/arm_sve-inl.h @@ -892,12 +892,6 @@ HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Abs, abs) HWY_SVE_FOREACH_I(HWY_SVE_RETV_ARGPV, SaturatedAbs, qabs) #endif // HWY_SVE_HAVE_2 -// ------------------------------ MaskedAbsOr -HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGMV_M, MaskedAbsOr, abs) - -// ------------------------------ MaskedAbsOrZero -HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGMV_Z, MaskedAbsOrZero, abs) - // ================================================== ARITHMETIC // Per-target flags to prevent generic_ops-inl.h defining Add etc. diff --git a/hwy/ops/generic_ops-inl.h b/hwy/ops/generic_ops-inl.h index dbde1377f1..16693bf9c5 100644 --- a/hwy/ops/generic_ops-inl.h +++ b/hwy/ops/generic_ops-inl.h @@ -672,18 +672,6 @@ HWY_API V SaturatedAbs(V v) { #endif -// ------------------------------ MaskedAbsOr -template -HWY_API V MaskedAbsOr(M m, V v, V no) { - return IfThenElse(m, Abs(v), no); -} - -// ------------------------------ MaskedAbsOrZero -template -HWY_API V MaskedAbsOrZero(M m, V v) { - return IfThenElseZero(m, Abs(v)); -} - // ------------------------------ Reductions // Targets follow one of two strategies. If HWY_NATIVE_REDUCE_SCALAR is toggled, From 78960957fe5eb2c05e1697be3e769601654ca79f Mon Sep 17 00:00:00 2001 From: Mohammad Azim Khan Date: Fri, 22 Nov 2024 18:39:13 +0000 Subject: [PATCH 3/4] Remove MaskedConvertToOrZero covered with convert and round ops --- g3doc/quick_reference.md | 3 --- hwy/ops/arm_sve-inl.h | 35 ----------------------------------- 2 files changed, 38 deletions(-) diff --git a/g3doc/quick_reference.md b/g3doc/quick_reference.md index 0e4d224609..2ee3dad3a9 100644 --- a/g3doc/quick_reference.md +++ b/g3doc/quick_reference.md @@ -1759,9 +1759,6 @@ All functions except `Stream` are defined in cache_control.h. `DemoteToNearestInt(d, v)` is more efficient on some targets, including x86 and RVV. -* Vec<D> **MaskedConvertToOrZero**(M m, D d, V v): returns `v[i]` - converted to `D` where m is active and returns zero otherwise. - #### Single vector demotion These functions demote a full vector (or parts thereof) into a vector of half diff --git a/hwy/ops/arm_sve-inl.h b/hwy/ops/arm_sve-inl.h index 442020c487..0143995100 100644 --- a/hwy/ops/arm_sve-inl.h +++ b/hwy/ops/arm_sve-inl.h @@ -2885,41 +2885,6 @@ HWY_API svfloat32_t DemoteTo(Simd d, const svuint64_t v) { HWY_SVE_FOREACH_F(HWY_SVE_CONVERT, ConvertTo, cvt) #undef HWY_SVE_CONVERT -// ------------------------------ MaskedConvertToOrZero F - -#define HWY_SVE_MASKED_CONVERT_TO_OR_ZERO(BASE, CHAR, BITS, HALF, NAME, OP) \ - /* Float from signed */ \ - template \ - HWY_API HWY_SVE_V(BASE, BITS) \ - NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ - HWY_SVE_V(int, BITS) v) { \ - return sv##OP##_##CHAR##BITS##_s##BITS##_z(m, v); \ - } \ - /* Float from unsigned */ \ - template \ - HWY_API HWY_SVE_V(BASE, BITS) \ - NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ - HWY_SVE_V(uint, BITS) v) { \ - return sv##OP##_##CHAR##BITS##_u##BITS##_z(m, v); \ - } \ - /* Signed from float, rounding toward zero */ \ - template \ - HWY_API HWY_SVE_V(int, BITS) \ - NAME(svbool_t m, HWY_SVE_D(int, BITS, N, kPow2) /* d */, \ - HWY_SVE_V(BASE, BITS) v) { \ - return sv##OP##_s##BITS##_##CHAR##BITS##_z(m, v); \ - } \ - /* Unsigned from float, rounding toward zero */ \ - template \ - HWY_API HWY_SVE_V(uint, BITS) \ - NAME(svbool_t m, HWY_SVE_D(uint, BITS, N, kPow2) /* d */, \ - HWY_SVE_V(BASE, BITS) v) { \ - return sv##OP##_u##BITS##_##CHAR##BITS##_z(m, v); \ - } - -HWY_SVE_FOREACH_F(HWY_SVE_MASKED_CONVERT_TO_OR_ZERO, MaskedConvertToOrZero, cvt) -#undef HWY_SVE_MASKED_CONVERT_TO_OR_ZERO - // ------------------------------ NearestInt (Round, ConvertTo) template >> HWY_API VFromD NearestInt(VF v) { From d9a19a66d3c4632c76c368f35dce7dfc57421661 Mon Sep 17 00:00:00 2001 From: Mohammad Azim Khan Date: Fri, 22 Nov 2024 18:49:03 +0000 Subject: [PATCH 4/4] Remove MaskedMaxOrZero as covered with Zero masked ops --- hwy/ops/arm_sve-inl.h | 6 ------ hwy/ops/generic_ops-inl.h | 5 ----- hwy/tests/masked_minmax_test.cc | 25 ------------------------- 3 files changed, 36 deletions(-) diff --git a/hwy/ops/arm_sve-inl.h b/hwy/ops/arm_sve-inl.h index 0143995100..66ad1dfe3b 100644 --- a/hwy/ops/arm_sve-inl.h +++ b/hwy/ops/arm_sve-inl.h @@ -1302,11 +1302,6 @@ HWY_SVE_FOREACH_F(HWY_SVE_FMA, NegMulSub, nmad) #undef HWY_SVE_FMA -// ------------------------------ MaskedMulAdd -namespace detail { -HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVVV, MaskedMulAdd, mad) -} - // ------------------------------ Round etc. HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Round, rintn) @@ -1550,7 +1545,6 @@ HWY_API svbool_t LowerHalfOfMask(D /*d*/, svbool_t m) { namespace detail { HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMin, min) HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMax, max) -HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVVZ, MaskedMaxOrZero, max) HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedAdd, add) HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedSub, sub) HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMul, mul) diff --git a/hwy/ops/generic_ops-inl.h b/hwy/ops/generic_ops-inl.h index 16693bf9c5..efee4bc971 100644 --- a/hwy/ops/generic_ops-inl.h +++ b/hwy/ops/generic_ops-inl.h @@ -7345,11 +7345,6 @@ HWY_API V BitShuffle(V v, VI idx) { #endif // HWY_NATIVE_BITSHUFFLE -template -HWY_API V MaskedMaxOrZero(M m, V a, V b) { - return IfThenElseZero(m, (Max(a, b))); -} - template HWY_API V MaskedOrOrZero(M m, V a, V b) { return IfThenElseZero(m, Or(a, b)); diff --git a/hwy/tests/masked_minmax_test.cc b/hwy/tests/masked_minmax_test.cc index 0f726bfac6..0e071b14c1 100644 --- a/hwy/tests/masked_minmax_test.cc +++ b/hwy/tests/masked_minmax_test.cc @@ -137,30 +137,6 @@ HWY_NOINLINE void TestAllSignedMinMax() { ForFloatTypes(ForPartialVectors()); } -struct TestMaskedMaxOrZero { - template - HWY_NOINLINE void operator()(T /*unused*/, D d) { - const MFromD all_true = MaskTrue(d); - const auto v1 = Iota(d, 1); - const auto v2 = Iota(d, 2); - - HWY_ASSERT_VEC_EQ(d, v2, MaskedMaxOrZero(all_true, v1, v2)); - - const MFromD first_five = FirstN(d, 5); - const Vec v0 = Zero(d); - - const Vec v1_exp = IfThenElse(first_five, v2, v0); - - auto output = MaskedMaxOrZero(first_five, v1, v2); - - HWY_ASSERT_VEC_EQ(d, v1_exp, output); - } -}; - -HWY_NOINLINE void TestAllMaskedMaxOrZero() { - ForAllTypes(ForPartialVectors()); -} - } // namespace // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE @@ -173,7 +149,6 @@ namespace { HWY_BEFORE_TEST(HwyMaskedMinMaxTest); HWY_EXPORT_AND_TEST_P(HwyMaskedMinMaxTest, TestAllUnsignedMinMax); HWY_EXPORT_AND_TEST_P(HwyMaskedMinMaxTest, TestAllSignedMinMax); -HWY_EXPORT_AND_TEST_P(HwyMaskedMinMaxTest, TestAllMaskedMaxOrZero); HWY_AFTER_TEST(); } // namespace } // namespace hwy