From 4b8b51e0bc08e09394fc53199723ec6dcb04c53a Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Sun, 5 Jan 2025 21:28:28 +0100 Subject: [PATCH] Add perm.optimal_perm() --- phiml/math/perm.py | 23 +++++++++++++++++++++-- tests/commit/math/test_perm.py | 11 +++++++++-- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/phiml/math/perm.py b/phiml/math/perm.py index df08111..1b74631 100644 --- a/phiml/math/perm.py +++ b/phiml/math/perm.py @@ -4,10 +4,10 @@ import numpy as np from ..backend import default_backend -from ._shape import concat_shapes_, batch, DimFilter, Shape, SHAPE_TYPES, shape, non_batch, channel, dual +from ._shape import concat_shapes_, batch, DimFilter, Shape, SHAPE_TYPES, shape, non_batch, channel, dual, primal from ._magic_ops import unpack_dim, expand, stack, slice_ from ._tensors import reshaped_tensor, TensorOrTree, Tensor, wrap -from ._ops import unravel_index +from ._ops import unravel_index, psum, dmin def random_permutation(*shape: Union[Shape, Any], dims=non_batch, index_dim=channel('index')) -> Tensor: @@ -107,3 +107,22 @@ def all_permutations(dims: Shape, list_dim=dual('perm'), index_dim: Optional[Sha assert len(dims) == 1, f"For multi-dim permutations, index_dim must be specified." return perms return unravel_index(perms, dims, index_dim) + + +def optimal_perm(cost_matrix: Tensor): + """ + Given a pair-wise cost matrix of two equal-size vectors, finds the optimal permutation to apply to one vector (corresponding to the dual dim of `cost_matrix`) in order to obtain the minimum total cost for a bijective map. + + Args: + cost_matrix: Pair-wise cost matrix. Must be a square matrix with a dual and primal dim. + + Returns: + dual_perm: Permutation that, when applied along the vector corresponding to the dual dim in `cost_matrix`, yields the minimum cost. + cost: Optimal cost vector, listed along primal dim of `cost_matrix`. + """ + assert dual(cost_matrix) and primal(cost_matrix), f"cost_matrix must have primal and dual dims but got {cost_matrix}" + perms = all_permutations(primal(cost_matrix), index_dim=None) + perms = expand(perms, channel(index=dual(cost_matrix).name_list)) + cost_pairs_by_perm = cost_matrix[perms] + perm, cost = dmin((perms, cost_pairs_by_perm), key=psum(cost_pairs_by_perm)) + return perm, cost diff --git a/tests/commit/math/test_perm.py b/tests/commit/math/test_perm.py index 9aac854..389e8c5 100644 --- a/tests/commit/math/test_perm.py +++ b/tests/commit/math/test_perm.py @@ -1,8 +1,8 @@ from unittest import TestCase from phiml import math -from phiml.math import spatial -from phiml.math.perm import all_permutations +from phiml.math import spatial, wrap, squeeze, channel +from phiml.math.perm import all_permutations, optimal_perm class TestSparse(TestCase): @@ -17,3 +17,10 @@ def test_all_permutations_2d(self): self.assertEqual(24, p.permutations.size) math.assert_close(1, p.max) math.assert_close(0, p.min) + + def test_optimal_perm(self): + o = wrap([0, 10, 100], 'x') + i = wrap([90, -2, 0], '~x') + i_perm, cost = optimal_perm((i - o) ** 2) + math.assert_close([1, 2, 0], squeeze(i_perm, channel)) + math.assert_close(cost, (i[i_perm] - o) ** 2)