From ac95ce8434d0cb4177deddd72f42b1e2f8392c21 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sat, 21 Oct 2023 01:29:42 +0200 Subject: [PATCH] Add id_interpolation to sparse mappings transforms --- src/fugw/mappings/sparse.py | 65 ++++++++++++++++++++++++++++--------- 1 file changed, 49 insertions(+), 16 deletions(-) diff --git a/src/fugw/mappings/sparse.py b/src/fugw/mappings/sparse.py index d494099..04be031 100644 --- a/src/fugw/mappings/sparse.py +++ b/src/fugw/mappings/sparse.py @@ -305,7 +305,9 @@ def fit( return self - def transform(self, source_features, device="auto"): + def transform( + self, source_features, id_interpolation=False, device="auto" + ): """ Transport source feature maps using fitted OT plan. Use GPUs if available. @@ -314,6 +316,9 @@ def transform(self, source_features, device="auto"): ---------- source_features: ndarray(n_samples, n) or ndarray(n) Contrast map for source subject + id_interpolation: bool, optional, defaults to False + If True and source/target share the same geometry, + interpolate the transport plan with the identity. device: "auto" or torch.device If "auto": use first available GPU if it's available, CPU otherwise. @@ -346,23 +351,51 @@ def transform(self, source_features, device="auto"): ) # Transform data - transformed_data = ( - ( - torch.sparse.mm( - self.pi.to(device).transpose(0, 1), - source_features_tensor.T, - ).to_dense() - / ( - torch.sparse.sum(self.pi.to(device), dim=0) - .to_dense() - .reshape(-1, 1) - # Add very small value to handle null rows - + 1e-16 + if id_interpolation: + if self.pi.shape[0] != self.pi.shape[1]: + raise ValueError( + "Cannot interpolate with identity if source and target" + " have different geometries." + ) + transformed_data = ( + ( + ( + torch.sparse.mm( + self.pi.to(device).transpose(0, 1), + source_features_tensor.T, + ).to_dense() + / ( + torch.sparse.sum(self.pi.to(device), dim=0) + .to_dense() + .reshape(-1, 1) + # Add very small value to handle null rows + + 1e-16 + ) + + source_features_tensor.T + ) + / 2 ) + .T.detach() + .cpu() + ) + else: + transformed_data = ( + ( + torch.sparse.mm( + self.pi.to(device).transpose(0, 1), + source_features_tensor.T, + ).to_dense() + / ( + torch.sparse.sum(self.pi.to(device), dim=0) + .to_dense() + .reshape(-1, 1) + # Add very small value to handle null rows + + 1e-16 + ) + ) + .T.detach() + .cpu() ) - .T.detach() - .cpu() - ) # Free allocated GPU memory del source_features_tensor