Skip to content

Commit

Permalink
chore: Refactor test_barycenter.py to include callback for barycenter…
Browse files Browse the repository at this point in the history
… step
  • Loading branch information
Pierre-Louis Barbarant committed Jul 15, 2024
1 parent 71cf202 commit ac7c395
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions tests/mappings/test_barycenter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from itertools import product

import numpy as np
import pytest
import torch
Expand All @@ -9,9 +11,14 @@
if torch.cuda.is_available():
devices.append(torch.device("cuda:0"))

callbacks = [None, lambda x: x["plans"]]


@pytest.mark.parametrize("device", devices)
def test_fugw_barycenter(device):
@pytest.mark.parametrize(
"device, callback",
product(devices, callbacks),
)
def test_fugw_barycenter(device, callback):
np.random.seed(0)
n_subjects = 4
n_voxels = 100
Expand All @@ -32,5 +39,9 @@ def test_fugw_barycenter(device):

fugw_barycenter = FUGWBarycenter()
fugw_barycenter.fit(
weights_list, features_list, geometry_list, device=device
weights_list,
features_list,
geometry_list,
device=device,
callback_barycenter=callback,
)

0 comments on commit ac7c395

Please sign in to comment.