Skip to content

Commit

Permalink
Merge pull request #67 from meom-group/nan-handling
Browse files Browse the repository at this point in the history
Better handling of nan/land value
  • Loading branch information
vadmbertr authored Jun 25, 2024
2 parents ff25e3e + 124fe34 commit e7308ae
Show file tree
Hide file tree
Showing 13 changed files with 548 additions and 537 deletions.
25 changes: 7 additions & 18 deletions jaxparrow/cyclogeostrophy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
N_IT_IT = 20
#: Default residual tolerance for Penven and Ioannou approaches
RES_EPS_IT = 0.01
#: Default residual value used during the first iteration for Penven and Ioannou approaches
RES_INIT_IT = "same"
#: Default size of the grid points used to compute the residual for Ioannou's approach
RES_FILTER_SIZE_IT = 3

Expand Down Expand Up @@ -143,12 +141,8 @@ def cyclogeostrophy(
coriolis_factor_v = geometry.compute_coriolis_factor(lat_v)

# Handle spurious and masked data
dx_u = sanitize.sanitize_data(dx_u, jnp.nan, is_land)
dy_u = sanitize.sanitize_data(dy_u, jnp.nan, is_land)
dx_v = sanitize.sanitize_data(dx_v, jnp.nan, is_land)
dy_v = sanitize.sanitize_data(dy_v, jnp.nan, is_land)
coriolis_factor_u = sanitize.sanitize_data(coriolis_factor_u, jnp.nan, is_land)
coriolis_factor_v = sanitize.sanitize_data(coriolis_factor_v, jnp.nan, is_land)
u_geos_u = sanitize.sanitize_data(u_geos_u, 0, is_land)
v_geos_v = sanitize.sanitize_data(v_geos_v, 0, is_land)

if method == "variational":
if n_it is None:
Expand Down Expand Up @@ -217,7 +211,6 @@ def _it_step(

# compute dist to u_cyclo and v_cyclo
res_np1 = jnp.abs(u_np1 - u_cyclo) + jnp.abs(v_np1 - v_cyclo)
res_np1 = sanitize.sanitize_data(res_np1, 0., mask)
res_np1 = lax.cond(
use_res_filter, # apply filter
lambda operands: jsp.signal.convolve(operands[0], operands[1], mode="same", method="fft") / operands[2],
Expand Down Expand Up @@ -313,7 +306,6 @@ def _var_loss_fn(


def _var_step(
mask: Float[Array, "lat lon"],
loss_fn: Callable[[[Float[Array, "lat lon"], Float[Array, "lat lon"]]], Float[Scalar, ""]],
optim: optax.GradientTransformation,
u_cyclo_u: Float[Array, "lat lon"],
Expand All @@ -323,8 +315,6 @@ def _var_step(
params = (u_cyclo_u, v_cyclo_v)
# evaluate the cost function and compute its gradient
loss, grads = value_and_grad(loss_fn)(params)
# make sure to remove nan values
grads = (sanitize.sanitize_data(grads[0], 0., mask), sanitize.sanitize_data(grads[1], 0., mask))
# update the optimizer
updates, opt_state = optim.update(grads, opt_state, params)
# apply updates to the parameters
Expand All @@ -336,14 +326,13 @@ def _var_step(
def _solve(
u_geos_u: Float[Array, "lat lon"],
v_geos_v: Float[Array, "lat lon"],
mask: Float[Array, "lat lon"],
loss_fn: Callable[[[Float[Array, "lat lon"], Float[Array, "lat lon"]]], Float[Scalar, ""]],
n_it: int,
optim: optax.GradientTransformation
) -> [Float[Array, "lat lon"], ...]:
# define step partial: freeze constant over iterations
def step_fn(carry, _):
return _var_step(mask, loss_fn, optim, *carry)
return _var_step(loss_fn, optim, *carry)

(u_cyclo_u, v_cyclo_v, _), losses = lax.scan(
step_fn,
Expand Down Expand Up @@ -374,7 +363,7 @@ def _variational(
u_geos_u, v_geos_v, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v, mask
)

return _solve(u_geos_u, v_geos_v, mask, loss_fn, n_it, optim)
return _solve(u_geos_u, v_geos_v, loss_fn, n_it, optim)


def _cyclogeostrophic_diff(
Expand All @@ -387,6 +376,6 @@ def _cyclogeostrophic_diff(
coriolis_factor_u: Float[Array, "lat lon"],
coriolis_factor_v: Float[Array, "lat lon"]
) -> Float[Scalar, ""]:
J_u = jnp.nansum((u_cyclo_u + v_adv_u / coriolis_factor_u - u_geos_u) ** 2)
J_v = jnp.nansum((v_cyclo_v - u_adv_v / coriolis_factor_v - v_geos_v) ** 2)
return J_u + J_v
j_u = ((u_cyclo_u + v_adv_u / coriolis_factor_u - u_geos_u) ** 2).sum()
j_v = ((v_cyclo_v - u_adv_v / coriolis_factor_v - v_geos_v) ** 2).sum()
return j_u + j_v
34 changes: 18 additions & 16 deletions jaxparrow/geostrophy.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,9 @@ def geostrophy(
coriolis_factor_t = geometry.compute_coriolis_factor(lat_t)

# Handle spurious and masked data
ssh_t = sanitize.sanitize_data(ssh_t, jnp.nan, is_land) # avoid spurious velocities near the coast
dx_t = sanitize.sanitize_data(dx_t, jnp.nan, is_land)
dy_t = sanitize.sanitize_data(dy_t, jnp.nan, is_land)
coriolis_factor_t = sanitize.sanitize_data(coriolis_factor_t, jnp.nan, is_land)
ssh_t = sanitize.sanitize_data(ssh_t, 0, is_land)

u_geos_u, v_geos_v = _geostrophy(ssh_t, dx_t, dy_t, coriolis_factor_t)
u_geos_u, v_geos_v = _geostrophy(ssh_t, dx_t, dy_t, coriolis_factor_t, is_land)

# Handle masked data
u_geos_u = sanitize.sanitize_data(u_geos_u, jnp.nan, is_land)
Expand All @@ -87,21 +84,26 @@ def _geostrophy(
ssh_t: Float[Array, "lat lon"],
dx_t: Float[Array, "lat lon"],
dy_t: Float[Array, "lat lon"],
coriolis_factor_t: Float[Array, "lat lon"]
coriolis_factor_t: Float[Array, "lat lon"],
mask: Float[Array, "lat lon"]
) -> [Float[Array, "lat lon"], Float[Array, "lat lon"]]:
# Compute the gradient of the ssh
ssh_dx_u = operators.derivative(ssh_t, dx_t, axis=1, padding="right") # (T(i), T(i+1)) -> U(i)
ssh_dy_v = operators.derivative(ssh_t, dy_t, axis=0, padding="right") # (T(j), T(j+1)) -> V(j)
ssh_dx_u = operators.derivative(ssh_t, dx_t, mask, axis=1, padding="right") # (T(i), T(i+1)) -> U(i)
ssh_dy_v = operators.derivative(ssh_t, dy_t, mask, axis=0, padding="right") # (T(j), T(j+1)) -> V(j)

# Interpolate the data
ssh_dy_t = operators.interpolation(ssh_dy_v, axis=0, padding="left") # (V(j), V(j+1)) -> T(j+1)
ssh_dy_u = operators.interpolation(ssh_dy_t, axis=1, padding="right") # (T(i), T(i+1)) -> U(i)

ssh_dx_t = operators.interpolation(ssh_dx_u, axis=1, padding="left") # (U(i), U(i+1)) -> T(i+1)
ssh_dx_v = operators.interpolation(ssh_dx_t, axis=0, padding="right") # (T(j), T(j+1)) -> V(j)

coriolis_factor_u = operators.interpolation(coriolis_factor_t, axis=1, padding="right") # (T(i), T(i+1)) -> U(i)
coriolis_factor_v = operators.interpolation(coriolis_factor_t, axis=0, padding="right") # (T(j), T(j+1)) -> V(j)
ssh_dy_t = operators.interpolation(ssh_dy_v, mask, axis=0, padding="left") # (V(j), V(j+1)) -> T(j+1)
ssh_dy_u = operators.interpolation(ssh_dy_t, mask, axis=1, padding="right") # (T(i), T(i+1)) -> U(i)

ssh_dx_t = operators.interpolation(ssh_dx_u, mask, axis=1, padding="left") # (U(i), U(i+1)) -> T(i+1)
ssh_dx_v = operators.interpolation(ssh_dx_t, mask, axis=0, padding="right") # (T(j), T(j+1)) -> V(j)

coriolis_factor_u = operators.interpolation(
coriolis_factor_t, mask, axis=1, padding="right"
) # (T(i), T(i+1)) -> U(i)
coriolis_factor_v = operators.interpolation(
coriolis_factor_t, mask, axis=0, padding="right"
) # (T(j), T(j+1)) -> V(j)

# Computing the geostrophic velocities
u_geos_u = - geometry.GRAVITY * ssh_dy_u / coriolis_factor_u # U(i)
Expand Down
11 changes: 7 additions & 4 deletions jaxparrow/tools/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,17 @@ def compute_uv_grids(
lon_v : Float[Array, "lat lon"]
Longitudes of the V grid
"""
lat_u = interpolation(lat_t, axis=1, padding="right")
lat_mask = jnp.zeros_like(lat_t, dtype=bool)
lon_mask = jnp.zeros_like(lon_t, dtype=bool)

lat_u = interpolation(lat_t, lat_mask, axis=1, padding="right")
lat_u = lat_u.at[:, -1].set(2 * lat_t[:, -1] - lat_t[:, -2])
lon_u = interpolation(lon_t, axis=1, padding="right")
lon_u = interpolation(lon_t, lon_mask, axis=1, padding="right")
lon_u = lon_u.at[:, -1].set(2 * lon_t[:, -1] - lon_t[:, -2])

lat_v = interpolation(lat_t, axis=0, padding="right")
lat_v = interpolation(lat_t, lat_mask, axis=0, padding="right")
lat_v = lat_v.at[-1, :].set(2 * lat_t[-1, :] - lat_t[-2, :])
lon_v = interpolation(lon_t, axis=0, padding="right")
lon_v = interpolation(lon_t, lon_mask, axis=0, padding="right")
lon_v = lon_v.at[-1, :].set(2 * lon_t[-1, :] - lon_t[-2, :])

return lat_u, lon_u, lat_v, lon_v
67 changes: 41 additions & 26 deletions jaxparrow/tools/kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,16 @@ def _u_advection_v(
dy_u: Float[Array, "lat lon"],
mask: Float[Array, "lat lon"]
) -> Float[Array, "lat lon"]:
dudx_t = derivative(u_u, dx_u, axis=1, padding="left") # (U(i), U(i+1)) -> T(i+1)
dudx_v = interpolation(dudx_t, axis=0, padding="right") # (T(j), T(j+1)) -> V(j)
dudx_t = derivative(u_u, dx_u, mask, axis=1, padding="left") # (U(i), U(i+1)) -> T(i+1)
dudx_v = interpolation(dudx_t, mask, axis=0, padding="right") # (T(j), T(j+1)) -> V(j)

dudy_f = derivative(u_u, dy_u, axis=0, padding="right") # (U(j), U(j+1)) -> F(j)
dudy_v = interpolation(dudy_f, axis=1, padding="left") # (F(i), F(i+1)) -> V(i+1)
dudy_f = derivative(u_u, dy_u, mask, axis=0, padding="right") # (U(j), U(j+1)) -> F(j)
dudy_v = interpolation(dudy_f, mask, axis=1, padding="left") # (F(i), F(i+1)) -> V(i+1)

u_t = interpolation(u_u, axis=1, padding="left") # (U(i), U(i+1)) -> T(i+1)
u_v = interpolation(u_t, axis=0, padding="right") # (T(j), T(j+1)) -> V(j)
u_t = interpolation(u_u, mask, axis=1, padding="left") # (U(i), U(i+1)) -> T(i+1)
u_v = interpolation(u_t, mask, axis=0, padding="right") # (T(j), T(j+1)) -> V(j)

u_adv_v = u_v * dudx_v + v_v * dudy_v # V(j)
u_adv_v = sanitize_data(u_adv_v, 0., mask)

return u_adv_v

Expand All @@ -77,24 +76,24 @@ def _v_advection_u(
dy_v: Float[Array, "lat lon"],
mask: Float[Array, "lat lon"]
) -> Float[Array, "lat lon"]:
dvdx_f = derivative(v_v, dx_v, axis=1, padding="right") # (V(i), V(i+1)) -> F(i)
dvdx_u = interpolation(dvdx_f, axis=0, padding="left") # (F(j), F(j+1)) -> U(j+1)
dvdx_f = derivative(v_v, dx_v, mask, axis=1, padding="right") # (V(i), V(i+1)) -> F(i)
dvdx_u = interpolation(dvdx_f, mask, axis=0, padding="left") # (F(j), F(j+1)) -> U(j+1)

dvdy_t = derivative(v_v, dy_v, axis=0, padding="left") # (V(j), V(j+1)) -> T(j+1)
dvdy_u = interpolation(dvdy_t, axis=1, padding="right") # (T(i), T(i+1)) -> U(i)
dvdy_t = derivative(v_v, dy_v, mask, axis=0, padding="left") # (V(j), V(j+1)) -> T(j+1)
dvdy_u = interpolation(dvdy_t, mask, axis=1, padding="right") # (T(i), T(i+1)) -> U(i)

v_t = interpolation(v_v, axis=0, padding="left") # (V(j), V(j+1)) -> T(j+1)
v_u = interpolation(v_t, axis=1, padding="right") # (T(i), T(i+1)) -> U(i)
v_t = interpolation(v_v, mask, axis=0, padding="left") # (V(j), V(j+1)) -> T(j+1)
v_u = interpolation(v_t, mask, axis=1, padding="right") # (T(i), T(i+1)) -> U(i)

v_adv_u = u_u * dvdx_u + v_u * dvdy_u # U(i)
v_adv_u = sanitize_data(v_adv_u, 0., mask)

return v_adv_u


def magnitude(
u: Float[Array, "lat lon"],
v: Float[Array, "lat lon"],
mask: Float[Array, "lat lon"] = None,
interpolate: bool = True
) -> Float[Array, "lat lon"]:
"""
Expand All @@ -107,6 +106,10 @@ def magnitude(
U component of the velocity field (on the U or T grid)
v : Float[Array, "lat lon"]
V component of the velocity field (on the V or T grid)
mask : Float[Array, "lat lon"], optional
Mask defining the marine area of the spatial domain; `1` or `True` stands for masked (i.e. land)
If not provided, inferred from ``u`` `nan` values
interpolate : bool, optional
If `True`, the velocity components are assumed to be located on the U and V grids,
and are interpolated to the T one (following NEMO convention [1]_).
Expand All @@ -119,14 +122,18 @@ def magnitude(
magn_t : Float[Array, "lat lon"]
Magnitude of the velocity field, on the T grid
"""
# Make sure the mask is initialized
mask = init_land_mask(u, mask)

if interpolate:
# interpolate to the T point
u_t = interpolation(u, axis=1, padding="left") # (U(i), U(i+1)) -> T(i+1)
v_t = interpolation(v, axis=0, padding="left") # (V(j), V(j+1)) -> T(j+1)
u_t = interpolation(u, mask, axis=1, padding="left") # (U(i), U(i+1)) -> T(i+1)
v_t = interpolation(v, mask, axis=0, padding="left") # (V(j), V(j+1)) -> T(j+1)
else:
u_t, v_t = u, v

magn_t = jnp.sqrt(u_t ** 2 + v_t ** 2)
magn_t = sanitize_data(magn_t, jnp.nan, mask)

return magn_t

Expand Down Expand Up @@ -184,19 +191,19 @@ def normalized_relative_vorticity(
f_u = compute_coriolis_factor(lat_u)

# Handle spurious data and apply mask
dy_u = sanitize_data(dy_u, jnp.nan, mask)
dx_v = sanitize_data(dx_v, jnp.nan, mask)
f_u = sanitize_data(f_u, jnp.nan, mask)
# dy_u = sanitize_data(dy_u, jnp.nan, mask)
# dx_v = sanitize_data(dx_v, jnp.nan, mask)
# f_u = sanitize_data(f_u, jnp.nan, mask)

# Compute the normalized relative vorticity
du_dy_f = derivative(u, dy_u, axis=0, padding="right") # (U(j), U(j+1)) -> F(j)
dv_dx_f = derivative(v, dx_v, axis=1, padding="right") # (V(i), V(i+1)) -> F(i)
f_f = interpolation(f_u, axis=0, padding="right") # (U(j), U(j+1)) -> F(j)
du_dy_f = derivative(u, dy_u, mask, axis=0, padding="right") # (U(j), U(j+1)) -> F(j)
dv_dx_f = derivative(v, dx_v, mask, axis=1, padding="right") # (V(i), V(i+1)) -> F(i)
f_f = interpolation(f_u, mask, axis=0, padding="right") # (U(j), U(j+1)) -> F(j)
w_f = (dv_dx_f - du_dy_f) / f_f # F(j)

if interpolate:
w_u = interpolation(w_f, axis=0, padding="left") # (F(j), F(j+1)) -> U(j+1)
w = interpolation(w_u, axis=1, padding="left") # (U(i), U(i+1)) -> T(i+1)
w_u = interpolation(w_f, mask, axis=0, padding="left") # (F(j), F(j+1)) -> U(j+1)
w = interpolation(w_u, mask, axis=1, padding="left") # (U(i), U(i+1)) -> T(i+1)
else:
w = w_f

Expand All @@ -208,6 +215,7 @@ def normalized_relative_vorticity(
def kinetic_energy(
u: Float[Array, "lat lon"],
v: Float[Array, "lat lon"],
mask: Float[Array, "lat lon"] = None,
interpolate: bool = True
) -> Float[Array, "lat lon"]:
"""
Expand All @@ -220,6 +228,10 @@ def kinetic_energy(
U component of the velocity field (on the U grid)
v : Float[Array, "lat lon"]
V component of the velocity field (on the V grid)
mask : Float[Array, "lat lon"], optional
Mask defining the marine area of the spatial domain; `1` or `True` stands for masked (i.e. land)
If not provided, inferred from ``u`` `nan` values
interpolate : bool, optional
If `True`, the velocity components are assumed to be located on the U and V grids,
and are interpolated to the T one (following NEMO convention [1]_).
Expand All @@ -232,10 +244,13 @@ def kinetic_energy(
eke : Float[Array, "lat lon"]
The Kinetic Energy on the T grid
"""
# Make sure the mask is initialized
mask = init_land_mask(u, mask)

if interpolate:
# interpolate to the T point
u_t = interpolation(u, axis=1, padding="left") # (U(i), U(i+1)) -> T(i+1)
v_t = interpolation(v, axis=0, padding="left") # (V(j), V(j+1)) -> T(j+1)
u_t = interpolation(u, mask, axis=1, padding="left") # (U(i), U(i+1)) -> T(i+1)
v_t = interpolation(v, mask, axis=0, padding="left") # (V(j), V(j+1)) -> T(j+1)
else:
u_t, v_t = u, v

Expand Down
Loading

0 comments on commit e7308ae

Please sign in to comment.