Skip to content

Commit

Permalink
Merge pull request #64 from meom-group/lax-scan
Browse files Browse the repository at this point in the history
lax.scan
  • Loading branch information
vadmbertr authored Apr 29, 2024
2 parents b369b3d + a24ba7d commit 2ea9e51
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 54 deletions.
69 changes: 25 additions & 44 deletions jaxparrow/cyclogeostrophy.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def cyclogeostrophy(
raise TypeError("optim should be an optax.GradientTransformation optimizer, or a string referring to such "
"an optimizer.")
res = _variational(u_geos_u, v_geos_v, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v, is_land,
n_it, optim, return_losses)
n_it, optim)
elif method == "iterative":
if n_it is None:
n_it = N_IT_IT
Expand Down Expand Up @@ -208,11 +208,8 @@ def _it_step(
u_cyclo: Float[Array, "lat lon"],
v_cyclo: Float[Array, "lat lon"],
mask_it: Float[Array, "lat lon"],
res_n: Float[Array, "lat lon"],
losses: Float[Array, "n_it"],
i: int
) -> [Float[Array, "lat lon"], Float[Array, "lat lon"], Float[Array, "lat lon"], Float[Array, "lat lon"],
Float[Array, "n_it"], int]:
res_n: Float[Array, "lat lon"]
) -> [[Float[Array, "lat lon"], Float[Array, "lat lon"], Float[Array, "lat lon"], Float[Array, "lat lon"], int], float]:
# next it
u_adv_v, v_adv_u = kinematics.advection(u_cyclo, v_cyclo, dx_u, dy_u, dx_v, dy_v, mask)
u_np1 = u_geos_u - v_adv_u / coriolis_factor_u
Expand All @@ -232,11 +229,12 @@ def _it_step(
mask_n = jnp.where(res_np1 <= res_n, 0, 1) # nan comp. equiv. to jnp.where(res_np1 > res_n, 1, 0)

# compute loss
losses = lax.cond(
loss = lax.cond(
return_losses,
lambda operands: operands[0].at[operands[1]].set(_cyclogeostrophic_diff(*operands[2:])),
lambda operands: operands[0],
(losses, i, u_geos_u, v_geos_v, u_cyclo, v_cyclo, u_adv_v, v_adv_u, coriolis_factor_u, coriolis_factor_v)
lambda: _cyclogeostrophic_diff(
u_geos_u, v_geos_v, u_cyclo, v_cyclo, u_adv_v, v_adv_u, coriolis_factor_u, coriolis_factor_v
),
lambda: jnp.nan
)

# update cyclogeostrophic velocities
Expand All @@ -247,9 +245,7 @@ def _it_step(
mask_it = jnp.maximum(mask_it, jnp.maximum(mask_jnp1, mask_n))
res_n = res_np1

i += 1

return u_cyclo, v_cyclo, mask_it, res_n, losses, i
return (u_cyclo, v_cyclo, mask_it, res_n), loss


@partial(jit, static_argnames=("n_it", "res_filter_size"))
Expand All @@ -274,22 +270,21 @@ def _iterative(
res_weights = jsp.signal.convolve(jnp.ones_like(u_geos_u), res_filter, mode="same", method="fft")

# define step partial: freeze constant over iterations
def step_fn(pytree):
def step_fn(carry, _):
return _it_step(
u_geos_u, v_geos_v,
dx_u, dx_v, dy_u, dy_v,
coriolis_factor_u, coriolis_factor_v, mask,
res_eps, res_filter, res_weights,
use_res_filter, return_losses,
*pytree
*carry
)

# apply updates
u_cyclo, v_cyclo, _, _, losses, _ = lax.while_loop( # noqa
lambda args: (args[-1] < n_it) | jnp.any(args[2] != 1),
(u_cyclo, v_cyclo, _, _), losses = lax.scan(
step_fn,
(u_geos_u, v_geos_v, mask.astype(int), jnp.maximum(jnp.abs(u_geos_u), jnp.abs(v_geos_v)),
jnp.ones(n_it) * jnp.nan, 0)
(u_geos_u, v_geos_v, mask.astype(int), jnp.maximum(jnp.abs(u_geos_u), jnp.abs(v_geos_v))),
xs=None, length=n_it
)

return u_cyclo, v_cyclo, losses
Expand Down Expand Up @@ -321,13 +316,10 @@ def _var_step(
mask: Float[Array, "lat lon"],
loss_fn: Callable[[[Float[Array, "lat lon"], Float[Array, "lat lon"]]], Float[Scalar, ""]],
optim: optax.GradientTransformation,
return_losses: bool,
u_cyclo_u: Float[Array, "lat lon"],
v_cyclo_v: Float[Array, "lat lon"],
opt_state: optax.OptState,
losses: Float[Array, "n_it"],
i: int
) -> [Float[Array, "lat lon"], ...]:
opt_state: optax.OptState
) -> [[Float[Array, "lat lon"], ...], float]:
params = (u_cyclo_u, v_cyclo_v)
# evaluate the cost function and compute its gradient
loss, grads = value_and_grad(loss_fn)(params)
Expand All @@ -338,16 +330,7 @@ def _var_step(
# apply updates to the parameters
u_n, v_n = optax.apply_updates(params, updates)

# store loss
losses = lax.cond(
return_losses,
lambda operands: operands[0].at[operands[1]].set(operands[2]), lambda operands: operands[0],
(losses, i, loss)
)

i += 1

return u_n, v_n, opt_state, losses, i
return (u_n, v_n, opt_state), loss


def _solve(
Expand All @@ -356,17 +339,16 @@ def _solve(
mask: Float[Array, "lat lon"],
loss_fn: Callable[[[Float[Array, "lat lon"], Float[Array, "lat lon"]]], Float[Scalar, ""]],
n_it: int,
optim: optax.GradientTransformation,
return_losses: bool
optim: optax.GradientTransformation
) -> [Float[Array, "lat lon"], ...]:
# define step partial: freeze constant over iterations
def step_fn(pytree):
return _var_step(mask, loss_fn, optim, return_losses, *pytree)
def step_fn(carry, _):
return _var_step(mask, loss_fn, optim, *carry)

u_cyclo_u, v_cyclo_v, opt_state, losses, i = lax.while_loop( # noqa
lambda args: args[-1] < n_it,
(u_cyclo_u, v_cyclo_v, _), losses = lax.scan(
step_fn,
(u_geos_u, v_geos_v, optim.init((u_geos_u, v_geos_v)), jnp.ones(n_it) * jnp.nan, 0)
(u_geos_u, v_geos_v, optim.init((u_geos_u, v_geos_v))),
xs=None, length=n_it
)

return u_cyclo_u, v_cyclo_v, losses
Expand All @@ -384,16 +366,15 @@ def _variational(
coriolis_factor_v: Float[Array, "lat lon"],
mask: Float[Array, "lat lon"],
n_it: int,
optim: optax.GradientTransformation,
return_losses: bool
optim: optax.GradientTransformation
) -> [Float[Array, "lat lon"], ...]:
# define loss partial: freeze constant over iterations
loss_fn = partial(
_var_loss_fn,
u_geos_u, v_geos_v, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v, mask
)

return _solve(u_geos_u, v_geos_v, mask, loss_fn, n_it, optim, return_losses)
return _solve(u_geos_u, v_geos_v, mask, loss_fn, n_it, optim)


def _cyclogeostrophic_diff(
Expand Down
3 changes: 2 additions & 1 deletion notebooks/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
We designed two notebooks, one with the extremely idealised scenario of a [gaussian eddy](gaussian_eddy), and a realistic one focusing on the [Alboran sea](alboran_sea) area.
We designed three notebooks, one with the idealised scenario of a [gaussian eddy](gaussian_eddy),
a realistic one (eNATL60 run) focusing on the [Alboran sea](alboran_sea) area, and its counterpart one using Duacs data.

The goal is both to showcase step-by-step how **jaxparrow** can be used, and to demonstrate the interest of the variational approach we propose.
2 changes: 1 addition & 1 deletion notebooks/alboran_sea.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
"id": "8a6a01b9",
"metadata": {},
"source": [
"# Alboran sea\n",
"# Method validation using the eNATL60 run\n",
"\n",
"## Input data\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/alboran_sea/alboran_sea.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_figsize(width_ratio, wh_ratio=1):
return fig_width, fig_height
```

# Alboran sea
# Method validation using the eNATL60 run

## Input data

Expand Down
5 changes: 2 additions & 3 deletions notebooks/gaussian_eddy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"collapsed": false
},
"source": [
"# Gaussian eddy\n",
"# Method validation in the idealized gaussian eddy scenario\n",
"\n",
"We want to use a gaussian eddy for our functional tests, as analytical solutions can be derived in that setting.\n",
"\n",
Expand Down Expand Up @@ -577,8 +577,7 @@
"optim = optax.sgd(learning_rate=5e-2)\n",
"u_cyclo_est, v_cyclo_est, _ = _variational(u_geos_u, v_geos_v, dXY, dXY, dXY, dXY,\n",
" coriolis_factor, coriolis_factor, mask,\n",
" n_it=20, optim=optim,\n",
" return_losses=False)\n",
" n_it=20, optim=optim)\n",
"\n",
"u_cyclo_est_t = interpolation(u_cyclo_est, axis=1, padding=\"left\")\n",
"v_cyclo_est_t = interpolation(v_cyclo_est, axis=0, padding=\"left\")\n",
Expand Down
5 changes: 2 additions & 3 deletions notebooks/gaussian_eddy/gaussian_eddy.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ autoreload
2
```

# Gaussian eddy
# Method validation in the idealized gaussian eddy scenario

We want to use a gaussian eddy for our functional tests, as analytical solutions can be derived in that setting.

Expand Down Expand Up @@ -283,8 +283,7 @@ mask = init_land_mask(u_geos_t)
optim = optax.sgd(learning_rate=5e-2)
u_cyclo_est, v_cyclo_est, _ = _variational(u_geos_u, v_geos_v, dXY, dXY, dXY, dXY,
coriolis_factor, coriolis_factor, mask,
n_it=20, optim=optim,
return_losses=False)
n_it=20, optim=optim)

u_cyclo_est_t = interpolation(u_cyclo_est, axis=1, padding="left")
v_cyclo_est_t = interpolation(v_cyclo_est, axis=0, padding="left")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_velocities.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_cyclogeostrophy_variational(self):
u_cyclo_est, v_cyclo_est, _ = _variational(u_geos_est, v_geos_est,
self.dXY, self.dXY, self.dXY, self.dXY,
self.coriolis_factor, self.coriolis_factor, mask,
1000, optim, False)
1000, optim)
u_cyclo_est_t = interpolation(u_cyclo_est, axis=1, padding="left")
v_cyclo_est_t = interpolation(v_cyclo_est, axis=0, padding="left")
cyclo_rmse = self.compute_rmse(self.u_cyclo, self.v_cyclo, u_cyclo_est_t, v_cyclo_est_t) # around .002
Expand Down

0 comments on commit 2ea9e51

Please sign in to comment.