Skip to content

Commit

Permalink
Merge pull request #29 from meom-group/issue-#24
Browse files Browse the repository at this point in the history
Issue #24
  • Loading branch information
vadmbertr authored Nov 6, 2023
2 parents bda1ce6 + 084781b commit 94ef499
Show file tree
Hide file tree
Showing 13 changed files with 424 additions and 214 deletions.
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information

project = "jaxparrow"
copyright = "2023, Victor Zaia, Vadim Bertrand, Emmanuel Cosme, Julien Le Sommer"
author = "Victor Zaia, Vadim Bertrand, Emmanuel Cosme, Julien Le Sommer"
copyright = "2023, Victor E V Z De Almeida, Vadim Bertrand, Julien Le Sommer, Emmanuel Cosme"
author = "Victor E V Z De Almeida, Vadim Bertrand, Julien Le Sommer, Emmanuel Cosme"
release = __version__

# -- General configuration ---------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks.rst → docs/examples.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Notebooks
=========
Examples
========

.. include:: ../notebooks/README.md
:parser: myst_parser.sphinx_
Expand Down
4 changes: 2 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ Contents:
.. toctree::
:maxdepth: 2

description
notebooks
Overview <overview>
examples
api
File renamed without changes.
93 changes: 50 additions & 43 deletions jaxparrow/cyclogeostrophy.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,46 +37,53 @@ def cyclogeostrophy(u_geos: Union[np.ndarray, np.ma.MaskedArray], v_geos: Union[
dx_u: np.ndarray, dx_v: np.ndarray, dy_u: np.ndarray, dy_v: np.ndarray,
coriolis_factor_u: Union[np.ndarray, np.ma.MaskedArray],
coriolis_factor_v: Union[np.ndarray, np.ma.MaskedArray],
method: Literal["variational", "penven", "ioannou"] = "variational",
method: Literal["variational", "iterative"] = "variational",
n_it: int = None, lr: float = LR_VAR, res_eps: float = RES_EPS_IT,
res_init: Union[float, Literal["same"]] = RES_INIT_IT, res_filter_size: int = RES_FILTER_SIZE_IT) \
res_init: Union[float, Literal["same"]] = RES_INIT_IT,
use_res_filter: bool = False, res_filter_size: int = RES_FILTER_SIZE_IT) \
-> Tuple[np.ndarray, np.ndarray]:
"""
Computes velocities from cyclogeostrophic approximation using a variational (default) or iterative method.
:param u_geos: U geostrophic velocity value
:param u_geos: U geostrophic velocity, NxM grid
:type u_geos: Union[np.ndarray, np.ma.MaskedArray]
:param v_geos: V geostrophic velocity value
:param v_geos: V geostrophic velocity, NxM grid
:type v_geos: Union[np.ndarray, np.ma.MaskedArray]
:param dx_u: U spatial step along x
:param dx_u: U spatial step along x, NxM grid
:type dx_u: np.ndarray
:param dx_v: V spatial step along x
:param dx_v: V spatial step along x, NxM grid
:type dx_v: np.ndarray
:param dy_u: U spatial step along y
:param dy_u: U spatial step along y, NxM grid
:type dy_u: np.ndarray
:param dy_v: V spatial step along y
:param dy_v: V spatial step along y, NxM grid
:type dy_v: np.ndarray
:param coriolis_factor_u: U Coriolis factor
:param coriolis_factor_u: U Coriolis factor, NxM grid
:type coriolis_factor_u: Union[np.ndarray, np.ma.MaskedArray]
:param coriolis_factor_v: V Coriolis factor
:param coriolis_factor_v: V Coriolis factor, NxM grid
:type coriolis_factor_v: Union[np.ndarray, np.ma.MaskedArray]
:param method: estimation method to use, defaults to "variational"
:type method: Literal["variational", "penven", "ioannou"], optional
:param n_it: maximum number of iterations, defaults to N_IT_IT
:type method: Literal["variational", "iterative"], optional
:param n_it: maximum number of iterations, defaults to N_IT_VAR or N_IT_IT based on the method argument
:type n_it: int, optional
:param lr: gradient descent learning rate, defaults to LR_VAR
:param lr: gradient descent learning rate of the variational approach, defaults to LR_VAR
:type lr: float, optional
:param res_eps: residual tolerance: if residuals are smaller, we consider them as equal to 0, defaults to EPS_IT
:param res_eps: residual tolerance of the iterative approach.
When residuals are smaller, we consider them as equal to 0.
Defaults to EPS_IT
:type res_eps: float, optional
:param res_init: residual initial value: if residuals are larger at the first iteration, we consider that the
solution diverges. If equals to "same" (default) absolute values of the geostrophic velocities are
used. Defaults to RES_INIT_IT
:param res_init: residual initial value of the iterative approach.
When residuals are larger at the first iteration, we consider that the solution diverges.
If equals to "same" (default) absolute values of the geostrophic velocities are used.
Defaults to RES_INIT_IT
:type res_init: Union[float | Literal["same"]], optional
:param res_filter_size: size of the convolution filter (from Ioannou) used when computing the residuals,
:param use_res_filter: use of a convolution filter for the iterative approach when computing the residuals
(method from Ioannou et al.) or not (original method from Penven et al.), defaults to False
:type use_res_filter: bool, optional
:param res_filter_size: size of the convolution filter used for the iterative approach when computing the residuals,
defaults to RES_FILTER_SIZE_IT
:type res_filter_size: int, optional
:returns: U and V cyclogeostrophic velocities
:returns: U and V cyclogeostrophic velocities, NxM grids
:rtype: Tuple[np.ndarray, np.ndarray]
"""
mask = np.ma.getmaskarray(u_geos).astype(int)
Expand All @@ -92,14 +99,11 @@ def cyclogeostrophy(u_geos: Union[np.ndarray, np.ma.MaskedArray], v_geos: Union[
if method == "variational":
u_cyclo, v_cyclo = _variational(u_geos, v_geos, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v,
n_it, lr)
elif method == "penven":
elif method == "iterative":
u_cyclo, v_cyclo = _iterative(u_geos, v_geos, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v,
mask, n_it, res_eps, res_init, res_filter_size=1)
elif method == "ioannou":
u_cyclo, v_cyclo = _iterative(u_geos, v_geos, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v,
mask, n_it, res_eps, res_init, res_filter_size)
mask, n_it, res_eps, res_init, use_res_filter, res_filter_size)
else:
raise ValueError("method should be one of [\"variational\", \"penven\", \"ioannou\"]")
raise ValueError("method should be one of [\"variational\", \"iterative\"]")

return u_cyclo, v_cyclo

Expand All @@ -110,8 +114,8 @@ def cyclogeostrophy(u_geos: Union[np.ndarray, np.ma.MaskedArray], v_geos: Union[

def _iterative(u_geos: np.ndarray, v_geos: np.ndarray, dx_u: np.ndarray, dx_v: np.ndarray,
dy_u: np.ndarray, dy_v: np.ndarray, coriolis_factor_u: np.ndarray, coriolis_factor_v: np.ndarray,
mask: np.ndarray, n_it: int = N_IT_IT, res_eps: float = RES_EPS_IT,
res_init: Union[float, str] = RES_INIT_IT, res_filter_size: int = RES_FILTER_SIZE_IT) \
mask: np.ndarray, n_it: int, res_eps: float, res_init: Union[float, str], use_res_filter: bool,
res_filter_size: int) \
-> Tuple[np.ndarray, np.ndarray]:
"""
Computes velocities from cyclogeostrophic approximation using the iterative method from Penven et al. (2014)
Expand All @@ -134,17 +138,19 @@ def _iterative(u_geos: np.ndarray, v_geos: np.ndarray, dx_u: np.ndarray, dx_v: n
:type coriolis_factor_v: np.ndarray
:param mask: initial data mask
:type mask: np.ndarray
:param n_it: maximum number of iterations, defaults to N_IT_IT
:type n_it: int, optional
:param res_eps: residual tolerance: if residuals are smaller, we consider them as equal to 0, defaults to EPS_IT
:type res_eps: float, optional
:param n_it: maximum number of iterations
:type n_it: int
:param res_eps: residual tolerance: if residuals are smaller, we consider them as equal to 0
:type res_eps: float
:param res_init: residual initial value: if residuals are larger at the first iteration, we consider that the
solution diverges. If equals to "same" (default) absolute values of the geostrophic velocities are
used. Defaults to RES_INIT_IT
:type res_init: float | str, optional
:param res_filter_size: size of the convolution filter (from Ioannou) used when computing the residuals,
defaults to RES_FILTER_SIZE_IT
:type res_filter_size: int, optional
used
:type res_init: float | str
:param use_res_filter: use of a convolution filter when computing the residuals (method from Ioannou et al.) or not
(original method from Penven et al.)
:type use_res_filter: bool
:param res_filter_size: size of the convolution filter used when computing the residuals
:type res_filter_size: int
:returns: U and V cyclogeostrophic velocities
:rtype: Tuple[np.ndarray, np.ndarray]
Expand All @@ -157,11 +163,12 @@ def _iterative(u_geos: np.ndarray, v_geos: np.ndarray, dx_u: np.ndarray, dx_v: n
res_n = res_init * np.ones_like(u_geos)
else:
raise ValueError("res_init should be equal to \"same\" or be a number.")
if not use_res_filter:
res_filter_size = 1

u_cyclo, v_cyclo = np.copy(u_geos), np.copy(v_geos)
res_filter = np.ones((res_filter_size, res_filter_size))
res_weights = signal.correlate(np.ones_like(u_geos), res_filter, mode="same")
res_esp = res_eps * np.ones_like(u_geos)
for _ in tqdm(range(n_it)):
# next it
advec_v = tools.compute_advection_v(u_cyclo, v_cyclo, dx_v, dy_v)
Expand All @@ -173,7 +180,7 @@ def _iterative(u_geos: np.ndarray, v_geos: np.ndarray, dx_u: np.ndarray, dx_v: n
res_np1 = np.abs(u_np1 - u_cyclo) + np.abs(v_np1 - v_cyclo)
res_np1 = signal.correlate(res_np1, res_filter, mode="same", method="direct") / res_weights # apply filter
# compute intermediate masks
mask_jnp1 = np.where(res_np1 < res_esp, 1, 0)
mask_jnp1 = np.where(res_np1 < res_eps, 1, 0)
mask_n = np.where(res_np1 > res_n, 1, 0)

# update cyclogeostrophic velocities
Expand Down Expand Up @@ -244,7 +251,7 @@ def _gradient_descent(u_geos: np.ndarray, v_geos: np.ndarray, f: Callable[[jax.A

def _variational(u_geos: np.ndarray, v_geos: np.ndarray, dx_u: np.ndarray, dx_v: np.ndarray,
dy_u: np.ndarray, dy_v: np.ndarray, coriolis_factor_u: np.ndarray, coriolis_factor_v: np.ndarray,
n_it: int = N_IT_VAR, lr: float = LR_VAR) -> Tuple[np.ndarray, np.ndarray]:
n_it: int, lr: float) -> Tuple[np.ndarray, np.ndarray]:
"""Computes the cyclogeostrophic balance using the variational method
:param u_geos: U geostrophic velocity value
Expand All @@ -263,10 +270,10 @@ def _variational(u_geos: np.ndarray, v_geos: np.ndarray, dx_u: np.ndarray, dx_v:
:type coriolis_factor_u: np.ndarray
:param coriolis_factor_v: V Coriolis factor
:type coriolis_factor_v: np.ndarray
:param n_it: maximum number of iterations, defaults to N_IT_VAR
:type n_it: int, optional
:param lr: gradient descent learning rate, defaults to LR_VAR
:type lr: float, optional
:param n_it: maximum number of iterations
:type n_it: int
:param lr: gradient descent learning rate
:type lr: float
:returns: U and V cyclogeostrophic velocities
:rtype: Tuple[np.ndarray, np.ndarray]
Expand Down
12 changes: 6 additions & 6 deletions jaxparrow/geostrophy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@ def geostrophy(ssh: Union[np.ndarray, np.ma.MaskedArray],
-> Tuple[np.ndarray, np.ndarray]:
"""Computes the geostrophic balance
:param ssh: Sea Surface Height (SSH) value
:param ssh: Sea Surface Height (SSH), NxM grid
:type ssh: Union[np.ndarray, np.ma.MaskedArray]
:param dx_ssh: SSH spatial step along x
:param dx_ssh: SSH spatial step along x, NxM grid
:type dx_ssh: Union[np.ndarray, np.ma.MaskedArray]
:param dy_ssh: SSH spatial step along y
:param dy_ssh: SSH spatial step along y, NxM grid
:type dy_ssh: Union[np.ndarray, np.ma.MaskedArray]
:param coriolis_factor_u: U Coriolis factor
:param coriolis_factor_u: U Coriolis factor, NxM grid
:type coriolis_factor_u: Union[np.ndarray, np.ma.MaskedArray]
:param coriolis_factor_v: V Coriolis factor
:param coriolis_factor_v: V Coriolis factor, NxM grid
:type coriolis_factor_v: Union[np.ndarray, np.ma.MaskedArray]
:returns: U and V geostrophic velocities
:returns: U and V geostrophic velocities, NxM grids
:rtype: Tuple[np.ndarray, np.ndarray]
"""
# Computing the gradient of the ssh
Expand Down
11 changes: 6 additions & 5 deletions jaxparrow/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
def compute_coriolis_factor(lat: Union[int, np.ndarray, np.ma.MaskedArray]) -> Union[np.ndarray, np.ma.MaskedArray]:
"""Computes the Coriolis factor from latitudes
:param lat: latitude
:param lat: latitude, NxM grid
:type lat: Union[np.ndarray, np.ma.MaskedArray]
:returns: Coriolis factor
:returns: Coriolis factor, NxM grid
:rtype: Union[np.ndarray, np.ma.MaskedArray]
"""
return 2 * EARTH_ANG_SPEED * np.sin(lat * np.pi / 180)
Expand Down Expand Up @@ -56,17 +56,18 @@ def compute_spatial_step(lat: Union[np.ndarray, np.ma.MaskedArray], lon: Union[n
"""Computes dx and dy spatial steps of a grid defined by lat, lon.
It makes use of the distance-on-a-sphere formula with Taylor expansion approximations of cos and arccos functions
to avoid truncation issues.
Applies Von Neuman boundary conditions to the spatial steps fields.
:param lat: latitude
:param lat: latitude, NxM grid
:type lat: Union[np.ndarray, np.ma.MaskedArray]
:param lon: longitude
:param lon: longitude, NxM grid
:type lon: Union[np.ndarray, np.ma.MaskedArray]
:param bounds: range of acceptable values, defaults to (1e2, 1e4). Out of this range, set to fill_value
:type bounds: Tuple[float, float], optional
:param fill_value: fill value, defaults to 1e12
:type fill_value: float, optional
:returns: dx and dy spatial steps
:returns: dx and dy spatial steps, NxM grids
:rtype: Tuple[np.ndarray, np.ndarray]
"""
dx, dy = np.zeros_like(lon), np.zeros_like(lon)
Expand Down
347 changes: 283 additions & 64 deletions notebooks/alboran_sea.ipynb

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions notebooks/alboran_sea/alboran_sea.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ plt.show()
### Compute spatial steps

The netCDF files we use as input do not contain the spatial steps required to compute derivatives later.
The sub-module `tools` provides the utility function `compute_spatial_step` to compute them from our grids.
The sub-module `tools` provides the utility function `compute_spatial_step` to compute them from our grids. It applies Von Neuman boundary conditions to those fields.

```python
dx_ssh, dy_ssh = compute_spatial_step(lat_ssh, lon_ssh)
Expand Down Expand Up @@ -219,10 +219,10 @@ plt.colorbar(im, ax=ax2)

### Penven method

We use the same function, but with the argument `method="penven"`.
We use the same function, but with the argument `method="iterative"`.

```python
u_penven, v_penven = cyclogeostrophy(u_geos, v_geos, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v, method="penven")
u_penven, v_penven = cyclogeostrophy(u_geos, v_geos, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v, method="iterative")
```

100%|██████████| 100/100 [00:00<00:00, 213.83it/s]
Expand Down Expand Up @@ -264,10 +264,10 @@ plt.colorbar(im, ax=ax2)

### Ioannou method

We use the same function, but with the argument `method="ioannou"`.
We use the same function, but with the arguments `method="iterative"`, and `use_res_filter=True`.

```python
u_ioannou, v_ioannou = cyclogeostrophy(u_geos, v_geos, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v, method="ioannou")
u_ioannou, v_ioannou = cyclogeostrophy(u_geos, v_geos, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v, method="iterative", use_res_filter=True)
```

100%|██████████| 100/100 [00:00<00:00, 112.78it/s]
Expand Down
Loading

0 comments on commit 94ef499

Please sign in to comment.