Skip to content

Latest commit

 

History

History
90 lines (64 loc) · 4.71 KB

README.md

File metadata and controls

90 lines (64 loc) · 4.71 KB

Windowed Non-Linear Reparameterization

forthebadge made-with-python forthebadge

This repository contains the code for Windowed Non Linear Reparameterization built on top of Blacjax. This work was done as a part of summer internship at Aalto University under the guidance of Dr. Nikolas Siccha and Prof. Aki Vehtari. Presentation explaining the concepts involved in the proposed algorithm.

Usage Example

Creating Numpyro model for sampling using Blackjax

def ll_pdf(centeredness, J, tau, mu, theta):
    return dist.Normal(jnp.full(J,mu*(1-centeredness)), jnp.exp(tau*(1-centeredness))).log_prob(theta) - dist.Normal(jnp.full(J,mu), jnp.exp(tau)).log_prob(theta)

def funnel(J=d, c = true_centeredness):
    mu = numpyro.sample('mu', dist.Normal(0, 1))
    tau = numpyro.sample('tau', dist.Normal(0,1))
    theta = numpyro.sample('theta', dist.Normal(jnp.full(J,mu), jnp.exp(tau)))
    numpyro.factor('theta_ll', ll_pdf(c, J, tau, mu, theta))

Running the inference to obtain the estimated centeredness

num_warmup = 1000

adapt = blackjax.window_adaptation(blackjax.nuts, funnel)
key = jax.random.PRNGKey(0)
(last_state, parameters), intermediate_states,logdensity_fn, estimated_centeredness, num_warmup_steps, num_evals  = adapt.run(key, num_warmup)

Why:

At times MCMC algorithms have trouble sampling from distributions. One such example is Neal’s funnel in which due to strong non-linear dependence between latent variables. Non-centering the model removes this dependence, converting the funnel into a spherical Gaussian distribution.

Centeredness vs Non-centeredness:

  • The best parameterization for a given model may lie somewhere between centered and non centered representation.

  • Existing solutions:

    • Variationally Inferred Parameterization[1]

    • NeuTra-lizing Bad Geometry in HMC[2]

Problems with existing solutions

  • Requires separate pre-processing steps apart from the regular warmup and sampling steps which increases the computation cost.
  • Need to tune the hyperparameters for the existing solutions to get good results

Proposed Solution

  • Finding the optimal centeredness during the succesive windows of warmup.
  • Loss function for finding centeredness should be such that it takes the parameterized distribution as close as possible to Normal distribution.

Warmup Phase

  • Used for adaptation of inverse mass matrix ($M^{-1}$) and time step size ($\Delta t$).

  • Consist of three stages:

    image
  • Initial buffer (I): Time step adaptation ($\Delta t$)

  • Window buffer (II): Both Time step ($\Delta t$) & Inverse mass matrix adaptation ($M^{-1}$)

  • Term buffer (III): Final Time step adaptation ($\Delta t$)

Modified Warmup Phase

  • Used for adaptation of inverse mass matrix ($M^{-1}$), time step size ($\Delta t$) and centeredness ($c$).
  • Initial buffer and Term buffer remains the same.
  • Using the samples obtained after each window buffer, optimize the centeredness ($c$) so as to reduce the distance between the present reparameterized distribution and an independent normal distribution.
  • For each succesive window, reparameterize the model based on the optimal centeredness obtained and repeat the step for finding optimal centeredness.

Implementation

  • Modified BlackJax sampler code to incorporate the proposed algorithm.
  • Inference time up from 2 seconds to 20 seconds.
  • To evaluate the results used a model whose true centeredness already known:
    • $\mu$ ~ $N(0,1)$
    • $\log \sigma$ ~ $N(0,1)$
    • $x$ ~ $N((1-c_{i}) \mu, \sigma^{(1-c_{i})})$

Results

  • The Image shows the comparison of the centeredness achieved by our method v/s Variationally Inferred Parameterisation (VIP). It is evident that our method converged to the true centeredness of the model. image
  • The Effective Sample Size per gradient evaluation for our method turns out to be better than VIP.
    Parameter ESS/∇ (vip) ESS/∇ (our)
    μ 9.27e-4 5.54e-4
    τ 4.22e-5 5.07e-4
    θ 9.73e-4 1.66e-3