Skip to content

Commit

Permalink
Merge pull request #63 from meom-group/fix-init-mask
Browse files Browse the repository at this point in the history
fix how the mask is initialized when unspecified
  • Loading branch information
vadmbertr authored Apr 29, 2024
2 parents 4fdbb15 + da4e250 commit b369b3d
Show file tree
Hide file tree
Showing 22 changed files with 233 additions and 202 deletions.
24 changes: 12 additions & 12 deletions jaxparrow/cyclogeostrophy.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ def cyclogeostrophy(
Cyclogeostrophic imbalance evaluated at each iteration, if ``return_losses=True``
"""
# Make sure the mask is initialized
mask = sanitize.init_mask(ssh_t, mask)
is_land = sanitize.init_land_mask(ssh_t, mask)

# Compute geostrophic SSC velocity field
u_geos_u, v_geos_v, lat_u, lon_u, lat_v, lon_v = geostrophy(ssh_t, lat_t, lon_t, mask, return_grids=True)
u_geos_u, v_geos_v, lat_u, lon_u, lat_v, lon_v = geostrophy(ssh_t, lat_t, lon_t, is_land, return_grids=True)

# Compute spatial steps and Coriolis factors
dx_u, dy_u = geometry.compute_spatial_step(lat_u, lon_u)
Expand All @@ -143,12 +143,12 @@ def cyclogeostrophy(
coriolis_factor_v = geometry.compute_coriolis_factor(lat_v)

# Handle spurious and masked data
dx_u = sanitize.sanitize_data(dx_u, jnp.nan, mask)
dy_u = sanitize.sanitize_data(dy_u, jnp.nan, mask)
dx_v = sanitize.sanitize_data(dx_v, jnp.nan, mask)
dy_v = sanitize.sanitize_data(dy_v, jnp.nan, mask)
coriolis_factor_u = sanitize.sanitize_data(coriolis_factor_u, jnp.nan, mask)
coriolis_factor_v = sanitize.sanitize_data(coriolis_factor_v, jnp.nan, mask)
dx_u = sanitize.sanitize_data(dx_u, jnp.nan, is_land)
dy_u = sanitize.sanitize_data(dy_u, jnp.nan, is_land)
dx_v = sanitize.sanitize_data(dx_v, jnp.nan, is_land)
dy_v = sanitize.sanitize_data(dy_v, jnp.nan, is_land)
coriolis_factor_u = sanitize.sanitize_data(coriolis_factor_u, jnp.nan, is_land)
coriolis_factor_v = sanitize.sanitize_data(coriolis_factor_v, jnp.nan, is_land)

if method == "variational":
if n_it is None:
Expand All @@ -160,20 +160,20 @@ def cyclogeostrophy(
elif not isinstance(optim, optax.GradientTransformation):
raise TypeError("optim should be an optax.GradientTransformation optimizer, or a string referring to such "
"an optimizer.")
res = _variational(u_geos_u, v_geos_v, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v, mask,
res = _variational(u_geos_u, v_geos_v, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v, is_land,
n_it, optim, return_losses)
elif method == "iterative":
if n_it is None:
n_it = N_IT_IT
res = _iterative(u_geos_u, v_geos_v, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v, mask,
res = _iterative(u_geos_u, v_geos_v, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v, is_land,
n_it, res_eps, use_res_filter, res_filter_size, return_losses)
else:
raise ValueError("method should be one of [\"variational\", \"iterative\"]")

# Handle masked data
u_cyclo_u, v_cyclo_v, losses = res
u_cyclo_u = sanitize.sanitize_data(u_cyclo_u, jnp.nan, mask)
v_cyclo_v = sanitize.sanitize_data(v_cyclo_v, jnp.nan, mask)
u_cyclo_u = sanitize.sanitize_data(u_cyclo_u, jnp.nan, is_land)
v_cyclo_v = sanitize.sanitize_data(v_cyclo_v, jnp.nan, is_land)

res = (u_cyclo_u, v_cyclo_v)
if return_geos:
Expand Down
14 changes: 7 additions & 7 deletions jaxparrow/geostrophy.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,23 @@ def geostrophy(
Longitudes of the V grid, if ``return_grids=True``
"""
# Make sure the mask is initialized
mask = sanitize.init_mask(ssh_t, mask)
is_land = sanitize.init_land_mask(ssh_t, mask)

# Compute spatial steps and Coriolis factors
dx_t, dy_t = geometry.compute_spatial_step(lat_t, lon_t)
coriolis_factor_t = geometry.compute_coriolis_factor(lat_t)

# Handle spurious and masked data
ssh_t = sanitize.sanitize_data(ssh_t, jnp.nan, mask) # avoid spurious velocities near the coast
dx_t = sanitize.sanitize_data(dx_t, jnp.nan, mask)
dy_t = sanitize.sanitize_data(dy_t, jnp.nan, mask)
coriolis_factor_t = sanitize.sanitize_data(coriolis_factor_t, jnp.nan, mask)
ssh_t = sanitize.sanitize_data(ssh_t, jnp.nan, is_land) # avoid spurious velocities near the coast
dx_t = sanitize.sanitize_data(dx_t, jnp.nan, is_land)
dy_t = sanitize.sanitize_data(dy_t, jnp.nan, is_land)
coriolis_factor_t = sanitize.sanitize_data(coriolis_factor_t, jnp.nan, is_land)

u_geos_u, v_geos_v = _geostrophy(ssh_t, dx_t, dy_t, coriolis_factor_t)

# Handle masked data
u_geos_u = sanitize.sanitize_data(u_geos_u, jnp.nan, mask)
v_geos_v = sanitize.sanitize_data(v_geos_v, jnp.nan, mask)
u_geos_u = sanitize.sanitize_data(u_geos_u, jnp.nan, is_land)
v_geos_v = sanitize.sanitize_data(v_geos_v, jnp.nan, is_land)

# Compute U and V grids
lat_u, lon_u, lat_v, lon_v = geometry.compute_uv_grids(lat_t, lon_t)
Expand Down
4 changes: 2 additions & 2 deletions jaxparrow/tools/kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from .geometry import compute_spatial_step, compute_coriolis_factor
from .operators import derivative, interpolation
from .sanitize import init_mask, sanitize_data
from .sanitize import init_land_mask, sanitize_data


def advection(
Expand Down Expand Up @@ -176,7 +176,7 @@ def normalized_relative_vorticity(
on the F grid (if ``interpolate=False``), or the T grid (if ``interpolate=True``)
"""
# Make sure the mask is initialized
mask = init_mask(u, mask)
mask = init_land_mask(u, mask)

# Compute spatial step and Coriolis factor
_, dy_u = compute_spatial_step(lat_u, lon_u)
Expand Down
4 changes: 2 additions & 2 deletions jaxparrow/tools/sanitize.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def sanitize_data(
return arr


def init_mask(
def init_land_mask(
field: Float[Array, "lat lon"],
mask: Float[Array, "lat lon"] = None
) -> Float[Array, "lat lon"]:
Expand All @@ -55,7 +55,7 @@ def init_mask(
Initialized (if needed) mask
"""
if mask is None:
mask = jnp.isfinite(field)
mask = ~jnp.isfinite(field)
return mask


Expand Down
244 changes: 130 additions & 114 deletions notebooks/gaussian_eddy.ipynb

Large diffs are not rendered by default.

102 changes: 60 additions & 42 deletions notebooks/gaussian_eddy/gaussian_eddy.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,19 @@ import matplotlib.pyplot as plt
import numpy as np
import optax

from jaxparrow.cyclogeostrophy import _iterative, _variational, LR_VAR
from jaxparrow.cyclogeostrophy import _iterative, _variational
from jaxparrow.geostrophy import _geostrophy
from jaxparrow.tools.kinematics import magnitude
from jaxparrow.tools.operators import interpolation
from jaxparrow.tools.sanitize import init_mask
from jaxparrow.tools.sanitize import init_land_mask

sys.path.extend([os.path.join(os.path.dirname(os.getcwd()), "tests")])
from tests import gaussian_eddy as ge # noqa

%reload_ext autoreload
%autoreload 2
%reload_ext
autoreload
%autoreload
2
```

# Gaussian eddy
Expand All @@ -30,10 +32,10 @@ We choose to use a constant spatial step in meters.
```python
# Alboran sea settings
R0 = 50e3
ETA0 = .1
ETA0 = .2
LAT = 36

dxy = 15e3
dxy = 10e3
```

## Simulating the eddy
Expand Down Expand Up @@ -177,13 +179,17 @@ plt.show()


```python
ax = plt.subplot()
ax.set_title("numerical geostrophy")
ax.set_xlabel("radial distance (m)")
ax.set_ylabel("azimuthal velocity (m/s)")
ax.scatter(R.flatten(), azim_geos_est.flatten(), s=1)
ax.vlines(R.flatten()[np.abs(azim_geos_est).flatten().argmax()],
ymin=azim_geos_est.min(), ymax=azim_geos_est.max(), colors="r", linestyles="dashed")
_, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
ax1.set_title("numerical geostrophy")
ax1.set_xlabel("radial distance (m)")
ax1.set_ylabel("azimuthal velocity (m/s)")
ax1.scatter(R.flatten(), azim_geos_est.flatten(), s=1)
ax1.vlines(R.flatten()[np.abs(azim_geos_est).flatten().argmax()],
ymin=azim_geos_est.min(), ymax=azim_geos_est.max(), colors="r", linestyles="dashed")
ax2.set_title("numerical error")
ax2.set_xlabel("radial distance (m)")
ax2.set_ylabel("absolute error (m/s)")
ax2.scatter(R.flatten(), azim_geos_est.flatten() - azim_geos.flatten(), s=1)
plt.show()
```

Expand All @@ -201,7 +207,7 @@ ge.compute_rmse(u_geos_t, u_geos_est_t), ge.compute_rmse(v_geos_t, v_geos_est_t)



(Array(0.00083074, dtype=float32), Array(0.00083074, dtype=float32))
(Array(0.0068815, dtype=float32), Array(0.0068815, dtype=float32))



Expand Down Expand Up @@ -265,16 +271,16 @@ $\mathbf{u} - \frac{\mathbf{k}}{f} \times (\mathbf{u} \cdot \nabla \mathbf{u}) =


```python
u_geos_u = interpolation(u_geos_t, axis=1, padding="right")
v_geos_v = interpolation(v_geos_t, axis=0, padding="right")
mask = init_mask(u_geos_t)
u_geos_u = u_geos_est
v_geos_v = v_geos_est
mask = init_land_mask(u_geos_t)
```

#### Variational estimation


```python
optim = optax.sgd(learning_rate=LR_VAR)
optim = optax.sgd(learning_rate=5e-2)
u_cyclo_est, v_cyclo_est, _ = _variational(u_geos_u, v_geos_v, dXY, dXY, dXY, dXY,
coriolis_factor, coriolis_factor, mask,
n_it=20, optim=optim,
Expand Down Expand Up @@ -310,13 +316,17 @@ plt.show()


```python
ax = plt.subplot()
ax.set_title("variational cyclogeostrophy")
ax.set_xlabel("radial distance (m)")
ax.set_ylabel("azimuthal velocity (m/s)")
ax.scatter(R.flatten(), azim_cyclo_est.flatten(), s=1)
ax.vlines(R.flatten()[np.abs(azim_cyclo_est).flatten().argmax()],
ymin=azim_cyclo_est.min(), ymax=azim_cyclo_est.max(), colors="r", linestyles="dashed")
_, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
ax1.set_title("variational cyclogeostrophy")
ax1.set_xlabel("radial distance (m)")
ax1.set_ylabel("azimuthal velocity (m/s)")
ax1.scatter(R.flatten(), azim_cyclo_est.flatten(), s=1)
ax1.vlines(R.flatten()[np.abs(azim_cyclo_est).flatten().argmax()],
ymin=azim_cyclo_est.min(), ymax=azim_cyclo_est.max(), colors="r", linestyles="dashed")
ax2.set_title("numerical error")
ax2.set_xlabel("radial distance (m)")
ax2.set_ylabel("absolute error (m/s)")
ax2.scatter(R.flatten(), azim_cyclo_est.flatten() - azim_cyclo.flatten(), s=1)
plt.show()
```

Expand All @@ -334,7 +344,7 @@ ge.compute_rmse(u_cyclo_t, u_cyclo_est_t), ge.compute_rmse(v_cyclo_t, v_cyclo_es



(Array(0.00273015, dtype=float32), Array(0.00273015, dtype=float32))
(Array(0.00562905, dtype=float32), Array(0.00562905, dtype=float32))



Expand Down Expand Up @@ -384,13 +394,17 @@ plt.show()


```python
ax = plt.subplot()
ax.set_title("iterative (filter) cyclogeostrophy")
ax.set_xlabel("radial distance (m)")
ax.set_ylabel("azimuthal velocity (m/s)")
ax.scatter(R.flatten(), azim_cyclo_est.flatten(), s=1)
ax.vlines(R.flatten()[np.abs(azim_cyclo_est).flatten().argmax()],
ymin=azim_cyclo_est.min(), ymax=azim_cyclo_est.max(), colors="r", linestyles="dashed")
_, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
ax1.set_title("iterative (filter) cyclogeostrophy")
ax1.set_xlabel("radial distance (m)")
ax1.set_ylabel("azimuthal velocity (m/s)")
ax1.scatter(R.flatten(), azim_cyclo_est.flatten(), s=1)
ax1.vlines(R.flatten()[np.abs(azim_cyclo_est).flatten().argmax()],
ymin=azim_cyclo_est.min(), ymax=azim_cyclo_est.max(), colors="r", linestyles="dashed")
ax2.set_title("numerical error")
ax2.set_xlabel("radial distance (m)")
ax2.set_ylabel("absolute error (m/s)")
ax2.scatter(R.flatten(), azim_cyclo_est.flatten() - azim_cyclo.flatten(), s=1)
plt.show()
```

Expand All @@ -408,7 +422,7 @@ ge.compute_rmse(u_cyclo_t, u_cyclo_est_t), ge.compute_rmse(v_cyclo_t, v_cyclo_es



(Array(0.00295354, dtype=float32), Array(0.00295354, dtype=float32))
(Array(0.00847729, dtype=float32), Array(0.00847729, dtype=float32))



Expand Down Expand Up @@ -454,13 +468,17 @@ plt.show()


```python
ax = plt.subplot()
ax.set_title("iterative cyclogeostrophy")
ax.set_xlabel("radial distance (m)")
ax.set_ylabel("azimuthal velocity (m/s)")
ax.scatter(R.flatten(), azim_cyclo_est.flatten(), s=1)
ax.vlines(R.flatten()[np.abs(azim_cyclo_est).flatten().argmax()],
ymin=azim_cyclo_est.min(), ymax=azim_cyclo_est.max(), colors="r", linestyles="dashed")
_, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
ax1.set_title("iterative cyclogeostrophy")
ax1.set_xlabel("radial distance (m)")
ax1.set_ylabel("azimuthal velocity (m/s)")
ax1.scatter(R.flatten(), azim_cyclo_est.flatten(), s=1)
ax1.vlines(R.flatten()[np.abs(azim_cyclo_est).flatten().argmax()],
ymin=azim_cyclo_est.min(), ymax=azim_cyclo_est.max(), colors="r", linestyles="dashed")
ax2.set_title("numerical error")
ax2.set_xlabel("radial distance (m)")
ax2.set_ylabel("absolute error (m/s)")
ax2.scatter(R.flatten(), azim_cyclo_est.flatten() - azim_cyclo.flatten(), s=1)
plt.show()
```

Expand All @@ -478,6 +496,6 @@ ge.compute_rmse(u_cyclo_t, u_cyclo_est_t), ge.compute_rmse(v_cyclo_t, v_cyclo_es



(Array(0.00295272, dtype=float32), Array(0.00295272, dtype=float32))
(Array(0.00861186, dtype=float32), Array(0.00861186, dtype=float32))


Binary file modified notebooks/gaussian_eddy/output_11_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified notebooks/gaussian_eddy/output_12_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified notebooks/gaussian_eddy/output_15_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified notebooks/gaussian_eddy/output_16_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified notebooks/gaussian_eddy/output_21_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified notebooks/gaussian_eddy/output_22_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified notebooks/gaussian_eddy/output_27_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified notebooks/gaussian_eddy/output_28_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified notebooks/gaussian_eddy/output_33_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified notebooks/gaussian_eddy/output_34_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified notebooks/gaussian_eddy/output_38_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified notebooks/gaussian_eddy/output_39_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified notebooks/gaussian_eddy/output_6_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified notebooks/gaussian_eddy/output_7_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion tests/gaussian_eddy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def simulate_gaussian_eddy(
latitude: int
) -> [jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array,
jax.Array, jax.Array]:
l0 = r0 * 1.5
l0 = r0 * 2 # limit boundary impact
xy = jnp.arange(0, l0, dxy)
xy = jnp.concatenate((-xy[::-1][:-1], xy))
X, Y = jnp.meshgrid(xy, xy)
Expand Down
Loading

0 comments on commit b369b3d

Please sign in to comment.