Skip to content

Commit

Permalink
Add perm.optimal_perm()
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Jan 5, 2025
1 parent fa6794d commit 4b8b51e
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
23 changes: 21 additions & 2 deletions phiml/math/perm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import numpy as np

from ..backend import default_backend
from ._shape import concat_shapes_, batch, DimFilter, Shape, SHAPE_TYPES, shape, non_batch, channel, dual
from ._shape import concat_shapes_, batch, DimFilter, Shape, SHAPE_TYPES, shape, non_batch, channel, dual, primal
from ._magic_ops import unpack_dim, expand, stack, slice_
from ._tensors import reshaped_tensor, TensorOrTree, Tensor, wrap
from ._ops import unravel_index
from ._ops import unravel_index, psum, dmin


def random_permutation(*shape: Union[Shape, Any], dims=non_batch, index_dim=channel('index')) -> Tensor:
Expand Down Expand Up @@ -107,3 +107,22 @@ def all_permutations(dims: Shape, list_dim=dual('perm'), index_dim: Optional[Sha
assert len(dims) == 1, f"For multi-dim permutations, index_dim must be specified."
return perms
return unravel_index(perms, dims, index_dim)


def optimal_perm(cost_matrix: Tensor):
"""
Given a pair-wise cost matrix of two equal-size vectors, finds the optimal permutation to apply to one vector (corresponding to the dual dim of `cost_matrix`) in order to obtain the minimum total cost for a bijective map.
Args:
cost_matrix: Pair-wise cost matrix. Must be a square matrix with a dual and primal dim.
Returns:
dual_perm: Permutation that, when applied along the vector corresponding to the dual dim in `cost_matrix`, yields the minimum cost.
cost: Optimal cost vector, listed along primal dim of `cost_matrix`.
"""
assert dual(cost_matrix) and primal(cost_matrix), f"cost_matrix must have primal and dual dims but got {cost_matrix}"
perms = all_permutations(primal(cost_matrix), index_dim=None)
perms = expand(perms, channel(index=dual(cost_matrix).name_list))
cost_pairs_by_perm = cost_matrix[perms]
perm, cost = dmin((perms, cost_pairs_by_perm), key=psum(cost_pairs_by_perm))
return perm, cost
11 changes: 9 additions & 2 deletions tests/commit/math/test_perm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from unittest import TestCase

from phiml import math
from phiml.math import spatial
from phiml.math.perm import all_permutations
from phiml.math import spatial, wrap, squeeze, channel
from phiml.math.perm import all_permutations, optimal_perm


class TestSparse(TestCase):
Expand All @@ -17,3 +17,10 @@ def test_all_permutations_2d(self):
self.assertEqual(24, p.permutations.size)
math.assert_close(1, p.max)
math.assert_close(0, p.min)

def test_optimal_perm(self):
o = wrap([0, 10, 100], 'x')
i = wrap([90, -2, 0], '~x')
i_perm, cost = optimal_perm((i - o) ** 2)
math.assert_close([1, 2, 0], squeeze(i_perm, channel))
math.assert_close(cost, (i[i_perm] - o) ** 2)

0 comments on commit 4b8b51e

Please sign in to comment.