Skip to content

Commit

Permalink
Merge branch 'main' into adjusted_mclmc_static
Browse files Browse the repository at this point in the history
  • Loading branch information
junpenglao authored Jan 20, 2025
2 parents 6714ad3 + a0812be commit 24014c5
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 45 deletions.
18 changes: 0 additions & 18 deletions .github/workflows/schedule-meeting.yml

This file was deleted.

21 changes: 17 additions & 4 deletions tests/mcmc/test_integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,23 @@ def kinetic_energy(p, position=None):
"c": jnp.ones((2, 1)),
}
_, unravel_fn = ravel_pytree(mvnormal_position_init)
key0, key1 = jax.random.split(jax.random.key(52))
mvnormal_momentum_init = unravel_fn(jax.random.normal(key0, (6,)))
a = jax.random.normal(key1, (6, 6))
cov = jnp.matmul(a.T, a)
mvnormal_momentum_init = {
"a": jnp.asarray(0.53288144),
"b": jnp.asarray([0.25310317, 1.3788314, -0.13486017]),
"c": jnp.asarray([[-0.59082425], [1.2088736]]),
}

cov = jnp.asarray(
[
[5.9959664, 1.1494889, -1.0420643, -0.6328479, -0.20363973, 2.1600752],
[1.1494889, 1.3504763, -0.3601517, -0.98311526, 1.1569028, -1.4185406],
[-1.0420643, -0.3601517, 6.3011055, -2.0662997, -0.10126236, 1.2898219],
[-0.6328479, -0.98311526, -2.0662997, 4.82699, -2.575554, 2.5724294],
[-0.20363973, 1.1569028, -0.10126236, -2.575554, 3.35319, -2.9411654],
[2.1600752, -1.4185406, 1.2898219, 2.5724294, -2.9411654, 6.3740206],
]
)

# Validated numerically
mvnormal_position_end = unravel_fn(
jnp.asarray([0.38887993, 0.85231394, 2.7879136, 3.0339851, 0.5856687, 1.9291426])
Expand Down
22 changes: 8 additions & 14 deletions tests/mcmc/test_proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import jax
import numpy as np
import pytest
from absl.testing import parameterized
from jax import numpy as jnp

from blackjax.mcmc.random_walk import normal
Expand All @@ -10,25 +11,18 @@
class TestNormalProposalDistribution(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.key(20220611)
self.key = jax.random.key(20250120)

def test_normal_univariate(self):
@parameterized.parameters([10.0, 15000.0])
def test_normal_univariate(self, initial_position):
"""
Move samples are generated in the univariate case,
with std following sigma, and independently of the position.
"""
key1, key2 = jax.random.split(self.key)
keys = jax.random.split(self.key, 200)
proposal = normal(sigma=jnp.array([1.0]))
samples_from_initial_position = [
proposal(key, jnp.array([10.0])) for key in jax.random.split(key1, 100)
]
samples_from_another_position = [
proposal(key, jnp.array([15000.0])) for key in jax.random.split(key2, 100)
]

for samples in [samples_from_initial_position, samples_from_another_position]:
np.testing.assert_allclose(0.0, np.mean(samples), rtol=1e-2, atol=1e-1)
np.testing.assert_allclose(1.0, np.std(samples), rtol=1e-2, atol=1e-1)
samples = [proposal(key, jnp.array([initial_position])) for key in keys]
self._check_mean_and_std(jnp.array([0.0]), jnp.array([1.0]), samples)

def test_normal_multivariate(self):
proposal = normal(sigma=jnp.array([1.0, 2.0]))
Expand Down Expand Up @@ -61,7 +55,7 @@ def _check_mean_and_std(expected_mean, expected_std, samples):
)
np.testing.assert_allclose(
expected_std,
np.sqrt(np.diag(np.cov(np.array(samples).T))),
np.sqrt(np.diag(np.atleast_2d(np.cov(np.array(samples).T)))),
rtol=1e-2,
atol=1e-1,
)
17 changes: 9 additions & 8 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test the accuracy of the MCMC kernels."""

import functools
import itertools

Expand Down Expand Up @@ -401,8 +402,8 @@ def test_mclmc(self):
coefs_samples = states["coefs"][3000:]
scale_samples = np.exp(states["log_scale"][3000:])

np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2)
np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2)
np.testing.assert_allclose(np.mean(scale_samples), 1.0, rtol=1e-2, atol=1e-1)
np.testing.assert_allclose(np.mean(coefs_samples), 3.0, rtol=1e-2, atol=1e-1)

@parameterized.parameters([True, False])
def test_adjusted_mclmc_dynamic(
Expand Down Expand Up @@ -458,8 +459,8 @@ def test_adjusted_mclmc(self, diagonal_preconditioning):
coefs_samples = states["coefs"][3000:]
scale_samples = np.exp(states["log_scale"][3000:])

np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2)
np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2)
np.testing.assert_allclose(np.mean(scale_samples), 1.0, rtol=1e-2, atol=1e-1)
np.testing.assert_allclose(np.mean(coefs_samples), 3.0, rtol=1e-2, atol=1e-1)

def test_mclmc_preconditioning(self):
class IllConditionedGaussian:
Expand Down Expand Up @@ -707,8 +708,8 @@ def test_barker(self):
coefs_samples = states["coefs"][3000:]
scale_samples = np.exp(states["log_scale"][3000:])

np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2)
np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2)
np.testing.assert_allclose(np.mean(scale_samples), 1.0, rtol=1e-2, atol=1e-1)
np.testing.assert_allclose(np.mean(coefs_samples), 3.0, rtol=1e-2, atol=1e-1)


class SGMCMCTest(chex.TestCase):
Expand Down Expand Up @@ -961,7 +962,7 @@ def test_irmh(self):
@chex.all_variants(with_pmap=False)
def test_nuts(self):
inference_algorithm = blackjax.nuts(
self.normal_logprob, step_size=4.0, inverse_mass_matrix=jnp.array([1.0])
self.normal_logprob, step_size=1.0, inverse_mass_matrix=jnp.array([1.0])
)

initial_state = inference_algorithm.init(jnp.array(3.0))
Expand Down Expand Up @@ -1121,7 +1122,7 @@ def test_barker(self):
},
{
"algorithm": blackjax.barker_proposal,
"parameters": {"step_size": 0.5},
"parameters": {"step_size": 0.45},
"is_mass_matrix_diagonal": None,
},
]
Expand Down
2 changes: 1 addition & 1 deletion tests/smc/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_smc_waste_free(self):
{},
)
same_for_all_params = dict(
step_size=1e-2, inverse_mass_matrix=jnp.eye(1), num_integration_steps=50
step_size=1e-2, inverse_mass_matrix=jnp.eye(1), num_integration_steps=100
)
hmc_kernel = functools.partial(
blackjax.hmc.build_kernel(), **same_for_all_params
Expand Down

0 comments on commit 24014c5

Please sign in to comment.