Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Duck array ops for all and any #9883

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
10 changes: 5 additions & 5 deletions xarray/coding/times.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from xarray.core import indexing
from xarray.core.common import contains_cftime_datetimes, is_np_datetime_like
from xarray.core.duck_array_ops import asarray, ravel, reshape
from xarray.core.duck_array_ops import array_all, array_any, asarray, ravel, reshape
from xarray.core.formatting import first_n_items, format_timestamp, last_item
from xarray.core.pdcompat import default_precision_timestamp, timestamp_as_unit
from xarray.core.utils import attempt_import, emit_user_level_warning
Expand Down Expand Up @@ -676,7 +676,7 @@ def _infer_time_units_from_diff(unique_timedeltas) -> str:
unit_timedelta = _unit_timedelta_numpy
zero_timedelta = np.timedelta64(0, "ns")
for time_unit in time_units:
if np.all(unique_timedeltas % unit_timedelta(time_unit) == zero_timedelta):
if array_all(unique_timedeltas % unit_timedelta(time_unit) == zero_timedelta):
return time_unit
return "seconds"

Expand Down Expand Up @@ -939,7 +939,7 @@ def encode_datetime(d):

def cast_to_int_if_safe(num) -> np.ndarray:
int_num = np.asarray(num, dtype=np.int64)
if (num == int_num).all():
if array_all(num == int_num):
num = int_num
return num

Expand All @@ -961,7 +961,7 @@ def _cast_to_dtype_if_safe(num: np.ndarray, dtype: np.dtype) -> np.ndarray:
cast_num = np.asarray(num, dtype=dtype)

if np.issubdtype(dtype, np.integer):
if not (num == cast_num).all():
if not array_all(num == cast_num):
if np.issubdtype(num.dtype, np.floating):
raise ValueError(
f"Not possible to cast all encoded times from "
Expand All @@ -979,7 +979,7 @@ def _cast_to_dtype_if_safe(num: np.ndarray, dtype: np.dtype) -> np.ndarray:
"a larger integer dtype."
)
else:
if np.isinf(cast_num).any():
if array_any(np.isinf(cast_num)):
raise OverflowError(
f"Not possible to cast encoded times from {num.dtype!r} to "
f"{dtype!r} without overflow. Consider removing the dtype "
Expand Down
20 changes: 15 additions & 5 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

import numpy as np
import pandas as pd
from numpy import all as array_all # noqa: F401
from numpy import any as array_any # noqa: F401
from numpy import ( # noqa: F401
isclose,
isnat,
Expand Down Expand Up @@ -231,7 +229,7 @@
xp = get_array_namespace(data)
if xp == np:
# numpy currently doesn't have a astype:
return data.astype(dtype, **kwargs)

Check warning on line 232 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.10

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.10

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.12

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.12

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.12

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.12

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / windows-latest py3.10

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / windows-latest py3.10

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / windows-latest py3.10

invalid value encountered in cast

Check warning on line 232 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / windows-latest py3.10

invalid value encountered in cast
return xp.astype(data, dtype, **kwargs)
return data.astype(dtype, **kwargs)

Expand Down Expand Up @@ -319,7 +317,9 @@
if lazy_equiv is None:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered")
return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all())
return bool(
array_all(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True))
)
else:
return lazy_equiv

Expand All @@ -333,7 +333,7 @@
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2))
return bool(flag_array.all())
return bool(array_all(flag_array))
else:
return lazy_equiv

Expand All @@ -349,7 +349,7 @@
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
flag_array = (arr1 == arr2) | isnull(arr1) | isnull(arr2)
return bool(flag_array.all())
return bool(array_all(flag_array))
else:
return lazy_equiv

Expand Down Expand Up @@ -536,6 +536,16 @@
cumsum_1d.numeric_only = True


def array_all(array, axis=None, keepdims=False, **kwargs):
xp = get_array_namespace(array)
return xp.all(array, axis=axis, keepdims=keepdims, **kwargs)


def array_any(array, axis=None, keepdims=False, **kwargs):
xp = get_array_namespace(array)
return xp.any(array, axis=axis, keepdims=keepdims, **kwargs)


_mean = _create_nan_agg_method("mean", invariant_0d=True)


Expand Down
8 changes: 4 additions & 4 deletions xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pandas.errors import OutOfBoundsDatetime

