Skip to content

Commit

Permalink
Add GroupBy.shuffle_to_chunks() (#9320)
Browse files Browse the repository at this point in the history
* Add GroupBy.shuffle()

* Cleanup

* Cleanup

* fix

* return groupby instance from shuffle

* Fix nD by

* Skip if no dask

* fix tests

* Add `chunks` to signature

* FIx self

* Another Self fix

* Forward chunks too

* [revert]

* undo flox limit

* [revert]

* fix types

* Add DataArray.shuffle_by, Dataset.shuffle_by

* Add doctest

* Refactor

* tweak docstrings

* fix typing

* Fix

* fix docstring

* bump min version to dask>=2024.08.1

* Fix typing

* Fix types

* remove shuffle_by for now.

* Add tests

* Support shuffling with multiple groupers

* Revert "remove shuffle_by for now."

This reverts commit 7a99c8f.

* bad merge

* Add a test

* Add docs

* bugfix

* Refactor out Dataset._shuffle

* fix types

* fix tests

* Handle by is chunked

* Some refactoring

* Remove shuffle_by

* shuffle -> distributed_shuffle

* return xarray object from distributed_shuffle

* fix

* fix doctest

* fix api

* Rename to `shuffle_to_chunks`

* update docs
  • Loading branch information
dcherian authored Nov 21, 2024
1 parent 077276a commit ad5c7ed
Show file tree
Hide file tree
Showing 14 changed files with 421 additions and 38 deletions.
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1210,6 +1210,7 @@ Dataset
DatasetGroupBy.var
DatasetGroupBy.dims
DatasetGroupBy.groups
DatasetGroupBy.shuffle_to_chunks

DataArray
---------
Expand Down Expand Up @@ -1241,6 +1242,7 @@ DataArray
DataArrayGroupBy.var
DataArrayGroupBy.dims
DataArrayGroupBy.groups
DataArrayGroupBy.shuffle_to_chunks

Grouper Objects
---------------
Expand Down
21 changes: 21 additions & 0 deletions doc/user-guide/groupby.rst
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,24 @@ Different groupers can be combined to construct sophisticated GroupBy operations
from xarray.groupers import BinGrouper
ds.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()).sum()
Shuffling
~~~~~~~~~

Shuffling is a generalization of sorting a DataArray or Dataset by another DataArray, named ``label`` for example, that follows from the idea of grouping by ``label``.
Shuffling reorders the DataArray or the DataArrays in a Dataset such that all members of a group occur sequentially. For example,
Shuffle the object using either :py:class:`DatasetGroupBy` or :py:class:`DataArrayGroupBy` as appropriate.

.. ipython:: python
da = xr.DataArray(
dims="x",
data=[1, 2, 3, 4, 5, 6],
coords={"label": ("x", "a b c a b c".split(" "))},
)
da.groupby("label").shuffle_to_chunks()
For chunked array types (e.g. dask or cubed), shuffle may result in a more optimized communication pattern when compared to direct indexing by the appropriate indexer.
Shuffling also makes GroupBy operations on chunked arrays an embarrassingly parallel problem, and may significantly improve workloads that use :py:meth:`DatasetGroupBy.map` or :py:meth:`DataArrayGroupBy.map`.
8 changes: 8 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
Bins,
DaCompatible,
NetcdfWriteModes,
T_Chunks,
T_DataArray,
T_DataArrayOrSet,
ZarrWriteModes,
Expand Down Expand Up @@ -105,6 +106,7 @@
Dims,
ErrorOptions,
ErrorOptionsWithWarn,
GroupIndices,
GroupInput,
InterpOptions,
PadModeOptions,
Expand Down Expand Up @@ -1687,6 +1689,12 @@ def sel(
)
return self._from_temp_dataset(ds)

def _shuffle(
self, dim: Hashable, *, indices: GroupIndices, chunks: T_Chunks
) -> Self:
ds = self._to_temp_dataset()._shuffle(dim=dim, indices=indices, chunks=chunks)
return self._from_temp_dataset(ds)

