From 51625a8f7cb117b1a25ea4c579b3cc539c109d0b Mon Sep 17 00:00:00 2001 From: Adrien Corenflos Date: Mon, 23 Sep 2024 09:09:18 +0100 Subject: [PATCH] Speed up Schrodinger Follmer test (#741) * Plotting BlackJAX with BlackJAX * Plotting BlackJAX with BlackJAX * Update blackjax/mcmc/metrics.py Co-authored-by: Junpeng Lao * Update blackjax/mcmc/metrics.py Co-authored-by: Junpeng Lao * Merged comments from Junpeng * Speed up Follmer --------- Co-authored-by: Junpeng Lao --- tests/vi/test_schrodinger_follmer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/vi/test_schrodinger_follmer.py b/tests/vi/test_schrodinger_follmer.py index fd58fed0a..79e7afedd 100644 --- a/tests/vi/test_schrodinger_follmer.py +++ b/tests/vi/test_schrodinger_follmer.py @@ -51,7 +51,7 @@ def logp_unnormalized_posterior(x, observed, prior_mu, prior_prec, true_cov): # Simulate the data observed = jax.random.multivariate_normal( - rng_key_observed, true_mu, true_cov, shape=(10_000,) + rng_key_observed, true_mu, true_cov, shape=(25,) ) logp_model = functools.partial(