Skip to content

Commit

Permalink
Added WeightedCategorical-Prior
Browse files Browse the repository at this point in the history
  • Loading branch information
JasperMartins committed Jan 22, 2025
1 parent 36470d5 commit aed4c4d
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 36 deletions.
126 changes: 102 additions & 24 deletions bilby/core/prior/analytical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1439,34 +1439,63 @@ def ln_prob(self, val):
return lnp


class Categorical(Prior):
def __init__(self, ncategories, name=None, latex_label=None,
unit=None, boundary="periodic"):
""" An equal-weighted Categorical prior
class WeightedCategorical(Prior):
def __init__(
self,
ncategories,
weights=None,
name=None,
latex_label=None,
unit=None,
boundary="periodic",
):
"""A weighted Categorical prior
Parameters
==========
ncategories: int
The number of available categories. The prior mass support is then
integers [0, ncategories - 1].
weights: array_like
The weights of each category. If None, then all categories are
equally weighted.
name: str
See superclass
The name of the parameter
latex_label: str
See superclass
The latex label of the parameter. Used for plotting.
unit: str
See superclass
The unit of the parameter. Used for plotting.
"""

minimum = 0
# Small delta added to help with MCMC walking
maximum = ncategories - 1 + 1e-15
super(Categorical, self).__init__(
name=name, latex_label=latex_label, minimum=minimum,
maximum=maximum, unit=unit, boundary=boundary)
super().__init__(
name=name,
latex_label=latex_label,
minimum=minimum,
maximum=maximum,
unit=unit,
boundary=boundary,
)
self.ncategories = ncategories
self.categories = np.arange(self.minimum, self.maximum)
self.p = 1 / self.ncategories
self.lnp = -np.log(self.ncategories)
self.categories = np.arange(self.minimum, self.maximum, dtype=int)

p = (
np.atleast_1d(weights) / np.sum(weights)
if weights is not None
else np.ones(self.ncategories) / self.ncategories
)
# check for consistent shape of input
if len(p) != self.ncategories or len(p.shape) != 1:
raise ValueError(
"Inconsistent shape of weights and number of categories:"
+ f"np.atleast_1d(weights) has shape {p.shape} "
+ f"while number of categories is {self.ncategories}")
self.p = p
# save cdf for rescaling
self._cum_p = np.cumsum(p)
self.lnp = np.log(self.p)

def rescale(self, val):
"""
Expand All @@ -1483,7 +1512,30 @@ def rescale(self, val):
=======
Union[float, array_like]: Rescaled probability
"""
return np.floor(val * (1 + self.maximum))
return np.searchsorted(self._cum_p, val)

def cdf(self, val):
"""Return the cumulative prior probability of val.
Parameters
==========
val: Union[float, int, array_like]
Returns
=======
float: cumulative prior probability of val
"""
if (not hasattr(val, "__len__")):
if val in self.categories:
return self._cum_p[int(val)]
else:
return 0
else:
val = np.atleast_1d(val).astype(int)
cumprobs = np.zeros_like(val, dtype=np.float64)
idxs = np.isin(val, self.categories)
cumprobs[idxs] = self._cum_p[val[idxs]]
return cumprobs

def prob(self, val):
"""Return the prior probability of val.
Expand All @@ -1496,16 +1548,16 @@ def prob(self, val):
=======
float: Prior probability of val
"""
if isinstance(val, (float, int)):
if (not hasattr(val, "__len__")):
if val in self.categories:
return self.p
return self.p[int(val)]
else:
return 0
else:
val = np.atleast_1d(val)
val = np.atleast_1d(val).astype(int)
probs = np.zeros_like(val, dtype=np.float64)
idxs = np.isin(val, self.categories)
probs[idxs] = self.p
probs[idxs] = self.p[val[idxs]]
return probs

def ln_prob(self, val):
Expand All @@ -1520,17 +1572,43 @@ def ln_prob(self, val):
float:
"""
if isinstance(val, (float, int)):
if (not hasattr(val, "__len__")):
if val in self.categories:
return self.lnp
return self.lnp[int(val)]
else:
return -np.inf
else:
val = np.atleast_1d(val)
probs = -np.inf * np.ones_like(val, dtype=np.float64)
val = np.atleast_1d(val).astype(int)
lnprobs = np.ones_like(val, dtype=np.float64) * (-np.inf)
idxs = np.isin(val, self.categories)
probs[idxs] = self.lnp
return probs
lnprobs[idxs] = self.lnp[val[idxs]]
return lnprobs


class Categorical(WeightedCategorical):
def __init__(self, ncategories, name=None, latex_label=None,
unit=None, boundary="periodic"):
""" An equal-weighted Categorical prior
Parameters
==========
ncategories: int
The number of available categories. The prior mass support is then
integers [0, ncategories - 1].
name: str
See superclass
latex_label: str
See superclass
unit: str
See superclass
"""
super().__init__(ncategories=ncategories,
weights=None,
name=name,
latex_label=latex_label,
unit=unit,
boundary=boundary
)


