diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml
index 3387a49..00a7035 100644
--- a/.github/workflows/python.yml
+++ b/.github/workflows/python.yml
@@ -25,8 +25,8 @@ jobs:
- "core"
- "jax"
- "pymc3"
- # - "pymc4"
- # - "pymc4_jax"
+ - "pymc"
+ - "pymc_jax"
steps:
- name: Checkout
diff --git a/docs/api/pymc.rst b/docs/api/pymc.rst
new file mode 100644
index 0000000..b140466
--- /dev/null
+++ b/docs/api/pymc.rst
@@ -0,0 +1,46 @@
+.. _pymc-api:
+
+PyMC (v5+) interface
+===============
+
+This ``celerite2.pymc`` submodule provides access to the *celerite2* models
+within the `PyTensor `_ framework. Of special
+interest, this adds support for probabilistic model building using `PyMC
+`_ v5 or later.
+
+*Note: PyMC v4 was a short-lived version of PyMC with the Aesara backend.
+Celerite2 now only supports to PyMC 5, but past releases of celerite2
+might work with Aesara.*
+
+The :ref:`first` tutorial demonstrates the use of this interface, while this
+page provides the details for the :class:`celerite2.pymc.GaussianProcess` class
+which provides all this functionality. This page does not include documentation
+for the term models defined in PyTensor, but you can refer to the
+:ref:`python-terms` section of the :ref:`python-api` documentation. All of those
+models are implemented in PyTensor and you can access them using something like
+the following:
+
+.. code-block:: python
+
+ import pytensor.tensor as pt
+ from celerite2.pymc import GaussianProcess, terms
+
+ term = terms.SHOTerm(S0=at.scalar(), w0=at.scalar(), Q=at.scalar())
+ gp = GaussianProcess(term)
+
+The :class:`celerite2.pymc.GaussianProcess` class is detailed below:
+
+.. autoclass:: celerite2.pymc.GaussianProcess
+ :inherited-members:
+ :exclude-members: sample, sample_conditional, recompute
+
+
+PyMC (v5+) support
+-----------------
+
+This implementation comes with a custom PyMC ``Distribution`` that represents a
+multivariate normal with a *celerite* covariance matrix. This is used by the
+:func:`celerite2.pymc.GaussianProcess.marginal` method documented above which
+adds a marginal likelihood node to a PyMC model.
+
+.. autoclass:: celerite2.pymc.distribution.CeleriteNormal
diff --git a/docs/api/pymc4.rst b/docs/api/pymc4.rst
deleted file mode 100644
index 3840c64..0000000
--- a/docs/api/pymc4.rst
+++ /dev/null
@@ -1,42 +0,0 @@
-.. _pymc4-api:
-
-PyMC4 interface
-===============
-
-This ``celerite2.pymc4`` submodule provides access to the *celerite2* models
-within the `Aesara `_ framework. Of special
-interest, this adds support for probabilistic model building using `PyMC
-`_ v4 or later.
-
-The :ref:`first` tutorial demonstrates the use of this interface, while this
-page provides the details for the :class:`celerite2.pymc4.GaussianProcess` class
-which provides all this functionality. This page does not include documentation
-for the term models defined in Aesara, but you can refer to the
-:ref:`python-terms` section of the :ref:`python-api` documentation. All of those
-models are implemented in Aesara and you can access them using something like
-the following:
-
-.. code-block:: python
-
- import aesara.tensor as at
- from celerite2.pymc4 import GaussianProcess, terms
-
- term = terms.SHOTerm(S0=at.scalar(), w0=at.scalar(), Q=at.scalar())
- gp = GaussianProcess(term)
-
-The :class:`celerite2.pymc4.GaussianProcess` class is detailed below:
-
-.. autoclass:: celerite2.pymc4.GaussianProcess
- :inherited-members:
- :exclude-members: sample, sample_conditional, recompute
-
-
-PyMC (v4) support
------------------
-
-This implementation comes with a custom PyMC ``Distribution`` that represents a
-multivariate normal with a *celerite* covariance matrix. This is used by the
-:func:`celerite2.pymc4.GaussianProcess.marginal` method documented above which
-adds a marginal likelihood node to a PyMC model.
-
-.. autoclass:: celerite2.pymc4.distribution.CeleriteNormal
diff --git a/docs/index.rst b/docs/index.rst
index 9b1c3a8..7d0fbfd 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -7,7 +7,7 @@ Regression in one dimension and this library, *celerite2* is a re-write of the
original `celerite project `_ to improve
numerical stability and integration with various machine learning frameworks.
This implementation includes interfaces in Python and C++, with full support for
-PyMC (v3 and v4) and JAX.
+PyMC (v3 and v5+) and JAX.
This documentation won't teach you the fundamentals of GP modeling but the best
resource for learning about this is available for free online: `Rasmussen &
@@ -43,7 +43,7 @@ an issue `_ there.
api/python
api/pymc3
- api/pymc4
+ api/pymc
api/jax
api/c++
diff --git a/docs/tutorials/first.ipynb b/docs/tutorials/first.ipynb
index 4e6e965..9aa422f 100644
--- a/docs/tutorials/first.ipynb
+++ b/docs/tutorials/first.ipynb
@@ -330,7 +330,7 @@
"source": [
"## Posterior inference using PyMC\n",
"\n",
- "*celerite2* also includes support for probabilistic modeling using PyMC (v4 or v3, using the `celerite2.pymc3` or `celerite2.pymc4` submodule respectively), and we can implement the same model from above as follows:"
+ "*celerite2* also includes support for probabilistic modeling using PyMC (v5 or v3, using the `celerite2.pymc` or `celerite2.pymc3` submodule respectively), and we can implement the same model from above as follows:"
]
},
{
@@ -340,50 +340,44 @@
"metadata": {},
"outputs": [],
"source": [
- "import warnings\n",
+ "import pymc as pm\n",
+ "from celerite2.pymc import GaussianProcess, terms as pm_terms\n",
"\n",
- "with warnings.catch_warnings():\n",
- " warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
- " warnings.filterwarnings(\"ignore\", category=RuntimeWarning)\n",
+ "with pm.Model() as model:\n",
+ " mean = pm.Normal(\"mean\", mu=0.0, sigma=prior_sigma)\n",
+ " log_jitter = pm.Normal(\"log_jitter\", mu=0.0, sigma=prior_sigma)\n",
"\n",
- " import pymc as pm\n",
- " from celerite2.pymc4 import GaussianProcess, terms as pm_terms\n",
- "\n",
- " with pm.Model() as model:\n",
- " mean = pm.Normal(\"mean\", mu=0.0, sigma=prior_sigma)\n",
- " log_jitter = pm.Normal(\"log_jitter\", mu=0.0, sigma=prior_sigma)\n",
- "\n",
- " log_sigma1 = pm.Normal(\"log_sigma1\", mu=0.0, sigma=prior_sigma)\n",
- " log_rho1 = pm.Normal(\"log_rho1\", mu=0.0, sigma=prior_sigma)\n",
- " log_tau = pm.Normal(\"log_tau\", mu=0.0, sigma=prior_sigma)\n",
- " term1 = pm_terms.SHOTerm(\n",
- " sigma=pm.math.exp(log_sigma1),\n",
- " rho=pm.math.exp(log_rho1),\n",
- " tau=pm.math.exp(log_tau),\n",
- " )\n",
+ " log_sigma1 = pm.Normal(\"log_sigma1\", mu=0.0, sigma=prior_sigma)\n",
+ " log_rho1 = pm.Normal(\"log_rho1\", mu=0.0, sigma=prior_sigma)\n",
+ " log_tau = pm.Normal(\"log_tau\", mu=0.0, sigma=prior_sigma)\n",
+ " term1 = pm_terms.SHOTerm(\n",
+ " sigma=pm.math.exp(log_sigma1),\n",
+ " rho=pm.math.exp(log_rho1),\n",
+ " tau=pm.math.exp(log_tau),\n",
+ " )\n",
"\n",
- " log_sigma2 = pm.Normal(\"log_sigma2\", mu=0.0, sigma=prior_sigma)\n",
- " log_rho2 = pm.Normal(\"log_rho2\", mu=0.0, sigma=prior_sigma)\n",
- " term2 = pm_terms.SHOTerm(\n",
- " sigma=pm.math.exp(log_sigma2), rho=pm.math.exp(log_rho2), Q=0.25\n",
- " )\n",
+ " log_sigma2 = pm.Normal(\"log_sigma2\", mu=0.0, sigma=prior_sigma)\n",
+ " log_rho2 = pm.Normal(\"log_rho2\", mu=0.0, sigma=prior_sigma)\n",
+ " term2 = pm_terms.SHOTerm(\n",
+ " sigma=pm.math.exp(log_sigma2), rho=pm.math.exp(log_rho2), Q=0.25\n",
+ " )\n",
"\n",
- " kernel = term1 + term2\n",
- " gp = GaussianProcess(kernel, mean=mean)\n",
- " gp.compute(t, diag=yerr**2 + pm.math.exp(log_jitter), quiet=True)\n",
- " gp.marginal(\"obs\", observed=y)\n",
+ " kernel = term1 + term2\n",
+ " gp = GaussianProcess(kernel, mean=mean)\n",
+ " gp.compute(t, diag=yerr**2 + pm.math.exp(log_jitter), quiet=True)\n",
+ " gp.marginal(\"obs\", observed=y)\n",
"\n",
- " pm.Deterministic(\"psd\", kernel.get_psd(omega))\n",
+ " pm.Deterministic(\"psd\", kernel.get_psd(omega))\n",
"\n",
- " trace = pm.sample(\n",
- " tune=1000,\n",
- " draws=1000,\n",
- " target_accept=0.9,\n",
- " init=\"adapt_full\",\n",
- " cores=2,\n",
- " chains=2,\n",
- " random_seed=34923,\n",
- " )"
+ " trace = pm.sample(\n",
+ " tune=1000,\n",
+ " draws=1000,\n",
+ " target_accept=0.9,\n",
+ " init=\"adapt_full\",\n",
+ " cores=2,\n",
+ " chains=2,\n",
+ " random_seed=34923,\n",
+ " )"
]
},
{
@@ -411,7 +405,7 @@
"plt.xlim(freq.min(), freq.max())\n",
"plt.xlabel(\"frequency [1 / day]\")\n",
"plt.ylabel(\"power [day ppt$^2$]\")\n",
- "_ = plt.title(\"posterior psd using PyMC3\")"
+ "_ = plt.title(\"posterior psd using PyMC\")"
]
},
{
@@ -488,7 +482,7 @@
"id": "6f3abdcf",
"metadata": {},
"source": [
- "This runtime was similar to the PyMC3 result from above, and (as we'll see below) the convergence is also similar.\n",
+ "This runtime was similar to the PyMC result from above, and (as we'll see below) the convergence is also similar.\n",
"Any difference in runtime will probably disappear for more computationally expensive models, but this interface is looking pretty great here!\n",
"\n",
"As above, we can plot the posterior expectations for the power spectrum:"
@@ -563,7 +557,7 @@
" bins,\n",
" histtype=\"step\",\n",
" density=True,\n",
- " label=\"PyMC3\",\n",
+ " label=\"PyMC\",\n",
")\n",
"plt.hist(\n",
" np.exp(np.asarray((numpyro_data.posterior[\"log_rho1\"].T)).flatten()),\n",
@@ -657,7 +651,7 @@
"metadata": {},
"source": [
"Overall these results are consistent, but the $\\hat{R}$ values are a bit high for the emcee run, so I'd probably run that for longer.\n",
- "Either way, for models like these, PyMC3 and numpyro are generally going to be much better inference tools (in terms of runtime per effective sample) than emcee, so those are the recommended interfaces if the rest of your model can be easily implemented in such a framework."
+ "Either way, for models like these, PyMC and numpyro are generally going to be much better inference tools (in terms of runtime per effective sample) than emcee, so those are the recommended interfaces if the rest of your model can be easily implemented in such a framework."
]
}
],
diff --git a/noxfile.py b/noxfile.py
index ff729c8..1560160 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -30,20 +30,20 @@ def pymc3(session):
@nox.session(python=ALL_PYTHON_VS)
-def pymc4(session):
- session.install(".[test,pymc4]")
- _session_run(session, "python/test/pymc4")
+def pymc(session):
+ session.install(".[test,pymc]")
+ _session_run(session, "python/test/pymc")
@nox.session(python=ALL_PYTHON_VS)
-def pymc4_jax(session):
- session.install(".[test,jax,pymc4]")
- _session_run(session, "python/test/pymc4/test_pymc4_ops.py")
+def pymc_jax(session):
+ session.install(".[test,jax,pymc]")
+ _session_run(session, "python/test/pymc/test_pymc_ops.py")
@nox.session(python=ALL_PYTHON_VS)
def full(session):
- session.install(".[test,jax,pymc3,pymc4]")
+ session.install(".[test,jax,pymc3,pymc]")
_session_run(session, "python/test")
diff --git a/python/celerite2/jax/ops.py b/python/celerite2/jax/ops.py
index 17e1666..bc6e983 100644
--- a/python/celerite2/jax/ops.py
+++ b/python/celerite2/jax/ops.py
@@ -25,7 +25,7 @@
import pkg_resources
from jax import core, lax
from jax import numpy as jnp
-from jax.abstract_arrays import ShapedArray
+from jax.core import ShapedArray
from jax.interpreters import ad, xla
from jax.lib import xla_client
@@ -201,8 +201,8 @@ def _rev_translation_rule(name, spec, c, *args):
def _build_op(name, spec):
- xla_client.register_cpu_custom_call_target(
- name, getattr(xla_ops, spec["name"])()
+ xla_client.register_custom_call_target(
+ name, getattr(xla_ops, spec["name"])(), platform="cpu"
)
prim = core.Primitive(f"celerite2_{spec['name']}")
@@ -216,8 +216,10 @@ def _build_op(name, spec):
if not spec["has_rev"]:
return prim, None
- xla_client.register_cpu_custom_call_target(
- name + b"_rev", getattr(xla_ops, f"{spec['name']}_rev")()
+ xla_client.register_custom_call_target(
+ name + b"_rev",
+ getattr(xla_ops, f"{spec['name']}_rev")(),
+ platform="cpu",
)
jvp_prim = core.Primitive(f"celerite2_{spec['name']}_jvp")
diff --git a/python/celerite2/pymc4/__init__.py b/python/celerite2/pymc/__init__.py
similarity index 65%
rename from python/celerite2/pymc4/__init__.py
rename to python/celerite2/pymc/__init__.py
index 76c8701..ee77b01 100644
--- a/python/celerite2/pymc4/__init__.py
+++ b/python/celerite2/pymc/__init__.py
@@ -4,27 +4,27 @@
def __set_compiler_flags():
- import aesara
+ import pytensor
def add_flag(current, new):
if new in current:
return current
return f"{current} {new}"
- current = aesara.config.gcc__cxxflags
+ current = pytensor.config.gcc__cxxflags
current = add_flag(current, "-Wno-c++11-narrowing")
current = add_flag(current, "-fno-exceptions")
current = add_flag(current, "-fno-unwind-tables")
current = add_flag(current, "-fno-asynchronous-unwind-tables")
- aesara.config.gcc__cxxflags = current
+ pytensor.config.gcc__cxxflags = current
__set_compiler_flags()
-from celerite2.pymc4 import terms
-from celerite2.pymc4.celerite2 import GaussianProcess
+from celerite2.pymc import terms # noqa
+from celerite2.pymc.celerite2 import GaussianProcess # noqa
try:
- from celerite2.pymc4 import jax_support # noqa
+ from celerite2.pymc import jax_support # noqa
except ImportError:
pass
diff --git a/python/celerite2/pymc4/celerite2.py b/python/celerite2/pymc/celerite2.py
similarity index 87%
rename from python/celerite2/pymc4/celerite2.py
rename to python/celerite2/pymc/celerite2.py
index 735007d..13386d9 100644
--- a/python/celerite2/pymc4/celerite2.py
+++ b/python/celerite2/pymc/celerite2.py
@@ -2,13 +2,13 @@
__all__ = ["GaussianProcess", "ConditionalDistribution"]
-import aesara.tensor as tt
import numpy as np
-from aesara.raise_op import Assert
+import pytensor.tensor as pt
+from pytensor.raise_op import Assert
from celerite2.citation import CITATIONS
from celerite2.core import BaseConditionalDistribution, BaseGaussianProcess
-from celerite2.pymc4 import ops
+from celerite2.pymc import ops
class ConditionalDistribution(BaseConditionalDistribution):
@@ -22,37 +22,37 @@ def _do_general_matmul(self, c, U1, V1, U2, V2, inp, target):
return target
def _diagdot(self, a, b):
- return tt.batched_dot(a.T, b.T)
+ return pt.batched_dot(a.T, b.T)
class GaussianProcess(BaseGaussianProcess):
conditional_distribution = ConditionalDistribution
def _as_tensor(self, tensor):
- return tt.as_tensor_variable(tensor).astype("float64")
+ return pt.as_tensor_variable(tensor).astype("float64")
def _zeros_like(self, tensor):
- return tt.zeros_like(tensor)
+ return pt.zeros_like(tensor)
def _do_compute(self, quiet):
if quiet:
self._d, self._W, _ = ops.factor_quiet(
self._t, self._c, self._a, self._U, self._V
)
- self._log_det = tt.switch(
- tt.any(self._d <= 0.0), -np.inf, tt.sum(tt.log(self._d))
+ self._log_det = pt.switch(
+ pt.any(self._d <= 0.0), -np.inf, pt.sum(pt.log(self._d))
)
else:
self._d, self._W, _ = ops.factor(
self._t, self._c, self._a, self._U, self._V
)
- self._log_det = tt.sum(tt.log(self._d))
+ self._log_det = pt.sum(pt.log(self._d))
self._norm = -0.5 * (self._log_det + self._size * np.log(2 * np.pi))
def _check_sorted(self, t):
- return Assert()(t, tt.all(t[1:] - t[:-1] >= 0))
+ return Assert()(t, pt.all(t[1:] - t[:-1] >= 0))
def _do_solve(self, y):
z = ops.solve_lower(self._t, self._c, self._U, self._W, y)[0]
@@ -69,7 +69,7 @@ def _do_norm(self, y):
alpha = ops.solve_lower(
self._t, self._c, self._U, self._W, y[:, None]
)[0][:, 0]
- return tt.sum(alpha**2 / self._d)
+ return pt.sum(alpha**2 / self._d)
def _add_citations_to_pymc_model(self, **kwargs):
import pymc as pm
@@ -87,10 +87,10 @@ def marginal(self, name, **kwargs):
observed (optional): The observed data
Returns:
- A :class:`celerite2.pymc3.CeleriteNormal` distribution
+ A :class:`celerite2.pymc.CeleriteNormal` distribution
representing the marginal likelihood.
"""
- from celerite2.pymc4.distribution import CeleriteNormal
+ from celerite2.pymc.distribution import CeleriteNormal
self._add_citations_to_pymc_model(**kwargs)
return CeleriteNormal(
@@ -108,7 +108,7 @@ def marginal(self, name, **kwargs):
def conditional(
self, name, y, t=None, include_mean=True, kernel=None, **kwargs
):
- """Add a variable representing the conditional density to a PyMC3 model
+ """Add a variable representing the conditional density to a PyMC model
.. note:: The performance of this method will generally be poor since
the sampler will numerically sample this parameter. Depending on
diff --git a/python/celerite2/pymc4/distribution.py b/python/celerite2/pymc/distribution.py
similarity index 76%
rename from python/celerite2/pymc4/distribution.py
rename to python/celerite2/pymc/distribution.py
index 3754339..1583772 100644
--- a/python/celerite2/pymc4/distribution.py
+++ b/python/celerite2/pymc/distribution.py
@@ -2,16 +2,19 @@
__all__ = ["CeleriteNormal"]
-import aesara.tensor as tt
import numpy as np
-from aesara.tensor.random.op import RandomVariable
-from aesara.tensor.random.utils import broadcast_params
+import pytensor.tensor as pt
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.distribution import Continuous
from pymc.distributions.shape_utils import rv_size_is_none
+from pytensor.tensor.random.op import RandomVariable
+from pytensor.tensor.random.utils import (
+ broadcast_params,
+ supp_shape_from_ref_param_shape,
+)
import celerite2.driver as driver
-from celerite2.pymc4 import ops
+from celerite2.pymc import ops
def safe_celerite_normal(rng, mean, norm, t, c, U, W, d, size=None):
@@ -33,6 +36,14 @@ class CeleriteNormalRV(RandomVariable):
dtype = "floatX"
_print_name = ("CeleriteNormal", "\\operatorname{CeleriteNormal}")
+ def _supp_shape_from_params(self, dist_params, param_shapes=None):
+ return supp_shape_from_ref_param_shape(
+ ndim_supp=self.ndim_supp,
+ dist_params=dist_params,
+ param_shapes=param_shapes,
+ ref_param_idx=2,
+ )
+
@classmethod
def rng_fn(cls, rng, mean, norm, t, c, U, W, d, size):
if any(
@@ -90,26 +101,26 @@ class CeleriteNormal(Continuous):
@classmethod
def dist(cls, mean, norm, t, c, U, W, d, **kwargs):
- mean = tt.as_tensor_variable(mean)
- norm = tt.as_tensor_variable(norm)
- t = tt.as_tensor_variable(t)
- c = tt.as_tensor_variable(c)
- U = tt.as_tensor_variable(U)
- W = tt.as_tensor_variable(W)
- d = tt.as_tensor_variable(d)
- mean = tt.broadcast_arrays(mean, t)[0]
+ mean = pt.as_tensor_variable(mean)
+ norm = pt.as_tensor_variable(norm)
+ t = pt.as_tensor_variable(t)
+ c = pt.as_tensor_variable(c)
+ U = pt.as_tensor_variable(U)
+ W = pt.as_tensor_variable(W)
+ d = pt.as_tensor_variable(d)
+ mean = pt.broadcast_arrays(mean, t)[0]
return super().dist([mean, norm, t, c, U, W, d], **kwargs)
def moment(rv, size, mean, *args):
moment = mean
if not rv_size_is_none(size):
- moment_size = tt.concatenate([size, [mean.shape[-1]]])
- moment = tt.full(moment_size, mean)
+ moment_size = pt.concatenate([size, [mean.shape[-1]]])
+ moment = pt.full(moment_size, mean)
return moment
def logp(value, mean, norm, t, c, U, W, d):
- ok = tt.all(tt.gt(d, 0.0))
+ ok = pt.all(pt.gt(d, 0.0))
alpha = value - mean
alpha = ops.solve_lower(t, c, U, W, alpha[:, None])[0][:, 0]
- logp = norm - 0.5 * tt.sum(alpha**2 / d)
+ logp = norm - 0.5 * pt.sum(alpha**2 / d)
return check_parameters(logp, ok)
diff --git a/python/celerite2/pymc4/jax_support.py b/python/celerite2/pymc/jax_support.py
similarity index 71%
rename from python/celerite2/pymc4/jax_support.py
rename to python/celerite2/pymc/jax_support.py
index 2d268eb..b41af72 100644
--- a/python/celerite2/pymc4/jax_support.py
+++ b/python/celerite2/pymc/jax_support.py
@@ -1,10 +1,10 @@
-from aesara.link.jax.dispatch import jax_funcify
+from pytensor.link.jax.dispatch import jax_funcify
from celerite2.jax import ops as jax_ops
-from celerite2.pymc4 import ops as pymc4_ops
+from celerite2.pymc import ops as pymc_ops
-@jax_funcify.register(pymc4_ops._CeleriteOp)
+@jax_funcify.register(pymc_ops._CeleriteOp)
def jax_funcify_Celerite(op, **kwargs):
name = op.name
if name.endswith("_fwd"):
diff --git a/python/celerite2/pymc4/ops.py b/python/celerite2/pymc/ops.py
similarity index 95%
rename from python/celerite2/pymc4/ops.py
rename to python/celerite2/pymc/ops.py
index ff0cc4f..c7ecdc5 100644
--- a/python/celerite2/pymc4/ops.py
+++ b/python/celerite2/pymc/ops.py
@@ -14,11 +14,11 @@
import json
from itertools import chain
-import aesara
-import aesara.tensor as tt
import numpy as np
import pkg_resources
-from aesara.graph import basic, op
+import pytensor
+import pytensor.tensor as pt
+from pytensor.graph import basic, op
import celerite2.backprop as backprop
import celerite2.driver as driver
@@ -82,7 +82,7 @@ def make_node(self, *inputs):
for spec in self.spec["dimensions"]
}
otypes = [
- tt.TensorType(
+ pt.TensorType(
"float64", [broadcastable[k] for k in spec["shape"]]
)()
for spec in self.spec["outputs"] + self.spec["extra_outputs"]
@@ -130,8 +130,8 @@ def perform(self, node, inputs, outputs):
def grad(self, inputs, gradients):
outputs = self(*inputs)
grads = (
- tt.zeros_like(outputs[n])
- if isinstance(b.type, aesara.gradient.DisconnectedType)
+ pt.zeros_like(outputs[n])
+ if isinstance(b.type, pytensor.gradient.DisconnectedType)
else b
for n, b in enumerate(gradients[: len(self.spec["outputs"])])
)
diff --git a/python/celerite2/pymc4/terms.py b/python/celerite2/pymc/terms.py
similarity index 69%
rename from python/celerite2/pymc4/terms.py
rename to python/celerite2/pymc/terms.py
index 4472d5e..ad14221 100644
--- a/python/celerite2/pymc4/terms.py
+++ b/python/celerite2/pymc/terms.py
@@ -13,13 +13,13 @@
"RotationTerm",
]
-import aesara
-import aesara.tensor as tt
import numpy as np
-from aesara.ifelse import ifelse
+import pytensor
+import pytensor.tensor as pt
+from pytensor.ifelse import ifelse
import celerite2.terms as base_terms
-from celerite2.pymc4 import ops
+from celerite2.pymc import ops
class Term(base_terms.Term):
@@ -30,11 +30,11 @@ def __init__(self, *, dtype="float64"):
self.coefficients = self.get_coefficients()
def __add__(self, b):
- dtype = aesara.scalar.upcast(self.dtype, b.dtype)
+ dtype = pytensor.scalar.upcast(self.dtype, b.dtype)
return TermSum(self, b, dtype=dtype)
def __mul__(self, b):
- dtype = aesara.scalar.upcast(self.dtype, b.dtype)
+ dtype = pytensor.scalar.upcast(self.dtype, b.dtype)
return TermProduct(self, b, dtype=dtype)
@property
@@ -51,17 +51,17 @@ def get_value(self, tau):
def _get_value_real(self, coefficients, tau):
ar, cr = coefficients
- tau = tt.abs(tau)
- tau = tt.shape_padright(tau)
- return tt.sum(ar * tt.exp(-cr * tau), axis=-1)
+ tau = pt.abs(tau)
+ tau = pt.shape_padright(tau)
+ return pt.sum(ar * pt.exp(-cr * tau), axis=-1)
def _get_value_complex(self, coefficients, tau):
ac, bc, cc, dc = coefficients
- tau = tt.abs(tau)
- tau = tt.shape_padright(tau)
- factor = tt.exp(-cc * tau)
- K = tt.sum(ac * factor * tt.cos(dc * tau), axis=-1)
- K += tt.sum(bc * factor * tt.sin(dc * tau), axis=-1)
+ tau = pt.abs(tau)
+ tau = pt.shape_padright(tau)
+ factor = pt.exp(-cc * tau)
+ K = pt.sum(ac * factor * pt.cos(dc * tau), axis=-1)
+ K += pt.sum(bc * factor * pt.sin(dc * tau), axis=-1)
return K
def get_psd(self, omega):
@@ -71,17 +71,17 @@ def get_psd(self, omega):
def _get_psd_real(self, coefficients, omega):
ar, cr = coefficients
- omega = tt.shape_padright(omega)
+ omega = pt.shape_padright(omega)
w2 = omega**2
- power = tt.sum(ar * cr / (cr**2 + w2), axis=-1)
+ power = pt.sum(ar * cr / (cr**2 + w2), axis=-1)
return np.sqrt(2.0 / np.pi) * power
def _get_psd_complex(self, coefficients, omega):
ac, bc, cc, dc = coefficients
- omega = tt.shape_padright(omega)
+ omega = pt.shape_padright(omega)
w2 = omega**2
w02 = cc**2 + dc**2
- power = tt.sum(
+ power = pt.sum(
((ac * cc + bc * dc) * w02 + (ac * cc - bc * dc) * w2)
/ (w2 * w2 + 2.0 * (cc**2 - dc**2) * w2 + w02 * w02),
axis=-1,
@@ -90,55 +90,55 @@ def _get_psd_complex(self, coefficients, omega):
def to_dense(self, x, diag):
K = self.get_value(x[:, None] - x[None, :])
- K += tt.diag(diag)
+ K += pt.diag(diag)
return K
def get_celerite_matrices(self, x, diag, **kwargs):
- x = tt.as_tensor_variable(x)
- diag = tt.as_tensor_variable(diag)
+ x = pt.as_tensor_variable(x)
+ diag = pt.as_tensor_variable(diag)
cr, ar, Ur, Vr = self._get_celerite_matrices_real(
self.coefficients[:2], x, **kwargs
)
cc, ac, Uc, Vc = self._get_celerite_matrices_complex(
self.coefficients[2:], x, **kwargs
)
- c = tt.concatenate((cr, cc))
+ c = pt.concatenate((cr, cc))
a = diag + ar + ac
- U = tt.concatenate((Ur, Uc), axis=1)
- V = tt.concatenate((Vr, Vc), axis=1)
+ U = pt.concatenate((Ur, Uc), axis=1)
+ V = pt.concatenate((Vr, Vc), axis=1)
return c, a, U, V
def _get_celerite_matrices_real(self, coefficients, x, **kwargs):
ar, cr = coefficients
- z = tt.zeros_like(x)
+ z = pt.zeros_like(x)
return (
cr,
- tt.sum(ar),
+ pt.sum(ar),
ar[None, :] + z[:, None],
- tt.ones_like(ar)[None, :] + z[:, None],
+ pt.ones_like(ar)[None, :] + z[:, None],
)
def _get_celerite_matrices_complex(self, coefficients, x, **kwargs):
ac, bc, cc, dc = coefficients
arg = dc[None, :] * x[:, None]
- cos = tt.cos(arg)
- sin = tt.sin(arg)
- U = tt.concatenate(
+ cos = pt.cos(arg)
+ sin = pt.sin(arg)
+ U = pt.concatenate(
(
ac[None, :] * cos + bc[None, :] * sin,
ac[None, :] * sin - bc[None, :] * cos,
),
axis=1,
)
- V = tt.concatenate((cos, sin), axis=1)
- c = tt.concatenate((cc, cc))
+ V = pt.concatenate((cos, sin), axis=1)
+ c = pt.concatenate((cc, cc))
- return c, tt.sum(ac), U, V
+ return c, pt.sum(ac), U, V
def dot(self, x, diag, y):
- x = tt.as_tensor_variable(x)
- y = tt.as_tensor_variable(y)
+ x = pt.as_tensor_variable(x)
+ y = pt.as_tensor_variable(y)
is_vector = False
if y.ndim == 1:
@@ -169,24 +169,24 @@ def terms(self):
return self._terms
def get_value(self, tau):
- tau = tt.as_tensor_variable(tau)
+ tau = pt.as_tensor_variable(tau)
return sum(term.get_value(tau) for term in self._terms)
def get_psd(self, omega):
- omega = tt.as_tensor_variable(omega)
+ omega = pt.as_tensor_variable(omega)
return sum(term.get_psd(omega) for term in self._terms)
def get_celerite_matrices(self, x, diag, **kwargs):
matrices = (
- term.get_celerite_matrices(x, tt.zeros_like(diag), **kwargs)
+ term.get_celerite_matrices(x, pt.zeros_like(diag), **kwargs)
for term in self._terms
)
c, a, U, V = zip(*matrices)
return (
- tt.concatenate(c, axis=-1),
+ pt.concatenate(c, axis=-1),
sum(a) + diag,
- tt.concatenate(U, axis=-1),
- tt.concatenate(V, axis=-1),
+ pt.concatenate(U, axis=-1),
+ pt.concatenate(V, axis=-1),
)
@@ -199,7 +199,7 @@ def __init__(self, term1, term2, **kwargs):
self.dtype = kwargs.get("dtype", "float64")
def get_value(self, tau):
- tau = tt.as_tensor_variable(tau)
+ tau = pt.as_tensor_variable(tau)
return self.term1.get_value(tau) * self.term2.get_value(tau)
def get_psd(self, omega):
@@ -208,13 +208,13 @@ def get_psd(self, omega):
)
def get_celerite_matrices(self, x, diag, **kwargs):
- z = tt.zeros_like(diag)
+ z = pt.zeros_like(diag)
c1, a1, U1, V1 = self.term1.get_celerite_matrices(x, z, **kwargs)
c2, a2, U2, V2 = self.term2.get_celerite_matrices(x, z, **kwargs)
- mg = tt.mgrid[: c1.shape[0], : c2.shape[0]]
- i = tt.flatten(mg[0])
- j = tt.flatten(mg[1])
+ mg = pt.mgrid[: c1.shape[0], : c2.shape[0]]
+ i = pt.flatten(mg[0])
+ j = pt.flatten(mg[1])
c = c1[i] + c2[j]
a = a1 * a2 + diag
@@ -259,7 +259,7 @@ def __init__(self, term, delta, **kwargs):
"coefficients"
)
self.term = term
- self.delta = tt.as_tensor_variable(delta).astype("float64")
+ self.delta = pt.as_tensor_variable(delta).astype("float64")
super().__init__(**kwargs)
def get_celerite_matrices(self, x, diag, **kwargs):
@@ -268,7 +268,7 @@ def get_celerite_matrices(self, x, diag, **kwargs):
# Real part
cd = cr * dt
- delta_diag = 2 * tt.sum(ar * (cd - tt.sinh(cd)) / cd**2)
+ delta_diag = 2 * pt.sum(ar * (cd - pt.sinh(cd)) / cd**2)
# Complex part
cd = c * dt
@@ -279,12 +279,12 @@ def get_celerite_matrices(self, x, diag, **kwargs):
C1 = a * (c2 - d2) + 2 * b * c * d
C2 = b * (c2 - d2) - 2 * a * c * d
norm = (dt * c2pd2) ** 2
- sinh = tt.sinh(cd)
- cosh = tt.cosh(cd)
- delta_diag += 2 * tt.sum(
+ sinh = pt.sinh(cd)
+ cosh = pt.cosh(cd)
+ delta_diag += 2 * pt.sum(
(
- C2 * cosh * tt.sin(dd)
- - C1 * sinh * tt.cos(dd)
+ C2 * cosh * pt.sin(dd)
+ - C1 * sinh * pt.cos(dd)
+ (a * c + b * d) * dt * c2pd2
)
/ norm
@@ -298,7 +298,7 @@ def get_coefficients(self):
# Real componenets
crd = cr * self.delta
- coeffs = [2 * ar * (tt.cosh(crd) - 1) / crd**2, cr]
+ coeffs = [2 * ar * (pt.cosh(crd) - 1) / crd**2, cr]
# Imaginary coefficients
cd = c * self.delta
@@ -306,8 +306,8 @@ def get_coefficients(self):
c2 = c**2
d2 = d**2
factor = 2.0 / (self.delta * (c2 + d2)) ** 2
- cos_term = tt.cosh(cd) * tt.cos(dd) - 1
- sin_term = tt.sinh(cd) * tt.sin(dd)
+ cos_term = pt.cosh(cd) * pt.cos(dd) - 1
+ sin_term = pt.sinh(cd) * pt.sin(dd)
C1 = a * (c2 - d2) + 2 * b * c * d
C2 = b * (c2 - d2) - 2 * a * c * d
@@ -324,7 +324,7 @@ def get_coefficients(self):
def get_psd(self, omega):
psd0 = self.term.get_psd(omega)
arg = 0.5 * self.delta * omega
- sinc = tt.switch(tt.neq(arg, 0), tt.sin(arg) / arg, tt.ones_like(arg))
+ sinc = pt.switch(pt.neq(arg, 0), pt.sin(arg) / arg, pt.ones_like(arg))
return psd0 * sinc**2
def get_value(self, tau0):
@@ -332,7 +332,7 @@ def get_value(self, tau0):
ar, cr, a, b, c, d = self.term.coefficients
# Format the lags correctly
- tau0 = tt.abs(tau0)
+ tau0 = pt.abs(tau0)
tau = tau0[..., None]
# Precompute some factors
@@ -342,13 +342,13 @@ def get_value(self, tau0):
# Real parts:
# tau > Delta
crd = cr * dt
- cosh = tt.cosh(crd)
+ cosh = pt.cosh(crd)
norm = 2 * ar / crd**2
- K_large = tt.sum(norm * (cosh - 1) * tt.exp(-cr * tau), axis=-1)
+ K_large = pt.sum(norm * (cosh - 1) * pt.exp(-cr * tau), axis=-1)
# tau < Delta
crdmt = cr * dmt
- K_small = K_large + tt.sum(norm * (crdmt - tt.sinh(crdmt)), axis=-1)
+ K_small = K_large + pt.sum(norm * (crdmt - pt.sinh(crdmt)), axis=-1)
# Complex part
cd = c * dt
@@ -359,46 +359,46 @@ def get_value(self, tau0):
C1 = a * (c2 - d2) + 2 * b * c * d
C2 = b * (c2 - d2) - 2 * a * c * d
norm = 1.0 / (dt * c2pd2) ** 2
- k0 = tt.exp(-c * tau)
- cdt = tt.cos(d * tau)
- sdt = tt.sin(d * tau)
+ k0 = pt.exp(-c * tau)
+ cdt = pt.cos(d * tau)
+ sdt = pt.sin(d * tau)
# For tau > Delta
- cos_term = 2 * (tt.cosh(cd) * tt.cos(dd) - 1)
- sin_term = 2 * (tt.sinh(cd) * tt.sin(dd))
+ cos_term = 2 * (pt.cosh(cd) * pt.cos(dd) - 1)
+ sin_term = 2 * (pt.sinh(cd) * pt.sin(dd))
factor = k0 * norm
- K_large += tt.sum(
+ K_large += pt.sum(
(C1 * cos_term - C2 * sin_term) * factor * cdt, axis=-1
)
- K_large += tt.sum(
+ K_large += pt.sum(
(C2 * cos_term + C1 * sin_term) * factor * sdt, axis=-1
)
# tau < Delta
- edmt = tt.exp(-c * dmt)
- edpt = tt.exp(-c * dpt)
+ edmt = pt.exp(-c * dmt)
+ edpt = pt.exp(-c * dpt)
cos_term = (
- edmt * tt.cos(d * dmt) + edpt * tt.cos(d * dpt) - 2 * k0 * cdt
+ edmt * pt.cos(d * dmt) + edpt * pt.cos(d * dpt) - 2 * k0 * cdt
)
sin_term = (
- edmt * tt.sin(d * dmt) + edpt * tt.sin(d * dpt) - 2 * k0 * sdt
+ edmt * pt.sin(d * dmt) + edpt * pt.sin(d * dpt) - 2 * k0 * sdt
)
- K_small += tt.sum(2 * (a * c + b * d) * c2pd2 * dmt * norm, axis=-1)
- K_small += tt.sum((C1 * cos_term + C2 * sin_term) * norm, axis=-1)
+ K_small += pt.sum(2 * (a * c + b * d) * c2pd2 * dmt * norm, axis=-1)
+ K_small += pt.sum((C1 * cos_term + C2 * sin_term) * norm, axis=-1)
- return tt.switch(tt.le(tau0, dt), K_small, K_large)
+ return pt.switch(pt.le(tau0, dt), K_small, K_large)
class RealTerm(Term):
__doc__ = base_terms.RealTerm.__doc__
def __init__(self, *, a, c, **kwargs):
- self.a = tt.as_tensor_variable(a).astype("float64")
- self.c = tt.as_tensor_variable(c).astype("float64")
+ self.a = pt.as_tensor_variable(a).astype("float64")
+ self.c = pt.as_tensor_variable(c).astype("float64")
super().__init__(**kwargs)
def get_coefficients(self):
- empty = tt.zeros(0, dtype=self.dtype)
+ empty = pt.zeros(0, dtype=self.dtype)
if self.a.ndim == 0:
return (
self.a[None],
@@ -415,14 +415,14 @@ class ComplexTerm(Term):
__doc__ = base_terms.ComplexTerm.__doc__
def __init__(self, *, a, b, c, d, **kwargs):
- self.a = tt.as_tensor_variable(a).astype("float64")
- self.b = tt.as_tensor_variable(b).astype("float64")
- self.c = tt.as_tensor_variable(c).astype("float64")
- self.d = tt.as_tensor_variable(d).astype("float64")
+ self.a = pt.as_tensor_variable(a).astype("float64")
+ self.b = pt.as_tensor_variable(b).astype("float64")
+ self.c = pt.as_tensor_variable(c).astype("float64")
+ self.d = pt.as_tensor_variable(d).astype("float64")
super().__init__(**kwargs)
def get_coefficients(self):
- empty = tt.zeros(0, dtype=self.dtype)
+ empty = pt.zeros(0, dtype=self.dtype)
if self.a.ndim == 0:
return (
empty,
@@ -440,14 +440,14 @@ class SHOTerm(Term):
__parameter_spec__ = base_terms.SHOTerm.__parameter_spec__
@base_terms.handle_parameter_spec(
- lambda x: tt.as_tensor_variable(x).astype("float64")
+ lambda x: pt.as_tensor_variable(x).astype("float64")
)
def __init__(self, *, eps=1e-5, **kwargs):
- self.eps = tt.as_tensor_variable(eps).astype("float64")
+ self.eps = pt.as_tensor_variable(eps).astype("float64")
self.dtype = kwargs.get("dtype", "float64")
self.overdamped = self.get_overdamped_coefficients()
self.underdamped = self.get_underdamped_coefficients()
- self.cond = tt.lt(self.Q, 0.5)
+ self.cond = pt.lt(self.Q, 0.5)
def get_value(self, tau):
return ifelse(
@@ -464,8 +464,8 @@ def get_psd(self, omega):
)
def get_celerite_matrices(self, x, diag, **kwargs):
- x = tt.as_tensor_variable(x)
- diag = tt.as_tensor_variable(diag)
+ x = pt.as_tensor_variable(x)
+ diag = pt.as_tensor_variable(diag)
cr, ar, Ur, Vr = super()._get_celerite_matrices_real(
self.overdamped, x, **kwargs
)
@@ -475,38 +475,38 @@ def get_celerite_matrices(self, x, diag, **kwargs):
ar = ar + diag
ac = ac + diag
- cr, cc = tt.broadcast_arrays(cr, cc)
- ar, ac = tt.broadcast_arrays(ar, ac)
- Ur, Uc = tt.broadcast_arrays(Ur, Uc)
- Vr, Vc = tt.broadcast_arrays(Vr, Vc)
+ cr, cc = pt.broadcast_arrays(cr, cc)
+ ar, ac = pt.broadcast_arrays(ar, ac)
+ Ur, Uc = pt.broadcast_arrays(Ur, Uc)
+ Vr, Vc = pt.broadcast_arrays(Vr, Vc)
return [
- tt.switch(self.cond, a, b)
+ pt.switch(self.cond, a, b)
for a, b in zip((cr, ar, Ur, Vr), (cc, ac, Uc, Vc))
]
def get_overdamped_coefficients(self):
Q = self.Q
- f = tt.sqrt(tt.maximum(1.0 - 4.0 * Q**2, self.eps))
+ f = pt.sqrt(pt.maximum(1.0 - 4.0 * Q**2, self.eps))
return (
0.5
* self.S0
* self.w0
* Q
- * tt.stack([1.0 + 1.0 / f, 1.0 - 1.0 / f]),
- 0.5 * self.w0 / Q * tt.stack([1.0 - f, 1.0 + f]),
+ * pt.stack([1.0 + 1.0 / f, 1.0 - 1.0 / f]),
+ 0.5 * self.w0 / Q * pt.stack([1.0 - f, 1.0 + f]),
)
def get_underdamped_coefficients(self):
Q = self.Q
- f = tt.sqrt(tt.maximum(4.0 * Q**2 - 1.0, self.eps))
+ f = pt.sqrt(pt.maximum(4.0 * Q**2 - 1.0, self.eps))
a = self.S0 * self.w0 * Q
c = 0.5 * self.w0 / Q
return (
- tt.stack([a]),
- tt.stack([a / f]),
- tt.stack([c]),
- tt.stack([c * f]),
+ pt.stack([a]),
+ pt.stack([a / f]),
+ pt.stack([c]),
+ pt.stack([c * f]),
)
@@ -514,15 +514,15 @@ class Matern32Term(Term):
__doc__ = base_terms.Matern32Term.__doc__
def __init__(self, *, sigma, rho, eps=0.01, **kwargs):
- self.sigma = tt.as_tensor_variable(sigma).astype("float64")
- self.rho = tt.as_tensor_variable(rho).astype("float64")
- self.eps = tt.as_tensor_variable(eps).astype("float64")
+ self.sigma = pt.as_tensor_variable(sigma).astype("float64")
+ self.rho = pt.as_tensor_variable(rho).astype("float64")
+ self.eps = pt.as_tensor_variable(eps).astype("float64")
super().__init__(**kwargs)
def get_coefficients(self):
w0 = np.sqrt(3.0) / self.rho
S0 = self.sigma**2 / w0
- empty = tt.zeros(0, dtype=self.dtype)
+ empty = pt.zeros(0, dtype=self.dtype)
return (
empty,
empty,
@@ -537,22 +537,22 @@ class RotationTerm(TermSum):
__doc__ = base_terms.RotationTerm.__doc__
def __init__(self, *, sigma, period, Q0, dQ, f, **kwargs):
- self.sigma = tt.as_tensor_variable(sigma).astype("float64")
- self.period = tt.as_tensor_variable(period).astype("float64")
- self.Q0 = tt.as_tensor_variable(Q0).astype("float64")
- self.dQ = tt.as_tensor_variable(dQ).astype("float64")
- self.f = tt.as_tensor_variable(f).astype("float64")
+ self.sigma = pt.as_tensor_variable(sigma).astype("float64")
+ self.period = pt.as_tensor_variable(period).astype("float64")
+ self.Q0 = pt.as_tensor_variable(Q0).astype("float64")
+ self.dQ = pt.as_tensor_variable(dQ).astype("float64")
+ self.f = pt.as_tensor_variable(f).astype("float64")
self.amp = self.sigma**2 / (1 + self.f)
# One term with a period of period
Q1 = 0.5 + self.Q0 + self.dQ
- w1 = 4 * np.pi * Q1 / (self.period * tt.sqrt(4 * Q1**2 - 1))
+ w1 = 4 * np.pi * Q1 / (self.period * pt.sqrt(4 * Q1**2 - 1))
S1 = self.amp / (w1 * Q1)
# Another term at half the period
Q2 = 0.5 + self.Q0
- w2 = 8 * np.pi * Q2 / (self.period * tt.sqrt(4 * Q2**2 - 1))
+ w2 = 8 * np.pi * Q2 / (self.period * pt.sqrt(4 * Q2**2 - 1))
S2 = self.f * self.amp / (w2 * Q2)
super().__init__(
diff --git a/python/test/pymc/conftest.py b/python/test/pymc/conftest.py
new file mode 100644
index 0000000..94625c2
--- /dev/null
+++ b/python/test/pymc/conftest.py
@@ -0,0 +1,14 @@
+def pytest_configure(config):
+ try:
+ import pytensor
+ except ImportError:
+ return
+
+ import platform
+
+ pytensor.config.floatX = "float64"
+ # TODO: Uncomment when PyMC commit 714b4a0 makes it into a release (probably 5.9.2)
+ # pytensor.config.compute_test_value = "raise"
+ if platform.system() == "Darwin":
+ pytensor.config.gcc.cxxflags = "-Wno-c++11-narrowing"
+ config.addinivalue_line("filterwarnings", "ignore")
diff --git a/python/test/pymc4/test_pymc4_celerite2.py b/python/test/pymc/test_pymc_celerite2.py
similarity index 92%
rename from python/test/pymc4/test_pymc4_celerite2.py
rename to python/test/pymc/test_pymc_celerite2.py
index 26e4100..d39a5d9 100644
--- a/python/test/pymc4/test_pymc4_celerite2.py
+++ b/python/test/pymc/test_pymc_celerite2.py
@@ -5,12 +5,12 @@
import celerite2
-pytest.importorskip("celerite2.pymc4")
+pytest.importorskip("celerite2.pymc")
try:
from celerite2 import terms as pyterms
- from celerite2.pymc4 import GaussianProcess, terms
- from celerite2.pymc4.celerite2 import CITATIONS
+ from celerite2.pymc import GaussianProcess, terms
+ from celerite2.pymc.celerite2 import CITATIONS
from celerite2.testing import check_gp_models
except (ImportError, ModuleNotFoundError):
pass
@@ -115,8 +115,8 @@ def test_marginal(data):
gp.marginal("obs", observed=y)
np.testing.assert_allclose(
- model.compile_logp()(model.test_point),
- model.compile_fn(gp.log_likelihood(y))(model.test_point),
+ model.compile_logp()(model.initial_point()),
+ model.compile_fn(gp.log_likelihood(y))(model.initial_point()),
)
diff --git a/python/test/pymc4/test_pymc4_ops.py b/python/test/pymc/test_pymc_ops.py
similarity index 81%
rename from python/test/pymc4/test_pymc4_ops.py
rename to python/test/pymc/test_pymc_ops.py
index dc1da8d..5f1b834 100644
--- a/python/test/pymc4/test_pymc4_ops.py
+++ b/python/test/pymc/test_pymc_ops.py
@@ -8,18 +8,18 @@
from celerite2 import backprop, driver
from celerite2.testing import get_matrices
-pytest.importorskip("celerite2.pymc4")
+pytest.importorskip("celerite2.pymc")
try:
- import aesara
- import aesara.tensor as tt
- from aesara.compile.mode import Mode
- from aesara.compile.sharedvalue import SharedVariable
- from aesara.graph.fg import FunctionGraph
- from aesara.graph.optdb import OptimizationQuery
- from aesara.link.jax import JAXLinker
-
- from celerite2.pymc4 import ops
+ import pytensor
+ import pytensor.tensor as tt
+ from pytensor.compile.mode import Mode
+ from pytensor.compile.sharedvalue import SharedVariable
+ from pytensor.graph.fg import FunctionGraph
+ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
+ from pytensor.link.jax import JAXLinker
+
+ from celerite2.pymc import ops
except (ImportError, ModuleNotFoundError):
pass
@@ -30,7 +30,9 @@
jax = None
else:
- opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
+ opts = RewriteDatabaseQuery(
+ include=[None], exclude=["cxx_only", "BlasOpt"]
+ )
jax_mode = Mode(JAXLinker(), opts)
py_mode = Mode("py", opts)
@@ -54,20 +56,20 @@ def convert_values_to_types(values):
def check_shape(op, inputs, outputs, values, result, multi):
if multi:
- shapes = aesara.function(inputs, [o.shape for o in outputs])(*values)
+ shapes = pytensor.function(inputs, [o.shape for o in outputs])(*values)
assert all(
np.all(v.shape == s) for v, s in zip(result, shapes)
), "Invalid shape inference"
else:
- shape = aesara.function(inputs, outputs.shape)(*values)
+ shape = pytensor.function(inputs, outputs.shape)(*values)
assert result.shape == shape
def check_basic(ref_func, op, values):
inputs = convert_values_to_types(values)
outputs = op(*inputs)
- result = aesara.function(inputs, outputs)(*values)
+ result = pytensor.function(inputs, outputs)(*values)
try:
result.shape
@@ -93,7 +95,7 @@ def check_basic(ref_func, op, values):
def check_grad(op, values, num_out, eps=1.234e-8):
inputs = convert_values_to_types(values)
outputs = op(*inputs)
- func = aesara.function(inputs, outputs)
+ func = pytensor.function(inputs, outputs)
vals0 = func(*values)
# Compute numerical grad
@@ -113,8 +115,8 @@ def check_grad(op, values, num_out, eps=1.234e-8):
# Compute the backprop
for k in range(num_out):
for i in range(vals0[k].size):
- res = aesara.function(
- inputs, aesara.grad(outputs[k].flatten()[i], inputs)
+ res = pytensor.function(
+ inputs, pytensor.grad(outputs[k].flatten()[i], inputs)
)(*values)
for n, b in enumerate(res):
@@ -124,7 +126,7 @@ def check_grad(op, values, num_out, eps=1.234e-8):
if jax is not None:
for k in range(num_out):
- out_grad = aesara.grad(tt.sum(outputs[k]), inputs)
+ out_grad = pytensor.grad(tt.sum(outputs[k]), inputs)
fg = FunctionGraph(inputs, out_grad)
compare_jax_and_py(fg, values)
@@ -248,20 +250,19 @@ def compare_jax_and_py(
assert_fn = partial(np.testing.assert_allclose, rtol=1e-4)
fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)]
- aesara_jax_fn = aesara.function(fn_inputs, fgraph.outputs, mode=jax_mode)
- jax_res = aesara_jax_fn(*test_inputs)
+ pytensor_jax_fn = pytensor.function(
+ fn_inputs, fgraph.outputs, mode=jax_mode
+ )
+ jax_res = pytensor_jax_fn(*test_inputs)
if must_be_device_array:
if isinstance(jax_res, list):
- assert all(
- isinstance(res, jax.interpreters.xla.DeviceArray)
- for res in jax_res
- )
+ assert all(isinstance(res, jax.Array) for res in jax_res)
else:
- assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
+ assert isinstance(jax_res, jax.Array)
- aesara_py_fn = aesara.function(fn_inputs, fgraph.outputs, mode=py_mode)
- py_res = aesara_py_fn(*test_inputs)
+ pytensor_py_fn = pytensor.function(fn_inputs, fgraph.outputs, mode=py_mode)
+ py_res = pytensor_py_fn(*test_inputs)
if len(fgraph.outputs) > 1:
for j, p in zip(jax_res, py_res):
diff --git a/python/test/pymc4/test_pymc4_terms.py b/python/test/pymc/test_pymc_terms.py
similarity index 90%
rename from python/test/pymc4/test_pymc4_terms.py
rename to python/test/pymc/test_pymc_terms.py
index 0858e58..ea186fc 100644
--- a/python/test/pymc4/test_pymc4_terms.py
+++ b/python/test/pymc/test_pymc_terms.py
@@ -5,11 +5,11 @@
import numpy as np
import pytest
-pytest.importorskip("celerite2.pymc4")
+pytest.importorskip("celerite2.pymc")
try:
from celerite2 import terms as pyterms
- from celerite2.pymc4 import terms
+ from celerite2.pymc import terms
from celerite2.testing import check_tensor_term
except (ImportError, ModuleNotFoundError):
pass
@@ -72,16 +72,16 @@ def test_base_terms(name, args):
def test_opt_error():
- import aesara.tensor as at
- from aesara import config, function, grad
+ import pytensor.tensor as pt
+ from pytensor import config, function, grad
x = np.linspace(0, 5, 10)
diag = np.full_like(x, 0.2)
with config.change_flags(on_opt_error="raise"):
- arg = at.scalar()
+ arg = pt.scalar()
arg.tag.test_value = 0.5
matrices = terms.SHOTerm(S0=1.0, w0=0.5, Q=arg).get_celerite_matrices(
x, diag
)
- function([arg], grad(sum(at.sum(m) for m in matrices), [arg]))(0.5)
+ function([arg], grad(sum(pt.sum(m) for m in matrices), [arg]))(0.5)
diff --git a/python/test/pymc4/conftest.py b/python/test/pymc4/conftest.py
deleted file mode 100644
index 3866a42..0000000
--- a/python/test/pymc4/conftest.py
+++ /dev/null
@@ -1,13 +0,0 @@
-def pytest_configure(config):
- try:
- import aesara
- except ImportError:
- return
-
- import platform
-
- aesara.config.floatX = "float64"
- aesara.config.compute_test_value = "raise"
- if platform.system() == "Darwin":
- aesara.config.gcc.cxxflags = "-Wno-c++11-narrowing"
- config.addinivalue_line("filterwarnings", "ignore")
diff --git a/setup.py b/setup.py
index 45778f8..60ad2ac 100644
--- a/setup.py
+++ b/setup.py
@@ -39,8 +39,8 @@
"scipy",
"celerite>=0.3.1",
],
- "pymc3": ["pymc3>=3.9", "numpy<1.22"],
- "pymc4": ["pymc>=4,<5"],
+ "pymc3": ["pymc3>=3.9", "numpy<1.22", "xarray<2023.10.0"],
+ "pymc": ["pymc>=5"],
"jax": ["jax", "jaxlib"],
"docs": [
"sphinx",
@@ -53,7 +53,7 @@
"matplotlib",
"scipy",
"emcee",
- "pymc>=4,<5",
+ "pymc>=5",
"tqdm",
"numpyro",
],