Skip to content

Commit

Permalink
Fix error in L2 calculation (#40)
Browse files Browse the repository at this point in the history
* fix bug in L2 calculation

* edit the description
  • Loading branch information
6Ulm authored Jan 24, 2024
1 parent 63e7007 commit 8ddd7ea
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
8 changes: 4 additions & 4 deletions doc/pages/introduction.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,16 @@ In their paper, authors relax this constraint and only impose that
the mass of each plan is equal (ie `$\text{m}(P) = \text{m}(Q)$`) and the problem is now
convex in `$P$` and in `$Q$`. Finally, they derive a block-coordinate-descent
(BCD) algorithm in which they alternatively freeze the value of `$P$` (resp. `$Q$`)
while running a convex-problem solver (in their case it's a sinkhorn algorithm)
to optimize `$Q$` (resp. `$P$`).
while running a solver for the regularized unbalanced optimal transport problem
(in their case it's a Sinkhorn algorithm) to optimize `$Q$` (resp. `$P$`).

In this work, we adapt the previous approach to approximate solutions
to FUGW losses. Moreover, we provide multiple solvers to run inside the BCD algorithm.
Namely, we provide:

- `sinkhorn`: the classical Sinkhorn procedure described in [(Chizat et al. 2017) [7]](#7)
- `mm`: a majorize-minimization algorithm described in [(Chapel et al. 2021) [4]](#4)
- `ibpp`: an inexact-bregman-proximal-point algorithm described in [(Xie et al. 2020) [5]](#5)
- `mm`: a majorization-minimization algorithm described in [(Chapel et al. 2021) [4]](#4)
- `ibpp`: an extension to the unbalanced setting of the inexact-bregman-proximal-point algorithm described in [(Xie et al. 2020) [5]](#5)

## References

Expand Down
20 changes: 10 additions & 10 deletions src/fugw/solvers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,7 +1147,7 @@ def compute_l2(p, q):
-------
l2: float
"""
return torch.sum((p - q) ** 2)
return torch.sum((p - q) ** 2) / 2


def compute_l2_sparse(p, q):
Expand All @@ -1163,7 +1163,7 @@ def compute_l2_sparse(p, q):
-------
l2: float
"""
return torch.sum((p.values() - q.values()) ** 2)
return torch.sum((p.values() - q.values()) ** 2) / 2


def compute_divergence(p, q, divergence="kl"):
Expand All @@ -1177,7 +1177,7 @@ def compute_divergence(p, q, divergence="kl"):
divergence: str
Either "kl" or "l2".
If "kl", compute KL(p, q).
If "l2", compute || p - q ||^2.
If "l2", compute || p - q ||^2 / 2.
Default: "kl"
Returns
Expand All @@ -1201,7 +1201,7 @@ def compute_divergence_sparse(p, q, divergence="kl"):
divergence: str
Either "kl" or "l2".
If "kl", compute KL(p, q).
If "l2", compute || p - q ||^2.
If "l2", compute || p - q ||^2 / 2.
Default: "kl"
Returns
Expand Down Expand Up @@ -1263,19 +1263,19 @@ def compute_quad_kl_sparse(mu, nu, alpha, beta):


def compute_quad_l2(a, b, mu, nu):
"""Compute || a otimes b - mu otimes nu ||^2."""
"""Compute || a otimes b - mu otimes nu ||^2 / 2."""

norm = (
(a**2).sum() * (b**2).sum()
- 2 * (a * mu).sum() * (b * nu).sum()
+ (mu**2).sum() * (nu**2).sum()
)

return norm
return norm / 2


def compute_quad_l2_sparse(a, b, mu, nu):
"""Compute || a otimes b - mu otimes nu ||^2.
"""Compute || a otimes b - mu otimes nu ||^2 / 2.
Because a otimes b is constly to store in memory,
we expand the norm so that we only have to deal with scalars.
Expand All @@ -1302,7 +1302,7 @@ def compute_quad_l2_sparse(a, b, mu, nu):
+ (mu.values() ** 2).sum() * (nu.values() ** 2).sum()
)

return norm
return norm / 2


def compute_quad_divergence(mu, nu, alpha, beta, divergence="kl"):
Expand All @@ -1319,7 +1319,7 @@ def compute_quad_divergence(mu, nu, alpha, beta, divergence="kl"):
divergence: str
Either "kl" or "l2".
If "kl", compute KL(mu otimes nu, alpha otimes beta).
If "l2", compute || mu otimes nu - alpha otimes beta ||^2.
If "l2", compute || mu otimes nu - alpha otimes beta ||^2 / 2.
Default: "kl"
Returns
Expand All @@ -1346,7 +1346,7 @@ def compute_quad_divergence_sparse(mu, nu, alpha, beta, divergence="kl"):
divergence: str
Either "kl" or "l2".
If "kl", compute KL(mu otimes nu, alpha otimes beta).
If "l2", compute || mu otimes nu - alpha otimes beta ||^2.
If "l2", compute || mu otimes nu - alpha otimes beta ||^2 / 2.
Default: "kl"
Returns
Expand Down

0 comments on commit 8ddd7ea

Please sign in to comment.