diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py index eabb642a3..408c31383 100644 --- a/blackjax/adaptation/adjusted_mclmc_adaptation.py +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -79,9 +79,15 @@ def adjusted_mclmc_find_L_and_step_size( part1_key, part2_key = jax.random.split(rng_key, 2) + total_num_tuning_integrator_steps = 0 for i in range(num_windows): window_key = jax.random.fold_in(part1_key, i) - (state, params, eigenvector) = adjusted_mclmc_make_L_step_size_adaptation( + ( + state, + params, + eigenvector, + num_tuning_integrator_steps, + ) = adjusted_mclmc_make_L_step_size_adaptation( kernel=mclmc_kernel, dim=dim, frac_tune1=frac_tune1, @@ -90,22 +96,38 @@ def adjusted_mclmc_find_L_and_step_size( diagonal_preconditioning=diagonal_preconditioning, max=max, tuning_factor=tuning_factor, - )(state, params, num_steps, window_key) + )( + state, params, num_steps, window_key + ) + total_num_tuning_integrator_steps += num_tuning_integrator_steps if frac_tune3 != 0: for i in range(num_windows): part2_key = jax.random.fold_in(part2_key, i) part2_key1, part2_key2 = jax.random.split(part2_key, 2) - state, params = adjusted_mclmc_make_adaptation_L( + ( + state, + params, + num_tuning_integrator_steps, + ) = adjusted_mclmc_make_adaptation_L( mclmc_kernel, frac=frac_tune3, Lfactor=0.5, max=max, eigenvector=eigenvector, - )(state, params, num_steps, part2_key1) + )( + state, params, num_steps, part2_key1 + ) - (state, params, _) = adjusted_mclmc_make_L_step_size_adaptation( + total_num_tuning_integrator_steps += num_tuning_integrator_steps + + ( + state, + params, + _, + num_tuning_integrator_steps, + ) = adjusted_mclmc_make_L_step_size_adaptation( kernel=mclmc_kernel, dim=dim, frac_tune1=frac_tune1, @@ -115,9 +137,13 @@ def adjusted_mclmc_find_L_and_step_size( diagonal_preconditioning=diagonal_preconditioning, max=max, tuning_factor=tuning_factor, - )(state, params, num_steps, part2_key2) + )( + state, params, num_steps, part2_key2 + ) - return state, params + total_num_tuning_integrator_steps += num_tuning_integrator_steps + + return state, params, total_num_tuning_integrator_steps def adjusted_mclmc_make_L_step_size_adaptation( @@ -256,6 +282,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): update_da=update_da, ) + num_tuning_integrator_steps = info.num_integration_steps.sum() final_stepsize = final_da(dual_avg_state) params = params._replace(step_size=final_stepsize) @@ -299,9 +326,11 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): initial_da=initial_da, ) + num_tuning_integrator_steps += info.num_integration_steps.sum() + params = params._replace(step_size=final_da(dual_avg_state)) - return state, params, eigenvector + return state, params, eigenvector, num_tuning_integrator_steps return L_step_size_adaptation @@ -316,16 +345,16 @@ def adaptation_L(state, params, num_steps, key): adaptation_L_keys = jax.random.split(key, num_steps) def step(state, key): - next_state, _ = kernel( + next_state, info = kernel( rng_key=key, state=state, step_size=params.step_size, avg_num_integration_steps=params.L / params.step_size, inverse_mass_matrix=params.inverse_mass_matrix, ) - return next_state, next_state.position + return next_state, (next_state.position, info) - state, samples = jax.lax.scan( + state, (samples, info) = jax.lax.scan( f=step, init=state, xs=adaptation_L_keys, @@ -346,10 +375,14 @@ def step(state, key): # number of effective samples per 1 actual sample ess = contract(effective_sample_size(flat_samples[None, ...])) / num_steps - return state, params._replace( - L=jnp.clip( - Lfactor * params.L / jnp.mean(ess), max=params.L * Lratio_upperbound - ) + return ( + state, + params._replace( + L=jnp.clip( + Lfactor * params.L / jnp.mean(ess), max=params.L * Lratio_upperbound + ) + ), + info.num_integration_steps.sum(), ) return adaptation_L diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index aa192b964..60fd46359 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -126,7 +126,7 @@ def mclmc_find_L_and_step_size( mclmc_kernel(params.inverse_mass_matrix), frac=frac_tune3, Lfactor=0.4 )(state, params, num_steps, part2_key) - return state, params + return state, params, num_steps * (frac_tune1 + frac_tune2 + frac_tune3) def make_L_step_size_adaptation( @@ -274,8 +274,8 @@ def make_adaptation_L(kernel, frac, Lfactor): """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" def adaptation_L(state, params, num_steps, key): - num_steps = int(num_steps * frac) - adaptation_L_keys = jax.random.split(key, num_steps) + num_steps_3 = int(num_steps * frac) + adaptation_L_keys = jax.random.split(key, num_steps_3) def step(state, key): next_state, _ = kernel( @@ -297,7 +297,7 @@ def step(state, key): ess = effective_sample_size(flat_samples[None, ...]) return state, params._replace( - L=Lfactor * params.step_size * jnp.mean(num_steps / ess) + L=Lfactor * params.step_size * jnp.mean(num_steps_3 / ess) ) return adaptation_L diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 34c2e3010..4d8a9fa61 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -122,6 +122,7 @@ def run_mclmc( ( blackjax_state_after_tuning, blackjax_mclmc_sampler_params, + _, ) = blackjax.mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, @@ -183,6 +184,7 @@ def run_adjusted_mclmc( ( blackjax_state_after_tuning, blackjax_mclmc_sampler_params, + _, ) = blackjax.adjusted_mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, @@ -252,6 +254,7 @@ def run_adjusted_mclmc_static( ( blackjax_state_after_tuning, blackjax_mclmc_sampler_params, + _, ) = blackjax.adjusted_mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, @@ -509,10 +512,7 @@ def get_inverse_mass_matrix(): inverse_mass_matrix=inverse_mass_matrix, ) - ( - _, - blackjax_mclmc_sampler_params, - ) = blackjax.mclmc_find_L_and_step_size( + (_, blackjax_mclmc_sampler_params, _) = blackjax.mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, state=initial_state,