Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: refactor geometry matching clauses with Python 3.10's pattern matching #5051

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 34 additions & 34 deletions yt/data_objects/selection_objects/data_selection_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,23 +609,22 @@ def to_frb(self, width, resolution, center=None, height=None, periodic=False):
>>> write_image(np.log10(frb["gas", "density"]), "density_100kpc.png")
"""

if (self.ds.geometry is Geometry.CYLINDRICAL and self.axis == 1) or (
self.ds.geometry is Geometry.POLAR and self.axis == 2
):
if center is not None and center != (0.0, 0.0):
raise NotImplementedError(
"Currently we only support images centered at R=0. "
+ "We plan to generalize this in the near future"
match (self.ds.geometry, self.axis):
case (Geometry.CYLINDRICAL, 1) | (Geometry.POLAR, 2):
if center is not None and center != (0.0, 0.0):
raise NotImplementedError(
"Currently we only support images centered at R=0. "
+ "We plan to generalize this in the near future"
)
from yt.visualization.fixed_resolution import (
CylindricalFixedResolutionBuffer,
)
from yt.visualization.fixed_resolution import (
CylindricalFixedResolutionBuffer,
)

validate_width_tuple(width)
if is_sequence(resolution):
resolution = max(resolution)
frb = CylindricalFixedResolutionBuffer(self, width, resolution)
return frb
validate_width_tuple(width)
if is_sequence(resolution):
resolution = max(resolution)
frb = CylindricalFixedResolutionBuffer(self, width, resolution)
return frb

if center is None:
center = self.center
Expand Down Expand Up @@ -1401,25 +1400,26 @@ def get_bbox(self) -> tuple[unyt_array, unyt_array]:
"""
Return the bounding box for this data container.
"""
geometry: Geometry = self.ds.geometry
if geometry is Geometry.CARTESIAN:
le, re = self._get_bbox()
le.convert_to_units("code_length")
re.convert_to_units("code_length")
return le, re
elif (
geometry is Geometry.CYLINDRICAL
or geometry is Geometry.POLAR
or geometry is Geometry.SPHERICAL
or geometry is Geometry.GEOGRAPHIC
or geometry is Geometry.INTERNAL_GEOGRAPHIC
or geometry is Geometry.SPECTRAL_CUBE
):
raise NotImplementedError(
f"get_bbox is currently not implemented for {geometry=}!"
)
else:
assert_never(geometry)
match self.ds.geometry:
case Geometry.CARTESIAN:
le, re = self._get_bbox()
le.convert_to_units("code_length")
re.convert_to_units("code_length")
return le, re
case (
Geometry.CYLINDRICAL
| Geometry.POLAR
| Geometry.SPHERICAL
| Geometry.GEOGRAPHIC
| Geometry.INTERNAL_GEOGRAPHIC
| Geometry.SPECTRAL_CUBE
):
geometry = self.ds.geometry
raise NotImplementedError(
f"get_bbox is currently not implemented for {geometry=}!"
)
case _:
assert_never(self.ds.geometry)

def volume(self):
"""
Expand Down
51 changes: 26 additions & 25 deletions yt/data_objects/static_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,28 +798,29 @@ def _setup_coordinate_handler(self, axis_order: AxisOrder | None) -> None:
f"Got {self.geometry=} with type {type(self.geometry)}"
)

if self.geometry is Geometry.CARTESIAN:
cls = CartesianCoordinateHandler
elif self.geometry is Geometry.CYLINDRICAL:
cls = CylindricalCoordinateHandler
elif self.geometry is Geometry.POLAR:
cls = PolarCoordinateHandler
elif self.geometry is Geometry.SPHERICAL:
cls = SphericalCoordinateHandler
# It shouldn't be required to reset self.no_cgs_equiv_length
# to the default value (False) here, but it's still necessary
# see https://github.com/yt-project/yt/pull/3618
self.no_cgs_equiv_length = False
elif self.geometry is Geometry.GEOGRAPHIC:
cls = GeographicCoordinateHandler
self.no_cgs_equiv_length = True
elif self.geometry is Geometry.INTERNAL_GEOGRAPHIC:
cls = InternalGeographicCoordinateHandler
self.no_cgs_equiv_length = True
elif self.geometry is Geometry.SPECTRAL_CUBE:
cls = SpectralCubeCoordinateHandler
else:
assert_never(self.geometry)
match self.geometry:
case Geometry.CARTESIAN:
cls = CartesianCoordinateHandler
case Geometry.CYLINDRICAL:
cls = CylindricalCoordinateHandler
case Geometry.POLAR:
cls = PolarCoordinateHandler
case Geometry.SPHERICAL:
cls = SphericalCoordinateHandler
# It shouldn't be required to reset self.no_cgs_equiv_length
# to the default value (False) here, but it's still necessary
# see https://github.com/yt-project/yt/pull/3618
self.no_cgs_equiv_length = False
case Geometry.GEOGRAPHIC:
cls = GeographicCoordinateHandler
self.no_cgs_equiv_length = True
case Geometry.INTERNAL_GEOGRAPHIC:
cls = InternalGeographicCoordinateHandler
self.no_cgs_equiv_length = True
case Geometry.SPECTRAL_CUBE:
cls = SpectralCubeCoordinateHandler
case _:
assert_never(self.geometry)

