Skip to content

Commit

Permalink
Added test cases for new behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
JasperMartins committed Nov 22, 2024
1 parent e5f3f34 commit 2fa2bc1
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions test/core/prior/dict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,11 @@ def test_sample_subset_with_actual_subset(self):
expected = dict(length=np.array([42.0, 42.0, 42.0]))
self.assertTrue(np.array_equal(expected["length"], samples["length"]))

joint_prior = self.joint_prior_from_file
samples = joint_prior.sample_subset(keys=["testAbase"], size=size)
self.assertTrue(joint_prior["testAbase"].dist.sampled_parameters == [])
self.assertTrue(joint_prior["testBbase"].dist.sampled_parameters == [])

def test_sample_subset_constrained_as_array(self):
size = 3
keys = ["mass", "speed"]
Expand Down Expand Up @@ -320,6 +325,15 @@ def test_ln_prob(self):
) + self.second_prior.ln_prob(samples["speed"])
self.assertEqual(expected, self.prior_set_from_dict.ln_prob(samples))

def test_ln_prob_actual_subset(self):
joint_prior = self.joint_prior_from_file
keys = ["testAbase"]
samples = joint_prior.sample_subset(keys=keys, size=1)
lnprob = joint_prior.ln_prob(samples)
self.assertTrue(joint_prior["testAbase"].dist.requested_parameters["testAbase"] is None)
self.assertTrue(joint_prior["testBbase"].dist.requested_parameters["testBbase"] is None)
self.assertTrue(lnprob == 0)

def test_rescale(self):
theta = [0.5, 0.5, 0.5]
expected = [
Expand All @@ -336,6 +350,16 @@ def test_rescale(self):
),
)

def test_rescale_actual_subset(self):
theta = [0.5]
keys = ["testAbase"]
joint_prior = self.joint_prior_from_file
samples = joint_prior.rescale(keys=keys, theta=theta)
print(joint_prior["testAbase"].dist._rescale_parameters)
self.assertTrue(joint_prior["testAbase"].dist._rescale_parameters["testAbase"] is None)
self.assertTrue(joint_prior["testBbase"].dist._rescale_parameters["testBbase"] is None)
self.assertTrue(np.all(np.isnan(samples)))

def test_cdf(self):
"""
Test that the CDF method is the inverse of the rescale method.
Expand Down

0 comments on commit 2fa2bc1

Please sign in to comment.