Skip to content

Commit

Permalink
add drift simulation
Browse files Browse the repository at this point in the history
  • Loading branch information
dnguyend committed Jul 24, 2024
1 parent 9102ef7 commit 829c06c
Show file tree
Hide file tree
Showing 4 changed files with 837 additions and 2 deletions.
37 changes: 37 additions & 0 deletions jax_rb/simulation/global_manifold_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@ def geodesic_move(mnf, x, unit_move, scale):
return mnf.retract(x, mnf.proj(x, mnf.sigma(x, unit_move.reshape(mnf.shape)*jnp.sqrt(scale))))


@partial(jit, static_argnums=(0,))
def geodesic_move_with_drift(mnf, x, unit_move, scale, additional_drift):
""" This method is used to simulate a Riemanian Brownian motion with drift. The additional_drift
is on top of the Brownian motion.
Simulate using a second order retraction.
The move is :math:`x_{new} = \\mathfrak{r}(x, \\Pi(x)\\sigma(x)(\\text{unit_move}(\\text{scale})^{\\frac{1}{2}}+\\text{scale}\\times\\text{additional_drift}))`
"""
return mnf.retract(x, mnf.proj(x, mnf.sigma(x, unit_move.reshape(mnf.shape)*jnp.sqrt(scale))
+ scale*additional_drift))



@partial(jit, static_argnums=(0,))
def geodesic_move_normalized(mnf, x, unit_move, scale):
""" similar to geodesic_move, but the move is normalized to have fixed length :math:`scale^{\\frac{1}{2}}`
Expand Down Expand Up @@ -52,6 +64,17 @@ def rbrownian_ito_move(mnf, x, unit_move, scale):
+ mnf.ito_drift(x)*scale)


@partial(jit, static_argnums=(0,))
def ito_move_with_drift(mnf, x, unit_move, scale, additional_drift):
""" This method is used to simulate a Riemanian Brownian motion with drift. The additional_drift
is on top of the Brownian motion.
Use Euler Maruyama and projection method to solve the Ito equation.
"""
return mnf.approx_nearest(
x + mnf.proj(x, mnf.sigma(x, unit_move.reshape(mnf.shape)*jnp.sqrt(scale)))
+ (additional_drift+mnf.ito_drift(x))*scale)


@partial(jit, static_argnums=(0,))
def rbrownian_stratonovich_move(mnf, x, unit_move, scale):
""" Use Euler Heun and projection method to solve the Stratonovich equation.
Expand All @@ -61,3 +84,17 @@ def rbrownian_stratonovich_move(mnf, x, unit_move, scale):
xbk = x + mnf.proj(x, dxs)
return mnf.approx_nearest(x + mnf.proj(0.5*(x + xbk), dxs)
+ mnf.proj(x, mnf.ito_drift(x)*scale))


@partial(jit, static_argnums=(0,))
def stratonovich_move_with_drift(mnf, x, unit_move, scale, additional_drift):
"""
This method is used to simulate a Riemanian Brownian motion with drift. The additional_drift
is on top of the Brownian motion.
Use Euler Heun and projection method to solve the Stratonovich equation.
"""
# stochastic dx
dxs = mnf.sigma(x, unit_move.reshape(mnf.shape)*jnp.sqrt(scale))
xbk = x + mnf.proj(x, dxs)
return mnf.approx_nearest(x + mnf.proj(0.5*(x + xbk), dxs)
+ mnf.proj(x, mnf.ito_drift(x)+additional_drift)*scale)
45 changes: 43 additions & 2 deletions jax_rb/simulation/matrix_group_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,26 @@

