From 456576bf54ddab44650f3b618f4d2728fa5fa05a Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Wed, 20 Nov 2024 18:01:36 +0100 Subject: [PATCH 01/22] Changed rescale-method of JointPrior to always return correct-size array and update in-place once all keys are requested. Changed (Conditional)PriorDict.rescale to always return samples in right shape. --- bilby/core/prior/dict.py | 35 ++++++++++++++------------ bilby/core/prior/joint.py | 52 ++++++++++++++++++++++++++++----------- 2 files changed, 57 insertions(+), 30 deletions(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index be3d543a9..2908bd083 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -600,18 +600,21 @@ def rescale(self, keys, theta): ========== keys: list List of prior keys to be rescaled - theta: list - List of randomly drawn values on a unit cube associated with the prior keys + theta: dict or array-like + Randomly drawn values on a unit cube associated with the prior keys Returns ======= - list: List of floats containing the rescaled sample + list: + If theta is 1D, returns list of floats containing the rescaled sample. + If theta is 2D, returns list of lists containing the rescaled samples. """ - theta = list(theta) + theta = [theta[key] for key in keys] if isinstance(theta, dict) else list(theta) samples = [] for key, units in zip(keys, theta): samps = self[key].rescale(units) - samples += list(np.asarray(samps).flatten()) + # turns 0d-arrays into scalars + samples.append(np.squeeze(samps).tolist()) return samples def test_redundancy(self, key, disable_logging=False): @@ -832,28 +835,28 @@ def rescale(self, keys, theta): ========== keys: list List of prior keys to be rescaled - theta: list - List of randomly drawn values on a unit cube associated with the prior keys + theta: dict or array-like + Randomly drawn values on a unit cube associated with the prior keys Returns ======= - list: List of floats containing the rescaled sample + list: + If theta is float for each key, returns list of floats containing the rescaled sample. + If theta is array-like for each key, returns list of lists containing the rescaled samples. """ keys = list(keys) - theta = list(theta) + theta = [theta[key] for key in keys] if isinstance(theta, dict) else list(theta) self._check_resolved() self._update_rescale_keys(keys) result = dict() - for key, index in zip( - self.sorted_keys_without_fixed_parameters, self._rescale_indexes - ): - result[key] = self[key].rescale( - theta[index], **self.get_required_variables(key) - ) + for key, index in zip(self.sorted_keys_without_fixed_parameters, self._rescale_indexes): + result[key] = self[key].rescale(theta[index], **self.get_required_variables(key)) self[key].least_recently_sampled = result[key] samples = [] for key in keys: - samples += list(np.asarray(result[key]).flatten()) + # turns 0d-arrays into scalars + res = np.squeeze(result[key]).tolist() + samples.append(res) return samples def _update_rescale_keys(self, keys): diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index 43c8913e3..b088b15ba 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -63,8 +63,9 @@ def __init__(self, names, bounds=None): self.requested_parameters = dict() self.reset_request() - # a dictionary of the rescaled parameters - self.rescale_parameters = dict() + # a dictionary of the rescale(d) parameters + self._rescale_parameters = dict() + self._rescaled_parameters = dict() self.reset_rescale() # a list of sampled parameters @@ -94,7 +95,12 @@ def filled_rescale(self): Check if all the rescaled parameters have been filled. """ - return not np.any([val is None for val in self.rescale_parameters.values()]) + return not np.any([val is None for val in self._rescale_parameters.values()]) + + def set_rescale(self, key, values): + values = np.array(values) + self._rescale_parameters[key] = values + self._rescaled_parameters[key] = np.atleast_1d(np.ones_like(values)) * np.nan def reset_rescale(self): """ @@ -102,7 +108,11 @@ def reset_rescale(self): """ for name in self.names: - self.rescale_parameters[name] = None + self._rescale_parameters[name] = None + self._rescaled_parameters[name] = None + + def get_rescaled(self, key): + return self._rescaled_parameters[key] def get_instantiation_dict(self): subclass_args = infer_args_from_method(self.__init__) @@ -303,10 +313,11 @@ def rescale(self, value, **kwargs): Parameters ========== - value: array - A 1d vector sample (one for each parameter) drawn from a uniform + value: array or None + If given, a 1d vector sample (one for each parameter) drawn from a uniform distribution between 0 and 1, or a 2d NxM array of samples where N is the number of samples and M is the number of parameters. + If None, values previously set using BaseJointPriorDist.set_rescale() are used. kwargs: dict All keyword args that need to be passed to _rescale method, these keyword args are called in the JointPrior rescale methods for each parameter @@ -317,7 +328,11 @@ def rescale(self, value, **kwargs): An vector sample drawn from the multivariate Gaussian distribution. """ - samp = np.array(value) + if value is None: + samp = np.array(list(self._rescale_parameters.values())).T + else: + samp = np.array(value) + if len(samp.shape) == 1: samp = samp.reshape(1, self.num_vars) @@ -327,6 +342,11 @@ def rescale(self, value, **kwargs): raise ValueError("Array is the wrong shape") samp = self._rescale(samp, **kwargs) + if value is None: + for i, key in enumerate(self.names): + output = self.get_rescaled(key) + # update in-place for proper handling in PriorDict-instances + output[:] = samp[:, i] return np.squeeze(samp) def _rescale(self, samp, **kwargs): @@ -790,19 +810,23 @@ def rescale(self, val, **kwargs): all kwargs passed to the dist.rescale method Returns ======= - float: - A sample from the prior parameter. + np.ndarray: + The samples from the prior parameter. If not all names in "dist" have been filled, + the array contains only np.nan. *This* specific array instance will be filled with + the rescaled value once all parameters have been requested """ - self.dist.rescale_parameters[self.name] = val + self.dist.set_rescale(self.name, val) if self.dist.filled_rescale(): - values = np.array(list(self.dist.rescale_parameters.values())).T - samples = self.dist.rescale(values, **kwargs) + self.dist.rescale(values=None, **kwargs) + output = self.dist.get_rescaled(self.name) self.dist.reset_rescale() - return samples else: - return [] # return empty list + output = self.dist.get_rescaled(self.name) + + # have to return raw output to conserve in-place modifications + return output def sample(self, size=1, **kwargs): """ From 8270402a4fde7639c9a3ce1dba28a30b20a3c022 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 09:55:06 +0100 Subject: [PATCH 02/22] Small fix to rescale of JointPrior --- bilby/core/prior/joint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index b088b15ba..7e6655ee3 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -819,7 +819,7 @@ def rescale(self, val, **kwargs): self.dist.set_rescale(self.name, val) if self.dist.filled_rescale(): - self.dist.rescale(values=None, **kwargs) + self.dist.rescale(value=None, **kwargs) output = self.dist.get_rescaled(self.name) self.dist.reset_rescale() else: From 27f4ef6b0de3b8a50cc897eddf39fd2f89e87a39 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Mon, 25 Nov 2024 16:50:02 +0100 Subject: [PATCH 03/22] For jointprior rescale, only cast to list once its save to loose mutability --- bilby/core/prior/dict.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 2908bd083..0490d194e 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -613,8 +613,10 @@ def rescale(self, keys, theta): samples = [] for key, units in zip(keys, theta): samps = self[key].rescale(units) + samples.append(samps) + for i, samps in enumerate(samples): # turns 0d-arrays into scalars - samples.append(np.squeeze(samps).tolist()) + samples[i] = np.squeeze(samps).tolist() return samples def test_redundancy(self, key, disable_logging=False): From 60d7be10874a72399aaf43c602b43c52c599203f Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Tue, 21 Jan 2025 16:27:34 +0100 Subject: [PATCH 04/22] Improved nomenclature, added comments, new tests that include joint priors --- bilby/core/prior/joint.py | 49 ++++++++++++++++++----------- test/core/prior/conditional_test.py | 8 +++-- test/core/prior/dict_test.py | 35 ++++++++++++++++++--- 3 files changed, 66 insertions(+), 26 deletions(-) diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index 7e6655ee3..9d2349343 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -63,9 +63,11 @@ def __init__(self, names, bounds=None): self.requested_parameters = dict() self.reset_request() - # a dictionary of the rescale(d) parameters - self._rescale_parameters = dict() - self._rescaled_parameters = dict() + # a dictionary that stores the unit-cube values of parameters for later rescaling + self._current_unit_cube_parameter_values = dict() + # a dictionary of arrays that are used as intermediate return values of JointPrior.rescale() + # and updated in-place once all parameters have been requested + self._current_rescaled_parameter_values = dict() self.reset_rescale() # a list of sampled parameters @@ -95,24 +97,24 @@ def filled_rescale(self): Check if all the rescaled parameters have been filled. """ - return not np.any([val is None for val in self._rescale_parameters.values()]) + return not np.any([val is None for val in self._current_unit_cube_parameter_values.values()]) def set_rescale(self, key, values): - values = np.array(values) - self._rescale_parameters[key] = values - self._rescaled_parameters[key] = np.atleast_1d(np.ones_like(values)) * np.nan + self._current_unit_cube_parameter_values[key] = np.array(values) + self._current_rescaled_parameter_values[key] = np.full_like(values, np.nan, dtype=float) def reset_rescale(self): """ Reset the rescaled parameters to None. """ - for name in self.names: - self._rescale_parameters[name] = None - self._rescaled_parameters[name] = None + self._current_unit_cube_parameter_values[name] = None + self._current_rescaled_parameter_values[name] = None def get_rescaled(self, key): - return self._rescaled_parameters[key] + """Return an array that will be updated in-place once the rescale-operation + has been performed.""" + return self._current_rescaled_parameter_values[key] def get_instantiation_dict(self): subclass_args = infer_args_from_method(self.__init__) @@ -317,7 +319,7 @@ def rescale(self, value, **kwargs): If given, a 1d vector sample (one for each parameter) drawn from a uniform distribution between 0 and 1, or a 2d NxM array of samples where N is the number of samples and M is the number of parameters. - If None, values previously set using BaseJointPriorDist.set_rescale() are used. + If None, the values previously set using BaseJointPriorDist.set_rescale() are used. kwargs: dict All keyword args that need to be passed to _rescale method, these keyword args are called in the JointPrior rescale methods for each parameter @@ -329,9 +331,11 @@ def rescale(self, value, **kwargs): distribution. """ if value is None: - samp = np.array(list(self._rescale_parameters.values())).T + samp = np.array(list(self._current_unit_cube_parameter_values.values())).T else: - samp = np.array(value) + for key, val in zip(self.names, value): + self.set_rescale(key, val) + samp = np.asarray(value) if len(samp.shape) == 1: samp = samp.reshape(1, self.num_vars) @@ -342,11 +346,12 @@ def rescale(self, value, **kwargs): raise ValueError("Array is the wrong shape") samp = self._rescale(samp, **kwargs) - if value is None: - for i, key in enumerate(self.names): - output = self.get_rescaled(key) - # update in-place for proper handling in PriorDict-instances - output[:] = samp[:, i] + for i, key in enumerate(self.names): + # get the numpy array used for indermediate outputs + # prior to a full rescale-operation + output = self.get_rescaled(key) + # update the array in-place + output[...] = samp[:, i] return np.squeeze(samp) def _rescale(self, samp, **kwargs): @@ -819,10 +824,16 @@ def rescale(self, val, **kwargs): self.dist.set_rescale(self.name, val) if self.dist.filled_rescale(): + # If all names have been filled, perform rescale operation self.dist.rescale(value=None, **kwargs) + # get the rescaled values for the requested parameter output = self.dist.get_rescaled(self.name) + # reset the rescale operation self.dist.reset_rescale() else: + # If not all names have been filled, return a *numpy array* + # filled only with `np.nan`. Once all names have been requested, + # this array is updated *in-place* with the rescaled values. output = self.dist.get_rescaled(self.name) # have to return raw output to conserve in-place modifications diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index 20c0cda93..abd83430f 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -334,7 +334,7 @@ def test_rescale_with_joint_prior(self): # set multivariate Gaussian distribution names = ["mvgvar_0", "mvgvar_1"] - mu = [[0.79, -0.83]] + mu = [[1, 1]] cov = [[[0.03, 0.], [0., 0.04]]] mvg = bilby.core.prior.MultivariateGaussianDist(names, mus=mu, covs=cov) @@ -349,7 +349,7 @@ def test_rescale_with_joint_prior(self): ) ) - ref_variables = list(self.test_sample.values()) + [0.4, 0.1] + ref_variables = list(self.test_sample.values()) + [0.5, 0.5] keys = list(self.test_sample.keys()) + names res = priordict.rescale(keys=keys, theta=ref_variables) @@ -359,9 +359,11 @@ def test_rescale_with_joint_prior(self): # check conditional values are still as expected expected = [self.test_sample["var_0"]] + self.assertFalse(np.any(np.isnan(res))) for ii in range(1, 4): expected.append(expected[-1] * self.test_sample[f"var_{ii}"]) - self.assertListEqual(expected, res[0:4]) + expected.extend([1, 1]) + self.assertListEqual(expected, res) def test_cdf(self): """ diff --git a/test/core/prior/dict_test.py b/test/core/prior/dict_test.py index 08e730bbe..a3c8e1cd6 100644 --- a/test/core/prior/dict_test.py +++ b/test/core/prior/dict_test.py @@ -33,8 +33,17 @@ def setUp(self): name="b", alpha=3, minimum=1, maximum=2, unit="m/s", boundary=None ) self.third_prior = bilby.core.prior.DeltaFunction(name="c", peak=42, unit="m") + + mvg = bilby.core.prior.MultivariateGaussianDist( + names=["testa", "testb"], + mus=[1, 1], + covs=np.array([[2.0, 0.5], [0.5, 2.0]]), + weights=1.0, + ) + self.testa = bilby.core.prior.MultivariateGaussian(dist=mvg, name="testa", unit="unit") + self.testb = bilby.core.prior.MultivariateGaussian(dist=mvg, name="testb", unit="unit") self.priors = dict( - mass=self.first_prior, speed=self.second_prior, length=self.third_prior + mass=self.first_prior, speed=self.second_prior, length=self.third_prior, testa=self.testa, testb=self.testb ) self.prior_set_from_dict = bilby.core.prior.PriorDict(dictionary=self.priors) self.default_prior_file = os.path.join( @@ -70,7 +79,7 @@ def test_prior_set_is_dict(self): self.assertIsInstance(self.prior_set_from_dict, dict) def test_prior_set_has_correct_length(self): - self.assertEqual(3, len(self.prior_set_from_dict)) + self.assertEqual(5, len(self.prior_set_from_dict)) def test_prior_set_has_expected_priors(self): self.assertDictEqual(self.priors, dict(self.prior_set_from_dict)) @@ -160,6 +169,12 @@ def test_to_file(self): "unit='m/s', boundary=None)\n", "mass = Uniform(minimum=0, maximum=1, name='a', latex_label='a', " "unit='kg', boundary=None)\n", + "testa_testb_mvg = MultivariateGaussianDist(names=['testa', 'testb'], nmodes=1, mus=[[1, 1]], " + "sigmas=[[1.4142135623730951, 1.4142135623730951]], " + "corrcoefs=[[[0.9999999999999998, 0.24999999999999994], [0.24999999999999994, 0.9999999999999998]]], " + "covs=[[[2.0, 0.5], [0.5, 2.0]]], weights=[1.0], bounds={'testa': (-inf, inf), 'testb': (-inf, inf)})\n", + "testa = MultivariateGaussian(dist=testa_testb_mvg, name='testa', latex_label='testa', unit='unit')\n", + "testb = MultivariateGaussian(dist=testa_testb_mvg, name='testb', latex_label='testb', unit='unit')\n", ] self.prior_set_from_dict.to_file(outdir="prior_files", label="to_file_test") with open("prior_files/to_file_test.prior") as f: @@ -178,6 +193,13 @@ def test_from_dict_with_string(self): self.assertDictEqual(self.prior_set_from_dict, from_dict) def test_convert_floats_to_delta_functions(self): + mvg = bilby.core.prior.MultivariateGaussianDist( + names=["testa", "testb"], + mus=[1, 1], + covs=np.array([[2.0, 0.5], [0.5, 2.0]]), + weights=1.0, + ) + self.prior_set_from_dict["d"] = 5 self.prior_set_from_dict["e"] = 7.3 self.prior_set_from_dict["f"] = "unconvertable" @@ -190,6 +212,8 @@ def test_convert_floats_to_delta_functions(self): name="b", alpha=3, minimum=1, maximum=2, unit="m/s", boundary=None ), length=bilby.core.prior.DeltaFunction(name="c", peak=42, unit="m"), + testa=bilby.core.prior.MultivariateGaussian(dist=mvg, name="testa", unit="unit"), + testb=bilby.core.prior.MultivariateGaussian(dist=mvg, name="testb", unit="unit"), d=bilby.core.prior.DeltaFunction(peak=5), e=bilby.core.prior.DeltaFunction(peak=7.3), f="unconvertable", @@ -321,12 +345,15 @@ def test_ln_prob(self): self.assertEqual(expected, self.prior_set_from_dict.ln_prob(samples)) def test_rescale(self): - theta = [0.5, 0.5, 0.5] + theta = [0.5, 0.5, 0.5, 0.5, 0.5] expected = [ self.first_prior.rescale(0.5), self.second_prior.rescale(0.5), self.third_prior.rescale(0.5), + self.testa.rescale(0.5), + self.testb.rescale(0.5) ] + assert not np.any(np.isnan(expected)) self.assertListEqual( sorted(expected), sorted( @@ -342,7 +369,7 @@ def test_cdf(self): Note that the format of inputs/outputs is different between the two methods. """ - sample = self.prior_set_from_dict.sample() + sample = self.prior_set_from_dict.sample_subset(keys=["length", "speed", "mass"]) original = np.array(list(sample.values())) new = np.array(self.prior_set_from_dict.rescale( sample.keys(), From 6c1bd66849ce8bceab25e6c38c0334b690b2eb22 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Tue, 21 Jan 2025 17:32:43 +0100 Subject: [PATCH 05/22] Added Warning for previously set rescale_parameters --- bilby/core/prior/joint.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index 9d2349343..cd2d0c622 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -332,18 +332,18 @@ def rescale(self, value, **kwargs): """ if value is None: samp = np.array(list(self._current_unit_cube_parameter_values.values())).T + if len(samp.shape) == 1: + samp = samp.reshape(1, self.num_vars) else: - for key, val in zip(self.names, value): - self.set_rescale(key, val) samp = np.asarray(value) - - if len(samp.shape) == 1: - samp = samp.reshape(1, self.num_vars) - - if len(samp.shape) != 2: - raise ValueError("Array is the wrong shape") - elif samp.shape[1] != self.num_vars: - raise ValueError("Array is the wrong shape") + if len(samp.shape) == 1: + samp = samp.reshape(1, self.num_vars) + if len(samp.shape) != 2: + raise ValueError("Array is the wrong shape") + elif samp.shape[1] != self.num_vars: + raise ValueError("Array is the wrong shape") + for key, val in zip(self.names, samp.T): + self.set_rescale(key, val) samp = self._rescale(samp, **kwargs) for i, key in enumerate(self.names): @@ -821,6 +821,18 @@ def rescale(self, val, **kwargs): the rescaled value once all parameters have been requested """ + if self.dist.get_rescaled(self.name) is not None: + import warnings + warnings.warn( + f"Rescale values for {self.name} in {self.dist} have already been set, " + "indicating that another rescale-operation is in progress.\n" + "Call dist.reset_rescale() on the joint prior distribution associated " + "with this prior after using dist.rescale() and make sure all parameters " + "necessary for rescaling have been requested in PriorDict.rescale().\n" + "Resetting now.", + RuntimeWarning + ) + self.dist.reset_rescale() self.dist.set_rescale(self.name, val) if self.dist.filled_rescale(): From 6148a2c20467b2330e49208b8873c286368f7374 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 09:57:00 +0100 Subject: [PATCH 06/22] Fix to BaseJointPriorDist bound check --- bilby/core/prior/joint.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index cd2d0c622..f8aff1378 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -221,11 +221,9 @@ def _check_samp(self, value): raise ValueError("Array is the wrong shape") # check sample(s) is within bounds - outbounds = np.ones(samp.shape[0], dtype=bool) + outbounds = np.zeros(samp.shape[0], dtype=bool) for s, bound in zip(samp.T, self.bounds.values()): - outbounds = (s < bound[0]) | (s > bound[1]) - if np.any(outbounds): - break + outbounds += (s < bound[0]) | (s > bound[1]) return samp, outbounds def ln_prob(self, value): From 9320b3666bed4b6e2a030569bf266d909bf7084c Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 09:58:11 +0100 Subject: [PATCH 07/22] Allow setting of "dist" attributes through JointPrior --- bilby/core/prior/joint.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index f8aff1378..dd4df6bcc 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -750,7 +750,7 @@ class MultivariateNormalDist(MultivariateGaussianDist): class JointPrior(Prior): - def __init__(self, dist, name=None, latex_label=None, unit=None): + def __init__(self, dist, name=None, latex_label=None, unit=None, **kwargs): """This defines the single parameter Prior object for parameters that belong to a JointPriorDist Parameters @@ -801,6 +801,17 @@ def maximum(self, maximum): self._maximum = maximum self.dist.bounds[self.name] = (self.dist.bounds[self.name][0], maximum) + def __setattr__(self, name, value): + # first check that the JointPrior has an explicit setter method for the attribute, which should take presedence + if hasattr(self.__class__, name) and getattr(self.__class__, name).fset is not None: + return super().__setattr__(name, value) + # then check if the BaseJointPriorDist-!subclass! has an explicit setter method for the attribute + elif hasattr(self, "dist") and hasattr(self.dist, name) and getattr(self.dist.__class__, name).fset is not None: + return self.dist.__setattr__(name, value) + # if not, use the default settattr + else: + return super().__setattr__(name, value) + def rescale(self, val, **kwargs): """ Scale a unit hypercube sample to the prior. From ee17ee0397a64f8ce457da8fc2a63650dacd185b Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 10:02:38 +0100 Subject: [PATCH 08/22] Make ConditionalPrior ready for JointPrior --- bilby/core/prior/conditional.py | 34 +++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/bilby/core/prior/conditional.py b/bilby/core/prior/conditional.py index 797cbd1c4..1a6ca7359 100644 --- a/bilby/core/prior/conditional.py +++ b/bilby/core/prior/conditional.py @@ -3,13 +3,14 @@ from .analytical import DeltaFunction, PowerLaw, Uniform, LogUniform, \ SymmetricLogUniform, Cosine, Sine, Gaussian, TruncatedGaussian, HalfGaussian, \ LogNormal, Exponential, StudentT, Beta, Logistic, Cauchy, Gamma, ChiSquared, FermiDirac -from ..utils import infer_args_from_method, infer_parameters_from_function +from .joint import JointPrior +from ..utils import infer_args_from_method, infer_parameters_from_function, get_dict_with_properties def conditional_prior_factory(prior_class): class ConditionalPrior(prior_class): - def __init__(self, condition_func, name=None, latex_label=None, unit=None, - boundary=None, **reference_params): + def __init__(self, condition_func, name=None, latex_label=None, unit=None, boundary=None, dist=None, + **reference_params): """ Parameters @@ -41,23 +42,26 @@ def condition_func(reference_params, y): See superclass boundary: str, optional See superclass + dist: BaseJointPriorDist, optional + See superclass reference_params: Initial values for attributes such as `minimum`, `maximum`. This differs on the `prior_class`, for example for the Gaussian prior this is `mu` and `sigma`. """ - if 'boundary' in infer_args_from_method(super(ConditionalPrior, self).__init__): - super(ConditionalPrior, self).__init__(name=name, latex_label=latex_label, - unit=unit, boundary=boundary, **reference_params) - else: - super(ConditionalPrior, self).__init__(name=name, latex_label=latex_label, - unit=unit, **reference_params) + kwargs = {"name": name, "latex_label": latex_label, "unit": unit, "boundary": boundary, "dist": dist} + needed_kwargs = infer_args_from_method(super(ConditionalPrior, self).__init__) + for kw in kwargs.copy(): + if kw not in needed_kwargs: + kwargs.pop(kw) + + super(ConditionalPrior, self).__init__(**kwargs, **reference_params) self._required_variables = None self.condition_func = condition_func self._reference_params = reference_params - self.__class__.__name__ = 'Conditional{}'.format(prior_class.__name__) - self.__class__.__qualname__ = 'Conditional{}'.format(prior_class.__qualname__) + self.__class__.__name__ = "Conditional{}".format(prior_class.__name__) + self.__class__.__qualname__ = "Conditional{}".format(prior_class.__qualname__) def sample(self, size=None, **required_variables): """Draw a sample from the prior @@ -202,7 +206,9 @@ def required_variables(self): return self._required_variables def get_instantiation_dict(self): - instantiation_dict = super(ConditionalPrior, self).get_instantiation_dict() + superclass_args = infer_args_from_method(super(ConditionalPrior, self).__init__) + dict_with_properties = get_dict_with_properties(self) + instantiation_dict = {key: dict_with_properties[key] for key in superclass_args} for key, value in self.reference_params.items(): instantiation_dict[key] = value return instantiation_dict @@ -322,6 +328,10 @@ class ConditionalInterped(conditional_prior_factory(Interped)): pass +class ConditionalJointPrior(conditional_prior_factory(JointPrior)): + pass + + class DirichletElement(ConditionalBeta): r""" Single element in a dirichlet distribution From d16c3d514be9c5c41100c364b690a695a5dfd80c Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 10:04:28 +0100 Subject: [PATCH 09/22] Make "mode" of MultivariateGaussianDist a setable property to use with ConditionalPriors --- bilby/core/prior/joint.py | 118 ++++++++++++++++++++++---------------- 1 file changed, 70 insertions(+), 48 deletions(-) diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index dd4df6bcc..7e655ec4a 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -635,76 +635,83 @@ def add_mode(self, mus=None, sigmas=None, corrcoef=None, cov=None, weight=1.0): ) def _rescale(self, samp, **kwargs): - try: - mode = kwargs["mode"] - except KeyError: - mode = None + mode = kwargs.get("mode", self.mode) if mode is None: if self.nmodes == 1: mode = 0 else: - mode = np.argwhere(self.cumweights - random.rng.uniform(0, 1) > 0)[0][0] + mode = random.rng.choice( + self.nmodes, size=len(samp), p=self.weights + ) samp = erfinv(2.0 * samp - 1) * 2.0 ** 0.5 # rotate and scale to the multivariate normal shape - samp = self.mus[mode] + self.sigmas[mode] * np.einsum( - "ij,kj->ik", samp * self.sqeigvalues[mode], self.eigvectors[mode] - ) + uniques = np.unique(mode) + if len(uniques) == 1: + unique = uniques[0] + samp = self.mus[unique] + self.sigmas[unique] * np.einsum( + "ij,kj->ik", samp * self.sqeigvalues[unique], self.eigvectors[unique] + ) + else: + for m in uniques: + mask = m == mode + samp[mask] = self.mus[m] + self.sigmas[m] * np.einsum( + "ij,kj->ik", samp[mask] * self.sqeigvalues[m], self.eigvectors[m] + ) + return samp def _sample(self, size, **kwargs): - try: - mode = kwargs["mode"] - except KeyError: - mode = None + mode = kwargs.get("mode", self.mode) + + samps = np.zeros((size, len(self))) + outbound = np.ones(size, dtype=bool) if mode is None: if self.nmodes == 1: mode = 0 - else: - if size == 1: - mode = np.argwhere(self.cumweights - random.rng.uniform(0, 1) > 0)[0][0] - else: - # pick modes - mode = [ - np.argwhere(self.cumweights - r > 0)[0][0] - for r in random.rng.uniform(0, 1, size) - ] + while np.any(outbound): + # sample the multivariate Gaussian keys + vals = random.rng.uniform(0, 1, (np.sum(outbound), len(self))) - samps = np.zeros((size, len(self))) - for i in range(size): - inbound = False - while not inbound: - # sample the multivariate Gaussian keys - vals = random.rng.uniform(0, 1, len(self)) - - if isinstance(mode, list): - samp = np.atleast_1d(self.rescale(vals, mode=mode[i])) - else: - samp = np.atleast_1d(self.rescale(vals, mode=mode)) - samps[i, :] = samp + if mode is None: + mode = random.rng.choice( + self.nmodes, size=np.sum(outbound), p=self.weights + ) - # check sample is in bounds (otherwise perform another draw) - outbound = False - for name, val in zip(self.names, samp): - if val < self.bounds[name][0] or val > self.bounds[name][1]: - outbound = True - break + samps[outbound] = np.atleast_1d(self.rescale(vals, mode=mode)) - if not outbound: - inbound = True + # check sample is in bounds and redraw those which are not + samps, outbound = self._check_samp(samps) return samps - def _ln_prob(self, samp, lnprob, outbounds): - for j in range(samp.shape[0]): - # loop over the modes and sum the probabilities - for i in range(self.nmodes): - # self.mvn[i] is a "standard" multivariate normal distribution; see add_mode() - z = (samp[j] - self.mus[i]) / self.sigmas[i] - lnprob[j] = np.logaddexp(lnprob[j], self.mvn[i].logpdf(z) - self.logprodsigmas[i]) + def _ln_prob(self, samp, lnprob, outbounds, **kwargs): + mode = kwargs.get("mode", self.mode) + + if mode is None: + for j in range(samp.shape[0]): + # loop over the modes and sum the probabilities + for i in range(self.nmodes): + # self.mvn[i] is a "standard" multivariate normal distribution; see add_mode() + z = (samp[j] - self.mus[i]) / self.sigmas[i] + lnprob[j] = np.logaddexp(lnprob[j], + self.mvn[i].logpdf(z) - self.logprodsigmas[i] + np.log(self.weights[i])) + else: + uniques = np.unique(np.asarray(mode, dtype=int)) + if len(uniques) == 1: + unique = uniques[0] + z = (samp[j] - self.mus[unique]) / self.sigmas[unique] + # don't multiply by the mode weight if the mode is given (ie. prob(mode|mode) = 1) + lnprob[j] = np.logaddexp(lnprob[j], self.mvn[unique].logpdf(z) - self.logprodsigmas[unique]) + else: + for m in uniques: + mask = mode == m + z = (samp[mask] - self.mus[m]) / self.sigmas[m] + # don't multiply by the mode weight if the mode is given (ie. prob(mode|mode) = 1) + lnprob[mask] = np.logaddexp(lnprob[mask], self.mvn[m].logpdf(z) - self.logprodsigmas[m]) # set out-of-bounds values to -inf lnprob[outbounds] = -np.inf @@ -744,6 +751,21 @@ def __eq__(self, other): return False return True + @property + def mode(self): + if hasattr(self, "_mode"): + return self._mode + else: + return None + + @mode.setter + def mode(self, mode): + if not np.isdtype(np.asarray(mode).dtype, "integral"): + raise ValueError("The mode to set must have integral data type.") + if np.any(mode >= self.nmodes) or np.any(mode < 0): + raise ValueError("The value of mode cannot be higher than the number of modes or smaller than zero.") + self._mode = mode + class MultivariateNormalDist(MultivariateGaussianDist): """A synonym for the :class:`~bilby.core.prior.MultivariateGaussianDist` distribution.""" From aa867324c3671236cb7f0557ae2c4aba6a23140f Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 10:07:15 +0100 Subject: [PATCH 10/22] Added TestCase for mode-setting of MultivariateGaussian with ConditionalPrior --- bilby/core/prior/joint.py | 2 +- test/core/prior/conditional_test.py | 37 +++++++++++++++++++++++------ 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index 7e655ec4a..eb55c634c 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -760,7 +760,7 @@ def mode(self): @mode.setter def mode(self, mode): - if not np.isdtype(np.asarray(mode).dtype, "integral"): + if not np.issubdtype(np.asarray(mode).dtype, np.integer): raise ValueError("The mode to set must have integral data type.") if np.any(mode >= self.nmodes) or np.any(mode < 0): raise ValueError("The value of mode cannot be higher than the number of modes or smaller than zero.") diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index abd83430f..d14684fb7 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -338,24 +338,47 @@ def test_rescale_with_joint_prior(self): cov = [[[0.03, 0.], [0., 0.04]]] mvg = bilby.core.prior.MultivariateGaussianDist(names, mus=mu, covs=cov) + names_2 = ["mvgvar_a", "mvgvar_b"] + mvg_dual_mode = bilby.core.prior.MultivariateGaussianDist( + names=names_2, + nmodes=2, + mus=[mu[0], (np.array(mu[0]) + np.ones_like(mu[0])).tolist()], + covs=[cov[0], cov[0]], + weights=[1, 2] + ) + + def condition_func_2(reference_params, var_0): + return dict(mode=np.searchsorted(np.cumsum(np.array([1, 2]) / 3), var_0)) + + def condition_func_1(reference_params, var_0, var_1): + return {"minimum": var_0, "maximum": var_1} + priordict = bilby.core.prior.ConditionalPriorDict( dict( var_3=self.prior_3, var_2=self.prior_2, var_0=self.prior_0, var_1=self.prior_1, - mvgvar_0=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_0"), - mvgvar_1=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_1"), + mvgvar_0=bilby.core.prior.ConditionalJointPrior( + condition_func_1, dist=mvg, name="mvgvar_0", minimum=self.minimum, maximum=self.maximum), + mvgvar_1=bilby.core.prior.ConditionalJointPrior( + condition_func_1, dist=mvg, name="mvgvar_1", minimum=self.minimum, maximum=self.maximum), + mvgvar_a=bilby.core.prior.ConditionalJointPrior( + condition_func_2, dist=mvg_dual_mode, name="mvgvar_a", + minimum=self.minimum, maximum=self.maximum, mode=None), + mvgvar_b=bilby.core.prior.ConditionalJointPrior( + condition_func_2, dist=mvg_dual_mode, name="mvgvar_b", + minimum=self.minimum, maximum=self.maximum, mode=None), ) ) - ref_variables = list(self.test_sample.values()) + [0.5, 0.5] - keys = list(self.test_sample.keys()) + names + ref_variables = list(self.test_sample.values()) + [0.5, 0.5] + [0.5, 0.2] + keys = list(self.test_sample.keys()) + names + names_2 res = priordict.rescale(keys=keys, theta=ref_variables) self.assertIsInstance(res, list) - self.assertEqual(np.shape(res), (6,)) - self.assertListEqual([isinstance(r, float) for r in res], 6 * [True]) + self.assertEqual(np.shape(res), (8,)) + self.assertListEqual([isinstance(r, float) for r in res], 8 * [True]) # check conditional values are still as expected expected = [self.test_sample["var_0"]] @@ -363,7 +386,7 @@ def test_rescale_with_joint_prior(self): for ii in range(1, 4): expected.append(expected[-1] * self.test_sample[f"var_{ii}"]) expected.extend([1, 1]) - self.assertListEqual(expected, res) + self.assertListEqual(expected, res[:-2]) def test_cdf(self): """ From 6eb76fac1c57993dd38742f07897874268ec9532 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 10:54:55 +0100 Subject: [PATCH 11/22] Avoid recreation of ConditionalPriorDict if not necessary --- bilby/core/prior/dict.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 0490d194e..ed91e54fc 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -727,7 +727,10 @@ def sample_subset(self, keys=iter([]), size=None): if key not in keys and isinstance(self[key], DeltaFunction) ] use_keys = add_delta_keys + list(keys) - subset_dict = ConditionalPriorDict({key: self[key] for key in use_keys}) + if set(use_keys) == set(self.keys()): + subset_dict = self + else: + subset_dict = ConditionalPriorDict({key: self[key] for key in use_keys}) if not subset_dict._resolved: raise IllegalConditionsException( "The current set of priors contains unresolvable conditions." From 82c38d04259f0cea6c76a2ef3552311da919c87b Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 11:12:03 +0100 Subject: [PATCH 12/22] Ensure rescaling step works for the chosen set of keys and that the conditional properties of the priors can be set to arrays or loop over rescale values if not --- bilby/core/prior/dict.py | 28 +++++++++++++++++++++++----- test/core/prior/conditional_test.py | 5 +++-- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index ed91e54fc..282573037 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -851,12 +851,30 @@ def rescale(self, keys, theta): """ keys = list(keys) theta = [theta[key] for key in keys] if isinstance(theta, dict) else list(theta) - self._check_resolved() - self._update_rescale_keys(keys) + if set(keys) == set(self.non_fixed_keys): + subset_dict = self + else: + subset_dict = ConditionalPriorDict({key: self[key] for key in keys}) + if not subset_dict._resolved: + raise IllegalConditionsException( + "The current set of priors contains unresolvable conditions." + ) + subset_dict._update_rescale_keys(keys) result = dict() - for key, index in zip(self.sorted_keys_without_fixed_parameters, self._rescale_indexes): - result[key] = self[key].rescale(theta[index], **self.get_required_variables(key)) - self[key].least_recently_sampled = result[key] + for key, index in zip(subset_dict.sorted_keys_without_fixed_parameters, subset_dict._rescale_indexes): + try: + result[key] = subset_dict[key].rescale(theta[index], **subset_dict.get_required_variables(key)) + except ValueError: + # Some prior classes can not handle an array of conditional parameters (e.g. alpha for PowerLaw) + # If that is the case, we sample each sample individually. + required_variables = subset_dict.get_required_variables(key) + result[key] = np.zeros_like(theta[key]) + for i in range(len(theta[key])): + rvars = { + key: value[i] for key, value in required_variables.items() + } + result[key][i] = subset_dict[key].rescale(theta[index][i], **rvars) + subset_dict[key].least_recently_sampled = result[key] samples = [] for key in keys: # turns 0d-arrays into scalars diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index d14684fb7..052f756cb 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -403,10 +403,11 @@ def test_cdf(self): ) def test_rescale_illegal_conditions(self): - del self.conditional_priors["var_0"] + test_sample = self.test_sample.copy() + test_sample.pop("var_0") with self.assertRaises(bilby.core.prior.IllegalConditionsException): self.conditional_priors.rescale( - keys=list(self.test_sample.keys()), + keys=list(test_sample.keys()), theta=list(self.test_sample.values()), ) From 5430f266f1c0e491de79145c9750f748c36db94b Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 11:23:40 +0100 Subject: [PATCH 13/22] Updated test case for conditional MultivariateGaussian to be more comprehensive --- test/core/prior/conditional_test.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index 052f756cb..57901cd65 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -351,7 +351,7 @@ def condition_func_2(reference_params, var_0): return dict(mode=np.searchsorted(np.cumsum(np.array([1, 2]) / 3), var_0)) def condition_func_1(reference_params, var_0, var_1): - return {"minimum": var_0, "maximum": var_1} + return {"minimum": var_0 - 1, "maximum": var_1 + 1} priordict = bilby.core.prior.ConditionalPriorDict( dict( @@ -387,6 +387,12 @@ def condition_func_1(reference_params, var_0, var_1): expected.append(expected[-1] * self.test_sample[f"var_{ii}"]) expected.extend([1, 1]) self.assertListEqual(expected, res[:-2]) + res_sample = priordict.sample(1) + self.assertEqual(list(res_sample.keys()), priordict.sorted_keys_without_fixed_parameters) + res_sample = priordict.sample(10) + self.assertListEqual([len(val) for val in res_sample.values()], [10] * len(res_sample.keys())) + lnprobs = priordict.ln_prob(priordict.sample(10), axis=0) + self.assertEqual(len(lnprobs), 10) def test_cdf(self): """ From c155da88198524283132c3b1216bf516d2fd46e5 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 16:46:06 +0100 Subject: [PATCH 14/22] Added and updated Test cases --- test/core/prior/conditional_test.py | 52 ++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 9 deletions(-) diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index 57901cd65..f80711b3a 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -353,8 +353,19 @@ def condition_func_2(reference_params, var_0): def condition_func_1(reference_params, var_0, var_1): return {"minimum": var_0 - 1, "maximum": var_1 + 1} + def condition_func_5(reference_parameters, mvgvar_a): + return dict(minimum=reference_parameters["minimum"], maximum=mvgvar_a) + + prior_5 = bilby.core.prior.ConditionalUniform( + condition_func=condition_func_5, minimum=self.minimum, maximum=self.maximum + ) + priordict = bilby.core.prior.ConditionalPriorDict( dict( + var_5=prior_5, + mvgvar_a=bilby.core.prior.ConditionalJointPrior( + condition_func_2, dist=mvg_dual_mode, name="mvgvar_a", + minimum=self.minimum, maximum=self.maximum, mode=None), var_3=self.prior_3, var_2=self.prior_2, var_0=self.prior_0, @@ -363,22 +374,20 @@ def condition_func_1(reference_params, var_0, var_1): condition_func_1, dist=mvg, name="mvgvar_0", minimum=self.minimum, maximum=self.maximum), mvgvar_1=bilby.core.prior.ConditionalJointPrior( condition_func_1, dist=mvg, name="mvgvar_1", minimum=self.minimum, maximum=self.maximum), - mvgvar_a=bilby.core.prior.ConditionalJointPrior( - condition_func_2, dist=mvg_dual_mode, name="mvgvar_a", - minimum=self.minimum, maximum=self.maximum, mode=None), mvgvar_b=bilby.core.prior.ConditionalJointPrior( condition_func_2, dist=mvg_dual_mode, name="mvgvar_b", minimum=self.minimum, maximum=self.maximum, mode=None), ) ) - ref_variables = list(self.test_sample.values()) + [0.5, 0.5] + [0.5, 0.2] - keys = list(self.test_sample.keys()) + names + names_2 + ref_variables = self.test_sample.copy() + ref_variables.update({"mvgvar_0": 0.5, "mvgvar_1": 0.5, "mvgvar_a": 0.5, "mvgvar_b": 0.2, "var_5": 0.5}) + keys = list(self.test_sample.keys()) + names + names_2 + ["var_5"] res = priordict.rescale(keys=keys, theta=ref_variables) self.assertIsInstance(res, list) - self.assertEqual(np.shape(res), (8,)) - self.assertListEqual([isinstance(r, float) for r in res], 8 * [True]) + self.assertEqual(np.shape(res), (9,)) + self.assertListEqual([isinstance(r, float) for r in res], 9 * [True]) # check conditional values are still as expected expected = [self.test_sample["var_0"]] @@ -389,11 +398,36 @@ def condition_func_1(reference_params, var_0, var_1): self.assertListEqual(expected, res[:-2]) res_sample = priordict.sample(1) self.assertEqual(list(res_sample.keys()), priordict.sorted_keys_without_fixed_parameters) - res_sample = priordict.sample(10) - self.assertListEqual([len(val) for val in res_sample.values()], [10] * len(res_sample.keys())) + res_sample = priordict.sample(1000) + self.assertListEqual([len(val) for val in res_sample.values()], [1000] * len(res_sample.keys())) lnprobs = priordict.ln_prob(priordict.sample(10), axis=0) self.assertEqual(len(lnprobs), 10) + with self.assertRaises(bilby.core.prior.IllegalConditionsException): + keys = set(priordict.keys()) - set(["mvgvar_a"]) + priordict.rescale(keys=keys, theta=ref_variables) + + def condition_func_6(reference_params, var_5): + return dict(mode=np.searchsorted(np.cumsum(np.array([1, 2]) / 3), var_5)) + + priordict_unresolveable = bilby.core.prior.ConditionalPriorDict( + dict( + var_5=prior_5, + var_3=self.prior_3, + var_2=self.prior_2, + var_0=self.prior_0, + var_1=self.prior_1, + mvgvar_a=bilby.core.prior.ConditionalJointPrior( + condition_func_6, dist=mvg_dual_mode, name="mvgvar_a", + minimum=self.minimum, maximum=self.maximum, mode=None), + mvgvar_b=bilby.core.prior.ConditionalJointPrior( + condition_func_6, dist=mvg_dual_mode, name="mvgvar_b", + minimum=self.minimum, maximum=self.maximum, mode=None), + + ) + ) + self.assertEqual(priordict_unresolveable._resolved, False) + def test_cdf(self): """ Test that the CDF method is the inverse of the rescale method. From 1bb471412bcf13264651be23e24deb435d4b0f90 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 17:12:06 +0100 Subject: [PATCH 15/22] Fixed ConditionalPrior.__repr__ after changes to ConditionalPrior. --- bilby/core/prior/conditional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bilby/core/prior/conditional.py b/bilby/core/prior/conditional.py index 1a6ca7359..fcad7253d 100644 --- a/bilby/core/prior/conditional.py +++ b/bilby/core/prior/conditional.py @@ -234,8 +234,8 @@ def __repr__(self): prior_name = self.__class__.__name__ instantiation_dict = self.get_instantiation_dict() instantiation_dict["condition_func"] = ".".join([ - instantiation_dict["condition_func"].__module__, - instantiation_dict["condition_func"].__name__ + self.condition_func.__module__, + self.condition_func.__name__ ]) args = ', '.join(['{}={}'.format(key, repr(instantiation_dict[key])) for key in instantiation_dict]) From 27e07915258cf3507f98cee485215c607beb9d66 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 17:17:15 +0100 Subject: [PATCH 16/22] Improve ConditionalPriorDict.rescale and ConditionalPriorDict.sample: Remove necessity to initialize a whole new class instance for lists of dicts that do not span all keys of the ConditionalPriorDict --- bilby/core/prior/dict.py | 85 ++++++++++++++++++---------------------- 1 file changed, 39 insertions(+), 46 deletions(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 282573037..17ebc9b4f 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -672,8 +672,6 @@ def __init__(self, dictionary=None, filename=None, conversion_function=None): self._conditional_keys = [] self._unconditional_keys = [] self._rescale_keys = [] - self._rescale_indexes = [] - self._least_recently_rescaled_keys = [] super(ConditionalPriorDict, self).__init__( dictionary=dictionary, filename=filename, @@ -720,40 +718,42 @@ def _check_conditions_resolved(self, key, sampled_keys): return conditions_resolved def sample_subset(self, keys=iter([]), size=None): + keys = list(keys) self.convert_floats_to_delta_functions() - add_delta_keys = [ - key - for key in self.keys() - if key not in keys and isinstance(self[key], DeltaFunction) - ] - use_keys = add_delta_keys + list(keys) - if set(use_keys) == set(self.keys()): - subset_dict = self - else: - subset_dict = ConditionalPriorDict({key: self[key] for key in use_keys}) - if not subset_dict._resolved: - raise IllegalConditionsException( - "The current set of priors contains unresolvable conditions." - ) + add_delta_keys = [] + for key in self.keys(): + if key not in keys and isinstance(self[key], DeltaFunction): + add_delta_keys.append(key) + + use_keys = add_delta_keys + keys + unconditional_use_keys = [key for key in self.unconditional_keys if key in use_keys] + sorted_conditional_use_keys = [key for key in self.conditional_keys if key in use_keys] + + for i, key in enumerate(sorted_conditional_use_keys): + if not self._check_conditions_resolved(key, unconditional_use_keys + sorted_conditional_use_keys[:i]): + raise IllegalConditionsException( + "The current set of priors contains unresolvable conditions." + ) + sorted_use_keys = unconditional_use_keys + sorted_conditional_use_keys samples = dict() - for key in subset_dict.sorted_keys: + for key in sorted_use_keys: if key not in keys or isinstance(self[key], Constraint): continue if isinstance(self[key], Prior): try: - samples[key] = subset_dict[key].sample( - size=size, **subset_dict.get_required_variables(key) + samples[key] = self[key].sample( + size=size, **self.get_required_variables(key) ) except ValueError: # Some prior classes can not handle an array of conditional parameters (e.g. alpha for PowerLaw) # If that is the case, we sample each sample individually. - required_variables = subset_dict.get_required_variables(key) + required_variables = self.get_required_variables(key) samples[key] = np.zeros(size) for i in range(size): rvars = { key: value[i] for key, value in required_variables.items() } - samples[key][i] = subset_dict[key].sample(**rvars) + samples[key][i] = self[key].sample(**rvars) else: logger.debug("{} not a known prior.".format(key)) return samples @@ -850,31 +850,32 @@ def rescale(self, keys, theta): If theta is array-like for each key, returns list of lists containing the rescaled samples. """ keys = list(keys) - theta = [theta[key] for key in keys] if isinstance(theta, dict) else list(theta) - if set(keys) == set(self.non_fixed_keys): - subset_dict = self - else: - subset_dict = ConditionalPriorDict({key: self[key] for key in keys}) - if not subset_dict._resolved: - raise IllegalConditionsException( - "The current set of priors contains unresolvable conditions." - ) - subset_dict._update_rescale_keys(keys) + + unconditional_keys = [key for key in self.unconditional_keys if key in keys] + sorted_conditional_keys = [key for key in self.conditional_keys if key in keys] + + for i, key in enumerate(sorted_conditional_keys): + if not self._check_conditions_resolved(key, unconditional_keys + sorted_conditional_keys[:i]): + raise IllegalConditionsException( + "The current set of priors contains unresolvable conditions." + ) + sorted_keys = unconditional_keys + sorted_conditional_keys + theta = [theta[key] for key in sorted_keys] if isinstance(theta, dict) else list(theta) result = dict() - for key, index in zip(subset_dict.sorted_keys_without_fixed_parameters, subset_dict._rescale_indexes): + for key, vals in zip(sorted_keys, theta): try: - result[key] = subset_dict[key].rescale(theta[index], **subset_dict.get_required_variables(key)) + result[key] = self[key].rescale(vals, **self.get_required_variables(key)) except ValueError: # Some prior classes can not handle an array of conditional parameters (e.g. alpha for PowerLaw) # If that is the case, we sample each sample individually. - required_variables = subset_dict.get_required_variables(key) - result[key] = np.zeros_like(theta[key]) - for i in range(len(theta[key])): + required_variables = self.get_required_variables(key) + result[key] = np.zeros_like(vals) + for i in range(len(vals)): rvars = { key: value[i] for key, value in required_variables.items() } - result[key][i] = subset_dict[key].rescale(theta[index][i], **rvars) - subset_dict[key].least_recently_sampled = result[key] + result[key][i] = self[key].rescale(vals[i], **rvars) + self[key].least_recently_sampled = result[key] samples = [] for key in keys: # turns 0d-arrays into scalars @@ -882,14 +883,6 @@ def rescale(self, keys, theta): samples.append(res) return samples - def _update_rescale_keys(self, keys): - if not keys == self._least_recently_rescaled_keys: - self._rescale_indexes = [ - keys.index(element) - for element in self.sorted_keys_without_fixed_parameters - ] - self._least_recently_rescaled_keys = keys - def _prepare_evaluation(self, keys, theta): self._check_resolved() for key, value in zip(keys, theta): From 3c76f3ffe16d1d7df3914bf18eb01288997effe6 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 17:21:53 +0100 Subject: [PATCH 17/22] Handle JointPrior's better in rescale, sample, (ln)prob, and _check_conditions_resolved of (Condtional)PriorDict - keep track of dependencies of JointPriors necessary for their complete evaluation and handle cases where not all necessary keys are requested. --- bilby/core/prior/dict.py | 64 +++++++++++++++++++++++++++++++++++----- 1 file changed, 56 insertions(+), 8 deletions(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 17ebc9b4f..b9a4a97ea 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -420,6 +420,9 @@ def sample_subset(self, keys=iter([]), size=None): samples[key] = self[key].sample(size=size) else: logger.debug("{} not a known prior.".format(key)) + # ensure that `reset_sampled()` of all JointPrior.dist + # with missing dependencies is called + self._reset_jointprior_dists_with_missed_dependencies(keys, "reset_sampled") return samples @property @@ -430,6 +433,27 @@ def non_fixed_keys(self): keys = [k for k in keys if k not in self.constraint_keys] return keys + @property + def jointprior_dependencies(self): + keys = self.keys() + keys = [k for k in keys if isinstance(self[k], JointPrior)] + dependencies = {k: list(set(self[k].dist.names) - set([k])) for k in keys} + return dependencies + + def _reset_jointprior_dists_with_missed_dependencies(self, keys, reset_func): + keys = set(keys) + dependencies = self.jointprior_dependencies + requested_jointpriors = set(dependencies).intersection() + missing_dependencies = {value for key in requested_jointpriors for value in dependencies[key]} + reset_dists = [] + for key in missing_dependencies: + dist = self[key].dist + if id(dist) in reset_dists: + pass + else: + getattr(dist, reset_func)() + reset_dists.append(id(dist)) + @property def fixed_keys(self): return [ @@ -499,13 +523,16 @@ def _estimate_normalization(self, keys, min_accept, sampling_chunk): factor = len(keep) / np.count_nonzero(keep) return factor - def prob(self, sample, **kwargs): + def prob(self, sample, normalized=True, **kwargs): """ Parameters ========== sample: dict Dictionary of the samples of which we want to have the probability of + normalized: bool + When False, disables calculation of constraint normalization factor + during prior probability computation. Default value is True. kwargs: The keyword arguments are passed directly to `np.prod` @@ -516,10 +543,16 @@ def prob(self, sample, **kwargs): """ prob = np.prod([self[key].prob(sample[key]) for key in sample], **kwargs) - return self.check_prob(sample, prob) + # ensure that `reset_request()` of all JointPrior.dist + # with missing dependencies is called + self._reset_jointprior_dists_with_missed_dependencies(sample.keys(), reset_func="reset_request") + return self.check_prob(sample, prob, normalized) - def check_prob(self, sample, prob): - ratio = self.normalize_constraint_factor(tuple(sample.keys())) + def check_prob(self, sample, prob, normalized=True): + if normalized: + ratio = self.normalize_constraint_factor(tuple(sample.keys())) + else: + ratio = 1 if np.all(prob == 0.0): return prob * ratio else: @@ -534,18 +567,18 @@ def check_prob(self, sample, prob): constrained_prob[keep] = prob[keep] * ratio return constrained_prob - def ln_prob(self, sample, axis=None, normalized=True): + def ln_prob(self, sample, normalized=True, **kwargs): """ Parameters ========== sample: dict Dictionary of the samples of which to calculate the log probability - axis: None or int - Axis along which the summation is performed normalized: bool When False, disables calculation of constraint normalization factor during prior probability computation. Default value is True. + kwargs: + The keyword arguments are passed directly to `np.prod` Returns ======= @@ -553,7 +586,11 @@ def ln_prob(self, sample, axis=None, normalized=True): Joint log probability of all the individual sample probabilities """ - ln_prob = np.sum([self[key].ln_prob(sample[key]) for key in sample], axis=axis) + ln_prob = np.sum([self[key].ln_prob(sample[key]) for key in sample], **kwargs) + + # ensure that `reset_request()` of all JointPrior.dist + # with missing dependencies is called + self._reset_jointprior_dists_with_missed_dependencies(sample.keys(), "reset_request") return self.check_ln_prob(sample, ln_prob, normalized=normalized) @@ -617,6 +654,7 @@ def rescale(self, keys, theta): for i, samps in enumerate(samples): # turns 0d-arrays into scalars samples[i] = np.squeeze(samps).tolist() + self._reset_jointprior_dists_with_missed_dependencies(keys, "reset_rescale") return samples def test_redundancy(self, key, disable_logging=False): @@ -715,6 +753,12 @@ def _check_conditions_resolved(self, key, sampled_keys): for k in self[key].required_variables: if k not in sampled_keys: conditions_resolved = False + break + elif isinstance(self[k], JointPrior): + dependencies = self.jointprior_dependencies[k] + if len(set(dependencies) - set(sampled_keys)) > 0: + conditions_resolved = False + break return conditions_resolved def sample_subset(self, keys=iter([]), size=None): @@ -756,6 +800,7 @@ def sample_subset(self, keys=iter([]), size=None): samples[key][i] = self[key].sample(**rvars) else: logger.debug("{} not a known prior.".format(key)) + self._reset_jointprior_dists_with_missed_dependencies(keys, "reset_sampled") return samples def get_required_variables(self, key): @@ -796,6 +841,7 @@ def prob(self, sample, **kwargs): for key in sample ] prob = np.prod(res, **kwargs) + self._reset_jointprior_dists_with_missed_dependencies(sample.keys(), "reset_request") return self.check_prob(sample, prob) def ln_prob(self, sample, axis=None, normalized=True): @@ -822,6 +868,7 @@ def ln_prob(self, sample, axis=None, normalized=True): for key in sample ] ln_prob = np.sum(res, axis=axis) + self._reset_jointprior_dists_with_missed_dependencies(sample.keys(), "reset_request") return self.check_ln_prob(sample, ln_prob, normalized=normalized) @@ -881,6 +928,7 @@ def rescale(self, keys, theta): # turns 0d-arrays into scalars res = np.squeeze(result[key]).tolist() samples.append(res) + self._reset_jointprior_dists_with_missed_dependencies(keys, "reset_rescale") return samples def _prepare_evaluation(self, keys, theta): From 82a4ee0d88f050ed2b6c4f087fabee86e0ba869a Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 17:52:57 +0100 Subject: [PATCH 18/22] Added test cases for new behavior --- test/core/prior/dict_test.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/core/prior/dict_test.py b/test/core/prior/dict_test.py index a3c8e1cd6..bf40887cf 100644 --- a/test/core/prior/dict_test.py +++ b/test/core/prior/dict_test.py @@ -311,6 +311,11 @@ def test_sample_subset_with_actual_subset(self): expected = dict(length=np.array([42.0, 42.0, 42.0])) self.assertTrue(np.array_equal(expected["length"], samples["length"])) + joint_prior = self.joint_prior_from_file + samples = joint_prior.sample_subset(keys=["testAbase"], size=size) + self.assertTrue(joint_prior["testAbase"].dist.sampled_parameters == []) + self.assertTrue(joint_prior["testBbase"].dist.sampled_parameters == []) + def test_sample_subset_constrained_as_array(self): size = 3 keys = ["mass", "speed"] @@ -344,6 +349,15 @@ def test_ln_prob(self): ) + self.second_prior.ln_prob(samples["speed"]) self.assertEqual(expected, self.prior_set_from_dict.ln_prob(samples)) + def test_ln_prob_actual_subset(self): + joint_prior = self.joint_prior_from_file + keys = ["testAbase"] + samples = joint_prior.sample_subset(keys=keys, size=1) + lnprob = joint_prior.ln_prob(samples) + self.assertTrue(joint_prior["testAbase"].dist.requested_parameters["testAbase"] is None) + self.assertTrue(joint_prior["testBbase"].dist.requested_parameters["testBbase"] is None) + self.assertTrue(lnprob == 0) + def test_rescale(self): theta = [0.5, 0.5, 0.5, 0.5, 0.5] expected = [ @@ -363,6 +377,16 @@ def test_rescale(self): ), ) + def test_rescale_actual_subset(self): + theta = [0.5] + keys = ["testAbase"] + joint_prior = self.joint_prior_from_file + samples = joint_prior.rescale(keys=keys, theta=theta) + print(joint_prior["testAbase"].dist._rescale_parameters) + self.assertTrue(joint_prior["testAbase"].dist._rescale_parameters["testAbase"] is None) + self.assertTrue(joint_prior["testBbase"].dist._rescale_parameters["testBbase"] is None) + self.assertTrue(np.all(np.isnan(samples))) + def test_cdf(self): """ Test that the CDF method is the inverse of the rescale method. From 927efc01c6a1c4a5a4233440d7bea9d3b4a494a0 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Mon, 16 Dec 2024 11:00:08 +0100 Subject: [PATCH 19/22] Small bugfix --- bilby/core/prior/dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index b9a4a97ea..eae6cc794 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -443,7 +443,7 @@ def jointprior_dependencies(self): def _reset_jointprior_dists_with_missed_dependencies(self, keys, reset_func): keys = set(keys) dependencies = self.jointprior_dependencies - requested_jointpriors = set(dependencies).intersection() + requested_jointpriors = set(dependencies).intersection(keys) missing_dependencies = {value for key in requested_jointpriors for value in dependencies[key]} reset_dists = [] for key in missing_dependencies: From aaca2f888bbe50fbcedc01e0573b16372edc8edf Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Tue, 21 Jan 2025 17:07:15 +0100 Subject: [PATCH 20/22] fix tests after rebase --- test/core/prior/conditional_test.py | 2 +- test/core/prior/dict_test.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index f80711b3a..1382eae35 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -395,7 +395,7 @@ def condition_func_5(reference_parameters, mvgvar_a): for ii in range(1, 4): expected.append(expected[-1] * self.test_sample[f"var_{ii}"]) expected.extend([1, 1]) - self.assertListEqual(expected, res[:-2]) + self.assertListEqual(expected, res[:6]) res_sample = priordict.sample(1) self.assertEqual(list(res_sample.keys()), priordict.sorted_keys_without_fixed_parameters) res_sample = priordict.sample(1000) diff --git a/test/core/prior/dict_test.py b/test/core/prior/dict_test.py index bf40887cf..b8421601d 100644 --- a/test/core/prior/dict_test.py +++ b/test/core/prior/dict_test.py @@ -382,9 +382,8 @@ def test_rescale_actual_subset(self): keys = ["testAbase"] joint_prior = self.joint_prior_from_file samples = joint_prior.rescale(keys=keys, theta=theta) - print(joint_prior["testAbase"].dist._rescale_parameters) - self.assertTrue(joint_prior["testAbase"].dist._rescale_parameters["testAbase"] is None) - self.assertTrue(joint_prior["testBbase"].dist._rescale_parameters["testBbase"] is None) + self.assertTrue(joint_prior["testAbase"].dist._current_rescaled_parameter_values["testAbase"] is None) + self.assertTrue(joint_prior["testBbase"].dist._current_rescaled_parameter_values["testBbase"] is None) self.assertTrue(np.all(np.isnan(samples))) def test_cdf(self): From 3ff62eca3db6b3e979dbda7cd5fc82d3b23dd109 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Wed, 22 Jan 2025 12:05:17 +0100 Subject: [PATCH 21/22] Better Error Handling --- bilby/core/prior/joint.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index eb55c634c..80e13d40c 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -317,7 +317,8 @@ def rescale(self, value, **kwargs): If given, a 1d vector sample (one for each parameter) drawn from a uniform distribution between 0 and 1, or a 2d NxM array of samples where N is the number of samples and M is the number of parameters. - If None, the values previously set using BaseJointPriorDist.set_rescale() are used. + If None, the values previously set using BaseJointPriorDist.set_rescale() are used, + the result is stored and can be accessed using get_rescaled(). kwargs: dict All keyword args that need to be passed to _rescale method, these keyword args are called in the JointPrior rescale methods for each parameter @@ -329,7 +330,9 @@ def rescale(self, value, **kwargs): distribution. """ if value is None: - samp = np.array(list(self._current_unit_cube_parameter_values.values())).T + if not self.filled_rescale(): + raise ValueError("Attempting to rescale from stored values without having set all required values.") + samp = np.array([self._current_unit_cube_parameter_values[key] for key in self.names]).T if len(samp.shape) == 1: samp = samp.reshape(1, self.num_vars) else: @@ -340,16 +343,16 @@ def rescale(self, value, **kwargs): raise ValueError("Array is the wrong shape") elif samp.shape[1] != self.num_vars: raise ValueError("Array is the wrong shape") - for key, val in zip(self.names, samp.T): - self.set_rescale(key, val) samp = self._rescale(samp, **kwargs) - for i, key in enumerate(self.names): - # get the numpy array used for indermediate outputs - # prior to a full rescale-operation - output = self.get_rescaled(key) - # update the array in-place - output[...] = samp[:, i] + # only store result if the rescale was done with saved unit-cube values + if value is None: + for i, key in enumerate(self.names): + # get the numpy array used for indermediate outputs + # prior to a full rescale-operation + output = self.get_rescaled(key) + # update the array in-place + output[...] = samp[:, i] return np.squeeze(samp) def _rescale(self, samp, **kwargs): From aa90c86eef0143236b76079e16cd46481e4bf877 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Wed, 22 Jan 2025 12:22:35 +0100 Subject: [PATCH 22/22] New test case for JointPrior.ln_log and subsequent fixes --- bilby/core/prior/joint.py | 38 ++++++++++------ test/core/prior/joint_test.py | 85 ++++++++++++++++++++++++++--------- 2 files changed, 89 insertions(+), 34 deletions(-) diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index 80e13d40c..172382c28 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -184,13 +184,13 @@ def _split_repr(cls, string): kwargs[key.strip()] = arg return kwargs - def prob(self, samp): + def prob(self, samp, **kwargs): """ Get the probability of a sample. For bounded priors the probability will not be properly normalised. """ - return np.exp(self.ln_prob(samp)) + return np.exp(self.ln_prob(samp, **kwargs)) def _check_samp(self, value): """ @@ -226,7 +226,7 @@ def _check_samp(self, value): outbounds += (s < bound[0]) | (s > bound[1]) return samp, outbounds - def ln_prob(self, value): + def ln_prob(self, value, **kwargs): """ Get the log-probability of a sample. For bounded priors the probability will not be properly normalised. @@ -241,7 +241,7 @@ def ln_prob(self, value): samp, outbounds = self._check_samp(value) lnprob = -np.inf * np.ones(samp.shape[0]) - lnprob = self._ln_prob(samp, lnprob, outbounds) + lnprob = self._ln_prob(samp, lnprob, outbounds, **kwargs) if samp.shape[0] == 1: return lnprob[0] else: @@ -658,6 +658,11 @@ def _rescale(self, samp, **kwargs): "ij,kj->ik", samp * self.sqeigvalues[unique], self.eigvectors[unique] ) else: + mode = np.asarray(mode) + if mode.shape != (samp.shape[0],): + raise ValueError(f"Inconsistent sizes of the array-like used to select modes " + f"with shape {mode.shape} and the array of requested samps " + f"with length {len(samp)}.") for m in uniques: mask = m == mode samp[mask] = self.mus[m] + self.sigmas[m] * np.einsum( @@ -695,21 +700,28 @@ def _ln_prob(self, samp, lnprob, outbounds, **kwargs): mode = kwargs.get("mode", self.mode) if mode is None: - for j in range(samp.shape[0]): - # loop over the modes and sum the probabilities - for i in range(self.nmodes): - # self.mvn[i] is a "standard" multivariate normal distribution; see add_mode() - z = (samp[j] - self.mus[i]) / self.sigmas[i] - lnprob[j] = np.logaddexp(lnprob[j], - self.mvn[i].logpdf(z) - self.logprodsigmas[i] + np.log(self.weights[i])) + # loop over the modes and sum the probabilities + for i in range(self.nmodes): + # self.mvn[i] is a "standard" multivariate normal distribution; see add_mode() + z = (samp - self.mus[i]) / self.sigmas[i] + lnprob = np.logaddexp( + lnprob, + self.mvn[i].logpdf(z) - self.logprodsigmas[i] + np.log(self.weights[i]) + ) else: uniques = np.unique(np.asarray(mode, dtype=int)) if len(uniques) == 1: unique = uniques[0] - z = (samp[j] - self.mus[unique]) / self.sigmas[unique] + z = (samp - self.mus[unique]) / self.sigmas[unique] # don't multiply by the mode weight if the mode is given (ie. prob(mode|mode) = 1) - lnprob[j] = np.logaddexp(lnprob[j], self.mvn[unique].logpdf(z) - self.logprodsigmas[unique]) + lnprob = np.logaddexp(lnprob, self.mvn[unique].logpdf(z) - self.logprodsigmas[unique]) else: + mode = np.asarray(mode) + print(mode.shape, samp.shape) + if mode.shape != (samp.shape[0],): + raise ValueError(f"Inconsistent sizes of the array-like used to select modes " + f"with shape {mode.shape} and the array of requested samps " + f"with length {len(samp)}.") for m in uniques: mask = mode == m z = (samp[mask] - self.mus[m]) / self.sigmas[m] diff --git a/test/core/prior/joint_test.py b/test/core/prior/joint_test.py index c99373b00..784cd3030 100644 --- a/test/core/prior/joint_test.py +++ b/test/core/prior/joint_test.py @@ -40,33 +40,76 @@ def test_mvg_from_repr(self): class TestMultivariateGaussianDistParameterScales(unittest.TestCase): - def _test_mvg_ln_prob_diff_expected(self, mvg, mus, sigmas, corrcoefs): - # the columns of the Cholesky decompsition give the directions along which - # the multivariate Gaussian PDF will decrease by exact differences per unit - # sigma; test that these are as expected - ln_prob_mus = mvg.ln_prob(mus) - d = np.linalg.cholesky(corrcoefs) - for i in np.ndindex(4, 4, 4): - ln_prob_mus_sigmas_d_i = mvg.ln_prob(mus + sigmas * (d @ i)) - diff_ln_prob = ln_prob_mus - ln_prob_mus_sigmas_d_i - diff_ln_prob_expected = 0.5 * np.sum(np.array(i)**2) - self.assertTrue( - np.allclose(diff_ln_prob, diff_ln_prob_expected) - ) + def _test_mvg_ln_prob_diff_expected(self, mvg, weights, muss, sigmass, corrcoefss): + all_test_points = [] + all_expected_probs = [] + + # first test all modes individually and store the results + for mode, (weight, mus, sigmas, corrcoefs) in enumerate(zip(weights, muss, sigmass, corrcoefss)): + # the columns of the Cholesky decompsition give the directions along which + # the multivariate Gaussian PDF will decrease by exact differences per unit + # sigma; test that these are as expected + ln_prob_mus = mvg.ln_prob(mus, mode=mode) + d = np.linalg.cholesky(corrcoefs) + test_points = [] + test_point_probs = [] + for i in np.ndindex(4, 4, 4): + ln_prob_mus_sigmas_d_i = mvg.ln_prob(mus + sigmas * (d @ i), mode=mode) + test_points.append(mus + sigmas * (d @ i)) + test_point_probs.append(ln_prob_mus_sigmas_d_i) + diff_ln_prob = ln_prob_mus - ln_prob_mus_sigmas_d_i + diff_ln_prob_expected = 0.5 * np.sum(np.array(i)**2) + self.assertTrue( + np.allclose(diff_ln_prob, diff_ln_prob_expected) + ) + test_point_probs_at_once = mvg.ln_prob(test_points, mode=mode) + + np.testing.assert_allclose(test_point_probs, test_point_probs_at_once) + all_test_points.append(test_points) + all_expected_probs.append(test_point_probs_at_once) + + # For the points associated with each mode, we have calculated the expected probabilities + # above. We can test if the other modes are taken into account correctly by: + # 1. For the points with known values for *one* mode, calculate the probabilties summed over all modes + # 2. Calculate the probabilities for these points for all modes except the one for which we know the true values + # 3. Subtract the probability over all other modes from the total probability + # 4. Check if the resulting probability (adjusted for the weight of the mode) is equal to the + # expected value + for current_mode, (test_points, expected_probs) in enumerate(zip(all_test_points, all_expected_probs)): + test_point_probs_at_once_all_modes = mvg.ln_prob(test_points) + prob_other_modes = np.full(len(test_points), -np.inf) + for mode, weight in enumerate(weights): + if mode == current_mode: + continue + prob_other_modes = np.logaddexp(prob_other_modes, np.log(weight) + mvg.ln_prob(test_points, mode=mode)) + ln_prob_current_mode = np.log(np.exp(test_point_probs_at_once_all_modes) - np.exp(prob_other_modes)) + ln_prob_current_mode -= np.log(weights[current_mode]) + np.testing.assert_allclose(ln_prob_current_mode, expected_probs) def test_mvg_unit_scales(self): # test using order-unity standard deviations and correlations - sigmas = 0.3 * np.ones(3) - corrcoefs = np.identity(3) - mus = np.array([3, 1, 2]) + sigmas_1 = 0.3 * np.ones(3) + corrcoefs_1 = np.identity(3) + mus_1 = np.array([3, 1, 2]) + + sigmas_2 = 0.4 * np.ones(3) + corrcoefs_2 = np.identity(3) + mus_2 = np.array([3, 1, 2]) + + sigmas_3 = 0.1 * np.ones(3) + corrcoefs_3 = np.identity(3) + mus_3 = np.array([3.2, 1., 2.5]) + weights = [0.5, 0.3, 0.2] mvg = bilby.core.prior.MultivariateGaussianDist( + nmodes=3, names=['a', 'b', 'c'], - mus=mus, - sigmas=sigmas, - corrcoefs=corrcoefs, + mus=[mus_1, mus_2, mus_3], + sigmas=[sigmas_1, sigmas_2, sigmas_3], + corrcoefs=[corrcoefs_1, corrcoefs_2, corrcoefs_3], + weights=weights ) - self._test_mvg_ln_prob_diff_expected(mvg, mus, sigmas, corrcoefs) + self._test_mvg_ln_prob_diff_expected(mvg, mvg.weights, mvg.mus, mvg.sigmas, mvg.corrcoefs) def test_mvg_cw_scales(self): # test using standard deviations and correlations from the @@ -92,7 +135,7 @@ def test_mvg_cw_scales(self): corrcoefs=corrcoefs, ) - self._test_mvg_ln_prob_diff_expected(mvg, mus, sigmas, corrcoefs) + self._test_mvg_ln_prob_diff_expected(mvg, [1], [mus], [sigmas], [corrcoefs]) if __name__ == "__main__":