Skip to content

Commit

Permalink
keeping only jit branch
Browse files Browse the repository at this point in the history
  • Loading branch information
junpenglao committed Jan 20, 2025
1 parent 5b39873 commit 3175cb7
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/vi/test_schrodinger_follmer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def setUp(self):
super().setUp()
self.key = jax.random.key(1)

@chex.all_variants(with_pmap=True, without_jit=False, without_device=False)
# @chex.all_variants(with_pmap=True)
def test_recover_posterior(self):
"""Simple Normal mean test"""

Expand Down Expand Up @@ -70,7 +70,7 @@ def logp_unnormalized_posterior(x, observed, prior_mu, prior_prec, true_cov):
schrodinger_follmer_algo = schrodinger_follmer(logp_model, 50, 25)

initial_state = schrodinger_follmer_algo.init(initial_position)
schrodinger_follmer_algo_sample = self.variant(
schrodinger_follmer_algo_sample = jax.jit(
lambda k, s: schrodinger_follmer_algo.sample(k, s, 100)
)
sampled_states = schrodinger_follmer_algo_sample(rng_key_init, initial_state)
Expand Down

0 comments on commit 3175cb7

Please sign in to comment.