-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add more dependencies * Add SDE solver.
- Loading branch information
1 parent
33be826
commit 761ba39
Showing
4 changed files
with
220 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Simulation utilities.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
import jax | ||
import jax.random as jrandom | ||
import jax.numpy as jnp | ||
from diffrax import ( | ||
diffeqsolve, | ||
ControlTerm, | ||
Euler, | ||
MultiTerm, | ||
ODETerm, | ||
SaveAt, | ||
Solution, | ||
VirtualBrownianTree, | ||
) | ||
|
||
|
||
def simplex_complete(y: jax.Array) -> jax.Array: | ||
"""Completes the parametrization to the point on the simplex. | ||
Args: | ||
y: shape (..., dim-1) | ||
Returns: | ||
y_ext, shape (..., dim) | ||
""" | ||
ones_minus_sum = 1 - y.sum(axis=-1) | ||
ones_minus_sum_expanded = ones_minus_sum[..., jnp.newaxis] | ||
return jnp.concatenate([y, ones_minus_sum_expanded], axis=-1) | ||
|
||
|
||
def simplex_truncate(y: jax.Array) -> jax.Array: | ||
"""Truncates the last entry. | ||
Args: | ||
y: shape (..., dim) | ||
Returns: | ||
y_trunc, shape (..., dim-1) | ||
""" | ||
return y[..., :-1] | ||
|
||
|
||
def _get_drift_term(fitness: jax.Array, sigmas: jax.Array): | ||
"""Generates the drift term of the SDE: | ||
drift(t, y, args) -> jnp.ndarray | ||
Note that `t` and `args` are ignored. Moreover, `y` is assumed | ||
to be of shape (variants-1,) as the last entry being implicitly | ||
defined by the other entries due to summing up to 1 constraint. | ||
Args: | ||
fitness: fitness vector, shape (variants,) | ||
sigmas: noise vector, shape (variants,) | ||
""" | ||
|
||
def drift(t, y, args): | ||
xs = simplex_complete(y) | ||
phi = jnp.sum(xs * fitness) | ||
|
||
square_term = jnp.sum(jnp.square(sigmas * xs)) | ||
return y * ( | ||
simplex_truncate(fitness) | ||
- phi | ||
- y * jnp.square(simplex_truncate(sigmas)) | ||
+ square_term | ||
) | ||
|
||
return drift | ||
|
||
|
||
def _get_diffusion_term(sigmas: jnp.ndarray): | ||
"""Generates the diffusion term of the SDE: | ||
diffusion(t, y, args) -> jnp.ndarray | ||
Note that `t` and `args` are ignored. Moreover, `y` is assumed | ||
to be of shape (variants-1,) as the last entry being implicitly | ||
defined by the other entries due to summing up to 1 constraint. | ||
The returned array is of shape (variants-1, variants) | ||
as we have independent noise for each variant, but we still | ||
write the SDEs only for the first `variants-1` entries of `y`. | ||
""" | ||
|
||
def diffusion(t, y, args): | ||
k = y.shape[0] + 1 | ||
|
||
term1 = (y * simplex_truncate(sigmas))[:, None] * jnp.eye(k)[:-1, :] | ||
term2 = jnp.outer(y, sigmas * simplex_complete(y)) | ||
|
||
return term1 - term2 | ||
|
||
return diffusion | ||
|
||
|
||
def solve_stochastic_replicator_dynamics( | ||
y0: jax.Array, | ||
t_span: jax.Array, | ||
fitness: jax.Array, | ||
noise: jax.Array | float = 0.05, | ||
brownian_tol: float = 1e-3, | ||
solver_dt: float = 1e-2, | ||
key: jax.Array | int = 42, | ||
jit_terms: bool = False, | ||
) -> tuple[jax.Array, Solution]: | ||
"""Solves | ||
Args: | ||
y0: starting point with | ||
positive entries summing up to 1, shape (variants,) | ||
t_span: time span, shape (steps,) | ||
fitness: fitness vector, shape (variants,) | ||
noise: noise level, float or array of shape (variants,) | ||
brownian_tol: tolerance for the Brownian tree, float | ||
solver_dt: default time step for the solver, float | ||
Returns: | ||
y, shape (steps, variants) | ||
sol, diffrax's Solution. Note that the `sol.ys` is of shape (steps, variants-1) | ||
as the last entry is implicitly defined by the summing up to 1 constraint. | ||
""" | ||
# Infer the number of variants and check the dimensions | ||
dim = y0.shape[0] | ||
assert fitness.shape == (dim,), "Fitness vector has wrong shape" | ||
noise = ( | ||
jnp.ones(dim) * noise | ||
) # This should work independently on whether `noise` is float or array | ||
assert noise.shape == (dim,), "Noise vector has wrong shape" | ||
|
||
# Make sure that `key` is JAX key | ||
if isinstance(key, int): | ||
key = jrandom.PRNGKey(key) | ||
|
||
# Check solver hyperparameters | ||
assert brownian_tol > 0, "Brownian tolerance must be positive" | ||
assert solver_dt > 0, "Solver time step must be positive" | ||
|
||
t0, t1 = t_span.min(), t_span.max() | ||
|
||
# Generate the drift and diffusion terms | ||
drift = _get_drift_term(fitness, noise) | ||
diffusion = _get_diffusion_term(noise) | ||
|
||
if jit_terms: | ||
drift = jax.jit(drift) | ||
diffusion = jax.jit(diffusion) | ||
|
||
brownian_motion = VirtualBrownianTree( | ||
t0, t1, tol=brownian_tol, shape=(dim,), key=key # pyright: ignore | ||
) | ||
terms = MultiTerm(ODETerm(drift), ControlTerm(diffusion, brownian_motion)) | ||
|
||
solver = Euler() | ||
saveat = SaveAt(ts=t_span) # pyright: ignore | ||
|
||
# Note that we truncate the last entry of the solution as it is implicitly defined | ||
sol = diffeqsolve( | ||
terms, # pyright: ignore | ||
solver, | ||
t0, # pyright: ignore | ||
t1, # pyright: ignore | ||
dt0=solver_dt, # pyright: ignore | ||
y0=simplex_truncate(y0), # pyright: ignore | ||
saveat=saveat, # pyright: ignore | ||
) | ||
|
||
ys = simplex_complete(sol.ys) # pyright: ignore | ||
return ys, sol |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
import numpy.testing as npt | ||
|
||
from covvfit.simulation._sde import ( | ||
simplex_complete, | ||
solve_stochastic_replicator_dynamics, | ||
) | ||
|
||
|
||
def simplex_complete_single(y: jnp.ndarray) -> jnp.ndarray: | ||
return jnp.append(y, 1 - y.sum()) | ||
|
||
|
||
def test_simplex_complete() -> None: | ||
y0 = jnp.array([0.1, 0.2, 0.3]) | ||
|
||
npt.assert_allclose(simplex_complete(y0), simplex_complete_single(y0)) | ||
|
||
y1 = jnp.linspace(0, 0.3, 10).reshape(5, 2) | ||
|
||
npt.assert_allclose(jax.vmap(simplex_complete_single)(y1), simplex_complete(y1)) | ||
|
||
|
||
def test_solve_replicator(dim: int = 3) -> None: | ||
y0 = jnp.linspace(0.1, 0.9, dim) | ||
y0 = y0 / y0.sum() | ||
|
||
t_span = jnp.linspace(0, 0.5, 5) | ||
fitness = jnp.linspace(0.0, 2.0, dim) | ||
noise = 0.05 | ||
|
||
ys_solved, sol = solve_stochastic_replicator_dynamics( | ||
y0=y0, | ||
t_span=t_span, | ||
fitness=fitness, | ||
noise=noise, | ||
brownian_tol=0.05, | ||
solver_dt=0.05, | ||
key=42, | ||
) | ||
|
||
assert ys_solved.shape == (t_span.shape[0], dim) | ||
assert sol.ys.shape == (t_span.shape[0], dim - 1) # pyright: ignore | ||
|
||
assert ys_solved.min() >= 0 | ||
assert ys_solved.max() <= 1 |