+"""Base class for matrix groups with left invariant metrics.
+"""
+
+from functools import partial
+
+import jax
+import jax.numpy as jnp
+import jax.numpy.linalg as jla
+from .global_manifold import GlobalManifold
+from ..utils.utils import (grand, lie)
+
+
+
+
[docs]
+
class MatrixLeftInvariant(GlobalManifold):
+
"""Matrix group with left invariant metric.
+
+
:param p: the size of the matrix
+
:param g_mat: The matrix defining the inner product at the identity. Typically g_mat is of size :math:`\\dim \\mathrm{G}` .
+
"""
+
def __init__(self, p, g_mat):
+
"""Constructor
+
"""
+
self.shape = (p, p)
+
self.dim = p*p
+
self._g_mat = g_mat
+
ei, ev = jla.eigh(g_mat)
+
self._i_sqrt_g_mat = ev@((1/jnp.sqrt(ei))[:, None]*ev.T)
+
# Stratonovich drift at id
+
self.v0 = self._make_v0() #: Stratonovich drift at the identity.
+
self.id_drift = self._make_id_drift() #: Ito drift at the identity.
+
+
+
[docs]
+
def name(self):
+
raise NotImplementedError
+
+
+
def _lie_algebra_proj(self, omg):
+
""" The projection :math:`p_{\\mathfrak{g}}` at the identity.
+
"""
+
raise NotImplementedError
+
+
def _mat_apply(self, mat, omg):
+
""" Implementing the operator :math:`\\mathcal{I}` applied on omg in :math:`\\mathcal{E}`.
+
"""
+
raise NotImplementedError
+
+
def _id_opt(self, omg):
+
"""The metric applied at the identity.
+
"""
+
return self._mat_apply(self._g_mat, omg)
+
+
def _inv_id_opt(self, omg):
+
"""Invert of _id_opt.
+
"""
+
return self._mat_apply(jla.inv(self._g_mat), omg)
+
+
+
[docs]
+
def proj(self, x, omg):
+
return x@self._lie_algebra_proj(jla.solve(x, omg))
+
+
+
def _d_proj(self, x, xi, eta):
+
ivx = jla.inv(x)
+
return xi@self._lie_algebra_proj(jla.solve(x, eta)) \
+
- x@self._lie_algebra_proj(ivx@xi@ivx@eta)
+
+
+
[docs]
+
def rand_ambient(self, key):
+
return grand(key, (self.shape))
+
+
+
+
[docs]
+
def rand_vec(self, key, x):
+
omg, key = grand(key, self.shape)
+
return self.proj(x, omg), key
+
+
+
+
[docs]
+
def rand_point(self, key):
+
raise NotImplementedError
+
+
+
+
[docs]
+
def inner(self, x, a, b):
+
return jnp.sum(jla.solve(x, a)*self._id_opt(jla.solve(x, b)))
+
+
+
+
[docs]
+
def g_metric(self, x, omg):
+
return jla.solve(x.T, self._id_opt(jla.solve(x, omg)))
+
+
+
+
[docs]
+
def inv_g_metric(self, x, omg):
+
return x@self._inv_id_opt(x.T@omg)
+
+
+
@partial(jax.jit, static_argnums=(0,))
+
def gamma(self, x, xi, eta):
+
# return - self.d_proj(x, xi, eta)
+
# + self.proj(x, self.gamma_ambient(x, xi, eta))
+
ivxi = jla.solve(x, xi)
+
iveta = jla.solve(x, eta)
+
+
return -0.5*(xi@iveta + eta@ivxi) \
+
+ 0.5*x@self._inv_id_opt(
+
self._lie_algebra_proj(
+
lie(self._id_opt(ivxi), iveta.T) \
+
+ lie(self._id_opt(iveta), ivxi.T)))
+
+
+
[docs]
+
@partial(jax.jit, static_argnums=(0,))
+
def gamma_ambient(self, x, xi, eta):
+
"""Christoffel function for ambient manifold.
+
"""
+
ivx = jla.inv(x)
+
return 0.5*x@self._inv_id_opt(
+
- self._id_opt(ivx@xi@ivx@eta + ivx@eta@ivx@xi) \
+
+ lie(self._id_opt(ivx@xi), eta.T@ivx.T) \
+
+ lie(self._id_opt(ivx@eta), xi.T@ivx.T))
+
+
+
+
[docs]
+
def retract(self, x, v):
+
raise NotImplementedError
+
+
+
def _make_id_drift_longer(self):
+
"""make the drift at identity.
+
The longer way, sum gamma.x.
+
"""
+
p = self.shape[0]
+
drft = jnp.zeros(self.shape)
+
for i in range(p):
+
for j in range(p):
+
eij = jnp.zeros((p, p)).at[i, j].set(1.)
+
drft -= self.gamma(jnp.eye(p), eij,
+
self._lie_algebra_proj(self._inv_id_opt(eij)))
+
return 0.5*drft
+
+
def _make_id_drift(self):
+
"""make the drift at identity.
+
Simplify so we dont need to evaluate gamma.
+
"""
+
p = self.shape[0]
+
v = jnp.zeros((p, p))
+
zr = jnp.zeros((p, p))
+
+
def lp(a):
+
return self._lie_algebra_proj(a)
+
for i in range(p):
+
for j in range(p):
+
eij = zr.at[i, j].set(1.)
+
v += - eij@self._inv_id_opt(lp(eij)) \
+
+ self._inv_id_opt(lp(lie(lp(eij), eij.T)))
+
+
return -0.5*v
+
+
+
def _make_v0(self):
+
""" make v0, the identity tangent vector corresponding to
+
the Stratonovich drift.
+
"""
+
p = self.shape[0]
+
v = jnp.zeros((p, p))
+
zr = jnp.zeros((p, p))
+
+
for i in range(p):
+
for j in range(p):
+
eij = zr.at[i, j].set(1.)
+
+
v += self._inv_id_opt(
+
self._lie_algebra_proj(
+
lie(self._lie_algebra_proj(eij), eij.T)))
+
+
return -0.5*v
+
+
+
[docs]
+
def approx_nearest(self, q):
+
""" find point on the manifold that
+
is nearest to q, same order as the nearest point.
+
"""
+
raise NotImplementedError
+
+
+
@partial(jax.jit, static_argnums=(0,))
+
def ito_drift(self, x):
+
return x@self.id_drift
+
+
+
+
[docs]
+
@partial(jax.jit, static_argnums=(0,))
+
def stratonovich_drift(self, x):
+
""" Stratonovich drift.
+
"""
+
return x@self.v0
+
+
+
# @partial(jax.jit, static_argnums=(0,))
+
+
[docs]
+
def laplace_beltrami(self, x, egradx, ehessvp):
+
p = self.shape[0]
+
tup = jnp.zeros((p, p))
+
ret = 0
+
for i in range(p):
+
for j in range(p):
+
e_ij = tup.at[i, j].set(1.)
+
ret += self.proj(x, self.inv_g_metric(
+
x, ehessvp(x, e_ij)))[i, j]
+
return ret + 2*jnp.sum(self.ito_drift(x)*egradx)
+
+
+
+
[docs]
+
def left_invariant_vector_field(self, x, v):
+
""" map from a unit vector in the trace metric
+
to a vector field with unit length in the
+
left invariant metric.
+
"""
+
return x@self._mat_apply(self._i_sqrt_g_mat, v)
+
+
+
+
[docs]
+
@partial(jax.jit, static_argnums=(0,))
+
def pseudo_transport(self, x, y, v):
+
"""the easy one
+
"""
+
return y@jla.solve(x, v)
+
+
+
+
[docs]
+
@partial(jax.jit, static_argnums=(0,))
+
def sigma_id(self, dw):
+
""" sigma, to generate the Brownian motion at the identity.
+
"""
+
return self._lie_algebra_proj(self._mat_apply(self._i_sqrt_g_mat, dw))
+
+
+
+
[docs]
+
@partial(jax.jit, static_argnums=(0,))
+
def sigma(self, x, dw):
+
""" sigma, to generate the Brownian motion.
+
"""
+
return x@self.sigma_id(dw)
+
+
+