Skip to content

Commit

Permalink
passing mypy, pytest locally
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishavlin committed Oct 16, 2024
1 parent f7b6e65 commit bc0d734
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 94 deletions.
6 changes: 3 additions & 3 deletions yt_xarray/accessor/_xr_to_yt.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Selection:

def __init__(
self,
xr_ds,
xr_ds: xr.Dataset,
fields: List[str] | None = None,
sel_dict: dict[str, Any] | None = None,
sel_dict_type: str | None = "isel",
Expand Down Expand Up @@ -100,7 +100,7 @@ def _find_starting_index(self, coordname, coord_da, coord_select) -> int:
search_for = selector.start
elif isinstance(selector, (float, np.datetime64, int)):
search_for = selector
elif isinstance(
elif isinstance( # type: ignore[unreachable]
selector, (collections.abc.Sequence, np.ndarray, xr.DataArray)
):
if _size_of_array_like(selector) > 1:
Expand Down Expand Up @@ -402,7 +402,7 @@ def _cf_xr_coord_disamb(


def _convert_to_yt_internal_coords(
coord_list: tuple[str] | list[str], xr_field: xr.DataArray
coord_list: tuple[str, ...] | list[str], xr_field: xr.DataArray
):
yt_coords = []
for c in coord_list:
Expand Down
3 changes: 2 additions & 1 deletion yt_xarray/tests/test_yt_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import os.path

import pytest
from yt.data_objects.static_output import Dataset as ytDataset

from yt_xarray.sample_data import load_random_xr_data


def get_xr_ds():
def get_xr_ds() -> ytDataset:
fields = {
"temperature": ("x", "y", "z"),
"pressure": ("x", "y", "z"),
Expand Down
104 changes: 59 additions & 45 deletions yt_xarray/transformations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import abc
from typing import Callable, List, Mapping, Optional, Tuple, Union
from typing import Any, Callable, Mapping, Union

import numpy as np
import unyt
Expand All @@ -20,10 +20,10 @@ class Transformer(abc.ABC):
Parameters
----------
native_coords: Tuple[str]
native_coords: tuple[str, ...]
the names of the native coordinates, e.g., ('x0', 'y0', 'z0'), on
which data is defined.
transformed_coords: Tuple[str]
transformed_coords: tuple[str, ...]
the names of the transformed coordinates, e.g., ('x1', 'y1', 'z1')
coord_aliases: dict
optional dictionary of coordinate aliases to map arbitrary keys to
Expand All @@ -36,9 +36,9 @@ class Transformer(abc.ABC):

def __init__(
self,
native_coords: Tuple[str],
transformed_coords: Tuple[str],
coord_aliases: Optional[dict] = None,
native_coords: tuple[str, ...],
transformed_coords: tuple[str, ...],
coord_aliases: dict[str, str] | None = None,
):
self.native_coords = native_coords
self._native_coords_set = set(native_coords)
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(
raise ValueError(msg)
self.coord_aliases = coord_aliases

def _disambiguate_coord(self, coord):
def _disambiguate_coord(self, coord: str) -> str:
if coord in self.native_coords or coord in self.transformed_coords:
return coord

Expand All @@ -79,14 +79,14 @@ def _disambiguate_coord(self, coord):
raise ValueError(msg)

@abc.abstractmethod
def _calculate_native(self, **coords):
def _calculate_native(self, **coords) -> list[npt.NDArray]:
"""
function to convert from transformed to native coordinates. Must be
implemented by each child class.
"""

@abc.abstractmethod
def _calculate_transformed(self, **coords):
def _calculate_transformed(self, **coords) -> list[npt.NDArray]:
"""
function to convert from native to transformed coordinates. Must be
implemented by each child class.
Expand Down Expand Up @@ -170,7 +170,9 @@ def to_transformed(self, **coords: npt.NDArray) -> list[npt.NDArray]:
return self._calculate_transformed(**new_coords)

@abc.abstractmethod
def calculate_transformed_bbox(self, bbox_dict: dict) -> np.ndarray:
def calculate_transformed_bbox(
self, bbox_dict: Mapping[str, npt.NDArray]
) -> npt.NDArray:
"""
Calculates a bounding box in transformed coordinates for a bounding box dictionary
in native coordinates.
Expand Down Expand Up @@ -198,7 +200,7 @@ class LinearScale(Transformer):
Parameters
----------
native_coords: Tuple[str]
native_coords: tuple[str, ...]
the names of the native coordinates, e.g., ('x', 'y', 'z'), on
which data is defined.
scale: dict
Expand All @@ -224,7 +226,9 @@ class LinearScale(Transformer):
"""

def __init__(self, native_coords: Tuple[str, ...], scale: Optional[dict] = None):
def __init__(
self, native_coords: tuple[str, ...], scale: dict[str, float] | None = None
):
if scale is None:
scale = {}

Expand Down Expand Up @@ -277,7 +281,9 @@ def calculate_transformed_bbox(
)


def _sphere_to_cart(r, theta, phi):
def _sphere_to_cart(
r: npt.NDArray, theta: npt.NDArray, phi: npt.NDArray
) -> tuple[npt.NDArray, npt.NDArray, npt.NDArray]:
# r : radius
# theta: colatitude
# phi: azimuth
Expand All @@ -289,7 +295,9 @@ def _sphere_to_cart(r, theta, phi):
return x, y, z


def _cart_to_sphere(x, y, z):
def _cart_to_sphere(
x: npt.NDArray, y: npt.NDArray, z: npt.NDArray
) -> tuple[npt.NDArray, npt.NDArray, npt.NDArray]:
# will return phi (azimuth) in +/- np.pi
r = np.sqrt(x * x + y * y + z * z)
theta = np.arccos(z / (r + 1e-12))
Expand Down Expand Up @@ -338,10 +346,10 @@ class GeocentricCartesian(Transformer):
def __init__(
self,
radial_type: str = "radius",
radial_axis: Optional[str] = None,
r_o: Optional[Union[float, unyt.unyt_quantity]] = None,
coord_aliases: Optional[dict] = None,
use_neg_lons: Optional[bool] = False,
radial_axis: str | None = None,
r_o: Union[float, unyt.unyt_quantity] | None = None,
coord_aliases: dict[str, str] | None = None,
use_neg_lons: bool = False,
):
transformed_coords = ("x", "y", "z")

Expand All @@ -366,7 +374,7 @@ def __init__(

super().__init__(native_coords, transformed_coords, coord_aliases=coord_aliases)

def _calculate_transformed(self, **coords):
def _calculate_transformed(self, **coords) -> list[npt.NDArray]:
if self.radial_type == "depth":
r_val = self._r_o - coords[self.radial_axis]
elif self.radial_type == "altitude":
Expand All @@ -378,9 +386,9 @@ def _calculate_transformed(self, **coords):
theta = (90.0 - lat) * np.pi / 180.0 # colatitude in radians
phi = lon * np.pi / 180.0 # azimuth in radians
x, y, z = _sphere_to_cart(r_val, theta, phi)
return x, y, z
return [x, y, z]

def _calculate_native(self, **coords):
def _calculate_native(self, **coords) -> list[npt.NDArray]:
r, theta, phi = _cart_to_sphere(coords["x"], coords["y"], coords["z"])
lat = 90.0 - theta * 180.0 / np.pi
lon = phi * 180.0 / np.pi
Expand All @@ -394,9 +402,11 @@ def _calculate_native(self, **coords):
r = r - self._r_o
elif self.radial_type == "depth":
r = self._r_o - r
return r, lat, lon
return [r, lat, lon]

def calculate_transformed_bbox(self, bbox_dict: dict) -> np.ndarray:
def calculate_transformed_bbox(
self, bbox_dict: Mapping[str, npt.NDArray]
) -> npt.NDArray:
"""
Calculates a bounding box in transformed coordinates for a bounding box dictionary
in native coordinates.
Expand Down Expand Up @@ -447,20 +457,20 @@ def calculate_transformed_bbox(self, bbox_dict: dict) -> np.ndarray:
def build_interpolated_cartesian_ds(
xr_ds: xr.Dataset,
transformer: Transformer,
fields: Optional[Union[str, Tuple[str]]] = None,
grid_resolution: Optional[List[int]] = None,
fill_value: Optional[float] = None,
length_unit: Optional[str] = "km",
refine_grid: Optional[bool] = False,
refine_by: Optional[int] = 2,
refine_max_iters: Optional[int] = 200,
refine_min_grid_size: Optional[int] = 10,
refinement_method: Optional[str] = "division",
sel_dict: Optional[dict] = None,
sel_dict_type: Optional[str] = "isel",
bbox_dict: Optional[dict] = None,
interp_method: Optional[str] = "nearest",
interp_func: Optional[Callable] = None,
fields: Union[str, tuple[str, ...], list[str]] | None = None,
grid_resolution: tuple[int, ...] | list[int] | None = None,
fill_value: float | None = None,
length_unit: str | float = "km",
refine_grid: bool = False,
refine_by: int = 2,
refine_max_iters: int = 200,
refine_min_grid_size: int = 10,
refinement_method: str = "division",
sel_dict: dict[str, Any] | None = None,
sel_dict_type: str = "isel",
bbox_dict: Mapping[str, npt.NDArray] | None = None,
interp_method: str = "nearest",
interp_func: Callable[..., npt.NDArray] | None = None,
):
"""
Build a yt cartesian dataset containing fields interpolated on demand
Expand Down Expand Up @@ -514,9 +524,6 @@ def build_interpolated_cartesian_ds(
"""

if fields is None:
fields = list(xr_ds.data_vars)

valid_methods = ("interpolate", "nearest")
if interp_method not in valid_methods:
msg = f"interp_method must be one of: {valid_methods}, found {interp_method}."
Expand All @@ -528,12 +535,19 @@ def build_interpolated_cartesian_ds(
)
interp_method = "interpolate"

if isinstance(fields, str):
fields = (fields,)
valid_fields: list[str]
if fields is None:
valid_fields = list(xr_ds.data_vars)
elif isinstance(fields, str):
valid_fields = [
fields,
]
else:
valid_fields = [f for f in fields]

sel_info = _xr_to_yt.Selection(
xr_ds,
fields=fields,
fields=valid_fields,
sel_dict=sel_dict,
sel_dict_type=sel_dict_type,
)
Expand Down Expand Up @@ -607,8 +621,8 @@ def _read_data(grid, field_name):

return output_vals

data_dict = {}
for field in fields:
data_dict: dict[str, Callable[..., npt.NDArray]] = {}
for field in valid_fields:
data_dict[field] = _read_data

if grid_resolution is None:
Expand Down
Loading

0 comments on commit bc0d734

Please sign in to comment.