diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 3387a49..00a7035 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -25,8 +25,8 @@ jobs: - "core" - "jax" - "pymc3" - # - "pymc4" - # - "pymc4_jax" + - "pymc" + - "pymc_jax" steps: - name: Checkout diff --git a/docs/api/pymc.rst b/docs/api/pymc.rst new file mode 100644 index 0000000..b140466 --- /dev/null +++ b/docs/api/pymc.rst @@ -0,0 +1,46 @@ +.. _pymc-api: + +PyMC (v5+) interface +=============== + +This ``celerite2.pymc`` submodule provides access to the *celerite2* models +within the `PyTensor `_ framework. Of special +interest, this adds support for probabilistic model building using `PyMC +`_ v5 or later. + +*Note: PyMC v4 was a short-lived version of PyMC with the Aesara backend. +Celerite2 now only supports to PyMC 5, but past releases of celerite2 +might work with Aesara.* + +The :ref:`first` tutorial demonstrates the use of this interface, while this +page provides the details for the :class:`celerite2.pymc.GaussianProcess` class +which provides all this functionality. This page does not include documentation +for the term models defined in PyTensor, but you can refer to the +:ref:`python-terms` section of the :ref:`python-api` documentation. All of those +models are implemented in PyTensor and you can access them using something like +the following: + +.. code-block:: python + + import pytensor.tensor as pt + from celerite2.pymc import GaussianProcess, terms + + term = terms.SHOTerm(S0=at.scalar(), w0=at.scalar(), Q=at.scalar()) + gp = GaussianProcess(term) + +The :class:`celerite2.pymc.GaussianProcess` class is detailed below: + +.. autoclass:: celerite2.pymc.GaussianProcess + :inherited-members: + :exclude-members: sample, sample_conditional, recompute + + +PyMC (v5+) support +----------------- + +This implementation comes with a custom PyMC ``Distribution`` that represents a +multivariate normal with a *celerite* covariance matrix. This is used by the +:func:`celerite2.pymc.GaussianProcess.marginal` method documented above which +adds a marginal likelihood node to a PyMC model. + +.. autoclass:: celerite2.pymc.distribution.CeleriteNormal diff --git a/docs/api/pymc4.rst b/docs/api/pymc4.rst deleted file mode 100644 index 3840c64..0000000 --- a/docs/api/pymc4.rst +++ /dev/null @@ -1,42 +0,0 @@ -.. _pymc4-api: - -PyMC4 interface -=============== - -This ``celerite2.pymc4`` submodule provides access to the *celerite2* models -within the `Aesara `_ framework. Of special -interest, this adds support for probabilistic model building using `PyMC -`_ v4 or later. - -The :ref:`first` tutorial demonstrates the use of this interface, while this -page provides the details for the :class:`celerite2.pymc4.GaussianProcess` class -which provides all this functionality. This page does not include documentation -for the term models defined in Aesara, but you can refer to the -:ref:`python-terms` section of the :ref:`python-api` documentation. All of those -models are implemented in Aesara and you can access them using something like -the following: - -.. code-block:: python - - import aesara.tensor as at - from celerite2.pymc4 import GaussianProcess, terms - - term = terms.SHOTerm(S0=at.scalar(), w0=at.scalar(), Q=at.scalar()) - gp = GaussianProcess(term) - -The :class:`celerite2.pymc4.GaussianProcess` class is detailed below: - -.. autoclass:: celerite2.pymc4.GaussianProcess - :inherited-members: - :exclude-members: sample, sample_conditional, recompute - - -PyMC (v4) support ------------------ - -This implementation comes with a custom PyMC ``Distribution`` that represents a -multivariate normal with a *celerite* covariance matrix. This is used by the -:func:`celerite2.pymc4.GaussianProcess.marginal` method documented above which -adds a marginal likelihood node to a PyMC model. - -.. autoclass:: celerite2.pymc4.distribution.CeleriteNormal diff --git a/docs/index.rst b/docs/index.rst index 9b1c3a8..7d0fbfd 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -7,7 +7,7 @@ Regression in one dimension and this library, *celerite2* is a re-write of the original `celerite project `_ to improve numerical stability and integration with various machine learning frameworks. This implementation includes interfaces in Python and C++, with full support for -PyMC (v3 and v4) and JAX. +PyMC (v3 and v5+) and JAX. This documentation won't teach you the fundamentals of GP modeling but the best resource for learning about this is available for free online: `Rasmussen & @@ -43,7 +43,7 @@ an issue `_ there. api/python api/pymc3 - api/pymc4 + api/pymc api/jax api/c++ diff --git a/docs/tutorials/first.ipynb b/docs/tutorials/first.ipynb index 4e6e965..9aa422f 100644 --- a/docs/tutorials/first.ipynb +++ b/docs/tutorials/first.ipynb @@ -330,7 +330,7 @@ "source": [ "## Posterior inference using PyMC\n", "\n", - "*celerite2* also includes support for probabilistic modeling using PyMC (v4 or v3, using the `celerite2.pymc3` or `celerite2.pymc4` submodule respectively), and we can implement the same model from above as follows:" + "*celerite2* also includes support for probabilistic modeling using PyMC (v5 or v3, using the `celerite2.pymc` or `celerite2.pymc3` submodule respectively), and we can implement the same model from above as follows:" ] }, { @@ -340,50 +340,44 @@ "metadata": {}, "outputs": [], "source": [ - "import warnings\n", + "import pymc as pm\n", + "from celerite2.pymc import GaussianProcess, terms as pm_terms\n", "\n", - "with warnings.catch_warnings():\n", - " warnings.filterwarnings(\"ignore\", category=UserWarning)\n", - " warnings.filterwarnings(\"ignore\", category=RuntimeWarning)\n", + "with pm.Model() as model:\n", + " mean = pm.Normal(\"mean\", mu=0.0, sigma=prior_sigma)\n", + " log_jitter = pm.Normal(\"log_jitter\", mu=0.0, sigma=prior_sigma)\n", "\n", - " import pymc as pm\n", - " from celerite2.pymc4 import GaussianProcess, terms as pm_terms\n", - "\n", - " with pm.Model() as model:\n", - " mean = pm.Normal(\"mean\", mu=0.0, sigma=prior_sigma)\n", - " log_jitter = pm.Normal(\"log_jitter\", mu=0.0, sigma=prior_sigma)\n", - "\n", - " log_sigma1 = pm.Normal(\"log_sigma1\", mu=0.0, sigma=prior_sigma)\n", - " log_rho1 = pm.Normal(\"log_rho1\", mu=0.0, sigma=prior_sigma)\n", - " log_tau = pm.Normal(\"log_tau\", mu=0.0, sigma=prior_sigma)\n", - " term1 = pm_terms.SHOTerm(\n", - " sigma=pm.math.exp(log_sigma1),\n", - " rho=pm.math.exp(log_rho1),\n", - " tau=pm.math.exp(log_tau),\n", - " )\n", + " log_sigma1 = pm.Normal(\"log_sigma1\", mu=0.0, sigma=prior_sigma)\n", + " log_rho1 = pm.Normal(\"log_rho1\", mu=0.0, sigma=prior_sigma)\n", + " log_tau = pm.Normal(\"log_tau\", mu=0.0, sigma=prior_sigma)\n", + " term1 = pm_terms.SHOTerm(\n", + " sigma=pm.math.exp(log_sigma1),\n", + " rho=pm.math.exp(log_rho1),\n", + " tau=pm.math.exp(log_tau),\n", + " )\n", "\n", - " log_sigma2 = pm.Normal(\"log_sigma2\", mu=0.0, sigma=prior_sigma)\n", - " log_rho2 = pm.Normal(\"log_rho2\", mu=0.0, sigma=prior_sigma)\n", - " term2 = pm_terms.SHOTerm(\n", - " sigma=pm.math.exp(log_sigma2), rho=pm.math.exp(log_rho2), Q=0.25\n", - " )\n", + " log_sigma2 = pm.Normal(\"log_sigma2\", mu=0.0, sigma=prior_sigma)\n", + " log_rho2 = pm.Normal(\"log_rho2\", mu=0.0, sigma=prior_sigma)\n", + " term2 = pm_terms.SHOTerm(\n", + " sigma=pm.math.exp(log_sigma2), rho=pm.math.exp(log_rho2), Q=0.25\n", + " )\n", "\n", - " kernel = term1 + term2\n", - " gp = GaussianProcess(kernel, mean=mean)\n", - " gp.compute(t, diag=yerr**2 + pm.math.exp(log_jitter), quiet=True)\n", - " gp.marginal(\"obs\", observed=y)\n", + " kernel = term1 + term2\n", + " gp = GaussianProcess(kernel, mean=mean)\n", + " gp.compute(t, diag=yerr**2 + pm.math.exp(log_jitter), quiet=True)\n", + " gp.marginal(\"obs\", observed=y)\n", "\n", - " pm.Deterministic(\"psd\", kernel.get_psd(omega))\n", + " pm.Deterministic(\"psd\", kernel.get_psd(omega))\n", "\n", - " trace = pm.sample(\n", - " tune=1000,\n", - " draws=1000,\n", - " target_accept=0.9,\n", - " init=\"adapt_full\",\n", - " cores=2,\n", - " chains=2,\n", - " random_seed=34923,\n", - " )" + " trace = pm.sample(\n", + " tune=1000,\n", + " draws=1000,\n", + " target_accept=0.9,\n", + " init=\"adapt_full\",\n", + " cores=2,\n", + " chains=2,\n", + " random_seed=34923,\n", + " )" ] }, { @@ -411,7 +405,7 @@ "plt.xlim(freq.min(), freq.max())\n", "plt.xlabel(\"frequency [1 / day]\")\n", "plt.ylabel(\"power [day ppt$^2$]\")\n", - "_ = plt.title(\"posterior psd using PyMC3\")" + "_ = plt.title(\"posterior psd using PyMC\")" ] }, { @@ -488,7 +482,7 @@ "id": "6f3abdcf", "metadata": {}, "source": [ - "This runtime was similar to the PyMC3 result from above, and (as we'll see below) the convergence is also similar.\n", + "This runtime was similar to the PyMC result from above, and (as we'll see below) the convergence is also similar.\n", "Any difference in runtime will probably disappear for more computationally expensive models, but this interface is looking pretty great here!\n", "\n", "As above, we can plot the posterior expectations for the power spectrum:" @@ -563,7 +557,7 @@ " bins,\n", " histtype=\"step\",\n", " density=True,\n", - " label=\"PyMC3\",\n", + " label=\"PyMC\",\n", ")\n", "plt.hist(\n", " np.exp(np.asarray((numpyro_data.posterior[\"log_rho1\"].T)).flatten()),\n", @@ -657,7 +651,7 @@ "metadata": {}, "source": [ "Overall these results are consistent, but the $\\hat{R}$ values are a bit high for the emcee run, so I'd probably run that for longer.\n", - "Either way, for models like these, PyMC3 and numpyro are generally going to be much better inference tools (in terms of runtime per effective sample) than emcee, so those are the recommended interfaces if the rest of your model can be easily implemented in such a framework." + "Either way, for models like these, PyMC and numpyro are generally going to be much better inference tools (in terms of runtime per effective sample) than emcee, so those are the recommended interfaces if the rest of your model can be easily implemented in such a framework." ] } ], diff --git a/noxfile.py b/noxfile.py index ff729c8..1560160 100644 --- a/noxfile.py +++ b/noxfile.py @@ -30,20 +30,20 @@ def pymc3(session): @nox.session(python=ALL_PYTHON_VS) -def pymc4(session): - session.install(".[test,pymc4]") - _session_run(session, "python/test/pymc4") +def pymc(session): + session.install(".[test,pymc]") + _session_run(session, "python/test/pymc") @nox.session(python=ALL_PYTHON_VS) -def pymc4_jax(session): - session.install(".[test,jax,pymc4]") - _session_run(session, "python/test/pymc4/test_pymc4_ops.py") +def pymc_jax(session): + session.install(".[test,jax,pymc]") + _session_run(session, "python/test/pymc/test_pymc_ops.py") @nox.session(python=ALL_PYTHON_VS) def full(session): - session.install(".[test,jax,pymc3,pymc4]") + session.install(".[test,jax,pymc3,pymc]") _session_run(session, "python/test") diff --git a/python/celerite2/jax/ops.py b/python/celerite2/jax/ops.py index 17e1666..bc6e983 100644 --- a/python/celerite2/jax/ops.py +++ b/python/celerite2/jax/ops.py @@ -25,7 +25,7 @@ import pkg_resources from jax import core, lax from jax import numpy as jnp -from jax.abstract_arrays import ShapedArray +from jax.core import ShapedArray from jax.interpreters import ad, xla from jax.lib import xla_client @@ -201,8 +201,8 @@ def _rev_translation_rule(name, spec, c, *args): def _build_op(name, spec): - xla_client.register_cpu_custom_call_target( - name, getattr(xla_ops, spec["name"])() + xla_client.register_custom_call_target( + name, getattr(xla_ops, spec["name"])(), platform="cpu" ) prim = core.Primitive(f"celerite2_{spec['name']}") @@ -216,8 +216,10 @@ def _build_op(name, spec): if not spec["has_rev"]: return prim, None - xla_client.register_cpu_custom_call_target( - name + b"_rev", getattr(xla_ops, f"{spec['name']}_rev")() + xla_client.register_custom_call_target( + name + b"_rev", + getattr(xla_ops, f"{spec['name']}_rev")(), + platform="cpu", ) jvp_prim = core.Primitive(f"celerite2_{spec['name']}_jvp") diff --git a/python/celerite2/pymc4/__init__.py b/python/celerite2/pymc/__init__.py similarity index 65% rename from python/celerite2/pymc4/__init__.py rename to python/celerite2/pymc/__init__.py index 76c8701..ee77b01 100644 --- a/python/celerite2/pymc4/__init__.py +++ b/python/celerite2/pymc/__init__.py @@ -4,27 +4,27 @@ def __set_compiler_flags(): - import aesara + import pytensor def add_flag(current, new): if new in current: return current return f"{current} {new}" - current = aesara.config.gcc__cxxflags + current = pytensor.config.gcc__cxxflags current = add_flag(current, "-Wno-c++11-narrowing") current = add_flag(current, "-fno-exceptions") current = add_flag(current, "-fno-unwind-tables") current = add_flag(current, "-fno-asynchronous-unwind-tables") - aesara.config.gcc__cxxflags = current + pytensor.config.gcc__cxxflags = current __set_compiler_flags() -from celerite2.pymc4 import terms -from celerite2.pymc4.celerite2 import GaussianProcess +from celerite2.pymc import terms # noqa +from celerite2.pymc.celerite2 import GaussianProcess # noqa try: - from celerite2.pymc4 import jax_support # noqa + from celerite2.pymc import jax_support # noqa except ImportError: pass diff --git a/python/celerite2/pymc4/celerite2.py b/python/celerite2/pymc/celerite2.py similarity index 87% rename from python/celerite2/pymc4/celerite2.py rename to python/celerite2/pymc/celerite2.py index 735007d..13386d9 100644 --- a/python/celerite2/pymc4/celerite2.py +++ b/python/celerite2/pymc/celerite2.py @@ -2,13 +2,13 @@ __all__ = ["GaussianProcess", "ConditionalDistribution"] -import aesara.tensor as tt import numpy as np -from aesara.raise_op import Assert +import pytensor.tensor as pt +from pytensor.raise_op import Assert from celerite2.citation import CITATIONS from celerite2.core import BaseConditionalDistribution, BaseGaussianProcess -from celerite2.pymc4 import ops +from celerite2.pymc import ops class ConditionalDistribution(BaseConditionalDistribution): @@ -22,37 +22,37 @@ def _do_general_matmul(self, c, U1, V1, U2, V2, inp, target): return target def _diagdot(self, a, b): - return tt.batched_dot(a.T, b.T) + return pt.batched_dot(a.T, b.T) class GaussianProcess(BaseGaussianProcess): conditional_distribution = ConditionalDistribution def _as_tensor(self, tensor): - return tt.as_tensor_variable(tensor).astype("float64") + return pt.as_tensor_variable(tensor).astype("float64") def _zeros_like(self, tensor): - return tt.zeros_like(tensor) + return pt.zeros_like(tensor) def _do_compute(self, quiet): if quiet: self._d, self._W, _ = ops.factor_quiet( self._t, self._c, self._a, self._U, self._V ) - self._log_det = tt.switch( - tt.any(self._d <= 0.0), -np.inf, tt.sum(tt.log(self._d)) + self._log_det = pt.switch( + pt.any(self._d <= 0.0), -np.inf, pt.sum(pt.log(self._d)) ) else: self._d, self._W, _ = ops.factor( self._t, self._c, self._a, self._U, self._V ) - self._log_det = tt.sum(tt.log(self._d)) + self._log_det = pt.sum(pt.log(self._d)) self._norm = -0.5 * (self._log_det + self._size * np.log(2 * np.pi)) def _check_sorted(self, t): - return Assert()(t, tt.all(t[1:] - t[:-1] >= 0)) + return Assert()(t, pt.all(t[1:] - t[:-1] >= 0)) def _do_solve(self, y): z = ops.solve_lower(self._t, self._c, self._U, self._W, y)[0] @@ -69,7 +69,7 @@ def _do_norm(self, y): alpha = ops.solve_lower( self._t, self._c, self._U, self._W, y[:, None] )[0][:, 0] - return tt.sum(alpha**2 / self._d) + return pt.sum(alpha**2 / self._d) def _add_citations_to_pymc_model(self, **kwargs): import pymc as pm @@ -87,10 +87,10 @@ def marginal(self, name, **kwargs): observed (optional): The observed data Returns: - A :class:`celerite2.pymc3.CeleriteNormal` distribution + A :class:`celerite2.pymc.CeleriteNormal` distribution representing the marginal likelihood. """ - from celerite2.pymc4.distribution import CeleriteNormal + from celerite2.pymc.distribution import CeleriteNormal self._add_citations_to_pymc_model(**kwargs) return CeleriteNormal( @@ -108,7 +108,7 @@ def marginal(self, name, **kwargs): def conditional( self, name, y, t=None, include_mean=True, kernel=None, **kwargs ): - """Add a variable representing the conditional density to a PyMC3 model + """Add a variable representing the conditional density to a PyMC model .. note:: The performance of this method will generally be poor since the sampler will numerically sample this parameter. Depending on diff --git a/python/celerite2/pymc4/distribution.py b/python/celerite2/pymc/distribution.py similarity index 76% rename from python/celerite2/pymc4/distribution.py rename to python/celerite2/pymc/distribution.py index 3754339..1583772 100644 --- a/python/celerite2/pymc4/distribution.py +++ b/python/celerite2/pymc/distribution.py @@ -2,16 +2,19 @@ __all__ = ["CeleriteNormal"] -import aesara.tensor as tt import numpy as np -from aesara.tensor.random.op import RandomVariable -from aesara.tensor.random.utils import broadcast_params +import pytensor.tensor as pt from pymc.distributions.dist_math import check_parameters from pymc.distributions.distribution import Continuous from pymc.distributions.shape_utils import rv_size_is_none +from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.random.utils import ( + broadcast_params, + supp_shape_from_ref_param_shape, +) import celerite2.driver as driver -from celerite2.pymc4 import ops +from celerite2.pymc import ops def safe_celerite_normal(rng, mean, norm, t, c, U, W, d, size=None): @@ -33,6 +36,14 @@ class CeleriteNormalRV(RandomVariable): dtype = "floatX" _print_name = ("CeleriteNormal", "\\operatorname{CeleriteNormal}") + def _supp_shape_from_params(self, dist_params, param_shapes=None): + return supp_shape_from_ref_param_shape( + ndim_supp=self.ndim_supp, + dist_params=dist_params, + param_shapes=param_shapes, + ref_param_idx=2, + ) + @classmethod def rng_fn(cls, rng, mean, norm, t, c, U, W, d, size): if any( @@ -90,26 +101,26 @@ class CeleriteNormal(Continuous): @classmethod def dist(cls, mean, norm, t, c, U, W, d, **kwargs): - mean = tt.as_tensor_variable(mean) - norm = tt.as_tensor_variable(norm) - t = tt.as_tensor_variable(t) - c = tt.as_tensor_variable(c) - U = tt.as_tensor_variable(U) - W = tt.as_tensor_variable(W) - d = tt.as_tensor_variable(d) - mean = tt.broadcast_arrays(mean, t)[0] + mean = pt.as_tensor_variable(mean) + norm = pt.as_tensor_variable(norm) + t = pt.as_tensor_variable(t) + c = pt.as_tensor_variable(c) + U = pt.as_tensor_variable(U) + W = pt.as_tensor_variable(W) + d = pt.as_tensor_variable(d) + mean = pt.broadcast_arrays(mean, t)[0] return super().dist([mean, norm, t, c, U, W, d], **kwargs) def moment(rv, size, mean, *args): moment = mean if not rv_size_is_none(size): - moment_size = tt.concatenate([size, [mean.shape[-1]]]) - moment = tt.full(moment_size, mean) + moment_size = pt.concatenate([size, [mean.shape[-1]]]) + moment = pt.full(moment_size, mean) return moment def logp(value, mean, norm, t, c, U, W, d): - ok = tt.all(tt.gt(d, 0.0)) + ok = pt.all(pt.gt(d, 0.0)) alpha = value - mean alpha = ops.solve_lower(t, c, U, W, alpha[:, None])[0][:, 0] - logp = norm - 0.5 * tt.sum(alpha**2 / d) + logp = norm - 0.5 * pt.sum(alpha**2 / d) return check_parameters(logp, ok) diff --git a/python/celerite2/pymc4/jax_support.py b/python/celerite2/pymc/jax_support.py similarity index 71% rename from python/celerite2/pymc4/jax_support.py rename to python/celerite2/pymc/jax_support.py index 2d268eb..b41af72 100644 --- a/python/celerite2/pymc4/jax_support.py +++ b/python/celerite2/pymc/jax_support.py @@ -1,10 +1,10 @@ -from aesara.link.jax.dispatch import jax_funcify +from pytensor.link.jax.dispatch import jax_funcify from celerite2.jax import ops as jax_ops -from celerite2.pymc4 import ops as pymc4_ops +from celerite2.pymc import ops as pymc_ops -@jax_funcify.register(pymc4_ops._CeleriteOp) +@jax_funcify.register(pymc_ops._CeleriteOp) def jax_funcify_Celerite(op, **kwargs): name = op.name if name.endswith("_fwd"): diff --git a/python/celerite2/pymc4/ops.py b/python/celerite2/pymc/ops.py similarity index 95% rename from python/celerite2/pymc4/ops.py rename to python/celerite2/pymc/ops.py index ff0cc4f..c7ecdc5 100644 --- a/python/celerite2/pymc4/ops.py +++ b/python/celerite2/pymc/ops.py @@ -14,11 +14,11 @@ import json from itertools import chain -import aesara -import aesara.tensor as tt import numpy as np import pkg_resources -from aesara.graph import basic, op +import pytensor +import pytensor.tensor as pt +from pytensor.graph import basic, op import celerite2.backprop as backprop import celerite2.driver as driver @@ -82,7 +82,7 @@ def make_node(self, *inputs): for spec in self.spec["dimensions"] } otypes = [ - tt.TensorType( + pt.TensorType( "float64", [broadcastable[k] for k in spec["shape"]] )() for spec in self.spec["outputs"] + self.spec["extra_outputs"] @@ -130,8 +130,8 @@ def perform(self, node, inputs, outputs): def grad(self, inputs, gradients): outputs = self(*inputs) grads = ( - tt.zeros_like(outputs[n]) - if isinstance(b.type, aesara.gradient.DisconnectedType) + pt.zeros_like(outputs[n]) + if isinstance(b.type, pytensor.gradient.DisconnectedType) else b for n, b in enumerate(gradients[: len(self.spec["outputs"])]) ) diff --git a/python/celerite2/pymc4/terms.py b/python/celerite2/pymc/terms.py similarity index 69% rename from python/celerite2/pymc4/terms.py rename to python/celerite2/pymc/terms.py index 4472d5e..ad14221 100644 --- a/python/celerite2/pymc4/terms.py +++ b/python/celerite2/pymc/terms.py @@ -13,13 +13,13 @@ "RotationTerm", ] -import aesara -import aesara.tensor as tt import numpy as np -from aesara.ifelse import ifelse +import pytensor +import pytensor.tensor as pt +from pytensor.ifelse import ifelse import celerite2.terms as base_terms -from celerite2.pymc4 import ops +from celerite2.pymc import ops class Term(base_terms.Term): @@ -30,11 +30,11 @@ def __init__(self, *, dtype="float64"): self.coefficients = self.get_coefficients() def __add__(self, b): - dtype = aesara.scalar.upcast(self.dtype, b.dtype) + dtype = pytensor.scalar.upcast(self.dtype, b.dtype) return TermSum(self, b, dtype=dtype) def __mul__(self, b): - dtype = aesara.scalar.upcast(self.dtype, b.dtype) + dtype = pytensor.scalar.upcast(self.dtype, b.dtype) return TermProduct(self, b, dtype=dtype) @property @@ -51,17 +51,17 @@ def get_value(self, tau): def _get_value_real(self, coefficients, tau): ar, cr = coefficients - tau = tt.abs(tau) - tau = tt.shape_padright(tau) - return tt.sum(ar * tt.exp(-cr * tau), axis=-1) + tau = pt.abs(tau) + tau = pt.shape_padright(tau) + return pt.sum(ar * pt.exp(-cr * tau), axis=-1) def _get_value_complex(self, coefficients, tau): ac, bc, cc, dc = coefficients - tau = tt.abs(tau) - tau = tt.shape_padright(tau) - factor = tt.exp(-cc * tau) - K = tt.sum(ac * factor * tt.cos(dc * tau), axis=-1) - K += tt.sum(bc * factor * tt.sin(dc * tau), axis=-1) + tau = pt.abs(tau) + tau = pt.shape_padright(tau) + factor = pt.exp(-cc * tau) + K = pt.sum(ac * factor * pt.cos(dc * tau), axis=-1) + K += pt.sum(bc * factor * pt.sin(dc * tau), axis=-1) return K def get_psd(self, omega): @@ -71,17 +71,17 @@ def get_psd(self, omega): def _get_psd_real(self, coefficients, omega): ar, cr = coefficients - omega = tt.shape_padright(omega) + omega = pt.shape_padright(omega) w2 = omega**2 - power = tt.sum(ar * cr / (cr**2 + w2), axis=-1) + power = pt.sum(ar * cr / (cr**2 + w2), axis=-1) return np.sqrt(2.0 / np.pi) * power def _get_psd_complex(self, coefficients, omega): ac, bc, cc, dc = coefficients - omega = tt.shape_padright(omega) + omega = pt.shape_padright(omega) w2 = omega**2 w02 = cc**2 + dc**2 - power = tt.sum( + power = pt.sum( ((ac * cc + bc * dc) * w02 + (ac * cc - bc * dc) * w2) / (w2 * w2 + 2.0 * (cc**2 - dc**2) * w2 + w02 * w02), axis=-1, @@ -90,55 +90,55 @@ def _get_psd_complex(self, coefficients, omega): def to_dense(self, x, diag): K = self.get_value(x[:, None] - x[None, :]) - K += tt.diag(diag) + K += pt.diag(diag) return K def get_celerite_matrices(self, x, diag, **kwargs): - x = tt.as_tensor_variable(x) - diag = tt.as_tensor_variable(diag) + x = pt.as_tensor_variable(x) + diag = pt.as_tensor_variable(diag) cr, ar, Ur, Vr = self._get_celerite_matrices_real( self.coefficients[:2], x, **kwargs ) cc, ac, Uc, Vc = self._get_celerite_matrices_complex( self.coefficients[2:], x, **kwargs ) - c = tt.concatenate((cr, cc)) + c = pt.concatenate((cr, cc)) a = diag + ar + ac - U = tt.concatenate((Ur, Uc), axis=1) - V = tt.concatenate((Vr, Vc), axis=1) + U = pt.concatenate((Ur, Uc), axis=1) + V = pt.concatenate((Vr, Vc), axis=1) return c, a, U, V def _get_celerite_matrices_real(self, coefficients, x, **kwargs): ar, cr = coefficients - z = tt.zeros_like(x) + z = pt.zeros_like(x) return ( cr, - tt.sum(ar), + pt.sum(ar), ar[None, :] + z[:, None], - tt.ones_like(ar)[None, :] + z[:, None], + pt.ones_like(ar)[None, :] + z[:, None], ) def _get_celerite_matrices_complex(self, coefficients, x, **kwargs): ac, bc, cc, dc = coefficients arg = dc[None, :] * x[:, None] - cos = tt.cos(arg) - sin = tt.sin(arg) - U = tt.concatenate( + cos = pt.cos(arg) + sin = pt.sin(arg) + U = pt.concatenate( ( ac[None, :] * cos + bc[None, :] * sin, ac[None, :] * sin - bc[None, :] * cos, ), axis=1, ) - V = tt.concatenate((cos, sin), axis=1) - c = tt.concatenate((cc, cc)) + V = pt.concatenate((cos, sin), axis=1) + c = pt.concatenate((cc, cc)) - return c, tt.sum(ac), U, V + return c, pt.sum(ac), U, V def dot(self, x, diag, y): - x = tt.as_tensor_variable(x) - y = tt.as_tensor_variable(y) + x = pt.as_tensor_variable(x) + y = pt.as_tensor_variable(y) is_vector = False if y.ndim == 1: @@ -169,24 +169,24 @@ def terms(self): return self._terms def get_value(self, tau): - tau = tt.as_tensor_variable(tau) + tau = pt.as_tensor_variable(tau) return sum(term.get_value(tau) for term in self._terms) def get_psd(self, omega): - omega = tt.as_tensor_variable(omega) + omega = pt.as_tensor_variable(omega) return sum(term.get_psd(omega) for term in self._terms) def get_celerite_matrices(self, x, diag, **kwargs): matrices = ( - term.get_celerite_matrices(x, tt.zeros_like(diag), **kwargs) + term.get_celerite_matrices(x, pt.zeros_like(diag), **kwargs) for term in self._terms ) c, a, U, V = zip(*matrices) return ( - tt.concatenate(c, axis=-1), + pt.concatenate(c, axis=-1), sum(a) + diag, - tt.concatenate(U, axis=-1), - tt.concatenate(V, axis=-1), + pt.concatenate(U, axis=-1), + pt.concatenate(V, axis=-1), ) @@ -199,7 +199,7 @@ def __init__(self, term1, term2, **kwargs): self.dtype = kwargs.get("dtype", "float64") def get_value(self, tau): - tau = tt.as_tensor_variable(tau) + tau = pt.as_tensor_variable(tau) return self.term1.get_value(tau) * self.term2.get_value(tau) def get_psd(self, omega): @@ -208,13 +208,13 @@ def get_psd(self, omega): ) def get_celerite_matrices(self, x, diag, **kwargs): - z = tt.zeros_like(diag) + z = pt.zeros_like(diag) c1, a1, U1, V1 = self.term1.get_celerite_matrices(x, z, **kwargs) c2, a2, U2, V2 = self.term2.get_celerite_matrices(x, z, **kwargs) - mg = tt.mgrid[: c1.shape[0], : c2.shape[0]] - i = tt.flatten(mg[0]) - j = tt.flatten(mg[1]) + mg = pt.mgrid[: c1.shape[0], : c2.shape[0]] + i = pt.flatten(mg[0]) + j = pt.flatten(mg[1]) c = c1[i] + c2[j] a = a1 * a2 + diag @@ -259,7 +259,7 @@ def __init__(self, term, delta, **kwargs): "coefficients" ) self.term = term - self.delta = tt.as_tensor_variable(delta).astype("float64") + self.delta = pt.as_tensor_variable(delta).astype("float64") super().__init__(**kwargs) def get_celerite_matrices(self, x, diag, **kwargs): @@ -268,7 +268,7 @@ def get_celerite_matrices(self, x, diag, **kwargs): # Real part cd = cr * dt - delta_diag = 2 * tt.sum(ar * (cd - tt.sinh(cd)) / cd**2) + delta_diag = 2 * pt.sum(ar * (cd - pt.sinh(cd)) / cd**2) # Complex part cd = c * dt @@ -279,12 +279,12 @@ def get_celerite_matrices(self, x, diag, **kwargs): C1 = a * (c2 - d2) + 2 * b * c * d C2 = b * (c2 - d2) - 2 * a * c * d norm = (dt * c2pd2) ** 2 - sinh = tt.sinh(cd) - cosh = tt.cosh(cd) - delta_diag += 2 * tt.sum( + sinh = pt.sinh(cd) + cosh = pt.cosh(cd) + delta_diag += 2 * pt.sum( ( - C2 * cosh * tt.sin(dd) - - C1 * sinh * tt.cos(dd) + C2 * cosh * pt.sin(dd) + - C1 * sinh * pt.cos(dd) + (a * c + b * d) * dt * c2pd2 ) / norm @@ -298,7 +298,7 @@ def get_coefficients(self): # Real componenets crd = cr * self.delta - coeffs = [2 * ar * (tt.cosh(crd) - 1) / crd**2, cr] + coeffs = [2 * ar * (pt.cosh(crd) - 1) / crd**2, cr] # Imaginary coefficients cd = c * self.delta @@ -306,8 +306,8 @@ def get_coefficients(self): c2 = c**2 d2 = d**2 factor = 2.0 / (self.delta * (c2 + d2)) ** 2 - cos_term = tt.cosh(cd) * tt.cos(dd) - 1 - sin_term = tt.sinh(cd) * tt.sin(dd) + cos_term = pt.cosh(cd) * pt.cos(dd) - 1 + sin_term = pt.sinh(cd) * pt.sin(dd) C1 = a * (c2 - d2) + 2 * b * c * d C2 = b * (c2 - d2) - 2 * a * c * d @@ -324,7 +324,7 @@ def get_coefficients(self): def get_psd(self, omega): psd0 = self.term.get_psd(omega) arg = 0.5 * self.delta * omega - sinc = tt.switch(tt.neq(arg, 0), tt.sin(arg) / arg, tt.ones_like(arg)) + sinc = pt.switch(pt.neq(arg, 0), pt.sin(arg) / arg, pt.ones_like(arg)) return psd0 * sinc**2 def get_value(self, tau0): @@ -332,7 +332,7 @@ def get_value(self, tau0): ar, cr, a, b, c, d = self.term.coefficients # Format the lags correctly - tau0 = tt.abs(tau0) + tau0 = pt.abs(tau0) tau = tau0[..., None] # Precompute some factors @@ -342,13 +342,13 @@ def get_value(self, tau0): # Real parts: # tau > Delta crd = cr * dt - cosh = tt.cosh(crd) + cosh = pt.cosh(crd) norm = 2 * ar / crd**2 - K_large = tt.sum(norm * (cosh - 1) * tt.exp(-cr * tau), axis=-1) + K_large = pt.sum(norm * (cosh - 1) * pt.exp(-cr * tau), axis=-1) # tau < Delta crdmt = cr * dmt - K_small = K_large + tt.sum(norm * (crdmt - tt.sinh(crdmt)), axis=-1) + K_small = K_large + pt.sum(norm * (crdmt - pt.sinh(crdmt)), axis=-1) # Complex part cd = c * dt @@ -359,46 +359,46 @@ def get_value(self, tau0): C1 = a * (c2 - d2) + 2 * b * c * d C2 = b * (c2 - d2) - 2 * a * c * d norm = 1.0 / (dt * c2pd2) ** 2 - k0 = tt.exp(-c * tau) - cdt = tt.cos(d * tau) - sdt = tt.sin(d * tau) + k0 = pt.exp(-c * tau) + cdt = pt.cos(d * tau) + sdt = pt.sin(d * tau) # For tau > Delta - cos_term = 2 * (tt.cosh(cd) * tt.cos(dd) - 1) - sin_term = 2 * (tt.sinh(cd) * tt.sin(dd)) + cos_term = 2 * (pt.cosh(cd) * pt.cos(dd) - 1) + sin_term = 2 * (pt.sinh(cd) * pt.sin(dd)) factor = k0 * norm - K_large += tt.sum( + K_large += pt.sum( (C1 * cos_term - C2 * sin_term) * factor * cdt, axis=-1 ) - K_large += tt.sum( + K_large += pt.sum( (C2 * cos_term + C1 * sin_term) * factor * sdt, axis=-1 ) # tau < Delta - edmt = tt.exp(-c * dmt) - edpt = tt.exp(-c * dpt) + edmt = pt.exp(-c * dmt) + edpt = pt.exp(-c * dpt) cos_term = ( - edmt * tt.cos(d * dmt) + edpt * tt.cos(d * dpt) - 2 * k0 * cdt + edmt * pt.cos(d * dmt) + edpt * pt.cos(d * dpt) - 2 * k0 * cdt ) sin_term = ( - edmt * tt.sin(d * dmt) + edpt * tt.sin(d * dpt) - 2 * k0 * sdt + edmt * pt.sin(d * dmt) + edpt * pt.sin(d * dpt) - 2 * k0 * sdt ) - K_small += tt.sum(2 * (a * c + b * d) * c2pd2 * dmt * norm, axis=-1) - K_small += tt.sum((C1 * cos_term + C2 * sin_term) * norm, axis=-1) + K_small += pt.sum(2 * (a * c + b * d) * c2pd2 * dmt * norm, axis=-1) + K_small += pt.sum((C1 * cos_term + C2 * sin_term) * norm, axis=-1) - return tt.switch(tt.le(tau0, dt), K_small, K_large) + return pt.switch(pt.le(tau0, dt), K_small, K_large) class RealTerm(Term): __doc__ = base_terms.RealTerm.__doc__ def __init__(self, *, a, c, **kwargs): - self.a = tt.as_tensor_variable(a).astype("float64") - self.c = tt.as_tensor_variable(c).astype("float64") + self.a = pt.as_tensor_variable(a).astype("float64") + self.c = pt.as_tensor_variable(c).astype("float64") super().__init__(**kwargs) def get_coefficients(self): - empty = tt.zeros(0, dtype=self.dtype) + empty = pt.zeros(0, dtype=self.dtype) if self.a.ndim == 0: return ( self.a[None], @@ -415,14 +415,14 @@ class ComplexTerm(Term): __doc__ = base_terms.ComplexTerm.__doc__ def __init__(self, *, a, b, c, d, **kwargs): - self.a = tt.as_tensor_variable(a).astype("float64") - self.b = tt.as_tensor_variable(b).astype("float64") - self.c = tt.as_tensor_variable(c).astype("float64") - self.d = tt.as_tensor_variable(d).astype("float64") + self.a = pt.as_tensor_variable(a).astype("float64") + self.b = pt.as_tensor_variable(b).astype("float64") + self.c = pt.as_tensor_variable(c).astype("float64") + self.d = pt.as_tensor_variable(d).astype("float64") super().__init__(**kwargs) def get_coefficients(self): - empty = tt.zeros(0, dtype=self.dtype) + empty = pt.zeros(0, dtype=self.dtype) if self.a.ndim == 0: return ( empty, @@ -440,14 +440,14 @@ class SHOTerm(Term): __parameter_spec__ = base_terms.SHOTerm.__parameter_spec__ @base_terms.handle_parameter_spec( - lambda x: tt.as_tensor_variable(x).astype("float64") + lambda x: pt.as_tensor_variable(x).astype("float64") ) def __init__(self, *, eps=1e-5, **kwargs): - self.eps = tt.as_tensor_variable(eps).astype("float64") + self.eps = pt.as_tensor_variable(eps).astype("float64") self.dtype = kwargs.get("dtype", "float64") self.overdamped = self.get_overdamped_coefficients() self.underdamped = self.get_underdamped_coefficients() - self.cond = tt.lt(self.Q, 0.5) + self.cond = pt.lt(self.Q, 0.5) def get_value(self, tau): return ifelse( @@ -464,8 +464,8 @@ def get_psd(self, omega): ) def get_celerite_matrices(self, x, diag, **kwargs): - x = tt.as_tensor_variable(x) - diag = tt.as_tensor_variable(diag) + x = pt.as_tensor_variable(x) + diag = pt.as_tensor_variable(diag) cr, ar, Ur, Vr = super()._get_celerite_matrices_real( self.overdamped, x, **kwargs ) @@ -475,38 +475,38 @@ def get_celerite_matrices(self, x, diag, **kwargs): ar = ar + diag ac = ac + diag - cr, cc = tt.broadcast_arrays(cr, cc) - ar, ac = tt.broadcast_arrays(ar, ac) - Ur, Uc = tt.broadcast_arrays(Ur, Uc) - Vr, Vc = tt.broadcast_arrays(Vr, Vc) + cr, cc = pt.broadcast_arrays(cr, cc) + ar, ac = pt.broadcast_arrays(ar, ac) + Ur, Uc = pt.broadcast_arrays(Ur, Uc) + Vr, Vc = pt.broadcast_arrays(Vr, Vc) return [ - tt.switch(self.cond, a, b) + pt.switch(self.cond, a, b) for a, b in zip((cr, ar, Ur, Vr), (cc, ac, Uc, Vc)) ] def get_overdamped_coefficients(self): Q = self.Q - f = tt.sqrt(tt.maximum(1.0 - 4.0 * Q**2, self.eps)) + f = pt.sqrt(pt.maximum(1.0 - 4.0 * Q**2, self.eps)) return ( 0.5 * self.S0 * self.w0 * Q - * tt.stack([1.0 + 1.0 / f, 1.0 - 1.0 / f]), - 0.5 * self.w0 / Q * tt.stack([1.0 - f, 1.0 + f]), + * pt.stack([1.0 + 1.0 / f, 1.0 - 1.0 / f]), + 0.5 * self.w0 / Q * pt.stack([1.0 - f, 1.0 + f]), ) def get_underdamped_coefficients(self): Q = self.Q - f = tt.sqrt(tt.maximum(4.0 * Q**2 - 1.0, self.eps)) + f = pt.sqrt(pt.maximum(4.0 * Q**2 - 1.0, self.eps)) a = self.S0 * self.w0 * Q c = 0.5 * self.w0 / Q return ( - tt.stack([a]), - tt.stack([a / f]), - tt.stack([c]), - tt.stack([c * f]), + pt.stack([a]), + pt.stack([a / f]), + pt.stack([c]), + pt.stack([c * f]), ) @@ -514,15 +514,15 @@ class Matern32Term(Term): __doc__ = base_terms.Matern32Term.__doc__ def __init__(self, *, sigma, rho, eps=0.01, **kwargs): - self.sigma = tt.as_tensor_variable(sigma).astype("float64") - self.rho = tt.as_tensor_variable(rho).astype("float64") - self.eps = tt.as_tensor_variable(eps).astype("float64") + self.sigma = pt.as_tensor_variable(sigma).astype("float64") + self.rho = pt.as_tensor_variable(rho).astype("float64") + self.eps = pt.as_tensor_variable(eps).astype("float64") super().__init__(**kwargs) def get_coefficients(self): w0 = np.sqrt(3.0) / self.rho S0 = self.sigma**2 / w0 - empty = tt.zeros(0, dtype=self.dtype) + empty = pt.zeros(0, dtype=self.dtype) return ( empty, empty, @@ -537,22 +537,22 @@ class RotationTerm(TermSum): __doc__ = base_terms.RotationTerm.__doc__ def __init__(self, *, sigma, period, Q0, dQ, f, **kwargs): - self.sigma = tt.as_tensor_variable(sigma).astype("float64") - self.period = tt.as_tensor_variable(period).astype("float64") - self.Q0 = tt.as_tensor_variable(Q0).astype("float64") - self.dQ = tt.as_tensor_variable(dQ).astype("float64") - self.f = tt.as_tensor_variable(f).astype("float64") + self.sigma = pt.as_tensor_variable(sigma).astype("float64") + self.period = pt.as_tensor_variable(period).astype("float64") + self.Q0 = pt.as_tensor_variable(Q0).astype("float64") + self.dQ = pt.as_tensor_variable(dQ).astype("float64") + self.f = pt.as_tensor_variable(f).astype("float64") self.amp = self.sigma**2 / (1 + self.f) # One term with a period of period Q1 = 0.5 + self.Q0 + self.dQ - w1 = 4 * np.pi * Q1 / (self.period * tt.sqrt(4 * Q1**2 - 1)) + w1 = 4 * np.pi * Q1 / (self.period * pt.sqrt(4 * Q1**2 - 1)) S1 = self.amp / (w1 * Q1) # Another term at half the period Q2 = 0.5 + self.Q0 - w2 = 8 * np.pi * Q2 / (self.period * tt.sqrt(4 * Q2**2 - 1)) + w2 = 8 * np.pi * Q2 / (self.period * pt.sqrt(4 * Q2**2 - 1)) S2 = self.f * self.amp / (w2 * Q2) super().__init__( diff --git a/python/test/pymc/conftest.py b/python/test/pymc/conftest.py new file mode 100644 index 0000000..94625c2 --- /dev/null +++ b/python/test/pymc/conftest.py @@ -0,0 +1,14 @@ +def pytest_configure(config): + try: + import pytensor + except ImportError: + return + + import platform + + pytensor.config.floatX = "float64" + # TODO: Uncomment when PyMC commit 714b4a0 makes it into a release (probably 5.9.2) + # pytensor.config.compute_test_value = "raise" + if platform.system() == "Darwin": + pytensor.config.gcc.cxxflags = "-Wno-c++11-narrowing" + config.addinivalue_line("filterwarnings", "ignore") diff --git a/python/test/pymc4/test_pymc4_celerite2.py b/python/test/pymc/test_pymc_celerite2.py similarity index 92% rename from python/test/pymc4/test_pymc4_celerite2.py rename to python/test/pymc/test_pymc_celerite2.py index 26e4100..d39a5d9 100644 --- a/python/test/pymc4/test_pymc4_celerite2.py +++ b/python/test/pymc/test_pymc_celerite2.py @@ -5,12 +5,12 @@ import celerite2 -pytest.importorskip("celerite2.pymc4") +pytest.importorskip("celerite2.pymc") try: from celerite2 import terms as pyterms - from celerite2.pymc4 import GaussianProcess, terms - from celerite2.pymc4.celerite2 import CITATIONS + from celerite2.pymc import GaussianProcess, terms + from celerite2.pymc.celerite2 import CITATIONS from celerite2.testing import check_gp_models except (ImportError, ModuleNotFoundError): pass @@ -115,8 +115,8 @@ def test_marginal(data): gp.marginal("obs", observed=y) np.testing.assert_allclose( - model.compile_logp()(model.test_point), - model.compile_fn(gp.log_likelihood(y))(model.test_point), + model.compile_logp()(model.initial_point()), + model.compile_fn(gp.log_likelihood(y))(model.initial_point()), ) diff --git a/python/test/pymc4/test_pymc4_ops.py b/python/test/pymc/test_pymc_ops.py similarity index 81% rename from python/test/pymc4/test_pymc4_ops.py rename to python/test/pymc/test_pymc_ops.py index dc1da8d..5f1b834 100644 --- a/python/test/pymc4/test_pymc4_ops.py +++ b/python/test/pymc/test_pymc_ops.py @@ -8,18 +8,18 @@ from celerite2 import backprop, driver from celerite2.testing import get_matrices -pytest.importorskip("celerite2.pymc4") +pytest.importorskip("celerite2.pymc") try: - import aesara - import aesara.tensor as tt - from aesara.compile.mode import Mode - from aesara.compile.sharedvalue import SharedVariable - from aesara.graph.fg import FunctionGraph - from aesara.graph.optdb import OptimizationQuery - from aesara.link.jax import JAXLinker - - from celerite2.pymc4 import ops + import pytensor + import pytensor.tensor as tt + from pytensor.compile.mode import Mode + from pytensor.compile.sharedvalue import SharedVariable + from pytensor.graph.fg import FunctionGraph + from pytensor.graph.rewriting.db import RewriteDatabaseQuery + from pytensor.link.jax import JAXLinker + + from celerite2.pymc import ops except (ImportError, ModuleNotFoundError): pass @@ -30,7 +30,9 @@ jax = None else: - opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) + opts = RewriteDatabaseQuery( + include=[None], exclude=["cxx_only", "BlasOpt"] + ) jax_mode = Mode(JAXLinker(), opts) py_mode = Mode("py", opts) @@ -54,20 +56,20 @@ def convert_values_to_types(values): def check_shape(op, inputs, outputs, values, result, multi): if multi: - shapes = aesara.function(inputs, [o.shape for o in outputs])(*values) + shapes = pytensor.function(inputs, [o.shape for o in outputs])(*values) assert all( np.all(v.shape == s) for v, s in zip(result, shapes) ), "Invalid shape inference" else: - shape = aesara.function(inputs, outputs.shape)(*values) + shape = pytensor.function(inputs, outputs.shape)(*values) assert result.shape == shape def check_basic(ref_func, op, values): inputs = convert_values_to_types(values) outputs = op(*inputs) - result = aesara.function(inputs, outputs)(*values) + result = pytensor.function(inputs, outputs)(*values) try: result.shape @@ -93,7 +95,7 @@ def check_basic(ref_func, op, values): def check_grad(op, values, num_out, eps=1.234e-8): inputs = convert_values_to_types(values) outputs = op(*inputs) - func = aesara.function(inputs, outputs) + func = pytensor.function(inputs, outputs) vals0 = func(*values) # Compute numerical grad @@ -113,8 +115,8 @@ def check_grad(op, values, num_out, eps=1.234e-8): # Compute the backprop for k in range(num_out): for i in range(vals0[k].size): - res = aesara.function( - inputs, aesara.grad(outputs[k].flatten()[i], inputs) + res = pytensor.function( + inputs, pytensor.grad(outputs[k].flatten()[i], inputs) )(*values) for n, b in enumerate(res): @@ -124,7 +126,7 @@ def check_grad(op, values, num_out, eps=1.234e-8): if jax is not None: for k in range(num_out): - out_grad = aesara.grad(tt.sum(outputs[k]), inputs) + out_grad = pytensor.grad(tt.sum(outputs[k]), inputs) fg = FunctionGraph(inputs, out_grad) compare_jax_and_py(fg, values) @@ -248,20 +250,19 @@ def compare_jax_and_py( assert_fn = partial(np.testing.assert_allclose, rtol=1e-4) fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)] - aesara_jax_fn = aesara.function(fn_inputs, fgraph.outputs, mode=jax_mode) - jax_res = aesara_jax_fn(*test_inputs) + pytensor_jax_fn = pytensor.function( + fn_inputs, fgraph.outputs, mode=jax_mode + ) + jax_res = pytensor_jax_fn(*test_inputs) if must_be_device_array: if isinstance(jax_res, list): - assert all( - isinstance(res, jax.interpreters.xla.DeviceArray) - for res in jax_res - ) + assert all(isinstance(res, jax.Array) for res in jax_res) else: - assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) + assert isinstance(jax_res, jax.Array) - aesara_py_fn = aesara.function(fn_inputs, fgraph.outputs, mode=py_mode) - py_res = aesara_py_fn(*test_inputs) + pytensor_py_fn = pytensor.function(fn_inputs, fgraph.outputs, mode=py_mode) + py_res = pytensor_py_fn(*test_inputs) if len(fgraph.outputs) > 1: for j, p in zip(jax_res, py_res): diff --git a/python/test/pymc4/test_pymc4_terms.py b/python/test/pymc/test_pymc_terms.py similarity index 90% rename from python/test/pymc4/test_pymc4_terms.py rename to python/test/pymc/test_pymc_terms.py index 0858e58..ea186fc 100644 --- a/python/test/pymc4/test_pymc4_terms.py +++ b/python/test/pymc/test_pymc_terms.py @@ -5,11 +5,11 @@ import numpy as np import pytest -pytest.importorskip("celerite2.pymc4") +pytest.importorskip("celerite2.pymc") try: from celerite2 import terms as pyterms - from celerite2.pymc4 import terms + from celerite2.pymc import terms from celerite2.testing import check_tensor_term except (ImportError, ModuleNotFoundError): pass @@ -72,16 +72,16 @@ def test_base_terms(name, args): def test_opt_error(): - import aesara.tensor as at - from aesara import config, function, grad + import pytensor.tensor as pt + from pytensor import config, function, grad x = np.linspace(0, 5, 10) diag = np.full_like(x, 0.2) with config.change_flags(on_opt_error="raise"): - arg = at.scalar() + arg = pt.scalar() arg.tag.test_value = 0.5 matrices = terms.SHOTerm(S0=1.0, w0=0.5, Q=arg).get_celerite_matrices( x, diag ) - function([arg], grad(sum(at.sum(m) for m in matrices), [arg]))(0.5) + function([arg], grad(sum(pt.sum(m) for m in matrices), [arg]))(0.5) diff --git a/python/test/pymc4/conftest.py b/python/test/pymc4/conftest.py deleted file mode 100644 index 3866a42..0000000 --- a/python/test/pymc4/conftest.py +++ /dev/null @@ -1,13 +0,0 @@ -def pytest_configure(config): - try: - import aesara - except ImportError: - return - - import platform - - aesara.config.floatX = "float64" - aesara.config.compute_test_value = "raise" - if platform.system() == "Darwin": - aesara.config.gcc.cxxflags = "-Wno-c++11-narrowing" - config.addinivalue_line("filterwarnings", "ignore") diff --git a/setup.py b/setup.py index 45778f8..60ad2ac 100644 --- a/setup.py +++ b/setup.py @@ -39,8 +39,8 @@ "scipy", "celerite>=0.3.1", ], - "pymc3": ["pymc3>=3.9", "numpy<1.22"], - "pymc4": ["pymc>=4,<5"], + "pymc3": ["pymc3>=3.9", "numpy<1.22", "xarray<2023.10.0"], + "pymc": ["pymc>=5"], "jax": ["jax", "jaxlib"], "docs": [ "sphinx", @@ -53,7 +53,7 @@ "matplotlib", "scipy", "emcee", - "pymc>=4,<5", + "pymc>=5", "tqdm", "numpyro", ],