diff --git a/src/fugw/mappings/sparse_barycenter.py b/src/fugw/mappings/sparse_barycenter.py index b75f058..caf3464 100644 --- a/src/fugw/mappings/sparse_barycenter.py +++ b/src/fugw/mappings/sparse_barycenter.py @@ -52,6 +52,10 @@ def update_barycenter_features(plans, weights_list, features_list, device): else: barycenter_features += acc + # Check for NaN values in the barycenter features + if torch.isnan(barycenter_features).any(): + raise ValueError("Barycenter features contain NaN values") + return barycenter_features.T @staticmethod