diff --git a/xarray/namedarray/_array_api/_elementwise_functions.py b/xarray/namedarray/_array_api/_elementwise_functions.py index 0d91b13f41d..0fc19a340c8 100644 --- a/xarray/namedarray/_array_api/_elementwise_functions.py +++ b/xarray/namedarray/_array_api/_elementwise_functions.py @@ -8,17 +8,11 @@ _ShapeType, _SupportsImag, _SupportsReal, + _atleast_0d, ) from xarray.namedarray.core import NamedArray -def _atleast_0d(x, xp): - """ - Workaround for numpy sometimes returning scalars instead of 0d arrays. - """ - return xp.asarray(x) - - def abs(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) _data = _atleast_0d(xp.abs(x._data), xp) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 5264786ded5..2bf0ef9dc3e 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -167,7 +167,7 @@ def roll( ) -> NamedArray: xp = _get_data_namespace(x) _data = xp.roll(x._data, shift=shift, axis=axis) - return x._new(_data) + return x._new(data=_data) def squeeze(x: NamedArray, /, axis: _Axes) -> NamedArray: diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 7f462bedb47..fb01a4ec56e 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -174,3 +174,10 @@ def _insert_dim(dims: _Dims, dim: _Dim | Default, axis: _Axis) -> _Dims: d = list(dims) d.insert(axis, dim) return tuple(d) + + +def _atleast_0d(x, xp): + """ + Workaround for numpy sometimes returning scalars instead of 0d arrays. + """ + return xp.asarray(x)