Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix how sample function handles Processed inputs #25

Merged
merged 8 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions tests/test_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import jax
from jax import numpy as jnp

import thermox


def test_sample_array_input():
key = jax.random.PRNGKey(0)
dim = 2
dt = 0.1
ts = jnp.arange(0, 10_000, dt)

A = jnp.array([[3, 2], [2, 4.0]])
b, x0 = jnp.zeros(dim), jnp.zeros(dim)
D = 2 * jnp.eye(dim)

samples = thermox.sample(key, ts, x0, A, b, D)

samp_cov = jnp.cov(samples.T)
samp_mean = jnp.mean(samples.T, axis=1)
assert jnp.allclose(A @ samp_cov, jnp.eye(2), atol=1e-1)
assert jnp.allclose(samp_mean, b, atol=1e-1)
60 changes: 60 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from jax import numpy as jnp

from thermox.utils import (
handle_matrix_inputs,
ProcessedDriftMatrix,
ProcessedDiffusionMatrix,
preprocess,
)


def test_handle_matrix_inputs_arrays():
A = jnp.array([[1, 3], [1, 4]])
D = jnp.array([[9, 4], [4, 20]])

a, d = preprocess(A, D)

A_star, D_star = preprocess(A, D)

assert isinstance(A_star, ProcessedDriftMatrix)
assert isinstance(D_star, ProcessedDiffusionMatrix)
assert jnp.all(a.val == A_star.val)


def test_handle_matrix_inputs_processed():
A = jnp.array([[1, 3], [1, 4]])
D = jnp.array([[9, 4], [4, 20]])

a, d = preprocess(A, D)

A_star, D_star = handle_matrix_inputs(a, d)

assert isinstance(A_star, ProcessedDriftMatrix)
assert isinstance(D_star, ProcessedDiffusionMatrix)
assert jnp.all(a.val == A_star.val)


def test_handle_matrix_inputs_array_drift_processed_diffusion():
A = jnp.array([[1, 3], [1, 4]])
D = jnp.array([[9, 4], [4, 20]])

a, d = preprocess(A, D)

A_star, D_star = handle_matrix_inputs(A, d)

assert isinstance(A_star, ProcessedDriftMatrix)
assert isinstance(D_star, ProcessedDiffusionMatrix)
assert jnp.all(a.val == A_star.val)


def test_handle_matrix_inputs_array_diffusion_processed_drift():
A = jnp.array([[1, 3], [1, 4]])
D = jnp.array([[9, 4], [4, 20]])

a, d = preprocess(A, D)

A_star, D_star = handle_matrix_inputs(a, D)

assert isinstance(A_star, ProcessedDriftMatrix)
assert isinstance(D_star, ProcessedDiffusionMatrix)
assert not jnp.all(a.val == A_star.val)
37 changes: 20 additions & 17 deletions thermox/prob.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from jax import Array, vmap

