Skip to content

Commit

Permalink
bug fix; first part
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed Aug 21, 2024
1 parent 4a11236 commit 62ba7d2
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
9 changes: 6 additions & 3 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ def step(iteration_state, weight_and_key):
expectation=jnp.array([x, jnp.square(x)]),
incremental_val=streaming_avg,
weight=(1 - mask) * success * params.step_size,
zero_prevention=mask,
)

return (state, params, adaptive_state, streaming_avg), None
Expand Down Expand Up @@ -243,7 +242,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
L = params.L
# determine L
sqrt_diag_cov = params.sqrt_diag_cov
if num_steps2 != 0.0:
if num_steps2 > 1:
x_average, x_squared_average = average[0], average[1]
variances = x_squared_average - jnp.square(x_average)
L = jnp.sqrt(jnp.sum(variances))
Expand All @@ -260,6 +259,9 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
xs=(jnp.ones(steps), keys), state=state, params=params
)

jax.debug.print(
"params {x}", x=MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov)
)
return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov)

return L_step_size_adaptation
Expand Down Expand Up @@ -304,7 +306,8 @@ def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_ch

reduced_step_size = 0.8
p, unravel_fn = ravel_pytree(next_state.position)
nonans = jnp.all(jnp.isfinite(p))
q, unravel_fn = ravel_pytree(next_state.momentum)
nonans = jnp.logical_and(jnp.all(jnp.isfinite(p)), jnp.all(jnp.isfinite(q)))
state, step_size, kinetic_change = jax.tree_util.tree_map(
lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old),
(next_state, step_size_max, kinetic_change),
Expand Down
12 changes: 9 additions & 3 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,10 @@ def transform(state_and_incremental_val, info):
return SamplingAlgorithm(init_fn, update_fn), transform


def safediv(x, y):
return jnp.where(x == 0.0, 0.0, x / y)


def incremental_value_update(
expectation, incremental_val, weight=1.0, zero_prevention=0.0
):
Expand All @@ -302,10 +306,12 @@ def incremental_value_update(

total, average = incremental_val
average = tree_map(
lambda exp, av: (total * av + weight * exp)
/ (total + weight + zero_prevention),
lambda exp, av: safediv(
total * av + weight * exp, (total + weight + zero_prevention)
),
expectation,
average,
)
total += weight
return total, average
incremental_val = total, average
return incremental_val

0 comments on commit 62ba7d2

Please sign in to comment.