Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lax.scan #64

Merged
merged 2 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 25 additions & 44 deletions jaxparrow/cyclogeostrophy.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def cyclogeostrophy(
raise TypeError("optim should be an optax.GradientTransformation optimizer, or a string referring to such "
"an optimizer.")
res = _variational(u_geos_u, v_geos_v, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v, is_land,
n_it, optim, return_losses)
n_it, optim)
elif method == "iterative":
if n_it is None:
n_it = N_IT_IT
Expand Down Expand Up @@ -208,11 +208,8 @@ def _it_step(
u_cyclo: Float[Array, "lat lon"],
v_cyclo: Float[Array, "lat lon"],
mask_it: Float[Array, "lat lon"],
res_n: Float[Array, "lat lon"],
losses: Float[Array, "n_it"],
i: int
) -> [Float[Array, "lat lon"], Float[Array, "lat lon"], Float[Array, "lat lon"], Float[Array, "lat lon"],
Float[Array, "n_it"], int]:
res_n: Float[Array, "lat lon"]
) -> [[Float[Array, "lat lon"], Float[Array, "lat lon"], Float[Array, "lat lon"], Float[Array, "lat lon"], int], float]:
# next it
u_adv_v, v_adv_u = kinematics.advection(u_cyclo, v_cyclo, dx_u, dy_u, dx_v, dy_v, mask)
u_np1 = u_geos_u - v_adv_u / coriolis_factor_u
Expand All @@ -232,11 +229,12 @@ def _it_step(
mask_n = jnp.where(res_np1 <= res_n, 0, 1) # nan comp. equiv. to jnp.where(res_np1 > res_n, 1, 0)

# compute loss
losses = lax.cond(
loss = lax.cond(
return_losses,
lambda operands: operands[0].at[operands[1]].set(_cyclogeostrophic_diff(*operands[2:])),
lambda operands: operands[0],
(losses, i, u_geos_u, v_geos_v, u_cyclo, v_cyclo, u_adv_v, v_adv_u, coriolis_factor_u, coriolis_factor_v)
lambda: _cyclogeostrophic_diff(
u_geos_u, v_geos_v, u_cyclo, v_cyclo, u_adv_v, v_adv_u, coriolis_factor_u, coriolis_factor_v
),
lambda: jnp.nan
)

# update cyclogeostrophic velocities
Expand All @@ -247,9 +245,7 @@ def _it_step(
mask_it = jnp.maximum(mask_it, jnp.maximum(mask_jnp1, mask_n))
res_n = res_np1

i += 1

return u_cyclo, v_cyclo, mask_it, res_n, losses, i
return (u_cyclo, v_cyclo, mask_it, res_n), loss


@partial(jit, static_argnames=("n_it", "res_filter_size"))
Expand All @@ -274,22 +270,21 @@ def _iterative(
res_weights = jsp.signal.convolve(jnp.ones_like(u_geos_u), res_filter, mode="same", method="fft")

# define step partial: freeze constant over iterations
def step_fn(pytree):
def step_fn(carry, _):
return _it_step(
u_geos_u, v_geos_v,
dx_u, dx_v, dy_u, dy_v,
coriolis_factor_u, coriolis_factor_v, mask,
res_eps, res_filter, res_weights,
use_res_filter, return_losses,
*pytree
*carry
)

# apply updates
u_cyclo, v_cyclo, _, _, losses, _ = lax.while_loop( # noqa
lambda args: (args[-1] < n_it) | jnp.any(args[2] != 1),
(u_cyclo, v_cyclo, _, _), losses = lax.scan(
step_fn,
(u_geos_u, v_geos_v, mask.astype(int), jnp.maximum(jnp.abs(u_geos_u), jnp.abs(v_geos_v)),
jnp.ones(n_it) * jnp.nan, 0)
(u_geos_u, v_geos_v, mask.astype(int), jnp.maximum(jnp.abs(u_geos_u), jnp.abs(v_geos_v))),
xs=None, length=n_it
)

return u_cyclo, v_cyclo, losses
Expand Down Expand Up @@ -321,13 +316,10 @@ def _var_step(
mask: Float[Array, "lat lon"],
loss_fn: Callable[[[Float[Array, "lat lon"], Float[Array, "lat lon"]]], Float[Scalar, ""]],
optim: optax.GradientTransformation,
return_losses: bool,
u_cyclo_u: Float[Array, "lat lon"],
v_cyclo_v: Float[Array, "lat lon"],
opt_state: optax.OptState,
losses: Float[Array, "n_it"],
i: int
) -> [Float[Array, "lat lon"], ...]:
opt_state: optax.OptState
) -> [[Float[Array, "lat lon"], ...], float]:
params = (u_cyclo_u, v_cyclo_v)
# evaluate the cost function and compute its gradient
loss, grads = value_and_grad(loss_fn)(params)
Expand All @@ -338,16 +330,7 @@ def _var_step(
# apply updates to the parameters
u_n, v_n = optax.apply_updates(params, updates)

# store loss
losses = lax.cond(
return_losses,
lambda operands: operands[0].at[operands[1]].set(operands[2]), lambda operands: operands[0],
(losses, i, loss)
)

i += 1

return u_n, v_n, opt_state, losses, i
return (u_n, v_n, opt_state), loss


def _solve(
Expand All @@ -356,17 +339,16 @@ def _solve(
mask: Float[Array, "lat lon"],
loss_fn: Callable[[[Float[Array, "lat lon"], Float[Array, "lat lon"]]], Float[Scalar, ""]],
n_it: int,
optim: optax.GradientTransformation,
return_losses: bool
optim: optax.GradientTransformation
) -> [Float[Array, "lat lon"], ...]:
# define step partial: freeze constant over iterations
def step_fn(pytree):
return _var_step(mask, loss_fn, optim, return_losses, *pytree)
def step_fn(carry, _):
return _var_step(mask, loss_fn, optim, *carry)

u_cyclo_u, v_cyclo_v, opt_state, losses, i = lax.while_loop( # noqa
lambda args: args[-1] < n_it,
(u_cyclo_u, v_cyclo_v, _), losses = lax.scan(
step_fn,
(u_geos_u, v_geos_v, optim.init((u_geos_u, v_geos_v)), jnp.ones(n_it) * jnp.nan, 0)
(u_geos_u, v_geos_v, optim.init((u_geos_u, v_geos_v))),
xs=None, length=n_it
)

return u_cyclo_u, v_cyclo_v, losses
Expand All @@ -384,16 +366,15 @@ def _variational(
coriolis_factor_v: Float[Array, "lat lon"],
mask: Float[Array, "lat lon"],
n_it: int,
optim: optax.GradientTransformation,
return_losses: bool
optim: optax.GradientTransformation
) -> [Float[Array, "lat lon"], ...]:
# define loss partial: freeze constant over iterations
loss_fn = partial(
_var_loss_fn,
u_geos_u, v_geos_v, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v, mask
)

return _solve(u_geos_u, v_geos_v, mask, loss_fn, n_it, optim, return_losses)
return _solve(u_geos_u, v_geos_v, mask, loss_fn, n_it, optim)


def _cyclogeostrophic_diff(
Expand Down
3 changes: 2 additions & 1 deletion notebooks/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
We designed two notebooks, one with the extremely idealised scenario of a [gaussian eddy](gaussian_eddy), and a realistic one focusing on the [Alboran sea](alboran_sea) area.
We designed three notebooks, one with the idealised scenario of a [gaussian eddy](gaussian_eddy),
a realistic one (eNATL60 run) focusing on the [Alboran sea](alboran_sea) area, and its counterpart one using Duacs data.

The goal is both to showcase step-by-step how **jaxparrow** can be used, and to demonstrate the interest of the variational approach we propose.
2 changes: 1 addition & 1 deletion notebooks/alboran_sea.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
"id": "8a6a01b9",
"metadata": {},
"source": [
"# Alboran sea\n",
"# Method validation using the eNATL60 run\n",
"\n",
"## Input data\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/alboran_sea/alboran_sea.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_figsize(width_ratio, wh_ratio=1):
return fig_width, fig_height
```

# Alboran sea
# Method validation using the eNATL60 run

## Input data

Expand Down
5 changes: 2 additions & 3 deletions notebooks/gaussian_eddy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"collapsed": false
},
"source": [
"# Gaussian eddy\n",
"# Method validation in the idealized gaussian eddy scenario\n",
"\n",
"We want to use a gaussian eddy for our functional tests, as analytical solutions can be derived in that setting.\n",
"\n",
Expand Down Expand Up @@ -577,8 +577,7 @@
"optim = optax.sgd(learning_rate=5e-2)\n",
"u_cyclo_est, v_cyclo_est, _ = _variational(u_geos_u, v_geos_v, dXY, dXY, dXY, dXY,\n",
" coriolis_factor, coriolis_factor, mask,\n",
" n_it=20, optim=optim,\n",
" return_losses=False)\n",
" n_it=20, optim=optim)\n",
"\n",
"u_cyclo_est_t = interpolation(u_cyclo_est, axis=1, padding=\"left\")\n",
"v_cyclo_est_t = interpolation(v_cyclo_est, axis=0, padding=\"left\")\n",
Expand Down
5 changes: 2 additions & 3 deletions notebooks/gaussian_eddy/gaussian_eddy.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ autoreload
2
```

# Gaussian eddy
# Method validation in the idealized gaussian eddy scenario

We want to use a gaussian eddy for our functional tests, as analytical solutions can be derived in that setting.

Expand Down Expand Up @@ -283,8 +283,7 @@ mask = init_land_mask(u_geos_t)
optim = optax.sgd(learning_rate=5e-2)
u_cyclo_est, v_cyclo_est, _ = _variational(u_geos_u, v_geos_v, dXY, dXY, dXY, dXY,
coriolis_factor, coriolis_factor, mask,
n_it=20, optim=optim,
return_losses=False)
n_it=20, optim=optim)

u_cyclo_est_t = interpolation(u_cyclo_est, axis=1, padding="left")
v_cyclo_est_t = interpolation(v_cyclo_est, axis=0, padding="left")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_velocities.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_cyclogeostrophy_variational(self):
u_cyclo_est, v_cyclo_est, _ = _variational(u_geos_est, v_geos_est,
self.dXY, self.dXY, self.dXY, self.dXY,
self.coriolis_factor, self.coriolis_factor, mask,
1000, optim, False)
1000, optim)
u_cyclo_est_t = interpolation(u_cyclo_est, axis=1, padding="left")
v_cyclo_est_t = interpolation(v_cyclo_est, axis=0, padding="left")
cyclo_rmse = self.compute_rmse(self.u_cyclo, self.v_cyclo, u_cyclo_est_t, v_cyclo_est_t) # around .002
Expand Down
Loading