Skip to content

Commit

Permalink
more
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan committed Aug 25, 2024
1 parent ec55434 commit 93bd9ca
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
8 changes: 1 addition & 7 deletions xarray/namedarray/_array_api/_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,11 @@
_ShapeType,
_SupportsImag,
_SupportsReal,
_atleast_0d,
)
from xarray.namedarray.core import NamedArray


def _atleast_0d(x, xp):
"""
Workaround for numpy sometimes returning scalars instead of 0d arrays.
"""
return xp.asarray(x)


def abs(x: NamedArray, /) -> NamedArray:
xp = _get_data_namespace(x)
_data = _atleast_0d(xp.abs(x._data), xp)
Expand Down
2 changes: 1 addition & 1 deletion xarray/namedarray/_array_api/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def roll(
) -> NamedArray:
xp = _get_data_namespace(x)
_data = xp.roll(x._data, shift=shift, axis=axis)
return x._new(_data)
return x._new(data=_data)


def squeeze(x: NamedArray, /, axis: _Axes) -> NamedArray:
Expand Down
7 changes: 7 additions & 0 deletions xarray/namedarray/_array_api/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,10 @@ def _insert_dim(dims: _Dims, dim: _Dim | Default, axis: _Axis) -> _Dims:
d = list(dims)
d.insert(axis, dim)
return tuple(d)


def _atleast_0d(x, xp):
"""
Workaround for numpy sometimes returning scalars instead of 0d arrays.
"""
return xp.asarray(x)

0 comments on commit 93bd9ca

Please sign in to comment.