Skip to content

Commit

Permalink
Fix review comments
Browse files Browse the repository at this point in the history
Remove OrZero suffixes for consistency
Convert SqrtLower into MaskedSqrtOr
Add TODO comments about GetExponent to x86_512 and ppc_vsx
  • Loading branch information
wbb-ccl committed Jan 28, 2025
1 parent bb045cc commit e3c6c3b
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 78 deletions.
13 changes: 6 additions & 7 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -657,10 +657,6 @@ from left to right, of the arguments passed to `Create{2-4}`.
* `V`: `{f}` \
<code>V **Sqrt**(V a)</code>: returns `sqrt(a[i])`.

* `V`: `{f}` \
<code>V **SqrtLower**(V a)</code>: returns `sqrt(a[0])` in lowest lane and
`a[i]` elsewhere.

* `V`: `{f}` \
<code>V **ApproximateReciprocalSqrt**(V a)</code>: returns an approximation
of `1.0 / sqrt(a[i])`. `sqrt(a) ~= ApproximateReciprocalSqrt(a) * a`. x86
Expand Down Expand Up @@ -893,6 +889,9 @@ exceptions for those lanes if that is supported by the ISA. When exceptions are
not a concern, these are equivalent to, and potentially more efficient than,
`IfThenElse(m, Add(a, b), no);` etc.

* `V`: `{f}` \
<code>V **MaskedSqrtOr**(V no, M m, V a)</code>: returns `sqrt(a[i])` or
`no[i]` if `m[i]` is false.
* <code>V **MaskedMinOr**(V no, M m, V a, V b)</code>: returns `Min(a, b)[i]`
or `no[i]` if `m[i]` is false.
* <code>V **MaskedMaxOr**(V no, M m, V a, V b)</code>: returns `Max(a, b)[i]`
Expand Down Expand Up @@ -923,13 +922,13 @@ All ops in this section return `0` for `mask=false` lanes. These are equivalent
to, and potentially more efficient than, `IfThenElseZero(m, Add(a, b));` etc.

* `V`: `{f}` \
<code>V **MaskedSqrtOrZero**(M m, V a)</code>: returns `sqrt(a[i])` where
<code>V **MaskedSqrt**(M m, V a)</code>: returns `sqrt(a[i])` where
m is true, and zero otherwise.
* `V`: `{f}` \
<code>V **MaskedApproximateReciprocalSqrtOrZero**(M m, V a)</code>: returns
<code>V **MaskedApproximateReciprocalSqrt**(M m, V a)</code>: returns
the result of ApproximateReciprocalSqrt where m is true and zero otherwise.
* `V`: `{f}` \
<code>V **MaskedApproximateReciprocalOrZero**(M m, V a)</code>: returns the
<code>V **MaskedApproximateReciprocal**(M m, V a)</code>: returns the
result of ApproximateReciprocal where m is true and zero otherwise.

#### Shifts
Expand Down
37 changes: 15 additions & 22 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,9 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _)
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
return sv##OP##_##CHAR##BITS(v); \
}
#define HWY_SVE_RETV_ARGMV_M(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return sv##OP##_##CHAR##BITS##_m(b, m, a); \
#define HWY_SVE_RETV_ARGMV(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v) { \
return sv##OP##_##CHAR##BITS##_x(m, v); \
}
#define HWY_SVE_RETV_ARGMV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a) { \
Expand Down Expand Up @@ -1244,27 +1243,13 @@ HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGV, ApproximateReciprocal, recpe)
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Sqrt, sqrt)

// ------------------------------ MaskedSqrt
namespace detail {
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMV_M, MaskedSqrt, sqrt)
}

// ------------------------------ SqrtLower
#ifdef HWY_NATIVE_SQRT_LOWER
#undef HWY_NATIVE_SQRT_LOWER
#ifdef HWY_NATIVE_MASKED_SQRT
#undef HWY_NATIVE_MASKED_SQRT
#else
#define HWY_NATIVE_SQRT_LOWER
#define HWY_NATIVE_MASKED_SQRT
#endif

#define HWY_SVE_SQRT_LOWER(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) a) { \
return detail::MaskedSqrt(svptrue_pat_b##BITS(SV_VL1), a, a); \
}

HWY_SVE_FOREACH_F(HWY_SVE_SQRT_LOWER, SqrtLower, _)
#undef HWY_SVE_SQRT_LOWER

