Skip to content

Commit

Permalink
Fix review comments
Browse files Browse the repository at this point in the history
Rename SetOr* ops for consistency
Rename AllOnes/AllZeros to AllBits1/0
Remove MaskedLoadU, this is covered by MaskedLoad
Rename LowerHigher to InsertIntoUpper
Rework StoreTruncated, rename to TruncateStore
Rename macro arg
Avoid full-length load in LoadHigher
Optimise AllBits1
  • Loading branch information
wbb-ccl committed Jan 30, 2025
1 parent ecb2f36 commit 6b90d90
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 113 deletions.
26 changes: 10 additions & 16 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -429,9 +429,9 @@ for comparisons, for example `Lt` instead of `operator<`.
the result, with `t0` in the least-significant (lowest-indexed) lane of each
128-bit block and `tK` in the most-significant (highest-indexed) lane of
each 128-bit block: `{t0, t1, ..., tK}`
* <code>V **SetOr**(V no, M m, T a)</code>: returns N-lane vector with lane
* <code>V **MaskedSetOr**(V no, M m, T a)</code>: returns N-lane vector with lane
`i` equal to `a` if `m[i]` is true else `no[i]`.
* <code>V **SetOrZero**(D d, M m, T a)</code>: returns N-lane vector with lane
* <code>V **MaskedSet**(D d, M m, T a)</code>: returns N-lane vector with lane
`i` equal to `a` if `m[i]` is true else 0.

### Getting/setting lanes
Expand Down Expand Up @@ -1087,10 +1087,10 @@ Per-lane variable shifts (slow if SSSE3/SSE4, or 16-bit, or Shr i64 on AVX2):
```HighestValue<MakeSigned<TFromV<V>>>()``` is returned in the
corresponding result lanes.

* <code>bool **AllOnes**(D, V v)</code>: returns whether all bits in `v[i]`
* <code>bool **AllBits1**(D, V v)</code>: returns whether all bits in `v[i]`
are set.

* <code>bool **AllZeros**(D, V v)</code>: returns whether all bits in `v[i]`
* <code>bool **AllBits0**(D, V v)</code>: returns whether all bits in `v[i]`
are clear.

The following operate on individual bits within each lane. Note that the
Expand Down Expand Up @@ -1577,9 +1577,6 @@ aligned memory at indices which are not a multiple of the vector length):

* <code>Vec&lt;D&gt; **LoadU**(D, const T* p)</code>: returns `p[i]`.

* <code>Vec&lt;D&gt; **MaskedLoadU**(D, M m, const T* p)</code>: returns `p[i]`
where mask is true and returns zero otherwise.

* <code>Vec&lt;D&gt; **LoadDup128**(D, const T* p)</code>: returns one 128-bit
block loaded from `p` and broadcasted into all 128-bit block\[s\]. This may
be faster than broadcasting single values, and is more convenient than
Expand Down Expand Up @@ -1610,9 +1607,9 @@ aligned memory at indices which are not a multiple of the vector length):
lanes from `p` to the first (lowest-index) lanes of the result vector and
fills the remaining lanes with `no`. Like LoadN, this does not fault.

* <code> Vec&lt;D&gt; **LoadHigher**(D d, V v, T* p)</code>: Loads `Lanes(d)/2` lanes from
`p` into the upper lanes of the result vector and the lower half of `v` into
the lower lanes.
* <code> Vec&lt;D&gt; **InsertIntoUpper**(D d, T* p, V v)</code>: Loads `Lanes(d)/2`
lanes from `p` into the upper lanes of the result vector and the lower half
of `v` into the lower lanes.

#### Store

Expand Down Expand Up @@ -1653,12 +1650,9 @@ aligned memory at indices which are not a multiple of the vector length):
StoreN does not modify any memory past
`p + HWY_MIN(Lanes(d), max_lanes_to_store) - 1`.

* <code>void **StoreTruncated**(Vec&lt;DFrom&gt; v, DFrom d, To* HWY_RESTRICT
p)</code>: Truncates elements of `v` to type `To` and stores on `p`. It is
similar to performing `TruncateTo` followed by `StoreN` where
`max_lanes_to_store` is `Lanes(d)`.

StoreTruncated does not modify any memory past `p + Lanes(d) - 1`.
* <code>void **TruncateStore**(Vec&lt;D&gt; v, D d, T* HWY_RESTRICT p)</code>:
Truncates elements of `v` to type `T` and stores on `p`. It is similar to
performing `TruncateTo` followed by `StoreU`.

#### Interleaved

Expand Down
44 changes: 16 additions & 28 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,26 +427,26 @@ using VFromD = decltype(Set(D(), TFromD<D>()));

