diff --git a/tests/notebooks/TestExpmAndCayleyIntegrator.ipynb b/tests/notebooks/TestExpmAndCayleyIntegrator.ipynb index 95614bb..ad969f4 100644 --- a/tests/notebooks/TestExpmAndCayleyIntegrator.ipynb +++ b/tests/notebooks/TestExpmAndCayleyIntegrator.ipynb @@ -4,8 +4,7 @@ "metadata": { "colab": { "provenance": [], - "authorship_tag": "ABX9TyOqYilyIAPowpS6WwTskmcE", - "include_colab_link": true + "authorship_tag": "ABX9TyM6M9Arew2I3GJUj0pDfjVV" }, "kernelspec": { "name": "python3", @@ -16,16 +15,6 @@ } }, "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, { "cell_type": "markdown", "source": [ @@ -48,34 +37,34 @@ "base_uri": "https://localhost:8080/" }, "id": "wpJ5PwCM-CIa", - "outputId": "4f4f4478-905e-4535-e5c5-135b56ab8fb5" + "outputId": "de4b8528-7bfc-4c8e-b071-dbedcbbf47c4" }, - "execution_count": null, + "execution_count": 1, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Collecting git+https://github.com/dnguyend/jax-rb\n", - " Cloning https://github.com/dnguyend/jax-rb to /tmp/pip-req-build-oa2a9hi1\n", - " Running command git clone --filter=blob:none --quiet https://github.com/dnguyend/jax-rb /tmp/pip-req-build-oa2a9hi1\n", - " Resolved https://github.com/dnguyend/jax-rb to commit 20efd03c04d80b3438f32dcbf48cd917036675b4\n", + " Cloning https://github.com/dnguyend/jax-rb to /tmp/pip-req-build-lsy3c8_z\n", + " Running command git clone --filter=blob:none --quiet https://github.com/dnguyend/jax-rb /tmp/pip-req-build-lsy3c8_z\n", + " Resolved https://github.com/dnguyend/jax-rb to commit 581cc9d9b79fd59e4e49f03ca352f9b35c65ae65\n", " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "Requirement already satisfied: jax in /usr/local/lib/python3.10/dist-packages (from jax_rb==0.1.dev50+g20efd03) (0.4.26)\n", - "Requirement already satisfied: jaxlib in /usr/local/lib/python3.10/dist-packages (from jax_rb==0.1.dev50+g20efd03) (0.4.26+cuda12.cudnn89)\n", - "Requirement already satisfied: ml-dtypes>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from jax->jax_rb==0.1.dev50+g20efd03) (0.2.0)\n", - "Requirement already satisfied: numpy>=1.22 in /usr/local/lib/python3.10/dist-packages (from jax->jax_rb==0.1.dev50+g20efd03) (1.25.2)\n", - "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax->jax_rb==0.1.dev50+g20efd03) (3.3.0)\n", - "Requirement already satisfied: scipy>=1.9 in /usr/local/lib/python3.10/dist-packages (from jax->jax_rb==0.1.dev50+g20efd03) (1.11.4)\n", + "Requirement already satisfied: jax in /usr/local/lib/python3.10/dist-packages (from jax_rb==0.1.dev57+g581cc9d) (0.4.26)\n", + "Requirement already satisfied: jaxlib in /usr/local/lib/python3.10/dist-packages (from jax_rb==0.1.dev57+g581cc9d) (0.4.26+cuda12.cudnn89)\n", + "Requirement already satisfied: ml-dtypes>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from jax->jax_rb==0.1.dev57+g581cc9d) (0.2.0)\n", + "Requirement already satisfied: numpy>=1.22 in /usr/local/lib/python3.10/dist-packages (from jax->jax_rb==0.1.dev57+g581cc9d) (1.25.2)\n", + "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax->jax_rb==0.1.dev57+g581cc9d) (3.3.0)\n", + "Requirement already satisfied: scipy>=1.9 in /usr/local/lib/python3.10/dist-packages (from jax->jax_rb==0.1.dev57+g581cc9d) (1.13.1)\n", "Building wheels for collected packages: jax_rb\n", " Building wheel for jax_rb (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for jax_rb: filename=jax_rb-0.1.dev50+g20efd03-py3-none-any.whl size=33135 sha256=f24800b7a2d206c9e978d3abd86bf29f7eca9b81150400954a41b4ac54b236ec\n", - " Stored in directory: /tmp/pip-ephem-wheel-cache-lgolyt48/wheels/0f/76/88/65e675f8bcca47be98c588d9a787a4c1c9b0a5044517ba6490\n", + " Created wheel for jax_rb: filename=jax_rb-0.1.dev57+g581cc9d-py3-none-any.whl size=33706 sha256=c4c979ea7f80ff92e40bcf9dea861d559273b4ea6131973c353c83213783be2a\n", + " Stored in directory: /tmp/pip-ephem-wheel-cache-tsmvnpug/wheels/0f/76/88/65e675f8bcca47be98c588d9a787a4c1c9b0a5044517ba6490\n", "Successfully built jax_rb\n", "Installing collected packages: jax_rb\n", - "Successfully installed jax_rb-0.1.dev50+g20efd03\n" + "Successfully installed jax_rb-0.1.dev57+g581cc9d\n" ] } ] @@ -101,7 +90,7 @@ "For a random move $\\Delta_W \\sim N(0, h^{\\frac{1}{2}}I_{\\R^{N\\times N}})$, the Euler-Maruyama exponential step will be\n", "$$X_{i+1} = X_i\\expm(X_i^{-1}(h\\mu_r(X_i)+ \\sigmam(X)\\Delta_W ))\n", "$$\n", - "It is a [remarkable fact](https://en.wikipedia.org/wiki/Pad%C3%A9_approximant#cite_note-wolfram-alpha-pade-exp-11) that the diagonal Pade approximator of $e^x$ is a rational function of the form $\\frac{p(x)}{p(-x)}$, with the first order approximation corresponds to $p(x) = 1+\\frac{x}{2}$.\n", + "It is a [remarkable fact](https://en.wikipedia.org/wiki/Pad%C3%A9_approximant#cite_note-wolfram-alpha-pade-exp-11) that the diagonal Pade approximator of $e^x$ is a rational function of the form $\\frac{p(x)}{p(-x)}$, with the first approximation corresponds to $p(x) = 1+\\frac{x}{2}$.\n", "\n", "For the group SO(N), or (more generally, for quadratic Lie group [Celledoni and Iserle]), for $a\\in \\so(N)$, we have $p(-a)^{-1}p(a)$ is in $SO(N)$ for all analytic $p$ with real coefficients. With $p(x) = 1+\\frac{x}{2}$, we have the Cayley retraction\n", "\n", @@ -111,7 +100,7 @@ "$$X_{i+1} = X_i(I - \\frac{1}{2}X_i^{-1}(h\\mu_r(X_i)+ \\sigmam(X)\\Delta_W ))^{-1}\n", "(I + \\frac{1}{2}X_i^{-1}(h\\mu_r(X_i)+ \\sigmam(X)\\Delta_W ))\n", "$$\n", - "We will check that these steps give the same simulation results as the geodesic, Ito and Stratonovich integrator in the paper.\n", + "We will check that these steps give the same simulation results as the geodesic, Ito and Stratonovich integrator in the paper. Beyond SO(n), when the group is not compact, the error growth is difficult to control for long term simulations.\n", "\n", "\n", "### References\n", @@ -132,34 +121,34 @@ "base_uri": "https://localhost:8080/" }, "id": "R50_I6ldmHMn", - "outputId": "e64532f6-84ac-4d9b-dcec-2a06f1aff4bf" + "outputId": "8222b66f-56f0-4d60-b207-7020c8dbe934" }, - "execution_count": null, + "execution_count": 2, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Collecting git+https://github.com/dnguyend/jax-rb\n", - " Cloning https://github.com/dnguyend/jax-rb to /tmp/pip-req-build-_uzzt7oc\n", - " Running command git clone --filter=blob:none --quiet https://github.com/dnguyend/jax-rb /tmp/pip-req-build-_uzzt7oc\n", - " Resolved https://github.com/dnguyend/jax-rb to commit 20efd03c04d80b3438f32dcbf48cd917036675b4\n", + " Cloning https://github.com/dnguyend/jax-rb to /tmp/pip-req-build-cvhpd3gb\n", + " Running command git clone --filter=blob:none --quiet https://github.com/dnguyend/jax-rb /tmp/pip-req-build-cvhpd3gb\n", + " Resolved https://github.com/dnguyend/jax-rb to commit 581cc9d9b79fd59e4e49f03ca352f9b35c65ae65\n", " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "Requirement already satisfied: jax in /usr/local/lib/python3.10/dist-packages (from jax_rb==0.1.dev50+g20efd03) (0.4.26)\n", - "Requirement already satisfied: jaxlib in /usr/local/lib/python3.10/dist-packages (from jax_rb==0.1.dev50+g20efd03) (0.4.26+cuda12.cudnn89)\n", - "Requirement already satisfied: ml-dtypes>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from jax->jax_rb==0.1.dev50+g20efd03) (0.2.0)\n", - "Requirement already satisfied: numpy>=1.22 in /usr/local/lib/python3.10/dist-packages (from jax->jax_rb==0.1.dev50+g20efd03) (1.25.2)\n", - "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax->jax_rb==0.1.dev50+g20efd03) (3.3.0)\n", - "Requirement already satisfied: scipy>=1.9 in /usr/local/lib/python3.10/dist-packages (from jax->jax_rb==0.1.dev50+g20efd03) (1.11.4)\n" + "Requirement already satisfied: jax in /usr/local/lib/python3.10/dist-packages (from jax_rb==0.1.dev57+g581cc9d) (0.4.26)\n", + "Requirement already satisfied: jaxlib in /usr/local/lib/python3.10/dist-packages (from jax_rb==0.1.dev57+g581cc9d) (0.4.26+cuda12.cudnn89)\n", + "Requirement already satisfied: ml-dtypes>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from jax->jax_rb==0.1.dev57+g581cc9d) (0.2.0)\n", + "Requirement already satisfied: numpy>=1.22 in /usr/local/lib/python3.10/dist-packages (from jax->jax_rb==0.1.dev57+g581cc9d) (1.25.2)\n", + "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax->jax_rb==0.1.dev57+g581cc9d) (3.3.0)\n", + "Requirement already satisfied: scipy>=1.9 in /usr/local/lib/python3.10/dist-packages (from jax->jax_rb==0.1.dev57+g581cc9d) (1.13.1)\n" ] } ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { "id": "QYVLILrQRLLO" }, @@ -176,8 +165,6 @@ "from jax import random, vmap, jit\n", "from jax.scipy.linalg import expm\n", "import jax_rb.manifolds.so_left_invariant as som\n", - "import jax_rb.manifolds.se_left_invariant as sem\n", - "import jax_rb.manifolds.affine_left_invariant as afm\n", "\n", "from jax_rb.utils.utils import (rand_positive_definite, sym, vcat)\n", "import jax_rb.simulation.simulator as sim\n", @@ -228,13 +215,13 @@ "metadata": { "id": "mIbxYgIAvuRx" }, - "execution_count": null, + "execution_count": 4, "outputs": [] }, { "cell_type": "markdown", "source": [ - "A few classes for retraction on groups and specialized implementations for $SO(n)$ and $SE(n)$. Here, we are in the case of Riemannian Brownian." + "A few classes for retraction on groups and specialized implementations for $SO(n)$. Here, we are in the case of Riemannian Brownian." ], "metadata": { "id": "PzJEj3XSnDgW" @@ -288,40 +275,13 @@ " )(jnp.eye(driver_dim)),\n", " axis=0)\n", "\n", - "class cayley_se_retraction():\n", - " \"\"\"Cayley retraction of a matrix Lie group\n", - " this is the most general, and not efficient implementation\n", - " for each lie group, we should have a custom implementation of this\n", - " \"\"\"\n", - " def __init__(self, mnf):\n", - " self.mnf = mnf\n", - "\n", - " def retract(self, x, v):\n", - " \"\"\"rescaling :math:`x+v` to be on the manifold\n", - " \"\"\"\n", - " n = x.shape[0] - 1\n", - " ixva = x[:-1, :-1].T@v[:-1, :-1]\n", - " return vcat(jnp.concatenate([x[:-1, :-1] + x[:-1, :-1]@jla.solve(jnp.eye(n)-0.5*ixva, ixva),\n", - " jla.solve(jnp.eye(n)-0.5*ixva, v[:-1, n:])], axis=1),\n", - " jnp.zeros(x.shape[0]).at[-1].set(1.).reshape(1, -1))\n", - " # x + x@jla.solve(jnp.eye(ixv.shape[0]) - 0.5*ixv, ixv)\n", - "\n", - " def drift_adjust(self, sigma, x, t, driver_dim):\n", - " \"\"\"return the adjustment :math:`\\\\mu_{adj}`\n", - " so that :math:`\\\\mu + \\\\mu_{adj} = \\\\mu_{\\\\mathfrak{r}}`\n", - " \"\"\"\n", - " return -0.5*jnp.sum(vmap(lambda seq:\n", - " x@sqr(self.mnf.sigma_id(seq.reshape(x.shape)))\n", - " )(jnp.eye(driver_dim)),\n", - " axis=0)\n", - "\n", "def sqr(a):\n", - " return a@a\n" + " return a@a" ], "metadata": { "id": "7A5PZbZ1msZg" }, - "execution_count": null, + "execution_count": 5, "outputs": [] }, { @@ -440,365 +400,45 @@ "base_uri": "https://localhost:8080/" }, "id": "Hp0BDZJowmDX", - "outputId": "12419e30-f42a-495a-9039-45d7f839567a" + "outputId": "12afeaf4-60f7-4f7d-d041-d7a3eb38925b" }, - "execution_count": null, + "execution_count": 6, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "test sum -gamma - ito drift=[[-3.26128013e-16 2.08166817e-16 1.66533454e-16 8.32667268e-17\n", - " 6.24500451e-17]\n", - " [-7.97972799e-17 -8.32667268e-17 -1.11022302e-16 -1.80411242e-16\n", - " -3.60822483e-16]\n", - " [ 3.12250226e-17 -3.60822483e-16 6.93889390e-17 -2.77555756e-17\n", - " -2.22044605e-16]\n", - " [-8.32667268e-17 -2.08166817e-16 1.24900090e-16 -2.77555756e-17\n", - " -2.49800181e-16]\n", - " [ 0.00000000e+00 -5.55111512e-17 -1.52655666e-16 4.85722573e-17\n", - " 2.08166817e-16]]\n", - "test adjusted ito is tangent=[[-2.37661530e-16 -2.22460559e-17 -1.90646870e-16 -7.17538281e-17\n", - " 1.26459054e-16]\n", - " [-2.22460559e-17 3.04724698e-16 3.08194244e-17 1.62580828e-17\n", - " 1.29192211e-16]\n", - " [-1.90646870e-16 3.08194244e-17 5.47579078e-17 1.45219969e-16\n", - " 6.05938757e-17]\n", - " [-7.17538281e-17 1.62580828e-17 1.45219969e-16 -1.71952758e-17\n", - " 5.51827559e-17]\n", - " [ 1.26459054e-16 1.29192211e-16 6.05938757e-17 5.51827559e-17\n", - " 4.28030441e-16]]\n", - "Time rtr 73.857657\n", - "Time crtr 20.281238\n", - "geo second order = 6.056695178129371\n", - "Ito = 6.055929373235796\n", - "Stratonovich = 6.054953057202997\n", + "test sum -gamma - ito drift=[[ 3.46944695e-16 -6.93889390e-17 -5.55111512e-17 2.77555756e-17\n", + " -1.04083409e-17]\n", + " [-5.72458747e-17 3.60822483e-16 0.00000000e+00 2.77555756e-17\n", + " 2.22044605e-16]\n", + " [-5.20417043e-17 2.22044605e-16 -5.55111512e-17 9.71445147e-17\n", + " 1.94289029e-16]\n", + " [ 6.93889390e-17 -1.38777878e-17 -4.16333634e-17 1.94289029e-16\n", + " 1.11022302e-16]\n", + " [-5.55111512e-17 2.77555756e-17 1.66533454e-16 4.16333634e-17\n", + " -3.46944695e-17]]\n", + "test adjusted ito is tangent=[[-4.77257607e-17 6.50674636e-17 1.89911247e-16 1.44996565e-16\n", + " -2.18644708e-17]\n", + " [ 6.50674636e-17 -3.20764925e-17 8.53113921e-18 -1.27781329e-16\n", + " -1.46461781e-16]\n", + " [ 1.89911247e-16 8.53113921e-18 -1.96166997e-16 -1.96229575e-17\n", + " 5.03222025e-17]\n", + " [ 1.44996565e-16 -1.27781329e-16 -1.96229575e-17 -5.22337541e-17\n", + " -1.14971926e-16]\n", + " [-2.18644708e-17 -1.46461781e-16 5.03222025e-17 -1.14971926e-16\n", + " -3.93447870e-16]]\n", + "Time rtr 70.728604\n", + "Time crtr 14.665379\n", + "geo second order = 6.056695178129377\n", + "Ito = 6.055929373235797\n", + "Stratonovich = 6.054953057202998\n", "Retractive = 6.055528965229489\n", "expm_so_Retractive = 6.055528965229489\n", "Cayley Retractive = 6.05581697657123\n" ] } ] - }, - { - "cell_type": "markdown", - "source": [ - "Test SE. The integrator for SE" - ], - "metadata": { - "id": "KSyq6UHQuxYJ" - } - }, - { - "cell_type": "code", - "source": [ - "def test_expm_integrator_se():\n", - " n = 3\n", - " key = random.PRNGKey(0)\n", - " se_dim = n*(n+1)//2\n", - " metric_mat, key = rand_positive_definite(key, se_dim, (.1, 30.))\n", - " mnf = sem.SELeftInvariant(n, metric_mat)\n", - " x, key = mnf.rand_point(key)\n", - " n1 = n+1\n", - "\n", - " gsum = jnp.zeros((n1, n1))\n", - " hsum = jnp.zeros((n1, n1))\n", - " for i in range(n1**2):\n", - " nsg = mnf.proj(x, mnf.sigma(x, jnp.zeros(n1**2).at[i].set(1.).reshape(n1, n1)))\n", - " hsum += x@sqr(jla.solve(x, nsg))\n", - " gsum += - mnf.gamma(x, nsg, nsg)\n", - " # print(jnp.sum(mnf.grad_c(x)*(hsum-gsum)))\n", - "\n", - " print(f\"test sum -gamma - ito drift={0.5*gsum - mnf.ito_drift(x)}\")\n", - " print(f\"test adjusted ito is tangent={sym(x.T@(-0.5*hsum+mnf.ito_drift(x)))}\")\n", - "\n", - " # now test the equation.\n", - " # test Brownian motion\n", - "\n", - " def new_sigma(x, _, dw):\n", - " return mnf.proj(x, mnf.sigma(x, dw))\n", - "\n", - " def mu(x, _):\n", - " return mnf.ito_drift(x)\n", - "\n", - " pay_offs = [lambda x, t: t*jnp.maximum(x[0, 0]-.5, 0),\n", - " lambda x: x[0, 0]**2]\n", - "\n", - " key, sk = random.split(key)\n", - " t_final = 1.\n", - " n_path = 1000\n", - " n_div = 1000\n", - " d_coeff = .5\n", - " wiener_dim = n1**2\n", - " x_0 = jnp.eye(n1)\n", - "\n", - " ret_geo = sim.simulate(x_0,\n", - " lambda x, unit_move, scale: mi.geodesic_move(\n", - " mnf, x, unit_move, scale),\n", - " pay_offs[0],\n", - " pay_offs[1],\n", - " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", - "\n", - " ret_ito = sim.simulate(x_0,\n", - " lambda x, unit_move, scale: mi.rbrownian_ito_move(\n", - " mnf, x, unit_move, scale),\n", - " pay_offs[0],\n", - " pay_offs[1],\n", - " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", - "\n", - " ret_str = sim.simulate(x_0,\n", - " lambda x, unit_move, scale: mi.rbrownian_stratonovich_move(\n", - " mnf, x, unit_move, scale),\n", - " pay_offs[0],\n", - " pay_offs[1],\n", - " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", - " rtr = expm_retraction(mnf)\n", - " ret_rtr = sim.simulate(x_0,\n", - " lambda x, unit_move, scale: matrix_retractive_move(\n", - " rtr, x, 1., unit_move, scale, new_sigma, mu),\n", - " pay_offs[0],\n", - " pay_offs[1],\n", - " [sk, t_final, 5, 5, d_coeff, wiener_dim])\n", - "\n", - " t0 = perf_counter()\n", - " ret_rtr = sim.simulate(x_0,\n", - " lambda x, unit_move, scale: matrix_retractive_move(\n", - " rtr, x, 1., unit_move, scale, new_sigma, mu),\n", - " pay_offs[0],\n", - " pay_offs[1],\n", - " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", - " t1 = perf_counter()\n", - " print('Time rtr %f' % (t1-t0))\n", - "\n", - " crtr = cayley_se_retraction(mnf)\n", - " t4 = perf_counter()\n", - " ret_crtr = sim.simulate(x_0,\n", - " lambda x, unit_move, scale: matrix_retractive_move(\n", - " crtr, x, 1., unit_move, scale, new_sigma, mu),\n", - " pay_offs[0],\n", - " pay_offs[1],\n", - " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", - " t5 = perf_counter()\n", - " print('Time crtr %f' % (t5-t4))\n", - "\n", - " print(f\"geo second order = {jnp.nanmean(ret_geo[0])}\")\n", - " print(f\"Ito = {jnp.nanmean(ret_ito[0])}\")\n", - " print(f\"Stratonovich = {jnp.nanmean(ret_str[0])}\")\n", - " print(f\"Retractive = {jnp.nanmean(ret_rtr[0])}\")\n", - " print(f\"Cayley Retractive = {jnp.nanmean(ret_crtr[0])}\")\n", - "test_expm_integrator_se()\n" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "TlZkYHY4uvgB", - "outputId": "0c950e9e-e80a-4a20-addc-adaf2d280593" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "test sum -gamma - ito drift=[[ 1.17961196e-16 9.02056208e-17 7.63278329e-17 2.77555756e-17]\n", - " [-5.55111512e-17 -4.94396191e-17 -6.93889390e-17 -3.64291930e-17]\n", - " [ 9.02056208e-17 7.28583860e-17 -1.38777878e-16 2.77555756e-17]\n", - " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]]\n", - "test adjusted ito is tangent=[[ 1.28024894e-16 9.54199437e-17 -2.23648658e-17 -2.70499180e-17]\n", - " [ 9.54199437e-17 5.49101560e-17 2.59732528e-17 -1.82677302e-17]\n", - " [-2.23648658e-17 2.59732528e-17 1.49963480e-16 -7.62769047e-18]\n", - " [-2.70499180e-17 -1.82677302e-17 -7.62769047e-18 -1.42003740e-16]]\n", - "Time rtr 42.457243\n", - "Time crtr 14.400974\n", - "geo second order = 1.1363965505646139\n", - "Ito = 1.136390973605797\n", - "Stratonovich = 1.1363671560250506\n", - "Retractive = 1.1363802045207878\n", - "Cayley Retractive = 1.1363867289743548\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "# A simple approximation of expm\n", - " The two-terms Taylor series. For most groups, this does not work, but for the affine group and the generalized linear group, this works." - ], - "metadata": { - "id": "-F3_McSfN4lu" - } - }, - { - "cell_type": "code", - "source": [ - "class expm_apprx_retraction():\n", - " \"\"\"the a retractive approximation of expm. This is simply a Taylor expansion\n", - " it works for affine group and GL(n), but the second Taylor expansion\n", - " in general is not a retraction. The other type is pade\n", - " \"\"\"\n", - " def __init__(self, mnf):\n", - " self.mnf = mnf\n", - "\n", - " def retract(self, x, v):\n", - " \"\"\"rescaling :math:`x+v` to be on the manifold\n", - " \"\"\"\n", - " return x + v + 0.5*x@sqr(jla.solve(x, v))\n", - "\n", - " def drift_adjust(self, _, x, t, driver_dim):\n", - " \"\"\"return the adjustment :math:`\\\\mu_{adj}`\n", - " so that :math:`\\\\mu + \\\\mu_{adj} = \\\\mu_{\\\\mathfrak{r}}`\n", - " \"\"\"\n", - "\n", - " return -0.5*jnp.sum(vmap(lambda seq:\n", - " x@sqr(self.mnf.sigma_id(seq.reshape(x.shape))))(jnp.eye(driver_dim)),\n", - " axis=0)\n" - ], - "metadata": { - "id": "OJSr4lXWN3Sf" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "Again, test the adjusted ito is tangent, then show the expm and the two terms taylor series simulations give the same result as the other 3 simulations" - ], - "metadata": { - "id": "UyP3uTdkPrgp" - } - }, - { - "cell_type": "code", - "source": [ - "def test_expm_integrator_affine():\n", - " n = 3\n", - " aff_dim = n*(n+1)\n", - " n1 = n + 1\n", - "\n", - " key = random.PRNGKey(0)\n", - " metric_mat, key = rand_positive_definite(key, aff_dim, (.1, 10.))\n", - " mnf = afm.AffineLeftInvariant(n, metric_mat)\n", - "\n", - " x, key = mnf.rand_point(key)\n", - "\n", - " gsum = jnp.zeros((n1, n1))\n", - " hsum = jnp.zeros((n1, n1))\n", - " for i in range(n1**2):\n", - " nsg = mnf.proj(x, mnf.sigma(x, jnp.zeros(n1**2).at[i].set(1.).reshape(n1, n1)))\n", - " hsum += x@sqr(jla.solve(x, nsg))\n", - " gsum += - mnf.gamma(x, nsg, nsg)\n", - " # print(jnp.sum(mnf.grad_c(x)*(hsum-gsum)))\n", - "\n", - " print(f\"test sum -gamma - ito drift={0.5*gsum - mnf.ito_drift(x)}\")\n", - " print(f\"test adjusted ito is tangent={jla.solve(x, (-0.5*hsum+mnf.ito_drift(x)))}\")\n", - "\n", - " # now test the equation.\n", - " # test Brownian motion\n", - "\n", - " def new_sigma(x, _, dw):\n", - " return mnf.proj(x, mnf.sigma(x, dw))\n", - "\n", - " def mu(x, _):\n", - " return mnf.ito_drift(x)\n", - "\n", - " pay_offs = [lambda x, t: t*jnp.maximum(x[0, 0]-.5, 0),\n", - " lambda x: (1+jnp.abs(x[0, 0]))**(-.5)\n", - " ]\n", - "\n", - "\n", - " key, sk = random.split(key)\n", - " t_final = 1.\n", - " n_path = 1000\n", - " n_div = 200\n", - " d_coeff = .5\n", - " wiener_dim = n1**2\n", - " x_0 = jnp.eye(n1)\n", - "\n", - " ret_geo = sim.simulate(x_0,\n", - " lambda x, unit_move, scale: mi.geodesic_move(\n", - " mnf, x, unit_move, scale),\n", - " pay_offs[0],\n", - " pay_offs[1],\n", - " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", - "\n", - " ret_ito = sim.simulate(x_0,\n", - " lambda x, unit_move, scale: mi.rbrownian_ito_move(\n", - " mnf, x, unit_move, scale),\n", - " pay_offs[0],\n", - " pay_offs[1],\n", - " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", - "\n", - " ret_str = sim.simulate(x_0,\n", - " lambda x, unit_move, scale: mi.rbrownian_stratonovich_move(\n", - " mnf, x, unit_move, scale),\n", - " pay_offs[0],\n", - " pay_offs[1],\n", - " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", - "\n", - " rtr = expm_retraction(mnf)\n", - " t0 = perf_counter()\n", - " ret_rtr = sim.simulate(x_0,\n", - " lambda x, unit_move, scale: matrix_retractive_move(\n", - " rtr, x, None, unit_move, scale, new_sigma, mu),\n", - " pay_offs[0],\n", - " pay_offs[1],\n", - " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", - " t1 = perf_counter()\n", - " print('Time rtr %f' % (t1-t0))\n", - "\n", - " artr = expm_apprx_retraction(mnf)\n", - " t2 = perf_counter()\n", - " ret_artr = sim.simulate(x_0,\n", - " lambda x, unit_move, scale: matrix_retractive_move(\n", - " artr, x, None, unit_move, scale, new_sigma, mu),\n", - " pay_offs[0],\n", - " pay_offs[1],\n", - " [sk, t_final, n_path, n_div, d_coeff, wiener_dim])\n", - " t3 = perf_counter()\n", - " print('Time artr %f' % (t3-t2))\n", - "\n", - " print(f\"geo second order = {jnp.nanmean(ret_geo[0])}\")\n", - " print(f\"Ito = {jnp.nanmean(ret_ito[0])}\")\n", - " print(f\"Stratonovich = {jnp.nanmean(ret_str[0])}\")\n", - " print(f\"Retractive = {jnp.nanmean(ret_rtr[0])}\")\n", - " print(f\"Appx Exp Retractive = {jnp.nanmean(ret_artr[0])}\")\n", - "\n", - "test_expm_integrator_affine()\n" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "WD-SSgqMPFTq", - "outputId": "1dce2e27-30d4-45d8-9811-cb115a861008" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "test sum -gamma - ito drift=[[ 3.57353036e-16 -8.88178420e-16 1.04083409e-16 -1.56125113e-16]\n", - " [-2.08166817e-16 -3.46944695e-18 6.93889390e-17 1.87350135e-16]\n", - " [ 3.95516953e-16 -3.74700271e-16 -3.98986399e-16 -2.84494650e-16]\n", - " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]]\n", - "test adjusted ito is tangent=[[-0.05771071 0.03133447 -0.01165553 -0.03045669]\n", - " [ 0.03785821 -0.20563126 0.03272205 0.04153271]\n", - " [ 0.0002595 0.04121875 -0.13468639 0.03533443]\n", - " [ 0. 0. 0. 0. ]]\n", - "Time rtr 11.436188\n", - "Time artr 2.475534\n", - "geo second order = 0.985876600067447\n", - "Ito = 0.985415106027261\n", - "Stratonovich = 0.9857234685704925\n", - "Retractive = 0.9857389429799134\n", - "Appx Exp Retractive = 0.9856962692286884\n" - ] - } - ] } ] } \ No newline at end of file