From c18897178096d9c97e7afc948f2b4a3538b95f14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Fri, 17 Feb 2023 13:25:46 +0100 Subject: [PATCH] Add a function to generate prior samples --- aemcmc/sample/__init__.py | 3 +++ aemcmc/sample/prior.py | 27 +++++++++++++++++++++++++++ tests/test_sample.py | 17 +++++++++++++++++ 3 files changed, 47 insertions(+) create mode 100644 aemcmc/sample/__init__.py create mode 100644 aemcmc/sample/prior.py create mode 100644 tests/test_sample.py diff --git a/aemcmc/sample/__init__.py b/aemcmc/sample/__init__.py new file mode 100644 index 0000000..f53e034 --- /dev/null +++ b/aemcmc/sample/__init__.py @@ -0,0 +1,3 @@ +from .prior import sample_prior + +__all__ = ["sample_prior"] diff --git a/aemcmc/sample/prior.py b/aemcmc/sample/prior.py new file mode 100644 index 0000000..a1ddd19 --- /dev/null +++ b/aemcmc/sample/prior.py @@ -0,0 +1,27 @@ +import aesara +import aesara.tensor as at +import aesara.tensor.random as ar + + +def sample_prior( + srng: ar.RandomStream, num_samples: at.TensorVariable, *rvs: at.TensorVariable +) -> at.TensorVariable: + """Sample from a model's prior distributions. + + Parameters + ---------- + srng: + `RandomStream` instance with which the model was defined. + num_samples: + The number of prior samples to generate. + rvs: + The random variables whose prior distribution we want to sample. + + """ + + def step_fn(): + return rvs, srng.state_updates + + samples, updates = aesara.scan(step_fn, n_steps=num_samples) + + return samples, updates diff --git a/tests/test_sample.py b/tests/test_sample.py new file mode 100644 index 0000000..c098cc3 --- /dev/null +++ b/tests/test_sample.py @@ -0,0 +1,17 @@ +import aesara +import aesara.tensor as at +import numpy as np + +from aemcmc.sample import sample_prior + + +def test_sample_prior(): + srng = at.random.RandomStream(0) + mu_rv = srng.normal(0, 1) + Y_rv = srng.normal(mu_rv, 1.0) + + samples, updates = sample_prior(srng, 10, Y_rv) + fn = aesara.function([], samples) + + samples_vals = fn() + assert np.shape(np.unique(samples_vals)) == (10,)