Skip to content

Commit

Permalink
Merge pull request #2425 from cambridgeconsultants:cc_up_float_operat…
Browse files Browse the repository at this point in the history
…ions

PiperOrigin-RevId: 720600683
  • Loading branch information
copybara-github committed Jan 28, 2025
2 parents 960f74d + e3c6c3b commit d547f91
Show file tree
Hide file tree
Showing 7 changed files with 316 additions and 1 deletion.
26 changes: 26 additions & 0 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,10 @@ from left to right, of the arguments passed to `Create{2-4}`.
<code>V **ApproximateReciprocal**(V a)</code>: returns an approximation of
`1.0 / a[i]`.

* `V`: `{f}` \
<code>V **GetExponent**(V v)</code>: returns the exponent of `v[i]` as a floating point value.
Essentially calculates `floor(log2(x))`.

#### Min/Max

**Note**: Min/Max corner cases are target-specific and may change. If either
Expand Down Expand Up @@ -864,6 +868,10 @@ variants are somewhat slower on Arm, and unavailable for integer inputs; if the
c))` or `MulAddSub(a, b, OddEven(c, Neg(c))`, but `MulSub(a, b, c)` is more
efficient on some targets (including AVX2/AVX3).

* <code>V **MulSubAdd**(V a, V b, V c)</code>: returns `a[i] * b[i] + c[i]` in
the even lanes and `a[i] * b[i] - c[i]` in the odd lanes. Essentially,
MulAddSub with `c[i]` negated.

* `V`: `bf16`, `D`: `RepartitionToWide<DFromV<V>>`, `VW`: `Vec<D>` \
<code>VW **MulEvenAdd**(D d, V a, V b, VW c)</code>: equivalent to and
potentially more efficient than `MulAdd(PromoteEvenTo(d, a),
Expand All @@ -881,6 +889,9 @@ exceptions for those lanes if that is supported by the ISA. When exceptions are
not a concern, these are equivalent to, and potentially more efficient than,
`IfThenElse(m, Add(a, b), no);` etc.

* `V`: `{f}` \
<code>V **MaskedSqrtOr**(V no, M m, V a)</code>: returns `sqrt(a[i])` or
`no[i]` if `m[i]` is false.
* <code>V **MaskedMinOr**(V no, M m, V a, V b)</code>: returns `Min(a, b)[i]`
or `no[i]` if `m[i]` is false.
* <code>V **MaskedMaxOr**(V no, M m, V a, V b)</code>: returns `Max(a, b)[i]`
Expand All @@ -905,6 +916,21 @@ not a concern, these are equivalent to, and potentially more efficient than,
b[i]` saturated to the minimum/maximum representable value, or `no[i]` if
`m[i]` is false.

#### Zero masked arithmetic

All ops in this section return `0` for `mask=false` lanes. These are equivalent
to, and potentially more efficient than, `IfThenElseZero(m, Add(a, b));` etc.

* `V`: `{f}` \
<code>V **MaskedSqrt**(M m, V a)</code>: returns `sqrt(a[i])` where
m is true, and zero otherwise.
* `V`: `{f}` \
<code>V **MaskedApproximateReciprocalSqrt**(M m, V a)</code>: returns
the result of ApproximateReciprocalSqrt where m is true and zero otherwise.
* `V`: `{f}` \
<code>V **MaskedApproximateReciprocal**(M m, V a)</code>: returns the
result of ApproximateReciprocal where m is true and zero otherwise.

#### Shifts

**Note**: Counts not in `[0, sizeof(T)*8)` yield implementation-defined results.
Expand Down
53 changes: 53 additions & 0 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,14 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _)
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
return sv##OP##_##CHAR##BITS(v); \
}
#define HWY_SVE_RETV_ARGMV(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v) { \
return sv##OP##_##CHAR##BITS##_x(m, v); \
}
#define HWY_SVE_RETV_ARGMV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a) { \
return sv##OP##_##CHAR##BITS##_z(m, a); \
}

// vector = f(vector, scalar), e.g. detail::AddN
#define HWY_SVE_RETV_ARGPVN(BASE, CHAR, BITS, HALF, NAME, OP) \
Expand Down Expand Up @@ -1234,6 +1242,15 @@ HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGV, ApproximateReciprocal, recpe)
// ------------------------------ Sqrt
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Sqrt, sqrt)

// ------------------------------ MaskedSqrt
#ifdef HWY_NATIVE_MASKED_SQRT
#undef HWY_NATIVE_MASKED_SQRT
#else
#define HWY_NATIVE_MASKED_SQRT
#endif

HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMV_Z, MaskedSqrt, sqrt)

// ------------------------------ ApproximateReciprocalSqrt
#ifdef HWY_NATIVE_F64_APPROX_RSQRT
#undef HWY_NATIVE_F64_APPROX_RSQRT
Expand Down Expand Up @@ -1521,6 +1538,7 @@ HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMul, mul)
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMVV, MaskedDiv, div)
HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGMVV, MaskedDiv, div)
HWY_SVE_FOREACH_UI64(HWY_SVE_RETV_ARGMVV, MaskedDiv, div)
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMV, MaskedSqrt, sqrt)
#if HWY_SVE_HAVE_2
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV, MaskedSatAdd, qadd)
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV, MaskedSatSub, qsub)
Expand Down Expand Up @@ -1584,6 +1602,11 @@ HWY_API V MaskedSatSubOr(V no, M m, V a, V b) {
}
#endif

template <class V, HWY_IF_FLOAT_V(V), class M>
HWY_API V MaskedSqrtOr(V no, M m, V v) {
return IfThenElse(m, detail::MaskedSqrt(m, v), no);
}

// ================================================== REDUCE

#ifdef HWY_NATIVE_REDUCE_SCALAR
Expand Down Expand Up @@ -3094,6 +3117,34 @@ HWY_API VFromD<D> Iota(const D d, T2 first) {
ConvertScalarTo<TFromD<D>>(first));
}

// ------------------------------ GetExponent

#if HWY_SVE_HAVE_2 || HWY_IDE
#ifdef HWY_NATIVE_GET_EXPONENT
#undef HWY_NATIVE_GET_EXPONENT
#else
#define HWY_NATIVE_GET_EXPONENT
#endif

namespace detail {
#define HWY_SVE_GET_EXP(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(int, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \
}
HWY_SVE_FOREACH_F(HWY_SVE_GET_EXP, GetExponent, logb)
#undef HWY_SVE_GET_EXP
} // namespace detail

template <class V, HWY_IF_FLOAT_V(V)>
HWY_API V GetExponent(V v) {
const DFromV<V> d;
const RebindToSigned<decltype(d)> di;
const VFromD<decltype(di)> exponent_int = detail::GetExponent(v);
// convert integer to original type
return ConvertTo(d, exponent_int);
}
#endif // HWY_SVE_HAVE_2

// ------------------------------ InterleaveLower

template <class D, class V>
Expand Down Expand Up @@ -6352,6 +6403,8 @@ HWY_API V HighestSetBitIndex(V v) {
#undef HWY_SVE_IF_NOT_EMULATED_D
#undef HWY_SVE_PTRUE
#undef HWY_SVE_RETV_ARGMVV
#undef HWY_SVE_RETV_ARGMV_Z
#undef HWY_SVE_RETV_ARGMV
#undef HWY_SVE_RETV_ARGPV
#undef HWY_SVE_RETV_ARGPVN
#undef HWY_SVE_RETV_ARGPVV
Expand Down
73 changes: 73 additions & 0 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1175,6 +1175,34 @@ HWY_API V MulByFloorPow2(V v, V exp) {

#endif // HWY_NATIVE_MUL_BY_POW2

// ------------------------------ GetExponent

#if (defined(HWY_NATIVE_GET_EXPONENT) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_GET_EXPONENT
#undef HWY_NATIVE_GET_EXPONENT
#else
#define HWY_NATIVE_GET_EXPONENT
#endif

template <class V, HWY_IF_FLOAT_V(V)>
HWY_API V GetExponent(V v) {
const DFromV<V> d;
using T = TFromV<V>;
const RebindToUnsigned<decltype(d)> du;
const RebindToSigned<decltype(d)> di;

constexpr uint8_t mantissa_bits = MantissaBits<T>();
const auto exponent_offset = Set(di, MaxExponentField<T>() >> 1);

// extract exponent bits as integer
const auto encoded_exponent = ShiftRight<mantissa_bits>(BitCast(du, Abs(v)));
const auto exponent_int = Sub(BitCast(di, encoded_exponent), exponent_offset);

// convert integer to original type
return ConvertTo(d, exponent_int);
}

#endif // HWY_NATIVE_GET_EXPONENT
// ------------------------------ LoadInterleaved2

#if HWY_IDE || \
Expand Down Expand Up @@ -4409,6 +4437,19 @@ HWY_API V MulAddSub(V mul, V x, V sub_or_add) {
OddEven(sub_or_add, BitCast(d, Neg(BitCast(d_negate, sub_or_add))));
return MulAdd(mul, x, add);
}
// ------------------------------ MulSubAdd

template <class V>
HWY_API V MulSubAdd(V mul, V x, V sub_or_add) {
using D = DFromV<V>;
using T = TFromD<D>;
using TNegate = If<!IsSigned<T>(), MakeSigned<T>, T>;

const D d;
const Rebind<TNegate, D> d_negate;

return MulAddSub(mul, x, BitCast(d, Neg(BitCast(d_negate, sub_or_add))));
}

// ------------------------------ Integer division
#if (defined(HWY_NATIVE_INT_DIV) == defined(HWY_TARGET_TOGGLE))
Expand Down Expand Up @@ -5234,6 +5275,26 @@ HWY_API VFromD<DI32> SatWidenMulAccumFixedPoint(DI32 di32,

#endif // HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT

// ------------------------------ MaskedSqrt

#if (defined(HWY_NATIVE_MASKED_SQRT) == defined(HWY_TARGET_TOGGLE))

#ifdef HWY_NATIVE_MASKED_SQRT
#undef HWY_NATIVE_MASKED_SQRT
#else
#define HWY_NATIVE_MASKED_SQRT
#endif
template <class V, HWY_IF_FLOAT_V(V), class M>
HWY_API V MaskedSqrt(M m, V v) {
return IfThenElseZero(m, Sqrt(v));
}

template <class V, HWY_IF_FLOAT_V(V), class M>
HWY_API V MaskedSqrtOr(V no, M m, V v) {
return IfThenElse(m, Sqrt(v), no);
}
#endif

// ------------------------------ SumOfMulQuadAccumulate

#if (defined(HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE) == \
Expand Down Expand Up @@ -5418,6 +5479,12 @@ HWY_API V ApproximateReciprocal(V v) {

#endif // HWY_NATIVE_F64_APPROX_RECIP

// ------------------------------ MaskedApproximateReciprocal
template <class V, HWY_IF_FLOAT_V(V), class M>
HWY_API V MaskedApproximateReciprocal(M m, V v) {
return IfThenElseZero(m, ApproximateReciprocal(v));
}

// ------------------------------ F64 ApproximateReciprocalSqrt

#if (defined(HWY_NATIVE_F64_APPROX_RSQRT) == defined(HWY_TARGET_TOGGLE))
Expand All @@ -5443,6 +5510,12 @@ HWY_API V ApproximateReciprocalSqrt(V v) {

#endif // HWY_NATIVE_F64_APPROX_RSQRT

// ------------------------------ MaskedApproximateReciprocalSqrt
template <class V, HWY_IF_FLOAT_V(V), class M>
HWY_API V MaskedApproximateReciprocalSqrt(M m, V v) {
return IfThenElseZero(m, ApproximateReciprocalSqrt(v));
}

// ------------------------------ Compress*

#if (defined(HWY_NATIVE_COMPRESS8) == defined(HWY_TARGET_TOGGLE))
Expand Down
3 changes: 3 additions & 0 deletions hwy/ops/ppc_vsx-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1939,6 +1939,9 @@ HWY_API Vec128<T, N> ApproximateReciprocal(Vec128<T, N> v) {
#endif
}

// TODO: Implement GetExponent using vec_extract_exp (which returns the biased
// exponent) followed by a subtraction by MaxExponentField<T>() >> 1

// ------------------------------ Floating-point square root

#if HWY_S390X_HAVE_Z14
Expand Down
2 changes: 2 additions & 0 deletions hwy/ops/x86_512-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1842,6 +1842,8 @@ HWY_API Vec512<double> ApproximateReciprocal(Vec512<double> v) {
return Vec512<double>{_mm512_rcp14_pd(v.raw)};
}

// TODO: Implement GetExponent using _mm_getexp_ps/_mm_getexp_pd/_mm_getexp_ph

// ------------------------------ MaskedMinOr

template <typename T, HWY_IF_U8(T)>
Expand Down
Loading

0 comments on commit d547f91

Please sign in to comment.