Skip to content

Commit

Permalink
test CI: old tests
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed Jan 21, 2025
1 parent c79a9ac commit 7974eea
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def run_mclmc(
(
blackjax_state_after_tuning,
blackjax_mclmc_sampler_params,
num_tuning_integrator_steps,
) = blackjax.mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_steps,
Expand Down Expand Up @@ -183,6 +184,7 @@ def run_adjusted_mclmc_dynamic(
(
blackjax_state_after_tuning,
blackjax_mclmc_sampler_params,
num_tuning_integrator_steps,
) = blackjax.adjusted_mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_steps,
Expand Down Expand Up @@ -252,6 +254,7 @@ def run_adjusted_mclmc(
(
blackjax_state_after_tuning,
blackjax_mclmc_sampler_params,
num_tuning_integrator_steps,
) = blackjax.adjusted_mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_steps,
Expand Down Expand Up @@ -402,9 +405,10 @@ def test_mclmc(self):
np.testing.assert_allclose(np.mean(scale_samples), 1.0, rtol=1e-2, atol=1e-1)
np.testing.assert_allclose(np.mean(coefs_samples), 3.0, rtol=1e-2, atol=1e-1)

# @parameterized.parameters([True, False])
@parameterized.parameters([True, False])
def test_adjusted_mclmc_dynamic(
self,
diagonal_preconditioning,
):
"""Test the MCLMC kernel."""

Expand All @@ -422,7 +426,7 @@ def test_adjusted_mclmc_dynamic(
logdensity_fn=logdensity_fn,
key=inference_key,
num_steps=10000,
diagonal_preconditioning=True,
diagonal_preconditioning=diagonal_preconditioning,
)

coefs_samples = states["coefs"][3000:]
Expand All @@ -431,10 +435,8 @@ def test_adjusted_mclmc_dynamic(
np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2)
np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2)

# @parameterized.parameters([True, False])
def test_adjusted_mclmc(
self,
):
@parameterized.parameters([True, False])
def test_adjusted_mclmc(self, diagonal_preconditioning):
"""Test the MCLMC kernel."""

init_key0, init_key1, inference_key = jax.random.split(self.key, 3)
Expand All @@ -451,7 +453,7 @@ def test_adjusted_mclmc(
logdensity_fn=logdensity_fn,
key=inference_key,
num_steps=10000,
diagonal_preconditioning=True,
diagonal_preconditioning=diagonal_preconditioning,
)

coefs_samples = states["coefs"][3000:]
Expand Down Expand Up @@ -517,7 +519,7 @@ def get_inverse_mass_matrix():
inverse_mass_matrix=inverse_mass_matrix,
)

(_, blackjax_mclmc_sampler_params) = blackjax.mclmc_find_L_and_step_size(
(_, blackjax_mclmc_sampler_params, _) = blackjax.mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_steps,
state=initial_state,
Expand Down

0 comments on commit 7974eea

Please sign in to comment.