Skip to content

Commit

Permalink
Fix burnin for linalg (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamDuffield authored Jul 16, 2024
1 parent 7494f72 commit ab5ddd1
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
23 changes: 13 additions & 10 deletions thermox/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def solve(
num_samples: int = 10000,
dt: float = 1.0,
burnin: int = 0,
key: Array = None,
key: Array | None = None,
associative_scan: bool = True,
) -> Array:
"""
Expand All @@ -37,11 +37,12 @@ def solve(
"""
if key is None:
key = random.PRNGKey(0)
ts = jnp.arange(burnin, burnin + num_samples) * dt
ts = jnp.arange(burnin, burnin + num_samples + 1) * dt
ts = jnp.concatenate([jnp.array([0]), ts])
x0 = jnp.zeros_like(b)
samples = sample_identity_diffusion(
key, ts, x0, A, jnp.linalg.solve(A, b), associative_scan
)
)[1:]
return jnp.mean(samples, axis=0)


Expand All @@ -50,7 +51,7 @@ def inv(
num_samples: int = 10000,
dt: float = 1.0,
burnin: int = 0,
key: Array = None,
key: Array | None = None,
associative_scan: bool = True,
) -> Array:
"""
Expand All @@ -72,10 +73,11 @@ def inv(
"""
if key is None:
key = random.PRNGKey(0)
ts = jnp.arange(burnin, burnin + num_samples) * dt
ts = jnp.arange(burnin, burnin + num_samples + 1) * dt
ts = jnp.concatenate([jnp.array([0]), ts])
b = jnp.zeros(A.shape[0])
x0 = jnp.zeros_like(b)
samples = sample(key, ts, x0, A, b, 2 * jnp.eye(A.shape[0]), associative_scan)
samples = sample(key, ts, x0, A, b, 2 * jnp.eye(A.shape[0]), associative_scan)[1:]
return jnp.cov(samples.T)


Expand All @@ -84,7 +86,7 @@ def expnegm(
num_samples: int = 10000,
dt: float = 1.0,
burnin: int = 0,
key: Array = None,
key: Array | None = None,
alpha: float = 0.0,
associative_scan: bool = True,
) -> Array:
Expand Down Expand Up @@ -113,10 +115,11 @@ def expnegm(
A_shifted = (A + alpha * jnp.eye(A.shape[0])) / dt
B = A_shifted + A_shifted.T

ts = jnp.arange(burnin, burnin + num_samples) * dt
ts = jnp.arange(burnin, burnin + num_samples + 1) * dt
ts = jnp.concatenate([jnp.array([0]), ts])
b = jnp.zeros(A.shape[0])
x0 = jnp.zeros_like(b)
samples = sample(key, ts, x0, A_shifted, b, B, associative_scan)
samples = sample(key, ts, x0, A_shifted, b, B, associative_scan)[1:]
return autocovariance(samples) * jnp.exp(alpha)


Expand All @@ -125,7 +128,7 @@ def expm(
num_samples: int = 10000,
dt: float = 1.0,
burnin: int = 0,
key: Array = None,
key: Array | None = None,
alpha: float = 1.0,
associative_scan: bool = True,
) -> Array:
Expand Down
2 changes: 1 addition & 1 deletion thermox/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def sample(
by using exact diagonalization.
Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2),
Preprocessing (diagonalization) costs O(d^3) and sampling costs O(T * d^2),
where T=len(ts).
If associative_scan=True then jax.lax.associative_scan is used which will run in
Expand Down

0 comments on commit ab5ddd1

Please sign in to comment.