Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various masked operations #2428

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,9 @@ types, and on SVE/RVV.

* <code>V **AndNot**(V a, V b)</code>: returns `~a[i] & b[i]`.

* <code>V **MaskedOrOrZero**(M m, V a, V b)</code>: returns `a[i] || b[i]`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about a different naming convention here which might be a bit more natural?
There is also a MaskedLoad which returns 0 as the default, as opposed to MaskedLoadOr, which has the explicit default value. If we apply that here, we can just call it MaskedOr(m, a b), what do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we mean a[i] | b[i]?

or `zero` if `m[i]` is false.

The following three-argument functions may be more efficient than assembling
them from 2-argument functions:

Expand Down Expand Up @@ -2237,6 +2240,22 @@ The following `ReverseN` must not be called if `Lanes(D()) < N`:
must be in the range `[0, 2 * Lanes(d))` but need not be unique. The index
type `TI` must be an integer of the same size as `TFromD<D>`.

* <code>V **TableLookupLanesOr**(M m, V a, V b, unspecified)</code> returns the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we don't yet have an optimized version of these op, and it's just a convenience wrapper over IfThenElse. Would it be an option to move this into a utility function within your codebase? It's not clear whether this provides enough value to be a documented op that all readers must know.

result of `TableLookupLanes(a, unspecified)` where `m[i]` is true, and returns
`b[i]` where `m[i]` is false.

* <code>V **TableLookupLanesOrZero**(M m, V a, unspecified)</code> returns
the result of `TableLookupLanes(a, unspecified)` where `m[i]` is true, and
returns zero where `m[i]` is false.

* <code>V **TwoTablesLookupLanesOr**(D d, M m, V a, V b, unspecified)</code>
returns the result of `TwoTablesLookupLanes(V a, V b, unspecified)` where
`m[i]` is true, and `a[i]` where `m[i]` is false.

* <code>V **TwoTablesLookupLanesOrZero**(D d, M m, V a, V b, unspecified)</code>
returns the result of `TwoTablesLookupLanes(V a, V b, unspecified)` where
`m[i]` is true, and zero where `m[i]` is false.

* <code>V **Per4LaneBlockShuffle**&lt;size_t kIdx3, size_t kIdx2, size_t
kIdx1, size_t kIdx0&gt;(V v)</code> does a per 4-lane block shuffle of `v`
if `Lanes(DFromV<V>())` is greater than or equal to 4 or a shuffle of the
Expand Down Expand Up @@ -2377,6 +2396,24 @@ more efficient on some targets.
* <code>T **ReduceMin**(D, V v)</code>: returns the minimum of all lanes.
* <code>T **ReduceMax**(D, V v)</code>: returns the maximum of all lanes.

### Masked reductions

**Note**: Horizontal operations (across lanes of the same vector) such as
reductions are slower than normal SIMD operations and are typically used outside
critical loops.

All ops in this section ignore lanes where `mask=false`. These are equivalent
to, and potentially more efficient than, `GetLane(SumOfLanes(d,
IfThenElseZero(m, v)))` etc. The result is implementation-defined when all mask
elements are false.

* <code>T **MaskedReduceSum**(D, M m, V v)</code>: returns the sum of all lanes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! This looks useful.
Please add a TODO that we should also implement this for RVV.

where `m[i]` is `true`.
* <code>T **MaskedReduceMin**(D, M m, V v)</code>: returns the minimum of all
lanes where `m[i]` is `true`.
* <code>T **MaskedReduceMax**(D, M m, V v)</code>: returns the maximum of all
lanes where `m[i]` is `true`.

### Crypto

Ops in this section are only available if `HWY_TARGET != HWY_SCALAR`:
Expand Down
72 changes: 72 additions & 0 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,15 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _)
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
return sv##OP##_##CHAR##BITS(v); \
}
#define HWY_SVE_RETV_ARGMV_M(BASE, CHAR, BITS, HALF, NAME, OP) \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: we have the naming convention P for predicate, for example in HWY_SVE_RETV_ARGPVV. I'm fine with either P or M, but let's please be consistent, feel free to pick one.
This might actually replace the existing HWY_SVE_RETV_ARGPV.

HWY_API HWY_SVE_V(BASE, BITS) \
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return sv##OP##_##CHAR##BITS##_m(b, m, a); \
}
#define HWY_SVE_RETV_ARGMV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a) { \
return sv##OP##_##CHAR##BITS##_z(m, a); \
}