using VBF16 = VFromD<ScalableTag<bfloat16_t>>;

// ------------------------------ SetOr/SetOrZero
// ------------------------------ MaskedSetOr/MaskedSet

#define HWY_SVE_SET_OR(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) inactive, \
svbool_t m, HWY_SVE_T(BASE, BITS) op) { \
return sv##OP##_##CHAR##BITS##_m(inactive, m, op); \
#define HWY_SVE_MASKED_SET_OR(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_V(BASE, BITS) no, svbool_t m, HWY_SVE_T(BASE, BITS) op) { \
return sv##OP##_##CHAR##BITS##_m(no, m, op); \
}

HWY_SVE_FOREACH(HWY_SVE_SET_OR, SetOr, dup_n)
#undef HWY_SVE_SET_OR
HWY_SVE_FOREACH(HWY_SVE_MASKED_SET_OR, MaskedSetOr, dup_n)
#undef HWY_SVE_MASKED_SET_OR

#define HWY_SVE_SET_OR_ZERO(BASE, CHAR, BITS, HALF, NAME, OP) \
#define HWY_SVE_MASKED_SET(BASE, CHAR, BITS, HALF, NAME, OP) \
template <size_t N, int kPow2> \
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \
svbool_t m, HWY_SVE_T(BASE, BITS) op) { \
return sv##OP##_##CHAR##BITS##_z(m, op); \
}

HWY_SVE_FOREACH(HWY_SVE_SET_OR_ZERO, SetOrZero, dup_n)
#undef HWY_SVE_SET_OR_ZERO
HWY_SVE_FOREACH(HWY_SVE_MASKED_SET, MaskedSet, dup_n)
#undef HWY_SVE_MASKED_SET

// ------------------------------ Zero

Expand Down Expand Up @@ -2232,18 +2232,6 @@ HWY_API void BlendedStore(VFromD<D> v, MFromD<D> m, D d,

#undef HWY_SVE_MEM

#define HWY_SVE_MASKED_MEM(BASE, CHAR, BITS, HALF, NAME, OP) \
template <size_t N, int kPow2> \
HWY_API HWY_SVE_V(BASE, BITS) \
MaskedLoadU(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, svbool_t m, \
const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \
return svld1_##CHAR##BITS(m, detail::NativeLanePointer(p)); \
}

HWY_SVE_FOREACH(HWY_SVE_MASKED_MEM, _, _)

#undef HWY_SVE_MASKED_MEM

#if HWY_TARGET != HWY_SVE2_128
namespace detail {
#define HWY_SVE_LOAD_DUP128(BASE, CHAR, BITS, HALF, NAME, OP) \
Expand Down Expand Up @@ -2312,12 +2300,12 @@ HWY_API VFromD<D> LoadDup128(D d, const TFromD<D>* HWY_RESTRICT p) {
#define HWY_SVE_STORE_TRUNCATED_WORD(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 32)

HWY_SVE_FOREACH_UI16(HWY_SVE_STORE_TRUNCATED_BYTE, StoreTruncated, st1b)
HWY_SVE_FOREACH_UI32(HWY_SVE_STORE_TRUNCATED_BYTE, StoreTruncated, st1b)
HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_BYTE, StoreTruncated, st1b)
HWY_SVE_FOREACH_UI32(HWY_SVE_STORE_TRUNCATED_HALF, StoreTruncated, st1h)
HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_HALF, StoreTruncated, st1h)
HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_WORD, StoreTruncated, st1w)
HWY_SVE_FOREACH_UI16(HWY_SVE_STORE_TRUNCATED_BYTE, TruncateStore, st1b)
HWY_SVE_FOREACH_UI32(HWY_SVE_STORE_TRUNCATED_BYTE, TruncateStore, st1b)
HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_BYTE, TruncateStore, st1b)
HWY_SVE_FOREACH_UI32(HWY_SVE_STORE_TRUNCATED_HALF, TruncateStore, st1h)
HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_HALF, TruncateStore, st1h)
HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_WORD, TruncateStore, st1w)

#undef HWY_SVE_STORE_TRUNCATED

Expand Down
45 changes: 20 additions & 25 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,18 +97,18 @@ HWY_API Vec<D> Inf(D d) {
return BitCast(d, Set(du, max_x2 >> 1));
}

// ------------------------------ SetOr/SetOrZero
// ------------------------------ MaskedSetOr/MaskedSet