from thermox.utils import (
preprocess,
handle_matrix_inputs,
preprocess_drift_matrix,
ProcessedDriftMatrix,
ProcessedDiffusionMatrix,
Expand All @@ -28,10 +28,10 @@ def log_prob_identity_diffusion(
Preprocessing (diagonalisation) costs O(d^3) and evaluation then costs O(T * d^2).

Args:
ts: array-like, times at which samples are collected. Includes time for x0.
xs: initial state of the process.
A: drift matrix (Array or thermox.ProcessedDriftMatrix).
b: drift displacement vector.
ts: Times at which samples are collected. Includes time for x0.
xs: Initial state of the process.
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
b: Drift displacement vector.
Returns:
Scalar log probability of given xs.
"""
Expand Down Expand Up @@ -94,24 +94,27 @@ def log_prob(

Assumes x(t_0) is given deterministically.

Preprocessing (diagonalisation) costs O(d^3) and evaluation then costs O(T * d^2).
Preprocessing (diagonalisation) costs O(d^3) and evaluation then costs O(T * d^2),
where T=len(ts).

By default, this function does the preprocessing on A and D before the evaluation.
However, the preprocessing can be done externally using thermox.preprocess
the output of which can be used as A and D here, this will skip the preprocessing.

Args:
ts: array-like, times at which samples are collected. Includes time for x0.
xs: initial state of the process.
A: drift matrix (Array or thermox.ProcessedDriftMatrix).
b: drift displacement vector.
D: diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).
ts: Times at which samples are collected. Includes time for x0.
xs: Initial state of the process.
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
Note : If a thermox.ProcessedDriftMatrix instance is used as input,
must be the transformed drift matrix, A_y, given by thermox.preprocess,
not thermox.utils.preprocess_drift_matrix.
b: Drift displacement vector.
D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).

Returns:
Scalar log probability of given xs.
"""
if isinstance(A, Array) or isinstance(D, Array):
if isinstance(A, ProcessedDriftMatrix):
A = A.val
if isinstance(D, ProcessedDiffusionMatrix):
D = D.val
A_y, D = preprocess(A, D)
A_y, D = handle_matrix_inputs(A, D)

ys = vmap(jnp.matmul, in_axes=(None, 0))(D.sqrt_inv, xs)
b_y = D.sqrt_inv @ b
Expand Down
39 changes: 21 additions & 18 deletions thermox/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jax import Array

from thermox.utils import (
preprocess,
handle_matrix_inputs,
preprocess_drift_matrix,
ProcessedDriftMatrix,
ProcessedDiffusionMatrix,
Expand All @@ -28,11 +28,11 @@ def sample_identity_diffusion(
where T=len(ts).

Args:
key: jax PRNGKey.
ts: array-like, times at which samples are collected. Includes time for x0.
x0: initial state of the process.
A: drift matrix (Array or thermox.ProcessedDriftMatrix).
b: drift displacement vector.
key: Jax PRNGKey.
ts: Times at which samples are collected. Includes time for x0.
x0: Initial state of the process.
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
b: Drift displacement vector.

Returns:
Array-like, desired samples.
Expand Down Expand Up @@ -88,26 +88,29 @@ def sample(

by using exact diagonalization.

Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2)
Preprocessing (diagonalisation) costs O(d^3) and sampling costs O(T * d^2),
where T=len(ts).

By default, this function does the preprocessing on A and D before the evaluation.
However, the preprocessing can be done externally using thermox.preprocess
the output of which can be used as A and D here, this will skip the preprocessing.

Args:
key: jax PRNGKey.
ts: array-like, times at which samples are collected. Includes time for x0.
x0: initial state of the process.
A: drift matrix (Array or thermox.ProcessedDriftMatrix).
b: drift displacement vector.
D: diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).
key: Jax PRNGKey.
ts: Times at which samples are collected. Includes time for x0.
x0: Initial state of the process.
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
Note : If a thermox.ProcessedDriftMatrix instance is used as input,
must be the transformed drift matrix, A_y, given by thermox.preprocess,
not thermox.utils.preprocess_drift_matrix.
b: Drift displacement vector.
D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).

Returns:
Array-like, desired samples.
shape: (len(ts), ) + x0.shape
"""
if isinstance(A, Array) and isinstance(D, Array):
A_y, D = preprocess(A, D)

assert isinstance(A_y, ProcessedDriftMatrix)
assert isinstance(D, ProcessedDiffusionMatrix)
A_y, D = handle_matrix_inputs(A, D)

y0 = D.sqrt_inv @ x0
b_y = D.sqrt_inv @ b
Expand Down
33 changes: 29 additions & 4 deletions thermox/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def preprocess_drift_matrix(A: Array) -> ProcessedDriftMatrix:
"""Preprocesses matrix A (calculates eigendecompositions of A and (A+A^T)/2)

Args:
A: drift matrix.
A: Drift matrix.

Returns:
ProcessedDriftMatrix containing eigendeomcomposition of A and (A+A^T)/2.
Expand Down Expand Up @@ -59,7 +59,7 @@ def preprocess_diffusion_matrix(D: Array) -> ProcessedDiffusionMatrix:
"""Preprocesses diffusion matrix D (calculates D^0.5 and D^-0.5 via Cholesky)

Args:
D: diffusion matrix.
D: Diffusion matrix.

Returns:
ProcessedDiffusionMatrix containing D^0.5 and D^-0.5.
Expand All @@ -77,8 +77,8 @@ def preprocess(
D^0.5 and D^-0.5)

Args:
A: drift matrix.
D: diffusion matrix.
A: Drift matrix.
D: Diffusion matrix.

Returns:
ProcessedDriftMatrix containing eigendecomposition of A_y and (A_y+A_y^T)/2.
Expand All @@ -89,3 +89,28 @@ def preprocess(
A_y = PD.sqrt_inv @ A @ PD.sqrt
PA_y = preprocess_drift_matrix(A_y)
return PA_y, PD


def handle_matrix_inputs(
SamDuffield marked this conversation as resolved.
Show resolved Hide resolved
A: Array | ProcessedDriftMatrix, D: Array | ProcessedDiffusionMatrix
) -> Tuple[ProcessedDriftMatrix, ProcessedDiffusionMatrix]:
"""Checks the type of the input drift matrix, A, and diffusion matrix, D,
and ensures that they are processed in the correct way.
Helper function for sample and log_prob functions.

Args:
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).

Returns:
ProcessedDriftMatrix containing eigendecomposition of A_y and (A_y+A_y^T)/2.
where A_y = D^-0.5 @ A @ D^0.5
ProcessedDiffusionMatrix containing D^0.5 and D^-0.5.
"""
if isinstance(A, Array) or isinstance(D, Array):
if isinstance(A, ProcessedDriftMatrix):
A = A.val
if isinstance(D, ProcessedDiffusionMatrix):
D = D.val
A, D = preprocess(A, D)
return A, D