From 8eed424c181f0c5c0b5cf5e50780115d14704fc6 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Jan 2025 12:42:00 -0500 Subject: [PATCH 01/10] add static adjusted mclmc --- blackjax/__init__.py | 2 + blackjax/mcmc/__init__.py | 4 +- blackjax/mcmc/adjusted_mclmc.py | 57 ++---- blackjax/mcmc/adjusted_mclmc_dynamic.py | 257 ++++++++++++++++++++++++ tests/mcmc/test_sampling.py | 110 +++++++++- 5 files changed, 384 insertions(+), 46 deletions(-) create mode 100644 blackjax/mcmc/adjusted_mclmc_dynamic.py diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 6a0de3809..35c9e3b58 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -13,6 +13,7 @@ from .diagnostics import effective_sample_size as ess from .diagnostics import potential_scale_reduction as rhat from .mcmc import adjusted_mclmc as _adjusted_mclmc +from .mcmc import adjusted_mclmc_dynamic as _adjusted_mclmc_dynamic from .mcmc import barker from .mcmc import dynamic_hmc as _dynamic_hmc from .mcmc import elliptical_slice as _elliptical_slice @@ -112,6 +113,7 @@ def generate_top_level_api_from(module): additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk) mclmc = generate_top_level_api_from(_mclmc) +adjusted_mclmc_dynamic = generate_top_level_api_from(_adjusted_mclmc_dynamic) adjusted_mclmc = generate_top_level_api_from(_adjusted_mclmc) elliptical_slice = generate_top_level_api_from(_elliptical_slice) ghmc = generate_top_level_api_from(_ghmc) diff --git a/blackjax/mcmc/__init__.py b/blackjax/mcmc/__init__.py index 1e1317684..fad5dcb97 100644 --- a/blackjax/mcmc/__init__.py +++ b/blackjax/mcmc/__init__.py @@ -1,5 +1,5 @@ from . import ( - adjusted_mclmc, + adjusted_mclmc_dynamic, barker, elliptical_slice, ghmc, @@ -25,5 +25,5 @@ "marginal_latent_gaussian", "random_walk", "mclmc", - "adjusted_mclmc", + "adjusted_mclmc_dynamic", ] diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index 81fbc2835..8288772a3 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -11,7 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin".""" +"""Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin". + +NOTE: For best performance, we recommend using adjusted_mclmc_dynamic instead of this module, which is primarily intended for use in parallelized versions of the algorithm. + +""" from typing import Callable, Union import jax @@ -19,28 +23,26 @@ import blackjax.mcmc.integrators as integrators from blackjax.base import SamplingAlgorithm -from blackjax.mcmc.dynamic_hmc import DynamicHMCState, halton_sequence -from blackjax.mcmc.hmc import HMCInfo +from blackjax.mcmc.hmc import HMCInfo, HMCState from blackjax.mcmc.proposal import static_binomial_sampling -from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_unit_vector __all__ = ["init", "build_kernel", "as_top_level_api"] -def init(position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: Array): +def init(position: ArrayLikeTree, logdensity_fn: Callable): logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) - return DynamicHMCState(position, logdensity, logdensity_grad, random_generator_arg) + return HMCState(position, logdensity, logdensity_grad) def build_kernel( - integration_steps_fn, + num_integration_steps: int, integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, - next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], sqrt_diag_cov=1.0, ): - """Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly. + """Build an MHMCHMC kernel where the number of integration steps is chosen randomly. Parameters ---------- @@ -63,15 +65,13 @@ def build_kernel( def kernel( rng_key: PRNGKey, - state: DynamicHMCState, + state: HMCState, logdensity_fn: Callable, step_size: float, L_proposal_factor: float = jnp.inf, - ) -> tuple[DynamicHMCState, HMCInfo]: + ) -> tuple[HMCState, HMCInfo]: """Generate a new sample with the MHMCHMC kernel.""" - num_integration_steps = integration_steps_fn(state.random_generator_arg) - key_momentum, key_integrator = jax.random.split(rng_key, 2) momentum = generate_unit_vector(key_momentum, state.position) proposal, info, _ = adjusted_mclmc_proposal( @@ -90,11 +90,10 @@ def kernel( ) return ( - DynamicHMCState( + HMCState( proposal.position, proposal.logdensity, proposal.logdensity_grad, - next_random_arg_fn(state.random_generator_arg), ), info, ) @@ -110,10 +109,9 @@ def as_top_level_api( *, divergence_threshold: int = 1000, integrator: Callable = integrators.isokinetic_mclachlan, - next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], - integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), + num_integration_steps, ) -> SamplingAlgorithm: - """Implements the (basic) user interface for the dynamic MHMCHMC kernel. + """Implements the (basic) user interface for the MHMCHMC kernel. Parameters ---------- @@ -140,15 +138,15 @@ def as_top_level_api( """ kernel = build_kernel( - integration_steps_fn=integration_steps_fn, + num_integration_steps, integrator=integrator, - next_random_arg_fn=next_random_arg_fn, sqrt_diag_cov=sqrt_diag_cov, divergence_threshold=divergence_threshold, ) - def init_fn(position: ArrayLikeTree, rng_key: Array): - return init(position, logdensity_fn, rng_key) + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, logdensity_fn) def update_fn(rng_key: PRNGKey, state): return kernel( @@ -240,18 +238,3 @@ def generate( return sampled_state, info, other_proposal_info return generate - - -def rescale(mu): - """returns s, such that - round(U(0, 1) * s + 0.5) - has expected value mu. - """ - k = jnp.floor(2 * mu - 1) - x = k * (mu - 0.5 * (k + 1)) / (k + 1 - mu) - return k + x - - -def trajectory_length(t, mu): - s = rescale(mu) - return jnp.rint(0.5 + halton_sequence(t) * s) diff --git a/blackjax/mcmc/adjusted_mclmc_dynamic.py b/blackjax/mcmc/adjusted_mclmc_dynamic.py new file mode 100644 index 000000000..81fbc2835 --- /dev/null +++ b/blackjax/mcmc/adjusted_mclmc_dynamic.py @@ -0,0 +1,257 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin".""" +from typing import Callable, Union + +import jax +import jax.numpy as jnp + +import blackjax.mcmc.integrators as integrators +from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.dynamic_hmc import DynamicHMCState, halton_sequence +from blackjax.mcmc.hmc import HMCInfo +from blackjax.mcmc.proposal import static_binomial_sampling +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.util import generate_unit_vector + +__all__ = ["init", "build_kernel", "as_top_level_api"] + + +def init(position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: Array): + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) + return DynamicHMCState(position, logdensity, logdensity_grad, random_generator_arg) + + +def build_kernel( + integration_steps_fn, + integrator: Callable = integrators.isokinetic_mclachlan, + divergence_threshold: float = 1000, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + sqrt_diag_cov=1.0, +): + """Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly. + + Parameters + ---------- + integrator + The integrator to use to integrate the Hamiltonian dynamics. + divergence_threshold + Value of the difference in energy above which we consider that the transition is divergent. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. Needs to return an `int`. + + Returns + ------- + A kernel that takes a rng_key and a Pytree that contains the current state + of the chain and that returns a new state of the chain along with + information about the transition. + """ + + def kernel( + rng_key: PRNGKey, + state: DynamicHMCState, + logdensity_fn: Callable, + step_size: float, + L_proposal_factor: float = jnp.inf, + ) -> tuple[DynamicHMCState, HMCInfo]: + """Generate a new sample with the MHMCHMC kernel.""" + + num_integration_steps = integration_steps_fn(state.random_generator_arg) + + key_momentum, key_integrator = jax.random.split(rng_key, 2) + momentum = generate_unit_vector(key_momentum, state.position) + proposal, info, _ = adjusted_mclmc_proposal( + integrator=integrators.with_isokinetic_maruyama( + integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) + ), + step_size=step_size, + L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size), + num_integration_steps=num_integration_steps, + divergence_threshold=divergence_threshold, + )( + key_integrator, + integrators.IntegratorState( + state.position, momentum, state.logdensity, state.logdensity_grad + ), + ) + + return ( + DynamicHMCState( + proposal.position, + proposal.logdensity, + proposal.logdensity_grad, + next_random_arg_fn(state.random_generator_arg), + ), + info, + ) + + return kernel + + +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + L_proposal_factor: float = jnp.inf, + sqrt_diag_cov=1.0, + *, + divergence_threshold: int = 1000, + integrator: Callable = integrators.isokinetic_mclachlan, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), +) -> SamplingAlgorithm: + """Implements the (basic) user interface for the dynamic MHMCHMC kernel. + + Parameters + ---------- + logdensity_fn + The log-density function we wish to draw samples from. + step_size + The value to use for the step size in the symplectic integrator. + divergence_threshold + The absolute value of the difference in energy between two states above + which we say that the transition is divergent. The default value is + commonly found in other libraries, and yet is arbitrary. + integrator + (algorithm parameter) The symplectic integrator to use to integrate the trajectory. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. + + + Returns + ------- + A ``SamplingAlgorithm``. + """ + + kernel = build_kernel( + integration_steps_fn=integration_steps_fn, + integrator=integrator, + next_random_arg_fn=next_random_arg_fn, + sqrt_diag_cov=sqrt_diag_cov, + divergence_threshold=divergence_threshold, + ) + + def init_fn(position: ArrayLikeTree, rng_key: Array): + return init(position, logdensity_fn, rng_key) + + def update_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + L_proposal_factor, + ) + + return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type] + + +def adjusted_mclmc_proposal( + integrator: Callable, + step_size: Union[float, ArrayLikeTree], + L_proposal_factor: float, + num_integration_steps: int = 1, + divergence_threshold: float = 1000, + *, + sample_proposal: Callable = static_binomial_sampling, +) -> Callable: + """Vanilla MHMCHMC algorithm. + + The algorithm integrates the trajectory applying a integrator + `num_integration_steps` times in one direction to get a proposal and uses a + Metropolis-Hastings acceptance step to either reject or accept this + proposal. This is what people usually refer to when they talk about "the + HMC algorithm". + + Parameters + ---------- + integrator + integrator used to build the trajectory step by step. + kinetic_energy + Function that computes the kinetic energy. + step_size + Size of the integration step. + num_integration_steps + Number of times we run the integrator to build the trajectory + divergence_threshold + Threshold above which we say that there is a divergence. + + Returns + ------- + A kernel that generates a new chain state and information about the transition. + + """ + + def step(i, vars): + state, kinetic_energy, rng_key = vars + rng_key, next_rng_key = jax.random.split(rng_key) + next_state, next_kinetic_energy = integrator( + state, step_size, L_proposal_factor, rng_key + ) + + return next_state, kinetic_energy + next_kinetic_energy, next_rng_key + + def build_trajectory(state, num_integration_steps, rng_key): + return jax.lax.fori_loop( + 0 * num_integration_steps, num_integration_steps, step, (state, 0, rng_key) + ) + + def generate( + rng_key, state: integrators.IntegratorState + ) -> tuple[integrators.IntegratorState, HMCInfo, ArrayTree]: + """Generate a new chain state.""" + end_state, kinetic_energy, rng_key = build_trajectory( + state, num_integration_steps, rng_key + ) + + new_energy = -end_state.logdensity + delta_energy = -state.logdensity + end_state.logdensity - kinetic_energy + delta_energy = jnp.where(jnp.isnan(delta_energy), -jnp.inf, delta_energy) + is_diverging = -delta_energy > divergence_threshold + sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state) + do_accept, p_accept, other_proposal_info = info + + info = HMCInfo( + state.momentum, + p_accept, + do_accept, + is_diverging, + new_energy, + end_state, + num_integration_steps, + ) + + return sampled_state, info, other_proposal_info + + return generate + + +def rescale(mu): + """returns s, such that + round(U(0, 1) * s + 0.5) + has expected value mu. + """ + k = jnp.floor(2 * mu - 1) + x = k * (mu - 0.5 * (k + 1)) / (k + 1 - mu) + return k + x + + +def trajectory_length(t, mu): + s = rescale(mu) + return jnp.rint(0.5 + halton_sequence(t) * s) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 474f67293..45d60f84a 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -14,7 +14,7 @@ import blackjax.diagnostics as diagnostics import blackjax.mcmc.random_walk from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info -from blackjax.mcmc.adjusted_mclmc import rescale +from blackjax.mcmc.adjusted_mclmc_dynamic import rescale from blackjax.mcmc.integrators import isokinetic_mclachlan from blackjax.util import run_inference_algorithm @@ -146,7 +146,7 @@ def run_mclmc( return samples - def run_adjusted_mclmc( + def run_adjusted_mclmc_dynamic( self, logdensity_fn, num_steps, @@ -158,13 +158,13 @@ def run_adjusted_mclmc( init_key, tune_key, run_key = jax.random.split(key, 3) - initial_state = blackjax.mcmc.adjusted_mclmc.init( + initial_state = blackjax.mcmc.adjusted_mclmc_dynamic.init( position=initial_position, logdensity_fn=logdensity_fn, random_generator_arg=init_key, ) - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc.build_kernel( + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc_dynamic.build_kernel( integrator=integrator, integration_steps_fn=lambda k: jnp.ceil( jax.random.uniform(k) * rescale(avg_num_integration_steps) @@ -177,7 +177,7 @@ def run_adjusted_mclmc( logdensity_fn=logdensity_fn, ) - target_acc_rate = 0.65 + target_acc_rate = 0.9 ( blackjax_state_after_tuning, @@ -197,7 +197,7 @@ def run_adjusted_mclmc( step_size = blackjax_mclmc_sampler_params.step_size L = blackjax_mclmc_sampler_params.L - alg = blackjax.adjusted_mclmc( + alg = blackjax.adjusted_mclmc_dynamic( logdensity_fn=logdensity_fn, step_size=step_size, integration_steps_fn=lambda key: jnp.ceil( @@ -218,6 +218,73 @@ def run_adjusted_mclmc( return out + def run_adjusted_mclmc( + self, + logdensity_fn, + num_steps, + initial_position, + key, + diagonal_preconditioning=False, + ): + integrator = isokinetic_mclachlan + + init_key, tune_key, run_key = jax.random.split(key, 3) + + initial_state = blackjax.mcmc.adjusted_mclmc.init( + position=initial_position, + logdensity_fn=logdensity_fn, + ) + + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc.build_kernel( + integrator=integrator, + num_integration_steps=avg_num_integration_steps, + sqrt_diag_cov=sqrt_diag_cov, + )( + rng_key=rng_key, + state=state, + step_size=step_size, + logdensity_fn=logdensity_fn, + ) + + target_acc_rate = 0.9 + + ( + blackjax_state_after_tuning, + blackjax_mclmc_sampler_params, + ) = blackjax.adjusted_mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=num_steps, + state=initial_state, + rng_key=tune_key, + target=target_acc_rate, + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.1, + diagonal_preconditioning=diagonal_preconditioning, + ) + + step_size = blackjax_mclmc_sampler_params.step_size + L = blackjax_mclmc_sampler_params.L + + alg = blackjax.adjusted_mclmc( + logdensity_fn=logdensity_fn, + step_size=step_size, + num_integration_steps=L / step_size, + integrator=integrator, + sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, + ) + + _, out = run_inference_algorithm( + rng_key=run_key, + initial_state=blackjax_state_after_tuning, + inference_algorithm=alg, + num_steps=num_steps, + transform=lambda state, _: state.position, + progress_bar=False, + ) + + return out + @parameterized.parameters( itertools.product( regression_test_cases, [True, False], window_adaptation_filters @@ -334,7 +401,35 @@ def test_mclmc(self): np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) - def test_adjusted_mclmc(self): + @parameterized.parameters([True, False]) + def test_adjusted_mclmc_dynamic(self, diagonal_preconditioning): + """Test the MCLMC kernel.""" + + init_key0, init_key1, inference_key = jax.random.split(self.key, 3) + x_data = jax.random.normal(init_key0, shape=(1000, 1)) + y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) + + logposterior_fn_ = functools.partial( + self.regression_logprob, x=x_data, preds=y_data + ) + logdensity_fn = lambda x: logposterior_fn_(**x) + + states = self.run_adjusted_mclmc_dynamic( + initial_position={"coefs": 1.0, "log_scale": 1.0}, + logdensity_fn=logdensity_fn, + key=inference_key, + num_steps=10000, + diagonal_preconditioning=diagonal_preconditioning, + ) + + coefs_samples = states["coefs"][3000:] + scale_samples = np.exp(states["log_scale"][3000:]) + + np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + + @parameterized.parameters([True, False]) + def test_adjusted_mclmc(self, diagonal_preconditioning): """Test the MCLMC kernel.""" init_key0, init_key1, inference_key = jax.random.split(self.key, 3) @@ -351,6 +446,7 @@ def test_adjusted_mclmc(self): logdensity_fn=logdensity_fn, key=inference_key, num_steps=10000, + diagonal_preconditioning=diagonal_preconditioning, ) coefs_samples = states["coefs"][3000:] From 9dd6bdba039003538eaf635fde7678defdbc7350 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Jan 2025 13:27:08 -0500 Subject: [PATCH 02/10] add static adjusted mclmc --- .../adaptation/adjusted_mclmc_adaptation.py | 10 ++--- blackjax/adaptation/mclmc_adaptation.py | 22 +++++----- blackjax/mcmc/adjusted_mclmc.py | 10 +++-- blackjax/mcmc/adjusted_mclmc_dynamic.py | 10 +++-- blackjax/mcmc/integrators.py | 12 +++--- blackjax/mcmc/mclmc.py | 8 ++-- tests/mcmc/test_integrators.py | 4 +- tests/mcmc/test_sampling.py | 40 +++++++++++-------- 8 files changed, 64 insertions(+), 52 deletions(-) diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py index f5d54e5c9..eabb642a3 100644 --- a/blackjax/adaptation/adjusted_mclmc_adaptation.py +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -74,7 +74,7 @@ def adjusted_mclmc_find_L_and_step_size( dim = pytree_size(state.position) if params is None: params = MCLMCAdaptationState( - jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, sqrt_diag_cov=jnp.ones((dim,)) + jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, inverse_mass_matrix=jnp.ones((dim,)) ) part1_key, part2_key = jax.random.split(rng_key, 2) @@ -152,7 +152,7 @@ def step(iteration_state, weight_and_key): state=previous_state, avg_num_integration_steps=avg_num_integration_steps, step_size=params.step_size, - sqrt_diag_cov=params.sqrt_diag_cov, + inverse_mass_matrix=params.inverse_mass_matrix, ) # step updating @@ -283,9 +283,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): L=params.L * change, step_size=params.step_size * change ) if diagonal_preconditioning: - params = params._replace( - sqrt_diag_cov=jnp.sqrt(variances), L=jnp.sqrt(dim) - ) + params = params._replace(inverse_mass_matrix=variances, L=jnp.sqrt(dim)) initial_da, update_da, final_da = dual_averaging_adaptation(target=target) ( @@ -323,7 +321,7 @@ def step(state, key): state=state, step_size=params.step_size, avg_num_integration_steps=params.L / params.step_size, - sqrt_diag_cov=params.sqrt_diag_cov, + inverse_mass_matrix=params.inverse_mass_matrix, ) return next_state, next_state.position diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 8452b6171..aa192b964 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -30,13 +30,13 @@ class MCLMCAdaptationState(NamedTuple): The momentum decoherent rate for the MCLMC algorithm. step_size The step size used for the MCLMC algorithm. - sqrt_diag_cov + inverse_mass_matrix A matrix used for preconditioning. """ L: float step_size: float - sqrt_diag_cov: float + inverse_mass_matrix: float def mclmc_find_L_and_step_size( @@ -87,10 +87,10 @@ def mclmc_find_L_and_step_size( Example ------- .. code:: - kernel = lambda sqrt_diag_cov : blackjax.mcmc.mclmc.build_kernel( + kernel = lambda inverse_mass_matrix : blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, integrator=integrator, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, ) ( @@ -106,7 +106,7 @@ def mclmc_find_L_and_step_size( """ dim = pytree_size(state.position) params = MCLMCAdaptationState( - jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, sqrt_diag_cov=jnp.ones((dim,)) + jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, inverse_mass_matrix=jnp.ones((dim,)) ) part1_key, part2_key = jax.random.split(rng_key, 2) @@ -123,7 +123,7 @@ def mclmc_find_L_and_step_size( if frac_tune3 != 0: state, params = make_adaptation_L( - mclmc_kernel(params.sqrt_diag_cov), frac=frac_tune3, Lfactor=0.4 + mclmc_kernel(params.inverse_mass_matrix), frac=frac_tune3, Lfactor=0.4 )(state, params, num_steps, part2_key) return state, params @@ -152,7 +152,7 @@ def predictor(previous_state, params, adaptive_state, rng_key): rng_key, nan_key = jax.random.split(rng_key) # dynamics - next_state, info = kernel(params.sqrt_diag_cov)( + next_state, info = kernel(params.inverse_mass_matrix)( rng_key=rng_key, state=previous_state, L=params.L, @@ -247,15 +247,15 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): L = params.L # determine L - sqrt_diag_cov = params.sqrt_diag_cov + inverse_mass_matrix = params.inverse_mass_matrix if num_steps2 > 1: x_average, x_squared_average = average[0], average[1] variances = x_squared_average - jnp.square(x_average) L = jnp.sqrt(jnp.sum(variances)) if diagonal_preconditioning: - sqrt_diag_cov = jnp.sqrt(variances) - params = params._replace(sqrt_diag_cov=sqrt_diag_cov) + inverse_mass_matrix = variances + params = params._replace(inverse_mass_matrix=inverse_mass_matrix) L = jnp.sqrt(dim) # readjust the stepsize @@ -265,7 +265,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): xs=(jnp.ones(steps), keys), state=state, params=params ) - return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov) + return state, MCLMCAdaptationState(L, params.step_size, inverse_mass_matrix) return L_step_size_adaptation diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index 8288772a3..9b868562c 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -40,7 +40,7 @@ def build_kernel( num_integration_steps: int, integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, - sqrt_diag_cov=1.0, + inverse_mass_matrix=1.0, ): """Build an MHMCHMC kernel where the number of integration steps is chosen randomly. @@ -76,7 +76,9 @@ def kernel( momentum = generate_unit_vector(key_momentum, state.position) proposal, info, _ = adjusted_mclmc_proposal( integrator=integrators.with_isokinetic_maruyama( - integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) + integrator( + logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix + ) ), step_size=step_size, L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size), @@ -105,7 +107,7 @@ def as_top_level_api( logdensity_fn: Callable, step_size: float, L_proposal_factor: float = jnp.inf, - sqrt_diag_cov=1.0, + inverse_mass_matrix=1.0, *, divergence_threshold: int = 1000, integrator: Callable = integrators.isokinetic_mclachlan, @@ -140,7 +142,7 @@ def as_top_level_api( kernel = build_kernel( num_integration_steps, integrator=integrator, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, divergence_threshold=divergence_threshold, ) diff --git a/blackjax/mcmc/adjusted_mclmc_dynamic.py b/blackjax/mcmc/adjusted_mclmc_dynamic.py index 81fbc2835..1a69e1a28 100644 --- a/blackjax/mcmc/adjusted_mclmc_dynamic.py +++ b/blackjax/mcmc/adjusted_mclmc_dynamic.py @@ -38,7 +38,7 @@ def build_kernel( integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], - sqrt_diag_cov=1.0, + inverse_mass_matrix=1.0, ): """Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly. @@ -76,7 +76,9 @@ def kernel( momentum = generate_unit_vector(key_momentum, state.position) proposal, info, _ = adjusted_mclmc_proposal( integrator=integrators.with_isokinetic_maruyama( - integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) + integrator( + logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix + ) ), step_size=step_size, L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size), @@ -106,7 +108,7 @@ def as_top_level_api( logdensity_fn: Callable, step_size: float, L_proposal_factor: float = jnp.inf, - sqrt_diag_cov=1.0, + inverse_mass_matrix=1.0, *, divergence_threshold: int = 1000, integrator: Callable = integrators.isokinetic_mclachlan, @@ -143,7 +145,7 @@ def as_top_level_api( integration_steps_fn=integration_steps_fn, integrator=integrator, next_random_arg_fn=next_random_arg_fn, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, divergence_threshold=divergence_threshold, ) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 593683ca4..733e7e960 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -311,7 +311,9 @@ def _normalized_flatten_array(x, tol=1e-13): return jnp.where(norm > tol, x / norm, x), norm -def esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0): +def esh_dynamics_momentum_update_one_step(inverse_mass_matrix=1.0): + sqrt_inverse_mass_matrix = jax.tree_util.tree_map(jnp.sqrt, inverse_mass_matrix) + def update( momentum: ArrayTree, logdensity_grad: ArrayTree, @@ -330,7 +332,7 @@ def update( logdensity_grad = logdensity_grad flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) - flatten_grads = flatten_grads * sqrt_diag_cov + flatten_grads = flatten_grads * sqrt_inverse_mass_matrix flatten_momentum, _ = ravel_pytree(momentum) dims = flatten_momentum.shape[0] normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads) @@ -342,7 +344,7 @@ def update( + 2 * zeta * flatten_momentum ) new_momentum_normalized, _ = _normalized_flatten_array(new_momentum_raw) - gr = unravel_fn(new_momentum_normalized * sqrt_diag_cov) + gr = unravel_fn(new_momentum_normalized * sqrt_inverse_mass_matrix) next_momentum = unravel_fn(new_momentum_normalized) kinetic_energy_change = ( delta @@ -374,11 +376,11 @@ def format_isokinetic_state_output( def generate_isokinetic_integrator(coefficients): def isokinetic_integrator( - logdensity_fn: Callable, sqrt_diag_cov: ArrayTree = 1.0 + logdensity_fn: Callable, inverse_mass_matrix: ArrayTree = 1.0 ) -> GeneralIntegrator: position_update_fn = euclidean_position_update_fn(logdensity_fn) one_step = generalized_two_stage_integrator( - esh_dynamics_momentum_update_one_step(sqrt_diag_cov), + esh_dynamics_momentum_update_one_step(inverse_mass_matrix), position_update_fn, coefficients, format_output_fn=format_isokinetic_state_output, diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index e7a69849b..ff9638a1f 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -60,7 +60,7 @@ def init(position: ArrayLike, logdensity_fn, rng_key): ) -def build_kernel(logdensity_fn, sqrt_diag_cov, integrator): +def build_kernel(logdensity_fn, inverse_mass_matrix, integrator): """Build a HMC kernel. Parameters @@ -81,7 +81,7 @@ def build_kernel(logdensity_fn, sqrt_diag_cov, integrator): """ step = with_isokinetic_maruyama( - integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) + integrator(logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix) ) def kernel( @@ -107,7 +107,7 @@ def as_top_level_api( L, step_size, integrator=isokinetic_mclachlan, - sqrt_diag_cov=1.0, + inverse_mass_matrix=1.0, ) -> SamplingAlgorithm: """The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be cumbersome to manipulate. Since most users only need to specify the kernel @@ -155,7 +155,7 @@ def as_top_level_api( A ``SamplingAlgorithm``. """ - kernel = build_kernel(logdensity_fn, sqrt_diag_cov, integrator) + kernel = build_kernel(logdensity_fn, inverse_mass_matrix, integrator) def init_fn(position: ArrayLike, rng_key: PRNGKey): return init(position, logdensity_fn, rng_key) diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index c38009e5e..c37c0ede6 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -238,7 +238,7 @@ def test_esh_momentum_update(self, dims): # Efficient implementation update_stable = self.variant( - esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0) + esh_dynamics_momentum_update_one_step(inverse_mass_matrix=1.0) ) next_momentum1, *_ = update_stable(momentum, gradient, step_size, 1.0) np.testing.assert_array_almost_equal(next_momentum, next_momentum1) @@ -263,7 +263,7 @@ def test_isokinetic_velocity_verlet(self): next_state, kinetic_energy_change = step(initial_state, step_size) # explicit integration - op1 = esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0) + op1 = esh_dynamics_momentum_update_one_step(inverse_mass_matrix=1.0) op2 = integrators.euclidean_position_update_fn(logdensity_fn) position, momentum, _, logdensity_grad = initial_state momentum, kinetic_grad, kinetic_energy_change0 = op1( diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 45d60f84a..a4ea66a9b 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -112,10 +112,10 @@ def run_mclmc( position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key ) - kernel = lambda sqrt_diag_cov: blackjax.mcmc.mclmc.build_kernel( + kernel = lambda inverse_mass_matrix: blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, integrator=blackjax.mcmc.mclmc.isokinetic_mclachlan, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, ) ( @@ -133,7 +133,7 @@ def run_mclmc( logdensity_fn, L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, - sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, + inverse_mass_matrix=blackjax_mclmc_sampler_params.inverse_mass_matrix, ) _, samples = run_inference_algorithm( @@ -144,6 +144,8 @@ def run_mclmc( transform=lambda state, info: state.position, ) + print(samples["coefs"][0].item()) + return samples def run_adjusted_mclmc_dynamic( @@ -164,12 +166,12 @@ def run_adjusted_mclmc_dynamic( random_generator_arg=init_key, ) - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc_dynamic.build_kernel( + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, inverse_mass_matrix: blackjax.mcmc.adjusted_mclmc_dynamic.build_kernel( integrator=integrator, integration_steps_fn=lambda k: jnp.ceil( jax.random.uniform(k) * rescale(avg_num_integration_steps) ), - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, )( rng_key=rng_key, state=state, @@ -204,7 +206,7 @@ def run_adjusted_mclmc_dynamic( jax.random.uniform(key) * rescale(L / step_size) ), integrator=integrator, - sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, + inverse_mass_matrix=blackjax_mclmc_sampler_params.inverse_mass_matrix, ) _, out = run_inference_algorithm( @@ -216,6 +218,8 @@ def run_adjusted_mclmc_dynamic( progress_bar=False, ) + print(blackjax_mclmc_sampler_params.inverse_mass_matrix[1].item()) + return out def run_adjusted_mclmc( @@ -235,10 +239,10 @@ def run_adjusted_mclmc( logdensity_fn=logdensity_fn, ) - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc.build_kernel( + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, inverse_mass_matrix: blackjax.mcmc.adjusted_mclmc.build_kernel( integrator=integrator, num_integration_steps=avg_num_integration_steps, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, )( rng_key=rng_key, state=state, @@ -271,7 +275,7 @@ def run_adjusted_mclmc( step_size=step_size, num_integration_steps=L / step_size, integrator=integrator, - sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, + inverse_mass_matrix=blackjax_mclmc_sampler_params.inverse_mass_matrix, ) _, out = run_inference_algorithm( @@ -402,7 +406,10 @@ def test_mclmc(self): np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) @parameterized.parameters([True, False]) - def test_adjusted_mclmc_dynamic(self, diagonal_preconditioning): + def test_adjusted_mclmc_dynamic( + self, + diagonal_preconditioning, + ): """Test the MCLMC kernel.""" init_key0, init_key1, inference_key = jax.random.split(self.key, 3) @@ -495,7 +502,7 @@ def __init__(self, d, condition_number): integrator = isokinetic_mclachlan - def get_sqrt_diag_cov(): + def get_inverse_mass_matrix(): init_key, tune_key = jax.random.split(key) initial_position = model.sample_init(init_key) @@ -506,10 +513,10 @@ def get_sqrt_diag_cov(): rng_key=init_key, ) - kernel = lambda sqrt_diag_cov: blackjax.mcmc.mclmc.build_kernel( + kernel = lambda inverse_mass_matrix: blackjax.mcmc.mclmc.build_kernel( logdensity_fn=model.logdensity_fn, integrator=integrator, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, ) ( @@ -523,13 +530,14 @@ def get_sqrt_diag_cov(): diagonal_preconditioning=True, ) - return blackjax_mclmc_sampler_params.sqrt_diag_cov + return blackjax_mclmc_sampler_params.inverse_mass_matrix - sqrt_diag_cov = get_sqrt_diag_cov() + inverse_mass_matrix = get_inverse_mass_matrix() assert ( jnp.abs( jnp.dot( - (sqrt_diag_cov**2) / jnp.linalg.norm(sqrt_diag_cov**2), + (inverse_mass_matrix**2) + / jnp.linalg.norm(inverse_mass_matrix**2), eigs / jnp.linalg.norm(eigs), ) - 1 From a49bb35f37293f0033ea4c9c5b8daf7ff62c1461 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Jan 2025 13:54:36 -0500 Subject: [PATCH 03/10] add static adjusted mclmc --- tests/mcmc/test_sampling.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index a4ea66a9b..d788696f8 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -144,8 +144,6 @@ def run_mclmc( transform=lambda state, info: state.position, ) - print(samples["coefs"][0].item()) - return samples def run_adjusted_mclmc_dynamic( @@ -218,8 +216,6 @@ def run_adjusted_mclmc_dynamic( progress_bar=False, ) - print(blackjax_mclmc_sampler_params.inverse_mass_matrix[1].item()) - return out def run_adjusted_mclmc( From 04522f57c94cf6d5aa2bbb25f9cd14450be645da Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Jan 2025 14:18:55 -0500 Subject: [PATCH 04/10] change order of parameters --- blackjax/mcmc/adjusted_mclmc.py | 16 ++++++++-------- tests/mcmc/test_sampling.py | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index 9b868562c..f390402f2 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -37,7 +37,7 @@ def init(position: ArrayLikeTree, logdensity_fn: Callable): def build_kernel( - num_integration_steps: int, + logdensity_fn: Callable, integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, inverse_mass_matrix=1.0, @@ -66,8 +66,8 @@ def build_kernel( def kernel( rng_key: PRNGKey, state: HMCState, - logdensity_fn: Callable, step_size: float, + num_integration_steps: int, L_proposal_factor: float = jnp.inf, ) -> tuple[HMCState, HMCInfo]: """Generate a new sample with the MHMCHMC kernel.""" @@ -140,7 +140,7 @@ def as_top_level_api( """ kernel = build_kernel( - num_integration_steps, + logdensity_fn=logdensity_fn, integrator=integrator, inverse_mass_matrix=inverse_mass_matrix, divergence_threshold=divergence_threshold, @@ -152,11 +152,11 @@ def init_fn(position: ArrayLikeTree, rng_key=None): def update_fn(rng_key: PRNGKey, state): return kernel( - rng_key, - state, - logdensity_fn, - step_size, - L_proposal_factor, + rng_key=rng_key, + state=state, + step_size=step_size, + num_integration_steps=num_integration_steps, + L_proposal_factor=L_proposal_factor, ) return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type] diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index d788696f8..e9068326e 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -237,13 +237,13 @@ def run_adjusted_mclmc( kernel = lambda rng_key, state, avg_num_integration_steps, step_size, inverse_mass_matrix: blackjax.mcmc.adjusted_mclmc.build_kernel( integrator=integrator, - num_integration_steps=avg_num_integration_steps, inverse_mass_matrix=inverse_mass_matrix, + logdensity_fn=logdensity_fn, )( rng_key=rng_key, state=state, step_size=step_size, - logdensity_fn=logdensity_fn, + num_integration_steps=avg_num_integration_steps, ) target_acc_rate = 0.9 From d2717e93305c93c4eb5713ce2b5972d934fbcba4 Mon Sep 17 00:00:00 2001 From: = Date: Sun, 19 Jan 2025 02:26:54 +0000 Subject: [PATCH 05/10] return tuning steps --- .../adaptation/adjusted_mclmc_adaptation.py | 27 ++++++++++++------- blackjax/adaptation/mclmc_adaptation.py | 8 +++--- tests/mcmc/test_sampling.py | 4 +++ 3 files changed, 26 insertions(+), 13 deletions(-) diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py index eabb642a3..762fe75fb 100644 --- a/blackjax/adaptation/adjusted_mclmc_adaptation.py +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -79,9 +79,10 @@ def adjusted_mclmc_find_L_and_step_size( part1_key, part2_key = jax.random.split(rng_key, 2) + total_num_tuning_integrator_steps = 0 for i in range(num_windows): window_key = jax.random.fold_in(part1_key, i) - (state, params, eigenvector) = adjusted_mclmc_make_L_step_size_adaptation( + (state, params, eigenvector, num_tuning_integrator_steps) = adjusted_mclmc_make_L_step_size_adaptation( kernel=mclmc_kernel, dim=dim, frac_tune1=frac_tune1, @@ -91,13 +92,14 @@ def adjusted_mclmc_find_L_and_step_size( max=max, tuning_factor=tuning_factor, )(state, params, num_steps, window_key) + total_num_tuning_integrator_steps += num_tuning_integrator_steps if frac_tune3 != 0: for i in range(num_windows): part2_key = jax.random.fold_in(part2_key, i) part2_key1, part2_key2 = jax.random.split(part2_key, 2) - state, params = adjusted_mclmc_make_adaptation_L( + state, params, num_tuning_integrator_steps = adjusted_mclmc_make_adaptation_L( mclmc_kernel, frac=frac_tune3, Lfactor=0.5, @@ -105,7 +107,9 @@ def adjusted_mclmc_find_L_and_step_size( eigenvector=eigenvector, )(state, params, num_steps, part2_key1) - (state, params, _) = adjusted_mclmc_make_L_step_size_adaptation( + total_num_tuning_integrator_steps += num_tuning_integrator_steps + + (state, params, _, num_tuning_integrator_steps) = adjusted_mclmc_make_L_step_size_adaptation( kernel=mclmc_kernel, dim=dim, frac_tune1=frac_tune1, @@ -117,7 +121,9 @@ def adjusted_mclmc_find_L_and_step_size( tuning_factor=tuning_factor, )(state, params, num_steps, part2_key2) - return state, params + total_num_tuning_integrator_steps += num_tuning_integrator_steps + + return state, params, total_num_tuning_integrator_steps def adjusted_mclmc_make_L_step_size_adaptation( @@ -256,6 +262,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): update_da=update_da, ) + num_tuning_integrator_steps = info.num_integration_steps.sum() final_stepsize = final_da(dual_avg_state) params = params._replace(step_size=final_stepsize) @@ -299,9 +306,11 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): initial_da=initial_da, ) + num_tuning_integrator_steps += info.num_integration_steps.sum() + params = params._replace(step_size=final_da(dual_avg_state)) - return state, params, eigenvector + return state, params, eigenvector, num_tuning_integrator_steps return L_step_size_adaptation @@ -316,16 +325,16 @@ def adaptation_L(state, params, num_steps, key): adaptation_L_keys = jax.random.split(key, num_steps) def step(state, key): - next_state, _ = kernel( + next_state, info = kernel( rng_key=key, state=state, step_size=params.step_size, avg_num_integration_steps=params.L / params.step_size, inverse_mass_matrix=params.inverse_mass_matrix, ) - return next_state, next_state.position + return next_state, (next_state.position, info) - state, samples = jax.lax.scan( + state, (samples, info) = jax.lax.scan( f=step, init=state, xs=adaptation_L_keys, @@ -350,7 +359,7 @@ def step(state, key): L=jnp.clip( Lfactor * params.L / jnp.mean(ess), max=params.L * Lratio_upperbound ) - ) + ), info.num_integration_steps.sum() return adaptation_L diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index aa192b964..60fd46359 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -126,7 +126,7 @@ def mclmc_find_L_and_step_size( mclmc_kernel(params.inverse_mass_matrix), frac=frac_tune3, Lfactor=0.4 )(state, params, num_steps, part2_key) - return state, params + return state, params, num_steps * (frac_tune1 + frac_tune2 + frac_tune3) def make_L_step_size_adaptation( @@ -274,8 +274,8 @@ def make_adaptation_L(kernel, frac, Lfactor): """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" def adaptation_L(state, params, num_steps, key): - num_steps = int(num_steps * frac) - adaptation_L_keys = jax.random.split(key, num_steps) + num_steps_3 = int(num_steps * frac) + adaptation_L_keys = jax.random.split(key, num_steps_3) def step(state, key): next_state, _ = kernel( @@ -297,7 +297,7 @@ def step(state, key): ess = effective_sample_size(flat_samples[None, ...]) return state, params._replace( - L=Lfactor * params.step_size * jnp.mean(num_steps / ess) + L=Lfactor * params.step_size * jnp.mean(num_steps_3 / ess) ) return adaptation_L diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index e9068326e..3a53fa084 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -121,6 +121,7 @@ def run_mclmc( ( blackjax_state_after_tuning, blackjax_mclmc_sampler_params, + num_tuning_integrator_steps, ) = blackjax.mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, @@ -182,6 +183,7 @@ def run_adjusted_mclmc_dynamic( ( blackjax_state_after_tuning, blackjax_mclmc_sampler_params, + num_tuning_integrator_steps, ) = blackjax.adjusted_mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, @@ -251,6 +253,7 @@ def run_adjusted_mclmc( ( blackjax_state_after_tuning, blackjax_mclmc_sampler_params, + num_tuning_integrator_steps, ) = blackjax.adjusted_mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, @@ -518,6 +521,7 @@ def get_inverse_mass_matrix(): ( _, blackjax_mclmc_sampler_params, + _ ) = blackjax.mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, From 98d48735f7a94434d1201abcc5566e411bfe6b43 Mon Sep 17 00:00:00 2001 From: = Date: Sun, 19 Jan 2025 02:49:20 +0000 Subject: [PATCH 06/10] return tuning steps --- .../adaptation/adjusted_mclmc_adaptation.py | 46 ++++++++++++++----- tests/mcmc/test_sampling.py | 6 +-- 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py index 762fe75fb..408c31383 100644 --- a/blackjax/adaptation/adjusted_mclmc_adaptation.py +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -82,7 +82,12 @@ def adjusted_mclmc_find_L_and_step_size( total_num_tuning_integrator_steps = 0 for i in range(num_windows): window_key = jax.random.fold_in(part1_key, i) - (state, params, eigenvector, num_tuning_integrator_steps) = adjusted_mclmc_make_L_step_size_adaptation( + ( + state, + params, + eigenvector, + num_tuning_integrator_steps, + ) = adjusted_mclmc_make_L_step_size_adaptation( kernel=mclmc_kernel, dim=dim, frac_tune1=frac_tune1, @@ -91,7 +96,9 @@ def adjusted_mclmc_find_L_and_step_size( diagonal_preconditioning=diagonal_preconditioning, max=max, tuning_factor=tuning_factor, - )(state, params, num_steps, window_key) + )( + state, params, num_steps, window_key + ) total_num_tuning_integrator_steps += num_tuning_integrator_steps if frac_tune3 != 0: @@ -99,17 +106,28 @@ def adjusted_mclmc_find_L_and_step_size( part2_key = jax.random.fold_in(part2_key, i) part2_key1, part2_key2 = jax.random.split(part2_key, 2) - state, params, num_tuning_integrator_steps = adjusted_mclmc_make_adaptation_L( + ( + state, + params, + num_tuning_integrator_steps, + ) = adjusted_mclmc_make_adaptation_L( mclmc_kernel, frac=frac_tune3, Lfactor=0.5, max=max, eigenvector=eigenvector, - )(state, params, num_steps, part2_key1) + )( + state, params, num_steps, part2_key1 + ) total_num_tuning_integrator_steps += num_tuning_integrator_steps - (state, params, _, num_tuning_integrator_steps) = adjusted_mclmc_make_L_step_size_adaptation( + ( + state, + params, + _, + num_tuning_integrator_steps, + ) = adjusted_mclmc_make_L_step_size_adaptation( kernel=mclmc_kernel, dim=dim, frac_tune1=frac_tune1, @@ -119,7 +137,9 @@ def adjusted_mclmc_find_L_and_step_size( diagonal_preconditioning=diagonal_preconditioning, max=max, tuning_factor=tuning_factor, - )(state, params, num_steps, part2_key2) + )( + state, params, num_steps, part2_key2 + ) total_num_tuning_integrator_steps += num_tuning_integrator_steps @@ -355,11 +375,15 @@ def step(state, key): # number of effective samples per 1 actual sample ess = contract(effective_sample_size(flat_samples[None, ...])) / num_steps - return state, params._replace( - L=jnp.clip( - Lfactor * params.L / jnp.mean(ess), max=params.L * Lratio_upperbound - ) - ), info.num_integration_steps.sum() + return ( + state, + params._replace( + L=jnp.clip( + Lfactor * params.L / jnp.mean(ess), max=params.L * Lratio_upperbound + ) + ), + info.num_integration_steps.sum(), + ) return adaptation_L diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 3a53fa084..fb19fe67c 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -518,11 +518,7 @@ def get_inverse_mass_matrix(): inverse_mass_matrix=inverse_mass_matrix, ) - ( - _, - blackjax_mclmc_sampler_params, - _ - ) = blackjax.mclmc_find_L_and_step_size( + (_, blackjax_mclmc_sampler_params, _) = blackjax.mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, state=initial_state, From 4ef0b92621d149e1aa66a453d30d388b52ce4659 Mon Sep 17 00:00:00 2001 From: = Date: Sun, 19 Jan 2025 09:22:04 +0000 Subject: [PATCH 07/10] return tuning steps --- blackjax/mcmc/adjusted_mclmc_dynamic.py | 1 - 1 file changed, 1 deletion(-) diff --git a/blackjax/mcmc/adjusted_mclmc_dynamic.py b/blackjax/mcmc/adjusted_mclmc_dynamic.py index 1a69e1a28..c0ac96fec 100644 --- a/blackjax/mcmc/adjusted_mclmc_dynamic.py +++ b/blackjax/mcmc/adjusted_mclmc_dynamic.py @@ -135,7 +135,6 @@ def as_top_level_api( Function that generates the next pseudo or quasi-random number of integration steps in the sequence, given the current `random_generator_arg`. - Returns ------- A ``SamplingAlgorithm``. From 5b398730ce225e6f72c083166be51d6d13197842 Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Mon, 20 Jan 2025 09:10:42 +0100 Subject: [PATCH 08/10] disable some test branch --- tests/vi/test_schrodinger_follmer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/vi/test_schrodinger_follmer.py b/tests/vi/test_schrodinger_follmer.py index 79e7afedd..33e406d11 100644 --- a/tests/vi/test_schrodinger_follmer.py +++ b/tests/vi/test_schrodinger_follmer.py @@ -14,7 +14,7 @@ def setUp(self): super().setUp() self.key = jax.random.key(1) - @chex.all_variants(with_pmap=True) + @chex.all_variants(with_pmap=True, without_jit=False, without_device=False) def test_recover_posterior(self): """Simple Normal mean test""" From 3175cb7fa96720671a442a3516c173f95e50fc06 Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Mon, 20 Jan 2025 09:28:35 +0100 Subject: [PATCH 09/10] keeping only jit branch --- tests/vi/test_schrodinger_follmer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/vi/test_schrodinger_follmer.py b/tests/vi/test_schrodinger_follmer.py index 33e406d11..13bea143f 100644 --- a/tests/vi/test_schrodinger_follmer.py +++ b/tests/vi/test_schrodinger_follmer.py @@ -14,7 +14,7 @@ def setUp(self): super().setUp() self.key = jax.random.key(1) - @chex.all_variants(with_pmap=True, without_jit=False, without_device=False) + # @chex.all_variants(with_pmap=True) def test_recover_posterior(self): """Simple Normal mean test""" @@ -70,7 +70,7 @@ def logp_unnormalized_posterior(x, observed, prior_mu, prior_prec, true_cov): schrodinger_follmer_algo = schrodinger_follmer(logp_model, 50, 25) initial_state = schrodinger_follmer_algo.init(initial_position) - schrodinger_follmer_algo_sample = self.variant( + schrodinger_follmer_algo_sample = jax.jit( lambda k, s: schrodinger_follmer_algo.sample(k, s, 100) ) sampled_states = schrodinger_follmer_algo_sample(rng_key_init, initial_state) From 6714ad31e903b68f3e5cf9238ca259f8d10aabce Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Mon, 20 Jan 2025 11:20:19 +0100 Subject: [PATCH 10/10] revert change --- tests/vi/test_schrodinger_follmer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/vi/test_schrodinger_follmer.py b/tests/vi/test_schrodinger_follmer.py index 13bea143f..79e7afedd 100644 --- a/tests/vi/test_schrodinger_follmer.py +++ b/tests/vi/test_schrodinger_follmer.py @@ -14,7 +14,7 @@ def setUp(self): super().setUp() self.key = jax.random.key(1) - # @chex.all_variants(with_pmap=True) + @chex.all_variants(with_pmap=True) def test_recover_posterior(self): """Simple Normal mean test""" @@ -70,7 +70,7 @@ def logp_unnormalized_posterior(x, observed, prior_mu, prior_prec, true_cov): schrodinger_follmer_algo = schrodinger_follmer(logp_model, 50, 25) initial_state = schrodinger_follmer_algo.init(initial_position) - schrodinger_follmer_algo_sample = jax.jit( + schrodinger_follmer_algo_sample = self.variant( lambda k, s: schrodinger_follmer_algo.sample(k, s, 100) ) sampled_states = schrodinger_follmer_algo_sample(rng_key_init, initial_state)