Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for PyMC v5 (and fix some recent Jax issues) #91

Merged
merged 18 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ jobs:
- "core"
- "jax"
- "pymc3"
# - "pymc4"
# - "pymc4_jax"
- "pymc"
- "pymc_jax"

steps:
- name: Checkout
Expand Down
46 changes: 46 additions & 0 deletions docs/api/pymc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
.. _pymc-api:

PyMC (v5+) interface
===============

This ``celerite2.pymc`` submodule provides access to the *celerite2* models
within the `PyTensor <https://pytensor.readthedocs.io/>`_ framework. Of special
interest, this adds support for probabilistic model building using `PyMC
<https://docs.pymc.io/>`_ v5 or later.

*Note: PyMC v4 was a short-lived version of PyMC with the Aesara backend.
Celerite2 now only supports to PyMC 5, but past releases of celerite2
might work with Aesara.*

The :ref:`first` tutorial demonstrates the use of this interface, while this
page provides the details for the :class:`celerite2.pymc.GaussianProcess` class
which provides all this functionality. This page does not include documentation
for the term models defined in PyTensor, but you can refer to the
:ref:`python-terms` section of the :ref:`python-api` documentation. All of those
models are implemented in PyTensor and you can access them using something like
the following:

.. code-block:: python

import pytensor.tensor as pt
from celerite2.pymc import GaussianProcess, terms

term = terms.SHOTerm(S0=at.scalar(), w0=at.scalar(), Q=at.scalar())
gp = GaussianProcess(term)

The :class:`celerite2.pymc.GaussianProcess` class is detailed below:

.. autoclass:: celerite2.pymc.GaussianProcess
:inherited-members:
:exclude-members: sample, sample_conditional, recompute


PyMC (v5+) support
-----------------

This implementation comes with a custom PyMC ``Distribution`` that represents a
multivariate normal with a *celerite* covariance matrix. This is used by the
:func:`celerite2.pymc.GaussianProcess.marginal` method documented above which
adds a marginal likelihood node to a PyMC model.

.. autoclass:: celerite2.pymc.distribution.CeleriteNormal
42 changes: 0 additions & 42 deletions docs/api/pymc4.rst

This file was deleted.

4 changes: 2 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Regression in one dimension and this library, *celerite2* is a re-write of the
original `celerite project <https://celerite.readthedocs.io>`_ to improve
numerical stability and integration with various machine learning frameworks.
This implementation includes interfaces in Python and C++, with full support for
PyMC (v3 and v4) and JAX.
PyMC (v3 and v5+) and JAX.

This documentation won't teach you the fundamentals of GP modeling but the best
resource for learning about this is available for free online: `Rasmussen &
Expand Down Expand Up @@ -43,7 +43,7 @@ an issue <https://github.com/exoplanet-dev/celerite2/issues>`_ there.

api/python
api/pymc3
api/pymc4
api/pymc
api/jax
api/c++

Expand Down
80 changes: 37 additions & 43 deletions docs/tutorials/first.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@
"source": [
"## Posterior inference using PyMC\n",
"\n",
"*celerite2* also includes support for probabilistic modeling using PyMC (v4 or v3, using the `celerite2.pymc3` or `celerite2.pymc4` submodule respectively), and we can implement the same model from above as follows:"
"*celerite2* also includes support for probabilistic modeling using PyMC (v5 or v3, using the `celerite2.pymc` or `celerite2.pymc3` submodule respectively), and we can implement the same model from above as follows:"
]
},
{
Expand All @@ -340,50 +340,44 @@
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"import pymc as pm\n",
"from celerite2.pymc import GaussianProcess, terms as pm_terms\n",
"\n",
"with warnings.catch_warnings():\n",
" warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
" warnings.filterwarnings(\"ignore\", category=RuntimeWarning)\n",
"with pm.Model() as model:\n",
" mean = pm.Normal(\"mean\", mu=0.0, sigma=prior_sigma)\n",
" log_jitter = pm.Normal(\"log_jitter\", mu=0.0, sigma=prior_sigma)\n",
"\n",
" import pymc as pm\n",
" from celerite2.pymc4 import GaussianProcess, terms as pm_terms\n",
"\n",
" with pm.Model() as model:\n",
" mean = pm.Normal(\"mean\", mu=0.0, sigma=prior_sigma)\n",
" log_jitter = pm.Normal(\"log_jitter\", mu=0.0, sigma=prior_sigma)\n",
"\n",
" log_sigma1 = pm.Normal(\"log_sigma1\", mu=0.0, sigma=prior_sigma)\n",
" log_rho1 = pm.Normal(\"log_rho1\", mu=0.0, sigma=prior_sigma)\n",
" log_tau = pm.Normal(\"log_tau\", mu=0.0, sigma=prior_sigma)\n",
" term1 = pm_terms.SHOTerm(\n",
" sigma=pm.math.exp(log_sigma1),\n",
" rho=pm.math.exp(log_rho1),\n",
" tau=pm.math.exp(log_tau),\n",
" )\n",
" log_sigma1 = pm.Normal(\"log_sigma1\", mu=0.0, sigma=prior_sigma)\n",
" log_rho1 = pm.Normal(\"log_rho1\", mu=0.0, sigma=prior_sigma)\n",
" log_tau = pm.Normal(\"log_tau\", mu=0.0, sigma=prior_sigma)\n",
" term1 = pm_terms.SHOTerm(\n",
" sigma=pm.math.exp(log_sigma1),\n",
" rho=pm.math.exp(log_rho1),\n",
" tau=pm.math.exp(log_tau),\n",
" )\n",
"\n",
" log_sigma2 = pm.Normal(\"log_sigma2\", mu=0.0, sigma=prior_sigma)\n",
" log_rho2 = pm.Normal(\"log_rho2\", mu=0.0, sigma=prior_sigma)\n",
" term2 = pm_terms.SHOTerm(\n",
" sigma=pm.math.exp(log_sigma2), rho=pm.math.exp(log_rho2), Q=0.25\n",
" )\n",
" log_sigma2 = pm.Normal(\"log_sigma2\", mu=0.0, sigma=prior_sigma)\n",
" log_rho2 = pm.Normal(\"log_rho2\", mu=0.0, sigma=prior_sigma)\n",
" term2 = pm_terms.SHOTerm(\n",
" sigma=pm.math.exp(log_sigma2), rho=pm.math.exp(log_rho2), Q=0.25\n",
" )\n",
"\n",
" kernel = term1 + term2\n",
" gp = GaussianProcess(kernel, mean=mean)\n",
" gp.compute(t, diag=yerr**2 + pm.math.exp(log_jitter), quiet=True)\n",
" gp.marginal(\"obs\", observed=y)\n",
" kernel = term1 + term2\n",
" gp = GaussianProcess(kernel, mean=mean)\n",
" gp.compute(t, diag=yerr**2 + pm.math.exp(log_jitter), quiet=True)\n",
" gp.marginal(\"obs\", observed=y)\n",
"\n",
" pm.Deterministic(\"psd\", kernel.get_psd(omega))\n",
" pm.Deterministic(\"psd\", kernel.get_psd(omega))\n",
"\n",
" trace = pm.sample(\n",
" tune=1000,\n",
" draws=1000,\n",
" target_accept=0.9,\n",
" init=\"adapt_full\",\n",
" cores=2,\n",
" chains=2,\n",
" random_seed=34923,\n",
" )"
" trace = pm.sample(\n",
" tune=1000,\n",
" draws=1000,\n",
" target_accept=0.9,\n",
" init=\"adapt_full\",\n",
" cores=2,\n",
" chains=2,\n",
" random_seed=34923,\n",
" )"
]
},
{
Expand Down Expand Up @@ -411,7 +405,7 @@
"plt.xlim(freq.min(), freq.max())\n",
"plt.xlabel(\"frequency [1 / day]\")\n",
"plt.ylabel(\"power [day ppt$^2$]\")\n",
"_ = plt.title(\"posterior psd using PyMC3\")"
"_ = plt.title(\"posterior psd using PyMC\")"
]
},
{
Expand Down Expand Up @@ -488,7 +482,7 @@
"id": "6f3abdcf",
"metadata": {},
"source": [
"This runtime was similar to the PyMC3 result from above, and (as we'll see below) the convergence is also similar.\n",
"This runtime was similar to the PyMC result from above, and (as we'll see below) the convergence is also similar.\n",
"Any difference in runtime will probably disappear for more computationally expensive models, but this interface is looking pretty great here!\n",
"\n",
"As above, we can plot the posterior expectations for the power spectrum:"
Expand Down Expand Up @@ -563,7 +557,7 @@
" bins,\n",
" histtype=\"step\",\n",
" density=True,\n",
" label=\"PyMC3\",\n",
" label=\"PyMC\",\n",
")\n",
"plt.hist(\n",
" np.exp(np.asarray((numpyro_data.posterior[\"log_rho1\"].T)).flatten()),\n",
Expand Down Expand Up @@ -657,7 +651,7 @@
"metadata": {},
"source": [
"Overall these results are consistent, but the $\\hat{R}$ values are a bit high for the emcee run, so I'd probably run that for longer.\n",
"Either way, for models like these, PyMC3 and numpyro are generally going to be much better inference tools (in terms of runtime per effective sample) than emcee, so those are the recommended interfaces if the rest of your model can be easily implemented in such a framework."
"Either way, for models like these, PyMC and numpyro are generally going to be much better inference tools (in terms of runtime per effective sample) than emcee, so those are the recommended interfaces if the rest of your model can be easily implemented in such a framework."
]
}
],
Expand Down
14 changes: 7 additions & 7 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,20 @@ def pymc3(session):


