This repository contains the code for Windowed Non Linear Reparameterization built on top of Blacjax. This work was done as a part of summer internship at Aalto University under the guidance of Dr. Nikolas Siccha and Prof. Aki Vehtari. Presentation explaining the concepts involved in the proposed algorithm.
def ll_pdf(centeredness, J, tau, mu, theta):
return dist.Normal(jnp.full(J,mu*(1-centeredness)), jnp.exp(tau*(1-centeredness))).log_prob(theta) - dist.Normal(jnp.full(J,mu), jnp.exp(tau)).log_prob(theta)
def funnel(J=d, c = true_centeredness):
mu = numpyro.sample('mu', dist.Normal(0, 1))
tau = numpyro.sample('tau', dist.Normal(0,1))
theta = numpyro.sample('theta', dist.Normal(jnp.full(J,mu), jnp.exp(tau)))
numpyro.factor('theta_ll', ll_pdf(c, J, tau, mu, theta))
num_warmup = 1000
adapt = blackjax.window_adaptation(blackjax.nuts, funnel)
key = jax.random.PRNGKey(0)
(last_state, parameters), intermediate_states,logdensity_fn, estimated_centeredness, num_warmup_steps, num_evals = adapt.run(key, num_warmup)
At times MCMC algorithms have trouble sampling from distributions. One such example is Neal’s funnel in which due to strong non-linear dependence between latent variables. Non-centering the model removes this dependence, converting the funnel into a spherical Gaussian distribution.
-
The best parameterization for a given model may lie somewhere between centered and non centered representation.
-
Existing solutions:
- Requires separate pre-processing steps apart from the regular warmup and sampling steps which increases the computation cost.
- Need to tune the hyperparameters for the existing solutions to get good results
- Finding the optimal centeredness during the succesive windows of warmup.
- Loss function for finding centeredness should be such that it takes the parameterized distribution as close as possible to Normal distribution.
-
Used for adaptation of inverse mass matrix (
$M^{-1}$ ) and time step size ($\Delta t$ ). -
Consist of three stages:
-
Initial buffer (I): Time step adaptation (
$\Delta t$ ) -
Window buffer (II): Both Time step (
$\Delta t$ ) & Inverse mass matrix adaptation ($M^{-1}$ ) -
Term buffer (III): Final Time step adaptation (
$\Delta t$ )
- Used for adaptation of inverse mass matrix (
$M^{-1}$ ), time step size ($\Delta t$ ) and centeredness ($c$ ). - Initial buffer and Term buffer remains the same.
- Using the samples obtained after each window buffer, optimize the centeredness (
$c$ ) so as to reduce the distance between the present reparameterized distribution and an independent normal distribution. - For each succesive window, reparameterize the model based on the optimal centeredness obtained and repeat the step for finding optimal centeredness.
- Modified BlackJax sampler code to incorporate the proposed algorithm.
- Inference time up from 2 seconds to 20 seconds.
- To evaluate the results used a model whose true centeredness already known:
-
$\mu$ ~$N(0,1)$ -
$\log \sigma$ ~$N(0,1)$ -
$x$ ~$N((1-c_{i}) \mu, \sigma^{(1-c_{i})})$
-
- The Image shows the comparison of the centeredness achieved by our method v/s Variationally Inferred Parameterisation (VIP). It is evident that our method converged to the true centeredness of the model.
- The Effective Sample Size per gradient evaluation for our method turns out to be better than VIP.
Parameter ESS/∇ (vip) ESS/∇ (our) μ 9.27e-4 5.54e-4 τ 4.22e-5 5.07e-4 θ 9.73e-4 1.66e-3