diff --git a/jaxparrow/cyclogeostrophy.py b/jaxparrow/cyclogeostrophy.py index 74904f7..99b5845 100644 --- a/jaxparrow/cyclogeostrophy.py +++ b/jaxparrow/cyclogeostrophy.py @@ -161,7 +161,7 @@ def cyclogeostrophy( raise TypeError("optim should be an optax.GradientTransformation optimizer, or a string referring to such " "an optimizer.") res = _variational(u_geos_u, v_geos_v, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v, is_land, - n_it, optim, return_losses) + n_it, optim) elif method == "iterative": if n_it is None: n_it = N_IT_IT @@ -208,11 +208,8 @@ def _it_step( u_cyclo: Float[Array, "lat lon"], v_cyclo: Float[Array, "lat lon"], mask_it: Float[Array, "lat lon"], - res_n: Float[Array, "lat lon"], - losses: Float[Array, "n_it"], - i: int -) -> [Float[Array, "lat lon"], Float[Array, "lat lon"], Float[Array, "lat lon"], Float[Array, "lat lon"], - Float[Array, "n_it"], int]: + res_n: Float[Array, "lat lon"] +) -> [[Float[Array, "lat lon"], Float[Array, "lat lon"], Float[Array, "lat lon"], Float[Array, "lat lon"], int], float]: # next it u_adv_v, v_adv_u = kinematics.advection(u_cyclo, v_cyclo, dx_u, dy_u, dx_v, dy_v, mask) u_np1 = u_geos_u - v_adv_u / coriolis_factor_u @@ -232,11 +229,12 @@ def _it_step( mask_n = jnp.where(res_np1 <= res_n, 0, 1) # nan comp. equiv. to jnp.where(res_np1 > res_n, 1, 0) # compute loss - losses = lax.cond( + loss = lax.cond( return_losses, - lambda operands: operands[0].at[operands[1]].set(_cyclogeostrophic_diff(*operands[2:])), - lambda operands: operands[0], - (losses, i, u_geos_u, v_geos_v, u_cyclo, v_cyclo, u_adv_v, v_adv_u, coriolis_factor_u, coriolis_factor_v) + lambda: _cyclogeostrophic_diff( + u_geos_u, v_geos_v, u_cyclo, v_cyclo, u_adv_v, v_adv_u, coriolis_factor_u, coriolis_factor_v + ), + lambda: jnp.nan ) # update cyclogeostrophic velocities @@ -247,9 +245,7 @@ def _it_step( mask_it = jnp.maximum(mask_it, jnp.maximum(mask_jnp1, mask_n)) res_n = res_np1 - i += 1 - - return u_cyclo, v_cyclo, mask_it, res_n, losses, i + return (u_cyclo, v_cyclo, mask_it, res_n), loss @partial(jit, static_argnames=("n_it", "res_filter_size")) @@ -274,22 +270,21 @@ def _iterative( res_weights = jsp.signal.convolve(jnp.ones_like(u_geos_u), res_filter, mode="same", method="fft") # define step partial: freeze constant over iterations - def step_fn(pytree): + def step_fn(carry, _): return _it_step( u_geos_u, v_geos_v, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v, mask, res_eps, res_filter, res_weights, use_res_filter, return_losses, - *pytree + *carry ) # apply updates - u_cyclo, v_cyclo, _, _, losses, _ = lax.while_loop( # noqa - lambda args: (args[-1] < n_it) | jnp.any(args[2] != 1), + (u_cyclo, v_cyclo, _, _), losses = lax.scan( step_fn, - (u_geos_u, v_geos_v, mask.astype(int), jnp.maximum(jnp.abs(u_geos_u), jnp.abs(v_geos_v)), - jnp.ones(n_it) * jnp.nan, 0) + (u_geos_u, v_geos_v, mask.astype(int), jnp.maximum(jnp.abs(u_geos_u), jnp.abs(v_geos_v))), + xs=None, length=n_it ) return u_cyclo, v_cyclo, losses @@ -321,13 +316,10 @@ def _var_step( mask: Float[Array, "lat lon"], loss_fn: Callable[[[Float[Array, "lat lon"], Float[Array, "lat lon"]]], Float[Scalar, ""]], optim: optax.GradientTransformation, - return_losses: bool, u_cyclo_u: Float[Array, "lat lon"], v_cyclo_v: Float[Array, "lat lon"], - opt_state: optax.OptState, - losses: Float[Array, "n_it"], - i: int -) -> [Float[Array, "lat lon"], ...]: + opt_state: optax.OptState +) -> [[Float[Array, "lat lon"], ...], float]: params = (u_cyclo_u, v_cyclo_v) # evaluate the cost function and compute its gradient loss, grads = value_and_grad(loss_fn)(params) @@ -338,16 +330,7 @@ def _var_step( # apply updates to the parameters u_n, v_n = optax.apply_updates(params, updates) - # store loss - losses = lax.cond( - return_losses, - lambda operands: operands[0].at[operands[1]].set(operands[2]), lambda operands: operands[0], - (losses, i, loss) - ) - - i += 1 - - return u_n, v_n, opt_state, losses, i + return (u_n, v_n, opt_state), loss def _solve( @@ -356,17 +339,16 @@ def _solve( mask: Float[Array, "lat lon"], loss_fn: Callable[[[Float[Array, "lat lon"], Float[Array, "lat lon"]]], Float[Scalar, ""]], n_it: int, - optim: optax.GradientTransformation, - return_losses: bool + optim: optax.GradientTransformation ) -> [Float[Array, "lat lon"], ...]: # define step partial: freeze constant over iterations - def step_fn(pytree): - return _var_step(mask, loss_fn, optim, return_losses, *pytree) + def step_fn(carry, _): + return _var_step(mask, loss_fn, optim, *carry) - u_cyclo_u, v_cyclo_v, opt_state, losses, i = lax.while_loop( # noqa - lambda args: args[-1] < n_it, + (u_cyclo_u, v_cyclo_v, _), losses = lax.scan( step_fn, - (u_geos_u, v_geos_v, optim.init((u_geos_u, v_geos_v)), jnp.ones(n_it) * jnp.nan, 0) + (u_geos_u, v_geos_v, optim.init((u_geos_u, v_geos_v))), + xs=None, length=n_it ) return u_cyclo_u, v_cyclo_v, losses @@ -384,8 +366,7 @@ def _variational( coriolis_factor_v: Float[Array, "lat lon"], mask: Float[Array, "lat lon"], n_it: int, - optim: optax.GradientTransformation, - return_losses: bool + optim: optax.GradientTransformation ) -> [Float[Array, "lat lon"], ...]: # define loss partial: freeze constant over iterations loss_fn = partial( @@ -393,7 +374,7 @@ def _variational( u_geos_u, v_geos_v, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v, mask ) - return _solve(u_geos_u, v_geos_v, mask, loss_fn, n_it, optim, return_losses) + return _solve(u_geos_u, v_geos_v, mask, loss_fn, n_it, optim) def _cyclogeostrophic_diff( diff --git a/notebooks/README.md b/notebooks/README.md index d44f99d..2387425 100644 --- a/notebooks/README.md +++ b/notebooks/README.md @@ -1,3 +1,4 @@ -We designed two notebooks, one with the extremely idealised scenario of a [gaussian eddy](gaussian_eddy), and a realistic one focusing on the [Alboran sea](alboran_sea) area. +We designed three notebooks, one with the idealised scenario of a [gaussian eddy](gaussian_eddy), +a realistic one (eNATL60 run) focusing on the [Alboran sea](alboran_sea) area, and its counterpart one using Duacs data. The goal is both to showcase step-by-step how **jaxparrow** can be used, and to demonstrate the interest of the variational approach we propose. \ No newline at end of file diff --git a/notebooks/alboran_sea.ipynb b/notebooks/alboran_sea.ipynb index c0bbd5a..f9b9370 100644 --- a/notebooks/alboran_sea.ipynb +++ b/notebooks/alboran_sea.ipynb @@ -62,7 +62,7 @@ "id": "8a6a01b9", "metadata": {}, "source": [ - "# Alboran sea\n", + "# Method validation using the eNATL60 run\n", "\n", "## Input data\n", "\n", diff --git a/notebooks/alboran_sea/alboran_sea.md b/notebooks/alboran_sea/alboran_sea.md index eb8cbc3..616c4e8 100644 --- a/notebooks/alboran_sea/alboran_sea.md +++ b/notebooks/alboran_sea/alboran_sea.md @@ -32,7 +32,7 @@ def get_figsize(width_ratio, wh_ratio=1): return fig_width, fig_height ``` -# Alboran sea +# Method validation using the eNATL60 run ## Input data diff --git a/notebooks/gaussian_eddy.ipynb b/notebooks/gaussian_eddy.ipynb index c8b3b5a..c0fa874 100644 --- a/notebooks/gaussian_eddy.ipynb +++ b/notebooks/gaussian_eddy.ipynb @@ -40,7 +40,7 @@ "collapsed": false }, "source": [ - "# Gaussian eddy\n", + "# Method validation in the idealized gaussian eddy scenario\n", "\n", "We want to use a gaussian eddy for our functional tests, as analytical solutions can be derived in that setting.\n", "\n", @@ -577,8 +577,7 @@ "optim = optax.sgd(learning_rate=5e-2)\n", "u_cyclo_est, v_cyclo_est, _ = _variational(u_geos_u, v_geos_v, dXY, dXY, dXY, dXY,\n", " coriolis_factor, coriolis_factor, mask,\n", - " n_it=20, optim=optim,\n", - " return_losses=False)\n", + " n_it=20, optim=optim)\n", "\n", "u_cyclo_est_t = interpolation(u_cyclo_est, axis=1, padding=\"left\")\n", "v_cyclo_est_t = interpolation(v_cyclo_est, axis=0, padding=\"left\")\n", diff --git a/notebooks/gaussian_eddy/gaussian_eddy.md b/notebooks/gaussian_eddy/gaussian_eddy.md index f51b641..888c1ea 100644 --- a/notebooks/gaussian_eddy/gaussian_eddy.md +++ b/notebooks/gaussian_eddy/gaussian_eddy.md @@ -21,7 +21,7 @@ autoreload 2 ``` -# Gaussian eddy +# Method validation in the idealized gaussian eddy scenario We want to use a gaussian eddy for our functional tests, as analytical solutions can be derived in that setting. @@ -283,8 +283,7 @@ mask = init_land_mask(u_geos_t) optim = optax.sgd(learning_rate=5e-2) u_cyclo_est, v_cyclo_est, _ = _variational(u_geos_u, v_geos_v, dXY, dXY, dXY, dXY, coriolis_factor, coriolis_factor, mask, - n_it=20, optim=optim, - return_losses=False) + n_it=20, optim=optim) u_cyclo_est_t = interpolation(u_cyclo_est, axis=1, padding="left") v_cyclo_est_t = interpolation(v_cyclo_est, axis=0, padding="left") diff --git a/tests/test_velocities.py b/tests/test_velocities.py index a2df222..a9261f7 100644 --- a/tests/test_velocities.py +++ b/tests/test_velocities.py @@ -49,7 +49,7 @@ def test_cyclogeostrophy_variational(self): u_cyclo_est, v_cyclo_est, _ = _variational(u_geos_est, v_geos_est, self.dXY, self.dXY, self.dXY, self.dXY, self.coriolis_factor, self.coriolis_factor, mask, - 1000, optim, False) + 1000, optim) u_cyclo_est_t = interpolation(u_cyclo_est, axis=1, padding="left") v_cyclo_est_t = interpolation(v_cyclo_est, axis=0, padding="left") cyclo_rmse = self.compute_rmse(self.u_cyclo, self.v_cyclo, u_cyclo_est_t, v_cyclo_est_t) # around .002