// ------------------------------ MaskedSqrtOrZero
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMV_Z, MaskedSqrtOrZero, sqrt)
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMV_Z, MaskedSqrt, sqrt)

// ------------------------------ ApproximateReciprocalSqrt
#ifdef HWY_NATIVE_F64_APPROX_RSQRT
Expand Down Expand Up @@ -1553,6 +1538,7 @@ HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMul, mul)
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMVV, MaskedDiv, div)
HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGMVV, MaskedDiv, div)
HWY_SVE_FOREACH_UI64(HWY_SVE_RETV_ARGMVV, MaskedDiv, div)
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMV, MaskedSqrt, sqrt)
#if HWY_SVE_HAVE_2
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV, MaskedSatAdd, qadd)
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV, MaskedSatSub, qsub)
Expand Down Expand Up @@ -1616,6 +1602,11 @@ HWY_API V MaskedSatSubOr(V no, M m, V a, V b) {
}
#endif

template <class V, HWY_IF_FLOAT_V(V), class M>
HWY_API V MaskedSqrtOr(V no, M m, V v) {
return IfThenElse(m, detail::MaskedSqrt(m, v), no);
}

// ================================================== REDUCE

#ifdef HWY_NATIVE_REDUCE_SCALAR
Expand Down Expand Up @@ -6412,6 +6403,8 @@ HWY_API V HighestSetBitIndex(V v) {
#undef HWY_SVE_IF_NOT_EMULATED_D
#undef HWY_SVE_PTRUE
#undef HWY_SVE_RETV_ARGMVV
#undef HWY_SVE_RETV_ARGMV_Z
#undef HWY_SVE_RETV_ARGMV
#undef HWY_SVE_RETV_ARGPV
#undef HWY_SVE_RETV_ARGPVN
#undef HWY_SVE_RETV_ARGPVV
Expand Down
38 changes: 17 additions & 21 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -5275,29 +5275,25 @@ HWY_API VFromD<DI32> SatWidenMulAccumFixedPoint(DI32 di32,

#endif // HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT

// ------------------------------ SqrtLower
#if (defined(HWY_NATIVE_SQRT_LOWER) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_SQRT_LOWER
#undef HWY_NATIVE_SQRT_LOWER
// ------------------------------ MaskedSqrt

#if (defined(HWY_NATIVE_MASKED_SQRT) == defined(HWY_TARGET_TOGGLE))

#ifdef HWY_NATIVE_MASKED_SQRT
#undef HWY_NATIVE_MASKED_SQRT
#else
#define HWY_NATIVE_SQRT_LOWER
#define HWY_NATIVE_MASKED_SQRT
#endif

template <class V, HWY_IF_FLOAT_V(V)>
HWY_API V SqrtLower(V a) {
const DFromV<V> d;
const auto first_mask = FirstN(d, 1);
return IfThenElse(first_mask, Sqrt(a), a);
template <class V, HWY_IF_FLOAT_V(V), class M>
HWY_API V MaskedSqrt(M m, V v) {
return IfThenElseZero(m, Sqrt(v));
}

#undef HWY_SVE_SQRT_LOWER
#endif // HWY_NATIVE_SQRT_LOWER

// ------------------------------ MaskedSqrtOrZero
template <class V, HWY_IF_FLOAT_V(V), class M>
HWY_API V MaskedSqrtOrZero(M m, V v) {
return IfThenElseZero(m, Sqrt(v));
HWY_API V MaskedSqrtOr(V no, M m, V v) {
return IfThenElse(m, Sqrt(v), no);
}
#endif

// ------------------------------ SumOfMulQuadAccumulate

Expand Down Expand Up @@ -5483,9 +5479,9 @@ HWY_API V ApproximateReciprocal(V v) {

#endif // HWY_NATIVE_F64_APPROX_RECIP

// ------------------------------ MaskedApproximateReciprocalOrZero
// ------------------------------ MaskedApproximateReciprocal
template <class V, HWY_IF_FLOAT_V(V), class M>
HWY_API V MaskedApproximateReciprocalOrZero(M m, V v) {
HWY_API V MaskedApproximateReciprocal(M m, V v) {
return IfThenElseZero(m, ApproximateReciprocal(v));
}

Expand Down Expand Up @@ -5514,9 +5510,9 @@ HWY_API V ApproximateReciprocalSqrt(V v) {

#endif // HWY_NATIVE_F64_APPROX_RSQRT

// ------------------------------ MaskedApproximateReciprocalSqrtOrZero
// ------------------------------ MaskedApproximateReciprocalSqrt
template <class V, HWY_IF_FLOAT_V(V), class M>
HWY_API V MaskedApproximateReciprocalSqrtOrZero(M m, V v) {
HWY_API V MaskedApproximateReciprocalSqrt(M m, V v) {
return IfThenElseZero(m, ApproximateReciprocalSqrt(v));
}

Expand Down
3 changes: 3 additions & 0 deletions hwy/ops/ppc_vsx-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1939,6 +1939,9 @@ HWY_API Vec128<T, N> ApproximateReciprocal(Vec128<T, N> v) {
#endif
}

// TODO: Implement GetExponent using vec_extract_exp (which returns the biased
// exponent) followed by a subtraction by MaxExponentField<T>() >> 1

// ------------------------------ Floating-point square root

#if HWY_S390X_HAVE_Z14
Expand Down
2 changes: 2 additions & 0 deletions hwy/ops/x86_512-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1842,6 +1842,8 @@ HWY_API Vec512<double> ApproximateReciprocal(Vec512<double> v) {
return Vec512<double>{_mm512_rcp14_pd(v.raw)};
}

// TODO: Implement GetExponent using _mm_getexp_ps/_mm_getexp_pd/_mm_getexp_ph

// ------------------------------ MaskedMinOr

template <typename T, HWY_IF_U8(T)>
Expand Down
34 changes: 6 additions & 28 deletions hwy/tests/float_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ struct TestMaskedApproximateReciprocal {
HWY_ASSERT(input && actual);

Store(nonzero, d, input.get());
Store(MaskedApproximateReciprocalOrZero(first_three, nonzero), d, actual.get());
Store(MaskedApproximateReciprocal(first_three, nonzero), d, actual.get());

double max_l1 = 0.0;
double worst_expected = 0.0;
Expand Down Expand Up @@ -224,40 +224,19 @@ HWY_NOINLINE void TestAllSquareRoot() {
ForFloatTypes(ForPartialVectors<TestSquareRoot>());
}

struct TestSqrtLower {
template <typename T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
const auto vi = Iota(d, 4);

const size_t N = Lanes(d);
auto expected = AllocateAligned<T>(N);

for (size_t i = 0; i < N; ++i) {
if (i == 0) {
expected[i] = ConvertScalarTo<T>(2); // sqrt(4)
} else {
expected[i] = ConvertScalarTo<T>(i + 4);
}
}

HWY_ASSERT_VEC_EQ(d, expected.get(), SqrtLower(vi));
}
};

HWY_NOINLINE void TestAllSqrtLower() {
ForFloatTypes(ForPartialVectors<TestSqrtLower>());
}

struct TestMaskedSqrt {
template <typename T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
const auto v0 = Zero(d);
const auto vi = Iota(d, 4);
const auto v2 = Iota(d, 5);

const MFromD<D> first_four = FirstN(d, 4);
const auto expected = IfThenElse(first_four, Sqrt(vi), v0);
const auto masked_expected = IfThenElse(first_four, Sqrt(vi), v2);

HWY_ASSERT_VEC_EQ(d, expected, MaskedSqrtOrZero(first_four, vi));
HWY_ASSERT_VEC_EQ(d, expected, MaskedSqrt(first_four, vi));
HWY_ASSERT_VEC_EQ(d, masked_expected, MaskedSqrtOr(v2, first_four, vi));
}
};

Expand Down Expand Up @@ -297,7 +276,7 @@ struct TestMaskedReciprocalSquareRoot {
const size_t N = Lanes(d);
auto lanes = AllocateAligned<T>(N);
HWY_ASSERT(lanes);
Store(MaskedApproximateReciprocalSqrtOrZero(first_three, v), d,
Store(MaskedApproximateReciprocalSqrt(first_three, v), d,
lanes.get());
for (size_t i = 0; i < N; ++i) {
T expected_val = i < 3 ? ConvertScalarTo<T>(1 / std::sqrt(123.0f))
Expand Down Expand Up @@ -660,7 +639,6 @@ HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllF32FromF16);
HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllDiv);
HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllApproximateReciprocal);
HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllSquareRoot);
HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllSqrtLower);
HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllMaskedSqrt);
HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllReciprocalSquareRoot);
HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllMaskedReciprocalSquareRoot);
Expand Down

0 comments on commit e3c6c3b

Please sign in to comment.