diff --git a/src/fugw/mappings/dense.py b/src/fugw/mappings/dense.py index f3ede0f..18fe2de 100644 --- a/src/fugw/mappings/dense.py +++ b/src/fugw/mappings/dense.py @@ -270,7 +270,9 @@ def fit( return self - def transform(self, source_features, device="auto"): + def transform( + self, source_features, id_interpolation=False, device="auto" + ): """ Transport source feature maps using fitted OT plan. Use GPUs if available. @@ -279,6 +281,9 @@ def transform(self, source_features, device="auto"): ---------- source_features: ndarray(n_samples, n) or ndarray(n) Contrast map for source subject + id_interpolation: bool, optional, defaults to False + If True and source/target share the same geometry, + interpolate the transport plan with the identity. device: "auto" or torch.device If "auto": use first available GPU if it's available, CPU otherwise. @@ -312,11 +317,37 @@ def transform(self, source_features, device="auto"): source_features_tensor = _make_tensor(source_features, device=device) # Transform data - transformed_data = ( - (pi.T @ source_features_tensor.T / pi.sum(dim=0).reshape(-1, 1)) - .T.detach() - .cpu() - ) + if id_interpolation: + if pi.shape[0] != pi.shape[1]: + raise ValueError( + "Cannot interpolate with identity if source and target" + " have different geometries." + ) + transformed_data = ( + ( + ( + ( + pi.T + @ source_features_tensor.T + / pi.sum(dim=0).reshape(-1, 1) + ) + + source_features_tensor.T + ) + / 2 + ) + .T.detach() + .cpu() + ) + else: + transformed_data = ( + ( + pi.T + @ source_features_tensor.T + / pi.sum(dim=0).reshape(-1, 1) + ) + .T.detach() + .cpu() + ) # Free allocated GPU memory del pi, source_features_tensor