Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjusted MCLMC #771

Merged
merged 14 commits into from
Jan 21, 2025
2 changes: 2 additions & 0 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .diagnostics import effective_sample_size as ess
from .diagnostics import potential_scale_reduction as rhat
from .mcmc import adjusted_mclmc as _adjusted_mclmc
from .mcmc import adjusted_mclmc_dynamic as _adjusted_mclmc_dynamic
from .mcmc import barker
from .mcmc import dynamic_hmc as _dynamic_hmc
from .mcmc import elliptical_slice as _elliptical_slice
Expand Down Expand Up @@ -113,6 +114,7 @@ def generate_top_level_api_from(module):
additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk)

mclmc = generate_top_level_api_from(_mclmc)
adjusted_mclmc_dynamic = generate_top_level_api_from(_adjusted_mclmc_dynamic)
adjusted_mclmc = generate_top_level_api_from(_adjusted_mclmc)
elliptical_slice = generate_top_level_api_from(_elliptical_slice)
ghmc = generate_top_level_api_from(_ghmc)
Expand Down
73 changes: 52 additions & 21 deletions blackjax/adaptation/adjusted_mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,20 @@ def adjusted_mclmc_find_L_and_step_size(
dim = pytree_size(state.position)
if params is None:
params = MCLMCAdaptationState(
jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, sqrt_diag_cov=jnp.ones((dim,))
jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, inverse_mass_matrix=jnp.ones((dim,))
)

part1_key, part2_key = jax.random.split(rng_key, 2)

total_num_tuning_integrator_steps = 0
for i in range(num_windows):
window_key = jax.random.fold_in(part1_key, i)
(state, params, eigenvector) = adjusted_mclmc_make_L_step_size_adaptation(
(
state,
params,
eigenvector,
num_tuning_integrator_steps,
) = adjusted_mclmc_make_L_step_size_adaptation(
kernel=mclmc_kernel,
dim=dim,
frac_tune1=frac_tune1,
Expand All @@ -90,22 +96,38 @@ def adjusted_mclmc_find_L_and_step_size(
diagonal_preconditioning=diagonal_preconditioning,
max=max,
tuning_factor=tuning_factor,
)(state, params, num_steps, window_key)
)(
state, params, num_steps, window_key
)
total_num_tuning_integrator_steps += num_tuning_integrator_steps

if frac_tune3 != 0:
for i in range(num_windows):
part2_key = jax.random.fold_in(part2_key, i)
part2_key1, part2_key2 = jax.random.split(part2_key, 2)

state, params = adjusted_mclmc_make_adaptation_L(
(
state,
params,
num_tuning_integrator_steps,
) = adjusted_mclmc_make_adaptation_L(
mclmc_kernel,
frac=frac_tune3,
Lfactor=0.5,
max=max,
eigenvector=eigenvector,
)(state, params, num_steps, part2_key1)
)(
state, params, num_steps, part2_key1
)

total_num_tuning_integrator_steps += num_tuning_integrator_steps

(state, params, _) = adjusted_mclmc_make_L_step_size_adaptation(
(
state,
params,
_,
num_tuning_integrator_steps,
) = adjusted_mclmc_make_L_step_size_adaptation(
kernel=mclmc_kernel,
dim=dim,
frac_tune1=frac_tune1,
Expand All @@ -115,9 +137,13 @@ def adjusted_mclmc_find_L_and_step_size(
diagonal_preconditioning=diagonal_preconditioning,
max=max,
tuning_factor=tuning_factor,
)(state, params, num_steps, part2_key2)
)(
state, params, num_steps, part2_key2
)

total_num_tuning_integrator_steps += num_tuning_integrator_steps

return state, params
return state, params, total_num_tuning_integrator_steps


def adjusted_mclmc_make_L_step_size_adaptation(
Expand Down Expand Up @@ -152,7 +178,7 @@ def step(iteration_state, weight_and_key):
state=previous_state,
avg_num_integration_steps=avg_num_integration_steps,
step_size=params.step_size,
sqrt_diag_cov=params.sqrt_diag_cov,
inverse_mass_matrix=params.inverse_mass_matrix,
)

# step updating
Expand Down Expand Up @@ -256,6 +282,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
update_da=update_da,
)

num_tuning_integrator_steps = info.num_integration_steps.sum()
final_stepsize = final_da(dual_avg_state)
params = params._replace(step_size=final_stepsize)

Expand Down Expand Up @@ -283,9 +310,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
L=params.L * change, step_size=params.step_size * change
)
if diagonal_preconditioning:
params = params._replace(
sqrt_diag_cov=jnp.sqrt(variances), L=jnp.sqrt(dim)
)
params = params._replace(inverse_mass_matrix=variances, L=jnp.sqrt(dim))

initial_da, update_da, final_da = dual_averaging_adaptation(target=target)
(
Expand All @@ -301,9 +326,11 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
initial_da=initial_da,
)

num_tuning_integrator_steps += info.num_integration_steps.sum()

params = params._replace(step_size=final_da(dual_avg_state))

return state, params, eigenvector
return state, params, eigenvector, num_tuning_integrator_steps

return L_step_size_adaptation

Expand All @@ -318,16 +345,16 @@ def adaptation_L(state, params, num_steps, key):
adaptation_L_keys = jax.random.split(key, num_steps)

def step(state, key):
next_state, _ = kernel(
next_state, info = kernel(
rng_key=key,
state=state,
step_size=params.step_size,
avg_num_integration_steps=params.L / params.step_size,
sqrt_diag_cov=params.sqrt_diag_cov,
inverse_mass_matrix=params.inverse_mass_matrix,
)
return next_state, next_state.position
return next_state, (next_state.position, info)

state, samples = jax.lax.scan(
state, (samples, info) = jax.lax.scan(
f=step,
init=state,
xs=adaptation_L_keys,
Expand All @@ -348,10 +375,14 @@ def step(state, key):
# number of effective samples per 1 actual sample
ess = contract(effective_sample_size(flat_samples[None, ...])) / num_steps

return state, params._replace(
L=jnp.clip(
Lfactor * params.L / jnp.mean(ess), max=params.L * Lratio_upperbound
)
return (
state,
params._replace(
L=jnp.clip(
Lfactor * params.L / jnp.mean(ess), max=params.L * Lratio_upperbound
)
),
info.num_integration_steps.sum(),
)

return adaptation_L
Expand Down
30 changes: 15 additions & 15 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ class MCLMCAdaptationState(NamedTuple):
The momentum decoherent rate for the MCLMC algorithm.
step_size
The step size used for the MCLMC algorithm.
sqrt_diag_cov
inverse_mass_matrix
A matrix used for preconditioning.
"""

L: float
step_size: float
sqrt_diag_cov: float
inverse_mass_matrix: float


def mclmc_find_L_and_step_size(
Expand Down Expand Up @@ -87,10 +87,10 @@ def mclmc_find_L_and_step_size(
Example
-------
.. code::
kernel = lambda sqrt_diag_cov : blackjax.mcmc.mclmc.build_kernel(
kernel = lambda inverse_mass_matrix : blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=logdensity_fn,
integrator=integrator,
sqrt_diag_cov=sqrt_diag_cov,
inverse_mass_matrix=inverse_mass_matrix,
)

(
Expand All @@ -106,7 +106,7 @@ def mclmc_find_L_and_step_size(
"""
dim = pytree_size(state.position)
params = MCLMCAdaptationState(
jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, sqrt_diag_cov=jnp.ones((dim,))
jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, inverse_mass_matrix=jnp.ones((dim,))
)
part1_key, part2_key = jax.random.split(rng_key, 2)

Expand All @@ -123,10 +123,10 @@ def mclmc_find_L_and_step_size(

if frac_tune3 != 0:
state, params = make_adaptation_L(
mclmc_kernel(params.sqrt_diag_cov), frac=frac_tune3, Lfactor=0.4
mclmc_kernel(params.inverse_mass_matrix), frac=frac_tune3, Lfactor=0.4
)(state, params, num_steps, part2_key)

return state, params
return state, params, num_steps * (frac_tune1 + frac_tune2 + frac_tune3)


def make_L_step_size_adaptation(
Expand All @@ -152,7 +152,7 @@ def predictor(previous_state, params, adaptive_state, rng_key):
rng_key, nan_key = jax.random.split(rng_key)

# dynamics
next_state, info = kernel(params.sqrt_diag_cov)(
next_state, info = kernel(params.inverse_mass_matrix)(
rng_key=rng_key,
state=previous_state,
L=params.L,
Expand Down Expand Up @@ -247,15 +247,15 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):

L = params.L
# determine L
sqrt_diag_cov = params.sqrt_diag_cov
inverse_mass_matrix = params.inverse_mass_matrix
if num_steps2 > 1:
x_average, x_squared_average = average[0], average[1]
variances = x_squared_average - jnp.square(x_average)
L = jnp.sqrt(jnp.sum(variances))

if diagonal_preconditioning:
sqrt_diag_cov = jnp.sqrt(variances)
params = params._replace(sqrt_diag_cov=sqrt_diag_cov)
inverse_mass_matrix = variances
params = params._replace(inverse_mass_matrix=inverse_mass_matrix)
L = jnp.sqrt(dim)

# readjust the stepsize
Expand All @@ -265,7 +265,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
xs=(jnp.ones(steps), keys), state=state, params=params
)

return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov)
return state, MCLMCAdaptationState(L, params.step_size, inverse_mass_matrix)

return L_step_size_adaptation

Expand All @@ -274,8 +274,8 @@ def make_adaptation_L(kernel, frac, Lfactor):
"""determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)"""

def adaptation_L(state, params, num_steps, key):
num_steps = int(num_steps * frac)
adaptation_L_keys = jax.random.split(key, num_steps)
num_steps_3 = int(num_steps * frac)
adaptation_L_keys = jax.random.split(key, num_steps_3)

def step(state, key):
next_state, _ = kernel(
Expand All @@ -297,7 +297,7 @@ def step(state, key):
ess = effective_sample_size(flat_samples[None, ...])

return state, params._replace(
L=Lfactor * params.step_size * jnp.mean(num_steps / ess)
L=Lfactor * params.step_size * jnp.mean(num_steps_3 / ess)
)

return adaptation_L
Expand Down
2 changes: 2 additions & 0 deletions blackjax/mcmc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import (
adjusted_mclmc,
adjusted_mclmc_dynamic,
barker,
elliptical_slice,
ghmc,
Expand All @@ -25,5 +26,6 @@
"marginal_latent_gaussian",
"random_walk",
"mclmc",
"adjusted_mclmc_dynamic",
"adjusted_mclmc",
]
Loading
Loading