@partial(jit, static_argnums=(0,))
def geodesic_move(mnf, x, unit_move, scale):
""" unit_move is reshaped to the shape conforming with sigma., usually the shape of the ambient space.
""" :math:`\\text{unit_move}` is reshaped to the shape conforming with sigma., usually the shape of the ambient space.
The move is :math:`x_{new} = \\mathfrak{r}(x, \\sigma(x)(\\text{unit_move}(\\text{scale})^{\\frac{1}{2}}))`
"""
return x@mnf.retract(jnp.eye(mnf.shape[0]),
mnf.sigma_id(
jnp.sqrt(scale)*unit_move.reshape(mnf.shape)))


@partial(jit, static_argnums=(0,))
def geodesic_move_with_drift(mnf, x, unit_move, scale, id_additional_drift):
""" This method is used to simulate a Riemanian Brownian motion with drift.
:math:`\\text{unit_move}` is reshaped to the shape conforming with sigma., usually the shape of the ambient space. :math:`\\text{id_additional_drift}` is an element of the Lie algebra.
The move is :math:`x_{new} = \\mathfrak{r}(x, \\sigma(x)((\\text{scale})^{\\frac{1}{2}}\\times \\text{unit_move})+\\text{scale}\\times x (\\text{id_additional_drift}))`
"""
return x@mnf.retract(jnp.eye(mnf.shape[0]),
mnf.sigma_id(
jnp.sqrt(scale)*unit_move.reshape(mnf.shape))
+ id_additional_drift*scale)


@partial(jit, static_argnums=(0,))
def geodesic_move_normalized(mnf, x, unit_move, scale):
""" Similar to geodesic_move, but unit move is rescaled to have fixed length 1
Expand All @@ -29,7 +41,7 @@ def geodesic_move_normalized(mnf, x, unit_move, scale):

@partial(jit, static_argnums=(0,))
def geodesic_move_dim_g(mnf, x, unit_move, scale):
"""Unit_move is of dimension :math:`\\dim \\mathrm{G}`.
""":math:`\\text{unit_move}` is of dimension :math:`\\dim \\mathrm{G}`.
The move is :math:`x_{new} = \\mathfrak{r}(x, \\sigma_{la}(x)(\\text{unit_move}(\\text{scale})^{\\frac{1}{2}}))`
"""
return x@mnf.retract(jnp.eye(mnf.shape[0]),
Expand Down Expand Up @@ -57,6 +69,18 @@ def rbrownian_ito_move(mnf, x, unit_move, scale):
+ x@mnf.id_drift*scale)


@partial(jit, static_argnums=(0,))
def ito_move_with_drift(mnf, x, unit_move, scale, id_additional_drift):
""" This method is used to simulate a Riemanian Brownian motion with drift given in Ito form. Use stochastic projection method to solve the Ito equation.
The drift is given as an element of the Lie algebra.
Use Euler Maruyama here.
"""
n = mnf.shape[0]
return mnf.approx_nearest(
x@jnp.eye(n) + x@mnf.sigma_id(unit_move.reshape(mnf.shape)*jnp.sqrt(scale))
+ x@mnf.id_drift*scale + x@id_additional_drift*scale)


@partial(jit, static_argnums=(0,))
def rbrownian_stratonovich_move(mnf, x, unit_move, scale):
""" Using projection method to solve the Stratonovich equation.
Expand All @@ -70,6 +94,23 @@ def rbrownian_stratonovich_move(mnf, x, unit_move, scale):
move = jnp.eye(n) + 0.5*(2*jnp.eye(n)+dxs)@dxs + mnf.v0*scale
return x@mnf.approx_nearest(move)


@partial(jit, static_argnums=(0,))
def stratonovich_move_with_drift(mnf, x, unit_move, scale, id_additional_drift):
""" This method is used to simulate a Riemanian Brownian motion with drift given in Stratonovich form.
Using projection method to solve the Stratonovich equation.
The additional drift is on top of the RB term, given as an element of the Lie algebra
In many cases :math:`v_0` is zero (unimodular group).
Use Euler Heun.
"""
n = mnf.shape[0]
# stochastic dx
dxs = mnf.sigma_id(unit_move.reshape(mnf.shape)*jnp.sqrt(scale))

move = jnp.eye(n) + 0.5*(2*jnp.eye(n)+dxs)@dxs + mnf.v0*scale + scale*id_additional_drift
return x@mnf.approx_nearest(move)


@partial(jit, static_argnums=(0,))
def ito_move_dim_g(mnf, x, unit_move, scale):
"""Similar to rbrownian_ito_move, but driven with a Wiener
Expand Down
Loading

0 comments on commit 829c06c

Please sign in to comment.