Skip to content

Commit

Permalink
Handle JointPrior's better in rescale, sample, (ln)prob, and _check_c…
Browse files Browse the repository at this point in the history
…onditions_resolved of (Condtional)PriorDict - keep track of dependencies of JointPriors necessary for their complete evaluation and handle cases where not all necessary keys are requested.
  • Loading branch information
JasperMartins committed Nov 22, 2024
1 parent 40377f5 commit e5f3f34
Showing 1 changed file with 56 additions and 8 deletions.
64 changes: 56 additions & 8 deletions bilby/core/prior/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,9 @@ def sample_subset(self, keys=iter([]), size=None):
samples[key] = self[key].sample(size=size)
else:
logger.debug("{} not a known prior.".format(key))
# ensure that `reset_sampled()` of all JointPrior.dist
# with missing dependencies is called
self._reset_jointprior_dists_with_missed_dependencies(keys, "reset_sampled")
return samples

@property
Expand All @@ -430,6 +433,27 @@ def non_fixed_keys(self):
keys = [k for k in keys if k not in self.constraint_keys]
return keys

@property
def jointprior_dependencies(self):
keys = self.keys()
keys = [k for k in keys if isinstance(self[k], JointPrior)]
dependencies = {k: list(set(self[k].dist.names) - set([k])) for k in keys}
return dependencies

def _reset_jointprior_dists_with_missed_dependencies(self, keys, reset_func):
keys = set(keys)
dependencies = self.jointprior_dependencies
requested_jointpriors = set(dependencies).intersection()
missing_dependencies = {value for key in requested_jointpriors for value in dependencies[key]}
reset_dists = []
for key in missing_dependencies:
dist = self[key].dist
if id(dist) in reset_dists:
pass
else:
getattr(dist, reset_func)()
reset_dists.append(id(dist))

@property
def fixed_keys(self):
return [
Expand Down Expand Up @@ -499,13 +523,16 @@ def _estimate_normalization(self, keys, min_accept, sampling_chunk):
factor = len(keep) / np.count_nonzero(keep)
return factor

def prob(self, sample, **kwargs):
def prob(self, sample, normalized=True, **kwargs):
"""
Parameters
==========
sample: dict
Dictionary of the samples of which we want to have the probability of
normalized: bool
When False, disables calculation of constraint normalization factor
during prior probability computation. Default value is True.
kwargs:
The keyword arguments are passed directly to `np.prod`
Expand All @@ -516,10 +543,16 @@ def prob(self, sample, **kwargs):
"""
prob = np.prod([self[key].prob(sample[key]) for key in sample], **kwargs)

return self.check_prob(sample, prob)
# ensure that `reset_request()` of all JointPrior.dist
# with missing dependencies is called
self._reset_jointprior_dists_with_missed_dependencies(sample.keys(), reset_func="reset_request")
return self.check_prob(sample, prob, normalized)

def check_prob(self, sample, prob):
ratio = self.normalize_constraint_factor(tuple(sample.keys()))
def check_prob(self, sample, prob, normalized=True):
if normalized:
ratio = self.normalize_constraint_factor(tuple(sample.keys()))
else:
ratio = 1
if np.all(prob == 0.0):
return prob * ratio
else:
Expand All @@ -534,26 +567,30 @@ def check_prob(self, sample, prob):
constrained_prob[keep] = prob[keep] * ratio
return constrained_prob

def ln_prob(self, sample, axis=None, normalized=True):
def ln_prob(self, sample, normalized=True, **kwargs):
"""
Parameters
==========
sample: dict
Dictionary of the samples of which to calculate the log probability
axis: None or int
Axis along which the summation is performed
normalized: bool
When False, disables calculation of constraint normalization factor
during prior probability computation. Default value is True.
kwargs:
The keyword arguments are passed directly to `np.prod`
Returns
=======
float or ndarray:
Joint log probability of all the individual sample probabilities
"""
ln_prob = np.sum([self[key].ln_prob(sample[key]) for key in sample], axis=axis)
ln_prob = np.sum([self[key].ln_prob(sample[key]) for key in sample], **kwargs)

# ensure that `reset_request()` of all JointPrior.dist
# with missing dependencies is called
self._reset_jointprior_dists_with_missed_dependencies(sample.keys(), "reset_request")
return self.check_ln_prob(sample, ln_prob,
normalized=normalized)

Expand Down Expand Up @@ -615,6 +652,7 @@ def rescale(self, keys, theta):
samps = self[key].rescale(units)
# turns 0d-arrays into scalars
samples.append(np.squeeze(samps).tolist())
self._reset_jointprior_dists_with_missed_dependencies(keys, "reset_rescale")
return samples

def test_redundancy(self, key, disable_logging=False):
Expand Down Expand Up @@ -713,6 +751,12 @@ def _check_conditions_resolved(self, key, sampled_keys):
for k in self[key].required_variables:
if k not in sampled_keys:
conditions_resolved = False
break
elif isinstance(self[k], JointPrior):
dependencies = self.jointprior_dependencies[k]
if len(set(dependencies) - set(sampled_keys)) > 0:
conditions_resolved = False
break
return conditions_resolved

def sample_subset(self, keys=iter([]), size=None):
Expand Down Expand Up @@ -754,6 +798,7 @@ def sample_subset(self, keys=iter([]), size=None):
samples[key][i] = self[key].sample(**rvars)
else:
logger.debug("{} not a known prior.".format(key))
self._reset_jointprior_dists_with_missed_dependencies(keys, "reset_sampled")
return samples

def get_required_variables(self, key):
Expand Down Expand Up @@ -794,6 +839,7 @@ def prob(self, sample, **kwargs):
for key in sample
]
prob = np.prod(res, **kwargs)
self._reset_jointprior_dists_with_missed_dependencies(sample.keys(), "reset_request")
return self.check_prob(sample, prob)

def ln_prob(self, sample, axis=None, normalized=True):
Expand All @@ -820,6 +866,7 @@ def ln_prob(self, sample, axis=None, normalized=True):
for key in sample
]
ln_prob = np.sum(res, axis=axis)
self._reset_jointprior_dists_with_missed_dependencies(sample.keys(), "reset_request")
return self.check_ln_prob(sample, ln_prob,
normalized=normalized)

Expand Down Expand Up @@ -879,6 +926,7 @@ def rescale(self, keys, theta):
# turns 0d-arrays into scalars
res = np.squeeze(result[key]).tolist()
samples.append(res)
self._reset_jointprior_dists_with_missed_dependencies(keys, "reset_rescale")
return samples

def _prepare_evaluation(self, keys, theta):
Expand Down

0 comments on commit e5f3f34

Please sign in to comment.