diff --git a/aemcmc/sample/__init__.py b/aemcmc/sample/__init__.py new file mode 100644 index 0000000..f53e034 --- /dev/null +++ b/aemcmc/sample/__init__.py @@ -0,0 +1,3 @@ +from .prior import sample_prior + +__all__ = ["sample_prior"] diff --git a/aemcmc/sample/prior.py b/aemcmc/sample/prior.py new file mode 100644 index 0000000..c363cb7 --- /dev/null +++ b/aemcmc/sample/prior.py @@ -0,0 +1,86 @@ +from typing import TYPE_CHECKING, Dict + +import aesara +import aesara.tensor as at +import aesara.tensor.random as ar +from aesara.compile.sharedvalue import SharedVariable +from aesara.graph.basic import ancestors +from aesara.tensor.random.type import RandomType + +if TYPE_CHECKING: + from aesara.graph.basic import Variable + + +def get_rv_updates( + srng: ar.RandomStream, *rvs: at.TensorVariable +) -> Dict[SharedVariable, "Variable"]: + r"""Get the updates needed to update RNG objects during sampling of `rvs`. + + A search is performed over `rvs` for `SharedVariable`\s with default + updates and the updates stored in `srng`. + + Parameters + ---------- + srng: + `RandomStream` instance with which the model was defined. + rvs: + The random variables whose prior distribution we want to sample. + + Returns + ------- + A dict containing the updates needed to sample from the models given by + `rvs`. + + """ + # TODO: It's kind of weird that this is an alist-like data structure; we + # should revisit this in `RandomStream` + srng_updates = dict(srng.state_updates) + rv_updates = {} + + for var in ancestors(rvs): + if not isinstance(var, SharedVariable) and not isinstance(var.type, RandomType): + continue + + # TODO: Consider making sure the updates correspond to "in-place" + # updates of the RNGs for relevant `RandomVariable`s? + # More generally, a function like this could be used to determine the + # consistency of `RandomVariable` updates in general (e.g. find + # bad/disassociated updates). + srng_update = srng_updates.get(var) + + if var.default_update: + if srng_update: + assert srng_update == var.default_update + + # We prefer the default update (for no particular reason) + rv_updates[var] = var.default_update + elif srng_update: + rv_updates[var] = srng_update + + return rv_updates + + +def sample_prior( + srng: ar.RandomStream, num_samples: at.TensorVariable, *rvs: at.TensorVariable +) -> at.TensorVariable: + """Sample from a model's prior distributions. + + Parameters + ---------- + srng: + `RandomStream` instance with which the model was defined. + num_samples: + The number of prior samples to generate. + rvs: + The random variables whose prior distribution we want to sample. + + """ + + rv_updates = get_rv_updates(srng, *rvs) + + def step_fn(): + return rvs, rv_updates + + samples, updates = aesara.scan(step_fn, n_steps=num_samples) + + return samples, updates diff --git a/tests/test_sample.py b/tests/test_sample.py new file mode 100644 index 0000000..e5886bd --- /dev/null +++ b/tests/test_sample.py @@ -0,0 +1,31 @@ +import aesara +import aesara.tensor as at +import numpy as np +from aesara.compile.sharedvalue import SharedVariable + +from aemcmc.sample import sample_prior + + +def test_sample_prior(): + srng = at.random.RandomStream(123) + + mu_rv = srng.normal(0, 1, name="mu") + Y_rv = srng.normal(mu_rv, 1.0, name="Y") + Z_rv = srng.gamma(0.5, 0.5, name="Z") + + samples, updates = sample_prior(srng, 10, Y_rv) + fn = aesara.function([], samples, updates=updates) + + # Make sure that `Z_rv` doesn't sneak into our prior sampling. + rng_objects = set( + var.get_value(borrow=True) + for var in fn.maker.fgraph.variables + if isinstance(var, SharedVariable) + ) + + assert mu_rv.owner.inputs[0].get_value(borrow=True) in rng_objects + assert Y_rv.owner.inputs[0].get_value(borrow=True) in rng_objects + assert Z_rv.owner.inputs[0].get_value(borrow=True) not in rng_objects + + samples_vals = fn() + assert np.shape(np.unique(samples_vals)) == (10,)