Skip to content

Commit

Permalink
Add id_interpolation to dense mappings transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
pbarbarant committed Oct 20, 2023
1 parent 7df8231 commit 17265dd
Showing 1 changed file with 37 additions and 6 deletions.
43 changes: 37 additions & 6 deletions src/fugw/mappings/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,9 @@ def fit(

return self

def transform(self, source_features, device="auto"):
def transform(
self, source_features, id_interpolation=False, device="auto"
):
"""
Transport source feature maps using fitted OT plan.
Use GPUs if available.
Expand All @@ -279,6 +281,9 @@ def transform(self, source_features, device="auto"):
----------
source_features: ndarray(n_samples, n) or ndarray(n)
Contrast map for source subject
id_interpolation: bool, optional, defaults to False
If True and source/target share the same geometry,
interpolate the transport plan with the identity.
device: "auto" or torch.device
If "auto": use first available GPU if it's available,
CPU otherwise.
Expand Down Expand Up @@ -312,11 +317,37 @@ def transform(self, source_features, device="auto"):
source_features_tensor = _make_tensor(source_features, device=device)

# Transform data
transformed_data = (
(pi.T @ source_features_tensor.T / pi.sum(dim=0).reshape(-1, 1))
.T.detach()
.cpu()
)
if id_interpolation:
if pi.shape[0] != pi.shape[1]:
raise ValueError(
"Cannot interpolate with identity if source and target"
" have different geometries."
)
transformed_data = (
(
(
(
pi.T
@ source_features_tensor.T
/ pi.sum(dim=0).reshape(-1, 1)
)
+ source_features_tensor.T
)
/ 2
)
.T.detach()
.cpu()
)
else:
transformed_data = (
(
pi.T
@ source_features_tensor.T
/ pi.sum(dim=0).reshape(-1, 1)
)
.T.detach()
.cpu()
)

# Free allocated GPU memory
del pi, source_features_tensor
Expand Down

0 comments on commit 17265dd

Please sign in to comment.