Skip to content

Commit

Permalink
FEAT: enable caching to be disabled in hyper.model.Model
Browse files Browse the repository at this point in the history
  • Loading branch information
ColmTalbot committed Aug 8, 2024
1 parent e37c308 commit 41cb703
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions bilby/hyper/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ..core.utils import infer_args_from_function_except_n_args


class Model(object):
class Model:
r"""
Population model that combines a set of factorizable models.
Expand All @@ -12,18 +12,24 @@ class Model(object):
p(\theta | \Lambda) = \prod_{i} p_{i}(\theta | \Lambda)
"""

def __init__(self, model_functions=None):
def __init__(self, model_functions=None, cache=True):
"""
Parameters
==========
model_functions: list
List of callables to compute the probability.
If this includes classes, the `__call__` method should return the
probability.
If this includes classes, the :code:`__call__`: method
should return the probability.
The requires variables are chosen at run time based on either
inspection or querying a :code:`variable_names` attribute.
cache: bool
Whether to cache the value returned by the model functions,
default=:code:`True`. The caching only looks at the parameters
not the data, so should be used with caution. The caching also
breaks :code:`jax` JIT compilation.
"""
self.models = model_functions
self.cache = cache
self._cached_parameters = {model: None for model in self.models}
self._cached_probability = {model: None for model in self.models}

Expand All @@ -48,14 +54,18 @@ def prob(self, data, **kwargs):
probability = 1.0
for ii, function in enumerate(self.models):
function_parameters = self._get_function_parameters(function)
if self._cached_parameters[function] == function_parameters:
if (
self.cache
and self._cached_parameters[function] == function_parameters
):
new_probability = self._cached_probability[function]
else:
new_probability = function(
data, **self._get_function_parameters(function)
)
self._cached_parameters[function] = function_parameters
self._cached_probability[function] = new_probability
if self.cache:
self._cached_parameters[function] = function_parameters
self._cached_probability[function] = new_probability
probability *= new_probability
return probability

Expand Down

0 comments on commit 41cb703

Please sign in to comment.