Skip to content

Commit

Permalink
feat: Add callback_barycenter parameter to FUGWSparseBarycenter class
Browse files Browse the repository at this point in the history
  • Loading branch information
pbarbarant committed Jul 15, 2024
1 parent ac7c395 commit f245fc5
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/fugw/mappings/sparse_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def fit(
mesh_sample=None,
nits_barycenter=5,
device="auto",
callback_barycenter=None,
verbose=False,
):
"""Compute barycentric features and geometry
Expand Down Expand Up @@ -192,6 +193,11 @@ def fit(
device: "auto" or torch.device
if "auto": use first available gpu if it's available,
cpu otherwise.
callback_barycenter: callable or None
Callback function called at the end of each barycenter step.
It will be called with the following arguments:
- locals (dictionary containing all local variables)
Returns
-------
Expand Down Expand Up @@ -274,6 +280,9 @@ def fit(
plans, weights_list, features_list, device
)

if callback_barycenter is not None:
callback_barycenter(locals())

return (
barycenter_weights,
barycenter_features,
Expand Down

0 comments on commit f245fc5

Please sign in to comment.