Skip to content

Commit

Permalink
Use apply_ufunc instead
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Dec 13, 2024
1 parent 81b73b9 commit 652bcc1
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 174 deletions.
41 changes: 15 additions & 26 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4119,18 +4119,6 @@ def interp(

coords = either_dict_or_kwargs(coords, coords_kwargs, "interp")
indexers = dict(self._validate_interp_indexers(coords))

if coords:
# This avoids broadcasting over coordinates that are both in
# the original array AND in the indexing array. It essentially
# forces interpolation along the shared coordinates.
sdims = (
set(self.dims)
.intersection(*[set(nx.dims) for nx in indexers.values()])
.difference(coords.keys())
)
indexers.update({d: self.variables[d] for d in sdims})

obj = self if assume_sorted else self.sortby(list(coords))

def maybe_variable(obj, k):
Expand Down Expand Up @@ -4161,16 +4149,18 @@ def _validate_interp_indexer(x, new_x):
for k, v in indexers.items()
}

# optimization: subset to coordinate range of the target index
if method in ["linear", "nearest"]:
for k, v in validated_indexers.items():
obj, newidx = missing._localize(obj, {k: v})
validated_indexers[k] = newidx[k]

# optimization: create dask coordinate arrays once per Dataset
# rather than once per Variable when dask.array.unify_chunks is called later
# GH4739
if obj.__dask_graph__():
has_chunked_array = bool(
any(is_chunked_array(v._data) for v in obj._variables.values())
)
if has_chunked_array:
# optimization: subset to coordinate range of the target index
if method in ["linear", "nearest"]:
for k, v in validated_indexers.items():
obj, newidx = missing._localize(obj, {k: v})
validated_indexers[k] = newidx[k]
# optimization: create dask coordinate arrays once per Dataset
# rather than once per Variable when dask.array.unify_chunks is called later
# GH4739
dask_indexers = {
k: (index.to_base_variable().chunk(), dest.to_base_variable().chunk())
for k, (index, dest) in validated_indexers.items()
Expand All @@ -4182,10 +4172,9 @@ def _validate_interp_indexer(x, new_x):
if name in indexers:
continue

if is_duck_dask_array(var.data):
use_indexers = dask_indexers
else:
use_indexers = validated_indexers
use_indexers = (
dask_indexers if is_duck_dask_array(var.data) else validated_indexers
)

dtype_kind = var.dtype.kind
if dtype_kind in "uifc":
Expand Down
Loading

0 comments on commit 652bcc1

Please sign in to comment.