From 1348369c15bb560c99ded165f74bdd91099576b2 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Wed, 5 Jun 2024 13:07:02 +0100 Subject: [PATCH] Fix solve --- thermox/linalg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thermox/linalg.py b/thermox/linalg.py index c6edbbe..a544aea 100644 --- a/thermox/linalg.py +++ b/thermox/linalg.py @@ -1,7 +1,7 @@ import jax.numpy as jnp from jax.lax import fori_loop from jax import Array, random -from thermox.sampler import sample, _sample_identity_diffusion +from thermox.sampler import sample, sample_identity_diffusion from thermox.utils import ProcessedDriftMatrix @@ -37,7 +37,7 @@ def solve( key = random.PRNGKey(0) ts = jnp.arange(burnin, burnin + num_samples) * dt x0 = jnp.zeros_like(b) - samples = _sample_identity_diffusion(key, ts, x0, A, jnp.linalg.solve(A, b)) + samples = sample_identity_diffusion(key, ts, x0, A, jnp.linalg.solve(A, b)) return jnp.mean(samples, axis=0)