// vector = f(vector, scalar), e.g. detail::AddN
#define HWY_SVE_RETV_ARGPVN(BASE, CHAR, BITS, HALF, NAME, OP) \
Expand Down Expand Up @@ -252,6 +261,17 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _)
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return sv##OP##_##CHAR##BITS##_x(m, a, b); \
}
#define HWY_SVE_RETV_ARGMVV_M(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return sv##OP##_##CHAR##BITS##_m(m, a, b); \
}
// User-specified mask. Mask=false value is zero.
#define HWY_SVE_RETV_ARGMVVZ(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return sv##OP##_##CHAR##BITS##_z(m, a, b); \
}

#define HWY_SVE_RETV_ARGVVV(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
Expand All @@ -260,6 +280,13 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _)
return sv##OP##_##CHAR##BITS(a, b, c); \
}

#define HWY_SVE_RETV_ARGMVVV(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \
HWY_SVE_V(BASE, BITS) c) { \
return sv##OP##_##CHAR##BITS##_m(m, a, b, c); \
}

// ------------------------------ Lanes

namespace detail {
Expand Down Expand Up @@ -727,6 +754,9 @@ HWY_API V Or(const V a, const V b) {
return BitCast(df, Or(BitCast(du, a), BitCast(du, b)));
}

// ------------------------------ MaskedOrOrZero
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVVZ, MaskedOrOrZero, orr)

// ------------------------------ Xor

namespace detail {
Expand Down Expand Up @@ -3288,6 +3318,25 @@ HWY_API TFromD<D> ReduceMax(D d, VFromD<D> v) {
return detail::MaxOfLanesM(detail::MakeMask(d), v);
}

#ifdef HWY_NATIVE_MASKED_REDUCE_SCALAR
#undef HWY_NATIVE_MASKED_REDUCE_SCALAR
#else
#define HWY_NATIVE_MASKED_REDUCE_SCALAR
#endif

template <class D, class M>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a TODO here that we can remove the SumOfLanesM in favor of using MaskedReduceSum directly. This entails adding the D arg to HWY_SVE_REDUCE_ADD as done in HWY_SVE_FIRSTN.

HWY_API TFromD<D> MaskedReduceSum(D /*d*/, M m, VFromD<D> v) {
return detail::SumOfLanesM(m, v);
}
template <class D, class M>
HWY_API TFromD<D> MaskedReduceMin(D /*d*/, M m, VFromD<D> v) {
return detail::MinOfLanesM(m, v);
}
template <class D, class M>
HWY_API TFromD<D> MaskedReduceMax(D /*d*/, M m, VFromD<D> v) {
return detail::MaxOfLanesM(m, v);
}

// ------------------------------ SumOfLanes

template <class D, HWY_IF_LANES_GT_D(D, 1)>
Expand Down Expand Up @@ -4755,6 +4804,23 @@ HWY_API V IfNegativeThenElse(V v, V yes, V no) {
static_assert(IsSigned<TFromV<V>>(), "Only works for signed/float");
return IfThenElse(IsNegative(v), yes, no);
}
// ------------------------------ IfNegativeThenNegOrUndefIfZero
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This op is undocumented, do we intend to add it? If so, let's add documentation and test.


#ifdef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG
#undef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG
#else
#define HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG
#endif

#define HWY_SVE_NEG_IF(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_V(BASE, BITS) mask, HWY_SVE_V(BASE, BITS) v) { \
return sv##OP##_##CHAR##BITS##_m(v, IsNegative(mask), v); \
}

HWY_SVE_FOREACH_IF(HWY_SVE_NEG_IF, IfNegativeThenNegOrUndefIfZero, neg)

#undef HWY_SVE_NEG_IF

// ------------------------------ AverageRound (ShiftRight)

