From d4b3e36421d7e83f316bfed4c7a68dffadfbace9 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Mon, 5 Aug 2024 17:14:11 +0200 Subject: [PATCH] feat: Add device parameter to compute_sparsity_mask function --- src/fugw/scripts/coarse_to_fine.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/fugw/scripts/coarse_to_fine.py b/src/fugw/scripts/coarse_to_fine.py index 8cf029a..43605a0 100644 --- a/src/fugw/scripts/coarse_to_fine.py +++ b/src/fugw/scripts/coarse_to_fine.py @@ -273,6 +273,7 @@ def compute_sparsity_mask( source_selection_radius=1, target_selection_radius=1, method="topk", + device="auto", ): """ Compute sparsity mask from coarse mapping. @@ -301,12 +302,20 @@ def compute_sparsity_mask( Method used to select pairs of source and target features whose neighbourhoods will be used to define the sparsity mask of the solution + device: "auto" or torch.device + if "auto": use first available gpu if it's available, + cpu otherwise. Returns ------- sparsity_mask: torch.sparse_coo_tensor of size (n, m) Sparsity mask used to initialize the fine mapping. """ + if device == "auto": + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + else: + device = torch.device("cpu") if method == "quantile": # Method 1: keep first percentile threshold = np.percentile(coarse_mapping.pi, 99.95) @@ -334,29 +343,17 @@ def compute_sparsity_mask( # which vertex is close to which sampled point N_source = get_neighbourhood_matrix( source_geometry_embeddings, source_sample, source_selection_radius - ) + ).to(device) N_target = get_neighbourhood_matrix( target_geometry_embeddings, target_sample, target_selection_radius - ) + ).to(device) # b. cluster matrices that encode # which sampled point belongs to which cluster - C_source = get_cluster_matrix(rows, source_sample.shape[0]) - C_target = get_cluster_matrix(cols, target_sample.shape[0]) + C_source = get_cluster_matrix(rows, source_sample.shape[0]).to(device) + C_target = get_cluster_matrix(cols, target_sample.shape[0]).to(device) sparsity_mask = (N_source @ C_source) @ (N_target @ C_target).T - # Check for any empty row/column in the sparsity mask - rows, cols = sparsity_mask.indices() - if ( - len(np.unique(rows)) < sparsity_mask.shape[0] - or len(np.unique(cols)) < sparsity_mask.shape[1] - ): - raise ValueError( - "Sparsity mask contains empty rows or columns. " - "Please consider increasing the selection radius " - "or increasing the number of samples." - ) - return sparsity_mask @@ -536,6 +533,7 @@ def fit( source_selection_radius=source_selection_radius, target_selection_radius=target_selection_radius, method=coarse_pairs_selection_method, + device=device, ) # Define init plan from sparsity mask