Skip to content

Commit

Permalink
Add id_interpolation to sparse mappings transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
pbarbarant committed Oct 20, 2023
1 parent 17265dd commit ac95ce8
Showing 1 changed file with 49 additions and 16 deletions.
65 changes: 49 additions & 16 deletions src/fugw/mappings/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,9 @@ def fit(

return self

def transform(self, source_features, device="auto"):
def transform(
self, source_features, id_interpolation=False, device="auto"
):
"""
Transport source feature maps using fitted OT plan.
Use GPUs if available.
Expand All @@ -314,6 +316,9 @@ def transform(self, source_features, device="auto"):
----------
source_features: ndarray(n_samples, n) or ndarray(n)
Contrast map for source subject
id_interpolation: bool, optional, defaults to False
If True and source/target share the same geometry,
interpolate the transport plan with the identity.
device: "auto" or torch.device
If "auto": use first available GPU if it's available,
CPU otherwise.
Expand Down Expand Up @@ -346,23 +351,51 @@ def transform(self, source_features, device="auto"):
)

# Transform data
transformed_data = (
(
torch.sparse.mm(
self.pi.to(device).transpose(0, 1),
source_features_tensor.T,
).to_dense()
/ (
torch.sparse.sum(self.pi.to(device), dim=0)
.to_dense()
.reshape(-1, 1)
# Add very small value to handle null rows
+ 1e-16
if id_interpolation:
if self.pi.shape[0] != self.pi.shape[1]:
raise ValueError(
"Cannot interpolate with identity if source and target"
" have different geometries."
)
transformed_data = (
(
(
torch.sparse.mm(
self.pi.to(device).transpose(0, 1),
source_features_tensor.T,
).to_dense()
/ (
torch.sparse.sum(self.pi.to(device), dim=0)
.to_dense()
.reshape(-1, 1)
# Add very small value to handle null rows
+ 1e-16
)
+ source_features_tensor.T
)
/ 2
)
.T.detach()
.cpu()
)
else:
transformed_data = (
(
torch.sparse.mm(
self.pi.to(device).transpose(0, 1),
source_features_tensor.T,
).to_dense()
/ (
torch.sparse.sum(self.pi.to(device), dim=0)
.to_dense()
.reshape(-1, 1)
# Add very small value to handle null rows
+ 1e-16
)
)
.T.detach()
.cpu()
)
.T.detach()
.cpu()
)

# Free allocated GPU memory
del source_features_tensor
Expand Down

0 comments on commit ac95ce8

Please sign in to comment.