-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathui.py
136 lines (106 loc) · 5.94 KB
/
ui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from tkinter import *
from tkinter import messagebox
import math
CELL_SIZE = 80
CELL_PADDING = 10
ARROW_LENGTH = CELL_SIZE / 2
class GridWorldWindow(object):
"""Manages all of the UI
"""
def __init__(self, metadata):
self.window = Tk()
self.window.title('Gridworld')
self.window.geometry('{}x{}'.format(1080, 720))
# extract data from the JSON
self.grid_width = metadata['width']
self.grid_height = metadata['height']
self.obstacles = [tuple(obstacle) for obstacle in metadata['obstacles']]
self.terminals = [tuple(terminal['state']) for terminal in metadata['terminals']]
self.canvas_width = metadata['width'] * CELL_SIZE
self.canvas_height = metadata['height'] * CELL_SIZE
# create the tkinder IDs for all of the modifiable UI
self.ids_text = [[0 for col in range(self.grid_width)] for row in range(self.grid_height)]
self.ids_rect = [[0 for col in range(self.grid_width)] for row in range(self.grid_height)]
self.ids_arrow = [[0 for col in range(self.grid_width)] for row in range(self.grid_height)]
self._create_buttons()
self.canvas = Canvas(self.window, width=self.canvas_width, height=self.canvas_height, bg='black')
self.canvas.pack(padx=10, pady=10)
self._create_grid()
def _create_buttons(self):
self.frame_value_buttons = Frame(self.window)
self.frame_value_buttons.pack(padx=5, pady=5)
self.frame_policy_buttons = Frame(self.window)
self.frame_policy_buttons.pack(padx=5, pady=5)
self.frame_reset_buttons = Frame(self.window)
self.frame_reset_buttons.pack(padx=5, pady=5)
self.btn_value_iteration_1_step = Button(self.frame_value_buttons, text='1-Step Value Iteration', anchor=W)
self.btn_value_iteration_1_step.pack(side=LEFT)
self.btn_value_iteration_100_steps = Button(self.frame_value_buttons, text='100-Step Value Iteration', anchor=E)
self.btn_value_iteration_100_steps.pack(side=LEFT)
self.btn_value_iteration_slow = Button(self.frame_value_buttons, text='Slow Value Iteration', anchor=E)
self.btn_value_iteration_slow.pack(side=LEFT)
self.btn_policy_iteration_1_step = Button(self.frame_policy_buttons, text='1-Step Policy Iteration', anchor=E)
self.btn_policy_iteration_1_step.pack(side=LEFT)
self.btn_policy_iteration_100_steps = Button(self.frame_policy_buttons, text='100-Step Policy Iteration', anchor=E)
self.btn_policy_iteration_100_steps.pack(side=LEFT)
self.btn_policy_iteration_slow = Button(self.frame_policy_buttons, text='Slow Policy Iteration', anchor=E)
self.btn_policy_iteration_slow.pack(side=LEFT)
self.btn_reset = Button(self.frame_reset_buttons, text='Reset', anchor=E)
self.btn_reset.pack(side=LEFT)
def _create_grid(self):
for row in range(self.grid_height):
for col in range(self.grid_width):
if (row, col) in self.obstacles:
fill = 'grey'
text = None
else:
fill = None
text = '0.00'
self.ids_rect[row][col] = self.canvas.create_rectangle(col * CELL_SIZE, row * CELL_SIZE, (col+1) * CELL_SIZE, (row+1) * CELL_SIZE, fill=fill, outline='white')
if (row, col) in self.terminals:
self.canvas.create_rectangle(col * CELL_SIZE + CELL_PADDING, row * CELL_SIZE + CELL_PADDING, (col+1) * CELL_SIZE - CELL_PADDING, (row+1) * CELL_SIZE - CELL_PADDING, fill=fill, outline='white')
self.ids_text[row][col] = self.canvas.create_text(col * CELL_SIZE + CELL_SIZE/2, row * CELL_SIZE + CELL_SIZE/2, text=text, fill='white')
self.ids_arrow[row][col] = self.canvas.create_line(0, 0, 0, 0, width=2, arrow=LAST, fill='white')
def _compute_color(self, value):
# negative values are redder while positive values are greener
if value == 0:
return '#000000'
elif value > 0:
g = math.floor(255 if value >= 1.0 else value * 256)
return '#{:02x}{:02x}{:02x}'.format(0, g, 0)
elif value < 0:
r = math.floor(255 if -value >= 1.0 else -value * 256)
return '#{:02x}{:02x}{:02x}'.format(r, 0, 0)
def show_dialog(self, text):
messagebox.showinfo('Info', text)
def update_grid(self, values, policy):
for state, value in values.items():
rect_id = self.ids_rect[state[0]][state[1]]
text_id = self.ids_text[state[0]][state[1]]
arrow_id = self.ids_arrow[state[0]][state[1]]
self.canvas.itemconfig(rect_id, fill=self._compute_color(value))
self.canvas.itemconfig(text_id, text='{:.2f}'.format(value))
if state not in self.terminals:
self.canvas.coords(arrow_id,
state[1] * CELL_SIZE + CELL_SIZE/2 + policy[state][1] * ARROW_LENGTH - policy[state][1],
state[0] * CELL_SIZE + CELL_SIZE/2 + policy[state][0] * ARROW_LENGTH - policy[state][0],
state[1] * CELL_SIZE + CELL_SIZE/2 + policy[state][1] * ARROW_LENGTH,
state[0] * CELL_SIZE + CELL_SIZE/2 + policy[state][0] * ARROW_LENGTH)
def clear(self):
for row in range(self.grid_height):
for col in range(self.grid_width):
rect_id = self.ids_rect[row][col]
text_id = self.ids_text[row][col]
arrow_id = self.ids_arrow[row][col]
if (row, col) in self.obstacles:
fill = 'grey'
text = None
else:
fill = self._compute_color(0)
text = '0.00'
self.canvas.itemconfig(rect_id, fill=fill)
self.canvas.itemconfig(text_id, text=text)
self.canvas.coords(arrow_id, 0, 0, 0, 0)
def run(self):
# run the UI loop
mainloop()