Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load/Store, masked set and counting operations #2430

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,10 @@ for comparisons, for example `Lt` instead of `operator<`.
the result, with `t0` in the least-significant (lowest-indexed) lane of each
128-bit block and `tK` in the most-significant (highest-indexed) lane of
each 128-bit block: `{t0, t1, ..., tK}`
* <code>V **MaskedSetOr**(V no, M m, T a)</code>: returns N-lane vector with lane
`i` equal to `a` if `m[i]` is true else `no[i]`.
* <code>V **MaskedSet**(D d, M m, T a)</code>: returns N-lane vector with lane
`i` equal to `a` if `m[i]` is true else 0.

### Getting/setting lanes

Expand Down Expand Up @@ -1065,6 +1069,10 @@ Per-lane variable shifts (slow if SSSE3/SSE4, or 16-bit, or Shr i64 on AVX2):
leading zeros in each lane. For any lanes where ```a[i]``` is zero,
```sizeof(TFromV<V>) * 8``` is returned in the corresponding result lanes.

* `V`: `{u,i}` \
<code>V **MaskedLeadingZeroCountOrZero**(M m, `V a)</code>: returns the
jan-wassenberg marked this conversation as resolved.
Show resolved Hide resolved
result of LeadingZeroCount where `m[i]` is true, and zero otherwise.

* `V`: `{u,i}` \
<code>V **TrailingZeroCount**(V a)</code>: returns the number of
trailing zeros in each lane. For any lanes where ```a[i]``` is zero,
Expand All @@ -1079,6 +1087,12 @@ Per-lane variable shifts (slow if SSSE3/SSE4, or 16-bit, or Shr i64 on AVX2):
```HighestValue<MakeSigned<TFromV<V>>>()``` is returned in the
corresponding result lanes.

* <code>bool **AllBits1**(D, V v)</code>: returns whether all bits in `v[i]`
jan-wassenberg marked this conversation as resolved.
Show resolved Hide resolved
are set.

* <code>bool **AllBits0**(D, V v)</code>: returns whether all bits in `v[i]`
are clear.

The following operate on individual bits within each lane. Note that the
non-operator functions (`And` instead of `&`) must be used for floating-point
types, and on SVE/RVV.
Expand Down Expand Up @@ -1593,6 +1607,10 @@ aligned memory at indices which are not a multiple of the vector length):
lanes from `p` to the first (lowest-index) lanes of the result vector and
fills the remaining lanes with `no`. Like LoadN, this does not fault.

* <code> Vec&lt;D&gt; **InsertIntoUpper**(D d, T* p, V v)</code>: Loads `Lanes(d)/2`
lanes from `p` into the upper lanes of the result vector and the lower half
of `v` into the lower lanes.

#### Store

* <code>void **Store**(Vec&lt;D&gt; v, D, T* aligned)</code>: copies `v[i]`
Expand Down Expand Up @@ -1632,6 +1650,10 @@ aligned memory at indices which are not a multiple of the vector length):
StoreN does not modify any memory past
`p + HWY_MIN(Lanes(d), max_lanes_to_store) - 1`.

* <code>void **TruncateStore**(Vec&lt;D&gt; v, D d, T* HWY_RESTRICT p)</code>:
Truncates elements of `v` to type `T` and stores on `p`. It is similar to
performing `TruncateTo` followed by `StoreU`.

#### Interleaved

* <code>void **LoadInterleaved2**(D, const T* p, Vec&lt;D&gt;&amp; v0,
Expand Down
68 changes: 68 additions & 0 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,27 @@ using VFromD = decltype(Set(D(), TFromD<D>()));

using VBF16 = VFromD<ScalableTag<bfloat16_t>>;

// ------------------------------ MaskedSetOr/MaskedSet

#define HWY_SVE_MASKED_SET_OR(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_V(BASE, BITS) no, svbool_t m, HWY_SVE_T(BASE, BITS) op) { \
return sv##OP##_##CHAR##BITS##_m(no, m, op); \
}

HWY_SVE_FOREACH(HWY_SVE_MASKED_SET_OR, MaskedSetOr, dup_n)
#undef HWY_SVE_MASKED_SET_OR

#define HWY_SVE_MASKED_SET(BASE, CHAR, BITS, HALF, NAME, OP) \
template <size_t N, int kPow2> \
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \
svbool_t m, HWY_SVE_T(BASE, BITS) op) { \
return sv##OP##_##CHAR##BITS##_z(m, op); \
}

HWY_SVE_FOREACH(HWY_SVE_MASKED_SET, MaskedSet, dup_n)
#undef HWY_SVE_MASKED_SET

// ------------------------------ Zero

template <class D>
Expand Down Expand Up @@ -2257,6 +2278,37 @@ HWY_API VFromD<D> LoadDup128(D d, const TFromD<D>* HWY_RESTRICT p) {

#endif // HWY_TARGET != HWY_SVE2_128

// Truncate to smaller size and store
#ifdef HWY_NATIVE_STORE_TRUNCATED
#undef HWY_NATIVE_STORE_TRUNCATED
#else
#define HWY_NATIVE_STORE_TRUNCATED
#endif

#define HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, TO_BITS) \
template <size_t N, int kPow2> \
HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, \
const HWY_SVE_D(BASE, BITS, N, kPow2) d, \
HWY_SVE_T(BASE, TO_BITS) * HWY_RESTRICT p) { \
sv##OP##_##CHAR##BITS(detail::PTrue(d), detail::NativeLanePointer(p), v); \
}

#define HWY_SVE_STORE_TRUNCATED_BYTE(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 8)
#define HWY_SVE_STORE_TRUNCATED_HALF(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 16)
#define HWY_SVE_STORE_TRUNCATED_WORD(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 32)

HWY_SVE_FOREACH_UI16(HWY_SVE_STORE_TRUNCATED_BYTE, TruncateStore, st1b)
HWY_SVE_FOREACH_UI32(HWY_SVE_STORE_TRUNCATED_BYTE, TruncateStore, st1b)
HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_BYTE, TruncateStore, st1b)
HWY_SVE_FOREACH_UI32(HWY_SVE_STORE_TRUNCATED_HALF, TruncateStore, st1h)
HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_HALF, TruncateStore, st1h)
HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_WORD, TruncateStore, st1w)

#undef HWY_SVE_STORE_TRUNCATED

// ------------------------------ Load/Store

// SVE only requires lane alignment, not natural alignment of the entire
Expand Down Expand Up @@ -6442,6 +6494,22 @@ HWY_API V HighestSetBitIndex(V v) {
return BitCast(d, Sub(Set(d, T{sizeof(T) * 8 - 1}), LeadingZeroCount(v)));
}

#ifdef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT
#undef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT
#else
#define HWY_NATIVE_MASKED_LEADING_ZERO_COUNT
#endif

#define HWY_SVE_MASKED_LEADING_ZERO_COUNT(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v) { \
const DFromV<decltype(v)> d; \
return BitCast(d, sv##OP##_##CHAR##BITS##_z(m, v)); \
}

HWY_SVE_FOREACH_UI(HWY_SVE_MASKED_LEADING_ZERO_COUNT,
MaskedLeadingZeroCountOrZero, clz)
#undef HWY_SVE_LEADING_ZERO_COUNT

// ================================================== END MACROS
#undef HWY_SVE_ALL_PTRUE
#undef HWY_SVE_D
Expand Down
92 changes: 92 additions & 0 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,21 @@ HWY_API Vec<D> Inf(D d) {
return BitCast(d, Set(du, max_x2 >> 1));
}

// ------------------------------ MaskedSetOr/MaskedSet

template <class V, typename T = TFromV<V>, typename D = DFromV<V>,
typename M = MFromD<D>>
HWY_API V MaskedSetOr(V no, M m, T a) {
D d;
return IfThenElse(m, Set(d, a), no);
}

template <class D, typename V = VFromD<D>, typename M = MFromD<D>,
typename T = TFromD<D>>
HWY_API V MaskedSet(D d, M m, T a) {
return IfThenElseZero(m, Set(d, a));
}

// ------------------------------ ZeroExtendResizeBitCast

// The implementation of detail::ZeroExtendResizeBitCast for the HWY_EMU128
Expand Down Expand Up @@ -336,6 +351,21 @@ HWY_API Mask<DTo> DemoteMaskTo(DTo d_to, DFrom d_from, Mask<DFrom> m) {

#endif // HWY_NATIVE_DEMOTE_MASK_TO

// ------------------------------ InsertIntoUpper
#if (defined(HWY_NATIVE_LOAD_HIGHER) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_LOAD_HIGHER
#undef HWY_NATIVE_LOAD_HIGHER
#else
#define HWY_NATIVE_LOAD_HIGHER
#endif
template <class D, typename T, class V = VFromD<D>(), HWY_IF_LANES_GT_D(D, 1)>
HWY_API V InsertIntoUpper(D d, T* p, V a) {
Half<D> dh;
const VFromD<decltype(dh)> b = LoadU(dh, p);
return Combine(d, b, LowerHalf(a));
}
#endif // HWY_NATIVE_LOAD_HIGHER

// ------------------------------ CombineMasks

#if (defined(HWY_NATIVE_COMBINE_MASKS) == defined(HWY_TARGET_TOGGLE))
Expand Down Expand Up @@ -2659,6 +2689,24 @@ HWY_API void StoreN(VFromD<D> v, D d, T* HWY_RESTRICT p,

#endif // (defined(HWY_NATIVE_STORE_N) == defined(HWY_TARGET_TOGGLE))

// ------------------------------ TruncateStore
#if (defined(HWY_NATIVE_STORE_TRUNCATED) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_STORE_TRUNCATED
#undef HWY_NATIVE_STORE_TRUNCATED
#else
#define HWY_NATIVE_STORE_TRUNCATED
#endif

template <class D, class T, HWY_IF_T_SIZE_GT_D(D, sizeof(T)),
HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)>
HWY_API void TruncateStore(VFromD<D> v, const D /*d*/, T* HWY_RESTRICT p) {
using DTo = Rebind<T, D>;
DTo dsmall;
StoreU(TruncateTo(dsmall, v), dsmall, p);
}

#endif // (defined(HWY_NATIVE_STORE_TRUNCATED) == defined(HWY_TARGET_TOGGLE))

// ------------------------------ Scatter

#if (defined(HWY_NATIVE_SCATTER) == defined(HWY_TARGET_TOGGLE))
Expand Down Expand Up @@ -3886,6 +3934,21 @@ HWY_API V TrailingZeroCount(V v) {
}
#endif // HWY_NATIVE_LEADING_ZERO_COUNT

// ------------------------------ MaskedLeadingZeroCountOrZero
#if (defined(HWY_NATIVE_MASKED_LEADING_ZERO_COUNT) == \
defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT
#undef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT
#else
#define HWY_NATIVE_MASKED_LEADING_ZERO_COUNT
#endif

template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), class M>
HWY_API V MaskedLeadingZeroCountOrZero(M m, V v) {
return IfThenElseZero(m, LeadingZeroCount(v));
}
#endif // HWY_NATIVE_MASKED_LEADING_ZERO_COUNT

// ------------------------------ AESRound

// Cannot implement on scalar: need at least 16 bytes for TableLookupBytes.
Expand Down Expand Up @@ -7442,6 +7505,35 @@ HWY_API V BitShuffle(V v, VI idx) {

#endif // HWY_NATIVE_BITSHUFFLE

// ------------------------------ AllBits1/AllBits0
#if (defined(HWY_NATIVE_ALLONES) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_ALLONES
#undef HWY_NATIVE_ALLONES
#else
#define HWY_NATIVE_ALLONES
#endif

template <class V>
HWY_API bool AllBits1(V a) {
const RebindToUnsigned<DFromV<V>> du;
using TU = TFromD<decltype(du)>;
return AllTrue(du, Eq(BitCast(du, a), Set(du, hwy::HighestValue<TU>())));
}
#endif // HWY_NATIVE_ALLONES

#if (defined(HWY_NATIVE_ALLZEROS) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_ALLZEROS
#undef HWY_NATIVE_ALLZEROS
#else
#define HWY_NATIVE_ALLZEROS
#endif

template <class V>
HWY_API bool AllBits0(V a) {
DFromV<V> d;
return AllTrue(d, Eq(a, Zero(d)));
}
#endif // HWY_NATIVE_ALLZEROS
// ================================================== Operator wrapper

// SVE* and RVV currently cannot define operators and have already defined
Expand Down
43 changes: 43 additions & 0 deletions hwy/tests/count_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,48 @@ HWY_NOINLINE void TestAllLeadingZeroCount() {
ForIntegerTypes(ForPartialVectors<TestLeadingZeroCount>());
}

struct TestMaskedLeadingZeroCount {
template <class T, class D>
HWY_ATTR_NO_MSAN HWY_NOINLINE void operator()(T /*unused*/, D d) {
RandomState rng;
using TU = MakeUnsigned<T>;
const RebindToUnsigned<decltype(d)> du;
size_t N = Lanes(d);
const MFromD<D> first_3 = FirstN(d, 3);
auto data = AllocateAligned<T>(N);
auto lzcnt = AllocateAligned<T>(N);
HWY_ASSERT(data && lzcnt);

constexpr T kNumOfBitsInT = static_cast<T>(sizeof(T) * 8);
for (size_t j = 0; j < N; j++) {
if (j < 3) {
lzcnt[j] = static_cast<T>(kNumOfBitsInT - 2);
} else {
lzcnt[j] = static_cast<T>(0);
}
}
HWY_ASSERT_VEC_EQ(
d, lzcnt.get(),
MaskedLeadingZeroCountOrZero(first_3, Set(d, static_cast<T>(2))));

for (size_t j = 0; j < N; j++) {
if (j < 3) {
lzcnt[j] = static_cast<T>(1);
} else {
lzcnt[j] = static_cast<T>(0);
}
}
HWY_ASSERT_VEC_EQ(
d, lzcnt.get(),
MaskedLeadingZeroCountOrZero(
first_3, BitCast(d, Set(du, TU{1} << (kNumOfBitsInT - 2)))));
}
};

HWY_NOINLINE void TestAllMaskedLeadingZeroCount() {
ForIntegerTypes(ForPartialVectors<TestMaskedLeadingZeroCount>());
}

template <class T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T),
HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2) | (1 << 4))>
static HWY_INLINE T TrailingZeroCountOfValue(T val) {
Expand Down Expand Up @@ -303,6 +345,7 @@ namespace {
HWY_BEFORE_TEST(HwyCountTest);
HWY_EXPORT_AND_TEST_P(HwyCountTest, TestAllPopulationCount);
HWY_EXPORT_AND_TEST_P(HwyCountTest, TestAllLeadingZeroCount);
HWY_EXPORT_AND_TEST_P(HwyCountTest, TestAllMaskedLeadingZeroCount);
HWY_EXPORT_AND_TEST_P(HwyCountTest, TestAllTrailingZeroCount);
HWY_EXPORT_AND_TEST_P(HwyCountTest, TestAllHighestSetBitIndex);
HWY_AFTER_TEST();
Expand Down
28 changes: 28 additions & 0 deletions hwy/tests/logical_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,32 @@ HWY_NOINLINE void TestAllTestBit() {
ForIntegerTypes(ForPartialVectors<TestTestBit>());
}

struct TestAllBits {
template <class T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
auto v0s = Zero(d);
HWY_ASSERT(AllBits0(v0s));
auto v1s = Not(v0s);
HWY_ASSERT(AllBits1(v1s));
const size_t kNumBits = sizeof(T) * 8;
for (size_t i = 0; i < kNumBits; ++i) {
const Vec<D> bit1 = Set(d, static_cast<T>(1ull << i));
const Vec<D> bit2 = Set(d, static_cast<T>(1ull << ((i + 1) % kNumBits)));
const Vec<D> bits12 = Or(bit1, bit2);
HWY_ASSERT(!AllBits1(bit1));
HWY_ASSERT(!AllBits0(bit1));
HWY_ASSERT(!AllBits1(bit2));
HWY_ASSERT(!AllBits0(bit2));
HWY_ASSERT(!AllBits1(bits12));
HWY_ASSERT(!AllBits0(bits12));
}
}
};

HWY_NOINLINE void TestAllAllBits() {
ForIntegerTypes(ForPartialVectors<TestAllBits>());
}

} // namespace
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
Expand All @@ -159,6 +185,8 @@ HWY_BEFORE_TEST(HwyLogicalTest);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllNot);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllLogical);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllTestBit);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllAllBits);

HWY_AFTER_TEST();
} // namespace
} // namespace hwy
Expand Down
Loading
Loading