class Triangular(Prior):
Expand Down
100 changes: 88 additions & 12 deletions test/core/prior/analytical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,7 @@ def test_single_probability(self):
def test_array_probability(self):
N = 3
categorical_prior = bilby.core.prior.Categorical(N)
self.assertTrue(
np.all(
categorical_prior.prob([0, 1, 1, 2, 3])
== np.array([1 / N, 1 / N, 1 / N, 1 / N, 0])
)
)
self.assertTrue(np.all(categorical_prior.prob([0, 1, 1, 2, 3]) == np.array([1 / N, 1 / N, 1 / N, 1 / N, 0])))

def test_single_lnprobability(self):
N = 3
Expand All @@ -58,12 +53,93 @@ def test_single_lnprobability(self):
def test_array_lnprobability(self):
N = 3
categorical_prior = bilby.core.prior.Categorical(N)
self.assertTrue(
np.all(
categorical_prior.ln_prob([0, 1, 1, 2, 3])
== np.array([-np.log(N), -np.log(N), -np.log(N), -np.log(N), -np.inf])
)
)
self.assertTrue(np.all(categorical_prior.ln_prob([0, 1, 1, 2, 3]) == np.array(
[-np.log(N), -np.log(N), -np.log(N), -np.log(N), -np.inf])))


class TestWeightedCategoricalPrior(unittest.TestCase):
def test_single_sample(self):
categorical_prior = bilby.core.prior.WeightedCategorical(3, [1, 2, 3])
in_prior = True
for _ in range(1000):
s = categorical_prior.sample()
if s not in [0, 1, 2]:
in_prior = False
self.assertTrue(in_prior)

def test_fail_init(self):
with self.assertRaises(ValueError):
bilby.core.prior.WeightedCategorical(3, [[1, 2], [2, 3], [3, 4]])
with self.assertRaises(ValueError):
bilby.core.prior.WeightedCategorical(3, [1, 2, 3, 4])

def test_array_sample(self):
ncat = 4
weights = np.arange(1, ncat + 1)
categorical_prior = bilby.core.prior.WeightedCategorical(ncat, weights=weights)
N = 100000
s = categorical_prior.sample(N)
cases = 0
for i in categorical_prior.categories:
print(i)
case = np.sum(s == i)
cases += case
self.assertAlmostEqual(case / N, categorical_prior.prob(i), places=int(np.log10(np.sqrt(N))))
self.assertAlmostEqual(case / N, weights[i] / np.sum(weights), places=int(np.log10(np.sqrt(N))))
self.assertEqual(cases, N)

def test_single_probability(self):
N = 3
weights = np.arange(1, N + 1)
categorical_prior = bilby.core.prior.WeightedCategorical(N, weights=weights)
for i in categorical_prior.categories:
self.assertEqual(categorical_prior.prob(i), weights[i] / np.sum(weights))
self.assertEqual(categorical_prior.prob(0.5), 0)

def test_array_probability(self):
N = 3
test_cases = [0, 1, 1, 2, 3]
weights = np.arange(1, N + 1)
categorical_prior = bilby.core.prior.WeightedCategorical(N, weights=weights)
probs = np.arange(1, N + 2) / np.sum(weights)
probs[-1] = 0
self.assertTrue(np.all(categorical_prior.prob(test_cases) == probs[test_cases]))

def test_single_lnprobability(self):
N = 3
weights = np.arange(1, N + 1)
categorical_prior = bilby.core.prior.WeightedCategorical(N, weights=weights)
for i in categorical_prior.categories:
self.assertEqual(categorical_prior.ln_prob(i), np.log(weights[i] / np.sum(weights)))
self.assertEqual(categorical_prior.prob(0.5), 0)

def test_array_lnprobability(self):
N = 3
test_cases = [0, 1, 1, 2, 3]
weights = np.arange(1, N + 1)

categorical_prior = bilby.core.prior.WeightedCategorical(N, weights=weights)
ln_probs = np.log(np.arange(1, N + 2) / np.sum(weights))
ln_probs[-1] = -np.inf

self.assertTrue(np.all(categorical_prior.ln_prob(test_cases) == ln_probs[test_cases]))

def test_cdf(self):
"""
Test that the CDF method is the inverse of the rescale method.
Note that the format of inputs/outputs is different between the two methods.
"""
N = 3
weights = np.arange(1, N + 1)

categorical_prior = bilby.core.prior.WeightedCategorical(N, weights=weights)
sample = categorical_prior.sample(size=10)
original = np.asarray(sample)
new = np.array(categorical_prior.rescale(
categorical_prior.cdf(sample)
))
np.testing.assert_array_equal(original, new)


if __name__ == "__main__":
Expand Down

0 comments on commit aed4c4d

Please sign in to comment.