diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py index 762fe75fb..408c31383 100644 --- a/blackjax/adaptation/adjusted_mclmc_adaptation.py +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -82,7 +82,12 @@ def adjusted_mclmc_find_L_and_step_size( total_num_tuning_integrator_steps = 0 for i in range(num_windows): window_key = jax.random.fold_in(part1_key, i) - (state, params, eigenvector, num_tuning_integrator_steps) = adjusted_mclmc_make_L_step_size_adaptation( + ( + state, + params, + eigenvector, + num_tuning_integrator_steps, + ) = adjusted_mclmc_make_L_step_size_adaptation( kernel=mclmc_kernel, dim=dim, frac_tune1=frac_tune1, @@ -91,7 +96,9 @@ def adjusted_mclmc_find_L_and_step_size( diagonal_preconditioning=diagonal_preconditioning, max=max, tuning_factor=tuning_factor, - )(state, params, num_steps, window_key) + )( + state, params, num_steps, window_key + ) total_num_tuning_integrator_steps += num_tuning_integrator_steps if frac_tune3 != 0: @@ -99,17 +106,28 @@ def adjusted_mclmc_find_L_and_step_size( part2_key = jax.random.fold_in(part2_key, i) part2_key1, part2_key2 = jax.random.split(part2_key, 2) - state, params, num_tuning_integrator_steps = adjusted_mclmc_make_adaptation_L( + ( + state, + params, + num_tuning_integrator_steps, + ) = adjusted_mclmc_make_adaptation_L( mclmc_kernel, frac=frac_tune3, Lfactor=0.5, max=max, eigenvector=eigenvector, - )(state, params, num_steps, part2_key1) + )( + state, params, num_steps, part2_key1 + ) total_num_tuning_integrator_steps += num_tuning_integrator_steps - (state, params, _, num_tuning_integrator_steps) = adjusted_mclmc_make_L_step_size_adaptation( + ( + state, + params, + _, + num_tuning_integrator_steps, + ) = adjusted_mclmc_make_L_step_size_adaptation( kernel=mclmc_kernel, dim=dim, frac_tune1=frac_tune1, @@ -119,7 +137,9 @@ def adjusted_mclmc_find_L_and_step_size( diagonal_preconditioning=diagonal_preconditioning, max=max, tuning_factor=tuning_factor, - )(state, params, num_steps, part2_key2) + )( + state, params, num_steps, part2_key2 + ) total_num_tuning_integrator_steps += num_tuning_integrator_steps @@ -355,11 +375,15 @@ def step(state, key): # number of effective samples per 1 actual sample ess = contract(effective_sample_size(flat_samples[None, ...])) / num_steps - return state, params._replace( - L=jnp.clip( - Lfactor * params.L / jnp.mean(ess), max=params.L * Lratio_upperbound - ) - ), info.num_integration_steps.sum() + return ( + state, + params._replace( + L=jnp.clip( + Lfactor * params.L / jnp.mean(ess), max=params.L * Lratio_upperbound + ) + ), + info.num_integration_steps.sum(), + ) return adaptation_L diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 3a53fa084..fb19fe67c 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -518,11 +518,7 @@ def get_inverse_mass_matrix(): inverse_mass_matrix=inverse_mass_matrix, ) - ( - _, - blackjax_mclmc_sampler_params, - _ - ) = blackjax.mclmc_find_L_and_step_size( + (_, blackjax_mclmc_sampler_params, _) = blackjax.mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, state=initial_state,