From 41cb703c5a0c283843095ab94750c5f9ee5ba491 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 8 Aug 2024 12:40:53 +0000 Subject: [PATCH] FEAT: enable caching to be disabled in hyper.model.Model --- bilby/hyper/model.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/bilby/hyper/model.py b/bilby/hyper/model.py index 359274079..3e84fbb6b 100644 --- a/bilby/hyper/model.py +++ b/bilby/hyper/model.py @@ -1,7 +1,7 @@ from ..core.utils import infer_args_from_function_except_n_args -class Model(object): +class Model: r""" Population model that combines a set of factorizable models. @@ -12,18 +12,24 @@ class Model(object): p(\theta | \Lambda) = \prod_{i} p_{i}(\theta | \Lambda) """ - def __init__(self, model_functions=None): + def __init__(self, model_functions=None, cache=True): """ Parameters ========== model_functions: list List of callables to compute the probability. - If this includes classes, the `__call__` method should return the - probability. + If this includes classes, the :code:`__call__`: method + should return the probability. The requires variables are chosen at run time based on either inspection or querying a :code:`variable_names` attribute. + cache: bool + Whether to cache the value returned by the model functions, + default=:code:`True`. The caching only looks at the parameters + not the data, so should be used with caution. The caching also + breaks :code:`jax` JIT compilation. """ self.models = model_functions + self.cache = cache self._cached_parameters = {model: None for model in self.models} self._cached_probability = {model: None for model in self.models} @@ -48,14 +54,18 @@ def prob(self, data, **kwargs): probability = 1.0 for ii, function in enumerate(self.models): function_parameters = self._get_function_parameters(function) - if self._cached_parameters[function] == function_parameters: + if ( + self.cache + and self._cached_parameters[function] == function_parameters + ): new_probability = self._cached_probability[function] else: new_probability = function( data, **self._get_function_parameters(function) ) - self._cached_parameters[function] = function_parameters - self._cached_probability[function] = new_probability + if self.cache: + self._cached_parameters[function] = function_parameters + self._cached_probability[function] = new_probability probability *= new_probability return probability