@nox.session(python=ALL_PYTHON_VS)
def pymc4(session):
session.install(".[test,pymc4]")
_session_run(session, "python/test/pymc4")
def pymc(session):
session.install(".[test,pymc]")
_session_run(session, "python/test/pymc")


@nox.session(python=ALL_PYTHON_VS)
def pymc4_jax(session):
session.install(".[test,jax,pymc4]")
_session_run(session, "python/test/pymc4/test_pymc4_ops.py")
def pymc_jax(session):
session.install(".[test,jax,pymc]")
_session_run(session, "python/test/pymc/test_pymc_ops.py")


@nox.session(python=ALL_PYTHON_VS)
def full(session):
session.install(".[test,jax,pymc3,pymc4]")
session.install(".[test,jax,pymc3,pymc]")
_session_run(session, "python/test")


Expand Down
12 changes: 7 additions & 5 deletions python/celerite2/jax/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import pkg_resources
from jax import core, lax
from jax import numpy as jnp
from jax.abstract_arrays import ShapedArray
from jax.core import ShapedArray
from jax.interpreters import ad, xla
from jax.lib import xla_client

Expand Down Expand Up @@ -201,8 +201,8 @@ def _rev_translation_rule(name, spec, c, *args):


def _build_op(name, spec):
xla_client.register_cpu_custom_call_target(
name, getattr(xla_ops, spec["name"])()
xla_client.register_custom_call_target(
name, getattr(xla_ops, spec["name"])(), platform="cpu"
)

prim = core.Primitive(f"celerite2_{spec['name']}")
Expand All @@ -216,8 +216,10 @@ def _build_op(name, spec):
if not spec["has_rev"]:
return prim, None

xla_client.register_cpu_custom_call_target(
name + b"_rev", getattr(xla_ops, f"{spec['name']}_rev")()
xla_client.register_custom_call_target(
name + b"_rev",
getattr(xla_ops, f"{spec['name']}_rev")(),
platform="cpu",
)

jvp_prim = core.Primitive(f"celerite2_{spec['name']}_jvp")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,27 @@


def __set_compiler_flags():
import aesara
import pytensor

def add_flag(current, new):
if new in current:
return current
return f"{current} {new}"

current = aesara.config.gcc__cxxflags
current = pytensor.config.gcc__cxxflags
current = add_flag(current, "-Wno-c++11-narrowing")
current = add_flag(current, "-fno-exceptions")
current = add_flag(current, "-fno-unwind-tables")
current = add_flag(current, "-fno-asynchronous-unwind-tables")
aesara.config.gcc__cxxflags = current
pytensor.config.gcc__cxxflags = current


__set_compiler_flags()

from celerite2.pymc4 import terms
from celerite2.pymc4.celerite2 import GaussianProcess
from celerite2.pymc import terms # noqa
from celerite2.pymc.celerite2 import GaussianProcess # noqa

try:
from celerite2.pymc4 import jax_support # noqa
from celerite2.pymc import jax_support # noqa
except ImportError:
pass
Loading