-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
145 lines (104 loc) · 4.4 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import numpy as np
from math import prod
import random
from copy import deepcopy
def init_qtable(env):
qtable = np.zeros((prod(env.shape), 4))
np.save('qlast.npy', qtable)
return None
def init_image(env,stop,danger, res=50):
img = np.ones( (env.shape[0] * res, env.shape[1] * res, 3) ) * 255
for i in range(env.shape[0]):
img[i*res,:] = (0, 0, 0)
for j in range(env.shape[1]):
img[:,j*res] = (0, 0, 0)
stop = stop * res
danger = danger * res
for s in stop:
img[s[0]:s[0]+res, s[1]:s[1]+res] = (0,100,0)
for d in danger:
img[d[0]:d[0]+res,d[1]:d[1]+res] = (0,0,100)
img = np.array(img).astype(np.uint8)
return img
def init_environment(env_shape=(15,15,15), danger_ratio=0.2, alpha=-1, stop_len=1):
env = [0] * prod(env_shape)
indices = set([i for i in range(len(env))])
stop = []
for i in range(stop_len):
stop.append( random.choice(list(indices)) )
env[stop[i]] = 999
indices = indices - set([stop[-1]])
stop_reshaped = []
for i in range(stop_len):
stop_reshaped.append( [stop[i]//(env_shape[1]*env_shape[2]), (stop[i] % (env_shape[1]*env_shape[2])) // env_shape[2], stop[i] % env_shape[2]] )
danger_len = int(len(indices)*danger_ratio) #random.randint(2,len(indices)//4)
danger=[]
for i in range(danger_len):
danger.append( random.choice(list(indices)) )
indices = indices - set([danger[-1]])
env[danger[-1]] = alpha
# reshape indices:
env_reshaped = np.array(env,dtype=int).reshape(env_shape)
danger_reshaped = []
for i in range(danger_len):
danger_reshaped.append( [danger[i]//(env_shape[1]*env_shape[2]), (danger[i] % (env_shape[1]*env_shape[2])) // env_shape[2], danger[i] % env_shape[2]] )
danger_reshaped = np.array(danger_reshaped,dtype=int)
np.savez('env.npz', env=env_reshaped, stop=stop_reshaped, danger=danger_reshaped, indices=indices)
return None
def reward(s, env):
reward_val=0
if env[s[0],s[1],s[2]]==999:
reward_val=1
if env[s[0],s[1],s[2]]==-1:
reward_val=-1
return reward_val
def index2lin(s, env):
index_lin=s[0]*(env.shape[0]*env.shape[2])+s[1]*env.shape[2]+s[2]
return index_lin
def do_step(s, env, qtable, random_step_prob=1, lr=0.1, gamma=0.1, qtable_save='qlast.npy'):
steps = [[-1,0,0],[1,0,0],[0,-1,0],[0,1,0]]
order = np.argsort( -qtable[index2lin(s,env)])
eps = random.uniform(0,1)
if eps > random_step_prob:
for step_index in order:
s_new = s + steps[step_index] # take the best according to qtable
allowed = True
for dim in [0,1,2]:
if s_new[dim] < 0 or s_new[dim] > env.shape[dim]-1:
allowed=False
if s_new[0]==s[0] and s_new[1]==s[1] and s_new[2]==s[2]:
allowed = False
if allowed:
break
else:
allowed = False
while not allowed:
step_index = random.choice([i for i in range(len(steps))]) # take random
s_new = s + steps[step_index]
allowed = True
for dim in [0,1,2]:
if s_new[dim] < 0 or s_new[dim] > env.shape[dim]-1:
allowed=False
if s_new[0]==s[0] and s_new[1]==s[1] and s_new[2]==s[2]:
allowed = False
if allowed:
break
order_new = np.argsort(-qtable[index2lin(s_new,env)])
for step_index_new in order_new:
s_new_new = s_new + steps[step_index_new]
allowed_new = True
for dim in [0,1,2]:
if s_new_new[dim] < 0 or s_new_new[dim] > env.shape[dim]-1:
allowed_new=False
if s_new_new[0]==s_new[0] and s_new_new[1]==s_new[1] and s_new_new[2]==s_new[2]:
allowed_new = False
if allowed_new:
break
old = deepcopy(qtable[index2lin(s,env),step_index])
reward_value = reward(s_new,env)
qtable[index2lin(s,env),step_index] = old + lr * ( reward_value + gamma * qtable[index2lin(s_new,env), step_index_new] - old )
qtable_updated=False
if qtable[index2lin(s,env),step_index] != old:
np.save(qtable_save, qtable)
qtable_updated=True
return s_new, qtable, reward_value, qtable_updated