def head(
self,
indexers: Mapping[Any, int] | int | None = None,
Expand Down
34 changes: 34 additions & 0 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
DsCompatible,
ErrorOptions,
ErrorOptionsWithWarn,
GroupIndices,
GroupInput,
InterpOptions,
JoinOptions,
Expand All @@ -166,6 +167,7 @@
ResampleCompatible,
SideOptions,
T_ChunkDimFreq,
T_Chunks,
T_DatasetPadConstantValues,
T_Xarray,
)
Expand Down Expand Up @@ -3237,6 +3239,38 @@ def sel(
result = self.isel(indexers=query_results.dim_indexers, drop=drop)
return result._overwrite_indexes(*query_results.as_tuple()[1:])

def _shuffle(self, dim, *, indices: GroupIndices, chunks: T_Chunks) -> Self:
# Shuffling is only different from `isel` for chunked arrays.
# Extract them out, and treat them specially. The rest, we route through isel.
# This makes it easy to ensure correct handling of indexes.
is_chunked = {
name: var
for name, var in self._variables.items()
if is_chunked_array(var._data)
}
subset = self[[name for name in self._variables if name not in is_chunked]]

no_slices: list[list[int]] = [
list(range(*idx.indices(self.sizes[dim])))
if isinstance(idx, slice)
else idx
for idx in indices
]
no_slices = [idx for idx in no_slices if idx]

shuffled = (
subset
if dim not in subset.dims
else subset.isel({dim: np.concatenate(no_slices)})
)
for name, var in is_chunked.items():
shuffled[name] = var._shuffle(
indices=no_slices,
dim=dim,
chunks=chunks,
)
return shuffled

def head(
self,
indexers: Mapping[Any, int] | int | None = None,
Expand Down
82 changes: 80 additions & 2 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,13 @@

from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.types import GroupIndex, GroupIndices, GroupInput, GroupKey
from xarray.core.types import (
GroupIndex,
GroupIndices,
GroupInput,
GroupKey,
T_Chunks,
)
from xarray.core.utils import Frozen
from xarray.groupers import EncodedGroups, Grouper

Expand Down Expand Up @@ -676,6 +682,76 @@ def sizes(self) -> Mapping[Hashable, int]:
self._sizes = self._obj.isel({self._group_dim: index}).sizes
return self._sizes

def shuffle_to_chunks(self, chunks: T_Chunks = None) -> T_Xarray:
"""
Sort or "shuffle" the underlying object.
"Shuffle" means the object is sorted so that all group members occur sequentially,
in the same chunk. Multiple groups may occur in the same chunk.
This method is particularly useful for chunked arrays (e.g. dask, cubed).
particularly when you need to map a function that requires all members of a group
to be present in a single chunk. For chunked array types, the order of appearance
is not guaranteed, but will depend on the input chunking.
Parameters
----------
chunks : int, tuple of int, "auto" or mapping of hashable to int or tuple of int, optional
How to adjust chunks along dimensions not present in the array being grouped by.
Returns
-------
DataArrayGroupBy or DatasetGroupBy
Examples
--------
>>> import dask.array
>>> da = xr.DataArray(
... dims="x",
... data=dask.array.arange(10, chunks=3),
... coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]},
... name="a",
... )
>>> shuffled = da.groupby("x").shuffle_to_chunks()
>>> shuffled
<xarray.DataArray 'a' (x: 10)> Size: 80B
dask.array<shuffle, shape=(10,), dtype=int64, chunksize=(3,), chunktype=numpy.ndarray>
Coordinates:
* x (x) int64 80B 0 1 1 1 2 2 2 3 3 3
>>> shuffled.groupby("x").quantile(q=0.5).compute()
<xarray.DataArray 'a' (x: 4)> Size: 32B
array([9., 3., 4., 5.])
Coordinates:
quantile float64 8B 0.5
* x (x) int64 32B 0 1 2 3
See Also
--------
dask.dataframe.DataFrame.shuffle
dask.array.shuffle
"""
self._raise_if_by_is_chunked()
return self._shuffle_obj(chunks)

def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray:
from xarray.core.dataarray import DataArray

was_array = isinstance(self._obj, DataArray)
as_dataset = self._obj._to_temp_dataset() if was_array else self._obj

for grouper in self.groupers:
if grouper.name not in as_dataset._variables:
as_dataset.coords[grouper.name] = grouper.group

shuffled = as_dataset._shuffle(
dim=self._group_dim, indices=self.encoded.group_indices, chunks=chunks
)
unstacked: Dataset = self._maybe_unstack(shuffled)
if was_array:
return self._obj._from_temp_dataset(unstacked)
else:
return unstacked # type: ignore[return-value]

def map(
self,
func: Callable,
Expand Down Expand Up @@ -896,7 +972,9 @@ def _maybe_unstack(self, obj):
# and `inserted_dims`
# if multiple groupers all share the same single dimension, then
# we don't stack/unstack. Do that manually now.
obj = obj.unstack(*self.encoded.unique_coord.dims)
dims_to_unstack = self.encoded.unique_coord.dims
if all(dim in obj.dims for dim in dims_to_unstack):
obj = obj.unstack(*dims_to_unstack)
to_drop = [
grouper.name
for grouper in self.groupers
Expand Down
45 changes: 45 additions & 0 deletions xarray/core/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
if TYPE_CHECKING:
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.types import T_Chunks

from xarray.groupers import RESAMPLE_DIM

Expand Down Expand Up @@ -58,6 +59,50 @@ def _flox_reduce(
result = result.rename({RESAMPLE_DIM: self._group_dim})
return result

def shuffle_to_chunks(self, chunks: T_Chunks = None):
"""
Sort or "shuffle" the underlying object.
"Shuffle" means the object is sorted so that all group members occur sequentially,
in the same chunk. Multiple groups may occur in the same chunk.
This method is particularly useful for chunked arrays (e.g. dask, cubed).
particularly when you need to map a function that requires all members of a group
to be present in a single chunk. For chunked array types, the order of appearance
is not guaranteed, but will depend on the input chunking.
Parameters
----------
chunks : int, tuple of int, "auto" or mapping of hashable to int or tuple of int, optional
How to adjust chunks along dimensions not present in the array being grouped by.
Returns
-------
DataArrayGroupBy or DatasetGroupBy
Examples
--------
>>> import dask.array
>>> da = xr.DataArray(
... dims="time",
... data=dask.array.arange(10, chunks=1),
... coords={"time": xr.date_range("2001-01-01", freq="12h", periods=10)},
... name="a",
... )
>>> shuffled = da.resample(time="2D").shuffle_to_chunks()
>>> shuffled
<xarray.DataArray 'a' (time: 10)> Size: 80B
dask.array<shuffle, shape=(10,), dtype=int64, chunksize=(4,), chunktype=numpy.ndarray>
Coordinates:
* time (time) datetime64[ns] 80B 2001-01-01 ... 2001-01-05T12:00:00
See Also
--------
dask.dataframe.DataFrame.shuffle
dask.array.shuffle
"""
(grouper,) = self.groupers
return self._shuffle_obj(chunks).drop_vars(RESAMPLE_DIM)

def _drop_coords(self) -> T_Xarray:
"""Drop non-dimension coordinates along the resampled dimension."""
obj = self._obj
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def read(self, __n: int = ...) -> AnyStr_co:
ZarrWriteModes = Literal["w", "w-", "a", "a-", "r+", "r"]

GroupKey = Any
GroupIndex = Union[int, slice, list[int]]
GroupIndex = Union[slice, list[int]]
GroupIndices = tuple[GroupIndex, ...]
Bins = Union[
int, Sequence[int], Sequence[float], Sequence[pd.Timestamp], np.ndarray, pd.Index
Expand Down
26 changes: 25 additions & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,13 @@
maybe_coerce_to_str,
)
from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions
from xarray.namedarray.pycompat import integer_types, is_0d_dask_array, to_duck_array
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import (
integer_types,
is_0d_dask_array,
is_chunked_array,
to_duck_array,
)
from xarray.namedarray.utils import module_available
from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims

Expand Down Expand Up @@ -1019,6 +1025,24 @@ def compute(self, **kwargs):
new = self.copy(deep=False)
return new.load(**kwargs)

def _shuffle(
self, indices: list[list[int]], dim: Hashable, chunks: T_Chunks
) -> Self:
# TODO (dcherian): consider making this public API
array = self._data
if is_chunked_array(array):
chunkmanager = get_chunked_array_type(array)
return self._replace(
data=chunkmanager.shuffle(
array,
indexer=indices,
axis=self.get_axis_num(dim),
chunks=chunks,
)
)
else:
return self.isel({dim: np.concatenate(indices)})

def isel(
self,
indexers: Mapping[Any, Any] | None = None,
Expand Down
Loading

0 comments on commit ad5c7ed

Please sign in to comment.