Expand Down Expand Up @@ -6291,13 +6357,19 @@ HWY_API V HighestSetBitIndex(V v) {
#undef HWY_SVE_IF_NOT_EMULATED_D
#undef HWY_SVE_PTRUE
#undef HWY_SVE_RETV_ARGMVV
#undef HWY_SVE_RETV_ARGMVVZ
#undef HWY_SVE_RETV_ARGPV
#undef HWY_SVE_RETV_ARGPVN
#undef HWY_SVE_RETV_ARGPVV
#undef HWY_SVE_RETV_ARGV
#undef HWY_SVE_RETV_ARGVN
#undef HWY_SVE_RETV_ARGMV
#undef HWY_SVE_RETV_ARGMV_M
#undef HWY_SVE_RETV_ARGMV_Z
#undef HWY_SVE_RETV_ARGVV
#undef HWY_SVE_RETV_ARGMVV_M
#undef HWY_SVE_RETV_ARGVVV
#undef HWY_SVE_RETV_ARGMVVV
#undef HWY_SVE_T
#undef HWY_SVE_UNDEFINED
#undef HWY_SVE_V
Expand Down
50 changes: 50 additions & 0 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,28 @@ HWY_API TFromD<D> ReduceMax(D d, VFromD<D> v) {
}
#endif // HWY_NATIVE_REDUCE_MINMAX_4_UI8

#if (defined(HWY_NATIVE_MASKED_REDUCE_SCALAR) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_MASKED_REDUCE_SCALAR
#undef HWY_NATIVE_MASKED_REDUCE_SCALAR
#else
#define HWY_NATIVE_MASKED_REDUCE_SCALAR
#endif

template <class D, class M>
HWY_API TFromD<D> MaskedReduceSum(D d, M m, VFromD<D> v) {
return ReduceSum(d, IfThenElseZero(m, v));
}
template <class D, class M>
HWY_API TFromD<D> MaskedReduceMin(D d, M m, VFromD<D> v) {
return ReduceMin(d, IfThenElse(m, v, MaxOfLanes(d, v)));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems unnecessarily expensive, how about we replace MaxOfLanes with Set(d, hwy::HighestValue)?

}
template <class D, class M>
HWY_API TFromD<D> MaskedReduceMax(D d, M m, VFromD<D> v) {
return ReduceMax(d, IfThenElseZero(m, v));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can get into trouble for signed values. If all values are negative, the presence of mask=false elements changes the result. Can similarly use hwy::LowestValue here?

}

#endif // HWY_NATIVE_MASKED_REDUCE_SCALAR

// ------------------------------ IsEitherNaN
#if (defined(HWY_NATIVE_IS_EITHER_NAN) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_IS_EITHER_NAN
Expand Down Expand Up @@ -6444,6 +6466,30 @@ HWY_API V ReverseBits(V v) {
}
#endif // HWY_NATIVE_REVERSE_BITS_UI16_32_64

// ------------------------------ TableLookupLanesOr
template <class V, class M>
HWY_API V TableLookupLanesOr(M m, V a, V b, IndicesFromD<DFromV<V>> idx) {
return IfThenElse(m, TableLookupLanes(a, idx), b);
}

// ------------------------------ TableLookupLanesOrZero
template <class V, class M>
HWY_API V TableLookupLanesOrZero(M m, V a, IndicesFromD<DFromV<V>> idx) {
return IfThenElseZero(m, TableLookupLanes(a, idx));
}

// ------------------------------ TwoTablesLookupLanesOr
template <class D, class V, class M>
HWY_API V TwoTablesLookupLanesOr(D d, M m, V a, V b, IndicesFromD<D> idx) {
return IfThenElse(m, TwoTablesLookupLanes(d, a, b, idx), a);
}

// ------------------------------ TwoTablesLookupLanesOrZero
template <class D, class V, class M>
HWY_API V TwoTablesLookupLanesOrZero(D d, M m, V a, V b, IndicesFromD<D> idx) {
return IfThenElse(m, TwoTablesLookupLanes(d, a, b, idx), Zero(d));
}

// ------------------------------ Per4LaneBlockShuffle

#if (defined(HWY_NATIVE_PER4LANEBLKSHUF_DUP32) == defined(HWY_TARGET_TOGGLE))
Expand Down Expand Up @@ -7299,6 +7345,10 @@ HWY_API V BitShuffle(V v, VI idx) {

#endif // HWY_NATIVE_BITSHUFFLE

template <class V, class M>
HWY_API V MaskedOrOrZero(M m, V a, V b) {
return IfThenElseZero(m, Or(a, b));
}
// ================================================== Operator wrapper

// SVE* and RVV currently cannot define operators and have already defined
Expand Down
23 changes: 23 additions & 0 deletions hwy/tests/logical_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,28 @@ HWY_NOINLINE void TestAllTestBit() {
ForIntegerTypes(ForPartialVectors<TestTestBit>());
}

struct TestMaskedOrOrZero {
template <typename T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
const MFromD<D> all_true = MaskTrue(d);
const auto v1 = Iota(d, 1);
const auto v2 = Iota(d, 2);

HWY_ASSERT_VEC_EQ(d, Or(v2, v1), MaskedOrOrZero(all_true, v1, v2));

const MFromD<D> first_five = FirstN(d, 5);
const Vec<D> v0 = Zero(d);

const Vec<D> v1_exp = IfThenElse(first_five, Or(v2, v1), v0);

HWY_ASSERT_VEC_EQ(d, v1_exp, MaskedOrOrZero(first_five, v1, v2));
}
};

HWY_NOINLINE void TestAllMaskedLogical() {
ForAllTypes(ForPartialVectors<TestMaskedOrOrZero>());
}

} // namespace
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
Expand All @@ -159,6 +181,7 @@ HWY_BEFORE_TEST(HwyLogicalTest);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllNot);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllLogical);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllTestBit);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllMaskedLogical);
HWY_AFTER_TEST();
} // namespace
} // namespace hwy
Expand Down
Loading