Skip to content

Commit

Permalink
Merge pull request #4169 from meeseeksmachine/auto-backport-of-pr-416…
Browse files Browse the repository at this point in the history
…6-on-yt-4.1.x

Backport PR #4166 on branch yt-4.1.x (BUG: fix future incompatiblities with unyt 3.0 (2/n))
  • Loading branch information
neutrinoceros authored Oct 14, 2022
2 parents e4957fa + 12dda0e commit 4db7e00
Show file tree
Hide file tree
Showing 17 changed files with 239 additions and 26 deletions.
3 changes: 2 additions & 1 deletion yt/data_objects/construction_data_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@
from yt.geometry import particle_deposit as particle_deposit
from yt.geometry.coordinates.cartesian_coordinates import all_data
from yt.loaders import load_uniform_grid
from yt.units._numpy_wrapper_functions import uconcatenate
from yt.units.unit_object import Unit # type: ignore
from yt.units.yt_array import YTArray, uconcatenate # type: ignore
from yt.units.yt_array import YTArray
from yt.utilities.exceptions import (
YTNoAPIKey,
YTNotInsideNotebook,
Expand Down
3 changes: 2 additions & 1 deletion yt/data_objects/data_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from yt.fields.field_exceptions import NeedsGridType
from yt.frontends.ytdata.utilities import save_as_dataset
from yt.funcs import get_output_filename, is_sequence, iter_fields, mylog
from yt.units.yt_array import YTArray, YTQuantity, uconcatenate # type: ignore
from yt.units._numpy_wrapper_functions import uconcatenate
from yt.units.yt_array import YTArray, YTQuantity
from yt.utilities.amr_kdtree.api import AMRKDTree
from yt.utilities.exceptions import (
YTCouldNotGenerateField,
Expand Down
2 changes: 1 addition & 1 deletion yt/data_objects/selection_objects/ray.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
from unyt import udot, unorm

from yt.data_objects.selection_objects.data_selection_objects import (
YTSelectionContainer,
Expand All @@ -16,6 +15,7 @@
validate_sequence,
)
from yt.units import YTArray, YTQuantity
from yt.units._numpy_wrapper_functions import udot, unorm
from yt.utilities.lib.pixelization_routines import SPHKernelInterpolationTable
from yt.utilities.logger import ytLogger as mylog

Expand Down
2 changes: 1 addition & 1 deletion yt/data_objects/tests/test_chunking.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from yt.testing import assert_equal, assert_true, fake_random_ds
from yt.units.yt_array import uconcatenate
from yt.units._numpy_wrapper_functions import uconcatenate


def _get_dobjs(c):
Expand Down
3 changes: 2 additions & 1 deletion yt/data_objects/tests/test_compose.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np

from yt.testing import assert_array_equal, fake_amr_ds, fake_random_ds
from yt.units.yt_array import YTArray, uintersect1d
from yt.units._numpy_wrapper_functions import uintersect1d
from yt.units.yt_array import YTArray


def setup():
Expand Down
2 changes: 1 addition & 1 deletion yt/data_objects/tests/test_rays.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from yt import load
from yt.testing import assert_equal, assert_rel_equal, fake_random_ds, requires_file
from yt.units.yt_array import uconcatenate
from yt.units._numpy_wrapper_functions import uconcatenate


def test_ray():
Expand Down
2 changes: 1 addition & 1 deletion yt/fields/particle_fields.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from yt.fields.derived_field import ValidateParameter, ValidateSpatial
from yt.units.yt_array import uconcatenate, ucross # type: ignore
from yt.units._numpy_wrapper_functions import uconcatenate, ucross
from yt.utilities.lib.misc_utilities import (
obtain_position_vector,
obtain_relative_velocity_vector,
Expand Down
2 changes: 1 addition & 1 deletion yt/frontends/gadget/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

from yt.frontends.sph.io import IOHandlerSPH
from yt.units.yt_array import uconcatenate # type: ignore
from yt.units._numpy_wrapper_functions import uconcatenate
from yt.utilities.lib.particle_kdtree_tools import generate_smoothing_length
from yt.utilities.logger import ytLogger as mylog
from yt.utilities.on_demand_imports import _h5py as h5py
Expand Down
3 changes: 2 additions & 1 deletion yt/frontends/ytdata/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
from yt.geometry.grid_geometry_handler import GridIndex
from yt.geometry.particle_geometry_handler import ParticleIndex
from yt.units import dimensions
from yt.units._numpy_wrapper_functions import uconcatenate
from yt.units.unit_registry import UnitRegistry # type: ignore
from yt.units.yt_array import YTQuantity, uconcatenate # type: ignore
from yt.units.yt_array import YTQuantity
from yt.utilities.exceptions import GenerationInProgress, YTFieldTypeNotFound
from yt.utilities.logger import ytLogger as mylog
from yt.utilities.on_demand_imports import _h5py as h5py
Expand Down
3 changes: 2 additions & 1 deletion yt/geometry/coordinates/cartesian_coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from yt.data_objects.index_subobjects.unstructured_mesh import SemiStructuredMesh
from yt.funcs import mylog
from yt.units.yt_array import YTArray, uconcatenate, uvstack # type: ignore
from yt.units._numpy_wrapper_functions import uconcatenate, uvstack
from yt.units.yt_array import YTArray
from yt.utilities.lib.pixelization_routines import (
interpolate_sph_grid_gather,
normalization_2d_utility,
Expand Down
3 changes: 2 additions & 1 deletion yt/geometry/geometry_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import numpy as np

from yt.config import ytcfg
from yt.units.yt_array import YTArray, uconcatenate # type: ignore
from yt.units._numpy_wrapper_functions import uconcatenate
from yt.units.yt_array import YTArray
from yt.utilities.exceptions import YTFieldNotFound
from yt.utilities.io_handler import io_registry
from yt.utilities.logger import ytLogger as mylog
Expand Down
21 changes: 11 additions & 10 deletions yt/units/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,8 @@
from unyt.array import (
loadtxt,
savetxt,
uconcatenate,
ucross,
udot,
uhstack,
uintersect1d,
unorm,
unyt_array,
unyt_quantity,
ustack,
uunion1d,
uvstack,
)
from unyt.unit_object import Unit, define_unit # NOQA: F401
from unyt.unit_registry import UnitRegistry # NOQA: Ffg401
Expand All @@ -22,7 +13,17 @@
from yt.units.unit_symbols import *
from yt.units.unit_symbols import _SymbolContainer
from yt.utilities.exceptions import YTArrayTooLargeToDisplay

from yt.units._numpy_wrapper_functions import (
uconcatenate,
ucross,
udot,
uhstack,
uintersect1d,
unorm,
ustack,
uunion1d,
uvstack,
)
YTArray = unyt_array

YTQuantity = unyt_quantity
Expand Down
206 changes: 206 additions & 0 deletions yt/units/_numpy_wrapper_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# This module is not part of the public namespace `yt.units`
# It is home to wrapper functions that are directly copied from unyt 2.9.2
# We vendor them as a transition step towards unyt 3.0 (in devlopment),
# where these wrapper functions are deprecated and are should be replaced with vanilla numpy API
# FUTURE:
# - require unyt>=3.0
# - deprecate these functions in yt too

from unyt import unyt_array, unyt_quantity
import numpy as np


def _validate_numpy_wrapper_units(v, arrs):
if not any(isinstance(a, unyt_array) for a in arrs):
return v
if not all(isinstance(a, unyt_array) for a in arrs):
raise RuntimeError("Not all of your arrays are unyt_arrays.")
a1 = arrs[0]
if not all(a.units == a1.units for a in arrs[1:]):
raise RuntimeError("Your arrays must have identical units.")
v.units = a1.units
return v


def uconcatenate(arrs, axis=0):
"""Concatenate a sequence of arrays.
This wrapper around numpy.concatenate preserves units. All input arrays
must have the same units. See the documentation of numpy.concatenate for
full details.
Examples
--------
>>> from unyt import cm
>>> A = [1, 2, 3]*cm
>>> B = [2, 3, 4]*cm
>>> uconcatenate((A, B))
unyt_array([1, 2, 3, 2, 3, 4], 'cm')
"""
v = np.concatenate(arrs, axis=axis)
v = _validate_numpy_wrapper_units(v, arrs)
return v


def ucross(arr1, arr2, registry=None, axisa=-1, axisb=-1, axisc=-1, axis=None):
"""Applies the cross product to two YT arrays.
This wrapper around numpy.cross preserves units.
See the documentation of numpy.cross for full
details.
"""
v = np.cross(arr1, arr2, axisa=axisa, axisb=axisb, axisc=axisc, axis=axis)
units = arr1.units * arr2.units
arr = unyt_array(v, units, registry=registry)
return arr


def uintersect1d(arr1, arr2, assume_unique=False):
"""Find the sorted unique elements of the two input arrays.
A wrapper around numpy.intersect1d that preserves units. All input arrays
must have the same units. See the documentation of numpy.intersect1d for
full details.
Examples
--------
>>> from unyt import cm
>>> A = [1, 2, 3]*cm
>>> B = [2, 3, 4]*cm
>>> uintersect1d(A, B)
unyt_array([2, 3], 'cm')
"""
v = np.intersect1d(arr1, arr2, assume_unique=assume_unique)
v = _validate_numpy_wrapper_units(v, [arr1, arr2])
return v


def uunion1d(arr1, arr2):
"""Find the union of two arrays.
A wrapper around numpy.intersect1d that preserves units. All input arrays
must have the same units. See the documentation of numpy.intersect1d for
full details.
Examples
--------
>>> from unyt import cm
>>> A = [1, 2, 3]*cm
>>> B = [2, 3, 4]*cm
>>> uunion1d(A, B)
unyt_array([1, 2, 3, 4], 'cm')
"""
v = np.union1d(arr1, arr2)
v = _validate_numpy_wrapper_units(v, [arr1, arr2])
return v


def unorm(data, ord=None, axis=None, keepdims=False):
"""Matrix or vector norm that preserves units
This is a wrapper around np.linalg.norm that preserves units. See
the documentation for that function for descriptions of the keyword
arguments.
Examples
--------
>>> from unyt import km
>>> data = [1, 2, 3]*km
>>> print(unorm(data))
3.7416573867739413 km
"""
norm = np.linalg.norm(data, ord=ord, axis=axis, keepdims=keepdims)
if norm.shape == ():
return unyt_quantity(norm, data.units)
return unyt_array(norm, data.units)


def udot(op1, op2):
"""Matrix or vector dot product that preserves units
This is a wrapper around np.dot that preserves units.
Examples
--------
>>> from unyt import km, s
>>> a = np.eye(2)*km
>>> b = (np.ones((2, 2)) * 2)*s
>>> print(udot(a, b))
[[2. 2.]
[2. 2.]] km*s
"""
dot = np.dot(op1.d, op2.d)
units = op1.units * op2.units
if dot.shape == ():
return unyt_quantity(dot, units)
return unyt_array(dot, units)


def uvstack(arrs):
"""Stack arrays in sequence vertically (row wise) while preserving units
This is a wrapper around np.vstack that preserves units.
Examples
--------
>>> from unyt import km
>>> a = [1, 2, 3]*km
>>> b = [2, 3, 4]*km
>>> print(uvstack([a, b]))
[[1 2 3]
[2 3 4]] km
"""
v = np.vstack(arrs)
v = _validate_numpy_wrapper_units(v, arrs)
return v


def uhstack(arrs):
"""Stack arrays in sequence horizontally while preserving units
This is a wrapper around np.hstack that preserves units.
Examples
--------
>>> from unyt import km
>>> a = [1, 2, 3]*km
>>> b = [2, 3, 4]*km
>>> print(uhstack([a, b]))
[1 2 3 2 3 4] km
>>> a = [[1],[2],[3]]*km
>>> b = [[2],[3],[4]]*km
>>> print(uhstack([a, b]))
[[1 2]
[2 3]
[3 4]] km
"""
v = np.hstack(arrs)
v = _validate_numpy_wrapper_units(v, arrs)
return v


def ustack(arrs, axis=0):
"""Join a sequence of arrays along a new axis while preserving units
The axis parameter specifies the index of the new axis in the
dimensions of the result. For example, if ``axis=0`` it will be the
first dimension and if ``axis=-1`` it will be the last dimension.
This is a wrapper around np.stack that preserves units. See the
documentation for np.stack for full details.
Examples
--------
>>> from unyt import km
>>> a = [1, 2, 3]*km
>>> b = [2, 3, 4]*km
>>> print(ustack([a, b]))
[[1 2 3]
[2 3 4]] km
"""
v = np.stack(arrs, axis=axis)
v = _validate_numpy_wrapper_units(v, arrs)
return v
2 changes: 1 addition & 1 deletion yt/utilities/particle_generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from yt.funcs import get_pbar
from yt.units.yt_array import uconcatenate # type: ignore
from yt.units._numpy_wrapper_functions import uconcatenate
from yt.utilities.lib.particle_mesh_operations import CICSample_3


Expand Down
2 changes: 1 addition & 1 deletion yt/utilities/tests/test_particle_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from yt.loaders import load_uniform_grid
from yt.testing import assert_almost_equal, assert_equal
from yt.units.yt_array import uconcatenate
from yt.units._numpy_wrapper_functions import uconcatenate
from yt.utilities.particle_generator import (
FromListParticleGenerator,
LatticeParticleGenerator,
Expand Down
2 changes: 1 addition & 1 deletion yt/visualization/volume_rendering/lens.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from yt.data_objects.image_array import ImageArray
from yt.units.yt_array import uhstack, unorm, uvstack # type: ignore
from yt.units._numpy_wrapper_functions import uhstack, unorm, uvstack
from yt.utilities.lib.grid_traversal import arr_fisheye_vectors
from yt.utilities.math_utils import get_rotation_matrix
from yt.utilities.parallel_tools.parallel_analysis_interface import (
Expand Down
4 changes: 2 additions & 2 deletions yt/visualization/volume_rendering/old_camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def new_image(self):

def get_sampler_args(self, image):
rotp = np.concatenate(
[self.orienter.inv_mat.ravel("F"), self.back_center.ravel()]
[self.orienter.inv_mat.ravel("F"), self.back_center.ravel().ndview]
)
args = (
np.atleast_3d(rotp),
Expand Down Expand Up @@ -2125,7 +2125,7 @@ def initialize_source(self):

def get_sampler_args(self, image):
rotp = np.concatenate(
[self.orienter.inv_mat.ravel("F"), self.back_center.ravel()]
[self.orienter.inv_mat.ravel("F"), self.back_center.ravel().ndview]
)
args = (
np.atleast_3d(rotp),
Expand Down

0 comments on commit 4db7e00

Please sign in to comment.