Skip to content

Commit

Permalink
test CI: old tests with addition of num tuning steps
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed Jan 21, 2025
1 parent 85fe088 commit f65a4b2
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 23 deletions.
63 changes: 48 additions & 15 deletions blackjax/adaptation/adjusted_mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,15 @@ def adjusted_mclmc_find_L_and_step_size(

part1_key, part2_key = jax.random.split(rng_key, 2)

total_num_tuning_integrator_steps = 0
for i in range(num_windows):
window_key = jax.random.fold_in(part1_key, i)
(state, params, eigenvector) = adjusted_mclmc_make_L_step_size_adaptation(
(
state,
params,
eigenvector,
num_tuning_integrator_steps,
) = adjusted_mclmc_make_L_step_size_adaptation(
kernel=mclmc_kernel,
dim=dim,
frac_tune1=frac_tune1,
Expand All @@ -90,22 +96,38 @@ def adjusted_mclmc_find_L_and_step_size(
diagonal_preconditioning=diagonal_preconditioning,
max=max,
tuning_factor=tuning_factor,
)(state, params, num_steps, window_key)
)(
state, params, num_steps, window_key
)
total_num_tuning_integrator_steps += num_tuning_integrator_steps

if frac_tune3 != 0:
for i in range(num_windows):
part2_key = jax.random.fold_in(part2_key, i)
part2_key1, part2_key2 = jax.random.split(part2_key, 2)

state, params = adjusted_mclmc_make_adaptation_L(
(
state,
params,
num_tuning_integrator_steps,
) = adjusted_mclmc_make_adaptation_L(
mclmc_kernel,
frac=frac_tune3,
Lfactor=0.5,
max=max,
eigenvector=eigenvector,
)(state, params, num_steps, part2_key1)
)(
state, params, num_steps, part2_key1
)

(state, params, _) = adjusted_mclmc_make_L_step_size_adaptation(
total_num_tuning_integrator_steps += num_tuning_integrator_steps

(
state,
params,
_,
num_tuning_integrator_steps,
) = adjusted_mclmc_make_L_step_size_adaptation(
kernel=mclmc_kernel,
dim=dim,
frac_tune1=frac_tune1,
Expand All @@ -115,9 +137,13 @@ def adjusted_mclmc_find_L_and_step_size(
diagonal_preconditioning=diagonal_preconditioning,
max=max,
tuning_factor=tuning_factor,
)(state, params, num_steps, part2_key2)
)(
state, params, num_steps, part2_key2
)

return state, params
total_num_tuning_integrator_steps += num_tuning_integrator_steps

return state, params, total_num_tuning_integrator_steps


def adjusted_mclmc_make_L_step_size_adaptation(
Expand Down Expand Up @@ -256,6 +282,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
update_da=update_da,
)

num_tuning_integrator_steps = info.num_integration_steps.sum()
final_stepsize = final_da(dual_avg_state)
params = params._replace(step_size=final_stepsize)

Expand Down Expand Up @@ -299,9 +326,11 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
initial_da=initial_da,
)

num_tuning_integrator_steps += info.num_integration_steps.sum()

params = params._replace(step_size=final_da(dual_avg_state))

return state, params, eigenvector
return state, params, eigenvector, num_tuning_integrator_steps

return L_step_size_adaptation

Expand All @@ -316,16 +345,16 @@ def adaptation_L(state, params, num_steps, key):
adaptation_L_keys = jax.random.split(key, num_steps)

def step(state, key):
next_state, _ = kernel(
next_state, info = kernel(
rng_key=key,
state=state,
step_size=params.step_size,
avg_num_integration_steps=params.L / params.step_size,
inverse_mass_matrix=params.inverse_mass_matrix,
)
return next_state, next_state.position
return next_state, (next_state.position, info)

state, samples = jax.lax.scan(
state, (samples, info) = jax.lax.scan(
f=step,
init=state,
xs=adaptation_L_keys,
Expand All @@ -346,10 +375,14 @@ def step(state, key):
# number of effective samples per 1 actual sample
ess = contract(effective_sample_size(flat_samples[None, ...])) / num_steps

return state, params._replace(
L=jnp.clip(
Lfactor * params.L / jnp.mean(ess), max=params.L * Lratio_upperbound
)
return (
state,
params._replace(
L=jnp.clip(
Lfactor * params.L / jnp.mean(ess), max=params.L * Lratio_upperbound
)
),
info.num_integration_steps.sum(),
)

return adaptation_L
Expand Down
8 changes: 4 additions & 4 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def mclmc_find_L_and_step_size(
mclmc_kernel(params.inverse_mass_matrix), frac=frac_tune3, Lfactor=0.4
)(state, params, num_steps, part2_key)

return state, params
return state, params, num_steps * (frac_tune1 + frac_tune2 + frac_tune3)


def make_L_step_size_adaptation(
Expand Down Expand Up @@ -274,8 +274,8 @@ def make_adaptation_L(kernel, frac, Lfactor):
"""determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)"""

def adaptation_L(state, params, num_steps, key):
num_steps = int(num_steps * frac)
adaptation_L_keys = jax.random.split(key, num_steps)
num_steps_3 = int(num_steps * frac)
adaptation_L_keys = jax.random.split(key, num_steps_3)

def step(state, key):
next_state, _ = kernel(
Expand All @@ -297,7 +297,7 @@ def step(state, key):
ess = effective_sample_size(flat_samples[None, ...])

return state, params._replace(
L=Lfactor * params.step_size * jnp.mean(num_steps / ess)
L=Lfactor * params.step_size * jnp.mean(num_steps_3 / ess)
)

return adaptation_L
Expand Down
8 changes: 4 additions & 4 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def run_mclmc(
(
blackjax_state_after_tuning,
blackjax_mclmc_sampler_params,
_,
) = blackjax.mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_steps,
Expand Down Expand Up @@ -183,6 +184,7 @@ def run_adjusted_mclmc(
(
blackjax_state_after_tuning,
blackjax_mclmc_sampler_params,
_,
) = blackjax.adjusted_mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_steps,
Expand Down Expand Up @@ -252,6 +254,7 @@ def run_adjusted_mclmc_static(
(
blackjax_state_after_tuning,
blackjax_mclmc_sampler_params,
_,
) = blackjax.adjusted_mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_steps,
Expand Down Expand Up @@ -509,10 +512,7 @@ def get_inverse_mass_matrix():
inverse_mass_matrix=inverse_mass_matrix,
)

(
_,
blackjax_mclmc_sampler_params,
) = blackjax.mclmc_find_L_and_step_size(
(_, blackjax_mclmc_sampler_params, _) = blackjax.mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_steps,
state=initial_state,
Expand Down

0 comments on commit f65a4b2

Please sign in to comment.