-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
123 lines (103 loc) · 3.74 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from collections import defaultdict
from heapq import heappush, heappop
from math import sqrt
def prim(graph):
"""
### TODO:
Update this method to work when the graph has multiple connected components.
Rather than returning a single tree, return a list of trees,
one per component, containing the MST for each component.
Each tree is a set of (weight, node1, node2) tuples.
"""
def prim_helper(visited, frontier, tree):
if len(frontier) == 0:
return tree
else:
weight, node, parent = heappop(frontier)
if node in visited:
return prim_helper(visited, frontier, tree)
else:
print('visiting', node)
# record this edge in the tree
tree.add((weight, node, parent))
visited.add(node)
for neighbor, w in graph[node]:
heappush(frontier, (w, neighbor, node))
# compare with dijkstra:
# heappush(frontier, (distance + weight, neighbor))
return prim_helper(visited, frontier, tree)
trees = []
visited = set()
while len(graph) > len(visited):
source = ""
for v in graph:
if v not in visited:
source = v
break
frontier = []
heappush(frontier, (0, source, source))
tree = set()
trees.append(prim_helper(visited, frontier, tree))
return trees
def test_prim():
graph = {
's': {('a', 4), ('b', 8)},
'a': {('s', 4), ('b', 2), ('c', 5)},
'b': {('s', 8), ('a', 2), ('c', 3)},
'c': {('a', 5), ('b', 3), ('d', 3)},
'd': {('c', 3)},
'e': {('f', 10)}, # e and f are in a separate component.
'f': {('e', 10)}
}
trees = prim(graph)
assert len(trees) == 2
# since we are not guaranteed to get the same order
# of edges in the answer, we'll check the size and
# weight of each tree.
len1 = len(trees[0])
len2 = len(trees[1])
assert min([len1, len2]) == 2
assert max([len1, len2]) == 5
sum1 = sum(e[0] for e in trees[0])
sum2 = sum(e[0] for e in trees[1])
assert min([sum1, sum2]) == 10
assert max([sum1, sum2]) == 12
###
def mst_from_points(points):
"""
Return the minimum spanning tree for a list of points, using euclidean distance
as the edge weight between each pair of points.
See test_mst_from_points.
Params:
points... a list of tuples (city_name, x-coord, y-coord)
Returns:
a list of edges of the form (weight, node1, node2) indicating the minimum spanning
tree connecting the cities in the input.
"""
###TODO
graph = {}
for i in range(len(points)):
for j in range(len(points)):
if i != j:
weight = euclidean_distance(points[i], points[j])
if points[i][0] not in graph:
graph[points[i][0]] = {(points[j][0], weight)}
else:
graph[points[i][0]].add((points[j][0], weight))
return prim(graph)[0]
def euclidean_distance(p1, p2):
return sqrt((p1[1] - p2[1])**2 + (p1[2] - p2[2])**2)
def test_euclidean_distance():
assert round(euclidean_distance(('a', 5, 10), ('b', 7, 12)), 2) == 2.83
def test_mst_from_points():
points = [('a', 5, 10), #(city_name, x-coord, y-coord)
('b', 7, 12),
('c', 2, 3),
('d', 12, 3),
('e', 4, 6),
('f', 6, 7)]
tree = mst_from_points(points)
# check that the weight of the MST is correct.
assert round(sum(e[0] for e in tree), 2) == 19.04
test_prim()
test_mst_from_points()