Skip to content

Commit

Permalink
Add a function to generate prior samples
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Feb 18, 2023
1 parent 2fc0ee3 commit c188971
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 0 deletions.
3 changes: 3 additions & 0 deletions aemcmc/sample/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .prior import sample_prior

__all__ = ["sample_prior"]
27 changes: 27 additions & 0 deletions aemcmc/sample/prior.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import aesara
import aesara.tensor as at
import aesara.tensor.random as ar


def sample_prior(
srng: ar.RandomStream, num_samples: at.TensorVariable, *rvs: at.TensorVariable
) -> at.TensorVariable:
"""Sample from a model's prior distributions.
Parameters
----------
srng:
`RandomStream` instance with which the model was defined.
num_samples:
The number of prior samples to generate.
rvs:
The random variables whose prior distribution we want to sample.
"""

def step_fn():
return rvs, srng.state_updates

samples, updates = aesara.scan(step_fn, n_steps=num_samples)

return samples, updates
17 changes: 17 additions & 0 deletions tests/test_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import aesara
import aesara.tensor as at
import numpy as np

from aemcmc.sample import sample_prior


def test_sample_prior():
srng = at.random.RandomStream(0)
mu_rv = srng.normal(0, 1)
Y_rv = srng.normal(mu_rv, 1.0)

samples, updates = sample_prior(srng, 10, Y_rv)
fn = aesara.function([], samples)

samples_vals = fn()
assert np.shape(np.unique(samples_vals)) == (10,)

0 comments on commit c188971

Please sign in to comment.