Skip to content

Commit

Permalink
Add improved goodness of fit implementation (#190)
Browse files Browse the repository at this point in the history
* Started implementing improved goodness of fit implementation

* add tests and improve implementation

* Fix examples

* Fix docstring error

* Handle batch size = None for goodness of fit computation

* adapt GoF implementation

* Fix docstring tests

* Update docstring for goodness_of_fit_score

Co-authored-by: Célia Benquet <[email protected]>

* add annotations to goodness_of_fit_history

Co-authored-by: Célia Benquet <[email protected]>

* fix typo

Co-authored-by: Célia Benquet <[email protected]>

* improve err message

Co-authored-by: Célia Benquet <[email protected]>

* make numerical test less conversative

* Add tests for exception handling

* fix tests

---------

Co-authored-by: Célia Benquet <[email protected]>
  • Loading branch information
stes and CeliaBenquet authored Feb 2, 2025
1 parent 7e74eda commit 4e32661
Show file tree
Hide file tree
Showing 2 changed files with 272 additions and 0 deletions.
143 changes: 143 additions & 0 deletions cebra/integrations/sklearn/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,149 @@ def infonce_loss(
return avg_loss


def goodness_of_fit_score(cebra_model: cebra_sklearn_cebra.CEBRA,
X: Union[npt.NDArray, torch.Tensor],
*y,
session_id: Optional[int] = None,
num_batches: int = 500) -> float:
"""Compute the goodness of fit score on a *single session* dataset on the model.
This function uses the :func:`infonce_loss` function to compute the InfoNCE loss
for a given `cebra_model` and the :func:`infonce_to_goodness_of_fit` function
to derive the goodness of fit from the InfoNCE loss.
Args:
cebra_model: The model to use to compute the InfoNCE loss on the samples.
X: A 2D data matrix, corresponding to a *single session* recording.
y: An arbitrary amount of continuous indices passed as 2D matrices, and up to one
discrete index passed as a 1D array. Each index has to match the length of ``X``.
session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`cebra.CEBRA.num_sessions`
for multisession, set to ``None`` for single session.
num_batches: The number of iterations to consider to evaluate the model on the new data.
Higher values will give a more accurate estimate. Set it to at least 500 iterations.
Returns:
The average GoF score estimated over ``num_batches`` batches from the data distribution.
Related:
:func:`infonce_to_goodness_of_fit`
Example:
>>> import cebra
>>> import numpy as np
>>> neural_data = np.random.uniform(0, 1, (1000, 20))
>>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512)
>>> cebra_model.fit(neural_data)
CEBRA(batch_size=512, max_iterations=10)
>>> gof = cebra.sklearn.metrics.goodness_of_fit_score(cebra_model, neural_data)
"""
loss = infonce_loss(cebra_model,
X,
*y,
session_id=session_id,
num_batches=num_batches,
correct_by_batchsize=False)
return infonce_to_goodness_of_fit(loss, cebra_model)


def goodness_of_fit_history(model: cebra_sklearn_cebra.CEBRA) -> np.ndarray:
"""Return the history of the goodness of fit score.
Args:
model: A trained CEBRA model.
Returns:
A numpy array containing the goodness of fit values, measured in bits.
Related:
:func:`infonce_to_goodness_of_fit`
Example:
>>> import cebra
>>> import numpy as np
>>> neural_data = np.random.uniform(0, 1, (1000, 20))
>>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512)
>>> cebra_model.fit(neural_data)
CEBRA(batch_size=512, max_iterations=10)
>>> gof_history = cebra.sklearn.metrics.goodness_of_fit_history(cebra_model)
"""
infonce = np.array(model.state_dict_["log"]["total"])
return infonce_to_goodness_of_fit(infonce, model)


def infonce_to_goodness_of_fit(
infonce: Union[float, np.ndarray],
model: Optional[cebra_sklearn_cebra.CEBRA] = None,
batch_size: Optional[int] = None,
num_sessions: Optional[int] = None) -> Union[float, np.ndarray]:
"""Given a trained CEBRA model, return goodness of fit metric.
The goodness of fit ranges from 0 (lowest meaningful value)
to a positive number with the unit "bits", the higher the
better.
Values lower than 0 bits are possible, but these only occur
due to numerical effects. A perfectly collapsed embedding
(e.g., because the data cannot be fit with the provided
auxiliary variables) will have a goodness of fit of 0.
The conversion between the generalized InfoNCE metric that
CEBRA is trained with and the goodness of fit computed with this
function is
.. math::
S = \\log N - \\text{InfoNCE}
To use this function, either provide a trained CEBRA model or the
batch size and number of sessions.
Args:
infonce: The InfoNCE loss, either a single value or an iterable of values.
model: The trained CEBRA model.
batch_size: The batch size used to train the model.
num_sessions: The number of sessions used to train the model.
Returns:
Numpy array containing the goodness of fit values, measured in bits
Raises:
RuntimeError: If the provided model is not fit to data.
ValueError: If both ``model`` and ``(batch_size, num_sessions)`` are provided.
"""
if model is not None:
if batch_size is not None or num_sessions is not None:
raise ValueError(
"batch_size and num_sessions should not be provided if model is provided."
)
if not hasattr(model, "state_dict_"):
raise RuntimeError("Fit the CEBRA model first.")
if model.batch_size is None:
raise ValueError(
"Computing the goodness of fit is not yet supported for "
"models trained on the full dataset (batchsize = None). ")
batch_size = model.batch_size
num_sessions = model.num_sessions_
if num_sessions is None:
num_sessions = 1

if model.batch_size is None:
raise ValueError(
"Computing the goodness of fit is not yet supported for "
"models trained on the full dataset (batchsize = None). ")
else:
if batch_size is None or num_sessions is None:
raise ValueError(
f"batch_size ({batch_size}) and num_sessions ({num_sessions})"
f"should be provided if model is not provided.")

nats_to_bits = np.log2(np.e)
chance_level = np.log(batch_size * num_sessions)
return (chance_level - infonce) * nats_to_bits


def _consistency_scores(
embeddings: List[Union[npt.NDArray, torch.Tensor]],
datasets: List[Union[int, str]],
Expand Down
129 changes: 129 additions & 0 deletions tests/test_sklearn_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,3 +383,132 @@ def test_sklearn_runs_consistency():
with pytest.raises(ValueError, match="Invalid.*embeddings"):
_, _, _ = cebra_sklearn_metrics.consistency_score(
invalid_embeddings_runs, between="runs")


@pytest.mark.parametrize("seed", [42, 24, 10])
def test_goodness_of_fit_score(seed):
"""
Ensure that the GoF score is close to 0 for a model fit on random data.
"""
cebra_model = cebra_sklearn_cebra.CEBRA(
model_architecture="offset1-model",
max_iterations=5,
batch_size=512,
)
generator = torch.Generator().manual_seed(seed)
X = torch.rand(5000, 50, dtype=torch.float32, generator=generator)
y = torch.rand(5000, 5, dtype=torch.float32, generator=generator)
cebra_model.fit(X, y)
score = cebra_sklearn_metrics.goodness_of_fit_score(cebra_model,
X,
y,
session_id=0,
num_batches=500)
assert isinstance(score, float)
assert np.isclose(score, 0, atol=0.01)


@pytest.mark.parametrize("seed", [42, 24, 10])
def test_goodness_of_fit_history(seed):
"""
Ensure that the GoF score is higher for a model fit on data with underlying
structure than for a model fit on random data.
"""

# Generate data
generator = torch.Generator().manual_seed(seed)
X = torch.rand(1000, 50, dtype=torch.float32, generator=generator)
y_random = torch.rand(len(X), 5, dtype=torch.float32, generator=generator)
linear_map = torch.randn(50, 5, dtype=torch.float32, generator=generator)
y_linear = X @ linear_map

def _fit_and_get_history(X, y):
cebra_model = cebra_sklearn_cebra.CEBRA(
model_architecture="offset1-model",
max_iterations=150,
batch_size=512,
device="cpu")
cebra_model.fit(X, y)
history = cebra_sklearn_metrics.goodness_of_fit_history(cebra_model)
# NOTE(stes): Ignore the first 5 iterations, they can have nonsensical values
# due to numerical issues.
return history[5:]

history_random = _fit_and_get_history(X, y_random)
history_linear = _fit_and_get_history(X, y_linear)

assert isinstance(history_random, np.ndarray)
assert history_random.shape[0] > 0
# NOTE(stes): Ignore the first 5 iterations, they can have nonsensical values
# due to numerical issues.
history_random_non_negative = history_random[history_random >= 0]
np.testing.assert_allclose(history_random_non_negative, 0, atol=0.075)

assert isinstance(history_linear, np.ndarray)
assert history_linear.shape[0] > 0

assert np.all(history_linear[-20:] > history_random[-20:])


@pytest.mark.parametrize("seed", [42, 24, 10])
def test_infonce_to_goodness_of_fit(seed):
"""Test the conversion from InfoNCE loss to goodness of fit metric."""
# Test with model
cebra_model = cebra_sklearn_cebra.CEBRA(
model_architecture="offset10-model",
max_iterations=5,
batch_size=128,
)
generator = torch.Generator().manual_seed(seed)
X = torch.rand(1000, 50, dtype=torch.float32, generator=generator)
cebra_model.fit(X)

# Test single value
gof = cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
model=cebra_model)
assert isinstance(gof, float)

# Test array of values
infonce_values = np.array([1.0, 2.0, 3.0])
gof_array = cebra_sklearn_metrics.infonce_to_goodness_of_fit(
infonce_values, model=cebra_model)
assert isinstance(gof_array, np.ndarray)
assert gof_array.shape == infonce_values.shape

# Test with explicit batch_size and num_sessions
gof = cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
batch_size=128,
num_sessions=1)
assert isinstance(gof, float)

# Test error cases
with pytest.raises(ValueError, match="batch_size.*should not be provided"):
cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
model=cebra_model,
batch_size=128)

with pytest.raises(ValueError, match="batch_size.*should not be provided"):
cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
model=cebra_model,
num_sessions=1)

# Test with unfitted model
unfitted_model = cebra_sklearn_cebra.CEBRA(max_iterations=5)
with pytest.raises(RuntimeError, match="Fit the CEBRA model first"):
cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
model=unfitted_model)

# Test with model having batch_size=None
none_batch_model = cebra_sklearn_cebra.CEBRA(batch_size=None,
max_iterations=5)
none_batch_model.fit(X)
with pytest.raises(ValueError, match="Computing the goodness of fit"):
cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
model=none_batch_model)

# Test missing batch_size or num_sessions when model is None
with pytest.raises(ValueError, match="batch_size.*and num_sessions"):
cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, batch_size=128)

with pytest.raises(ValueError, match="batch_size.*and num_sessions"):
cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, num_sessions=1)

0 comments on commit 4e32661

Please sign in to comment.