diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 6becb21ffca..649eeb60282 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -2126,18 +2126,6 @@ def to_floatable(x: DataArray) -> DataArray: return to_floatable(data) -def _apply_vectorized_indexer(indices, coord): - from xarray.core.indexing import ( - VectorizedIndexer, - apply_indexer, - as_indexable, - ) - - return apply_indexer( - as_indexable(coord), VectorizedIndexer((indices.squeeze(axis=-1),)) - ) - - def _calc_idxminmax( *, array, @@ -2182,28 +2170,14 @@ def _calc_idxminmax( indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna) # Handle chunked arrays (e.g. dask). + coord = array[dim]._variable.to_base_variable() if is_chunked_array(array.data): chunkmanager = get_chunked_array_type(array.data) - chunked_coord = chunkmanager.from_array(array[dim].data, chunks=((-1,),)) - - if indx.ndim == 0: - out = chunked_coord[indx.data] - else: - out = chunkmanager.map_blocks( - _apply_vectorized_indexer, - indx.data[..., np.newaxis], - chunked_coord, - chunks=indx.data.chunks, - drop_axis=-1, - dtype=chunked_coord.dtype, - ) - res = indx.copy(data=out) - # we need to attach back the dim name - res.name = dim - else: - res = array[dim][(indx,)] - # The dim is gone but we need to remove the corresponding coordinate. - del res.coords[dim] + coord_array = chunkmanager.from_array( + array[dim].data, chunks=((array.sizes[dim],),) + ) + coord = coord.copy(data=coord_array) + res = indx._replace(coord[(indx.variable,)]).rename(dim) if skipna or (skipna is None and array.dtype.kind in na_dtypes): # Put the NaN values back in after removing them diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 0a7b94a53c7..fde90ee71d0 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -2,6 +2,7 @@ import enum import functools +import math import operator from collections import Counter, defaultdict from collections.abc import Callable, Hashable, Iterable, Mapping @@ -472,12 +473,12 @@ def __init__(self, key: tuple[slice | np.ndarray[Any, np.dtype[np.generic]], ... for k in key: if isinstance(k, slice): k = as_integer_slice(k) - elif is_duck_dask_array(k): - raise ValueError( - "Vectorized indexing with Dask arrays is not supported. " - "Please pass a numpy array by calling ``.compute``. " - "See https://github.com/dask/dask/issues/8958." - ) + # elif is_duck_dask_array(k): + # raise ValueError( + # "Vectorized indexing with Dask arrays is not supported. " + # "Please pass a numpy array by calling ``.compute``. " + # "See https://github.com/dask/dask/issues/8958." + # ) elif is_duck_array(k): if not np.issubdtype(k.dtype, np.integer): raise TypeError( @@ -1607,6 +1608,18 @@ def transpose(self, order): return xp.permute_dims(self.array, order) +def _apply_vectorized_indexer_dask_wrapper(indices, coord): + from xarray.core.indexing import ( + VectorizedIndexer, + apply_indexer, + as_indexable, + ) + + return apply_indexer( + as_indexable(coord), VectorizedIndexer((indices.squeeze(axis=-1),)) + ) + + class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin): """Wrap a dask array to support explicit indexing.""" @@ -1630,7 +1643,29 @@ def _oindex_get(self, indexer: OuterIndexer): return value def _vindex_get(self, indexer: VectorizedIndexer): - return self.array.vindex[indexer.tuple] + try: + return self.array.vindex[indexer.tuple] + except IndexError as e: + # TODO: upstream to dask + has_dask = any(is_duck_dask_array(i) for i in indexer.tuple) + if not has_dask or (has_dask and len(indexer.tuple) > 1): + raise e + if math.prod(self.array.numblocks) > 1 or self.array.ndim > 1: + raise e + (idxr,) = indexer.tuple + if idxr.ndim == 0: + return self.array[idxr.data] + else: + import dask.array + + return dask.array.map_blocks( + _apply_vectorized_indexer_dask_wrapper, + idxr[..., np.newaxis], + self.array, + chunks=idxr.chunks, + drop_axis=-1, + dtype=self.array.dtype, + ) def __getitem__(self, indexer: ExplicitIndexer): self._check_and_raise_if_non_basic_indexer(indexer) diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 698849cb7fe..6c82112f233 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -1010,8 +1010,7 @@ def test_vectorized_indexing_dask_array(): coords={"y": range(4), "x": range(2)}, dims=("y", "x"), ) - with pytest.raises(ValueError, match="Vectorized indexing with Dask arrays"): - darr[indexer.chunk({"y": 2})] + darr[indexer.chunk({"y": 2})] @requires_dask