Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite interp to use apply_ufunc #9881

Merged
merged 26 commits into from
Dec 19, 2024
Merged

Conversation

dcherian
Copy link
Contributor

@dcherian dcherian commented Dec 13, 2024

  1. Removes a bunch of complexity around interpolating dask arrays by using apply_ufunc instead of blockwise directly.
  2. A major improvement is that we can now use vectorize=True to get sane dask graphs for vectorized interpolation to chunked arrays (interp performance with chunked dimensions #6799 (comment))
  3. Added a bunch of typing.
  4. Happily this fixes Interpolation with multiple mutlidimensional arrays sharing dims fails #4463

cc @ks905383 your vectorized interpolation example now has this graph:
image

instead of this quadratic monstrosity
image

@dcherian dcherian added needs review run-benchmark Run the ASV benchmark workflow labels Dec 13, 2024
@dcherian dcherian requested a review from Illviljan December 13, 2024 06:32
@@ -4127,18 +4119,6 @@ def interp(

coords = either_dict_or_kwargs(coords, coords_kwargs, "interp")
indexers = dict(self._validate_interp_indexers(coords))

if coords:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handled by vectorize=True. This is possibly a perf regression with numpy arrays, but a massive improvement with chunked arrays.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For posterity the bad thing about this approach is that it can greatly expand the number of core dimensions for the problem, limiting the potential for parallelism.

Consider the problem in #6799 (comment). In the following, dimension names are listed out in [].

da[time, q, lat, lon].interp(q=bar[lat,lon]) gets rewritten to da[time,q,lat,lon].interp(q=bar[lat, lon], lat=lat[lat], lon=lon[lon]) which thanks to our automatic rechunking, makes dask merge chunks in lat, lon too, for no benefit.

def _chunked_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=True):
"""Wrapper for `_interpnd` through `blockwise` for chunked arrays.

def _interpnd(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I merged in two functions to reduce indirection and make it easier to read.

xarray/tests/test_interp.py Outdated Show resolved Hide resolved
exclude_dims=all_in_core_dims,
dask="parallelized",
kwargs=dict(interp_func=func, interp_kwargs=kwargs),
dask_gufunc_kwargs=dict(output_sizes=output_sizes, allow_rechunk=True),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allow_rechunk=True matches the current behaviour where we rechunk along all core dimensions to a single chunk.

@dcherian dcherian force-pushed the redo-blockwise-interp branch from 245697e to a5e1854 Compare December 14, 2024 00:06
@dcherian dcherian force-pushed the redo-blockwise-interp branch from 652a239 to 586f638 Compare December 14, 2024 04:02
@Illviljan Illviljan mentioned this pull request Dec 14, 2024
1 task
@dcherian
Copy link
Contributor Author

Merging on thursday if there are no comments.

IMO this is a big win for maintainability.

@dcherian dcherian added plan to merge Final call for comments and removed needs review labels Dec 17, 2024
@@ -566,29 +577,30 @@ def _get_valid_fill_mask(arr, dim, limit):
) <= limit


def _localize(var, indexes_coords):
def _localize(obj: T, indexes_coords: SourceDest) -> tuple[T, SourceDest]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should use T_Xarray instead of a plain T to get rid of the type ignore at return.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That doesn't have Variable, so I'd have to make a new T_DatasetOrVariable or a protocol with .isel perhaps?

xarray/core/missing.py Outdated Show resolved Hide resolved
xarray/core/missing.py Outdated Show resolved Hide resolved
xarray/tests/test_interp.py Outdated Show resolved Hide resolved
xarray/tests/test_interp.py Outdated Show resolved Hide resolved
Copy link
Contributor

@Illviljan Illviljan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmarks still looks good. Nice work!

xarray/core/dataset.py Show resolved Hide resolved
xarray/core/missing.py Outdated Show resolved Hide resolved
xarray/core/missing.py Outdated Show resolved Hide resolved
xarray/core/missing.py Outdated Show resolved Hide resolved
xarray/core/missing.py Outdated Show resolved Hide resolved
Comment on lines +830 to +831
# TODO: narrow interp_func to interpolator here
return _interp1d(var, x_list, new_x_list, interp_func, interp_kwargs) # type: ignore[arg-type]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mypy is correct to error here right?
_interp1d calls interp_func(...)(....) and that should crash with a InterpCallable?
Is there a pytest with interp_func: InterpCallable?
Is InterpCallable necessary? Would be nice to just remove it...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it depends on whether we end up using get_interpolator or get_interpolator_nd. I'm sure there's a test but can't remember which off the top of my head.

@dcherian dcherian merged commit 29fe679 into pydata:main Dec 19, 2024
29 checks passed
@dcherian dcherian deleted the redo-blockwise-interp branch December 19, 2024 16:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
plan to merge Final call for comments run-benchmark Run the ASV benchmark workflow topic-interpolation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Interpolation with multiple mutlidimensional arrays sharing dims fails
3 participants