diff --git a/tests/notebooks/LangevinGroup.ipynb b/tests/notebooks/LangevinGroup.ipynb new file mode 100644 index 0000000..eb1580b --- /dev/null +++ b/tests/notebooks/LangevinGroup.ipynb @@ -0,0 +1,649 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "authorship_tag": "ABX9TyP16ELfRbFYOu/Q/JtpXi0J" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "$\\newcommand{\\sigmam}{\\mathring{\\sigma}}$\n", + "$\\newcommand{\\rgrad}{\\mathsf{rgrad}}$\n", + "$\\newcommand{\\sfg}{\\mathsf{g}}$\n", + "$\\newcommand{\\cI}{\\mathcal{I}}$\n", + "$\\newcommand{\\R}{\\mathbb{R}}$\n", + "$\\newcommand{\\fR}{\\mathfrak{r}}$\n", + "# Simulating the Riemannian Langevin process to sample a distribution on a manifold with a specified density relative to the Riemannian measure on matrix Lie Groups.\n", + "\n", + "The Langevin process is a solution of the equation\n", + "$$dX_t = (\\frac{1}{2}rgrad_{\\log V}(X_t) + \\mu_B(X_t))dt +\\sigmam(X_t) dW_t\n", + "$$\n", + "where $V$ is a smooth positive function on $M$ and\n", + "$$dB_t = \\mu_B(B_t)dt +\\sigmam(B_t) dW_t\n", + "$$\n", + "is the Riemannian Brownian motion of a metric $\\sfg$. With smoothness and curvature conditions, this process converges to a distribution on the manifold with density relative to the Riemannian measure proportional to $V$.\n", + "\n", + "We test the groups SO and SE in this workbook. First, install jax_rb" + ], + "metadata": { + "id": "RVwgseAEAnFg" + } + }, + { + "cell_type": "code", + "source": [ + "pip install git+https://github.com/dnguyend/jax-rb" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Sf83bhTlBsbm", + "outputId": "dd3194cb-5e66-4a1f-e711-cc3a47c39a91" + }, + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting git+https://github.com/dnguyend/jax-rb\n", + " Cloning https://github.com/dnguyend/jax-rb to /tmp/pip-req-build-ph2b9mtd\n", + " Running command git clone --filter=blob:none --quiet https://github.com/dnguyend/jax-rb /tmp/pip-req-build-ph2b9mtd\n", + " Resolved https://github.com/dnguyend/jax-rb to commit 829c06c2301ca7671986bc311b70b19267494bf9\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "Requirement already satisfied: jax in /usr/local/lib/python3.10/dist-packages (from jax_rb==0.1.dev53+g829c06c) (0.4.26)\n", + "Requirement already satisfied: jaxlib in /usr/local/lib/python3.10/dist-packages (from jax_rb==0.1.dev53+g829c06c) (0.4.26+cuda12.cudnn89)\n", + "Requirement already satisfied: ml-dtypes>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from jax->jax_rb==0.1.dev53+g829c06c) (0.2.0)\n", + "Requirement already satisfied: numpy>=1.22 in /usr/local/lib/python3.10/dist-packages (from jax->jax_rb==0.1.dev53+g829c06c) (1.25.2)\n", + "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax->jax_rb==0.1.dev53+g829c06c) (3.3.0)\n", + "Requirement already satisfied: scipy>=1.9 in /usr/local/lib/python3.10/dist-packages (from jax->jax_rb==0.1.dev53+g829c06c) (1.11.4)\n", + "Building wheels for collected packages: jax_rb\n", + " Building wheel for jax_rb (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for jax_rb: filename=jax_rb-0.1.dev53+g829c06c-py3-none-any.whl size=33510 sha256=220c81000e3e5e7e676ecce151f8832014c1b3cd9166c4c51108fd4bda63ea3b\n", + " Stored in directory: /tmp/pip-ephem-wheel-cache-6nyleohr/wheels/0f/76/88/65e675f8bcca47be98c588d9a787a4c1c9b0a5044517ba6490\n", + "Successfully built jax_rb\n", + "Installing collected packages: jax_rb\n", + "Successfully installed jax_rb-0.1.dev53+g829c06c\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Basic import. We also add a function to sample on a $SO(n)$ using polar retraction based on [Chikuse2003], to compare with the SDE simulation.\n", + "Thus, given a function $V$ as a density relative to the Riemannian measure, the expectation with respect to the measure defined by the density $V$ is\n", + " $\\frac{\\int fVdvol_{\\sfg}}{\\int Vdvol_{\\sfg}}$, where we will compute the two integrals by sampling using the function uniform_sample below\n", + "\n", + "*Reference*\\\n", + "Y. Chikuse, Statistics on Special Manifolds, Springer New York, NY,\n", + "New York, NY, USA, 2003." + ], + "metadata": { + "id": "pqACqfoCHoXo" + } + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "PaZRI6MIAVGw" + }, + "outputs": [], + "source": [ + "\"\"\" test riemannian langevin for SO and SE\n", + "\"\"\"\n", + "\n", + "\n", + "from functools import partial\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import jax.numpy.linalg as jla\n", + "from jax import random, vmap, jit\n", + "import jax_rb.manifolds.so_left_invariant as som\n", + "import jax_rb.manifolds.se_left_invariant as sem\n", + "\n", + "from jax_rb.utils.utils import (rand_positive_definite, sym, vcat, grand)\n", + "import jax_rb.simulation.simulator as sim\n", + "import jax_rb.simulation.matrix_group_integrator as mi\n", + "\n", + "\n", + "jax.config.update(\"jax_enable_x64\", True)\n", + "\n", + "\n", + "def sqr(x):\n", + " return x@x\n", + "\n", + "\n", + "def cz(mat):\n", + " return jnp.max(jnp.abs(mat))\n", + "\n", + "def uniform_sample(key, shape, pay_off, n_samples):\n", + " \"\"\" Sample the manifold uniformly\n", + " \"\"\"\n", + " x_all, key = grand(key, (shape[0], shape[1], n_samples))\n", + "\n", + " def do_one_point(seq):\n", + " # ei, ev = jla.eigh(seq.T@seq)\n", + " # return pay_off(seq@ev@((1/jnp.sqrt(ei))[:, None]*ev.T))\n", + " u, _, vt = jla.svd(seq)\n", + " return pay_off(u[:, :shape[0]]@vt)\n", + "\n", + " s = jax.vmap(do_one_point, in_axes=2)(x_all)\n", + " return jnp.nanmean(s)\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Sampling on the special orthogonal group $SO(n)$.\n", + "Besides the three main simulation methods, we also use the drift-adjust method for the polar decomposition. As in the case of the Stiefel manifold, the polar decomposition has the second order derivative\n", + "$$\\fR^{(2)}(x, 0, v,v )= -x v^Tv = x (x^Tv)^2$$\n", + "This terms offsets the Ito drift of the RB gradient term in $\\mu_{\\fR}$ (we can see this at $x=I_n$, then translate to any point). Thus, we only have the gradient component in $\\mu_{\\fR}$.\n", + "\n", + "The density is proprtional to $e^{-\\Lambda_0 x^T\\Lambda_1 x}$ with a diagonal $\\Lambda_0$. The function to take expectation is $f(x) = (1+\\sum |\\lambda_{ij}x_{ij}|)^{\\frac{1}{2}}$.\n" + ], + "metadata": { + "id": "GOnOoYNkIfH3" + } + }, + { + "cell_type": "code", + "source": [ + "class cayley_so_retraction():\n", + " \"\"\"Cayley retraction of a matrix Lie group\n", + " this is the most general, and not efficient implementation\n", + " for each lie group, we should have a custom implementation of this\n", + " \"\"\"\n", + " def __init__(self, mnf):\n", + " self.mnf = mnf\n", + "\n", + " def retract(self, x, v):\n", + " \"\"\"rescaling :math:`x+v` to be on the manifold\n", + " \"\"\"\n", + " ixv = x.T@v\n", + " return x + x@jla.solve(jnp.eye(ixv.shape[0]) - 0.5*ixv, ixv)\n", + "\n", + " def inverse_retract(self, x, y):\n", + " u = x.T@y\n", + " n = self.mnf.shape[0]\n", + " return 2*x@jla.solve(jnp.eye(n)+u, u-jnp.eye(n))\n", + "\n", + " def drift_adjust(self, x, driver_dim):\n", + " \"\"\"return the adjustment :math:`\\\\mu_{adj}`\n", + " so that :math:`\\\\mu + \\\\mu_{adj} = \\\\mu_{\\\\mathfrak{r}}`\n", + " \"\"\"\n", + " return -0.5*jnp.sum(vmap(lambda seq:\n", + " x@sqr(self.mnf.sigma_id(seq.reshape(x.shape)))\n", + " )(jnp.eye(driver_dim)),\n", + " axis=0)\n", + "\n", + "\n", + "\n", + "def test_langevin_so():\n", + " # test Langevin on se(n) with vfunc = e^{-\\frac{1}{2}v^T\\Lambda v}\n", + " # jax.config.update('jax_default_device', jax.devices('cpu')[0])\n", + " n = 4\n", + " so_dim = n*(n-1)//2\n", + "\n", + " lbd = [0.5*jnp.diag(jnp.arange(1, n+1)), 0.5*jnp.arange(1, n**2+1).reshape(n, n)]\n", + "\n", + " def log_v(_, x):\n", + " return -jnp.trace(lbd[0]@x.T@lbd[1]@x)\n", + "\n", + " def grad_log_v(mnf, x):\n", + " return -mnf.proj(x, mnf.inv_g_metric(\n", + " x, (lbd[1]+lbd[1].T)@x@lbd[0]))\n", + "\n", + " key = random.PRNGKey(0)\n", + "\n", + " metric_mat, key = rand_positive_definite(key, so_dim, (.1, 30.))\n", + "\n", + " print(\"Doing SO\")\n", + "\n", + " # metric_mat = jnp.eye(se_dim)\n", + " so = som.SOLeftInvariant(n, metric_mat)\n", + " crtr = cayley_so_retraction(so)\n", + " x, key = so.rand_point(key)\n", + " eta, key = so.rand_vec(key, x)\n", + "\n", + " # print(jax.jvp(lambda x: log_v(so, x), (x,), (eta,))[1])\n", + " # print(so.inner(x, eta, grad_log_v(so, x)))\n", + "\n", + " # x1 = crtr.retract(x, eta)\n", + " # eta1 = crtr.inverse_retract(x, x1)\n", + " # print(cz(eta1-eta))\n", + "\n", + "\n", + " pay_offs = [None, lambda x: jnp.sqrt(1+jnp.sum(jnp.abs(x)))]\n", + "\n", + " lbd1, key = grand(key, (n**2,))\n", + " pay_offs = [None, lambda x: jnp.sqrt(1+jnp.sum(jnp.abs(lbd1*x.reshape(-1))))]\n", + "\n", + " x_0 = jnp.eye(n)\n", + "\n", + " key, sk = random.split(key)\n", + " t_final = 20.\n", + " # t_final = 1.5\n", + " n_path = 1000\n", + " n_div = 1000\n", + " d_coeff = .5\n", + "\n", + " wiener_dim = n**2\n", + "\n", + " ret_rtr = sim.simulate(x_0,\n", + " lambda x, unit_move, scale: crtr.retract(\n", + " x,\n", + " x@so.sigma_id(unit_move.reshape(x.shape))*scale**.5\n", + " + 0.5*grad_log_v(so, x)*scale),\n", + " pay_offs[0],\n", + " pay_offs[1],\n", + " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", + "\n", + " print(\"SO Cayley retract %.3f\" % jnp.nanmean(ret_rtr[0]))\n", + "\n", + "\n", + " ret_rtr1 = sim.simulate(x_0,\n", + " lambda x, unit_move, scale: mi.ito_move_with_drift(\n", + " so, x, unit_move, scale, 0.5*jla.solve(x, grad_log_v(so, x))),\n", + " pay_offs[0],\n", + " # lambda x: x[1, -1]*x[1, -1],\n", + " pay_offs[1],\n", + " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", + "\n", + " print(\"Ito Langevin %.3f\" % jnp.nanmean(ret_rtr1[0]))\n", + "\n", + " ret_rtr2 = sim.simulate(x_0,\n", + " lambda x, unit_move, scale: mi.stratonovich_move_with_drift(\n", + " so, x, unit_move, scale, 0.5*jla.solve(x, grad_log_v(so, x))),\n", + " pay_offs[0],\n", + " # lambda x: x[1, -1]*x[1, -1],\n", + " pay_offs[1],\n", + " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", + "\n", + " print(\"Stratonovich Langevin %.3f\" % jnp.nanmean(ret_rtr2[0]))\n", + "\n", + "\n", + " ret_rtr3 = sim.simulate(x_0,\n", + " lambda x, unit_move, scale: mi.geodesic_move_with_drift(\n", + " so, x, unit_move, scale, 0.5*jla.solve(x, grad_log_v(so, x))),\n", + " pay_offs[0],\n", + " # lambda x: x[1, -1]*x[1, -1],\n", + " pay_offs[1],\n", + " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", + "\n", + " print(\"Geodesic 2nd order langevin %.3f\" % jnp.nanmean(ret_rtr3[0]))\n", + "\n", + " # Using known method to sample Stiefel manifold uniformly with homogenous measure\n", + " n_samples = 1000**2\n", + " ret_denom = uniform_sample(\n", + " key, so.shape,\n", + " lambda x: jnp.exp(log_v(so, x)),\n", + " n_samples)\n", + "\n", + " ret_num = uniform_sample(\n", + " key, so.shape,\n", + " lambda x: pay_offs[1](x)*jnp.exp(log_v(so, x)),\n", + " n_samples)\n", + " print(\"SO sampling with density %.3f\" % (ret_num/ret_denom))\n", + "\n", + "test_langevin_so()\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "wYFUaqOgEoEQ", + "outputId": "ebdeb63d-10a9-4f0f-b268-22415e6c2d80" + }, + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Doing SO\n", + "SO Cayley retract 2.841\n", + "Ito Langevin 2.841\n", + "Stratonovich Langevin 2.841\n", + "Geodesic 2nd order langevin 2.841\n", + "SO sampling with density 2.800\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Sampling on $SE(n)$. The group $SE(n)$ is not compact. For a left-invariant metric, the volume form is $\\det(\\cI)^{\\frac{1}{2}}dvol_I$, where $I$ is the standard (product) measure of $SO(n)\\times \\R^n$. For the density $V$, we take the function $V(U, v) = e^{-\\frac{1}{2}v^T\\Lambda v}$ for an element $(U, v)\\in SE(n)$ for a diagonal matrix $\\Lambda$.\n", + "\n", + "For the first test we take $n=3$ and $f=\\sum |x_{ij}|$ where we identify $SE(n)$ with a subgroup of $GL(n+1)$." + ], + "metadata": { + "id": "v7MTJ0T2IXuw" + } + }, + { + "cell_type": "code", + "source": [ + "\n", + "def test_langevin_se():\n", + " # test Langevin on se(n) with vfunc = e^{-\\frac{1}{2}v^T\\Lambda v}\n", + " # jax.config.update('jax_default_device', jax.devices('cpu')[0])\n", + " n = 3\n", + " lbd = 10.*jnp.arange(1, n+1)\n", + "\n", + " @partial(jit, static_argnums=(0,))\n", + " def log_v(_, x):\n", + " return -0.5*jnp.sum(x[:-1, -1]*lbd*x[:-1, -1])\n", + "\n", + " @partial(jit, static_argnums=(0,))\n", + " def grad_log_v(mnf, x):\n", + " return mnf.proj(x, mnf.inv_g_metric(\n", + " x,\n", + " jnp.zeros_like(x).at[:-1, -1].set(-lbd*x[:-1, -1])))\n", + "\n", + " key = random.PRNGKey(0)\n", + "\n", + " se_dim = n*(n+1)//2\n", + " n1 = n+1\n", + " metric_mat, key = rand_positive_definite(key, se_dim, (.1, 30.))\n", + "\n", + " # convergent seems to be to same metric, but different rate\n", + "\n", + " # metric_mat = jnp.eye(se_dim)\n", + " # metric_mat = metric_mat.at[0, 0].set(1.)\n", + " se = sem.SELeftInvariant(n, metric_mat)\n", + " # x, key = se.rand_point(key)\n", + " # eta, key = se.rand_vec(key, x)\n", + "\n", + " # print(jax.jvp(lambda x: log_v(se, x), (x,), (eta,))[1])\n", + " # print(se.inner(x, grad_log_v(se, x), eta))\n", + "\n", + " # pay_offs = [None, lambda x: jnp.sqrt(jnp.sum(x[:-1, -1]**2))]\n", + "\n", + " # pay_offs = [None, lambda x: jnp.sum(jnp.abs(x[:-1, -1]))]\n", + "\n", + " # pay_offs = [None, lambda x: jnp.sqrt(jnp.sum(x*x))]\n", + " print(\"Test SE with n=%d expectation of sum |x|\" % (n))\n", + " pay_offs = [None, lambda x: jnp.sum(jnp.abs(x))]\n", + "\n", + " x_0 = jnp.eye(n1)\n", + " key, sk = random.split(key)\n", + " t_final = 20.\n", + " n_path = 5000\n", + " n_div = 1000\n", + " d_coeff = .5\n", + "\n", + " wiener_dim = n1**2\n", + " ret_rtr1 = sim.simulate(x_0,\n", + " lambda x, unit_move, scale: mi.ito_move_with_drift(\n", + " se, x, unit_move, scale, 0.5*jla.solve(x, grad_log_v(se, x))),\n", + " pay_offs[0],\n", + " # lambda x: x[1, -1]*x[1, -1],\n", + " pay_offs[1],\n", + " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", + "\n", + " print(\"Ito Langevin %.3f\" % jnp.nanmean(ret_rtr1[0]))\n", + "\n", + " ret_rtr2 = sim.simulate(x_0,\n", + " lambda x, unit_move, scale: mi.stratonovich_move_with_drift(\n", + " se, x, unit_move, scale, 0.5*jla.solve(x, grad_log_v(se, x))),\n", + " pay_offs[0],\n", + " # lambda x: x[1, -1]*x[1, -1],\n", + " pay_offs[1],\n", + " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", + "\n", + " print(\"Stratonovich Langevin %.3f\" % jnp.nanmean(ret_rtr2[0]))\n", + "\n", + "\n", + " ret_rtr3 = sim.simulate(x_0,\n", + " lambda x, unit_move, scale: mi.geodesic_move_with_drift(\n", + " se, x, unit_move, scale, 0.5*jla.solve(x, grad_log_v(se, x))),\n", + " pay_offs[0],\n", + " # lambda x: x[1, -1]*x[1, -1],\n", + " pay_offs[1],\n", + " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", + "\n", + " print(\"Geodesic 2nd order langevin %.3f\" % jnp.nanmean(ret_rtr3[0]))\n", + "\n", + "\n", + " def se_sample(key, shape, pay_off, n_samples):\n", + " \"\"\" Sample the manifold uniformly on the sphere\n", + " and with the\n", + " \"\"\"\n", + " x_all, key = grand(key, (shape[0]-1, shape[1], n_samples))\n", + "\n", + " def do_one_point(seq):\n", + " # ei, ev = jla.eigh(seq.T@seq)\n", + " # return pay_off(seq@ev@((1/jnp.sqrt(ei))[:, None]*ev.T))\n", + " u, _, vt = jla.svd(seq[:, :-1])\n", + " x = vcat(jnp.concatenate(\n", + " [u@vt, seq[:, -1][:, None]], axis=1),\n", + " jnp.zeros((1, shape[1])).at[0, -1].set(1.))\n", + " return pay_off(x)*jnp.exp(log_v(se, x)+0.5*jnp.sum(x[:-1, -1]**2))\n", + "\n", + " s = jax.vmap(do_one_point, in_axes=2)(x_all)\n", + " return jnp.nanmean(s)\n", + "\n", + " n_samples = 1000**2\n", + " ret_denom = se_sample(\n", + " key, se.shape,\n", + " lambda x: 1.,\n", + " n_samples)\n", + " ret_num = se_sample(\n", + " key, se.shape,\n", + " pay_offs[1],\n", + " n_samples)\n", + "\n", + " print(\"uniform sampling with density %.3f\" % (ret_num/ret_denom))\n", + "\n", + "\n", + "test_langevin_se()\n", + "\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "K6TcrRjbEqsX", + "outputId": "0071c0bd-e3a0-4974-fbc5-d324eaf87960" + }, + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Test SE with n=3 expectation of sum |x|\n", + "Ito Langevin 6.076\n", + "Stratonovich Langevin 6.076\n", + "Geodesic 2nd order langevin 6.076\n", + "uniform sampling with density 6.075\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "For the second test, we take $n=4$, with the same density, and take expectation of the function $vec(x)^TAvec(x)$, where we vectorize the first $n$ rows of $x\\in SE(n)$, identified with an element of $\\R^{(n+1)\\times(n+1)}$, for a randomly generated matrix $A$ of size $n(n+1)\\times n(n+1)$." + ], + "metadata": { + "id": "MOKBL02mfKzF" + } + }, + { + "cell_type": "code", + "source": [ + "\n", + "def test_langevin_se2():\n", + " # jax.config.update('jax_default_device', jax.devices('cpu')[0])\n", + " n = 4\n", + " se_dim = n*(n+1)//2\n", + " n1 = n+1\n", + "\n", + " lbd = 10. + jnp.arange(n)\n", + "\n", + " @partial(jit, static_argnums=(0,))\n", + " def log_v(_, x):\n", + " return -0.5*jnp.sum(x[:-1, -1]*lbd*x[:-1, -1])\n", + "\n", + " @partial(jit, static_argnums=(0,))\n", + " def grad_log_v(mnf, x):\n", + " return mnf.proj(x, mnf.inv_g_metric(\n", + " x,\n", + " jnp.zeros_like(x).at[:-1, -1].set(-lbd*x[:-1, -1])))\n", + "\n", + " key = random.PRNGKey(0)\n", + "\n", + " # metric_mat, key = rand_positive_definite(key, se_dim, (.1, 30.))\n", + " A, key = grand(key, (n*n1,n*n1))\n", + " A = sym(A@A.T)\n", + " # convergent seems to be to same metric, but different rate\n", + "\n", + " metric_mat = jnp.eye(se_dim)\n", + " # metric_mat = metric_mat.at[0, 0].set(1.)\n", + " se = sem.SELeftInvariant(n, metric_mat)\n", + " # x, key = se.rand_point(key)\n", + " # eta, key = se.rand_vec(key, x)\n", + "\n", + " # print(jax.jvp(lambda x: log_v(se, x), (x,), (eta,))[1])\n", + " # print(se.inner(x, grad_log_v(se, x), eta))\n", + " print(\"Test SE n=%d expectation of |x^TAx|^(1/2) for a positive definite matrix A\" % (n))\n", + "\n", + " pay_offs = [None, lambda x: jnp.sqrt(jnp.abs(jnp.sum(x[:-1, :].reshape(-1)*(A@x[:-1, :].reshape(-1)))))]\n", + " # pay_offs = [None, lambda x: jnp.sqrt(jnp.sum(x*x*jnp.arange(1, n1+1)[None, :]))]\n", + " # pay_offs = [None, lambda x: jnp.sum(x[0, :-1]*x[:-1, -1])]\n", + "\n", + " x_0 = jnp.eye(n1)\n", + " key, sk = random.split(key)\n", + " t_final = 20.\n", + " n_path = 5000\n", + " n_div = 1000\n", + " d_coeff = .5\n", + "\n", + " wiener_dim = n1**2\n", + "\n", + " ret_rtr1 = sim.simulate(x_0,\n", + " lambda x, unit_move, scale: mi.ito_move_with_drift(\n", + " se, x, unit_move, scale, 0.5*jla.solve(x, grad_log_v(se, x))),\n", + " pay_offs[0],\n", + " # lambda x: x[1, -1]*x[1, -1],\n", + " pay_offs[1],\n", + " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", + "\n", + " print(\"Ito Langevin %.3f\" % jnp.nanmean(ret_rtr1[0]))\n", + "\n", + "\n", + " ret_rtr2 = sim.simulate(x_0,\n", + " lambda x, unit_move, scale: mi.stratonovich_move_with_drift(\n", + " se, x, unit_move, scale, 0.5*jla.solve(x, grad_log_v(se, x))),\n", + " pay_offs[0],\n", + " # lambda x: x[1, -1]*x[1, -1],\n", + " pay_offs[1],\n", + " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", + "\n", + " print(\"Stratonovich Langevin %.3f\" % jnp.nanmean(ret_rtr2[0]))\n", + "\n", + "\n", + " ret_rtr3 = sim.simulate(x_0,\n", + " lambda x, unit_move, scale: mi.geodesic_move_with_drift(\n", + " se, x, unit_move, scale, 0.5*jla.solve(x, grad_log_v(se, x))),\n", + " pay_offs[0],\n", + " # lambda x: x[1, -1]*x[1, -1],\n", + " pay_offs[1],\n", + " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", + "\n", + " print(\"Geodesic 2nd order langevin %.3f\" % jnp.nanmean(ret_rtr3[0]))\n", + "\n", + "\n", + " def se_sample(key, shape, pay_off, n_samples):\n", + " \"\"\" Sample the manifold uniformly on the sphere\n", + " and with the\n", + " \"\"\"\n", + " x_all, key = grand(key, (shape[0]-1, shape[1], n_samples))\n", + "\n", + " def do_one_point(seq):\n", + " # ei, ev = jla.eigh(seq.T@seq)\n", + " # return pay_off(seq@ev@((1/jnp.sqrt(ei))[:, None]*ev.T))\n", + " u, _, vt = jla.svd(seq[:, :-1])\n", + " x = vcat(jnp.concatenate(\n", + " [u@vt, seq[:, -1][:, None]], axis=1),\n", + " jnp.zeros((1, shape[1])).at[0, -1].set(1.))\n", + " return pay_off(x)*jnp.exp(log_v(se, x)+0.5*jnp.sum(x[:-1, -1]**2))\n", + " #return jnp.sqrt(3+jnp.sum(x[:-1, -1]**2))*jnp.exp(log_v(se, x)+0.5*jnp.sum(x[:-1, -1]**2))\n", + "\n", + " s = jax.vmap(do_one_point, in_axes=2)(x_all)\n", + " # ret = []\n", + " # for i in range(x_all.shape[2]):\n", + " # ret.append(do_one_point(x_all[:, :, i]))\n", + " # s = jnp.array(ret)\n", + " return jnp.nanmean(s)\n", + "\n", + " n_samples = 1000**2\n", + "\n", + " ret_denom = se_sample(\n", + " key, se.shape,\n", + " lambda x: 1.,\n", + " n_samples)\n", + " \"\"\"\n", + " ret_num = se_sample(\n", + " key, se.shape,\n", + " lambda x: pay_offs[1](x),\n", + " n_path*500)\n", + " \"\"\"\n", + " ret_num = se_sample(\n", + " key, se.shape,\n", + " # lambda x: x[1, -1]*x[1, -1],\n", + " pay_offs[1],\n", + " # lambda x: pay_offs[1](x) - jnp.sqrt(3+jnp.sum(x[:-1, -1]**2)),\n", + " n_samples)\n", + "\n", + " print(\"uniform sampling with density %.3f\" % (ret_num/ret_denom))\n", + "\n", + "\n", + "test_langevin_se2()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "HtlnhgR9Esn3", + "outputId": "5aec477d-e5e2-459a-8a92-1f66adfdb122" + }, + "execution_count": 5, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Test SE n=4 expectation of |x^TAx|^(1/2) for a positive definite matrix A\n", + "Ito Langevin 9.289\n", + "Stratonovich Langevin 9.295\n", + "Geodesic 2nd order langevin 9.288\n", + "uniform sampling with density 9.303\n" + ] + } + ] + } + ] +} \ No newline at end of file