from xarray.core.datatree_render import RenderDataTree
from xarray.core.duck_array_ops import array_equiv, astype
from xarray.core.duck_array_ops import array_all, array_any, array_equiv, astype
from xarray.core.indexing import MemoryCachedArray
from xarray.core.options import OPTIONS, _get_boolean_with_default
from xarray.core.treenode import group_subtrees
Expand Down Expand Up @@ -204,9 +204,9 @@ def format_items(x):
day_part = x[~pd.isnull(x)].astype("timedelta64[D]").astype("timedelta64[ns]")
time_needed = x[~pd.isnull(x)] != day_part
day_needed = day_part != np.timedelta64(0, "ns")
if np.logical_not(day_needed).all():
if array_all(np.logical_not(day_needed)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

surprised we haven't updated the logical_not yet

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could just use

Suggested change
if array_all(np.logical_not(day_needed)):
if array_all(~day_needed):

Or not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think ~ is np.invert/xp.bitwise_invert, so possibly different? I'll leave this for a future PR if that's OK.

timedelta_format = "time"
elif np.logical_not(time_needed).all():
elif array_all(np.logical_not(time_needed)):
timedelta_format = "date"

formatted = [format_item(xi, timedelta_format) for xi in x]
Expand All @@ -232,7 +232,7 @@ def format_array_flat(array, max_width: int):

cum_len = np.cumsum([len(s) + 1 for s in relevant_items]) - 1
if (array.size > 2) and (
(max_possibly_relevant < array.size) or (cum_len > max_width).any()
(max_possibly_relevant < array.size) or array_any(cum_len > max_width)
):
padding = " ... "
max_len = max(int(np.argmax(cum_len + len(padding) - 1 > max_width)), 2)
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs):
data = getattr(np, func)(value, axis=axis, **kwargs)

# TODO This will evaluate dask arrays and might be costly.
if (valid_count == 0).any():
if duck_array_ops.array_any(valid_count == 0):
raise ValueError("All-NaN slice encountered")

return data
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def __init__(self, obj: T_Xarray, weights: T_DataArray) -> None:

def _weight_check(w):
# Ref https://github.com/pydata/xarray/pull/4559/files#r515968670
if duck_array_ops.isnull(w).any():
if duck_array_ops.array_any(duck_array_ops.isnull(w)):
raise ValueError(
"`weights` cannot contain missing values. "
"Missing values can be replaced by `weights.fillna(0)`."
Expand Down
11 changes: 5 additions & 6 deletions xarray/groupers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@
from numpy.typing import ArrayLike

from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq
from xarray.core import duck_array_ops
from xarray.core.computation import apply_ufunc
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
from xarray.core.dataarray import DataArray
from xarray.core.duck_array_ops import isnull
from xarray.core.duck_array_ops import array_all, isnull
from xarray.core.groupby import T_Group, _DummyGroup
from xarray.core.indexes import safe_cast_to_index
from xarray.core.resample_cftime import CFTimeGrouper
Expand Down Expand Up @@ -235,7 +234,7 @@ def _factorize_unique(self) -> EncodedGroups:
# look through group to find the unique values
sort = not isinstance(self.group_as_index, pd.MultiIndex)
unique_values, codes_ = unique_value_groups(self.group_as_index, sort=sort)
if (codes_ == -1).all():
if array_all(codes_ == -1):
raise ValueError(
"Failed to group data. Are you grouping by a variable that is all NaN?"
)
Expand Down Expand Up @@ -347,7 +346,7 @@ def reset(self) -> Self:
)

def __post_init__(self) -> None:
if duck_array_ops.isnull(self.bins).all():
if array_all(isnull(self.bins)):
raise ValueError("All bin edges are NaN.")

def _cut(self, data):
Expand Down Expand Up @@ -381,7 +380,7 @@ def factorize(self, group: T_Group) -> EncodedGroups:
f"Bin edges must be provided when grouping by chunked arrays. Received {self.bins=!r} instead"
)
codes = self._factorize_lazy(group)
if not by_is_chunked and (codes == -1).all():
if not by_is_chunked and array_all(codes == -1):
raise ValueError(
f"None of the data falls within bins with edges {self.bins!r}"
)
Expand Down Expand Up @@ -547,7 +546,7 @@ def factorize(self, group: T_Group) -> EncodedGroups:
def _factorize_given_labels(data: np.ndarray, labels: np.ndarray) -> np.ndarray:
# Copied from flox
sorter = np.argsort(labels)
is_sorted = (sorter == np.arange(sorter.size)).all()
is_sorted = array_all(sorter == np.arange(sorter.size))
codes = np.searchsorted(labels, data, sorter=sorter)
mask = ~np.isin(data, labels) | isnull(data) | (codes == len(labels))
# codes is the index in to the sorted array.
Expand Down
2 changes: 1 addition & 1 deletion xarray/testing/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def assert_duckarray_equal(x, y, err_msg="", verbose=True):
if (utils.is_duck_array(x) and utils.is_scalar(y)) or (
utils.is_scalar(x) and utils.is_duck_array(y)
):
equiv = (x == y).all()
equiv = duck_array_ops.array_all(x == y)
else:
equiv = duck_array_ops.array_equiv(x, y)
assert equiv, _format_message(x, y, err_msg=err_msg, verbose=verbose)
Expand Down
Loading