Skip to content

Commit

Permalink
Merge pull request #82 from alexisthual/feat/parcellations
Browse files Browse the repository at this point in the history
Parcellations as inputs to FUGWSparse
  • Loading branch information
pbarbarant authored Nov 5, 2024
2 parents bede443 + d4a9910 commit b77086e
Show file tree
Hide file tree
Showing 2 changed files with 238 additions and 0 deletions.
90 changes: 90 additions & 0 deletions src/fugw/scripts/piecewise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch
from sklearn.preprocessing import OneHotEncoder


def check_labels(labels: torch.Tensor) -> None:
"""
Check that labels are a 1D tensor of integers.
Parameters
----------
labels: torch.Tensor
Labels to check.
Raises
------
ValueError
If labels are not a 1D tensor of integers.
"""
if not torch.is_tensor(labels):
raise ValueError(f"labels must be a tensor, got {type(labels)}.")
if labels.dim() != 1:
raise ValueError(f"labels must be a 1D tensor, got {labels.dim()}D.")
if labels.dtype not in {
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
}:
raise TypeError(
f"labels must be an integer tensor, got {labels.dtype}."
)


def one_hot_encoding(labels: torch.Tensor) -> torch.Tensor:
"""
Compute one-hot encoding of the labels.
Parameters
----------
labels: torch.Tensor of size (n,)
Cluster labels for each voxel.
Must be a 1D tensor of c integers values.
Returns
-------
one_hot: torch.sparse_coo_tensor of size (n, c).
One-hot encoding of the labels.
"""
# Convert labels to string
labels_categorical = labels.cpu().numpy().astype(str).reshape(-1, 1)
# Use sklearn to compute the one-hot encoding
encoder = OneHotEncoder(sparse_output=False)
one_hot = encoder.fit_transform(labels_categorical)
one_hot_tensor = torch.from_numpy(one_hot)
return one_hot_tensor.to_sparse_coo().to(labels.device)


def compute_sparsity_mask(
labels: torch.Tensor,
device: str = "auto",
) -> torch.Tensor:
"""
Compute sparsity mask from coarse mapping.
Parameters
----------
labels: torch.Tensor of size (n,)
Cluster labels for each voxel.
device: "auto" or torch.device
if "auto": use first available gpu if it's available,
cpu otherwise.
Returns
-------
sparsity_mask: torch.sparse_coo_tensor of size (n, m)
Sparsity mask used to initialize the fine mapping.
"""
check_labels(labels)

if device == "auto":
if torch.cuda.is_available():
device = torch.device("cuda", 0)
else:
device = torch.device("cpu")
labels = labels.to(device)

# Create a one-hot encoding of the voxels
one_hot = one_hot_encoding(labels)
return (one_hot @ one_hot.T).coalesce().to(torch.float32)
148 changes: 148 additions & 0 deletions tests/scripts/test_piecewise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from itertools import product

import numpy as np
import pytest
import torch

from fugw.scripts import piecewise
from fugw.mappings import FUGWSparse
from fugw.utils import _init_mock_distribution

np.random.seed(0)
torch.manual_seed(0)

n_voxels = 100
n_samples_source = 50
n_samples_target = 45
n_features_train = 10
n_features_test = 5
n_pieces = 10

devices = [torch.device("cpu")]
if torch.cuda.is_available():
devices.append(torch.device("cuda:0"))

return_numpys = [False, True]


@pytest.mark.skip_if_no_mkl
def test_one_hot_encoding():
labels = torch.randint(0, n_pieces, (n_voxels,))
one_hot = piecewise.one_hot_encoding(labels)
assert one_hot.shape == (n_voxels, n_pieces)


@pytest.mark.skip_if_no_mkl
@pytest.mark.parametrize(
"device",
devices,
)
def test_compute_sparsity_mask(device):
labels = torch.tensor([0, 1, 1], device=device)
mask = piecewise.compute_sparsity_mask(labels, device=device)
assert mask.shape == (3, 3)
assert mask.is_sparse
assert torch.allclose(
mask.to_dense(),
torch.tensor(
[[1.0, 0, 0], [0, 1.0, 1.0], [0, 1.0, 1.0]], device=device
),
)

labels = torch.randint(0, n_pieces, (n_voxels,))
sparsity_mask = piecewise.compute_sparsity_mask(labels)
assert sparsity_mask.shape == (n_voxels, n_voxels)


@pytest.mark.skip_if_no_mkl
@pytest.mark.parametrize(
"device,return_numpy",
product(devices, return_numpys),
)
def test_piecewise(device, return_numpy):
source_weights, source_features, source_geometry, source_embeddings = (
_init_mock_distribution(
n_features_train, n_voxels, return_numpy=return_numpy
)
)
target_weights, target_features, target_geometry, target_embeddings = (
_init_mock_distribution(
n_features_train, n_voxels, return_numpy=return_numpy
)
)

labels = torch.randint(0, n_pieces, (n_voxels,))
init_plan = piecewise.compute_sparsity_mask(
labels=labels,
device=device,
)

piecewise_mapping = FUGWSparse()
piecewise_mapping.fit(
source_features,
target_features,
source_geometry_embedding=source_embeddings,
target_geometry_embedding=target_embeddings,
source_weights=source_weights,
target_weights=target_weights,
init_plan=init_plan,
device=device,
verbose=True,
)

assert piecewise_mapping.pi.shape == (n_voxels, n_voxels)

# Use trained model to transport new features
# 1. with numpy arrays
source_features_test = np.random.rand(n_features_test, n_voxels)
target_features_test = np.random.rand(n_features_test, n_voxels)
source_features_on_target = piecewise_mapping.transform(
source_features_test
)
assert source_features_on_target.shape == target_features_test.shape
assert isinstance(source_features_on_target, np.ndarray)
target_features_on_source = piecewise_mapping.inverse_transform(
target_features_test
)
assert target_features_on_source.shape == source_features_test.shape
assert isinstance(target_features_on_source, np.ndarray)

source_features_test = np.random.rand(n_voxels)
target_features_test = np.random.rand(n_voxels)
source_features_on_target = piecewise_mapping.transform(
source_features_test
)
assert source_features_on_target.shape == target_features_test.shape
assert isinstance(source_features_on_target, np.ndarray)
target_features_on_source = piecewise_mapping.inverse_transform(
target_features_test
)
assert target_features_on_source.shape == source_features_test.shape
assert isinstance(target_features_on_source, np.ndarray)

# 2. with torch tensors
source_features_test = torch.rand(n_features_test, n_voxels)
target_features_test = torch.rand(n_features_test, n_voxels)
source_features_on_target = piecewise_mapping.transform(
source_features_test
)
assert source_features_on_target.shape == target_features_test.shape
assert isinstance(source_features_on_target, torch.Tensor)
target_features_on_source = piecewise_mapping.inverse_transform(
target_features_test
)
assert target_features_on_source.shape == source_features_test.shape
assert isinstance(target_features_on_source, torch.Tensor)

source_features_test = torch.rand(n_voxels)
target_features_test = torch.rand(n_voxels)
source_features_on_target = piecewise_mapping.transform(
source_features_test
)
assert source_features_on_target.shape == target_features_test.shape
assert isinstance(source_features_on_target, torch.Tensor)
target_features_on_source = piecewise_mapping.inverse_transform(
target_features_test
)
assert target_features_on_source.shape == source_features_test.shape
assert isinstance(target_features_on_source, torch.Tensor)

0 comments on commit b77086e

Please sign in to comment.