From da40ea5ed3e9275a43b819b7b5522ea02d56bd00 Mon Sep 17 00:00:00 2001 From: Thomas Vandal Date: Tue, 31 Oct 2023 23:52:30 -0400 Subject: [PATCH 01/18] import ShapedArray from core instead of abstract_arrays --- python/celerite2/jax/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/celerite2/jax/ops.py b/python/celerite2/jax/ops.py index 17e1666..274228f 100644 --- a/python/celerite2/jax/ops.py +++ b/python/celerite2/jax/ops.py @@ -25,7 +25,7 @@ import pkg_resources from jax import core, lax from jax import numpy as jnp -from jax.abstract_arrays import ShapedArray +from jax.core import ShapedArray from jax.interpreters import ad, xla from jax.lib import xla_client From b3f871b99482f0617b263821d4eee1839d57278f Mon Sep 17 00:00:00 2001 From: Thomas Vandal Date: Tue, 31 Oct 2023 23:54:18 -0400 Subject: [PATCH 02/18] Use `xla_client.register_custom_call_target()` instead of deprecated `register_cpu_custom_call_target()` Ref: https://github.com/google/jax/issues/17662 --- python/celerite2/jax/ops.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/celerite2/jax/ops.py b/python/celerite2/jax/ops.py index 274228f..bc6e983 100644 --- a/python/celerite2/jax/ops.py +++ b/python/celerite2/jax/ops.py @@ -201,8 +201,8 @@ def _rev_translation_rule(name, spec, c, *args): def _build_op(name, spec): - xla_client.register_cpu_custom_call_target( - name, getattr(xla_ops, spec["name"])() + xla_client.register_custom_call_target( + name, getattr(xla_ops, spec["name"])(), platform="cpu" ) prim = core.Primitive(f"celerite2_{spec['name']}") @@ -216,8 +216,10 @@ def _build_op(name, spec): if not spec["has_rev"]: return prim, None - xla_client.register_cpu_custom_call_target( - name + b"_rev", getattr(xla_ops, f"{spec['name']}_rev")() + xla_client.register_custom_call_target( + name + b"_rev", + getattr(xla_ops, f"{spec['name']}_rev")(), + platform="cpu", ) jvp_prim = core.Primitive(f"celerite2_{spec['name']}_jvp") From 54cf43a53a83285b651896bdc53f40cdf7c991b5 Mon Sep 17 00:00:00 2001 From: Thomas Vandal Date: Wed, 1 Nov 2023 00:10:46 -0400 Subject: [PATCH 03/18] pymc4 -> pymc and aesara -> pytensor --- python/celerite2/{pymc4 => pymc}/__init__.py | 12 +- python/celerite2/{pymc4 => pymc}/celerite2.py | 28 +-- .../celerite2/{pymc4 => pymc}/distribution.py | 32 +-- .../celerite2/{pymc4 => pymc}/jax_support.py | 6 +- python/celerite2/{pymc4 => pymc}/ops.py | 12 +- python/celerite2/{pymc4 => pymc}/terms.py | 224 +++++++++--------- python/test/{pymc4 => pymc}/conftest.py | 8 +- .../test_pymc_celerite2.py} | 6 +- .../test_pymc_ops.py} | 41 ++-- .../test_pymc_terms.py} | 12 +- setup.py | 4 +- 11 files changed, 193 insertions(+), 192 deletions(-) rename python/celerite2/{pymc4 => pymc}/__init__.py (65%) rename python/celerite2/{pymc4 => pymc}/celerite2.py (87%) rename python/celerite2/{pymc4 => pymc}/distribution.py (82%) rename python/celerite2/{pymc4 => pymc}/jax_support.py (71%) rename python/celerite2/{pymc4 => pymc}/ops.py (95%) rename python/celerite2/{pymc4 => pymc}/terms.py (69%) rename python/test/{pymc4 => pymc}/conftest.py (52%) rename python/test/{pymc4/test_pymc4_celerite2.py => pymc/test_pymc_celerite2.py} (95%) rename python/test/{pymc4/test_pymc4_ops.py => pymc/test_pymc_ops.py} (84%) rename python/test/{pymc4/test_pymc4_terms.py => pymc/test_pymc_terms.py} (90%) diff --git a/python/celerite2/pymc4/__init__.py b/python/celerite2/pymc/__init__.py similarity index 65% rename from python/celerite2/pymc4/__init__.py rename to python/celerite2/pymc/__init__.py index 76c8701..ee77b01 100644 --- a/python/celerite2/pymc4/__init__.py +++ b/python/celerite2/pymc/__init__.py @@ -4,27 +4,27 @@ def __set_compiler_flags(): - import aesara + import pytensor def add_flag(current, new): if new in current: return current return f"{current} {new}" - current = aesara.config.gcc__cxxflags + current = pytensor.config.gcc__cxxflags current = add_flag(current, "-Wno-c++11-narrowing") current = add_flag(current, "-fno-exceptions") current = add_flag(current, "-fno-unwind-tables") current = add_flag(current, "-fno-asynchronous-unwind-tables") - aesara.config.gcc__cxxflags = current + pytensor.config.gcc__cxxflags = current __set_compiler_flags() -from celerite2.pymc4 import terms -from celerite2.pymc4.celerite2 import GaussianProcess +from celerite2.pymc import terms # noqa +from celerite2.pymc.celerite2 import GaussianProcess # noqa try: - from celerite2.pymc4 import jax_support # noqa + from celerite2.pymc import jax_support # noqa except ImportError: pass diff --git a/python/celerite2/pymc4/celerite2.py b/python/celerite2/pymc/celerite2.py similarity index 87% rename from python/celerite2/pymc4/celerite2.py rename to python/celerite2/pymc/celerite2.py index 735007d..9083057 100644 --- a/python/celerite2/pymc4/celerite2.py +++ b/python/celerite2/pymc/celerite2.py @@ -2,13 +2,13 @@ __all__ = ["GaussianProcess", "ConditionalDistribution"] -import aesara.tensor as tt +import pytensor.tensor as pt import numpy as np -from aesara.raise_op import Assert +from pytensor.raise_op import Assert from celerite2.citation import CITATIONS from celerite2.core import BaseConditionalDistribution, BaseGaussianProcess -from celerite2.pymc4 import ops +from celerite2.pymc import ops class ConditionalDistribution(BaseConditionalDistribution): @@ -22,37 +22,37 @@ def _do_general_matmul(self, c, U1, V1, U2, V2, inp, target): return target def _diagdot(self, a, b): - return tt.batched_dot(a.T, b.T) + return pt.batched_dot(a.T, b.T) class GaussianProcess(BaseGaussianProcess): conditional_distribution = ConditionalDistribution def _as_tensor(self, tensor): - return tt.as_tensor_variable(tensor).astype("float64") + return pt.as_tensor_variable(tensor).astype("float64") def _zeros_like(self, tensor): - return tt.zeros_like(tensor) + return pt.zeros_like(tensor) def _do_compute(self, quiet): if quiet: self._d, self._W, _ = ops.factor_quiet( self._t, self._c, self._a, self._U, self._V ) - self._log_det = tt.switch( - tt.any(self._d <= 0.0), -np.inf, tt.sum(tt.log(self._d)) + self._log_det = pt.switch( + pt.any(self._d <= 0.0), -np.inf, pt.sum(pt.log(self._d)) ) else: self._d, self._W, _ = ops.factor( self._t, self._c, self._a, self._U, self._V ) - self._log_det = tt.sum(tt.log(self._d)) + self._log_det = pt.sum(pt.log(self._d)) self._norm = -0.5 * (self._log_det + self._size * np.log(2 * np.pi)) def _check_sorted(self, t): - return Assert()(t, tt.all(t[1:] - t[:-1] >= 0)) + return Assert()(t, pt.all(t[1:] - t[:-1] >= 0)) def _do_solve(self, y): z = ops.solve_lower(self._t, self._c, self._U, self._W, y)[0] @@ -69,7 +69,7 @@ def _do_norm(self, y): alpha = ops.solve_lower( self._t, self._c, self._U, self._W, y[:, None] )[0][:, 0] - return tt.sum(alpha**2 / self._d) + return pt.sum(alpha**2 / self._d) def _add_citations_to_pymc_model(self, **kwargs): import pymc as pm @@ -87,10 +87,10 @@ def marginal(self, name, **kwargs): observed (optional): The observed data Returns: - A :class:`celerite2.pymc3.CeleriteNormal` distribution + A :class:`celerite2.pymc.CeleriteNormal` distribution representing the marginal likelihood. """ - from celerite2.pymc4.distribution import CeleriteNormal + from celerite2.pymc.distribution import CeleriteNormal self._add_citations_to_pymc_model(**kwargs) return CeleriteNormal( @@ -108,7 +108,7 @@ def marginal(self, name, **kwargs): def conditional( self, name, y, t=None, include_mean=True, kernel=None, **kwargs ): - """Add a variable representing the conditional density to a PyMC3 model + """Add a variable representing the conditional density to a PyMC model .. note:: The performance of this method will generally be poor since the sampler will numerically sample this parameter. Depending on diff --git a/python/celerite2/pymc4/distribution.py b/python/celerite2/pymc/distribution.py similarity index 82% rename from python/celerite2/pymc4/distribution.py rename to python/celerite2/pymc/distribution.py index 3754339..a271cfc 100644 --- a/python/celerite2/pymc4/distribution.py +++ b/python/celerite2/pymc/distribution.py @@ -2,16 +2,16 @@ __all__ = ["CeleriteNormal"] -import aesara.tensor as tt +import pytensor.tensor as pt import numpy as np -from aesara.tensor.random.op import RandomVariable -from aesara.tensor.random.utils import broadcast_params +from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.random.utils import broadcast_params from pymc.distributions.dist_math import check_parameters from pymc.distributions.distribution import Continuous from pymc.distributions.shape_utils import rv_size_is_none import celerite2.driver as driver -from celerite2.pymc4 import ops +from celerite2.pymc import ops def safe_celerite_normal(rng, mean, norm, t, c, U, W, d, size=None): @@ -90,26 +90,26 @@ class CeleriteNormal(Continuous): @classmethod def dist(cls, mean, norm, t, c, U, W, d, **kwargs): - mean = tt.as_tensor_variable(mean) - norm = tt.as_tensor_variable(norm) - t = tt.as_tensor_variable(t) - c = tt.as_tensor_variable(c) - U = tt.as_tensor_variable(U) - W = tt.as_tensor_variable(W) - d = tt.as_tensor_variable(d) - mean = tt.broadcast_arrays(mean, t)[0] + mean = pt.as_tensor_variable(mean) + norm = pt.as_tensor_variable(norm) + t = pt.as_tensor_variable(t) + c = pt.as_tensor_variable(c) + U = pt.as_tensor_variable(U) + W = pt.as_tensor_variable(W) + d = pt.as_tensor_variable(d) + mean = pt.broadcast_arrays(mean, t)[0] return super().dist([mean, norm, t, c, U, W, d], **kwargs) def moment(rv, size, mean, *args): moment = mean if not rv_size_is_none(size): - moment_size = tt.concatenate([size, [mean.shape[-1]]]) - moment = tt.full(moment_size, mean) + moment_size = pt.concatenate([size, [mean.shape[-1]]]) + moment = pt.full(moment_size, mean) return moment def logp(value, mean, norm, t, c, U, W, d): - ok = tt.all(tt.gt(d, 0.0)) + ok = pt.all(pt.gt(d, 0.0)) alpha = value - mean alpha = ops.solve_lower(t, c, U, W, alpha[:, None])[0][:, 0] - logp = norm - 0.5 * tt.sum(alpha**2 / d) + logp = norm - 0.5 * pt.sum(alpha**2 / d) return check_parameters(logp, ok) diff --git a/python/celerite2/pymc4/jax_support.py b/python/celerite2/pymc/jax_support.py similarity index 71% rename from python/celerite2/pymc4/jax_support.py rename to python/celerite2/pymc/jax_support.py index 2d268eb..b41af72 100644 --- a/python/celerite2/pymc4/jax_support.py +++ b/python/celerite2/pymc/jax_support.py @@ -1,10 +1,10 @@ -from aesara.link.jax.dispatch import jax_funcify +from pytensor.link.jax.dispatch import jax_funcify from celerite2.jax import ops as jax_ops -from celerite2.pymc4 import ops as pymc4_ops +from celerite2.pymc import ops as pymc_ops -@jax_funcify.register(pymc4_ops._CeleriteOp) +@jax_funcify.register(pymc_ops._CeleriteOp) def jax_funcify_Celerite(op, **kwargs): name = op.name if name.endswith("_fwd"): diff --git a/python/celerite2/pymc4/ops.py b/python/celerite2/pymc/ops.py similarity index 95% rename from python/celerite2/pymc4/ops.py rename to python/celerite2/pymc/ops.py index ff0cc4f..d73ae3e 100644 --- a/python/celerite2/pymc4/ops.py +++ b/python/celerite2/pymc/ops.py @@ -14,11 +14,11 @@ import json from itertools import chain -import aesara -import aesara.tensor as tt +import pytensor +import pytensor.tensor as pt import numpy as np import pkg_resources -from aesara.graph import basic, op +from pytensor.graph import basic, op import celerite2.backprop as backprop import celerite2.driver as driver @@ -82,7 +82,7 @@ def make_node(self, *inputs): for spec in self.spec["dimensions"] } otypes = [ - tt.TensorType( + pt.TensorType( "float64", [broadcastable[k] for k in spec["shape"]] )() for spec in self.spec["outputs"] + self.spec["extra_outputs"] @@ -130,8 +130,8 @@ def perform(self, node, inputs, outputs): def grad(self, inputs, gradients): outputs = self(*inputs) grads = ( - tt.zeros_like(outputs[n]) - if isinstance(b.type, aesara.gradient.DisconnectedType) + pt.zeros_like(outputs[n]) + if isinstance(b.type, pytensor.gradient.DisconnectedType) else b for n, b in enumerate(gradients[: len(self.spec["outputs"])]) ) diff --git a/python/celerite2/pymc4/terms.py b/python/celerite2/pymc/terms.py similarity index 69% rename from python/celerite2/pymc4/terms.py rename to python/celerite2/pymc/terms.py index 4472d5e..91721af 100644 --- a/python/celerite2/pymc4/terms.py +++ b/python/celerite2/pymc/terms.py @@ -13,13 +13,13 @@ "RotationTerm", ] -import aesara -import aesara.tensor as tt +import pytensor +import pytensor.tensor as pt import numpy as np -from aesara.ifelse import ifelse +from pytensor.ifelse import ifelse import celerite2.terms as base_terms -from celerite2.pymc4 import ops +from celerite2.pymc import ops class Term(base_terms.Term): @@ -30,11 +30,11 @@ def __init__(self, *, dtype="float64"): self.coefficients = self.get_coefficients() def __add__(self, b): - dtype = aesara.scalar.upcast(self.dtype, b.dtype) + dtype = pytensor.scalar.upcast(self.dtype, b.dtype) return TermSum(self, b, dtype=dtype) def __mul__(self, b): - dtype = aesara.scalar.upcast(self.dtype, b.dtype) + dtype = pytensor.scalar.upcast(self.dtype, b.dtype) return TermProduct(self, b, dtype=dtype) @property @@ -51,17 +51,17 @@ def get_value(self, tau): def _get_value_real(self, coefficients, tau): ar, cr = coefficients - tau = tt.abs(tau) - tau = tt.shape_padright(tau) - return tt.sum(ar * tt.exp(-cr * tau), axis=-1) + tau = pt.abs(tau) + tau = pt.shape_padright(tau) + return pt.sum(ar * pt.exp(-cr * tau), axis=-1) def _get_value_complex(self, coefficients, tau): ac, bc, cc, dc = coefficients - tau = tt.abs(tau) - tau = tt.shape_padright(tau) - factor = tt.exp(-cc * tau) - K = tt.sum(ac * factor * tt.cos(dc * tau), axis=-1) - K += tt.sum(bc * factor * tt.sin(dc * tau), axis=-1) + tau = pt.abs(tau) + tau = pt.shape_padright(tau) + factor = pt.exp(-cc * tau) + K = pt.sum(ac * factor * pt.cos(dc * tau), axis=-1) + K += pt.sum(bc * factor * pt.sin(dc * tau), axis=-1) return K def get_psd(self, omega): @@ -71,17 +71,17 @@ def get_psd(self, omega): def _get_psd_real(self, coefficients, omega): ar, cr = coefficients - omega = tt.shape_padright(omega) + omega = pt.shape_padright(omega) w2 = omega**2 - power = tt.sum(ar * cr / (cr**2 + w2), axis=-1) + power = pt.sum(ar * cr / (cr**2 + w2), axis=-1) return np.sqrt(2.0 / np.pi) * power def _get_psd_complex(self, coefficients, omega): ac, bc, cc, dc = coefficients - omega = tt.shape_padright(omega) + omega = pt.shape_padright(omega) w2 = omega**2 w02 = cc**2 + dc**2 - power = tt.sum( + power = pt.sum( ((ac * cc + bc * dc) * w02 + (ac * cc - bc * dc) * w2) / (w2 * w2 + 2.0 * (cc**2 - dc**2) * w2 + w02 * w02), axis=-1, @@ -90,55 +90,55 @@ def _get_psd_complex(self, coefficients, omega): def to_dense(self, x, diag): K = self.get_value(x[:, None] - x[None, :]) - K += tt.diag(diag) + K += pt.diag(diag) return K def get_celerite_matrices(self, x, diag, **kwargs): - x = tt.as_tensor_variable(x) - diag = tt.as_tensor_variable(diag) + x = pt.as_tensor_variable(x) + diag = pt.as_tensor_variable(diag) cr, ar, Ur, Vr = self._get_celerite_matrices_real( self.coefficients[:2], x, **kwargs ) cc, ac, Uc, Vc = self._get_celerite_matrices_complex( self.coefficients[2:], x, **kwargs ) - c = tt.concatenate((cr, cc)) + c = pt.concatenate((cr, cc)) a = diag + ar + ac - U = tt.concatenate((Ur, Uc), axis=1) - V = tt.concatenate((Vr, Vc), axis=1) + U = pt.concatenate((Ur, Uc), axis=1) + V = pt.concatenate((Vr, Vc), axis=1) return c, a, U, V def _get_celerite_matrices_real(self, coefficients, x, **kwargs): ar, cr = coefficients - z = tt.zeros_like(x) + z = pt.zeros_like(x) return ( cr, - tt.sum(ar), + pt.sum(ar), ar[None, :] + z[:, None], - tt.ones_like(ar)[None, :] + z[:, None], + pt.ones_like(ar)[None, :] + z[:, None], ) def _get_celerite_matrices_complex(self, coefficients, x, **kwargs): ac, bc, cc, dc = coefficients arg = dc[None, :] * x[:, None] - cos = tt.cos(arg) - sin = tt.sin(arg) - U = tt.concatenate( + cos = pt.cos(arg) + sin = pt.sin(arg) + U = pt.concatenate( ( ac[None, :] * cos + bc[None, :] * sin, ac[None, :] * sin - bc[None, :] * cos, ), axis=1, ) - V = tt.concatenate((cos, sin), axis=1) - c = tt.concatenate((cc, cc)) + V = pt.concatenate((cos, sin), axis=1) + c = pt.concatenate((cc, cc)) - return c, tt.sum(ac), U, V + return c, pt.sum(ac), U, V def dot(self, x, diag, y): - x = tt.as_tensor_variable(x) - y = tt.as_tensor_variable(y) + x = pt.as_tensor_variable(x) + y = pt.as_tensor_variable(y) is_vector = False if y.ndim == 1: @@ -169,24 +169,24 @@ def terms(self): return self._terms def get_value(self, tau): - tau = tt.as_tensor_variable(tau) + tau = pt.as_tensor_variable(tau) return sum(term.get_value(tau) for term in self._terms) def get_psd(self, omega): - omega = tt.as_tensor_variable(omega) + omega = pt.as_tensor_variable(omega) return sum(term.get_psd(omega) for term in self._terms) def get_celerite_matrices(self, x, diag, **kwargs): matrices = ( - term.get_celerite_matrices(x, tt.zeros_like(diag), **kwargs) + term.get_celerite_matrices(x, pt.zeros_like(diag), **kwargs) for term in self._terms ) c, a, U, V = zip(*matrices) return ( - tt.concatenate(c, axis=-1), + pt.concatenate(c, axis=-1), sum(a) + diag, - tt.concatenate(U, axis=-1), - tt.concatenate(V, axis=-1), + pt.concatenate(U, axis=-1), + pt.concatenate(V, axis=-1), ) @@ -199,7 +199,7 @@ def __init__(self, term1, term2, **kwargs): self.dtype = kwargs.get("dtype", "float64") def get_value(self, tau): - tau = tt.as_tensor_variable(tau) + tau = pt.as_tensor_variable(tau) return self.term1.get_value(tau) * self.term2.get_value(tau) def get_psd(self, omega): @@ -208,13 +208,13 @@ def get_psd(self, omega): ) def get_celerite_matrices(self, x, diag, **kwargs): - z = tt.zeros_like(diag) + z = pt.zeros_like(diag) c1, a1, U1, V1 = self.term1.get_celerite_matrices(x, z, **kwargs) c2, a2, U2, V2 = self.term2.get_celerite_matrices(x, z, **kwargs) - mg = tt.mgrid[: c1.shape[0], : c2.shape[0]] - i = tt.flatten(mg[0]) - j = tt.flatten(mg[1]) + mg = pt.mgrid[: c1.shape[0], : c2.shape[0]] + i = pt.flatten(mg[0]) + j = pt.flatten(mg[1]) c = c1[i] + c2[j] a = a1 * a2 + diag @@ -259,7 +259,7 @@ def __init__(self, term, delta, **kwargs): "coefficients" ) self.term = term - self.delta = tt.as_tensor_variable(delta).astype("float64") + self.delta = pt.as_tensor_variable(delta).astype("float64") super().__init__(**kwargs) def get_celerite_matrices(self, x, diag, **kwargs): @@ -268,7 +268,7 @@ def get_celerite_matrices(self, x, diag, **kwargs): # Real part cd = cr * dt - delta_diag = 2 * tt.sum(ar * (cd - tt.sinh(cd)) / cd**2) + delta_diag = 2 * pt.sum(ar * (cd - pt.sinh(cd)) / cd**2) # Complex part cd = c * dt @@ -279,12 +279,12 @@ def get_celerite_matrices(self, x, diag, **kwargs): C1 = a * (c2 - d2) + 2 * b * c * d C2 = b * (c2 - d2) - 2 * a * c * d norm = (dt * c2pd2) ** 2 - sinh = tt.sinh(cd) - cosh = tt.cosh(cd) - delta_diag += 2 * tt.sum( + sinh = pt.sinh(cd) + cosh = pt.cosh(cd) + delta_diag += 2 * pt.sum( ( - C2 * cosh * tt.sin(dd) - - C1 * sinh * tt.cos(dd) + C2 * cosh * pt.sin(dd) + - C1 * sinh * pt.cos(dd) + (a * c + b * d) * dt * c2pd2 ) / norm @@ -298,7 +298,7 @@ def get_coefficients(self): # Real componenets crd = cr * self.delta - coeffs = [2 * ar * (tt.cosh(crd) - 1) / crd**2, cr] + coeffs = [2 * ar * (pt.cosh(crd) - 1) / crd**2, cr] # Imaginary coefficients cd = c * self.delta @@ -306,8 +306,8 @@ def get_coefficients(self): c2 = c**2 d2 = d**2 factor = 2.0 / (self.delta * (c2 + d2)) ** 2 - cos_term = tt.cosh(cd) * tt.cos(dd) - 1 - sin_term = tt.sinh(cd) * tt.sin(dd) + cos_term = pt.cosh(cd) * pt.cos(dd) - 1 + sin_term = pt.sinh(cd) * pt.sin(dd) C1 = a * (c2 - d2) + 2 * b * c * d C2 = b * (c2 - d2) - 2 * a * c * d @@ -324,7 +324,7 @@ def get_coefficients(self): def get_psd(self, omega): psd0 = self.term.get_psd(omega) arg = 0.5 * self.delta * omega - sinc = tt.switch(tt.neq(arg, 0), tt.sin(arg) / arg, tt.ones_like(arg)) + sinc = pt.switch(pt.neq(arg, 0), pt.sin(arg) / arg, pt.ones_like(arg)) return psd0 * sinc**2 def get_value(self, tau0): @@ -332,7 +332,7 @@ def get_value(self, tau0): ar, cr, a, b, c, d = self.term.coefficients # Format the lags correctly - tau0 = tt.abs(tau0) + tau0 = pt.abs(tau0) tau = tau0[..., None] # Precompute some factors @@ -342,13 +342,13 @@ def get_value(self, tau0): # Real parts: # tau > Delta crd = cr * dt - cosh = tt.cosh(crd) + cosh = pt.cosh(crd) norm = 2 * ar / crd**2 - K_large = tt.sum(norm * (cosh - 1) * tt.exp(-cr * tau), axis=-1) + K_large = pt.sum(norm * (cosh - 1) * pt.exp(-cr * tau), axis=-1) # tau < Delta crdmt = cr * dmt - K_small = K_large + tt.sum(norm * (crdmt - tt.sinh(crdmt)), axis=-1) + K_small = K_large + pt.sum(norm * (crdmt - pt.sinh(crdmt)), axis=-1) # Complex part cd = c * dt @@ -359,46 +359,46 @@ def get_value(self, tau0): C1 = a * (c2 - d2) + 2 * b * c * d C2 = b * (c2 - d2) - 2 * a * c * d norm = 1.0 / (dt * c2pd2) ** 2 - k0 = tt.exp(-c * tau) - cdt = tt.cos(d * tau) - sdt = tt.sin(d * tau) + k0 = pt.exp(-c * tau) + cdt = pt.cos(d * tau) + sdt = pt.sin(d * tau) # For tau > Delta - cos_term = 2 * (tt.cosh(cd) * tt.cos(dd) - 1) - sin_term = 2 * (tt.sinh(cd) * tt.sin(dd)) + cos_term = 2 * (pt.cosh(cd) * pt.cos(dd) - 1) + sin_term = 2 * (pt.sinh(cd) * pt.sin(dd)) factor = k0 * norm - K_large += tt.sum( + K_large += pt.sum( (C1 * cos_term - C2 * sin_term) * factor * cdt, axis=-1 ) - K_large += tt.sum( + K_large += pt.sum( (C2 * cos_term + C1 * sin_term) * factor * sdt, axis=-1 ) # tau < Delta - edmt = tt.exp(-c * dmt) - edpt = tt.exp(-c * dpt) + edmt = pt.exp(-c * dmt) + edpt = pt.exp(-c * dpt) cos_term = ( - edmt * tt.cos(d * dmt) + edpt * tt.cos(d * dpt) - 2 * k0 * cdt + edmt * pt.cos(d * dmt) + edpt * pt.cos(d * dpt) - 2 * k0 * cdt ) sin_term = ( - edmt * tt.sin(d * dmt) + edpt * tt.sin(d * dpt) - 2 * k0 * sdt + edmt * pt.sin(d * dmt) + edpt * pt.sin(d * dpt) - 2 * k0 * sdt ) - K_small += tt.sum(2 * (a * c + b * d) * c2pd2 * dmt * norm, axis=-1) - K_small += tt.sum((C1 * cos_term + C2 * sin_term) * norm, axis=-1) + K_small += pt.sum(2 * (a * c + b * d) * c2pd2 * dmt * norm, axis=-1) + K_small += pt.sum((C1 * cos_term + C2 * sin_term) * norm, axis=-1) - return tt.switch(tt.le(tau0, dt), K_small, K_large) + return pt.switch(pt.le(tau0, dt), K_small, K_large) class RealTerm(Term): __doc__ = base_terms.RealTerm.__doc__ def __init__(self, *, a, c, **kwargs): - self.a = tt.as_tensor_variable(a).astype("float64") - self.c = tt.as_tensor_variable(c).astype("float64") + self.a = pt.as_tensor_variable(a).astype("float64") + self.c = pt.as_tensor_variable(c).astype("float64") super().__init__(**kwargs) def get_coefficients(self): - empty = tt.zeros(0, dtype=self.dtype) + empty = pt.zeros(0, dtype=self.dtype) if self.a.ndim == 0: return ( self.a[None], @@ -415,14 +415,14 @@ class ComplexTerm(Term): __doc__ = base_terms.ComplexTerm.__doc__ def __init__(self, *, a, b, c, d, **kwargs): - self.a = tt.as_tensor_variable(a).astype("float64") - self.b = tt.as_tensor_variable(b).astype("float64") - self.c = tt.as_tensor_variable(c).astype("float64") - self.d = tt.as_tensor_variable(d).astype("float64") + self.a = pt.as_tensor_variable(a).astype("float64") + self.b = pt.as_tensor_variable(b).astype("float64") + self.c = pt.as_tensor_variable(c).astype("float64") + self.d = pt.as_tensor_variable(d).astype("float64") super().__init__(**kwargs) def get_coefficients(self): - empty = tt.zeros(0, dtype=self.dtype) + empty = pt.zeros(0, dtype=self.dtype) if self.a.ndim == 0: return ( empty, @@ -440,14 +440,14 @@ class SHOTerm(Term): __parameter_spec__ = base_terms.SHOTerm.__parameter_spec__ @base_terms.handle_parameter_spec( - lambda x: tt.as_tensor_variable(x).astype("float64") + lambda x: pt.as_tensor_variable(x).astype("float64") ) def __init__(self, *, eps=1e-5, **kwargs): - self.eps = tt.as_tensor_variable(eps).astype("float64") + self.eps = pt.as_tensor_variable(eps).astype("float64") self.dtype = kwargs.get("dtype", "float64") self.overdamped = self.get_overdamped_coefficients() self.underdamped = self.get_underdamped_coefficients() - self.cond = tt.lt(self.Q, 0.5) + self.cond = pt.lt(self.Q, 0.5) def get_value(self, tau): return ifelse( @@ -464,8 +464,8 @@ def get_psd(self, omega): ) def get_celerite_matrices(self, x, diag, **kwargs): - x = tt.as_tensor_variable(x) - diag = tt.as_tensor_variable(diag) + x = pt.as_tensor_variable(x) + diag = pt.as_tensor_variable(diag) cr, ar, Ur, Vr = super()._get_celerite_matrices_real( self.overdamped, x, **kwargs ) @@ -475,38 +475,38 @@ def get_celerite_matrices(self, x, diag, **kwargs): ar = ar + diag ac = ac + diag - cr, cc = tt.broadcast_arrays(cr, cc) - ar, ac = tt.broadcast_arrays(ar, ac) - Ur, Uc = tt.broadcast_arrays(Ur, Uc) - Vr, Vc = tt.broadcast_arrays(Vr, Vc) + cr, cc = pt.broadcast_arrays(cr, cc) + ar, ac = pt.broadcast_arrays(ar, ac) + Ur, Uc = pt.broadcast_arrays(Ur, Uc) + Vr, Vc = pt.broadcast_arrays(Vr, Vc) return [ - tt.switch(self.cond, a, b) + pt.switch(self.cond, a, b) for a, b in zip((cr, ar, Ur, Vr), (cc, ac, Uc, Vc)) ] def get_overdamped_coefficients(self): Q = self.Q - f = tt.sqrt(tt.maximum(1.0 - 4.0 * Q**2, self.eps)) + f = pt.sqrt(pt.maximum(1.0 - 4.0 * Q**2, self.eps)) return ( 0.5 * self.S0 * self.w0 * Q - * tt.stack([1.0 + 1.0 / f, 1.0 - 1.0 / f]), - 0.5 * self.w0 / Q * tt.stack([1.0 - f, 1.0 + f]), + * pt.stack([1.0 + 1.0 / f, 1.0 - 1.0 / f]), + 0.5 * self.w0 / Q * pt.stack([1.0 - f, 1.0 + f]), ) def get_underdamped_coefficients(self): Q = self.Q - f = tt.sqrt(tt.maximum(4.0 * Q**2 - 1.0, self.eps)) + f = pt.sqrt(pt.maximum(4.0 * Q**2 - 1.0, self.eps)) a = self.S0 * self.w0 * Q c = 0.5 * self.w0 / Q return ( - tt.stack([a]), - tt.stack([a / f]), - tt.stack([c]), - tt.stack([c * f]), + pt.stack([a]), + pt.stack([a / f]), + pt.stack([c]), + pt.stack([c * f]), ) @@ -514,15 +514,15 @@ class Matern32Term(Term): __doc__ = base_terms.Matern32Term.__doc__ def __init__(self, *, sigma, rho, eps=0.01, **kwargs): - self.sigma = tt.as_tensor_variable(sigma).astype("float64") - self.rho = tt.as_tensor_variable(rho).astype("float64") - self.eps = tt.as_tensor_variable(eps).astype("float64") + self.sigma = pt.as_tensor_variable(sigma).astype("float64") + self.rho = pt.as_tensor_variable(rho).astype("float64") + self.eps = pt.as_tensor_variable(eps).astype("float64") super().__init__(**kwargs) def get_coefficients(self): w0 = np.sqrt(3.0) / self.rho S0 = self.sigma**2 / w0 - empty = tt.zeros(0, dtype=self.dtype) + empty = pt.zeros(0, dtype=self.dtype) return ( empty, empty, @@ -537,22 +537,22 @@ class RotationTerm(TermSum): __doc__ = base_terms.RotationTerm.__doc__ def __init__(self, *, sigma, period, Q0, dQ, f, **kwargs): - self.sigma = tt.as_tensor_variable(sigma).astype("float64") - self.period = tt.as_tensor_variable(period).astype("float64") - self.Q0 = tt.as_tensor_variable(Q0).astype("float64") - self.dQ = tt.as_tensor_variable(dQ).astype("float64") - self.f = tt.as_tensor_variable(f).astype("float64") + self.sigma = pt.as_tensor_variable(sigma).astype("float64") + self.period = pt.as_tensor_variable(period).astype("float64") + self.Q0 = pt.as_tensor_variable(Q0).astype("float64") + self.dQ = pt.as_tensor_variable(dQ).astype("float64") + self.f = pt.as_tensor_variable(f).astype("float64") self.amp = self.sigma**2 / (1 + self.f) # One term with a period of period Q1 = 0.5 + self.Q0 + self.dQ - w1 = 4 * np.pi * Q1 / (self.period * tt.sqrt(4 * Q1**2 - 1)) + w1 = 4 * np.pi * Q1 / (self.period * pt.sqrt(4 * Q1**2 - 1)) S1 = self.amp / (w1 * Q1) # Another term at half the period Q2 = 0.5 + self.Q0 - w2 = 8 * np.pi * Q2 / (self.period * tt.sqrt(4 * Q2**2 - 1)) + w2 = 8 * np.pi * Q2 / (self.period * pt.sqrt(4 * Q2**2 - 1)) S2 = self.f * self.amp / (w2 * Q2) super().__init__( diff --git a/python/test/pymc4/conftest.py b/python/test/pymc/conftest.py similarity index 52% rename from python/test/pymc4/conftest.py rename to python/test/pymc/conftest.py index 3866a42..a8b7f68 100644 --- a/python/test/pymc4/conftest.py +++ b/python/test/pymc/conftest.py @@ -1,13 +1,13 @@ def pytest_configure(config): try: - import aesara + import pytensor except ImportError: return import platform - aesara.config.floatX = "float64" - aesara.config.compute_test_value = "raise" + pytensor.config.floatX = "float64" + pytensor.config.compute_test_value = "raise" if platform.system() == "Darwin": - aesara.config.gcc.cxxflags = "-Wno-c++11-narrowing" + pytensor.config.gcc.cxxflags = "-Wno-c++11-narrowing" config.addinivalue_line("filterwarnings", "ignore") diff --git a/python/test/pymc4/test_pymc4_celerite2.py b/python/test/pymc/test_pymc_celerite2.py similarity index 95% rename from python/test/pymc4/test_pymc4_celerite2.py rename to python/test/pymc/test_pymc_celerite2.py index 26e4100..ec6e96d 100644 --- a/python/test/pymc4/test_pymc4_celerite2.py +++ b/python/test/pymc/test_pymc_celerite2.py @@ -5,12 +5,12 @@ import celerite2 -pytest.importorskip("celerite2.pymc4") +pytest.importorskip("celerite2.pymc") try: from celerite2 import terms as pyterms - from celerite2.pymc4 import GaussianProcess, terms - from celerite2.pymc4.celerite2 import CITATIONS + from celerite2.pymc import GaussianProcess, terms + from celerite2.pymc.celerite2 import CITATIONS from celerite2.testing import check_gp_models except (ImportError, ModuleNotFoundError): pass diff --git a/python/test/pymc4/test_pymc4_ops.py b/python/test/pymc/test_pymc_ops.py similarity index 84% rename from python/test/pymc4/test_pymc4_ops.py rename to python/test/pymc/test_pymc_ops.py index dc1da8d..8d4f5b2 100644 --- a/python/test/pymc4/test_pymc4_ops.py +++ b/python/test/pymc/test_pymc_ops.py @@ -8,18 +8,19 @@ from celerite2 import backprop, driver from celerite2.testing import get_matrices -pytest.importorskip("celerite2.pymc4") +pytest.importorskip("celerite2.pymc") try: - import aesara - import aesara.tensor as tt - from aesara.compile.mode import Mode - from aesara.compile.sharedvalue import SharedVariable - from aesara.graph.fg import FunctionGraph - from aesara.graph.optdb import OptimizationQuery - from aesara.link.jax import JAXLinker - - from celerite2.pymc4 import ops + import pytensor + import pytensor.tensor as tt + from pytensor.compile.mode import Mode + from pytensor.compile.sharedvalue import SharedVariable + from pytensor.graph.fg import FunctionGraph + # TODO: pytensor.graph.rewriting.db.RewriteDatabaseQuery? + from pytensor.graph.rewriting.db import RewriteDatabaseQuery + from pytensor.link.jax import JAXLinker + + from celerite2.pymc import ops except (ImportError, ModuleNotFoundError): pass @@ -30,7 +31,7 @@ jax = None else: - opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) + opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) jax_mode = Mode(JAXLinker(), opts) py_mode = Mode("py", opts) @@ -54,20 +55,20 @@ def convert_values_to_types(values): def check_shape(op, inputs, outputs, values, result, multi): if multi: - shapes = aesara.function(inputs, [o.shape for o in outputs])(*values) + shapes = pytensor.function(inputs, [o.shape for o in outputs])(*values) assert all( np.all(v.shape == s) for v, s in zip(result, shapes) ), "Invalid shape inference" else: - shape = aesara.function(inputs, outputs.shape)(*values) + shape = pytensor.function(inputs, outputs.shape)(*values) assert result.shape == shape def check_basic(ref_func, op, values): inputs = convert_values_to_types(values) outputs = op(*inputs) - result = aesara.function(inputs, outputs)(*values) + result = pytensor.function(inputs, outputs)(*values) try: result.shape @@ -93,7 +94,7 @@ def check_basic(ref_func, op, values): def check_grad(op, values, num_out, eps=1.234e-8): inputs = convert_values_to_types(values) outputs = op(*inputs) - func = aesara.function(inputs, outputs) + func = pytensor.function(inputs, outputs) vals0 = func(*values) # Compute numerical grad @@ -113,8 +114,8 @@ def check_grad(op, values, num_out, eps=1.234e-8): # Compute the backprop for k in range(num_out): for i in range(vals0[k].size): - res = aesara.function( - inputs, aesara.grad(outputs[k].flatten()[i], inputs) + res = pytensor.function( + inputs, pytensor.grad(outputs[k].flatten()[i], inputs) )(*values) for n, b in enumerate(res): @@ -124,7 +125,7 @@ def check_grad(op, values, num_out, eps=1.234e-8): if jax is not None: for k in range(num_out): - out_grad = aesara.grad(tt.sum(outputs[k]), inputs) + out_grad = pytensor.grad(tt.sum(outputs[k]), inputs) fg = FunctionGraph(inputs, out_grad) compare_jax_and_py(fg, values) @@ -248,7 +249,7 @@ def compare_jax_and_py( assert_fn = partial(np.testing.assert_allclose, rtol=1e-4) fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)] - aesara_jax_fn = aesara.function(fn_inputs, fgraph.outputs, mode=jax_mode) + aesara_jax_fn = pytensor.function(fn_inputs, fgraph.outputs, mode=jax_mode) jax_res = aesara_jax_fn(*test_inputs) if must_be_device_array: @@ -260,7 +261,7 @@ def compare_jax_and_py( else: assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) - aesara_py_fn = aesara.function(fn_inputs, fgraph.outputs, mode=py_mode) + aesara_py_fn = pytensor.function(fn_inputs, fgraph.outputs, mode=py_mode) py_res = aesara_py_fn(*test_inputs) if len(fgraph.outputs) > 1: diff --git a/python/test/pymc4/test_pymc4_terms.py b/python/test/pymc/test_pymc_terms.py similarity index 90% rename from python/test/pymc4/test_pymc4_terms.py rename to python/test/pymc/test_pymc_terms.py index 0858e58..ea186fc 100644 --- a/python/test/pymc4/test_pymc4_terms.py +++ b/python/test/pymc/test_pymc_terms.py @@ -5,11 +5,11 @@ import numpy as np import pytest -pytest.importorskip("celerite2.pymc4") +pytest.importorskip("celerite2.pymc") try: from celerite2 import terms as pyterms - from celerite2.pymc4 import terms + from celerite2.pymc import terms from celerite2.testing import check_tensor_term except (ImportError, ModuleNotFoundError): pass @@ -72,16 +72,16 @@ def test_base_terms(name, args): def test_opt_error(): - import aesara.tensor as at - from aesara import config, function, grad + import pytensor.tensor as pt + from pytensor import config, function, grad x = np.linspace(0, 5, 10) diag = np.full_like(x, 0.2) with config.change_flags(on_opt_error="raise"): - arg = at.scalar() + arg = pt.scalar() arg.tag.test_value = 0.5 matrices = terms.SHOTerm(S0=1.0, w0=0.5, Q=arg).get_celerite_matrices( x, diag ) - function([arg], grad(sum(at.sum(m) for m in matrices), [arg]))(0.5) + function([arg], grad(sum(pt.sum(m) for m in matrices), [arg]))(0.5) diff --git a/setup.py b/setup.py index 45778f8..7a5d2bc 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,7 @@ "celerite>=0.3.1", ], "pymc3": ["pymc3>=3.9", "numpy<1.22"], - "pymc4": ["pymc>=4,<5"], + "pymc": ["pymc>=5"], "jax": ["jax", "jaxlib"], "docs": [ "sphinx", @@ -53,7 +53,7 @@ "matplotlib", "scipy", "emcee", - "pymc>=4,<5", + "pymc>=5", "tqdm", "numpyro", ], From 82fb3cf7da476404ab60c06bcfdc0d44c611e0e9 Mon Sep 17 00:00:00 2001 From: Thomas Vandal Date: Wed, 1 Nov 2023 00:11:21 -0400 Subject: [PATCH 04/18] Update from `DeviceArray` to `jax.Array` Ref: https://jax.readthedocs.io/en/latest/jax_array_migration.html --- python/test/pymc/test_pymc_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/test/pymc/test_pymc_ops.py b/python/test/pymc/test_pymc_ops.py index 8d4f5b2..217f59d 100644 --- a/python/test/pymc/test_pymc_ops.py +++ b/python/test/pymc/test_pymc_ops.py @@ -255,11 +255,11 @@ def compare_jax_and_py( if must_be_device_array: if isinstance(jax_res, list): assert all( - isinstance(res, jax.interpreters.xla.DeviceArray) + isinstance(res, jax.Array) for res in jax_res ) else: - assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) + assert isinstance(jax_res, jax.Array) aesara_py_fn = pytensor.function(fn_inputs, fgraph.outputs, mode=py_mode) py_res = aesara_py_fn(*test_inputs) From f2a5996c865cba89f4934280f2a954e65a340060 Mon Sep 17 00:00:00 2001 From: Thomas Vandal Date: Wed, 1 Nov 2023 13:57:03 -0400 Subject: [PATCH 05/18] Implement `_supp_shape_from_params()` for `CeleriteNormalRV` There use to be a default implementation of this but it has been removed. Implemented using a reference parameter, as is done for most distributions. --- python/celerite2/pymc/distribution.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/celerite2/pymc/distribution.py b/python/celerite2/pymc/distribution.py index a271cfc..ea5e866 100644 --- a/python/celerite2/pymc/distribution.py +++ b/python/celerite2/pymc/distribution.py @@ -5,7 +5,7 @@ import pytensor.tensor as pt import numpy as np from pytensor.tensor.random.op import RandomVariable -from pytensor.tensor.random.utils import broadcast_params +from pytensor.tensor.random.utils import broadcast_params, supp_shape_from_ref_param_shape from pymc.distributions.dist_math import check_parameters from pymc.distributions.distribution import Continuous from pymc.distributions.shape_utils import rv_size_is_none @@ -33,6 +33,14 @@ class CeleriteNormalRV(RandomVariable): dtype = "floatX" _print_name = ("CeleriteNormal", "\\operatorname{CeleriteNormal}") + def _supp_shape_from_params(self, dist_params, param_shapes=None): + return supp_shape_from_ref_param_shape( + ndim_supp=self.ndim_supp, + dist_params=dist_params, + param_shapes=param_shapes, + ref_param_idx=0, + ) + @classmethod def rng_fn(cls, rng, mean, norm, t, c, U, W, d, size): if any( From 8eacba4a0293b7c97ad01bb3aa51c7769002740e Mon Sep 17 00:00:00 2001 From: Thomas Vandal Date: Wed, 1 Nov 2023 13:57:43 -0400 Subject: [PATCH 06/18] Replace deprecated `model.test_point` by `model.initial_point()` --- python/test/pymc/test_pymc_celerite2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/test/pymc/test_pymc_celerite2.py b/python/test/pymc/test_pymc_celerite2.py index ec6e96d..d39a5d9 100644 --- a/python/test/pymc/test_pymc_celerite2.py +++ b/python/test/pymc/test_pymc_celerite2.py @@ -115,8 +115,8 @@ def test_marginal(data): gp.marginal("obs", observed=y) np.testing.assert_allclose( - model.compile_logp()(model.test_point), - model.compile_fn(gp.log_likelihood(y))(model.test_point), + model.compile_logp()(model.initial_point()), + model.compile_fn(gp.log_likelihood(y))(model.initial_point()), ) From ec3bf11cca3b1d07124ee0a33333d541aa5514ce Mon Sep 17 00:00:00 2001 From: Thomas Vandal Date: Wed, 1 Nov 2023 14:26:00 -0400 Subject: [PATCH 07/18] Use time as reference parameter for `CeleriteNormalRV` --- python/celerite2/pymc/distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/celerite2/pymc/distribution.py b/python/celerite2/pymc/distribution.py index ea5e866..3c22e24 100644 --- a/python/celerite2/pymc/distribution.py +++ b/python/celerite2/pymc/distribution.py @@ -38,7 +38,7 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None): ndim_supp=self.ndim_supp, dist_params=dist_params, param_shapes=param_shapes, - ref_param_idx=0, + ref_param_idx=2, ) @classmethod From a8566baaa37bb07430f8513349f5e88c08940912 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Nov 2023 18:40:59 +0000 Subject: [PATCH 08/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- python/celerite2/pymc/celerite2.py | 2 +- python/celerite2/pymc/distribution.py | 9 ++++++--- python/celerite2/pymc/ops.py | 4 ++-- python/celerite2/pymc/terms.py | 2 +- python/test/pymc/test_pymc_ops.py | 10 +++++----- 5 files changed, 15 insertions(+), 12 deletions(-) diff --git a/python/celerite2/pymc/celerite2.py b/python/celerite2/pymc/celerite2.py index 9083057..13386d9 100644 --- a/python/celerite2/pymc/celerite2.py +++ b/python/celerite2/pymc/celerite2.py @@ -2,8 +2,8 @@ __all__ = ["GaussianProcess", "ConditionalDistribution"] -import pytensor.tensor as pt import numpy as np +import pytensor.tensor as pt from pytensor.raise_op import Assert from celerite2.citation import CITATIONS diff --git a/python/celerite2/pymc/distribution.py b/python/celerite2/pymc/distribution.py index 3c22e24..1583772 100644 --- a/python/celerite2/pymc/distribution.py +++ b/python/celerite2/pymc/distribution.py @@ -2,13 +2,16 @@ __all__ = ["CeleriteNormal"] -import pytensor.tensor as pt import numpy as np -from pytensor.tensor.random.op import RandomVariable -from pytensor.tensor.random.utils import broadcast_params, supp_shape_from_ref_param_shape +import pytensor.tensor as pt from pymc.distributions.dist_math import check_parameters from pymc.distributions.distribution import Continuous from pymc.distributions.shape_utils import rv_size_is_none +from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.random.utils import ( + broadcast_params, + supp_shape_from_ref_param_shape, +) import celerite2.driver as driver from celerite2.pymc import ops diff --git a/python/celerite2/pymc/ops.py b/python/celerite2/pymc/ops.py index d73ae3e..c7ecdc5 100644 --- a/python/celerite2/pymc/ops.py +++ b/python/celerite2/pymc/ops.py @@ -14,10 +14,10 @@ import json from itertools import chain -import pytensor -import pytensor.tensor as pt import numpy as np import pkg_resources +import pytensor +import pytensor.tensor as pt from pytensor.graph import basic, op import celerite2.backprop as backprop diff --git a/python/celerite2/pymc/terms.py b/python/celerite2/pymc/terms.py index 91721af..ad14221 100644 --- a/python/celerite2/pymc/terms.py +++ b/python/celerite2/pymc/terms.py @@ -13,9 +13,9 @@ "RotationTerm", ] +import numpy as np import pytensor import pytensor.tensor as pt -import numpy as np from pytensor.ifelse import ifelse import celerite2.terms as base_terms diff --git a/python/test/pymc/test_pymc_ops.py b/python/test/pymc/test_pymc_ops.py index 217f59d..10f668e 100644 --- a/python/test/pymc/test_pymc_ops.py +++ b/python/test/pymc/test_pymc_ops.py @@ -16,6 +16,7 @@ from pytensor.compile.mode import Mode from pytensor.compile.sharedvalue import SharedVariable from pytensor.graph.fg import FunctionGraph + # TODO: pytensor.graph.rewriting.db.RewriteDatabaseQuery? from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.link.jax import JAXLinker @@ -31,7 +32,9 @@ jax = None else: - opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) + opts = RewriteDatabaseQuery( + include=[None], exclude=["cxx_only", "BlasOpt"] + ) jax_mode = Mode(JAXLinker(), opts) py_mode = Mode("py", opts) @@ -254,10 +257,7 @@ def compare_jax_and_py( if must_be_device_array: if isinstance(jax_res, list): - assert all( - isinstance(res, jax.Array) - for res in jax_res - ) + assert all(isinstance(res, jax.Array) for res in jax_res) else: assert isinstance(jax_res, jax.Array) From eae6d5fe7278928a990d04ad31280c9cfb0a74d4 Mon Sep 17 00:00:00 2001 From: Thomas Vandal Date: Wed, 1 Nov 2023 14:53:31 -0400 Subject: [PATCH 09/18] Move `pymc4.rst` to `pymc.rst` in docs --- docs/api/{pymc4.rst => pymc.rst} | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) rename docs/api/{pymc4.rst => pymc.rst} (77%) diff --git a/docs/api/pymc4.rst b/docs/api/pymc.rst similarity index 77% rename from docs/api/pymc4.rst rename to docs/api/pymc.rst index 3840c64..26b82cf 100644 --- a/docs/api/pymc4.rst +++ b/docs/api/pymc.rst @@ -1,12 +1,16 @@ .. _pymc4-api: -PyMC4 interface +PyMC (v5+) interface =============== -This ``celerite2.pymc4`` submodule provides access to the *celerite2* models -within the `Aesara `_ framework. Of special +This ``celerite2.pymc`` submodule provides access to the *celerite2* models +within the `PyTensor `_ framework. Of special interest, this adds support for probabilistic model building using `PyMC -`_ v4 or later. +`_ v5 or later. + +*Note: PyMC v4 was a short-lived version of PyMC with the aesara backend. +Upgrading to PyMC 5 or above is the recommended way forward, but past releases of +celerite2 might work with aesara.* The :ref:`first` tutorial demonstrates the use of this interface, while this page provides the details for the :class:`celerite2.pymc4.GaussianProcess` class From 611fdf094f93487e90360a723508f003924d7621 Mon Sep 17 00:00:00 2001 From: Thomas Vandal Date: Wed, 1 Nov 2023 15:01:40 -0400 Subject: [PATCH 10/18] Replace aesara_jax_fn by pytensor_jax_fn in test_pymc_ops.py --- python/test/pymc/test_pymc_ops.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/test/pymc/test_pymc_ops.py b/python/test/pymc/test_pymc_ops.py index 10f668e..5f1b834 100644 --- a/python/test/pymc/test_pymc_ops.py +++ b/python/test/pymc/test_pymc_ops.py @@ -16,8 +16,6 @@ from pytensor.compile.mode import Mode from pytensor.compile.sharedvalue import SharedVariable from pytensor.graph.fg import FunctionGraph - - # TODO: pytensor.graph.rewriting.db.RewriteDatabaseQuery? from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.link.jax import JAXLinker @@ -252,8 +250,10 @@ def compare_jax_and_py( assert_fn = partial(np.testing.assert_allclose, rtol=1e-4) fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)] - aesara_jax_fn = pytensor.function(fn_inputs, fgraph.outputs, mode=jax_mode) - jax_res = aesara_jax_fn(*test_inputs) + pytensor_jax_fn = pytensor.function( + fn_inputs, fgraph.outputs, mode=jax_mode + ) + jax_res = pytensor_jax_fn(*test_inputs) if must_be_device_array: if isinstance(jax_res, list): @@ -261,8 +261,8 @@ def compare_jax_and_py( else: assert isinstance(jax_res, jax.Array) - aesara_py_fn = pytensor.function(fn_inputs, fgraph.outputs, mode=py_mode) - py_res = aesara_py_fn(*test_inputs) + pytensor_py_fn = pytensor.function(fn_inputs, fgraph.outputs, mode=py_mode) + py_res = pytensor_py_fn(*test_inputs) if len(fgraph.outputs) > 1: for j, p in zip(jax_res, py_res): From e9daaa74ec4a37e20639150e40bdfdbf9e16468c Mon Sep 17 00:00:00 2001 From: Thomas Vandal Date: Wed, 1 Nov 2023 15:08:37 -0400 Subject: [PATCH 11/18] Replace pymc4 by pymc in API docs --- docs/api/pymc.rst | 28 ++++++++++++++-------------- docs/index.rst | 4 ++-- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/docs/api/pymc.rst b/docs/api/pymc.rst index 26b82cf..b140466 100644 --- a/docs/api/pymc.rst +++ b/docs/api/pymc.rst @@ -1,4 +1,4 @@ -.. _pymc4-api: +.. _pymc-api: PyMC (v5+) interface =============== @@ -8,39 +8,39 @@ within the `PyTensor `_ framework. Of special interest, this adds support for probabilistic model building using `PyMC `_ v5 or later. -*Note: PyMC v4 was a short-lived version of PyMC with the aesara backend. -Upgrading to PyMC 5 or above is the recommended way forward, but past releases of -celerite2 might work with aesara.* +*Note: PyMC v4 was a short-lived version of PyMC with the Aesara backend. +Celerite2 now only supports to PyMC 5, but past releases of celerite2 +might work with Aesara.* The :ref:`first` tutorial demonstrates the use of this interface, while this -page provides the details for the :class:`celerite2.pymc4.GaussianProcess` class +page provides the details for the :class:`celerite2.pymc.GaussianProcess` class which provides all this functionality. This page does not include documentation -for the term models defined in Aesara, but you can refer to the +for the term models defined in PyTensor, but you can refer to the :ref:`python-terms` section of the :ref:`python-api` documentation. All of those -models are implemented in Aesara and you can access them using something like +models are implemented in PyTensor and you can access them using something like the following: .. code-block:: python - import aesara.tensor as at - from celerite2.pymc4 import GaussianProcess, terms + import pytensor.tensor as pt + from celerite2.pymc import GaussianProcess, terms term = terms.SHOTerm(S0=at.scalar(), w0=at.scalar(), Q=at.scalar()) gp = GaussianProcess(term) -The :class:`celerite2.pymc4.GaussianProcess` class is detailed below: +The :class:`celerite2.pymc.GaussianProcess` class is detailed below: -.. autoclass:: celerite2.pymc4.GaussianProcess +.. autoclass:: celerite2.pymc.GaussianProcess :inherited-members: :exclude-members: sample, sample_conditional, recompute -PyMC (v4) support +PyMC (v5+) support ----------------- This implementation comes with a custom PyMC ``Distribution`` that represents a multivariate normal with a *celerite* covariance matrix. This is used by the -:func:`celerite2.pymc4.GaussianProcess.marginal` method documented above which +:func:`celerite2.pymc.GaussianProcess.marginal` method documented above which adds a marginal likelihood node to a PyMC model. -.. autoclass:: celerite2.pymc4.distribution.CeleriteNormal +.. autoclass:: celerite2.pymc.distribution.CeleriteNormal diff --git a/docs/index.rst b/docs/index.rst index 9b1c3a8..7d0fbfd 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -7,7 +7,7 @@ Regression in one dimension and this library, *celerite2* is a re-write of the original `celerite project `_ to improve numerical stability and integration with various machine learning frameworks. This implementation includes interfaces in Python and C++, with full support for -PyMC (v3 and v4) and JAX. +PyMC (v3 and v5+) and JAX. This documentation won't teach you the fundamentals of GP modeling but the best resource for learning about this is available for free online: `Rasmussen & @@ -43,7 +43,7 @@ an issue `_ there. api/python api/pymc3 - api/pymc4 + api/pymc api/jax api/c++ From 3fc3deded067615b7124a1cb8aedcee943b7df7c Mon Sep 17 00:00:00 2001 From: Thomas Vandal Date: Wed, 1 Nov 2023 15:09:08 -0400 Subject: [PATCH 12/18] Replace pymc4 by pymc in first.ipynb --- docs/tutorials/first.ipynb | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/tutorials/first.ipynb b/docs/tutorials/first.ipynb index 4e6e965..eadc827 100644 --- a/docs/tutorials/first.ipynb +++ b/docs/tutorials/first.ipynb @@ -330,7 +330,7 @@ "source": [ "## Posterior inference using PyMC\n", "\n", - "*celerite2* also includes support for probabilistic modeling using PyMC (v4 or v3, using the `celerite2.pymc3` or `celerite2.pymc4` submodule respectively), and we can implement the same model from above as follows:" + "*celerite2* also includes support for probabilistic modeling using PyMC (v5 or v3, using the `celerite2.pymc` or `celerite2.pymc3` submodule respectively), and we can implement the same model from above as follows:" ] }, { @@ -347,7 +347,7 @@ " warnings.filterwarnings(\"ignore\", category=RuntimeWarning)\n", "\n", " import pymc as pm\n", - " from celerite2.pymc4 import GaussianProcess, terms as pm_terms\n", + " from celerite2.pymc import GaussianProcess, terms as pm_terms\n", "\n", " with pm.Model() as model:\n", " mean = pm.Normal(\"mean\", mu=0.0, sigma=prior_sigma)\n", @@ -488,7 +488,7 @@ "id": "6f3abdcf", "metadata": {}, "source": [ - "This runtime was similar to the PyMC3 result from above, and (as we'll see below) the convergence is also similar.\n", + "This runtime was similar to the PyMC result from above, and (as we'll see below) the convergence is also similar.\n", "Any difference in runtime will probably disappear for more computationally expensive models, but this interface is looking pretty great here!\n", "\n", "As above, we can plot the posterior expectations for the power spectrum:" @@ -563,7 +563,7 @@ " bins,\n", " histtype=\"step\",\n", " density=True,\n", - " label=\"PyMC3\",\n", + " label=\"PyMC\",\n", ")\n", "plt.hist(\n", " np.exp(np.asarray((numpyro_data.posterior[\"log_rho1\"].T)).flatten()),\n", @@ -657,7 +657,7 @@ "metadata": {}, "source": [ "Overall these results are consistent, but the $\\hat{R}$ values are a bit high for the emcee run, so I'd probably run that for longer.\n", - "Either way, for models like these, PyMC3 and numpyro are generally going to be much better inference tools (in terms of runtime per effective sample) than emcee, so those are the recommended interfaces if the rest of your model can be easily implemented in such a framework." + "Either way, for models like these, PyMC and numpyro are generally going to be much better inference tools (in terms of runtime per effective sample) than emcee, so those are the recommended interfaces if the rest of your model can be easily implemented in such a framework." ] } ], From ed4e697c441cb8b61449d33987499c061b86e9f2 Mon Sep 17 00:00:00 2001 From: Thomas Vandal Date: Wed, 1 Nov 2023 16:32:53 -0400 Subject: [PATCH 13/18] Replace pymc4 by pymc in workflows and noxfile --- .github/workflows/python.yml | 4 ++-- noxfile.py | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 3387a49..c6d32a8 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -25,8 +25,8 @@ jobs: - "core" - "jax" - "pymc3" - # - "pymc4" - # - "pymc4_jax" + # - "pymc" + # - "pymc_jax" steps: - name: Checkout diff --git a/noxfile.py b/noxfile.py index ff729c8..1560160 100644 --- a/noxfile.py +++ b/noxfile.py @@ -30,20 +30,20 @@ def pymc3(session): @nox.session(python=ALL_PYTHON_VS) -def pymc4(session): - session.install(".[test,pymc4]") - _session_run(session, "python/test/pymc4") +def pymc(session): + session.install(".[test,pymc]") + _session_run(session, "python/test/pymc") @nox.session(python=ALL_PYTHON_VS) -def pymc4_jax(session): - session.install(".[test,jax,pymc4]") - _session_run(session, "python/test/pymc4/test_pymc4_ops.py") +def pymc_jax(session): + session.install(".[test,jax,pymc]") + _session_run(session, "python/test/pymc/test_pymc_ops.py") @nox.session(python=ALL_PYTHON_VS) def full(session): - session.install(".[test,jax,pymc3,pymc4]") + session.install(".[test,jax,pymc3,pymc]") _session_run(session, "python/test") From 808efdef4f85059787df00ccb4ce48f4b2fe118d Mon Sep 17 00:00:00 2001 From: Thomas Vandal Date: Wed, 1 Nov 2023 20:37:59 -0400 Subject: [PATCH 14/18] PyMC3 -> PyMC in Getting Started (first.ipynb) tutorial PSD figure --- docs/tutorials/first.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/first.ipynb b/docs/tutorials/first.ipynb index eadc827..75d6eef 100644 --- a/docs/tutorials/first.ipynb +++ b/docs/tutorials/first.ipynb @@ -411,7 +411,7 @@ "plt.xlim(freq.min(), freq.max())\n", "plt.xlabel(\"frequency [1 / day]\")\n", "plt.ylabel(\"power [day ppt$^2$]\")\n", - "_ = plt.title(\"posterior psd using PyMC3\")" + "_ = plt.title(\"posterior psd using PyMC\")" ] }, { From 06d182e8a6aeb97725ef374c31a45f158da5299c Mon Sep 17 00:00:00 2001 From: Thomas Vandal Date: Wed, 1 Nov 2023 20:48:59 -0400 Subject: [PATCH 15/18] Uncomment pymc and pymc_jax Github workflows Were commented because of errors with PyMC v4 --- .github/workflows/python.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index c6d32a8..00a7035 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -25,8 +25,8 @@ jobs: - "core" - "jax" - "pymc3" - # - "pymc" - # - "pymc_jax" + - "pymc" + - "pymc_jax" steps: - name: Checkout From 525ed6c719f00b56fe1adf3d4b2d61ccc14fe763 Mon Sep 17 00:00:00 2001 From: Thomas Vandal Date: Wed, 1 Nov 2023 20:54:44 -0400 Subject: [PATCH 16/18] Remove warning filters when running PyMC sampling in Getting Started tutorial --- docs/tutorials/first.ipynb | 82 ++++++++++++++++++-------------------- 1 file changed, 38 insertions(+), 44 deletions(-) diff --git a/docs/tutorials/first.ipynb b/docs/tutorials/first.ipynb index 75d6eef..9aa422f 100644 --- a/docs/tutorials/first.ipynb +++ b/docs/tutorials/first.ipynb @@ -340,50 +340,44 @@ "metadata": {}, "outputs": [], "source": [ - "import warnings\n", - "\n", - "with warnings.catch_warnings():\n", - " warnings.filterwarnings(\"ignore\", category=UserWarning)\n", - " warnings.filterwarnings(\"ignore\", category=RuntimeWarning)\n", - "\n", - " import pymc as pm\n", - " from celerite2.pymc import GaussianProcess, terms as pm_terms\n", - "\n", - " with pm.Model() as model:\n", - " mean = pm.Normal(\"mean\", mu=0.0, sigma=prior_sigma)\n", - " log_jitter = pm.Normal(\"log_jitter\", mu=0.0, sigma=prior_sigma)\n", - "\n", - " log_sigma1 = pm.Normal(\"log_sigma1\", mu=0.0, sigma=prior_sigma)\n", - " log_rho1 = pm.Normal(\"log_rho1\", mu=0.0, sigma=prior_sigma)\n", - " log_tau = pm.Normal(\"log_tau\", mu=0.0, sigma=prior_sigma)\n", - " term1 = pm_terms.SHOTerm(\n", - " sigma=pm.math.exp(log_sigma1),\n", - " rho=pm.math.exp(log_rho1),\n", - " tau=pm.math.exp(log_tau),\n", - " )\n", - "\n", - " log_sigma2 = pm.Normal(\"log_sigma2\", mu=0.0, sigma=prior_sigma)\n", - " log_rho2 = pm.Normal(\"log_rho2\", mu=0.0, sigma=prior_sigma)\n", - " term2 = pm_terms.SHOTerm(\n", - " sigma=pm.math.exp(log_sigma2), rho=pm.math.exp(log_rho2), Q=0.25\n", - " )\n", - "\n", - " kernel = term1 + term2\n", - " gp = GaussianProcess(kernel, mean=mean)\n", - " gp.compute(t, diag=yerr**2 + pm.math.exp(log_jitter), quiet=True)\n", - " gp.marginal(\"obs\", observed=y)\n", - "\n", - " pm.Deterministic(\"psd\", kernel.get_psd(omega))\n", - "\n", - " trace = pm.sample(\n", - " tune=1000,\n", - " draws=1000,\n", - " target_accept=0.9,\n", - " init=\"adapt_full\",\n", - " cores=2,\n", - " chains=2,\n", - " random_seed=34923,\n", - " )" + "import pymc as pm\n", + "from celerite2.pymc import GaussianProcess, terms as pm_terms\n", + "\n", + "with pm.Model() as model:\n", + " mean = pm.Normal(\"mean\", mu=0.0, sigma=prior_sigma)\n", + " log_jitter = pm.Normal(\"log_jitter\", mu=0.0, sigma=prior_sigma)\n", + "\n", + " log_sigma1 = pm.Normal(\"log_sigma1\", mu=0.0, sigma=prior_sigma)\n", + " log_rho1 = pm.Normal(\"log_rho1\", mu=0.0, sigma=prior_sigma)\n", + " log_tau = pm.Normal(\"log_tau\", mu=0.0, sigma=prior_sigma)\n", + " term1 = pm_terms.SHOTerm(\n", + " sigma=pm.math.exp(log_sigma1),\n", + " rho=pm.math.exp(log_rho1),\n", + " tau=pm.math.exp(log_tau),\n", + " )\n", + "\n", + " log_sigma2 = pm.Normal(\"log_sigma2\", mu=0.0, sigma=prior_sigma)\n", + " log_rho2 = pm.Normal(\"log_rho2\", mu=0.0, sigma=prior_sigma)\n", + " term2 = pm_terms.SHOTerm(\n", + " sigma=pm.math.exp(log_sigma2), rho=pm.math.exp(log_rho2), Q=0.25\n", + " )\n", + "\n", + " kernel = term1 + term2\n", + " gp = GaussianProcess(kernel, mean=mean)\n", + " gp.compute(t, diag=yerr**2 + pm.math.exp(log_jitter), quiet=True)\n", + " gp.marginal(\"obs\", observed=y)\n", + "\n", + " pm.Deterministic(\"psd\", kernel.get_psd(omega))\n", + "\n", + " trace = pm.sample(\n", + " tune=1000,\n", + " draws=1000,\n", + " target_accept=0.9,\n", + " init=\"adapt_full\",\n", + " cores=2,\n", + " chains=2,\n", + " random_seed=34923,\n", + " )" ] }, { From 65fbc0faafe12c08535ee4634f742d0041e75d89 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 2 Nov 2023 16:07:40 -0400 Subject: [PATCH 17/18] Upper limit on xarray version for PyMC3 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 7a5d2bc..60ad2ac 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,7 @@ "scipy", "celerite>=0.3.1", ], - "pymc3": ["pymc3>=3.9", "numpy<1.22"], + "pymc3": ["pymc3>=3.9", "numpy<1.22", "xarray<2023.10.0"], "pymc": ["pymc>=5"], "jax": ["jax", "jaxlib"], "docs": [ From 953dc87d8704db1250326cff3e198157967bb4c5 Mon Sep 17 00:00:00 2001 From: Thomas Vandal Date: Thu, 2 Nov 2023 19:15:44 -0400 Subject: [PATCH 18/18] Comment `pytensor.config.compute_test_value = "raise"` for now Requires a fix in PyMC. Introduced to main by PR #6982, commit 714b4a, will probably be in v5.9.2. --- python/test/pymc/conftest.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/test/pymc/conftest.py b/python/test/pymc/conftest.py index a8b7f68..94625c2 100644 --- a/python/test/pymc/conftest.py +++ b/python/test/pymc/conftest.py @@ -7,7 +7,8 @@ def pytest_configure(config): import platform pytensor.config.floatX = "float64" - pytensor.config.compute_test_value = "raise" + # TODO: Uncomment when PyMC commit 714b4a0 makes it into a release (probably 5.9.2) + # pytensor.config.compute_test_value = "raise" if platform.system() == "Darwin": pytensor.config.gcc.cxxflags = "-Wno-c++11-narrowing" config.addinivalue_line("filterwarnings", "ignore")