From 2fa2bc15a8ad332102332d1f25bef8b840c4ead2 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 17:52:57 +0100 Subject: [PATCH] Added test cases for new behavior --- test/core/prior/dict_test.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/core/prior/dict_test.py b/test/core/prior/dict_test.py index d6e6239f..edc2fb20 100644 --- a/test/core/prior/dict_test.py +++ b/test/core/prior/dict_test.py @@ -287,6 +287,11 @@ def test_sample_subset_with_actual_subset(self): expected = dict(length=np.array([42.0, 42.0, 42.0])) self.assertTrue(np.array_equal(expected["length"], samples["length"])) + joint_prior = self.joint_prior_from_file + samples = joint_prior.sample_subset(keys=["testAbase"], size=size) + self.assertTrue(joint_prior["testAbase"].dist.sampled_parameters == []) + self.assertTrue(joint_prior["testBbase"].dist.sampled_parameters == []) + def test_sample_subset_constrained_as_array(self): size = 3 keys = ["mass", "speed"] @@ -320,6 +325,15 @@ def test_ln_prob(self): ) + self.second_prior.ln_prob(samples["speed"]) self.assertEqual(expected, self.prior_set_from_dict.ln_prob(samples)) + def test_ln_prob_actual_subset(self): + joint_prior = self.joint_prior_from_file + keys = ["testAbase"] + samples = joint_prior.sample_subset(keys=keys, size=1) + lnprob = joint_prior.ln_prob(samples) + self.assertTrue(joint_prior["testAbase"].dist.requested_parameters["testAbase"] is None) + self.assertTrue(joint_prior["testBbase"].dist.requested_parameters["testBbase"] is None) + self.assertTrue(lnprob == 0) + def test_rescale(self): theta = [0.5, 0.5, 0.5] expected = [ @@ -336,6 +350,16 @@ def test_rescale(self): ), ) + def test_rescale_actual_subset(self): + theta = [0.5] + keys = ["testAbase"] + joint_prior = self.joint_prior_from_file + samples = joint_prior.rescale(keys=keys, theta=theta) + print(joint_prior["testAbase"].dist._rescale_parameters) + self.assertTrue(joint_prior["testAbase"].dist._rescale_parameters["testAbase"] is None) + self.assertTrue(joint_prior["testBbase"].dist._rescale_parameters["testBbase"] is None) + self.assertTrue(np.all(np.isnan(samples))) + def test_cdf(self): """ Test that the CDF method is the inverse of the rescale method.