Skip to content

Commit

Permalink
Merge pull request #30 from meom-group/issue-#26
Browse files Browse the repository at this point in the history
Issue #26
  • Loading branch information
vadmbertr authored Nov 6, 2023
2 parents 94ef499 + 8be8dad commit e28d42b
Show file tree
Hide file tree
Showing 35 changed files with 799 additions and 452 deletions.
7 changes: 4 additions & 3 deletions jaxparrow/cyclogeostrophy.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def cyclogeostrophy(u_geos: Union[np.ndarray, np.ma.MaskedArray], v_geos: Union[
coriolis_factor_u = coriolis_factor_u.filled(1)
if isinstance(coriolis_factor_v, np.ma.MaskedArray):
coriolis_factor_v = coriolis_factor_v.filled(1)
u_geos = np.nan_to_num(u_geos, nan=0, posinf=0, neginf=0)
v_geos = np.nan_to_num(v_geos, nan=0, posinf=0, neginf=0)

if method == "variational":
u_cyclo, v_cyclo = _variational(u_geos, v_geos, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v,
Expand Down Expand Up @@ -163,8 +165,6 @@ def _iterative(u_geos: np.ndarray, v_geos: np.ndarray, dx_u: np.ndarray, dx_v: n
res_n = res_init * np.ones_like(u_geos)
else:
raise ValueError("res_init should be equal to \"same\" or be a number.")
if not use_res_filter:
res_filter_size = 1

u_cyclo, v_cyclo = np.copy(u_geos), np.copy(v_geos)
res_filter = np.ones((res_filter_size, res_filter_size))
Expand All @@ -178,7 +178,8 @@ def _iterative(u_geos: np.ndarray, v_geos: np.ndarray, dx_u: np.ndarray, dx_v: n

# compute dist to u_cyclo and v_cyclo
res_np1 = np.abs(u_np1 - u_cyclo) + np.abs(v_np1 - v_cyclo)
res_np1 = signal.correlate(res_np1, res_filter, mode="same", method="direct") / res_weights # apply filter
if use_res_filter:
res_np1 = signal.correlate(res_np1, res_filter, mode="same") / res_weights # apply filter
# compute intermediate masks
mask_jnp1 = np.where(res_np1 < res_eps, 1, 0)
mask_n = np.where(res_np1 > res_n, 1, 0)
Expand Down
2 changes: 1 addition & 1 deletion jaxparrow/geostrophy.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ def geostrophy(ssh: Union[np.ndarray, np.ma.MaskedArray],
u_geos = - tools.GRAVITY * grad_ssh_y / cu
v_geos = tools.GRAVITY * grad_ssh_x / cv

return u_geos, v_geos
return np.nan_to_num(u_geos, nan=0, posinf=0, neginf=0), np.nan_to_num(v_geos, nan=0, posinf=0, neginf=0)
4 changes: 2 additions & 2 deletions jaxparrow/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .tools import compute_coriolis_factor, compute_spatial_step
from .tools import compute_coriolis_factor, compute_derivative, compute_spatial_step
from . import tools

__all__ = ["compute_coriolis_factor", "compute_spatial_step", "tools"]
__all__ = ["compute_coriolis_factor", "compute_derivative", "compute_spatial_step", "tools"]
22 changes: 11 additions & 11 deletions jaxparrow/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
GRAVITY = 9.81
P0 = np.pi / 180

__all__ = ["compute_coriolis_factor", "compute_spatial_step"]
__all__ = ["compute_coriolis_factor", "compute_derivative", "compute_spatial_step"]


# =============================================================================
Expand Down Expand Up @@ -106,18 +106,18 @@ def interpolate(field: Union[np.ndarray, np.ma.MaskedArray], axis: int = 0) -> n
return f


def _compute_derivative(field: Union[np.ndarray, np.ma.MaskedArray], dxy: Union[np.ndarray, np.ma.MaskedArray],
axis: int = 0) -> np.ndarray:
def compute_derivative(field: Union[np.ndarray, np.ma.MaskedArray], dxy: Union[np.ndarray, np.ma.MaskedArray],
axis: int = 0) -> np.ndarray:
"""Computes the x or y derivatives of a 2D field using finite differences
:param field: field values
:param field: field values, NxM grid
:type field: Union[np.ndarray, np.ma.MaskedArray]
:param dxy: spatial steps
:param dxy: spatial steps, NxM grid
:type dxy: Union[np.ndarray, np.ma.MaskedArray]
:param axis: axis along which boundary conditions are applied, defaults to 0
:type axis: int, optional
:returns: derivatives
:returns: derivatives, NxM grid
:rtype: np.ndarray
"""
f = np.copy(field)
Expand All @@ -144,7 +144,7 @@ def compute_gradient(field: Union[np.ndarray, np.ma.MaskedArray],
:returns: gradients
:rtype: Tuple[np.ndarray, np.ndarray]
"""
fx, fy = _compute_derivative(field, dx, axis=1), _compute_derivative(field, dy, axis=0)
fx, fy = compute_derivative(field, dx, axis=1), compute_derivative(field, dy, axis=0)
return fx, fy


Expand All @@ -169,10 +169,10 @@ def compute_advection_u(u: Union[np.ndarray, np.ma.MaskedArray], v: Union[np.nda
u_adv = np.copy(u)
v_adv = np.copy(v)

dudx = _compute_derivative(u, dx, axis=1) # h points
dudx = compute_derivative(u, dx, axis=1) # h points
dudx = interpolate(dudx, axis=0) # v points

dudy = _compute_derivative(u, dy, axis=0) # vorticity points
dudy = compute_derivative(u, dy, axis=0) # vorticity points
dudy = interpolate(dudy, axis=1) # v points

u_adv = interpolate(u_adv, axis=1) # h points
Expand Down Expand Up @@ -203,10 +203,10 @@ def compute_advection_v(u: Union[np.ndarray, np.ma.MaskedArray], v: Union[np.nda
u_adv = np.copy(u)
v_adv = np.copy(v)

dvdx = _compute_derivative(v, dx, axis=1) # vorticity points
dvdx = compute_derivative(v, dx, axis=1) # vorticity points
dvdx = interpolate(dvdx, axis=0) # u points

dvdy = _compute_derivative(v, dy, axis=0) # h points
dvdy = compute_derivative(v, dy, axis=0) # h points
dvdy = interpolate(dvdy, axis=1) # u points

v_adv = interpolate(v_adv, axis=1) # vorticity points
Expand Down
700 changes: 395 additions & 305 deletions notebooks/alboran_sea.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit e28d42b

Please sign in to comment.