diff --git a/bilby/core/prior/conditional.py b/bilby/core/prior/conditional.py index 797cbd1c..fcad7253 100644 --- a/bilby/core/prior/conditional.py +++ b/bilby/core/prior/conditional.py @@ -3,13 +3,14 @@ from .analytical import DeltaFunction, PowerLaw, Uniform, LogUniform, \ SymmetricLogUniform, Cosine, Sine, Gaussian, TruncatedGaussian, HalfGaussian, \ LogNormal, Exponential, StudentT, Beta, Logistic, Cauchy, Gamma, ChiSquared, FermiDirac -from ..utils import infer_args_from_method, infer_parameters_from_function +from .joint import JointPrior +from ..utils import infer_args_from_method, infer_parameters_from_function, get_dict_with_properties def conditional_prior_factory(prior_class): class ConditionalPrior(prior_class): - def __init__(self, condition_func, name=None, latex_label=None, unit=None, - boundary=None, **reference_params): + def __init__(self, condition_func, name=None, latex_label=None, unit=None, boundary=None, dist=None, + **reference_params): """ Parameters @@ -41,23 +42,26 @@ def condition_func(reference_params, y): See superclass boundary: str, optional See superclass + dist: BaseJointPriorDist, optional + See superclass reference_params: Initial values for attributes such as `minimum`, `maximum`. This differs on the `prior_class`, for example for the Gaussian prior this is `mu` and `sigma`. """ - if 'boundary' in infer_args_from_method(super(ConditionalPrior, self).__init__): - super(ConditionalPrior, self).__init__(name=name, latex_label=latex_label, - unit=unit, boundary=boundary, **reference_params) - else: - super(ConditionalPrior, self).__init__(name=name, latex_label=latex_label, - unit=unit, **reference_params) + kwargs = {"name": name, "latex_label": latex_label, "unit": unit, "boundary": boundary, "dist": dist} + needed_kwargs = infer_args_from_method(super(ConditionalPrior, self).__init__) + for kw in kwargs.copy(): + if kw not in needed_kwargs: + kwargs.pop(kw) + + super(ConditionalPrior, self).__init__(**kwargs, **reference_params) self._required_variables = None self.condition_func = condition_func self._reference_params = reference_params - self.__class__.__name__ = 'Conditional{}'.format(prior_class.__name__) - self.__class__.__qualname__ = 'Conditional{}'.format(prior_class.__qualname__) + self.__class__.__name__ = "Conditional{}".format(prior_class.__name__) + self.__class__.__qualname__ = "Conditional{}".format(prior_class.__qualname__) def sample(self, size=None, **required_variables): """Draw a sample from the prior @@ -202,7 +206,9 @@ def required_variables(self): return self._required_variables def get_instantiation_dict(self): - instantiation_dict = super(ConditionalPrior, self).get_instantiation_dict() + superclass_args = infer_args_from_method(super(ConditionalPrior, self).__init__) + dict_with_properties = get_dict_with_properties(self) + instantiation_dict = {key: dict_with_properties[key] for key in superclass_args} for key, value in self.reference_params.items(): instantiation_dict[key] = value return instantiation_dict @@ -228,8 +234,8 @@ def __repr__(self): prior_name = self.__class__.__name__ instantiation_dict = self.get_instantiation_dict() instantiation_dict["condition_func"] = ".".join([ - instantiation_dict["condition_func"].__module__, - instantiation_dict["condition_func"].__name__ + self.condition_func.__module__, + self.condition_func.__name__ ]) args = ', '.join(['{}={}'.format(key, repr(instantiation_dict[key])) for key in instantiation_dict]) @@ -322,6 +328,10 @@ class ConditionalInterped(conditional_prior_factory(Interped)): pass +class ConditionalJointPrior(conditional_prior_factory(JointPrior)): + pass + + class DirichletElement(ConditionalBeta): r""" Single element in a dirichlet distribution diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index be3d543a..eae6cc79 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -420,6 +420,9 @@ def sample_subset(self, keys=iter([]), size=None): samples[key] = self[key].sample(size=size) else: logger.debug("{} not a known prior.".format(key)) + # ensure that `reset_sampled()` of all JointPrior.dist + # with missing dependencies is called + self._reset_jointprior_dists_with_missed_dependencies(keys, "reset_sampled") return samples @property @@ -430,6 +433,27 @@ def non_fixed_keys(self): keys = [k for k in keys if k not in self.constraint_keys] return keys + @property + def jointprior_dependencies(self): + keys = self.keys() + keys = [k for k in keys if isinstance(self[k], JointPrior)] + dependencies = {k: list(set(self[k].dist.names) - set([k])) for k in keys} + return dependencies + + def _reset_jointprior_dists_with_missed_dependencies(self, keys, reset_func): + keys = set(keys) + dependencies = self.jointprior_dependencies + requested_jointpriors = set(dependencies).intersection(keys) + missing_dependencies = {value for key in requested_jointpriors for value in dependencies[key]} + reset_dists = [] + for key in missing_dependencies: + dist = self[key].dist + if id(dist) in reset_dists: + pass + else: + getattr(dist, reset_func)() + reset_dists.append(id(dist)) + @property def fixed_keys(self): return [ @@ -499,13 +523,16 @@ def _estimate_normalization(self, keys, min_accept, sampling_chunk): factor = len(keep) / np.count_nonzero(keep) return factor - def prob(self, sample, **kwargs): + def prob(self, sample, normalized=True, **kwargs): """ Parameters ========== sample: dict Dictionary of the samples of which we want to have the probability of + normalized: bool + When False, disables calculation of constraint normalization factor + during prior probability computation. Default value is True. kwargs: The keyword arguments are passed directly to `np.prod` @@ -516,10 +543,16 @@ def prob(self, sample, **kwargs): """ prob = np.prod([self[key].prob(sample[key]) for key in sample], **kwargs) - return self.check_prob(sample, prob) + # ensure that `reset_request()` of all JointPrior.dist + # with missing dependencies is called + self._reset_jointprior_dists_with_missed_dependencies(sample.keys(), reset_func="reset_request") + return self.check_prob(sample, prob, normalized) - def check_prob(self, sample, prob): - ratio = self.normalize_constraint_factor(tuple(sample.keys())) + def check_prob(self, sample, prob, normalized=True): + if normalized: + ratio = self.normalize_constraint_factor(tuple(sample.keys())) + else: + ratio = 1 if np.all(prob == 0.0): return prob * ratio else: @@ -534,18 +567,18 @@ def check_prob(self, sample, prob): constrained_prob[keep] = prob[keep] * ratio return constrained_prob - def ln_prob(self, sample, axis=None, normalized=True): + def ln_prob(self, sample, normalized=True, **kwargs): """ Parameters ========== sample: dict Dictionary of the samples of which to calculate the log probability - axis: None or int - Axis along which the summation is performed normalized: bool When False, disables calculation of constraint normalization factor during prior probability computation. Default value is True. + kwargs: + The keyword arguments are passed directly to `np.prod` Returns ======= @@ -553,7 +586,11 @@ def ln_prob(self, sample, axis=None, normalized=True): Joint log probability of all the individual sample probabilities """ - ln_prob = np.sum([self[key].ln_prob(sample[key]) for key in sample], axis=axis) + ln_prob = np.sum([self[key].ln_prob(sample[key]) for key in sample], **kwargs) + + # ensure that `reset_request()` of all JointPrior.dist + # with missing dependencies is called + self._reset_jointprior_dists_with_missed_dependencies(sample.keys(), "reset_request") return self.check_ln_prob(sample, ln_prob, normalized=normalized) @@ -600,18 +637,24 @@ def rescale(self, keys, theta): ========== keys: list List of prior keys to be rescaled - theta: list - List of randomly drawn values on a unit cube associated with the prior keys + theta: dict or array-like + Randomly drawn values on a unit cube associated with the prior keys Returns ======= - list: List of floats containing the rescaled sample + list: + If theta is 1D, returns list of floats containing the rescaled sample. + If theta is 2D, returns list of lists containing the rescaled samples. """ - theta = list(theta) + theta = [theta[key] for key in keys] if isinstance(theta, dict) else list(theta) samples = [] for key, units in zip(keys, theta): samps = self[key].rescale(units) - samples += list(np.asarray(samps).flatten()) + samples.append(samps) + for i, samps in enumerate(samples): + # turns 0d-arrays into scalars + samples[i] = np.squeeze(samps).tolist() + self._reset_jointprior_dists_with_missed_dependencies(keys, "reset_rescale") return samples def test_redundancy(self, key, disable_logging=False): @@ -667,8 +710,6 @@ def __init__(self, dictionary=None, filename=None, conversion_function=None): self._conditional_keys = [] self._unconditional_keys = [] self._rescale_keys = [] - self._rescale_indexes = [] - self._least_recently_rescaled_keys = [] super(ConditionalPriorDict, self).__init__( dictionary=dictionary, filename=filename, @@ -712,42 +753,54 @@ def _check_conditions_resolved(self, key, sampled_keys): for k in self[key].required_variables: if k not in sampled_keys: conditions_resolved = False + break + elif isinstance(self[k], JointPrior): + dependencies = self.jointprior_dependencies[k] + if len(set(dependencies) - set(sampled_keys)) > 0: + conditions_resolved = False + break return conditions_resolved def sample_subset(self, keys=iter([]), size=None): + keys = list(keys) self.convert_floats_to_delta_functions() - add_delta_keys = [ - key - for key in self.keys() - if key not in keys and isinstance(self[key], DeltaFunction) - ] - use_keys = add_delta_keys + list(keys) - subset_dict = ConditionalPriorDict({key: self[key] for key in use_keys}) - if not subset_dict._resolved: - raise IllegalConditionsException( - "The current set of priors contains unresolvable conditions." - ) + add_delta_keys = [] + for key in self.keys(): + if key not in keys and isinstance(self[key], DeltaFunction): + add_delta_keys.append(key) + + use_keys = add_delta_keys + keys + unconditional_use_keys = [key for key in self.unconditional_keys if key in use_keys] + sorted_conditional_use_keys = [key for key in self.conditional_keys if key in use_keys] + + for i, key in enumerate(sorted_conditional_use_keys): + if not self._check_conditions_resolved(key, unconditional_use_keys + sorted_conditional_use_keys[:i]): + raise IllegalConditionsException( + "The current set of priors contains unresolvable conditions." + ) + sorted_use_keys = unconditional_use_keys + sorted_conditional_use_keys samples = dict() - for key in subset_dict.sorted_keys: + for key in sorted_use_keys: if key not in keys or isinstance(self[key], Constraint): continue if isinstance(self[key], Prior): try: - samples[key] = subset_dict[key].sample( - size=size, **subset_dict.get_required_variables(key) + samples[key] = self[key].sample( + size=size, **self.get_required_variables(key) ) except ValueError: # Some prior classes can not handle an array of conditional parameters (e.g. alpha for PowerLaw) # If that is the case, we sample each sample individually. - required_variables = subset_dict.get_required_variables(key) + required_variables = self.get_required_variables(key) samples[key] = np.zeros(size) for i in range(size): rvars = { key: value[i] for key, value in required_variables.items() } - samples[key][i] = subset_dict[key].sample(**rvars) + samples[key][i] = self[key].sample(**rvars) else: logger.debug("{} not a known prior.".format(key)) + self._reset_jointprior_dists_with_missed_dependencies(keys, "reset_sampled") return samples def get_required_variables(self, key): @@ -788,6 +841,7 @@ def prob(self, sample, **kwargs): for key in sample ] prob = np.prod(res, **kwargs) + self._reset_jointprior_dists_with_missed_dependencies(sample.keys(), "reset_request") return self.check_prob(sample, prob) def ln_prob(self, sample, axis=None, normalized=True): @@ -814,6 +868,7 @@ def ln_prob(self, sample, axis=None, normalized=True): for key in sample ] ln_prob = np.sum(res, axis=axis) + self._reset_jointprior_dists_with_missed_dependencies(sample.keys(), "reset_request") return self.check_ln_prob(sample, ln_prob, normalized=normalized) @@ -832,38 +887,50 @@ def rescale(self, keys, theta): ========== keys: list List of prior keys to be rescaled - theta: list - List of randomly drawn values on a unit cube associated with the prior keys + theta: dict or array-like + Randomly drawn values on a unit cube associated with the prior keys Returns ======= - list: List of floats containing the rescaled sample + list: + If theta is float for each key, returns list of floats containing the rescaled sample. + If theta is array-like for each key, returns list of lists containing the rescaled samples. """ keys = list(keys) - theta = list(theta) - self._check_resolved() - self._update_rescale_keys(keys) + + unconditional_keys = [key for key in self.unconditional_keys if key in keys] + sorted_conditional_keys = [key for key in self.conditional_keys if key in keys] + + for i, key in enumerate(sorted_conditional_keys): + if not self._check_conditions_resolved(key, unconditional_keys + sorted_conditional_keys[:i]): + raise IllegalConditionsException( + "The current set of priors contains unresolvable conditions." + ) + sorted_keys = unconditional_keys + sorted_conditional_keys + theta = [theta[key] for key in sorted_keys] if isinstance(theta, dict) else list(theta) result = dict() - for key, index in zip( - self.sorted_keys_without_fixed_parameters, self._rescale_indexes - ): - result[key] = self[key].rescale( - theta[index], **self.get_required_variables(key) - ) + for key, vals in zip(sorted_keys, theta): + try: + result[key] = self[key].rescale(vals, **self.get_required_variables(key)) + except ValueError: + # Some prior classes can not handle an array of conditional parameters (e.g. alpha for PowerLaw) + # If that is the case, we sample each sample individually. + required_variables = self.get_required_variables(key) + result[key] = np.zeros_like(vals) + for i in range(len(vals)): + rvars = { + key: value[i] for key, value in required_variables.items() + } + result[key][i] = self[key].rescale(vals[i], **rvars) self[key].least_recently_sampled = result[key] samples = [] for key in keys: - samples += list(np.asarray(result[key]).flatten()) + # turns 0d-arrays into scalars + res = np.squeeze(result[key]).tolist() + samples.append(res) + self._reset_jointprior_dists_with_missed_dependencies(keys, "reset_rescale") return samples - def _update_rescale_keys(self, keys): - if not keys == self._least_recently_rescaled_keys: - self._rescale_indexes = [ - keys.index(element) - for element in self.sorted_keys_without_fixed_parameters - ] - self._least_recently_rescaled_keys = keys - def _prepare_evaluation(self, keys, theta): self._check_resolved() for key, value in zip(keys, theta): diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index 43c8913e..172382c2 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -63,8 +63,11 @@ def __init__(self, names, bounds=None): self.requested_parameters = dict() self.reset_request() - # a dictionary of the rescaled parameters - self.rescale_parameters = dict() + # a dictionary that stores the unit-cube values of parameters for later rescaling + self._current_unit_cube_parameter_values = dict() + # a dictionary of arrays that are used as intermediate return values of JointPrior.rescale() + # and updated in-place once all parameters have been requested + self._current_rescaled_parameter_values = dict() self.reset_rescale() # a list of sampled parameters @@ -94,15 +97,24 @@ def filled_rescale(self): Check if all the rescaled parameters have been filled. """ - return not np.any([val is None for val in self.rescale_parameters.values()]) + return not np.any([val is None for val in self._current_unit_cube_parameter_values.values()]) + + def set_rescale(self, key, values): + self._current_unit_cube_parameter_values[key] = np.array(values) + self._current_rescaled_parameter_values[key] = np.full_like(values, np.nan, dtype=float) def reset_rescale(self): """ Reset the rescaled parameters to None. """ - for name in self.names: - self.rescale_parameters[name] = None + self._current_unit_cube_parameter_values[name] = None + self._current_rescaled_parameter_values[name] = None + + def get_rescaled(self, key): + """Return an array that will be updated in-place once the rescale-operation + has been performed.""" + return self._current_rescaled_parameter_values[key] def get_instantiation_dict(self): subclass_args = infer_args_from_method(self.__init__) @@ -172,13 +184,13 @@ def _split_repr(cls, string): kwargs[key.strip()] = arg return kwargs - def prob(self, samp): + def prob(self, samp, **kwargs): """ Get the probability of a sample. For bounded priors the probability will not be properly normalised. """ - return np.exp(self.ln_prob(samp)) + return np.exp(self.ln_prob(samp, **kwargs)) def _check_samp(self, value): """ @@ -209,14 +221,12 @@ def _check_samp(self, value): raise ValueError("Array is the wrong shape") # check sample(s) is within bounds - outbounds = np.ones(samp.shape[0], dtype=bool) + outbounds = np.zeros(samp.shape[0], dtype=bool) for s, bound in zip(samp.T, self.bounds.values()): - outbounds = (s < bound[0]) | (s > bound[1]) - if np.any(outbounds): - break + outbounds += (s < bound[0]) | (s > bound[1]) return samp, outbounds - def ln_prob(self, value): + def ln_prob(self, value, **kwargs): """ Get the log-probability of a sample. For bounded priors the probability will not be properly normalised. @@ -231,7 +241,7 @@ def ln_prob(self, value): samp, outbounds = self._check_samp(value) lnprob = -np.inf * np.ones(samp.shape[0]) - lnprob = self._ln_prob(samp, lnprob, outbounds) + lnprob = self._ln_prob(samp, lnprob, outbounds, **kwargs) if samp.shape[0] == 1: return lnprob[0] else: @@ -303,10 +313,12 @@ def rescale(self, value, **kwargs): Parameters ========== - value: array - A 1d vector sample (one for each parameter) drawn from a uniform + value: array or None + If given, a 1d vector sample (one for each parameter) drawn from a uniform distribution between 0 and 1, or a 2d NxM array of samples where N is the number of samples and M is the number of parameters. + If None, the values previously set using BaseJointPriorDist.set_rescale() are used, + the result is stored and can be accessed using get_rescaled(). kwargs: dict All keyword args that need to be passed to _rescale method, these keyword args are called in the JointPrior rescale methods for each parameter @@ -317,16 +329,30 @@ def rescale(self, value, **kwargs): An vector sample drawn from the multivariate Gaussian distribution. """ - samp = np.array(value) - if len(samp.shape) == 1: - samp = samp.reshape(1, self.num_vars) - - if len(samp.shape) != 2: - raise ValueError("Array is the wrong shape") - elif samp.shape[1] != self.num_vars: - raise ValueError("Array is the wrong shape") + if value is None: + if not self.filled_rescale(): + raise ValueError("Attempting to rescale from stored values without having set all required values.") + samp = np.array([self._current_unit_cube_parameter_values[key] for key in self.names]).T + if len(samp.shape) == 1: + samp = samp.reshape(1, self.num_vars) + else: + samp = np.asarray(value) + if len(samp.shape) == 1: + samp = samp.reshape(1, self.num_vars) + if len(samp.shape) != 2: + raise ValueError("Array is the wrong shape") + elif samp.shape[1] != self.num_vars: + raise ValueError("Array is the wrong shape") samp = self._rescale(samp, **kwargs) + # only store result if the rescale was done with saved unit-cube values + if value is None: + for i, key in enumerate(self.names): + # get the numpy array used for indermediate outputs + # prior to a full rescale-operation + output = self.get_rescaled(key) + # update the array in-place + output[...] = samp[:, i] return np.squeeze(samp) def _rescale(self, samp, **kwargs): @@ -612,76 +638,95 @@ def add_mode(self, mus=None, sigmas=None, corrcoef=None, cov=None, weight=1.0): ) def _rescale(self, samp, **kwargs): - try: - mode = kwargs["mode"] - except KeyError: - mode = None + mode = kwargs.get("mode", self.mode) if mode is None: if self.nmodes == 1: mode = 0 else: - mode = np.argwhere(self.cumweights - random.rng.uniform(0, 1) > 0)[0][0] + mode = random.rng.choice( + self.nmodes, size=len(samp), p=self.weights + ) samp = erfinv(2.0 * samp - 1) * 2.0 ** 0.5 # rotate and scale to the multivariate normal shape - samp = self.mus[mode] + self.sigmas[mode] * np.einsum( - "ij,kj->ik", samp * self.sqeigvalues[mode], self.eigvectors[mode] - ) + uniques = np.unique(mode) + if len(uniques) == 1: + unique = uniques[0] + samp = self.mus[unique] + self.sigmas[unique] * np.einsum( + "ij,kj->ik", samp * self.sqeigvalues[unique], self.eigvectors[unique] + ) + else: + mode = np.asarray(mode) + if mode.shape != (samp.shape[0],): + raise ValueError(f"Inconsistent sizes of the array-like used to select modes " + f"with shape {mode.shape} and the array of requested samps " + f"with length {len(samp)}.") + for m in uniques: + mask = m == mode + samp[mask] = self.mus[m] + self.sigmas[m] * np.einsum( + "ij,kj->ik", samp[mask] * self.sqeigvalues[m], self.eigvectors[m] + ) + return samp def _sample(self, size, **kwargs): - try: - mode = kwargs["mode"] - except KeyError: - mode = None + mode = kwargs.get("mode", self.mode) + + samps = np.zeros((size, len(self))) + outbound = np.ones(size, dtype=bool) if mode is None: if self.nmodes == 1: mode = 0 - else: - if size == 1: - mode = np.argwhere(self.cumweights - random.rng.uniform(0, 1) > 0)[0][0] - else: - # pick modes - mode = [ - np.argwhere(self.cumweights - r > 0)[0][0] - for r in random.rng.uniform(0, 1, size) - ] + while np.any(outbound): + # sample the multivariate Gaussian keys + vals = random.rng.uniform(0, 1, (np.sum(outbound), len(self))) - samps = np.zeros((size, len(self))) - for i in range(size): - inbound = False - while not inbound: - # sample the multivariate Gaussian keys - vals = random.rng.uniform(0, 1, len(self)) - - if isinstance(mode, list): - samp = np.atleast_1d(self.rescale(vals, mode=mode[i])) - else: - samp = np.atleast_1d(self.rescale(vals, mode=mode)) - samps[i, :] = samp + if mode is None: + mode = random.rng.choice( + self.nmodes, size=np.sum(outbound), p=self.weights + ) - # check sample is in bounds (otherwise perform another draw) - outbound = False - for name, val in zip(self.names, samp): - if val < self.bounds[name][0] or val > self.bounds[name][1]: - outbound = True - break + samps[outbound] = np.atleast_1d(self.rescale(vals, mode=mode)) - if not outbound: - inbound = True + # check sample is in bounds and redraw those which are not + samps, outbound = self._check_samp(samps) return samps - def _ln_prob(self, samp, lnprob, outbounds): - for j in range(samp.shape[0]): + def _ln_prob(self, samp, lnprob, outbounds, **kwargs): + mode = kwargs.get("mode", self.mode) + + if mode is None: # loop over the modes and sum the probabilities for i in range(self.nmodes): # self.mvn[i] is a "standard" multivariate normal distribution; see add_mode() - z = (samp[j] - self.mus[i]) / self.sigmas[i] - lnprob[j] = np.logaddexp(lnprob[j], self.mvn[i].logpdf(z) - self.logprodsigmas[i]) + z = (samp - self.mus[i]) / self.sigmas[i] + lnprob = np.logaddexp( + lnprob, + self.mvn[i].logpdf(z) - self.logprodsigmas[i] + np.log(self.weights[i]) + ) + else: + uniques = np.unique(np.asarray(mode, dtype=int)) + if len(uniques) == 1: + unique = uniques[0] + z = (samp - self.mus[unique]) / self.sigmas[unique] + # don't multiply by the mode weight if the mode is given (ie. prob(mode|mode) = 1) + lnprob = np.logaddexp(lnprob, self.mvn[unique].logpdf(z) - self.logprodsigmas[unique]) + else: + mode = np.asarray(mode) + print(mode.shape, samp.shape) + if mode.shape != (samp.shape[0],): + raise ValueError(f"Inconsistent sizes of the array-like used to select modes " + f"with shape {mode.shape} and the array of requested samps " + f"with length {len(samp)}.") + for m in uniques: + mask = mode == m + z = (samp[mask] - self.mus[m]) / self.sigmas[m] + # don't multiply by the mode weight if the mode is given (ie. prob(mode|mode) = 1) + lnprob[mask] = np.logaddexp(lnprob[mask], self.mvn[m].logpdf(z) - self.logprodsigmas[m]) # set out-of-bounds values to -inf lnprob[outbounds] = -np.inf @@ -721,13 +766,28 @@ def __eq__(self, other): return False return True + @property + def mode(self): + if hasattr(self, "_mode"): + return self._mode + else: + return None + + @mode.setter + def mode(self, mode): + if not np.issubdtype(np.asarray(mode).dtype, np.integer): + raise ValueError("The mode to set must have integral data type.") + if np.any(mode >= self.nmodes) or np.any(mode < 0): + raise ValueError("The value of mode cannot be higher than the number of modes or smaller than zero.") + self._mode = mode + class MultivariateNormalDist(MultivariateGaussianDist): """A synonym for the :class:`~bilby.core.prior.MultivariateGaussianDist` distribution.""" class JointPrior(Prior): - def __init__(self, dist, name=None, latex_label=None, unit=None): + def __init__(self, dist, name=None, latex_label=None, unit=None, **kwargs): """This defines the single parameter Prior object for parameters that belong to a JointPriorDist Parameters @@ -778,6 +838,17 @@ def maximum(self, maximum): self._maximum = maximum self.dist.bounds[self.name] = (self.dist.bounds[self.name][0], maximum) + def __setattr__(self, name, value): + # first check that the JointPrior has an explicit setter method for the attribute, which should take presedence + if hasattr(self.__class__, name) and getattr(self.__class__, name).fset is not None: + return super().__setattr__(name, value) + # then check if the BaseJointPriorDist-!subclass! has an explicit setter method for the attribute + elif hasattr(self, "dist") and hasattr(self.dist, name) and getattr(self.dist.__class__, name).fset is not None: + return self.dist.__setattr__(name, value) + # if not, use the default settattr + else: + return super().__setattr__(name, value) + def rescale(self, val, **kwargs): """ Scale a unit hypercube sample to the prior. @@ -790,19 +861,41 @@ def rescale(self, val, **kwargs): all kwargs passed to the dist.rescale method Returns ======= - float: - A sample from the prior parameter. + np.ndarray: + The samples from the prior parameter. If not all names in "dist" have been filled, + the array contains only np.nan. *This* specific array instance will be filled with + the rescaled value once all parameters have been requested """ - self.dist.rescale_parameters[self.name] = val + if self.dist.get_rescaled(self.name) is not None: + import warnings + warnings.warn( + f"Rescale values for {self.name} in {self.dist} have already been set, " + "indicating that another rescale-operation is in progress.\n" + "Call dist.reset_rescale() on the joint prior distribution associated " + "with this prior after using dist.rescale() and make sure all parameters " + "necessary for rescaling have been requested in PriorDict.rescale().\n" + "Resetting now.", + RuntimeWarning + ) + self.dist.reset_rescale() + self.dist.set_rescale(self.name, val) if self.dist.filled_rescale(): - values = np.array(list(self.dist.rescale_parameters.values())).T - samples = self.dist.rescale(values, **kwargs) + # If all names have been filled, perform rescale operation + self.dist.rescale(value=None, **kwargs) + # get the rescaled values for the requested parameter + output = self.dist.get_rescaled(self.name) + # reset the rescale operation self.dist.reset_rescale() - return samples else: - return [] # return empty list + # If not all names have been filled, return a *numpy array* + # filled only with `np.nan`. Once all names have been requested, + # this array is updated *in-place* with the rescaled values. + output = self.dist.get_rescaled(self.name) + + # have to return raw output to conserve in-place modifications + return output def sample(self, size=1, **kwargs): """ diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index 20c0cda9..1382eae3 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -334,34 +334,99 @@ def test_rescale_with_joint_prior(self): # set multivariate Gaussian distribution names = ["mvgvar_0", "mvgvar_1"] - mu = [[0.79, -0.83]] + mu = [[1, 1]] cov = [[[0.03, 0.], [0., 0.04]]] mvg = bilby.core.prior.MultivariateGaussianDist(names, mus=mu, covs=cov) + names_2 = ["mvgvar_a", "mvgvar_b"] + mvg_dual_mode = bilby.core.prior.MultivariateGaussianDist( + names=names_2, + nmodes=2, + mus=[mu[0], (np.array(mu[0]) + np.ones_like(mu[0])).tolist()], + covs=[cov[0], cov[0]], + weights=[1, 2] + ) + + def condition_func_2(reference_params, var_0): + return dict(mode=np.searchsorted(np.cumsum(np.array([1, 2]) / 3), var_0)) + + def condition_func_1(reference_params, var_0, var_1): + return {"minimum": var_0 - 1, "maximum": var_1 + 1} + + def condition_func_5(reference_parameters, mvgvar_a): + return dict(minimum=reference_parameters["minimum"], maximum=mvgvar_a) + + prior_5 = bilby.core.prior.ConditionalUniform( + condition_func=condition_func_5, minimum=self.minimum, maximum=self.maximum + ) + priordict = bilby.core.prior.ConditionalPriorDict( dict( + var_5=prior_5, + mvgvar_a=bilby.core.prior.ConditionalJointPrior( + condition_func_2, dist=mvg_dual_mode, name="mvgvar_a", + minimum=self.minimum, maximum=self.maximum, mode=None), var_3=self.prior_3, var_2=self.prior_2, var_0=self.prior_0, var_1=self.prior_1, - mvgvar_0=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_0"), - mvgvar_1=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_1"), + mvgvar_0=bilby.core.prior.ConditionalJointPrior( + condition_func_1, dist=mvg, name="mvgvar_0", minimum=self.minimum, maximum=self.maximum), + mvgvar_1=bilby.core.prior.ConditionalJointPrior( + condition_func_1, dist=mvg, name="mvgvar_1", minimum=self.minimum, maximum=self.maximum), + mvgvar_b=bilby.core.prior.ConditionalJointPrior( + condition_func_2, dist=mvg_dual_mode, name="mvgvar_b", + minimum=self.minimum, maximum=self.maximum, mode=None), ) ) - ref_variables = list(self.test_sample.values()) + [0.4, 0.1] - keys = list(self.test_sample.keys()) + names + ref_variables = self.test_sample.copy() + ref_variables.update({"mvgvar_0": 0.5, "mvgvar_1": 0.5, "mvgvar_a": 0.5, "mvgvar_b": 0.2, "var_5": 0.5}) + keys = list(self.test_sample.keys()) + names + names_2 + ["var_5"] res = priordict.rescale(keys=keys, theta=ref_variables) self.assertIsInstance(res, list) - self.assertEqual(np.shape(res), (6,)) - self.assertListEqual([isinstance(r, float) for r in res], 6 * [True]) + self.assertEqual(np.shape(res), (9,)) + self.assertListEqual([isinstance(r, float) for r in res], 9 * [True]) # check conditional values are still as expected expected = [self.test_sample["var_0"]] + self.assertFalse(np.any(np.isnan(res))) for ii in range(1, 4): expected.append(expected[-1] * self.test_sample[f"var_{ii}"]) - self.assertListEqual(expected, res[0:4]) + expected.extend([1, 1]) + self.assertListEqual(expected, res[:6]) + res_sample = priordict.sample(1) + self.assertEqual(list(res_sample.keys()), priordict.sorted_keys_without_fixed_parameters) + res_sample = priordict.sample(1000) + self.assertListEqual([len(val) for val in res_sample.values()], [1000] * len(res_sample.keys())) + lnprobs = priordict.ln_prob(priordict.sample(10), axis=0) + self.assertEqual(len(lnprobs), 10) + + with self.assertRaises(bilby.core.prior.IllegalConditionsException): + keys = set(priordict.keys()) - set(["mvgvar_a"]) + priordict.rescale(keys=keys, theta=ref_variables) + + def condition_func_6(reference_params, var_5): + return dict(mode=np.searchsorted(np.cumsum(np.array([1, 2]) / 3), var_5)) + + priordict_unresolveable = bilby.core.prior.ConditionalPriorDict( + dict( + var_5=prior_5, + var_3=self.prior_3, + var_2=self.prior_2, + var_0=self.prior_0, + var_1=self.prior_1, + mvgvar_a=bilby.core.prior.ConditionalJointPrior( + condition_func_6, dist=mvg_dual_mode, name="mvgvar_a", + minimum=self.minimum, maximum=self.maximum, mode=None), + mvgvar_b=bilby.core.prior.ConditionalJointPrior( + condition_func_6, dist=mvg_dual_mode, name="mvgvar_b", + minimum=self.minimum, maximum=self.maximum, mode=None), + + ) + ) + self.assertEqual(priordict_unresolveable._resolved, False) def test_cdf(self): """ @@ -378,10 +443,11 @@ def test_cdf(self): ) def test_rescale_illegal_conditions(self): - del self.conditional_priors["var_0"] + test_sample = self.test_sample.copy() + test_sample.pop("var_0") with self.assertRaises(bilby.core.prior.IllegalConditionsException): self.conditional_priors.rescale( - keys=list(self.test_sample.keys()), + keys=list(test_sample.keys()), theta=list(self.test_sample.values()), ) diff --git a/test/core/prior/dict_test.py b/test/core/prior/dict_test.py index 08e730bb..b8421601 100644 --- a/test/core/prior/dict_test.py +++ b/test/core/prior/dict_test.py @@ -33,8 +33,17 @@ def setUp(self): name="b", alpha=3, minimum=1, maximum=2, unit="m/s", boundary=None ) self.third_prior = bilby.core.prior.DeltaFunction(name="c", peak=42, unit="m") + + mvg = bilby.core.prior.MultivariateGaussianDist( + names=["testa", "testb"], + mus=[1, 1], + covs=np.array([[2.0, 0.5], [0.5, 2.0]]), + weights=1.0, + ) + self.testa = bilby.core.prior.MultivariateGaussian(dist=mvg, name="testa", unit="unit") + self.testb = bilby.core.prior.MultivariateGaussian(dist=mvg, name="testb", unit="unit") self.priors = dict( - mass=self.first_prior, speed=self.second_prior, length=self.third_prior + mass=self.first_prior, speed=self.second_prior, length=self.third_prior, testa=self.testa, testb=self.testb ) self.prior_set_from_dict = bilby.core.prior.PriorDict(dictionary=self.priors) self.default_prior_file = os.path.join( @@ -70,7 +79,7 @@ def test_prior_set_is_dict(self): self.assertIsInstance(self.prior_set_from_dict, dict) def test_prior_set_has_correct_length(self): - self.assertEqual(3, len(self.prior_set_from_dict)) + self.assertEqual(5, len(self.prior_set_from_dict)) def test_prior_set_has_expected_priors(self): self.assertDictEqual(self.priors, dict(self.prior_set_from_dict)) @@ -160,6 +169,12 @@ def test_to_file(self): "unit='m/s', boundary=None)\n", "mass = Uniform(minimum=0, maximum=1, name='a', latex_label='a', " "unit='kg', boundary=None)\n", + "testa_testb_mvg = MultivariateGaussianDist(names=['testa', 'testb'], nmodes=1, mus=[[1, 1]], " + "sigmas=[[1.4142135623730951, 1.4142135623730951]], " + "corrcoefs=[[[0.9999999999999998, 0.24999999999999994], [0.24999999999999994, 0.9999999999999998]]], " + "covs=[[[2.0, 0.5], [0.5, 2.0]]], weights=[1.0], bounds={'testa': (-inf, inf), 'testb': (-inf, inf)})\n", + "testa = MultivariateGaussian(dist=testa_testb_mvg, name='testa', latex_label='testa', unit='unit')\n", + "testb = MultivariateGaussian(dist=testa_testb_mvg, name='testb', latex_label='testb', unit='unit')\n", ] self.prior_set_from_dict.to_file(outdir="prior_files", label="to_file_test") with open("prior_files/to_file_test.prior") as f: @@ -178,6 +193,13 @@ def test_from_dict_with_string(self): self.assertDictEqual(self.prior_set_from_dict, from_dict) def test_convert_floats_to_delta_functions(self): + mvg = bilby.core.prior.MultivariateGaussianDist( + names=["testa", "testb"], + mus=[1, 1], + covs=np.array([[2.0, 0.5], [0.5, 2.0]]), + weights=1.0, + ) + self.prior_set_from_dict["d"] = 5 self.prior_set_from_dict["e"] = 7.3 self.prior_set_from_dict["f"] = "unconvertable" @@ -190,6 +212,8 @@ def test_convert_floats_to_delta_functions(self): name="b", alpha=3, minimum=1, maximum=2, unit="m/s", boundary=None ), length=bilby.core.prior.DeltaFunction(name="c", peak=42, unit="m"), + testa=bilby.core.prior.MultivariateGaussian(dist=mvg, name="testa", unit="unit"), + testb=bilby.core.prior.MultivariateGaussian(dist=mvg, name="testb", unit="unit"), d=bilby.core.prior.DeltaFunction(peak=5), e=bilby.core.prior.DeltaFunction(peak=7.3), f="unconvertable", @@ -287,6 +311,11 @@ def test_sample_subset_with_actual_subset(self): expected = dict(length=np.array([42.0, 42.0, 42.0])) self.assertTrue(np.array_equal(expected["length"], samples["length"])) + joint_prior = self.joint_prior_from_file + samples = joint_prior.sample_subset(keys=["testAbase"], size=size) + self.assertTrue(joint_prior["testAbase"].dist.sampled_parameters == []) + self.assertTrue(joint_prior["testBbase"].dist.sampled_parameters == []) + def test_sample_subset_constrained_as_array(self): size = 3 keys = ["mass", "speed"] @@ -320,13 +349,25 @@ def test_ln_prob(self): ) + self.second_prior.ln_prob(samples["speed"]) self.assertEqual(expected, self.prior_set_from_dict.ln_prob(samples)) + def test_ln_prob_actual_subset(self): + joint_prior = self.joint_prior_from_file + keys = ["testAbase"] + samples = joint_prior.sample_subset(keys=keys, size=1) + lnprob = joint_prior.ln_prob(samples) + self.assertTrue(joint_prior["testAbase"].dist.requested_parameters["testAbase"] is None) + self.assertTrue(joint_prior["testBbase"].dist.requested_parameters["testBbase"] is None) + self.assertTrue(lnprob == 0) + def test_rescale(self): - theta = [0.5, 0.5, 0.5] + theta = [0.5, 0.5, 0.5, 0.5, 0.5] expected = [ self.first_prior.rescale(0.5), self.second_prior.rescale(0.5), self.third_prior.rescale(0.5), + self.testa.rescale(0.5), + self.testb.rescale(0.5) ] + assert not np.any(np.isnan(expected)) self.assertListEqual( sorted(expected), sorted( @@ -336,13 +377,22 @@ def test_rescale(self): ), ) + def test_rescale_actual_subset(self): + theta = [0.5] + keys = ["testAbase"] + joint_prior = self.joint_prior_from_file + samples = joint_prior.rescale(keys=keys, theta=theta) + self.assertTrue(joint_prior["testAbase"].dist._current_rescaled_parameter_values["testAbase"] is None) + self.assertTrue(joint_prior["testBbase"].dist._current_rescaled_parameter_values["testBbase"] is None) + self.assertTrue(np.all(np.isnan(samples))) + def test_cdf(self): """ Test that the CDF method is the inverse of the rescale method. Note that the format of inputs/outputs is different between the two methods. """ - sample = self.prior_set_from_dict.sample() + sample = self.prior_set_from_dict.sample_subset(keys=["length", "speed", "mass"]) original = np.array(list(sample.values())) new = np.array(self.prior_set_from_dict.rescale( sample.keys(), diff --git a/test/core/prior/joint_test.py b/test/core/prior/joint_test.py index c99373b0..784cd303 100644 --- a/test/core/prior/joint_test.py +++ b/test/core/prior/joint_test.py @@ -40,33 +40,76 @@ def test_mvg_from_repr(self): class TestMultivariateGaussianDistParameterScales(unittest.TestCase): - def _test_mvg_ln_prob_diff_expected(self, mvg, mus, sigmas, corrcoefs): - # the columns of the Cholesky decompsition give the directions along which - # the multivariate Gaussian PDF will decrease by exact differences per unit - # sigma; test that these are as expected - ln_prob_mus = mvg.ln_prob(mus) - d = np.linalg.cholesky(corrcoefs) - for i in np.ndindex(4, 4, 4): - ln_prob_mus_sigmas_d_i = mvg.ln_prob(mus + sigmas * (d @ i)) - diff_ln_prob = ln_prob_mus - ln_prob_mus_sigmas_d_i - diff_ln_prob_expected = 0.5 * np.sum(np.array(i)**2) - self.assertTrue( - np.allclose(diff_ln_prob, diff_ln_prob_expected) - ) + def _test_mvg_ln_prob_diff_expected(self, mvg, weights, muss, sigmass, corrcoefss): + all_test_points = [] + all_expected_probs = [] + + # first test all modes individually and store the results + for mode, (weight, mus, sigmas, corrcoefs) in enumerate(zip(weights, muss, sigmass, corrcoefss)): + # the columns of the Cholesky decompsition give the directions along which + # the multivariate Gaussian PDF will decrease by exact differences per unit + # sigma; test that these are as expected + ln_prob_mus = mvg.ln_prob(mus, mode=mode) + d = np.linalg.cholesky(corrcoefs) + test_points = [] + test_point_probs = [] + for i in np.ndindex(4, 4, 4): + ln_prob_mus_sigmas_d_i = mvg.ln_prob(mus + sigmas * (d @ i), mode=mode) + test_points.append(mus + sigmas * (d @ i)) + test_point_probs.append(ln_prob_mus_sigmas_d_i) + diff_ln_prob = ln_prob_mus - ln_prob_mus_sigmas_d_i + diff_ln_prob_expected = 0.5 * np.sum(np.array(i)**2) + self.assertTrue( + np.allclose(diff_ln_prob, diff_ln_prob_expected) + ) + test_point_probs_at_once = mvg.ln_prob(test_points, mode=mode) + + np.testing.assert_allclose(test_point_probs, test_point_probs_at_once) + all_test_points.append(test_points) + all_expected_probs.append(test_point_probs_at_once) + + # For the points associated with each mode, we have calculated the expected probabilities + # above. We can test if the other modes are taken into account correctly by: + # 1. For the points with known values for *one* mode, calculate the probabilties summed over all modes + # 2. Calculate the probabilities for these points for all modes except the one for which we know the true values + # 3. Subtract the probability over all other modes from the total probability + # 4. Check if the resulting probability (adjusted for the weight of the mode) is equal to the + # expected value + for current_mode, (test_points, expected_probs) in enumerate(zip(all_test_points, all_expected_probs)): + test_point_probs_at_once_all_modes = mvg.ln_prob(test_points) + prob_other_modes = np.full(len(test_points), -np.inf) + for mode, weight in enumerate(weights): + if mode == current_mode: + continue + prob_other_modes = np.logaddexp(prob_other_modes, np.log(weight) + mvg.ln_prob(test_points, mode=mode)) + ln_prob_current_mode = np.log(np.exp(test_point_probs_at_once_all_modes) - np.exp(prob_other_modes)) + ln_prob_current_mode -= np.log(weights[current_mode]) + np.testing.assert_allclose(ln_prob_current_mode, expected_probs) def test_mvg_unit_scales(self): # test using order-unity standard deviations and correlations - sigmas = 0.3 * np.ones(3) - corrcoefs = np.identity(3) - mus = np.array([3, 1, 2]) + sigmas_1 = 0.3 * np.ones(3) + corrcoefs_1 = np.identity(3) + mus_1 = np.array([3, 1, 2]) + + sigmas_2 = 0.4 * np.ones(3) + corrcoefs_2 = np.identity(3) + mus_2 = np.array([3, 1, 2]) + + sigmas_3 = 0.1 * np.ones(3) + corrcoefs_3 = np.identity(3) + mus_3 = np.array([3.2, 1., 2.5]) + weights = [0.5, 0.3, 0.2] mvg = bilby.core.prior.MultivariateGaussianDist( + nmodes=3, names=['a', 'b', 'c'], - mus=mus, - sigmas=sigmas, - corrcoefs=corrcoefs, + mus=[mus_1, mus_2, mus_3], + sigmas=[sigmas_1, sigmas_2, sigmas_3], + corrcoefs=[corrcoefs_1, corrcoefs_2, corrcoefs_3], + weights=weights ) - self._test_mvg_ln_prob_diff_expected(mvg, mus, sigmas, corrcoefs) + self._test_mvg_ln_prob_diff_expected(mvg, mvg.weights, mvg.mus, mvg.sigmas, mvg.corrcoefs) def test_mvg_cw_scales(self): # test using standard deviations and correlations from the @@ -92,7 +135,7 @@ def test_mvg_cw_scales(self): corrcoefs=corrcoefs, ) - self._test_mvg_ln_prob_diff_expected(mvg, mus, sigmas, corrcoefs) + self._test_mvg_ln_prob_diff_expected(mvg, [1], [mus], [sigmas], [corrcoefs]) if __name__ == "__main__":