You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
BlackJax currently implements Sequential Monte Carlo for the static case, when the target distribution doesn’t change across iterations. This method is more of a template than an algorithm, in the sense that there are many possible instantiations of an SMC static sampler that are composed of possibly many different inner algorithms.
The “sequential” nature of SMC is due to the fact that a sequence of intermediate targets are traversed, in order to reach the end target distribution. Currently BJ supports three ways of doing that: adaptive tempering, tempering from a fixed sequence, and partial posteriors (data tempering).
Moreover, internally we support SMC samplers that mutate the set of particles using MCMC chains. We support two ways of running the chain: either take several steps and keep the last result, or run a waste-free algorithm.
Those MCMC parameter chains can be chosen for performance. We call tuning when the values are chosen based on particles/sampling information. Pretuning holds a set of parameter samples across iterations, and chooses them accordingly to chain performance.
As you may imagine from the above, there are many combinations of all these pieces, and users may even provide their own implementation of any of them and combine them with existing ones.
Our APIs.
Today we support building SMC samplers either via calling build_kernel() or via top_level_api() (which are mapped, respectively, to blackjax.algorithm_name() and blackjax.algorithm_name.build_kernel()).
Let me now describe two interrelated situations: the increased complexity of top_level_api() on SMC samplers and coupling on implementations.
Imagine you would like to build an SMC sampler with adaptive_tempering. Internally adaptive_tempered creates a tempered kernel, which creates an from_mcmc object, which creates a base SMC object. This means that you can’t run adaptive_tempered with a different implementation of tempered, can’t run tempering without creating a from_mcmc object, and so on. In other words, the creation of that function chain is coupled to its execution. Moreover, this pattern is repeated in both build_kernel and top_level_api.
This means that, for example, if you want to have different ways of updating the particles as discussed above, you have to add a parameter to all intermediate functions.
A proposal
Decouple build_kernel from the creation of dependencies. adaptive_tempered shouldn’t create a tempered kernel but take one as a parameter. This will reduce the coupling of these functions, reducing their dependencies to just what they need to have a single responsibility. Moreover, this would foster composition for more research-oriented users.
Keep top_level_api as it is in terms of parameters: it's the way users currently use samplers when they don’t need to modify them. Internally, these functions will need to create the dependencies (in the adaptive tempered case, it needs to create a tempered kernel).
Incorporate a Builder API. This one should make use of build_kernel functions internally. This is possibly going to become the de facto way of building BJ SMC samplers in the future. It should foster building consistent objects while keeping the interface low on coupling.
I have made a POC on how this Builder API would look like, and how the refactor on build_kernel () functions would look like. Check the tests: I have rewritten every single SMC instance using the new API, while keeping the top_level_api in place. See #773
The text was updated successfully, but these errors were encountered:
Context
BlackJax currently implements Sequential Monte Carlo for the static case, when the target distribution doesn’t change across iterations. This method is more of a template than an algorithm, in the sense that there are many possible instantiations of an SMC static sampler that are composed of possibly many different inner algorithms.
The “sequential” nature of SMC is due to the fact that a sequence of intermediate targets are traversed, in order to reach the end target distribution. Currently BJ supports three ways of doing that: adaptive tempering, tempering from a fixed sequence, and partial posteriors (data tempering).
Moreover, internally we support SMC samplers that mutate the set of particles using MCMC chains. We support two ways of running the chain: either take several steps and keep the last result, or run a waste-free algorithm.
Those MCMC parameter chains can be chosen for performance. We call tuning when the values are chosen based on particles/sampling information. Pretuning holds a set of parameter samples across iterations, and chooses them accordingly to chain performance.
As you may imagine from the above, there are many combinations of all these pieces, and users may even provide their own implementation of any of them and combine them with existing ones.
Our APIs.
Today we support building SMC samplers either via calling build_kernel() or via top_level_api() (which are mapped, respectively, to blackjax.algorithm_name() and blackjax.algorithm_name.build_kernel()).
Let me now describe two interrelated situations: the increased complexity of top_level_api() on SMC samplers and coupling on implementations.
Imagine you would like to build an SMC sampler with adaptive_tempering. Internally adaptive_tempered creates a tempered kernel, which creates an from_mcmc object, which creates a base SMC object. This means that you can’t run adaptive_tempered with a different implementation of tempered, can’t run tempering without creating a from_mcmc object, and so on. In other words, the creation of that function chain is coupled to its execution. Moreover, this pattern is repeated in both build_kernel and top_level_api.
This means that, for example, if you want to have different ways of updating the particles as discussed above, you have to add a parameter to all intermediate functions.
A proposal
I have made a POC on how this Builder API would look like, and how the refactor on build_kernel () functions would look like. Check the tests: I have rewritten every single SMC instance using the new API, while keeping the top_level_api in place. See #773
The text was updated successfully, but these errors were encountered: