From 3175cb7fa96720671a442a3516c173f95e50fc06 Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Mon, 20 Jan 2025 09:28:35 +0100 Subject: [PATCH] keeping only jit branch --- tests/vi/test_schrodinger_follmer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/vi/test_schrodinger_follmer.py b/tests/vi/test_schrodinger_follmer.py index 33e406d11..13bea143f 100644 --- a/tests/vi/test_schrodinger_follmer.py +++ b/tests/vi/test_schrodinger_follmer.py @@ -14,7 +14,7 @@ def setUp(self): super().setUp() self.key = jax.random.key(1) - @chex.all_variants(with_pmap=True, without_jit=False, without_device=False) + # @chex.all_variants(with_pmap=True) def test_recover_posterior(self): """Simple Normal mean test""" @@ -70,7 +70,7 @@ def logp_unnormalized_posterior(x, observed, prior_mu, prior_prec, true_cov): schrodinger_follmer_algo = schrodinger_follmer(logp_model, 50, 25) initial_state = schrodinger_follmer_algo.init(initial_position) - schrodinger_follmer_algo_sample = self.variant( + schrodinger_follmer_algo_sample = jax.jit( lambda k, s: schrodinger_follmer_algo.sample(k, s, 100) ) sampled_states = schrodinger_follmer_algo_sample(rng_key_init, initial_state)