From 4e525da1281e340ef8c6922e22ec0e3d2d99de4f Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Mon, 15 Jul 2024 16:29:52 +0200 Subject: [PATCH] Refactor test_sparse_barycenter.py to include callback functions --- tests/mappings/test_sparse_barycenter.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/mappings/test_sparse_barycenter.py b/tests/mappings/test_sparse_barycenter.py index 8d6477e..d067a15 100644 --- a/tests/mappings/test_sparse_barycenter.py +++ b/tests/mappings/test_sparse_barycenter.py @@ -1,3 +1,5 @@ +from itertools import product + import numpy as np import pytest import torch @@ -9,10 +11,15 @@ if torch.cuda.is_available(): devices.append(torch.device("cuda:0")) +callbacks = [None, lambda x: x["plans"]] + @pytest.mark.skip_if_no_mkl -@pytest.mark.parametrize("device", devices) -def test_fugw_barycenter(device): +@pytest.mark.parametrize( + "device, callback", + product(devices, callbacks), +) +def test_fugw_barycenter(device, callback): np.random.seed(0) n_subjects = 4 n_voxels = 100