From 8ddd7ea85c42b0f80b9ede396ef911ef70b78199 Mon Sep 17 00:00:00 2001 From: Huy Tran Date: Wed, 24 Jan 2024 15:25:48 +0100 Subject: [PATCH] Fix error in L2 calculation (#40) * fix bug in L2 calculation * edit the description --- doc/pages/introduction.md | 8 ++++---- src/fugw/solvers/utils.py | 20 ++++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/doc/pages/introduction.md b/doc/pages/introduction.md index 19dd663e..8756b992 100644 --- a/doc/pages/introduction.md +++ b/doc/pages/introduction.md @@ -45,16 +45,16 @@ In their paper, authors relax this constraint and only impose that the mass of each plan is equal (ie `$\text{m}(P) = \text{m}(Q)$`) and the problem is now convex in `$P$` and in `$Q$`. Finally, they derive a block-coordinate-descent (BCD) algorithm in which they alternatively freeze the value of `$P$` (resp. `$Q$`) -while running a convex-problem solver (in their case it's a sinkhorn algorithm) -to optimize `$Q$` (resp. `$P$`). +while running a solver for the regularized unbalanced optimal transport problem +(in their case it's a Sinkhorn algorithm) to optimize `$Q$` (resp. `$P$`). In this work, we adapt the previous approach to approximate solutions to FUGW losses. Moreover, we provide multiple solvers to run inside the BCD algorithm. Namely, we provide: - `sinkhorn`: the classical Sinkhorn procedure described in [(Chizat et al. 2017) [7]](#7) -- `mm`: a majorize-minimization algorithm described in [(Chapel et al. 2021) [4]](#4) -- `ibpp`: an inexact-bregman-proximal-point algorithm described in [(Xie et al. 2020) [5]](#5) +- `mm`: a majorization-minimization algorithm described in [(Chapel et al. 2021) [4]](#4) +- `ibpp`: an extension to the unbalanced setting of the inexact-bregman-proximal-point algorithm described in [(Xie et al. 2020) [5]](#5) ## References diff --git a/src/fugw/solvers/utils.py b/src/fugw/solvers/utils.py index f7e8e7bf..1687d640 100644 --- a/src/fugw/solvers/utils.py +++ b/src/fugw/solvers/utils.py @@ -1147,7 +1147,7 @@ def compute_l2(p, q): ------- l2: float """ - return torch.sum((p - q) ** 2) + return torch.sum((p - q) ** 2) / 2 def compute_l2_sparse(p, q): @@ -1163,7 +1163,7 @@ def compute_l2_sparse(p, q): ------- l2: float """ - return torch.sum((p.values() - q.values()) ** 2) + return torch.sum((p.values() - q.values()) ** 2) / 2 def compute_divergence(p, q, divergence="kl"): @@ -1177,7 +1177,7 @@ def compute_divergence(p, q, divergence="kl"): divergence: str Either "kl" or "l2". If "kl", compute KL(p, q). - If "l2", compute || p - q ||^2. + If "l2", compute || p - q ||^2 / 2. Default: "kl" Returns @@ -1201,7 +1201,7 @@ def compute_divergence_sparse(p, q, divergence="kl"): divergence: str Either "kl" or "l2". If "kl", compute KL(p, q). - If "l2", compute || p - q ||^2. + If "l2", compute || p - q ||^2 / 2. Default: "kl" Returns @@ -1263,7 +1263,7 @@ def compute_quad_kl_sparse(mu, nu, alpha, beta): def compute_quad_l2(a, b, mu, nu): - """Compute || a otimes b - mu otimes nu ||^2.""" + """Compute || a otimes b - mu otimes nu ||^2 / 2.""" norm = ( (a**2).sum() * (b**2).sum() @@ -1271,11 +1271,11 @@ def compute_quad_l2(a, b, mu, nu): + (mu**2).sum() * (nu**2).sum() ) - return norm + return norm / 2 def compute_quad_l2_sparse(a, b, mu, nu): - """Compute || a otimes b - mu otimes nu ||^2. + """Compute || a otimes b - mu otimes nu ||^2 / 2. Because a otimes b is constly to store in memory, we expand the norm so that we only have to deal with scalars. @@ -1302,7 +1302,7 @@ def compute_quad_l2_sparse(a, b, mu, nu): + (mu.values() ** 2).sum() * (nu.values() ** 2).sum() ) - return norm + return norm / 2 def compute_quad_divergence(mu, nu, alpha, beta, divergence="kl"): @@ -1319,7 +1319,7 @@ def compute_quad_divergence(mu, nu, alpha, beta, divergence="kl"): divergence: str Either "kl" or "l2". If "kl", compute KL(mu otimes nu, alpha otimes beta). - If "l2", compute || mu otimes nu - alpha otimes beta ||^2. + If "l2", compute || mu otimes nu - alpha otimes beta ||^2 / 2. Default: "kl" Returns @@ -1346,7 +1346,7 @@ def compute_quad_divergence_sparse(mu, nu, alpha, beta, divergence="kl"): divergence: str Either "kl" or "l2". If "kl", compute KL(mu otimes nu, alpha otimes beta). - If "l2", compute || mu otimes nu - alpha otimes beta ||^2. + If "l2", compute || mu otimes nu - alpha otimes beta ||^2 / 2. Default: "kl" Returns