From 829c06c2301ca7671986bc311b70b19267494bf9 Mon Sep 17 00:00:00 2001 From: "dnguyend@gmail.com" Date: Wed, 24 Jul 2024 19:02:28 -0400 Subject: [PATCH] add drift simulation --- .../simulation/global_manifold_integrator.py | 37 ++ jax_rb/simulation/matrix_group_integrator.py | 45 +- tests/langevin_group.py | 393 ++++++++++++++++++ tests/langevin_stiefel.py | 364 ++++++++++++++++ 4 files changed, 837 insertions(+), 2 deletions(-) create mode 100644 tests/langevin_group.py create mode 100644 tests/langevin_stiefel.py diff --git a/jax_rb/simulation/global_manifold_integrator.py b/jax_rb/simulation/global_manifold_integrator.py index 5bb04de..972be66 100644 --- a/jax_rb/simulation/global_manifold_integrator.py +++ b/jax_rb/simulation/global_manifold_integrator.py @@ -15,6 +15,18 @@ def geodesic_move(mnf, x, unit_move, scale): return mnf.retract(x, mnf.proj(x, mnf.sigma(x, unit_move.reshape(mnf.shape)*jnp.sqrt(scale)))) +@partial(jit, static_argnums=(0,)) +def geodesic_move_with_drift(mnf, x, unit_move, scale, additional_drift): + """ This method is used to simulate a Riemanian Brownian motion with drift. The additional_drift + is on top of the Brownian motion. + Simulate using a second order retraction. + The move is :math:`x_{new} = \\mathfrak{r}(x, \\Pi(x)\\sigma(x)(\\text{unit_move}(\\text{scale})^{\\frac{1}{2}}+\\text{scale}\\times\\text{additional_drift}))` + """ + return mnf.retract(x, mnf.proj(x, mnf.sigma(x, unit_move.reshape(mnf.shape)*jnp.sqrt(scale)) + + scale*additional_drift)) + + + @partial(jit, static_argnums=(0,)) def geodesic_move_normalized(mnf, x, unit_move, scale): """ similar to geodesic_move, but the move is normalized to have fixed length :math:`scale^{\\frac{1}{2}}` @@ -52,6 +64,17 @@ def rbrownian_ito_move(mnf, x, unit_move, scale): + mnf.ito_drift(x)*scale) +@partial(jit, static_argnums=(0,)) +def ito_move_with_drift(mnf, x, unit_move, scale, additional_drift): + """ This method is used to simulate a Riemanian Brownian motion with drift. The additional_drift + is on top of the Brownian motion. + Use Euler Maruyama and projection method to solve the Ito equation. + """ + return mnf.approx_nearest( + x + mnf.proj(x, mnf.sigma(x, unit_move.reshape(mnf.shape)*jnp.sqrt(scale))) + + (additional_drift+mnf.ito_drift(x))*scale) + + @partial(jit, static_argnums=(0,)) def rbrownian_stratonovich_move(mnf, x, unit_move, scale): """ Use Euler Heun and projection method to solve the Stratonovich equation. @@ -61,3 +84,17 @@ def rbrownian_stratonovich_move(mnf, x, unit_move, scale): xbk = x + mnf.proj(x, dxs) return mnf.approx_nearest(x + mnf.proj(0.5*(x + xbk), dxs) + mnf.proj(x, mnf.ito_drift(x)*scale)) + + +@partial(jit, static_argnums=(0,)) +def stratonovich_move_with_drift(mnf, x, unit_move, scale, additional_drift): + """ + This method is used to simulate a Riemanian Brownian motion with drift. The additional_drift + is on top of the Brownian motion. + Use Euler Heun and projection method to solve the Stratonovich equation. + """ + # stochastic dx + dxs = mnf.sigma(x, unit_move.reshape(mnf.shape)*jnp.sqrt(scale)) + xbk = x + mnf.proj(x, dxs) + return mnf.approx_nearest(x + mnf.proj(0.5*(x + xbk), dxs) + + mnf.proj(x, mnf.ito_drift(x)+additional_drift)*scale) diff --git a/jax_rb/simulation/matrix_group_integrator.py b/jax_rb/simulation/matrix_group_integrator.py index d9a96fa..e55c5ac 100644 --- a/jax_rb/simulation/matrix_group_integrator.py +++ b/jax_rb/simulation/matrix_group_integrator.py @@ -9,7 +9,7 @@ @partial(jit, static_argnums=(0,)) def geodesic_move(mnf, x, unit_move, scale): - """ unit_move is reshaped to the shape conforming with sigma., usually the shape of the ambient space. + """ :math:`\\text{unit_move}` is reshaped to the shape conforming with sigma., usually the shape of the ambient space. The move is :math:`x_{new} = \\mathfrak{r}(x, \\sigma(x)(\\text{unit_move}(\\text{scale})^{\\frac{1}{2}}))` """ return x@mnf.retract(jnp.eye(mnf.shape[0]), @@ -17,6 +17,18 @@ def geodesic_move(mnf, x, unit_move, scale): jnp.sqrt(scale)*unit_move.reshape(mnf.shape))) +@partial(jit, static_argnums=(0,)) +def geodesic_move_with_drift(mnf, x, unit_move, scale, id_additional_drift): + """ This method is used to simulate a Riemanian Brownian motion with drift. + :math:`\\text{unit_move}` is reshaped to the shape conforming with sigma., usually the shape of the ambient space. :math:`\\text{id_additional_drift}` is an element of the Lie algebra. + The move is :math:`x_{new} = \\mathfrak{r}(x, \\sigma(x)((\\text{scale})^{\\frac{1}{2}}\\times \\text{unit_move})+\\text{scale}\\times x (\\text{id_additional_drift}))` + """ + return x@mnf.retract(jnp.eye(mnf.shape[0]), + mnf.sigma_id( + jnp.sqrt(scale)*unit_move.reshape(mnf.shape)) + + id_additional_drift*scale) + + @partial(jit, static_argnums=(0,)) def geodesic_move_normalized(mnf, x, unit_move, scale): """ Similar to geodesic_move, but unit move is rescaled to have fixed length 1 @@ -29,7 +41,7 @@ def geodesic_move_normalized(mnf, x, unit_move, scale): @partial(jit, static_argnums=(0,)) def geodesic_move_dim_g(mnf, x, unit_move, scale): - """Unit_move is of dimension :math:`\\dim \\mathrm{G}`. + """:math:`\\text{unit_move}` is of dimension :math:`\\dim \\mathrm{G}`. The move is :math:`x_{new} = \\mathfrak{r}(x, \\sigma_{la}(x)(\\text{unit_move}(\\text{scale})^{\\frac{1}{2}}))` """ return x@mnf.retract(jnp.eye(mnf.shape[0]), @@ -57,6 +69,18 @@ def rbrownian_ito_move(mnf, x, unit_move, scale): + x@mnf.id_drift*scale) +@partial(jit, static_argnums=(0,)) +def ito_move_with_drift(mnf, x, unit_move, scale, id_additional_drift): + """ This method is used to simulate a Riemanian Brownian motion with drift given in Ito form. Use stochastic projection method to solve the Ito equation. + The drift is given as an element of the Lie algebra. + Use Euler Maruyama here. + """ + n = mnf.shape[0] + return mnf.approx_nearest( + x@jnp.eye(n) + x@mnf.sigma_id(unit_move.reshape(mnf.shape)*jnp.sqrt(scale)) + + x@mnf.id_drift*scale + x@id_additional_drift*scale) + + @partial(jit, static_argnums=(0,)) def rbrownian_stratonovich_move(mnf, x, unit_move, scale): """ Using projection method to solve the Stratonovich equation. @@ -70,6 +94,23 @@ def rbrownian_stratonovich_move(mnf, x, unit_move, scale): move = jnp.eye(n) + 0.5*(2*jnp.eye(n)+dxs)@dxs + mnf.v0*scale return x@mnf.approx_nearest(move) + +@partial(jit, static_argnums=(0,)) +def stratonovich_move_with_drift(mnf, x, unit_move, scale, id_additional_drift): + """ This method is used to simulate a Riemanian Brownian motion with drift given in Stratonovich form. + Using projection method to solve the Stratonovich equation. + The additional drift is on top of the RB term, given as an element of the Lie algebra + In many cases :math:`v_0` is zero (unimodular group). + Use Euler Heun. + """ + n = mnf.shape[0] + # stochastic dx + dxs = mnf.sigma_id(unit_move.reshape(mnf.shape)*jnp.sqrt(scale)) + + move = jnp.eye(n) + 0.5*(2*jnp.eye(n)+dxs)@dxs + mnf.v0*scale + scale*id_additional_drift + return x@mnf.approx_nearest(move) + + @partial(jit, static_argnums=(0,)) def ito_move_dim_g(mnf, x, unit_move, scale): """Similar to rbrownian_ito_move, but driven with a Wiener diff --git a/tests/langevin_group.py b/tests/langevin_group.py new file mode 100644 index 0000000..6ea9fe3 --- /dev/null +++ b/tests/langevin_group.py @@ -0,0 +1,393 @@ +""" test riemannian langevin for SO and SE +""" + + +from functools import partial + +import jax +import jax.numpy as jnp +import jax.numpy.linalg as jla +from jax import random, vmap, jit +import jax_rb.manifolds.so_left_invariant as som +import jax_rb.manifolds.se_left_invariant as sem + +from jax_rb.utils.utils import (rand_positive_definite, sym, vcat, grand) +import jax_rb.simulation.simulator as sim +import jax_rb.simulation.matrix_group_integrator as mi + + +jax.config.update("jax_enable_x64", True) + + +def sqr(x): + return x@x + + +def cz(mat): + return jnp.max(jnp.abs(mat)) + + +class cayley_so_retraction(): + """Cayley retraction of a matrix Lie group + this is the most general, and not efficient implementation + for each lie group, we should have a custom implementation of this + """ + def __init__(self, mnf): + self.mnf = mnf + + def retract(self, x, v): + """rescaling :math:`x+v` to be on the manifold + """ + ixv = x.T@v + return x + x@jla.solve(jnp.eye(ixv.shape[0]) - 0.5*ixv, ixv) + + def inverse_retract(self, x, y): + u = x.T@y + n = self.mnf.shape[0] + return 2*x@jla.solve(jnp.eye(n)+u, u-jnp.eye(n)) + + def drift_adjust(self, x, driver_dim): + """return the adjustment :math:`\\mu_{adj}` + so that :math:`\\mu + \\mu_{adj} = \\mu_{\\mathfrak{r}}` + """ + return -0.5*jnp.sum(vmap(lambda seq: + x@sqr(self.mnf.sigma_id(seq.reshape(x.shape))) + )(jnp.eye(driver_dim)), + axis=0) + + +def uniform_sample(key, shape, pay_off, n_samples): + """ Sample the manifold uniformly + """ + x_all, key = grand(key, (shape[0], shape[1], n_samples)) + + def do_one_point(seq): + # ei, ev = jla.eigh(seq.T@seq) + # return pay_off(seq@ev@((1/jnp.sqrt(ei))[:, None]*ev.T)) + u, _, vt = jla.svd(seq) + return pay_off(u[:, :shape[0]]@vt) + + s = jax.vmap(do_one_point, in_axes=2)(x_all) + return jnp.nanmean(s) + + +def test_langevin_so(): + # test Langevin on se(n) with vfunc = e^{-\frac{1}{2}v^T\Lambda v} + # jax.config.update('jax_default_device', jax.devices('cpu')[0]) + n = 4 + so_dim = n*(n-1)//2 + + lbd = 2.1*jnp.arange(1, so_dim+1) + + def log_v(_, x): + return -jnp.sum(x[jnp.triu_indices(n,1)].reshape(-1)**2*lbd) + + def grad_log_v(mnf, x): + idx = jnp.triu_indices(n,1) + return mnf.proj(x, mnf.inv_g_metric( + x, + jnp.zeros_like(x).at[idx].set(-2*lbd*x[idx].reshape(-1)))) + + key = random.PRNGKey(0) + + metric_mat, key = rand_positive_definite(key, so_dim, (.1, 30.)) + + print("Doing SO") + + # metric_mat = jnp.eye(se_dim) + so = som.SOLeftInvariant(n, metric_mat) + crtr = cayley_so_retraction(so) + # x, key = so.rand_point(key) + # eta, key = so.rand_vec(key, x) + # x1 = crtr.retract(x, eta) + # eta1 = crtr.inverse_retract(x, x1) + # print(cz(eta1-eta)) + + # pay_offs = [None, lambda x: jnp.sqrt(jnp.sum(x*x))] + pay_offs = [None, lambda x: jnp.sqrt(1+jnp.sum(jnp.abs(x)))] + # lbd1 = jnp.arange(1, n**2+1) + lbd1, key = grand(key, (n**2,)) + pay_offs = [None, lambda x: jnp.sqrt(1+jnp.sum(jnp.abs(lbd1*x.reshape(-1))))] + + x_0 = jnp.eye(n) + + key, sk = random.split(key) + t_final = 40. + # t_final = 1.5 + n_path = 1000 + n_div = 1000 + d_coeff = .5 + + wiener_dim = n**2 + + ret_rtr = sim.simulate(x_0, + lambda x, unit_move, scale: crtr.retract( + x, + x@so.sigma_id(unit_move.reshape(x.shape))*scale**.5 + + 0.5*grad_log_v(so, x)*scale), + pay_offs[0], + pay_offs[1], + [sk, t_final, n_path, n_div, d_coeff, wiener_dim]) + + print("SO Cayley retract %.3f" % jnp.nanmean(ret_rtr[0])) + + # The more effective sample + n_samples = 1000**2 + ret_denom = uniform_sample( + key, so.shape, + lambda x: jnp.exp(log_v(so, x)), + n_samples) + + ret_num = uniform_sample( + key, so.shape, + lambda x: pay_offs[1](x)*jnp.exp(log_v(so, x)), + n_samples) + print("SO sampling with density %.3f" % (ret_num/ret_denom)) + + +def test_langevin_se(): + # test Langevin on se(n) with vfunc = e^{-\frac{1}{2}v^T\Lambda v} + # jax.config.update('jax_default_device', jax.devices('cpu')[0]) + n = 3 + # lbd = 0.5*jnp.arange(1, n+1) + lbd = 0.5*jnp.ones(n) + # lbd = jnp.array([1., 100.]) + lbd = 10.*jnp.arange(1, n+1) + + @partial(jit, static_argnums=(0,)) + def log_v(_, x): + return -0.5*jnp.sum(x[:-1, -1]*lbd*x[:-1, -1]) + + @partial(jit, static_argnums=(0,)) + def grad_log_v(mnf, x): + return mnf.proj(x, mnf.inv_g_metric( + x, + jnp.zeros_like(x).at[:-1, -1].set(-lbd*x[:-1, -1]))) + + key = random.PRNGKey(0) + + se_dim = n*(n+1)//2 + n1 = n+1 + metric_mat, key = rand_positive_definite(key, se_dim, (.1, 30.)) + + # convergent seems to be to same metric, but different rate + + # metric_mat = jnp.eye(se_dim) + # metric_mat = metric_mat.at[0, 0].set(1.) + se = sem.SELeftInvariant(n, metric_mat) + # x, key = se.rand_point(key) + # eta, key = se.rand_vec(key, x) + + # print(jax.jvp(lambda x: log_v(se, x), (x,), (eta,))[1]) + # print(se.inner(x, grad_log_v(se, x), eta)) + + # pay_offs = [None, lambda x: jnp.sqrt(jnp.sum(x[:-1, -1]**2))] + + # pay_offs = [None, lambda x: jnp.sum(jnp.abs(x[:-1, -1]))] + + # pay_offs = [None, lambda x: jnp.sqrt(jnp.sum(x*x))] + print("Test SE with n=%d expectation of sum |x|" % (n)) + pay_offs = [None, lambda x: jnp.sum(jnp.abs(x))] + + x_0 = jnp.eye(n1) + key, sk = random.split(key) + t_final = 100. + n_path = 5000 + n_div = 1000 + d_coeff = .5 + + wiener_dim = n1**2 + ret_rtr1 = sim.simulate(x_0, + lambda x, unit_move, scale: mi.ito_move_with_drift( + se, x, unit_move, scale, 0.5*jla.solve(x, grad_log_v(se, x))), + pay_offs[0], + # lambda x: x[1, -1]*x[1, -1], + pay_offs[1], + [sk, t_final, n_path, n_div, d_coeff, wiener_dim]) + + print("Ito Langevin %.3f" % jnp.nanmean(ret_rtr1[0])) + + ret_rtr2 = sim.simulate(x_0, + lambda x, unit_move, scale: mi.stratonovich_move_with_drift( + se, x, unit_move, scale, 0.5*jla.solve(x, grad_log_v(se, x))), + pay_offs[0], + # lambda x: x[1, -1]*x[1, -1], + pay_offs[1], + [sk, t_final, n_path, n_div, d_coeff, wiener_dim]) + + print("Stratonovich Langevin %.3f" % jnp.nanmean(ret_rtr2[0])) + + + ret_rtr3 = sim.simulate(x_0, + lambda x, unit_move, scale: mi.geodesic_move_with_drift( + se, x, unit_move, scale, 0.5*jla.solve(x, grad_log_v(se, x))), + pay_offs[0], + # lambda x: x[1, -1]*x[1, -1], + pay_offs[1], + [sk, t_final, n_path, n_div, d_coeff, wiener_dim]) + + print("Geodesic 2nd order langevin %.3f" % jnp.nanmean(ret_rtr3[0])) + + + def se_sample(key, shape, pay_off, n_samples): + """ Sample the manifold uniformly on the sphere + and with the + """ + x_all, key = grand(key, (shape[0]-1, shape[1], n_samples)) + + def do_one_point(seq): + # ei, ev = jla.eigh(seq.T@seq) + # return pay_off(seq@ev@((1/jnp.sqrt(ei))[:, None]*ev.T)) + u, _, vt = jla.svd(seq[:, :-1]) + x = vcat(jnp.concatenate( + [u@vt, seq[:, -1][:, None]], axis=1), + jnp.zeros((1, shape[1])).at[0, -1].set(1.)) + return pay_off(x)*jnp.exp(log_v(se, x)+0.5*jnp.sum(x[:-1, -1]**2)) + + s = jax.vmap(do_one_point, in_axes=2)(x_all) + return jnp.nanmean(s) + + n_samples = 1000**2 + ret_denom = se_sample( + key, se.shape, + lambda x: 1., + n_samples) + ret_num = se_sample( + key, se.shape, + pay_offs[1], + n_samples) + + print("uniform sampling with density %.3f" % (ret_num/ret_denom)) + + +def test_langevin_se2(): + # jax.config.update('jax_default_device', jax.devices('cpu')[0]) + n = 4 + se_dim = n*(n+1)//2 + n1 = n+1 + + lbd = 0.5*jnp.arange(1, n+1) + lbd = 10.*jnp.ones(n) + + @partial(jit, static_argnums=(0,)) + def log_v(_, x): + return -0.5*jnp.sum(x[:-1, -1]*lbd*x[:-1, -1]) + + @partial(jit, static_argnums=(0,)) + def grad_log_v(mnf, x): + return mnf.proj(x, mnf.inv_g_metric( + x, + jnp.zeros_like(x).at[:-1, -1].set(-lbd*x[:-1, -1]))) + + key = random.PRNGKey(0) + + # metric_mat, key = rand_positive_definite(key, se_dim, (.1, 30.)) + A, key = grand(key, (n*n1,n*n1)) + A = sym(A@A.T) + # convergent seems to be to same metric, but different rate + + metric_mat = jnp.eye(se_dim) + # metric_mat = metric_mat.at[0, 0].set(1.) + se = sem.SELeftInvariant(n, metric_mat) + # x, key = se.rand_point(key) + # eta, key = se.rand_vec(key, x) + + # print(jax.jvp(lambda x: log_v(se, x), (x,), (eta,))[1]) + # print(se.inner(x, grad_log_v(se, x), eta)) + print("Test SE n=%d expectation of |x^TAx|^(1/2) for a positive definite matrix A" % (n)) + + pay_offs = [None, lambda x: jnp.sqrt(jnp.abs(jnp.sum(x[:-1, :].reshape(-1)*(A@x[:-1, :].reshape(-1)))))] + # pay_offs = [None, lambda x: jnp.sqrt(jnp.sum(x*x*jnp.arange(1, n1+1)[None, :]))] + # pay_offs = [None, lambda x: jnp.sum(x[0, :-1]*x[:-1, -1])] + + x_0 = jnp.eye(n1) + key, sk = random.split(key) + t_final = 50. + n_path = 5000 + n_div = 1000 + d_coeff = .5 + + wiener_dim = n1**2 + + ret_rtr1 = sim.simulate(x_0, + lambda x, unit_move, scale: mi.ito_move_with_drift( + se, x, unit_move, scale, 0.5*jla.solve(x, grad_log_v(se, x))), + pay_offs[0], + # lambda x: x[1, -1]*x[1, -1], + pay_offs[1], + [sk, t_final, n_path, n_div, d_coeff, wiener_dim]) + + print("Ito Langevin %.3f" % jnp.nanmean(ret_rtr1[0])) + + + ret_rtr2 = sim.simulate(x_0, + lambda x, unit_move, scale: mi.stratonovich_move_with_drift( + se, x, unit_move, scale, 0.5*jla.solve(x, grad_log_v(se, x))), + pay_offs[0], + # lambda x: x[1, -1]*x[1, -1], + pay_offs[1], + [sk, t_final, n_path, n_div, d_coeff, wiener_dim]) + + print("Stratonovich Langevin %.3f" % jnp.nanmean(ret_rtr2[0])) + + + ret_rtr3 = sim.simulate(x_0, + lambda x, unit_move, scale: mi.geodesic_move_with_drift( + se, x, unit_move, scale, 0.5*jla.solve(x, grad_log_v(se, x))), + pay_offs[0], + # lambda x: x[1, -1]*x[1, -1], + pay_offs[1], + [sk, t_final, n_path, n_div, d_coeff, wiener_dim]) + + print("Geodesic 2nd order langevin %.3f" % jnp.nanmean(ret_rtr3[0])) + + + def se_sample(key, shape, pay_off, n_samples): + """ Sample the manifold uniformly on the sphere + and with the + """ + x_all, key = grand(key, (shape[0]-1, shape[1], n_samples)) + + def do_one_point(seq): + # ei, ev = jla.eigh(seq.T@seq) + # return pay_off(seq@ev@((1/jnp.sqrt(ei))[:, None]*ev.T)) + u, _, vt = jla.svd(seq[:, :-1]) + x = vcat(jnp.concatenate( + [u@vt, seq[:, -1][:, None]], axis=1), + jnp.zeros((1, shape[1])).at[0, -1].set(1.)) + return pay_off(x)*jnp.exp(log_v(se, x)+0.5*jnp.sum(x[:-1, -1]**2)) + #return jnp.sqrt(3+jnp.sum(x[:-1, -1]**2))*jnp.exp(log_v(se, x)+0.5*jnp.sum(x[:-1, -1]**2)) + + s = jax.vmap(do_one_point, in_axes=2)(x_all) + # ret = [] + # for i in range(x_all.shape[2]): + # ret.append(do_one_point(x_all[:, :, i])) + # s = jnp.array(ret) + return jnp.nanmean(s) + + n_samples = 1000**2 + + ret_denom = se_sample( + key, se.shape, + lambda x: 1., + n_samples) + """ + ret_num = se_sample( + key, se.shape, + lambda x: pay_offs[1](x), + n_path*500) + """ + ret_num = se_sample( + key, se.shape, + # lambda x: x[1, -1]*x[1, -1], + pay_offs[1], + # lambda x: pay_offs[1](x) - jnp.sqrt(3+jnp.sum(x[:-1, -1]**2)), + n_samples) + + print("uniform sampling with density %.3f" % (ret_num/ret_denom)) + + +if __name__ == '__main__': + test_langevin_so() + test_langevin_se() + test_langevin_se2() diff --git a/tests/langevin_stiefel.py b/tests/langevin_stiefel.py new file mode 100644 index 0000000..1ad0ecd --- /dev/null +++ b/tests/langevin_stiefel.py @@ -0,0 +1,364 @@ +""" test riemannian langevin for stiefel manifolds +""" + +from functools import partial +from time import perf_counter + +import jax +import jax.numpy as jnp +import jax.numpy.linalg as jla +from jax import random, vmap, jit +""" +""" +from jax.scipy.linalg import expm +import jax_rb.manifolds.stiefel as stm + +from jax_rb.utils.utils import (sym, grand) +import jax_rb.simulation.simulator as sim +import jax_rb.simulation.global_manifold_integrator as gmi + + +jax.config.update("jax_enable_x64", True) + + +def sqr(x): + return x@x + + +def cz(mat): + return jnp.max(jnp.abs(mat)) + + +class stiefel_polar_retraction(): + def __init__(self, mnf): + self.mnf = mnf + + def retract(self, x, v): + """rescaling :math:`x+v` to be on the manifold + """ + u, _, vt = jla.svd(x+v, full_matrices=False) + return u@vt + + def drift_adjust(self, x): + n, d, alp1 = self.mnf.shape[0], self.mnf.shape[1], self.mnf.alpha[1] + return -0.5*(n-d+0.5*(d-1)/alp1)*x + + +def uniform_sampling(key, shape, pay_off, n_samples): + """ Sample the manifold uniformly + """ + x_all, key = grand(key, (shape[0], shape[1], n_samples)) + + def do_one_point(seq): + # ei, ev = jla.eigh(seq.T@seq) + # return pay_off(seq@ev@((1/jnp.sqrt(ei))[:, None]*ev.T)) + u, _, vt = jla.svd(seq, full_matrices=False) + return pay_off(u[:, :shape[0]]@vt) + + s = jax.vmap(do_one_point, in_axes=2)(x_all) + return jnp.nanmean(s) + + +def test_stiefel_langevin_von_mises_fisher(key, stf, kp, F, func): + # test Langevin on stiefel with vfunc = e^{-\frac{1}{2}v^T\Lambda v} + # jax.config.update('jax_default_device', jax.devices('cpu')[0]) + print("Doing Stiefel von Mises Fisher (n, d)=%s alpha=%s" % (str(stf.shape), str(stf.alpha))) + + @partial(jit, static_argnums=(0,)) + def log_v(_, x): + return kp*jnp.trace(F.T@x) + + @partial(jit, static_argnums=(0,)) + def grad_log_v(mnf, x): + return kp*mnf.proj(x, mnf.inv_g_metric(x, F)) + + x, key = stf.rand_point(key) + eta, key = stf.rand_vec(key, x) + + # print(jax.jvp(lambda x: log_v(stf, x), (x,), (eta,))[1]) + # print(stf.inner(x, grad_log_v(stf, x), eta)) + + pay_offs = [None, func] + + x_0, key = stf.rand_point(key) + key, sk = random.split(key) + t_final = 5. + n_path = 10000 + n_div = 500 + d_coeff = .5 + + wiener_dim = stf.shape[0]*stf.shape[1] + # crtr = cayley_se_retraction(se) + + # rbrownian_ito_langevin_move(mnf, x, unit_move, scale, grad_log_v) + ret_rtr1 = sim.simulate(x_0, + lambda x, unit_move, scale: gmi.ito_move_with_drift( + stf, x, unit_move, scale, 0.5*grad_log_v(stf, x)), + pay_offs[0], + # lambda x: x[1, -1]*x[1, -1], + pay_offs[1], + [sk, t_final, n_path, n_div, d_coeff, wiener_dim]) + + print("ito langevin %.3f" % jnp.nanmean(ret_rtr1[0])) + + ret_rtr2 = sim.simulate(x_0, + lambda x, unit_move, scale: gmi.stratonovich_move_with_drift( + stf, x, unit_move, scale, 0.5*grad_log_v(stf, x)), + pay_offs[0], + # lambda x: x[1, -1]*x[1, -1], + pay_offs[1], + [sk, t_final, n_path, n_div, d_coeff, wiener_dim]) + + print("stratonovich langevin %.3f" % jnp.nanmean(ret_rtr2[0])) + + ret_rtr3 = sim.simulate(x_0, + lambda x, unit_move, scale: gmi.geodesic_move_with_drift( + stf, x, unit_move, scale, 0.5*grad_log_v(stf, x)), + pay_offs[0], + # lambda x: x[1, -1]*x[1, -1], + pay_offs[1], + [sk, t_final, n_path, n_div, d_coeff, wiener_dim]) + + print("geodesic langevin %.3f" % jnp.nanmean(ret_rtr3[0])) + + n_samples = 1000**2 + ret_spl = uniform_sampling(key, stf.shape, + lambda x: pay_offs[1](x)*jnp.exp(log_v(None, x)), + n_samples) + + ret_spl_0 = uniform_sampling(key, stf.shape, + lambda x: jnp.exp(log_v(None, x)), + n_samples) + + print("stiefel uniform sampling with density %.3f" % (ret_spl/ret_spl_0)) + # import scipy.special as ss + # print(jnp.sqrt(2)*ss.iv(1, 1)/ss.iv(.5, 1)*ss.gamma(1.5)) + + +def test_stiefel_langevin_bingham(key, stf, A, func): + # test Langevin on stiefel with vfunc = e^{-\frac{1}{2}v^T\Lambda v} + # jax.config.update('jax_default_device', jax.devices('cpu')[0]) + @partial(jit, static_argnums=(0,)) + def log_v(_, x): + return jnp.trace(x.T@A@x) + + @partial(jit, static_argnums=(0,)) + def grad_log_v(mnf, x): + return mnf.proj(x, mnf.inv_g_metric(x, 2*A@x)) + + print("Doing Bingham (n, d)=%s alpha=%s" % (str(stf.shape), str(stf.alpha))) + + # x, key = stf.rand_point(key) + # eta, key = stf.rand_vec(key, x) + + # print(jax.jvp(lambda x: log_v(stf, x), (x,), (eta,))[1]) + # print(stf.inner(x, grad_log_v(stf, x), eta)) + + pay_offs = [None, func] + + x_0, key = stf.rand_point(key) + key, sk = random.split(key) + t_final = 5. + n_path = 10000 + n_div = 500 + d_coeff = .5 + + wiener_dim = stf.shape[0]*stf.shape[1] + # crtr = cayley_se_retraction(se) + + # rbrownian_ito_langevin_move(mnf, x, unit_move, scale, grad_log_v) + ret_rtr1 = sim.simulate(x_0, + lambda x, unit_move, scale: gmi.ito_move_with_drift( + stf, x, unit_move, scale, 0.5*grad_log_v(stf, x)), + pay_offs[0], + # lambda x: x[1, -1]*x[1, -1], + pay_offs[1], + [sk, t_final, n_path, n_div, d_coeff, wiener_dim]) + + print("ito langevin %.3f" % jnp.nanmean(ret_rtr1[0])) + + ret_rtr2 = sim.simulate(x_0, + lambda x, unit_move, scale: gmi.stratonovich_move_with_drift( + stf, x, unit_move, scale, 0.5*grad_log_v(stf, x)), + pay_offs[0], + # lambda x: x[1, -1]*x[1, -1], + pay_offs[1], + [sk, t_final, n_path, n_div, d_coeff, wiener_dim]) + + print("stratonovich langevin %.3f" % jnp.nanmean(ret_rtr2[0])) + + ret_rtr3 = sim.simulate(x_0, + lambda x, unit_move, scale: gmi.geodesic_move_with_drift( + stf, x, unit_move, scale, 0.5*grad_log_v(stf, x)), + pay_offs[0], + # lambda x: x[1, -1]*x[1, -1], + pay_offs[1], + [sk, t_final, n_path, n_div, d_coeff, wiener_dim]) + + print("geodesic langevin %.3f" % jnp.nanmean(ret_rtr3[0])) + + n_samples = 1000**2 + ret_spl = uniform_sampling(key, stf.shape, + lambda x: pay_offs[1](x)*jnp.exp(log_v(None, x)), + n_samples) + + ret_spl_0 = uniform_sampling(key, stf.shape, + lambda x: jnp.exp(log_v(None, x)), + n_samples) + + print("stiefel uniform sampling with density %.3f" % (ret_spl/ret_spl_0)) + + +def test_all_stiefel_von_mises_fisher(): + n = 3 + d = 1 + alp = jnp.array([1, .6]) + key = random.PRNGKey(0) + + stf = stm.RealStiefelAlpha((n, d), alp) + + # F, key = stf.rand_point(key) + kp = 1.2 + F, key = stf.rand_point(key) + test_stiefel_langevin_von_mises_fisher( + key, stf, kp, F, + lambda x: jnp.sqrt(jnp.abs(1-jnp.trace(F.T@x)**2))) + # print(jnp.sqrt(2)*ss.iv(1, 1)/ss.iv(.5, 1)*ss.gamma(1.5)) + + n = 5 + d = 3 + alp = jnp.array([1, .6]) + key = random.PRNGKey(0) + + stf = stm.RealStiefelAlpha((n, d), alp) + + # F, key = stf.rand_point(key) + kp = 1.2 + F, key = stf.rand_point(key) + + test_stiefel_langevin_von_mises_fisher(key, stf, kp, F, lambda x: jnp.sqrt(jnp.abs(1-jnp.trace(F.T@x)**2))) + + test_stiefel_langevin_von_mises_fisher(key, stf, kp, F, + lambda x: jnp.sum(jnp.abs(x))) + + +def gen_sym_traceless(key, n): + A, key = grand(key, (n, n)) + return sym(A) - jnp.trace(A)/n*jnp.eye(n), key + + +def test_all_bingham(): + n = 3 + d = 1 + alp = jnp.array([1, .6]) + key = random.PRNGKey(0) + + stf = stm.RealStiefelAlpha((n, d), alp) + + A, key = gen_sym_traceless(key, n) + test_stiefel_langevin_bingham( + key, stf, A, + lambda x: jnp.sum(jnp.abs(x))) + # print(jnp.sqrt(2)*ss.iv(1, 1)/ss.iv(.5, 1)*ss.gamma(1.5)) + + n = 5 + d = 3 + alp = jnp.array([1, .6]) + key = random.PRNGKey(0) + + stf = stm.RealStiefelAlpha((n, d), alp) + A, key = gen_sym_traceless(key, n) + test_stiefel_langevin_bingham( + key, stf, A, + lambda x: jnp.sum(jnp.abs(x))) + + test_stiefel_langevin_bingham( + key, stf, A, + lambda x: jnp.sum(jnp.abs(x)*(A@jnp.abs(x)))) + + n = 7 + d = 3 + alp = jnp.array([1, .6]) + key = random.PRNGKey(0) + + stf = stm.RealStiefelAlpha((n, d), alp) + A, key = gen_sym_traceless(key, n) + test_stiefel_langevin_bingham( + key, stf, A, + lambda x: jnp.sum(jnp.abs(x))) + + test_stiefel_langevin_bingham( + key, stf, A, + lambda x: jnp.sum(jnp.abs(x)*(A@jnp.abs(x)))) + + +def drift_adjust_verify(self, x, sigma, wiener_dim): + """return the adjustment :math:`\\mu_{adj}` + so that :math:`\\mu + \\mu_{adj} = \\mu_{\\mathfrak{r}}` + """ + def sqt(a): + return a.T@a + + return -0.5*x@jnp.sum(vmap(lambda seq: + sqt(self.proj(x, sigma(x, seq.reshape(x.shape)))))(jnp.eye(wiener_dim)), + axis=0) + + +def test_polar_retract_adjust(): + n = 7 + d = 3 + alp = jnp.array([1, .6]) + key = random.PRNGKey(0) + stf = stm.RealStiefelAlpha((n, d), alp) + print("Doing Stiefel Polar retract for Bingham (n, d)=%s alpha=%s" % (str(stf.shape), str(stf.alpha))) + @partial(jit, static_argnums=(0,)) + def log_v(_, x): + return jnp.trace(x.T@A@x) + + @partial(jit, static_argnums=(0,)) + def grad_log_v(mnf, x): + return mnf.proj(x, mnf.inv_g_metric(x, 2*A@x)) + + x, key = stf.rand_point(key) + + # mu2 = -0.5*(n-d+0.5*(d-1)/alp[1])*x + prtr = stiefel_polar_retraction(stf) + + mu1 = drift_adjust_verify(stf, x, stf.sigma, n*d) + mu2 = prtr.drift_adjust(x) + # print(mu2-mu1) + + A, key = gen_sym_traceless(key, n) + + x_0, key = stf.rand_point(key) + pay_offs = [None, lambda x: jnp.sum(jnp.abs(x))] + + key, sk = random.split(key) + t_final = 5. + n_path = 10000 + n_div = 500 + d_coeff = .5 + + wiener_dim = stf.shape[0]*stf.shape[1] + + test_stiefel_langevin_bingham( + key, stf, A, + pay_offs[1]) + + ret_rtr = sim.simulate(x_0, + lambda x, unit_move, scale: prtr.retract( + x, stf.proj(x, stf.sigma(x, unit_move.reshape(x.shape)*scale**.5 + + scale*( + stf.ito_drift(x) + + 0.5*grad_log_v(stf, x))))), + pay_offs[0], + pay_offs[1], + [sk, t_final, n_path, n_div, d_coeff, wiener_dim]) + + print("Polar adjust %.3f" % jnp.nanmean(ret_rtr[0])) + + +if __name__ == '__main__': + test_all_stiefel_von_mises_fisher() + test_all_bingham() + test_polar_retract_adjust()