From 3f7981d135f10919ce8b458babf06d2d8988c042 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Mon, 4 Nov 2024 16:39:31 +0100 Subject: [PATCH 1/3] Add piecewise.py with label checking and one-hot encoding functions --- src/fugw/scripts/piecewise.py | 90 +++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 src/fugw/scripts/piecewise.py diff --git a/src/fugw/scripts/piecewise.py b/src/fugw/scripts/piecewise.py new file mode 100644 index 0000000..fc4376a --- /dev/null +++ b/src/fugw/scripts/piecewise.py @@ -0,0 +1,90 @@ +import torch +from sklearn.preprocessing import OneHotEncoder + + +def check_labels(labels): + """ + Check that labels are a 1D tensor of integers. + + Parameters + ---------- + labels: torch.Tensor + Labels to check. + + Raises + ------ + ValueError + If labels are not a 1D tensor of integers. + """ + if not torch.is_tensor(labels): + raise ValueError(f"labels must be a tensor, got {type(labels)}.") + if labels.dim() != 1: + raise ValueError(f"labels must be a 1D tensor, got {labels.dim()}D.") + if labels.dtype not in { + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + }: + raise TypeError( + f"labels must be an integer tensor, got {labels.dtype}." + ) + + +def one_hot_encoding(labels): + """ + Compute one-hot encoding of the labels. + + Parameters + ---------- + labels: torch.Tensor of size (n,) + Cluster labels for each voxel. + Must be a 1D tensor of c integers values. + + Returns + ------- + one_hot: torch.sparse_coo_tensor of size (n, c). + One-hot encoding of the labels. + """ + # Convert labels to string + labels_categorical = labels.cpu().numpy().astype(str).reshape(-1, 1) + # Use sklearn to compute the one-hot encoding + encoder = OneHotEncoder(sparse_output=False) + one_hot = encoder.fit_transform(labels_categorical) + one_hot_tensor = torch.from_numpy(one_hot) + return one_hot_tensor.to_sparse_coo().to(labels.device) + + +def compute_sparsity_mask( + labels, + device="auto", +): + """ + Compute sparsity mask from coarse mapping. + + Parameters + ---------- + labels: torch.Tensor of size (n,) + Cluster labels for each voxel. + device: "auto" or torch.device + if "auto": use first available gpu if it's available, + cpu otherwise. + + Returns + ------- + sparsity_mask: torch.sparse_coo_tensor of size (n, m) + Sparsity mask used to initialize the fine mapping. + """ + check_labels(labels) + + if device == "auto": + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + else: + device = torch.device("cpu") + labels = labels.to(device) + + # Create a one-hot encoding of the voxels + one_hot = one_hot_encoding(labels) + return (one_hot @ one_hot.T).coalesce().to(torch.float32) From 622174c0c3eedecd355299211874b514e0d672a6 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Mon, 4 Nov 2024 16:39:40 +0100 Subject: [PATCH 2/3] Add unit tests for piecewise functions including one-hot encoding and sparsity mask --- tests/scripts/test_piecewise.py | 148 ++++++++++++++++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 tests/scripts/test_piecewise.py diff --git a/tests/scripts/test_piecewise.py b/tests/scripts/test_piecewise.py new file mode 100644 index 0000000..0b52b80 --- /dev/null +++ b/tests/scripts/test_piecewise.py @@ -0,0 +1,148 @@ +from itertools import product + +import numpy as np +import pytest +import torch + +from fugw.scripts import piecewise +from fugw.mappings import FUGWSparse +from fugw.utils import _init_mock_distribution + +np.random.seed(0) +torch.manual_seed(0) + +n_voxels = 100 +n_samples_source = 50 +n_samples_target = 45 +n_features_train = 10 +n_features_test = 5 +n_pieces = 10 + +devices = [torch.device("cpu")] +if torch.cuda.is_available(): + devices.append(torch.device("cuda:0")) + +return_numpys = [False, True] + + +@pytest.mark.skip_if_no_mkl +def test_one_hot_encoding(): + labels = torch.randint(0, n_pieces, (n_voxels,)) + one_hot = piecewise.one_hot_encoding(labels) + assert one_hot.shape == (n_voxels, n_pieces) + + +@pytest.mark.skip_if_no_mkl +@pytest.mark.parametrize( + "device", + devices, +) +def test_compute_sparsity_mask(device): + labels = torch.tensor([0, 1, 1], device=device) + mask = piecewise.compute_sparsity_mask(labels, device=device) + assert mask.shape == (3, 3) + assert mask.is_sparse + assert torch.allclose( + mask.to_dense(), + torch.tensor( + [[1.0, 0, 0], [0, 1.0, 1.0], [0, 1.0, 1.0]], device=device + ), + ) + + labels = torch.randint(0, n_pieces, (n_voxels,)) + sparsity_mask = piecewise.compute_sparsity_mask(labels) + assert sparsity_mask.shape == (n_voxels, n_voxels) + + +@pytest.mark.skip_if_no_mkl +@pytest.mark.parametrize( + "device,return_numpy", + product(devices, return_numpys), +) +def test_piecewise(device, return_numpy): + source_weights, source_features, source_geometry, source_embeddings = ( + _init_mock_distribution( + n_features_train, n_voxels, return_numpy=return_numpy + ) + ) + target_weights, target_features, target_geometry, target_embeddings = ( + _init_mock_distribution( + n_features_train, n_voxels, return_numpy=return_numpy + ) + ) + + labels = torch.randint(0, n_pieces, (n_voxels,)) + init_plan = piecewise.compute_sparsity_mask( + labels=labels, + device=device, + ) + + piecewise_mapping = FUGWSparse() + piecewise_mapping.fit( + source_features, + target_features, + source_geometry_embedding=source_embeddings, + target_geometry_embedding=target_embeddings, + source_weights=source_weights, + target_weights=target_weights, + init_plan=init_plan, + device=device, + verbose=True, + ) + + assert piecewise_mapping.pi.shape == (n_voxels, n_voxels) + + # Use trained model to transport new features + # 1. with numpy arrays + source_features_test = np.random.rand(n_features_test, n_voxels) + target_features_test = np.random.rand(n_features_test, n_voxels) + source_features_on_target = piecewise_mapping.transform( + source_features_test + ) + assert source_features_on_target.shape == target_features_test.shape + assert isinstance(source_features_on_target, np.ndarray) + target_features_on_source = piecewise_mapping.inverse_transform( + target_features_test + ) + assert target_features_on_source.shape == source_features_test.shape + assert isinstance(target_features_on_source, np.ndarray) + + source_features_test = np.random.rand(n_voxels) + target_features_test = np.random.rand(n_voxels) + source_features_on_target = piecewise_mapping.transform( + source_features_test + ) + assert source_features_on_target.shape == target_features_test.shape + assert isinstance(source_features_on_target, np.ndarray) + target_features_on_source = piecewise_mapping.inverse_transform( + target_features_test + ) + assert target_features_on_source.shape == source_features_test.shape + assert isinstance(target_features_on_source, np.ndarray) + + # 2. with torch tensors + source_features_test = torch.rand(n_features_test, n_voxels) + target_features_test = torch.rand(n_features_test, n_voxels) + source_features_on_target = piecewise_mapping.transform( + source_features_test + ) + assert source_features_on_target.shape == target_features_test.shape + assert isinstance(source_features_on_target, torch.Tensor) + target_features_on_source = piecewise_mapping.inverse_transform( + target_features_test + ) + assert target_features_on_source.shape == source_features_test.shape + assert isinstance(target_features_on_source, torch.Tensor) + + source_features_test = torch.rand(n_voxels) + target_features_test = torch.rand(n_voxels) + source_features_on_target = piecewise_mapping.transform( + source_features_test + ) + assert source_features_on_target.shape == target_features_test.shape + assert isinstance(source_features_on_target, torch.Tensor) + target_features_on_source = piecewise_mapping.inverse_transform( + target_features_test + ) + assert target_features_on_source.shape == source_features_test.shape + assert isinstance(target_features_on_source, torch.Tensor) From d4a9910b172049faac48e94a100444bdb0f8a403 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Mon, 4 Nov 2024 18:37:38 +0100 Subject: [PATCH 3/3] Add type hints for piecewise.py --- src/fugw/scripts/piecewise.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/fugw/scripts/piecewise.py b/src/fugw/scripts/piecewise.py index fc4376a..0bf02a1 100644 --- a/src/fugw/scripts/piecewise.py +++ b/src/fugw/scripts/piecewise.py @@ -2,7 +2,7 @@ from sklearn.preprocessing import OneHotEncoder -def check_labels(labels): +def check_labels(labels: torch.Tensor) -> None: """ Check that labels are a 1D tensor of integers. @@ -32,7 +32,7 @@ def check_labels(labels): ) -def one_hot_encoding(labels): +def one_hot_encoding(labels: torch.Tensor) -> torch.Tensor: """ Compute one-hot encoding of the labels. @@ -57,9 +57,9 @@ def one_hot_encoding(labels): def compute_sparsity_mask( - labels, - device="auto", -): + labels: torch.Tensor, + device: str = "auto", +) -> torch.Tensor: """ Compute sparsity mask from coarse mapping.