Skip to content

Commit

Permalink
feat: Add device parameter to compute_sparsity_mask function
Browse files Browse the repository at this point in the history
  • Loading branch information
pbarbarant committed Aug 5, 2024
1 parent 35c4525 commit d4b3e36
Showing 1 changed file with 14 additions and 16 deletions.
30 changes: 14 additions & 16 deletions src/fugw/scripts/coarse_to_fine.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def compute_sparsity_mask(
source_selection_radius=1,
target_selection_radius=1,
method="topk",
device="auto",
):
"""
Compute sparsity mask from coarse mapping.
Expand Down Expand Up @@ -301,12 +302,20 @@ def compute_sparsity_mask(
Method used to select pairs of source and target features
whose neighbourhoods will be used to define
the sparsity mask of the solution
device: "auto" or torch.device
if "auto": use first available gpu if it's available,
cpu otherwise.
Returns
-------
sparsity_mask: torch.sparse_coo_tensor of size (n, m)
Sparsity mask used to initialize the fine mapping.
"""
if device == "auto":
if torch.cuda.is_available():
device = torch.device("cuda", 0)
else:
device = torch.device("cpu")
if method == "quantile":
# Method 1: keep first percentile
threshold = np.percentile(coarse_mapping.pi, 99.95)
Expand Down Expand Up @@ -334,29 +343,17 @@ def compute_sparsity_mask(
# which vertex is close to which sampled point
N_source = get_neighbourhood_matrix(
source_geometry_embeddings, source_sample, source_selection_radius
)
).to(device)
N_target = get_neighbourhood_matrix(
target_geometry_embeddings, target_sample, target_selection_radius
)
).to(device)
# b. cluster matrices that encode
# which sampled point belongs to which cluster
C_source = get_cluster_matrix(rows, source_sample.shape[0])
C_target = get_cluster_matrix(cols, target_sample.shape[0])
C_source = get_cluster_matrix(rows, source_sample.shape[0]).to(device)
C_target = get_cluster_matrix(cols, target_sample.shape[0]).to(device)

sparsity_mask = (N_source @ C_source) @ (N_target @ C_target).T

# Check for any empty row/column in the sparsity mask
rows, cols = sparsity_mask.indices()
if (
len(np.unique(rows)) < sparsity_mask.shape[0]
or len(np.unique(cols)) < sparsity_mask.shape[1]
):
raise ValueError(
"Sparsity mask contains empty rows or columns. "
"Please consider increasing the selection radius "
"or increasing the number of samples."
)

return sparsity_mask


Expand Down Expand Up @@ -536,6 +533,7 @@ def fit(
source_selection_radius=source_selection_radius,
target_selection_radius=target_selection_radius,
method=coarse_pairs_selection_method,
device=device,
)

# Define init plan from sparsity mask
Expand Down

0 comments on commit d4b3e36

Please sign in to comment.