-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #82 from alexisthual/feat/parcellations
Parcellations as inputs to FUGWSparse
- Loading branch information
Showing
2 changed files
with
238 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import torch | ||
from sklearn.preprocessing import OneHotEncoder | ||
|
||
|
||
def check_labels(labels: torch.Tensor) -> None: | ||
""" | ||
Check that labels are a 1D tensor of integers. | ||
Parameters | ||
---------- | ||
labels: torch.Tensor | ||
Labels to check. | ||
Raises | ||
------ | ||
ValueError | ||
If labels are not a 1D tensor of integers. | ||
""" | ||
if not torch.is_tensor(labels): | ||
raise ValueError(f"labels must be a tensor, got {type(labels)}.") | ||
if labels.dim() != 1: | ||
raise ValueError(f"labels must be a 1D tensor, got {labels.dim()}D.") | ||
if labels.dtype not in { | ||
torch.uint8, | ||
torch.int8, | ||
torch.int16, | ||
torch.int32, | ||
torch.int64, | ||
}: | ||
raise TypeError( | ||
f"labels must be an integer tensor, got {labels.dtype}." | ||
) | ||
|
||
|
||
def one_hot_encoding(labels: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Compute one-hot encoding of the labels. | ||
Parameters | ||
---------- | ||
labels: torch.Tensor of size (n,) | ||
Cluster labels for each voxel. | ||
Must be a 1D tensor of c integers values. | ||
Returns | ||
------- | ||
one_hot: torch.sparse_coo_tensor of size (n, c). | ||
One-hot encoding of the labels. | ||
""" | ||
# Convert labels to string | ||
labels_categorical = labels.cpu().numpy().astype(str).reshape(-1, 1) | ||
# Use sklearn to compute the one-hot encoding | ||
encoder = OneHotEncoder(sparse_output=False) | ||
one_hot = encoder.fit_transform(labels_categorical) | ||
one_hot_tensor = torch.from_numpy(one_hot) | ||
return one_hot_tensor.to_sparse_coo().to(labels.device) | ||
|
||
|
||
def compute_sparsity_mask( | ||
labels: torch.Tensor, | ||
device: str = "auto", | ||
) -> torch.Tensor: | ||
""" | ||
Compute sparsity mask from coarse mapping. | ||
Parameters | ||
---------- | ||
labels: torch.Tensor of size (n,) | ||
Cluster labels for each voxel. | ||
device: "auto" or torch.device | ||
if "auto": use first available gpu if it's available, | ||
cpu otherwise. | ||
Returns | ||
------- | ||
sparsity_mask: torch.sparse_coo_tensor of size (n, m) | ||
Sparsity mask used to initialize the fine mapping. | ||
""" | ||
check_labels(labels) | ||
|
||
if device == "auto": | ||
if torch.cuda.is_available(): | ||
device = torch.device("cuda", 0) | ||
else: | ||
device = torch.device("cpu") | ||
labels = labels.to(device) | ||
|
||
# Create a one-hot encoding of the voxels | ||
one_hot = one_hot_encoding(labels) | ||
return (one_hot @ one_hot.T).coalesce().to(torch.float32) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
from itertools import product | ||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
|
||
from fugw.scripts import piecewise | ||
from fugw.mappings import FUGWSparse | ||
from fugw.utils import _init_mock_distribution | ||
|
||
np.random.seed(0) | ||
torch.manual_seed(0) | ||
|
||
n_voxels = 100 | ||
n_samples_source = 50 | ||
n_samples_target = 45 | ||
n_features_train = 10 | ||
n_features_test = 5 | ||
n_pieces = 10 | ||
|
||
devices = [torch.device("cpu")] | ||
if torch.cuda.is_available(): | ||
devices.append(torch.device("cuda:0")) | ||
|
||
return_numpys = [False, True] | ||
|
||
|
||
@pytest.mark.skip_if_no_mkl | ||
def test_one_hot_encoding(): | ||
labels = torch.randint(0, n_pieces, (n_voxels,)) | ||
one_hot = piecewise.one_hot_encoding(labels) | ||
assert one_hot.shape == (n_voxels, n_pieces) | ||
|
||
|
||
@pytest.mark.skip_if_no_mkl | ||
@pytest.mark.parametrize( | ||
"device", | ||
devices, | ||
) | ||
def test_compute_sparsity_mask(device): | ||
labels = torch.tensor([0, 1, 1], device=device) | ||
mask = piecewise.compute_sparsity_mask(labels, device=device) | ||
assert mask.shape == (3, 3) | ||
assert mask.is_sparse | ||
assert torch.allclose( | ||
mask.to_dense(), | ||
torch.tensor( | ||
[[1.0, 0, 0], [0, 1.0, 1.0], [0, 1.0, 1.0]], device=device | ||
), | ||
) | ||
|
||
labels = torch.randint(0, n_pieces, (n_voxels,)) | ||
sparsity_mask = piecewise.compute_sparsity_mask(labels) | ||
assert sparsity_mask.shape == (n_voxels, n_voxels) | ||
|
||
|
||
@pytest.mark.skip_if_no_mkl | ||
@pytest.mark.parametrize( | ||
"device,return_numpy", | ||
product(devices, return_numpys), | ||
) | ||
def test_piecewise(device, return_numpy): | ||
source_weights, source_features, source_geometry, source_embeddings = ( | ||
_init_mock_distribution( | ||
n_features_train, n_voxels, return_numpy=return_numpy | ||
) | ||
) | ||
target_weights, target_features, target_geometry, target_embeddings = ( | ||
_init_mock_distribution( | ||
n_features_train, n_voxels, return_numpy=return_numpy | ||
) | ||
) | ||
|
||
labels = torch.randint(0, n_pieces, (n_voxels,)) | ||
init_plan = piecewise.compute_sparsity_mask( | ||
labels=labels, | ||
device=device, | ||
) | ||
|
||
piecewise_mapping = FUGWSparse() | ||
piecewise_mapping.fit( | ||
source_features, | ||
target_features, | ||
source_geometry_embedding=source_embeddings, | ||
target_geometry_embedding=target_embeddings, | ||
source_weights=source_weights, | ||
target_weights=target_weights, | ||
init_plan=init_plan, | ||
device=device, | ||
verbose=True, | ||
) | ||
|
||
assert piecewise_mapping.pi.shape == (n_voxels, n_voxels) | ||
|
||
# Use trained model to transport new features | ||
# 1. with numpy arrays | ||
source_features_test = np.random.rand(n_features_test, n_voxels) | ||
target_features_test = np.random.rand(n_features_test, n_voxels) | ||
source_features_on_target = piecewise_mapping.transform( | ||
source_features_test | ||
) | ||
assert source_features_on_target.shape == target_features_test.shape | ||
assert isinstance(source_features_on_target, np.ndarray) | ||
target_features_on_source = piecewise_mapping.inverse_transform( | ||
target_features_test | ||
) | ||
assert target_features_on_source.shape == source_features_test.shape | ||
assert isinstance(target_features_on_source, np.ndarray) | ||
|
||
source_features_test = np.random.rand(n_voxels) | ||
target_features_test = np.random.rand(n_voxels) | ||
source_features_on_target = piecewise_mapping.transform( | ||
source_features_test | ||
) | ||
assert source_features_on_target.shape == target_features_test.shape | ||
assert isinstance(source_features_on_target, np.ndarray) | ||
target_features_on_source = piecewise_mapping.inverse_transform( | ||
target_features_test | ||
) | ||
assert target_features_on_source.shape == source_features_test.shape | ||
assert isinstance(target_features_on_source, np.ndarray) | ||
|
||
# 2. with torch tensors | ||
source_features_test = torch.rand(n_features_test, n_voxels) | ||
target_features_test = torch.rand(n_features_test, n_voxels) | ||
source_features_on_target = piecewise_mapping.transform( | ||
source_features_test | ||
) | ||
assert source_features_on_target.shape == target_features_test.shape | ||
assert isinstance(source_features_on_target, torch.Tensor) | ||
target_features_on_source = piecewise_mapping.inverse_transform( | ||
target_features_test | ||
) | ||
assert target_features_on_source.shape == source_features_test.shape | ||
assert isinstance(target_features_on_source, torch.Tensor) | ||
|
||
source_features_test = torch.rand(n_voxels) | ||
target_features_test = torch.rand(n_voxels) | ||
source_features_on_target = piecewise_mapping.transform( | ||
source_features_test | ||
) | ||
assert source_features_on_target.shape == target_features_test.shape | ||
assert isinstance(source_features_on_target, torch.Tensor) | ||
target_features_on_source = piecewise_mapping.inverse_transform( | ||
target_features_test | ||
) | ||
assert target_features_on_source.shape == source_features_test.shape | ||
assert isinstance(target_features_on_source, torch.Tensor) |