Skip to content

A simple implementation of Hamiltonian Monte Carlo in JAX.

License

Notifications You must be signed in to change notification settings

martin-marek/mini-hmc-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mini-hmc-jax

This is a simple vectorized implementation of Hamiltonian Monte Carlo in JAX.

Here's a minimal example to sample from a distribution:

import jax
import jax.numpy as jnp
import hmc

# define target distribution
def target_log_pdf(params):
    return jax.scipy.stats.t.logpdf(params, df=1).sum()

# run HMC
params_init = jnp.zeros(10)
key = jax.random.PRNGKey(0)
chain = hmc.sample(key, params_init, target_log_pdf, n_steps=10, n_leapfrog_steps=100, step_size=0.1)

Based on the official repository for the paper What Are Bayesian Neural Network Posteriors Really Like?

@article{izmailov2021bayesian,
  title={What Are Bayesian Neural Network Posteriors Really Like?},
  author={Izmailov, Pavel and Vikram, Sharad and Hoffman, Matthew D and Wilson, Andrew Gordon},
  journal={arXiv preprint arXiv:2104.14421},
  year={2021}
}

About

A simple implementation of Hamiltonian Monte Carlo in JAX.

Resources

License

Stars

Watchers

Forks

Languages