From bb12d26633118cea71f7b019e882d8b2a1374299 Mon Sep 17 00:00:00 2001 From: dnguyend Date: Fri, 26 Jul 2024 12:34:39 -0400 Subject: [PATCH] Add Langevin Stiefel method --- tests/notebooks/LangevinStiefel.ipynb | 742 ++++++++++++++++++++++++++ 1 file changed, 742 insertions(+) create mode 100644 tests/notebooks/LangevinStiefel.ipynb diff --git a/tests/notebooks/LangevinStiefel.ipynb b/tests/notebooks/LangevinStiefel.ipynb new file mode 100644 index 0000000..fd64c7b --- /dev/null +++ b/tests/notebooks/LangevinStiefel.ipynb @@ -0,0 +1,742 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "authorship_tag": "ABX9TyPmXvXiPZZ3MBv1sL7ITipf", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "$\\newcommand{\\sigmam}{\\mathring{\\sigma}}$\n", + "$\\newcommand{\\egrad}{\\mathsf{egrad}}$\n", + "$\\newcommand{\\rgrad}{\\mathsf{rgrad}}$\n", + "$\\newcommand{\\sfg}{\\mathsf{g}}$\n", + "$\\newcommand{\\cI}{\\mathcal{I}}$\n", + "$\\newcommand{\\R}{\\mathbb{R}}$\n", + "$\\newcommand{\\fR}{\\mathfrak{r}}$\n", + "\n", + "# Simulating Riemannian Langevin equations on Stiefel manifolds\n", + "The long-time limit of the Riemannian Langevin process allows us to sample a distribution on a manifold with a specified density relative to the Riemannian measure.\n", + "\n", + "The Langevin process is a solution of the equation\n", + "$$dX_t = (\\frac{1}{2}rgrad_{\\log V}(X_t) + \\mu_B(X_t))dt +\\sigmam(X_t) dW_t\n", + "$$\n", + "where $V$ is a smooth function positive function on $M$ and\n", + "$$dB_t = \\mu_B(B_t)dt +\\sigmam(B_t) dW_t\n", + "$$\n", + "is the Riemannian Brownian motion of a metric $\\sfg$. With smoothness and curvature conditions, this process converges to a distribution on the manifold with density relative to the Riemannian measure proportional to $V$.\n", + "\n", + "We test the von Mises-Fisher and Bingham distributions in this workbook. First, install jax_rb:" + ], + "metadata": { + "id": "bmG_t7bGESZf" + } + }, + { + "cell_type": "code", + "source": [ + "pip install git+https://github.com/dnguyend/jax-rb" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "IsEk89QmSlLj", + "outputId": "234c2dbd-4a27-406f-f49d-897c05a44a52" + }, + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting git+https://github.com/dnguyend/jax-rb\n", + " Cloning https://github.com/dnguyend/jax-rb to /tmp/pip-req-build-lo4kwtj9\n", + " Running command git clone --filter=blob:none --quiet https://github.com/dnguyend/jax-rb /tmp/pip-req-build-lo4kwtj9\n", + " Resolved https://github.com/dnguyend/jax-rb to commit 829c06c2301ca7671986bc311b70b19267494bf9\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "Requirement already satisfied: jax in /usr/local/lib/python3.10/dist-packages (from jax_rb==0.1.dev53+g829c06c) (0.4.26)\n", + "Requirement already satisfied: jaxlib in /usr/local/lib/python3.10/dist-packages (from jax_rb==0.1.dev53+g829c06c) (0.4.26+cuda12.cudnn89)\n", + "Requirement already satisfied: ml-dtypes>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from jax->jax_rb==0.1.dev53+g829c06c) (0.2.0)\n", + "Requirement already satisfied: numpy>=1.22 in /usr/local/lib/python3.10/dist-packages (from jax->jax_rb==0.1.dev53+g829c06c) (1.25.2)\n", + "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax->jax_rb==0.1.dev53+g829c06c) (3.3.0)\n", + "Requirement already satisfied: scipy>=1.9 in /usr/local/lib/python3.10/dist-packages (from jax->jax_rb==0.1.dev53+g829c06c) (1.11.4)\n", + "Building wheels for collected packages: jax_rb\n", + " Building wheel for jax_rb (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for jax_rb: filename=jax_rb-0.1.dev53+g829c06c-py3-none-any.whl size=33510 sha256=bb66af84c0cce2991645eabc97adfd8a2e9d37e3f94cecc2577e8daf64fa696f\n", + " Stored in directory: /tmp/pip-ephem-wheel-cache-rl611tqh/wheels/0f/76/88/65e675f8bcca47be98c588d9a787a4c1c9b0a5044517ba6490\n", + "Successfully built jax_rb\n", + "Installing collected packages: jax_rb\n", + "Successfully installed jax_rb-0.1.dev53+g829c06c\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Basic imports and helper functions.\n", + "\n", + "We can sample the Stiefel manifold with the homogeneous measure uniformly by polar decomposition.\n", + "\n", + "On another note, consider the polar decomposition as a retraction. We can simulate an Ito process by applying an adjustment to the drift. Class stiefel_polar_retraction implement the polar retraction, and provide the adjustment for the drift." + ], + "metadata": { + "id": "1ITx2ORlqC2Q" + } + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "FVjba8YeD9n8" + }, + "outputs": [], + "source": [ + "\"\"\" test riemannian langevin for stiefel manifolds\n", + "\"\"\"\n", + "\n", + "from functools import partial\n", + "from time import perf_counter\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import jax.numpy.linalg as jla\n", + "from jax import random, vmap, jit\n", + "\"\"\"\n", + "\"\"\"\n", + "from jax.scipy.linalg import expm\n", + "import jax_rb.manifolds.stiefel as stm\n", + "\n", + "from jax_rb.utils.utils import (sym, grand)\n", + "import jax_rb.simulation.simulator as sim\n", + "import jax_rb.simulation.global_manifold_integrator as gmi\n", + "\n", + "\n", + "jax.config.update(\"jax_enable_x64\", True)\n", + "\n", + "\n", + "def sqr(x):\n", + " return x@x\n", + "\n", + "\n", + "def cz(mat):\n", + " return jnp.max(jnp.abs(mat))\n", + "\n", + "\n", + "class stiefel_polar_retraction():\n", + " def __init__(self, mnf):\n", + " self.mnf = mnf\n", + "\n", + " def retract(self, x, v):\n", + " \"\"\"rescaling :math:`x+v` to be on the manifold\n", + " \"\"\"\n", + " u, _, vt = jla.svd(x+v, full_matrices=False)\n", + " return u@vt\n", + "\n", + " def drift_adjust(self, x):\n", + " n, d, alp1 = self.mnf.shape[0], self.mnf.shape[1], self.mnf.alpha[1]\n", + " return 0.5*(n-d+0.5*(d-1)/alp1)*x\n", + "\n", + "\n", + "def uniform_sampling(key, shape, pay_off, n_samples):\n", + " \"\"\" Sample the manifold uniformly\n", + " \"\"\"\n", + " x_all, key = grand(key, (shape[0], shape[1], n_samples))\n", + "\n", + " def do_one_point(seq):\n", + " # ei, ev = jla.eigh(seq.T@seq)\n", + " # return pay_off(seq@ev@((1/jnp.sqrt(ei))[:, None]*ev.T))\n", + " u, _, vt = jla.svd(seq, full_matrices=False)\n", + " return pay_off(u[:, :shape[0]]@vt)\n", + "\n", + " s = jax.vmap(do_one_point, in_axes=2)(x_all)\n", + " return jnp.nanmean(s)\n", + "\n", + "def gen_sym_traceless(key, n):\n", + " \"\"\" Generating a traceless symmetric matrix\n", + " \"\"\"\n", + " A, key = grand(key, (n, n))\n", + " return sym(A) - jnp.trace(A)/n*jnp.eye(n), key\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Test von Mises-Fisher\n", + "For the von Mises-Fisher the density is proportional to $V(x) = e^{\\kappa Tr (M^Tx)}$. In this case, $\\egrad_{\\log V}(x)=\\kappa M$, and $\\rgrad_{\\log V}(x) = \\kappa\\Pi(x)g^{-1}M$, this additional drift is used in the _with_drift methods in the library.\n", + "\n", + "For each test, given a function $f$, we run the 3 integration methods: Ito, Stratonovich and geodesic, which is a second order retraction of the projected drift. Long term, these integration methods are supposed to converge to the von Mises-Fisher distribution. The average simulated value $f(X_t)$ should converge to $E(f)$. We verify that they are consistent, and agree with $\\frac{\\int fVdvol}{\\int Vdvol}$, where each integral is computed by uniform sampling.\n", + "\n", + "For some cases, we test two different functions $f$. All examples show good agreement." + ], + "metadata": { + "id": "O-2b-ewlSr0r" + } + }, + { + "cell_type": "code", + "source": [ + "\n", + "def test_stiefel_langevin_von_mises_fisher(key, stf, kp, M, func, t_final, n_path, n_div, n_samples=1000**2):\n", + " # test Langevin on stiefel with vfunc = e^{-\\frac{1}{2}v^T\\Lambda v}\n", + " # jax.config.update('jax_default_device', jax.devices('cpu')[0])\n", + " print(\"Doing Stiefel von Mises Fisher (n, d)=%s alpha=%s\" % (str(stf.shape), str(stf.alpha)))\n", + "\n", + " @partial(jit, static_argnums=(0,))\n", + " def log_v(_, x):\n", + " return kp*jnp.trace(M.T@x)\n", + "\n", + " @partial(jit, static_argnums=(0,))\n", + " def grad_log_v(mnf, x):\n", + " return kp*mnf.proj(x, mnf.inv_g_metric(x, M))\n", + "\n", + " x, key = stf.rand_point(key)\n", + " eta, key = stf.rand_vec(key, x)\n", + "\n", + " # print(jax.jvp(lambda x: log_v(stf, x), (x,), (eta,))[1])\n", + " # print(stf.inner(x, grad_log_v(stf, x), eta))\n", + "\n", + " pay_offs = [None, func]\n", + "\n", + " x_0, key = stf.rand_point(key)\n", + " key, sk = random.split(key)\n", + " # t_final = 5.\n", + " # n_path = 10000\n", + " # n_div = 500\n", + " d_coeff = .5\n", + "\n", + " wiener_dim = stf.shape[0]*stf.shape[1]\n", + " # crtr = cayley_se_retraction(se)\n", + "\n", + " # rbrownian_ito_langevin_move(mnf, x, unit_move, scale, grad_log_v)\n", + " ret_rtr1 = sim.simulate(x_0,\n", + " lambda x, unit_move, scale: gmi.ito_move_with_drift(\n", + " stf, x, unit_move, scale, 0.5*grad_log_v(stf, x)),\n", + " pay_offs[0],\n", + " # lambda x: x[1, -1]*x[1, -1],\n", + " pay_offs[1],\n", + " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", + "\n", + " print(\"ito langevin %.3f\" % jnp.nanmean(ret_rtr1[0]))\n", + "\n", + " ret_rtr2 = sim.simulate(x_0,\n", + " lambda x, unit_move, scale: gmi.stratonovich_move_with_drift(\n", + " stf, x, unit_move, scale, 0.5*grad_log_v(stf, x)),\n", + " pay_offs[0],\n", + " # lambda x: x[1, -1]*x[1, -1],\n", + " pay_offs[1],\n", + " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", + "\n", + " print(\"stratonovich langevin %.3f\" % jnp.nanmean(ret_rtr2[0]))\n", + "\n", + " ret_rtr3 = sim.simulate(x_0,\n", + " lambda x, unit_move, scale: gmi.geodesic_move_with_drift(\n", + " stf, x, unit_move, scale, 0.5*grad_log_v(stf, x)),\n", + " pay_offs[0],\n", + " # lambda x: x[1, -1]*x[1, -1],\n", + " pay_offs[1],\n", + " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", + "\n", + " print(\"geodesic langevin %.3f\" % jnp.nanmean(ret_rtr3[0]))\n", + "\n", + "\n", + " ret_spl = uniform_sampling(key, stf.shape,\n", + " lambda x: pay_offs[1](x)*jnp.exp(log_v(None, x)),\n", + " n_samples)\n", + "\n", + " ret_spl_0 = uniform_sampling(key, stf.shape,\n", + " lambda x: jnp.exp(log_v(None, x)),\n", + " n_samples)\n", + "\n", + " print(\"stiefel uniform sampling with density %.3f\" % (ret_spl/ret_spl_0))\n", + " # import scipy.special as ss\n", + " # print(jnp.sqrt(2)*ss.iv(1, 1)/ss.iv(.5, 1)*ss.gamma(1.5))\n", + "\n", + "\n", + "def test_all_stiefel_von_mises_fisher():\n", + " n = 3\n", + " d = 1\n", + "\n", + " alp = jnp.array([1, 1.])\n", + " key = random.PRNGKey(0)\n", + "\n", + " stf = stm.RealStiefelAlpha((n, d), alp)\n", + "\n", + " # F, key = stf.rand_point(key)\n", + " kp = 1.\n", + " M, key = stf.rand_point(key)\n", + " test_stiefel_langevin_von_mises_fisher(\n", + " key, stf, kp, M,\n", + " lambda x: jnp.sqrt(jnp.abs(1-jnp.trace(M.T@x)**2)), t_final=5., n_path=10000, n_div=500, n_samples=1000**2)\n", + " # print(jnp.sqrt(2)*ss.iv(1, 1)/ss.iv(.5, 1)*ss.gamma(1.5))\n", + "\n", + "\n", + " n = 5\n", + " d = 3\n", + " alp = jnp.array([1, .6])\n", + " key = random.PRNGKey(0)\n", + "\n", + " stf = stm.RealStiefelAlpha((n, d), alp)\n", + "\n", + " # F, key = stf.rand_point(key)\n", + " kp = 1.2\n", + " M, key = stf.rand_point(key)\n", + "\n", + " test_stiefel_langevin_von_mises_fisher(key, stf, kp, M, lambda x: jnp.sqrt(jnp.abs(1-jnp.trace(M.T@x)**2)),\n", + " t_final=5., n_path=10000, n_div=500, n_samples=1000**2)\n", + "\n", + " test_stiefel_langevin_von_mises_fisher(key, stf, kp, M, lambda x: jnp.sum(jnp.abs(x)), t_final=5.,\n", + " n_path=10000, n_div=500,n_samples=1000**2)\n", + "\n", + " n = 5\n", + " d = 3\n", + " alp = jnp.array([1, 1.])\n", + " key = random.PRNGKey(0)\n", + "\n", + " stf = stm.RealStiefelAlpha((n, d), alp)\n", + "\n", + " # F, key = stf.rand_point(key)\n", + " kp = 1.2\n", + " M, key = stf.rand_point(key)\n", + "\n", + " test_stiefel_langevin_von_mises_fisher(key, stf, kp, M, lambda x: jnp.sqrt(jnp.abs(1-jnp.trace(M.T@x)**2)),\n", + " t_final=5., n_path=10000, n_div=500, n_samples=1000**2)\n", + "\n", + " test_stiefel_langevin_von_mises_fisher(key, stf, kp, M, lambda x: jnp.sum(jnp.abs(x)), t_final=5.,\n", + " n_path=10000, n_div=500,n_samples=1000**2)\n", + "\n", + "\n", + "test_all_stiefel_von_mises_fisher()\n", + "import scipy.special as ss\n", + "print(\"Exact value for m=3 d=1 val=%f\" % (jnp.sqrt(2)*ss.iv(1, 1.)/ss.iv(.5, 1.)*ss.gamma(1.5)))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5wUdQp2tSpdj", + "outputId": "9420b6eb-6e07-49b1-de94-f26c1a6be3ec" + }, + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Doing Stiefel von Mises Fisher (n, d)=(3, 1) alpha=[1. 1.]\n", + "ito langevin 0.756\n", + "stratonovich langevin 0.756\n", + "geodesic langevin 0.755\n", + "stiefel uniform sampling with density 0.754\n", + "Doing Stiefel von Mises Fisher (n, d)=(5, 3) alpha=[1. 0.6]\n", + "ito langevin 0.885\n", + "stratonovich langevin 0.885\n", + "geodesic langevin 0.888\n", + "stiefel uniform sampling with density 0.882\n", + "Doing Stiefel von Mises Fisher (n, d)=(5, 3) alpha=[1. 0.6]\n", + "ito langevin 5.629\n", + "stratonovich langevin 5.627\n", + "geodesic langevin 5.629\n", + "stiefel uniform sampling with density 5.624\n", + "Doing Stiefel von Mises Fisher (n, d)=(5, 3) alpha=[1. 1.]\n", + "ito langevin 0.883\n", + "stratonovich langevin 0.881\n", + "geodesic langevin 0.885\n", + "stiefel uniform sampling with density 0.882\n", + "Doing Stiefel von Mises Fisher (n, d)=(5, 3) alpha=[1. 1.]\n", + "ito langevin 5.633\n", + "stratonovich langevin 5.631\n", + "geodesic langevin 5.633\n", + "stiefel uniform sampling with density 5.624\n", + "Exact value for m=3 d=1 val=0.755402\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "A taller example. $(n, d)=(27,2)$." + ], + "metadata": { + "id": "X_eR2VGZ_LQD" + } + }, + { + "cell_type": "code", + "source": [ + "def test_tall_stiefel_von_mises_fisher():\n", + " n = 27\n", + " d = 2\n", + " alp = jnp.array([1, .6])\n", + " key = random.PRNGKey(0)\n", + "\n", + " stf = stm.RealStiefelAlpha((n, d), alp)\n", + "\n", + " # F, key = stf.rand_point(key)\n", + " kp = 1.2\n", + " M, key = stf.rand_point(key)\n", + "\n", + " test_stiefel_langevin_von_mises_fisher(key, stf, kp, M, lambda x: jnp.sqrt(jnp.abs(1-jnp.trace(M.T@x)**2)),\n", + " t_final=5., n_path=10000, n_div=500, n_samples=10000)\n", + "\n", + " test_stiefel_langevin_von_mises_fisher(key, stf, kp, M,\n", + " lambda x: jnp.sum(jnp.abs(x)),\n", + " t_final=5., n_path=10000, n_div=500, n_samples=10000)\n", + "\n", + "test_tall_stiefel_von_mises_fisher()\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "bjDB4QxZ7Trn", + "outputId": "9cd0d4e7-b951-4d29-a290-1e3eb4aa5722" + }, + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Doing Stiefel von Mises Fisher (n, d)=(27, 2) alpha=[1. 0.6]\n", + "ito langevin 0.957\n", + "stratonovich langevin 0.956\n", + "geodesic langevin 0.956\n", + "stiefel uniform sampling with density 0.957\n", + "Doing Stiefel von Mises Fisher (n, d)=(27, 2) alpha=[1. 0.6]\n", + "ito langevin 8.372\n", + "stratonovich langevin 8.372\n", + "geodesic langevin 8.370\n", + "stiefel uniform sampling with density 8.363\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Test Bingham\n", + "The density is $e^{Trx^TAx}$, thus the Riemannian gradient is $2\\Pi\\sfg^{-1}A x$, where $A$ is a traceless matrix." + ], + "metadata": { + "id": "LJlkPargS2kT" + } + }, + { + "cell_type": "code", + "source": [ + "def test_stiefel_langevin_bingham(key, stf, A, func):\n", + " # test Langevin on stiefel with vfunc = e^{-\\frac{1}{2}v^T\\Lambda v}\n", + " # jax.config.update('jax_default_device', jax.devices('cpu')[0])\n", + " @partial(jit, static_argnums=(0,))\n", + " def log_v(_, x):\n", + " return jnp.trace(x.T@A@x)\n", + "\n", + " @partial(jit, static_argnums=(0,))\n", + " def grad_log_v(mnf, x):\n", + " return mnf.proj(x, mnf.inv_g_metric(x, 2*A@x))\n", + "\n", + " print(\"Doing Bingham (n, d)=%s alpha=%s\" % (str(stf.shape), str(stf.alpha)))\n", + "\n", + " # x, key = stf.rand_point(key)\n", + " # eta, key = stf.rand_vec(key, x)\n", + "\n", + " # print(jax.jvp(lambda x: log_v(stf, x), (x,), (eta,))[1])\n", + " # print(stf.inner(x, grad_log_v(stf, x), eta))\n", + "\n", + " pay_offs = [None, func]\n", + "\n", + " x_0, key = stf.rand_point(key)\n", + " key, sk = random.split(key)\n", + " t_final = 5.\n", + " n_path = 10000\n", + " n_div = 500\n", + " d_coeff = .5\n", + "\n", + " wiener_dim = stf.shape[0]*stf.shape[1]\n", + " # crtr = cayley_se_retraction(se)\n", + "\n", + " # rbrownian_ito_langevin_move(mnf, x, unit_move, scale, grad_log_v)\n", + " ret_rtr1 = sim.simulate(x_0,\n", + " lambda x, unit_move, scale: gmi.ito_move_with_drift(\n", + " stf, x, unit_move, scale, 0.5*grad_log_v(stf, x)),\n", + " pay_offs[0],\n", + " # lambda x: x[1, -1]*x[1, -1],\n", + " pay_offs[1],\n", + " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", + "\n", + " print(\"ito langevin %.3f\" % jnp.nanmean(ret_rtr1[0]))\n", + "\n", + " ret_rtr2 = sim.simulate(x_0,\n", + " lambda x, unit_move, scale: gmi.stratonovich_move_with_drift(\n", + " stf, x, unit_move, scale, 0.5*grad_log_v(stf, x)),\n", + " pay_offs[0],\n", + " # lambda x: x[1, -1]*x[1, -1],\n", + " pay_offs[1],\n", + " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", + "\n", + " print(\"stratonovich langevin %.3f\" % jnp.nanmean(ret_rtr2[0]))\n", + "\n", + " ret_rtr3 = sim.simulate(x_0,\n", + " lambda x, unit_move, scale: gmi.geodesic_move_with_drift(\n", + " stf, x, unit_move, scale, 0.5*grad_log_v(stf, x)),\n", + " pay_offs[0],\n", + " # lambda x: x[1, -1]*x[1, -1],\n", + " pay_offs[1],\n", + " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", + "\n", + " print(\"geodesic langevin %.3f\" % jnp.nanmean(ret_rtr3[0]))\n", + "\n", + " n_samples = 1000**2\n", + " ret_spl = uniform_sampling(key, stf.shape,\n", + " lambda x: pay_offs[1](x)*jnp.exp(log_v(None, x)),\n", + " n_samples)\n", + "\n", + " ret_spl_0 = uniform_sampling(key, stf.shape,\n", + " lambda x: jnp.exp(log_v(None, x)),\n", + " n_samples)\n", + "\n", + " print(\"stiefel uniform sampling with density %.3f\" % (ret_spl/ret_spl_0))\n", + "\n", + "def test_all_bingham():\n", + " n = 3\n", + " d = 1\n", + " alp = jnp.array([1, .6])\n", + " key = random.PRNGKey(0)\n", + "\n", + " stf = stm.RealStiefelAlpha((n, d), alp)\n", + "\n", + " A, key = gen_sym_traceless(key, n)\n", + " test_stiefel_langevin_bingham(\n", + " key, stf, A,\n", + " lambda x: jnp.sum(jnp.abs(x)))\n", + " # print(jnp.sqrt(2)*ss.iv(1, 1)/ss.iv(.5, 1)*ss.gamma(1.5))\n", + "\n", + " n = 5\n", + " d = 3\n", + " alp = jnp.array([1, .6])\n", + " key = random.PRNGKey(0)\n", + "\n", + " stf = stm.RealStiefelAlpha((n, d), alp)\n", + " A, key = gen_sym_traceless(key, n)\n", + " test_stiefel_langevin_bingham(\n", + " key, stf, A,\n", + " lambda x: jnp.sum(jnp.abs(x)))\n", + "\n", + " test_stiefel_langevin_bingham(\n", + " key, stf, A,\n", + " lambda x: jnp.sum(jnp.abs(x)*(A@jnp.abs(x))))\n", + "\n", + " n = 7\n", + " d = 3\n", + " alp = jnp.array([1, .6])\n", + " key = random.PRNGKey(0)\n", + "\n", + " stf = stm.RealStiefelAlpha((n, d), alp)\n", + " A, key = gen_sym_traceless(key, n)\n", + " test_stiefel_langevin_bingham(\n", + " key, stf, A,\n", + " lambda x: jnp.sum(jnp.abs(x)))\n", + "\n", + " test_stiefel_langevin_bingham(\n", + " key, stf, A,\n", + " lambda x: jnp.sum(jnp.abs(x)*(A@jnp.abs(x))))\n", + "\n", + "\n", + "test_all_bingham()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "3FaJdgmyS1Tr", + "outputId": "3e04a36b-b0df-4945-d80f-81db26b0f81f" + }, + "execution_count": 5, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Doing Bingham (n, d)=(3, 1) alpha=[1. 0.6]\n", + "ito langevin 1.494\n", + "stratonovich langevin 1.495\n", + "geodesic langevin 1.493\n", + "stiefel uniform sampling with density 1.497\n", + "Doing Bingham (n, d)=(5, 3) alpha=[1. 0.6]\n", + "ito langevin 5.613\n", + "stratonovich langevin 5.617\n", + "geodesic langevin 5.613\n", + "stiefel uniform sampling with density 5.616\n", + "Doing Bingham (n, d)=(5, 3) alpha=[1. 0.6]\n", + "ito langevin -0.559\n", + "stratonovich langevin -0.571\n", + "geodesic langevin -0.555\n", + "stiefel uniform sampling with density -0.565\n", + "Doing Bingham (n, d)=(7, 3) alpha=[1. 0.6]\n", + "ito langevin 6.551\n", + "stratonovich langevin 6.551\n", + "geodesic langevin 6.550\n", + "stiefel uniform sampling with density 6.553\n", + "Doing Bingham (n, d)=(7, 3) alpha=[1. 0.6]\n", + "ito langevin -0.563\n", + "stratonovich langevin -0.567\n", + "geodesic langevin -0.548\n", + "stiefel uniform sampling with density -0.579\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "We also run the drift-adjust method for the polar decomposition. The second order expansion of the Polar decompostion is given as\n", + "$$\\fR(x, hv) = x+ hv-\\frac{h^2}{2}xv^{T}v + O(h^3)\n", + "$$\n", + "This is because if $x + hv = U\\Sigma$ is a polar deconposition then since $x^Tv$ is antisymmetric,\n", + "$$(x + hv)^T(x + hv)=\n", + "I_p + h^2v^Tv = \\Sigma^2,$$ or\n", + "$$U = (x + hv)(I_p + h^2v^Tv)^{-\\frac{1}{2}} = x + hv -\\frac{h^2}{2}xv^Tv+O(h^3).$$\n", + "We can check in this case, the second order adjustment\n", + "$$-\\frac{1}{2}\\sum_{ij}\\fR^{(2)}(x,0, \\Pi\\sfg^{-1}E_{ij}, \\Pi\\sfg^{-1}E_{ij}) $$\n", + "cancels the Riemannian-Brown Ito drift, leaving just the density contribution.\n", + "Results are below, showing consistency of the methods." + ], + "metadata": { + "id": "hx8PkACZmIzM" + } + }, + { + "cell_type": "code", + "source": [ + "def drift_adjust_verify(self, x, sigma, wiener_dim):\n", + " \"\"\"return the adjustment :math:`\\\\mu_{adj}`\n", + " so that :math:`\\\\mu + \\\\mu_{adj} = \\\\mu_{\\\\mathfrak{r}}`\n", + " \"\"\"\n", + " def sqt(a):\n", + " return a.T@a\n", + "\n", + " return -0.5*x@jnp.sum(vmap(lambda seq:\n", + " -sqt(self.proj(x, sigma(x, seq.reshape(x.shape)))))(jnp.eye(wiener_dim)),\n", + " axis=0)\n", + "\n", + "\n", + "def test_polar_retract_adjust():\n", + " n = 7\n", + " d = 3\n", + " alp = jnp.array([1, .6])\n", + " key = random.PRNGKey(0)\n", + " stf = stm.RealStiefelAlpha((n, d), alp)\n", + " print(\"Doing Stiefel Polar retract for Bingham (n, d)=%s alpha=%s\" % (str(stf.shape), str(stf.alpha)))\n", + " @partial(jit, static_argnums=(0,))\n", + " def log_v(_, x):\n", + " return jnp.trace(x.T@A@x)\n", + "\n", + " @partial(jit, static_argnums=(0,))\n", + " def grad_log_v(mnf, x):\n", + " return mnf.proj(x, mnf.inv_g_metric(x, 2*A@x))\n", + "\n", + " x, key = stf.rand_point(key)\n", + "\n", + " # mu2 = -0.5*(n-d+0.5*(d-1)/alp[1])*x\n", + " prtr = stiefel_polar_retraction(stf)\n", + "\n", + " mu1 = drift_adjust_verify(stf, x, stf.sigma, n*d)\n", + " mu2 = prtr.drift_adjust(x)\n", + " # print(\"compare drift adjust %s\" % str(mu2 - mu1))\n", + " print(\"compare drift adjust %s\" % str(mu1 + stf.ito_drift(x)))\n", + "\n", + " A, key = gen_sym_traceless(key, n)\n", + "\n", + " x_0, key = stf.rand_point(key)\n", + " pay_offs = [None, lambda x: jnp.sum(jnp.abs(x))]\n", + "\n", + " key, sk = random.split(key)\n", + " t_final = 5.\n", + " n_path = 10000\n", + " n_div = 500\n", + " d_coeff = .5\n", + "\n", + " wiener_dim = stf.shape[0]*stf.shape[1]\n", + "\n", + " test_stiefel_langevin_bingham(key, stf, A, pay_offs[1])\n", + "\n", + " @jax.jit\n", + " def polar_adj(x, unit_move, scale):\n", + " return prtr.retract(x, stf.proj(x, stf.sigma(x, unit_move.reshape(x.shape)*scale**.5\n", + " + scale*(\n", + " # prtr.drift_adjust(x)\n", + " # + stf.ito_drift(x)\n", + " + 0.5*grad_log_v(stf, x)))))\n", + " ret_rtr = sim.simulate(x_0,\n", + " polar_adj,\n", + " pay_offs[0],\n", + " pay_offs[1],\n", + " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", + "\n", + " print(\"Polar adjust %.3f\" % jnp.nanmean(ret_rtr[0]))\n", + "\n", + "\n", + "test_polar_retract_adjust()\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "23eloFGjS6hj", + "outputId": "62778763-e6d0-4e0b-c3db-51583709bfe1" + }, + "execution_count": 6, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Doing Stiefel Polar retract for Bingham (n, d)=(7, 3) alpha=[1. 0.6]\n", + "compare drift adjust [[ 4.44089210e-16 -3.33066907e-16 -5.55111512e-17]\n", + " [ 0.00000000e+00 1.66533454e-16 2.22044605e-16]\n", + " [-5.55111512e-17 4.44089210e-16 2.22044605e-16]\n", + " [ 1.11022302e-16 -3.33066907e-16 0.00000000e+00]\n", + " [ 0.00000000e+00 -6.66133815e-16 -2.22044605e-16]\n", + " [ 0.00000000e+00 5.55111512e-17 -1.11022302e-16]\n", + " [-2.22044605e-16 -1.11022302e-15 1.11022302e-16]]\n", + "Doing Bingham (n, d)=(7, 3) alpha=[1. 0.6]\n", + "ito langevin 6.562\n", + "stratonovich langevin 6.564\n", + "geodesic langevin 6.563\n", + "stiefel uniform sampling with density 6.564\n", + "Polar adjust 6.557\n" + ] + } + ] + } + ] +} \ No newline at end of file