diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 7645a890b..eb11c7bf6 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -203,7 +203,6 @@ def step(iteration_state, weight_and_key): expectation=jnp.array([x, jnp.square(x)]), incremental_val=streaming_avg, weight=(1 - mask) * success * params.step_size, - zero_prevention=mask, ) return (state, params, adaptive_state, streaming_avg), None @@ -243,7 +242,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): L = params.L # determine L sqrt_diag_cov = params.sqrt_diag_cov - if num_steps2 != 0.0: + if num_steps2 > 1: x_average, x_squared_average = average[0], average[1] variances = x_squared_average - jnp.square(x_average) L = jnp.sqrt(jnp.sum(variances)) @@ -260,6 +259,9 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): xs=(jnp.ones(steps), keys), state=state, params=params ) + jax.debug.print( + "params {x}", x=MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov) + ) return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov) return L_step_size_adaptation @@ -304,7 +306,8 @@ def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_ch reduced_step_size = 0.8 p, unravel_fn = ravel_pytree(next_state.position) - nonans = jnp.all(jnp.isfinite(p)) + q, unravel_fn = ravel_pytree(next_state.momentum) + nonans = jnp.logical_and(jnp.all(jnp.isfinite(p)), jnp.all(jnp.isfinite(q))) state, step_size, kinetic_change = jax.tree_util.tree_map( lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), (next_state, step_size_max, kinetic_change), diff --git a/blackjax/util.py b/blackjax/util.py index b6c5367b5..2eae8f7b2 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -281,6 +281,10 @@ def transform(state_and_incremental_val, info): return SamplingAlgorithm(init_fn, update_fn), transform +def safediv(x, y): + return jnp.where(x == 0.0, 0.0, x / y) + + def incremental_value_update( expectation, incremental_val, weight=1.0, zero_prevention=0.0 ): @@ -302,10 +306,12 @@ def incremental_value_update( total, average = incremental_val average = tree_map( - lambda exp, av: (total * av + weight * exp) - / (total + weight + zero_prevention), + lambda exp, av: safediv( + total * av + weight * exp, (total + weight + zero_prevention) + ), expectation, average, ) total += weight - return total, average + incremental_val = total, average + return incremental_val