Skip to content

Commit

Permalink
Refactor test_sparse_barycenter.py to include callback functions
Browse files Browse the repository at this point in the history
  • Loading branch information
pbarbarant committed Jul 15, 2024
1 parent f245fc5 commit 4e525da
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions tests/mappings/test_sparse_barycenter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from itertools import product

import numpy as np
import pytest
import torch
Expand All @@ -9,10 +11,15 @@
if torch.cuda.is_available():
devices.append(torch.device("cuda:0"))

callbacks = [None, lambda x: x["plans"]]


@pytest.mark.skip_if_no_mkl
@pytest.mark.parametrize("device", devices)
def test_fugw_barycenter(device):
@pytest.mark.parametrize(
"device, callback",
product(devices, callbacks),
)
def test_fugw_barycenter(device, callback):
np.random.seed(0)
n_subjects = 4
n_voxels = 100
Expand Down

0 comments on commit 4e525da

Please sign in to comment.