self.coordinates = cls(self, ordering=axis_order)

Expand Down Expand Up @@ -1948,9 +1949,9 @@ def add_gradient_fields(self, fields=None):
... ("gas", "density_gradient_magnitude"),
... ]

Note that the above example assumes ds.geometry == 'cartesian'. In general,
the function will create gradient components along the axes of the dataset
coordinate system.
Note that the above example assumes ds.geometry is Geometry.CARTESIAN.
In general, the function will create gradient components along the axes
of the dataset coordinate system.
For instance, with cylindrical data, one gets 'density_gradient_<r,theta,z>'

"""
Expand Down
101 changes: 47 additions & 54 deletions yt/fields/field_info_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,29 +222,26 @@ def get_aliases_gallery(self) -> list[FieldName]:
if self.ds is None:
return aliases_gallery

geometry: Geometry = self.ds.geometry
if (
geometry is Geometry.POLAR
or geometry is Geometry.CYLINDRICAL
or geometry is Geometry.SPHERICAL
):
aliases: list[FieldName]
for field in sorted(self.field_list):
if field[0] in self.ds.particle_types:
continue
args = known_other_fields.get(field[1], ("", [], None))
units, aliases, display_name = args
aliases_gallery.extend(aliases)
elif (
geometry is Geometry.CARTESIAN
or geometry is Geometry.GEOGRAPHIC
or geometry is Geometry.INTERNAL_GEOGRAPHIC
or geometry is Geometry.SPECTRAL_CUBE
):
# nothing to do
pass
else:
assert_never(geometry)
match self.ds.geometry:
case Geometry.POLAR | Geometry.CYLINDRICAL | Geometry.SPHERICAL:
aliases: list[FieldName]
for field in sorted(self.field_list):
if field[0] in self.ds.particle_types:
continue
args = known_other_fields.get(field[1], ("", [], None))
units, aliases, display_name = args
aliases_gallery.extend(aliases)
case (
Geometry.CARTESIAN
| Geometry.GEOGRAPHIC
| Geometry.INTERNAL_GEOGRAPHIC
| Geometry.SPECTRAL_CUBE
):
# nothing to do
pass
case _:
assert_never(self.ds.geometry)

return aliases_gallery

def setup_fluid_aliases(self, ftype: FieldType = "gas") -> None:
Expand Down Expand Up @@ -280,38 +277,34 @@ def setup_fluid_aliases(self, ftype: FieldType = "gas") -> None:
field, sampling_type="cell", units=units, display_name=display_name
)
axis_names = self.ds.coordinates.axis_order
geometry: Geometry = self.ds.geometry
for alias in aliases:
if (
geometry is Geometry.POLAR
or geometry is Geometry.CYLINDRICAL
or geometry is Geometry.SPHERICAL
):
if alias[-2:] not in ["_x", "_y", "_z"]:
to_convert = False
else:
for suffix in ["x", "y", "z"]:
if f"{alias[:-2]}_{suffix}" not in aliases_gallery:
to_convert = False
break
to_convert = True
if to_convert:
if alias[-2:] == "_x":
alias = f"{alias[:-2]}_{axis_names[0]}"
elif alias[-2:] == "_y":
alias = f"{alias[:-2]}_{axis_names[1]}"
elif alias[-2:] == "_z":
alias = f"{alias[:-2]}_{axis_names[2]}"
elif (
geometry is Geometry.CARTESIAN
or geometry is Geometry.GEOGRAPHIC
or geometry is Geometry.INTERNAL_GEOGRAPHIC
or geometry is Geometry.SPECTRAL_CUBE
):
# nothing to do
pass
else:
assert_never(geometry)
match self.ds.geometry:
case Geometry.POLAR | Geometry.CYLINDRICAL | Geometry.SPHERICAL:
if alias[-2:] not in ["_x", "_y", "_z"]:
to_convert = False
else:
for suffix in ["x", "y", "z"]:
if f"{alias[:-2]}_{suffix}" not in aliases_gallery:
to_convert = False
break
to_convert = True
if to_convert:
if alias[-2:] == "_x":
alias = f"{alias[:-2]}_{axis_names[0]}"
elif alias[-2:] == "_y":
alias = f"{alias[:-2]}_{axis_names[1]}"
elif alias[-2:] == "_z":
alias = f"{alias[:-2]}_{axis_names[2]}"
case (
Geometry.CARTESIAN
| Geometry.GEOGRAPHIC
| Geometry.INTERNAL_GEOGRAPHIC
| Geometry.SPECTRAL_CUBE
):
# nothing to do
pass
case _:
assert_never(self.ds.geometry)
self.alias((ftype, alias), field)

@staticmethod
Expand Down
Loading
Loading