Skip to content

Commit

Permalink
Merge branch 'main' of github.com:alexisthual/fugw
Browse files Browse the repository at this point in the history
  • Loading branch information
alexisthual committed May 27, 2024
2 parents 1befe4e + e0170f7 commit 224bde1
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/fugw/mappings/barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def compute_all_ot_plans(
solver=solver,
solver_params=solver_params,
device=device,
storing_device=device,
verbose=verbose,
)

Expand Down
5 changes: 4 additions & 1 deletion src/fugw/mappings/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def fit(
solver_params={},
callback_bcd=None,
device="auto",
storing_device="cpu",
verbose=False,
):
"""
Expand Down Expand Up @@ -100,6 +101,8 @@ def fit(
device: "auto" or torch.device
if "auto": use first available gpu if it's available,
cpu otherwise.
storing_device: torch.device, default="cpu"
Device on which to store the computed transport plan.
verbose: bool, optional, defaults to False
Log solving process.
Expand Down Expand Up @@ -255,7 +258,7 @@ def fit(
)

# Store variables of interest in model
self.pi = res["pi"].detach().cpu()
self.pi = res["pi"].detach().to(device=storing_device)
self.loss = res["loss"]
self.loss_steps = res["loss_steps"]
self.loss_times = res["loss_times"]
Expand Down
5 changes: 4 additions & 1 deletion src/fugw/mappings/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def fit(
solver_params={},
callback_bcd=None,
device="auto",
storing_device="cpu",
verbose=False,
):
"""
Expand Down Expand Up @@ -106,6 +107,8 @@ def fit(
device: "auto" or torch.device
if "auto": use first available gpu if it's available,
cpu otherwise.
storing_device: torch.device, default="cpu"
Device on which to store the computed transport plan.
verbose: bool, optional, defaults to False
Log solving process.
Expand Down Expand Up @@ -292,7 +295,7 @@ def fit(
verbose=verbose,
)

self.pi = res["pi"].to_sparse_coo().detach().cpu()
self.pi = res["pi"].to_sparse_coo().detach().to(device=storing_device)
self.loss = res["loss"]
self.loss_val = res["loss_val"]
self.loss_steps = res["loss_steps"]
Expand Down
Empty file added tests/__init__.py
Empty file.
17 changes: 17 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch
import pytest


def pytest_configure(config):
config.addinivalue_line(
"markers", "skip_if_no_mkl: Skip test if MKL support is not available."
)


@pytest.fixture(autouse=True)
def check_mkl_availability(request):
if (
"skip_if_no_mkl" in request.keywords
and not torch.backends.mkl.is_available()
):
pytest.skip("Test requires MKL support which is not available.")
3 changes: 3 additions & 0 deletions tests/mappings/test_sparse_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
callbacks = [None, lambda x: x["gamma"]]


@pytest.mark.skip_if_no_mkl
@pytest.mark.parametrize(
"device,return_numpy,solver,callback",
product(devices, return_numpys, solvers, callbacks),
Expand Down Expand Up @@ -95,6 +96,7 @@ def test_sparse_mapping(device, return_numpy, solver, callback):
assert isinstance(target_features_on_source, torch.Tensor)


@pytest.mark.skip_if_no_mkl
@pytest.mark.parametrize(
"device,sparse_layout,return_numpy",
product(devices, sparse_layouts, return_numpys),
Expand Down Expand Up @@ -163,6 +165,7 @@ def test_fugw_sparse_with_init(device, sparse_layout, return_numpy):
assert isinstance(target_features_on_source, torch.Tensor)


@pytest.mark.skip_if_no_mkl
@pytest.mark.parametrize(
"validation", ["None", "features", "geometries", "Both"]
)
Expand Down
2 changes: 2 additions & 0 deletions tests/scripts/test_coarse_to_fine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
return_numpys = [False, True]


@pytest.mark.skip_if_no_mkl
@pytest.mark.parametrize("return_numpy", product(return_numpys))
def test_random_normalizing(return_numpy):
_, _, _, embeddings = _init_mock_distribution(
Expand Down Expand Up @@ -53,6 +54,7 @@ def test_uniform_mesh_sampling():
assert np.unique(sample).shape == (n_samples,)


@pytest.mark.skip_if_no_mkl
@pytest.mark.parametrize(
"device,return_numpy", product(devices, return_numpys)
)
Expand Down
2 changes: 2 additions & 0 deletions tests/solvers/test_sparse_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@


# TODO: need to test sinkhorn
@pytest.mark.skip_if_no_mkl
@pytest.mark.parametrize(
"solver,device,callback,alpha",
product(["sinkhorn", "mm", "ibpp"], devices, callbacks, alphas),
Expand Down Expand Up @@ -143,6 +144,7 @@ def test_sparse_solvers(solver, device, callback, alpha):
)


@pytest.mark.skip_if_no_mkl
@pytest.mark.parametrize(
"validation,device",
product(["None", "features", "geometries", "Both"], devices),
Expand Down

0 comments on commit 224bde1

Please sign in to comment.