forked from facebookresearch/segment-anything
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmatching.py
56 lines (51 loc) · 1.8 KB
/
matching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""
Used for visualizing the optimal matching between two trees.
"""
import networkx as nx
from uuid import uuid4
from networkx.algorithms import bipartite
from more_itertools import unzip
def optimalBipartiteMatching (costTable) :
"""
Return the minimum cost bipartite matching.
Parameters
----------
costTable : dict()
For each pair, what is the cost of matching
that pair together.
"""
A, B = list(map(set, unzip(costTable.keys())))
# Rename the elements of sets A and B uniquely and record
# the mapping. If the set elements have common names, then
# they won't clash while computing the bipartite matching.
aMap = dict(zip(A, [str(_) + "-" + str(uuid4()) for _ in A]))
aInvMap = dict(map(reversed, aMap.items()))
bMap = dict(zip(B, [str(_) + "-" + str(uuid4()) for _ in B]))
bInvMap = dict(map(reversed, bMap.items()))
# Create bipartite graph with given edge costs.
G = nx.Graph()
for key, cost in costTable.items() :
i, j = key
G.add_node(aMap[i], bipartite=0)
G.add_node(bMap[j], bipartite=1)
G.add_edge(aMap[i], bMap[j], weight=cost)
# Solve the bipartite matching problem and remove duplicate
# edges from the matching.
matchingWithDups = bipartite.minimum_weight_full_matching(G)
matching = dict()
for i in A :
if aMap[i] in matchingWithDups :
matching[i] = bInvMap[matchingWithDups[aMap[i]]]
return matching
def bestAssignmentCost (costTable) :
"""
Compute the minimum total cost assignment.
Parameters
----------
costTable : dict()
For each pair, what is the cost of matching that pair
together.
"""
matching = optimalBipartiteMatching(costTable)
cost = sum([costTable[e] for e in matching.items()])
return cost