Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for PyMC >= 5 with PyTensor backend #89

Merged
merged 14 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
EXTRA_REQUIRE = {
"pymc3": ["pymc3>=3.9", "numpy<1.22"],
"pymc4": ["pymc>=4.0.0,<5"],
"pymc": ["pymc>=5.0.0"],
"jax": ["jax", "jaxlib"],
"test": ["pytest"],
"comparison": ["batman-package", "starry", "numpy<1.22"],
Expand Down
30 changes: 30 additions & 0 deletions src/exoplanet_core/pymc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-

__all__ = ["ops"]


def __set_compiler_flags():
import pytensor

def add_flag(current, new):
if new in current:
return current
return f"{current} {new}"

current = pytensor.config.gcc__cxxflags
current = add_flag(current, "-Wno-c++11-narrowing")
current = add_flag(current, "-fno-exceptions")
current = add_flag(current, "-fno-unwind-tables")
current = add_flag(current, "-fno-asynchronous-unwind-tables")
pytensor.config.gcc__cxxflags = current
vandalt marked this conversation as resolved.
Show resolved Hide resolved


__set_compiler_flags()


from exoplanet_core.pymc import ops

try:
from exoplanet_core.pymc import jax_support # noqa
except ImportError:
pass
28 changes: 28 additions & 0 deletions src/exoplanet_core/pymc/jax_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from pytensor.link.jax.dispatch import jax_funcify

from exoplanet_core.jax import ops as jax_ops
from exoplanet_core.pymc import ops as pymc_ops


@jax_funcify.register(pymc_ops.Kepler)
def jax_funcify_Kepler(op, **kwargs):
def kepler(M, ecc):
return jax_ops.kepler(M, ecc)

return kepler


@jax_funcify.register(pymc_ops.QuadSolutionVector)
def jax_funcify_QuadSolutionVector(op, **kwargs):
def quad_solution_vector(b, r):
return jax_ops._base_quad_solution_vector(b, r)

return quad_solution_vector


@jax_funcify.register(pymc_ops.ContactPoints)
def jax_funcify_ContactPoints(op, **kwargs):
def contact_points(a, e, cosw, sinw, cosi, sini, L):
return jax_ops.contact_points(a, e, cosw, sinw, cosi, sini, L)

return contact_points
235 changes: 235 additions & 0 deletions src/exoplanet_core/pymc/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
# -*- coding: utf-8 -*-

__all__ = ["kepler", "quad_solution_vector", "contact_points"]

from itertools import chain

import numpy as np
import pytensor
import pytensor.tensor as pt
from pytensor.graph import basic, op

from exoplanet_core import driver


def as_tensor_variable(x, dtype="float64", **kwargs):
t = pt.as_tensor_variable(x, **kwargs)
if dtype is None:
return t
return t.astype(dtype)


def resize_or_set(outputs, n, shape, dtype=np.float64):
if outputs[n][0] is None:
outputs[n][0] = np.empty(shape, dtype=dtype)
else:
outputs[n][0] = np.ascontiguousarray(
np.resize(outputs[n][0], shape), dtype=dtype
)
return outputs[n][0]


# **********
# * KEPLER *
# **********
class Kepler(op.Op):
r"""Solve Kepler's equation

This op numerically evaluates the solution to Kepler's equation for given
mean anomaly ``M`` and eccentricity ``e``:

.. math::

M = E - e sin(E)

For computational efficiency this op actually returns ``cos(f)`` and
``sin(f)``, where ``f`` is the true anomaly, defined as

.. math::

f = 2\,\arctan\left(\sqrt{\frac{1 + e}{1 - e}}\,\tan\frac{E}{2}\right)

Args:
mean_anomaly: The mean anomaly
eccentricity: Eccentricity, like it says on the box

Returns:
(cos(f), sin(f)): The cosine and sine of the true anomaly ``f``.

"""
__props__ = ()

def make_node(self, M, ecc):
in_args = [as_tensor_variable(M), as_tensor_variable(ecc)]
if any(i.dtype != "float64" for i in in_args):
raise ValueError("float64 precision is required")
out_args = [in_args[0].type(), in_args[0].type()]
return basic.Apply(self, in_args, out_args)

def infer_shape(self, *args):
return args[-1]

def perform(self, node, inputs, outputs):
M, ecc = inputs
sinf = resize_or_set(outputs, 0, M.shape)
cosf = resize_or_set(outputs, 1, M.shape)
driver.solve_kepler(M, ecc, sinf, cosf)

def grad(self, inputs, gradients):
M, e = inputs
sinf, cosf = self(M, e)

ecosf = e * cosf
ome2 = 1 - e**2
dfdM = (1 + ecosf) ** 2 / ome2**1.5
dfde = (2 + ecosf) * sinf / ome2

bM = pt.zeros_like(M)
be = pt.zeros_like(M)
if not isinstance(
gradients[0].type, pytensor.gradient.DisconnectedType
):
bM += gradients[0] * cosf * dfdM
be += gradients[0] * cosf * dfde

if not isinstance(
gradients[1].type, pytensor.gradient.DisconnectedType
):
bM -= gradients[1] * sinf * dfdM
be -= gradients[1] * sinf * dfde

return [bM, be]

def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return eval_points
return self.grad(inputs, eval_points)


kepler = Kepler()


# **********
# * STARRY *
# **********
class QuadSolutionVector(op.Op):
r"""Compute the "solution vector" for a quadratic limb-darkening model

Note that you probably don't ever want to directly instantiate this op.
Use ``exoplanet_core.pymc.ops.quad_solution_vector`` instead.

This will return a tensor with an extra dimension of size 3 on the right
hand side which represents the solution vector for each pair of impact
parameter ``b`` and radius ratio ``r`` values. See `Agol+ (2020)
<https://arxiv.org/abs/1908.03222>`_ for more details.

Args:
b: The impact parameter
r: The radius ratio

Returns:
(s, dsdb, dsdr): The solution vector and its first derivatives. Each
will have the shape ``[..., 3]``, where ``...`` indicates the shape
of ``b`` or ``r``.

"""
__props__ = ()

def make_node(self, b, r):
in_args = [as_tensor_variable(b), as_tensor_variable(r)]
if any(i.dtype != "float64" for i in in_args):
raise ValueError("float64 precision is required")
x = in_args[0]
o = [
pt.tensor(
# NOTE: Changed broadcastable to shape because it caused deprecation warning,
# BUT this might change again in the future: https://github.com/pymc-devs/pytensor/issues/408
shape=tuple(x.broadcastable) + (False,),
dtype=x.dtype,
)
for _ in range(3)
]
return basic.Apply(self, in_args, o)

def infer_shape(self, *args):
shapes = args[-1]
shape = tuple(shapes[0]) + (3,)
return shape, shape, shape

def perform(self, node, inputs, outputs):
b, r = inputs
shape = b.shape + (3,)
s = resize_or_set(outputs, 0, shape)
dsdb = resize_or_set(outputs, 1, shape)
dsdr = resize_or_set(outputs, 2, shape)
driver.quad_solution_vector_with_grad(b, r, s, dsdb, dsdr)

def grad(self, inputs, gradients):
b, r = inputs
s, dsdb, dsdr = self(b, r)
bs = gradients[0]

for g in gradients[1:]:
if not isinstance(g.type, pytensor.gradient.DisconnectedType):
raise ValueError(
"Backpropagation is only supported for the solution vector"
)

if isinstance(bs.type, pytensor.gradient.DisconnectedType):
return [
pytensor.gradient.DisconnectedType()(),
pytensor.gradient.DisconnectedType()(),
]

return [pt.sum(bs * dsdb, axis=-1), pt.sum(bs * dsdr, axis=-1)]

def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return eval_points
return self.grad(inputs, eval_points)


_quad_solution_vector = QuadSolutionVector()


def quad_solution_vector(b, r):
return _quad_solution_vector(b, r)[0]


# ******************
# * CONTACT POINTS *
# ******************
class ContactPoints(op.Op):
__props__ = ()

def make_node(self, a, e, cosw, sinw, cosi, sini, L):
in_args = list(
map(as_tensor_variable, (a, e, cosw, sinw, cosi, sini, L))
)
if any(i.dtype != "float64" for i in in_args):
raise ValueError("float64 precision is required")
out_args = [
in_args[0].type(),
in_args[0].type(),
pt.tensor(
# NOTE: Changed broadcastable to shape because it caused deprecation warning,
# BUT this might change again in the future: https://github.com/pymc-devs/pytensor/issues/408
shape=tuple(in_args[0].broadcastable),
dfm marked this conversation as resolved.
Show resolved Hide resolved
dtype="int32",
),
]
return basic.Apply(self, in_args, out_args)

def infer_shape(self, *args):
shapes = args[-1]
return (shapes[0], shapes[0], shapes[0])

def perform(self, node, inputs, outputs):
shape = inputs[0].shape
M_left = resize_or_set(outputs, 0, shape)
M_right = resize_or_set(outputs, 1, shape)
flag = resize_or_set(outputs, 2, shape, dtype=np.int32)
driver.contact_points(*chain(inputs, (M_left, M_right, flag)))


contact_points = ContactPoints()
Loading