template <class V, typename T = TFromV<V>, typename D = DFromV<V>,
typename M = MFromD<D>>
HWY_API V SetOr(V no, M m, T a) {
HWY_API V MaskedSetOr(V no, M m, T a) {
D d;
return IfThenElse(m, Set(d, a), no);
}

template <class D, typename V = VFromD<D>, typename M = MFromD<D>,
typename T = TFromD<D>>
HWY_API V SetOrZero(D d, M m, T a) {
HWY_API V MaskedSet(D d, M m, T a) {
return IfThenElseZero(m, Set(d, a));
}

Expand Down Expand Up @@ -351,17 +351,18 @@ HWY_API Mask<DTo> DemoteMaskTo(DTo d_to, DFrom d_from, Mask<DFrom> m) {

#endif // HWY_NATIVE_DEMOTE_MASK_TO

// ------------------------------ LoadHigher
// ------------------------------ InsertIntoUpper
#if (defined(HWY_NATIVE_LOAD_HIGHER) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_LOAD_HIGHER
#undef HWY_NATIVE_LOAD_HIGHER
#else
#define HWY_NATIVE_LOAD_HIGHER
#endif
template <class D, typename T, class V = VFromD<D>(), HWY_IF_LANES_GT_D(D, 1)>
HWY_API V LoadHigher(D d, V a, T* p) {
const V b = LoadU(d, p);
return ConcatLowerLower(d, b, a);
HWY_API V InsertIntoUpper(D d, T* p, V a) {
Half<D> dh;
const VFromD<decltype(dh)> b = LoadU(dh, p);
return Combine(d, b, LowerHalf(a));
}
#endif // HWY_NATIVE_LOAD_HIGHER

Expand Down Expand Up @@ -1204,12 +1205,6 @@ HWY_API V MulByFloorPow2(V v, V exp) {

#endif // HWY_NATIVE_MUL_BY_POW2

// ------------------------------ MaskedLoadU
template <class D, class M>
HWY_API VFromD<D> MaskedLoadU(D d, M m,
const TFromD<D>* HWY_RESTRICT unaligned) {
return IfThenElseZero(m, LoadU(d, unaligned));
}
// ------------------------------ GetExponent

#if (defined(HWY_NATIVE_GET_EXPONENT) == defined(HWY_TARGET_TOGGLE))
Expand Down Expand Up @@ -2694,21 +2689,20 @@ HWY_API void StoreN(VFromD<D> v, D d, T* HWY_RESTRICT p,

#endif // (defined(HWY_NATIVE_STORE_N) == defined(HWY_TARGET_TOGGLE))

// ------------------------------ StoreTruncated
// ------------------------------ TruncateStore
#if (defined(HWY_NATIVE_STORE_TRUNCATED) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_STORE_TRUNCATED
#undef HWY_NATIVE_STORE_TRUNCATED
#else
#define HWY_NATIVE_STORE_TRUNCATED
#endif

template <class DFrom, class To, class DTo = Rebind<To, DFrom>,
HWY_IF_T_SIZE_GT_D(DFrom, sizeof(To)),
HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(VFromD<DFrom>)>
HWY_API void StoreTruncated(VFromD<DFrom> v, const DFrom d,
To* HWY_RESTRICT p) {
template <class D, class T, HWY_IF_T_SIZE_GT_D(D, sizeof(T)),
HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)>
HWY_API void TruncateStore(VFromD<D> v, const D /*d*/, T* HWY_RESTRICT p) {
using DTo = Rebind<T, D>;
DTo dsmall;
StoreN(TruncateTo(dsmall, v), dsmall, p, Lanes(d));
StoreU(TruncateTo(dsmall, v), dsmall, p);
}

#endif // (defined(HWY_NATIVE_STORE_TRUNCATED) == defined(HWY_TARGET_TOGGLE))
Expand Down Expand Up @@ -7511,7 +7505,7 @@ HWY_API V BitShuffle(V v, VI idx) {

#endif // HWY_NATIVE_BITSHUFFLE

// ------------------------------ AllOnes/AllZeros
// ------------------------------ AllBits1/AllBits0
#if (defined(HWY_NATIVE_ALLONES) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_ALLONES
#undef HWY_NATIVE_ALLONES
Expand All @@ -7520,9 +7514,10 @@ HWY_API V BitShuffle(V v, VI idx) {
#endif

template <class V>
HWY_API bool AllOnes(V a) {
DFromV<V> d;
return AllTrue(d, Eq(Not(a), Zero(d)));
HWY_API bool AllBits1(V a) {
const RebindToUnsigned<DFromV<V>> du;
using TU = TFromD<decltype(du)>;
return AllTrue(du, Eq(BitCast(du, a), Set(du, hwy::HighestValue<TU>())));
}
#endif // HWY_NATIVE_ALLONES

Expand All @@ -7534,7 +7529,7 @@ HWY_API bool AllOnes(V a) {
#endif

template <class V>
HWY_API bool AllZeros(V a) {
HWY_API bool AllBits0(V a) {
DFromV<V> d;
return AllTrue(d, Eq(a, Zero(d)));
}
Expand Down
24 changes: 12 additions & 12 deletions hwy/tests/logical_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,30 +146,30 @@ HWY_NOINLINE void TestAllTestBit() {
ForIntegerTypes(ForPartialVectors<TestTestBit>());
}

struct TestAllOnes {
struct TestAllBits {
template <class T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
auto v0s = Zero(d);
HWY_ASSERT(AllZeros(v0s));
HWY_ASSERT(AllBits0(v0s));
auto v1s = Not(v0s);
HWY_ASSERT(AllOnes(v1s));
HWY_ASSERT(AllBits1(v1s));
const size_t kNumBits = sizeof(T) * 8;
for (size_t i = 0; i < kNumBits; ++i) {
const Vec<D> bit1 = Set(d, static_cast<T>(1ull << i));
const Vec<D> bit2 = Set(d, static_cast<T>(1ull << ((i + 1) % kNumBits)));
const Vec<D> bits12 = Or(bit1, bit2);
HWY_ASSERT(!AllOnes(bit1));
HWY_ASSERT(!AllZeros(bit1));
HWY_ASSERT(!AllOnes(bit2));
HWY_ASSERT(!AllZeros(bit2));
HWY_ASSERT(!AllOnes(bits12));
HWY_ASSERT(!AllZeros(bits12));
HWY_ASSERT(!AllBits1(bit1));
HWY_ASSERT(!AllBits0(bit1));
HWY_ASSERT(!AllBits1(bit2));
HWY_ASSERT(!AllBits0(bit2));
HWY_ASSERT(!AllBits1(bits12));
HWY_ASSERT(!AllBits0(bits12));
}
}
};

HWY_NOINLINE void TestAllAllOnes() {
ForIntegerTypes(ForPartialVectors<TestAllOnes>());
HWY_NOINLINE void TestAllAllBits() {
ForIntegerTypes(ForPartialVectors<TestAllBits>());
}

} // namespace
Expand All @@ -185,7 +185,7 @@ HWY_BEFORE_TEST(HwyLogicalTest);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllNot);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllLogical);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllTestBit);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllAllOnes);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllAllBits);

HWY_AFTER_TEST();
} // namespace
Expand Down
20 changes: 10 additions & 10 deletions hwy/tests/mask_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -317,14 +317,14 @@ HWY_NOINLINE void TestAllLogicalMask() {
ForAllTypes(ForPartialVectors<TestLogicalMask>());
}

struct TestSetOr {
struct TestMaskedSetOr {
template <class D>
void testWithMask(D d, MFromD<D> m) {
TFromD<D> a = 1;
auto yes = Set(d, a);
auto no = Set(d, 2);
auto expected = IfThenElse(m, yes, no);
auto actual = SetOr(no, m, a);
auto actual = MaskedSetOr(no, m, a);
HWY_ASSERT_VEC_EQ(d, expected, actual);
}
template <class T, class D>
Expand All @@ -344,18 +344,18 @@ struct TestSetOr {
}
};

HWY_NOINLINE void TestAllSetOr() {
ForAllTypes(ForShrinkableVectors<TestSetOr>());
HWY_NOINLINE void TestAllMaskedSetOr() {
ForAllTypes(ForShrinkableVectors<TestMaskedSetOr>());
}

struct TestSetOrZero {
struct TestMaskedSet {
template <class D>
void testWithMask(D d, MFromD<D> m) {
TFromD<D> a = 1;
auto yes = Set(d, a);
auto no = Zero(d);
auto expected = IfThenElse(m, yes, no);
auto actual = SetOrZero(d, m, a);
auto actual = MaskedSet(d, m, a);
HWY_ASSERT_VEC_EQ(d, expected, actual);
}
template <class T, class D>
Expand All @@ -375,8 +375,8 @@ struct TestSetOrZero {
}
};

HWY_NOINLINE void TestAllSetOrZero() {
ForAllTypes(ForShrinkableVectors<TestSetOrZero>());
HWY_NOINLINE void TestAllMaskedSet() {
ForAllTypes(ForShrinkableVectors<TestMaskedSet>());
}

} // namespace
Expand All @@ -397,8 +397,8 @@ HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllCountTrue);
HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllFindFirstTrue);
HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllFindLastTrue);
HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllLogicalMask);
HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllSetOr);
HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllSetOrZero);
HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllMaskedSetOr);
HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllMaskedSet);
HWY_AFTER_TEST();
} // namespace
} // namespace hwy
Expand Down
Loading

0 comments on commit 6b90d90

Please sign in to comment.