forked from lazyprogrammer/machine_learning_examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsarsa.py
132 lines (110 loc) · 3.93 KB
/
sarsa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# https://deeplearningcourses.com/c/artificial-intelligence-reinforcement-learning-in-python
# https://www.udemy.com/artificial-intelligence-reinforcement-learning-in-python
from __future__ import print_function, division
from builtins import range
# Note: you may need to update your version of future
# sudo pip install -U future
import numpy as np
import matplotlib.pyplot as plt
from grid_world import standard_grid, negative_grid
from iterative_policy_evaluation import print_values, print_policy
from monte_carlo_es import max_dict
from td0_prediction import random_action
GAMMA = 0.9
ALPHA = 0.1
ALL_POSSIBLE_ACTIONS = ('U', 'D', 'L', 'R')
if __name__ == '__main__':
# NOTE: if we use the standard grid, there's a good chance we will end up with
# suboptimal policies
# e.g.
# ---------------------------
# R | R | R | |
# ---------------------------
# R* | | U | |
# ---------------------------
# U | R | U | L |
# since going R at (1,0) (shown with a *) incurs no cost, it's OK to keep doing that.
# we'll either end up staying in the same spot, or back to the start (2,0), at which
# point we whould then just go back up, or at (0,0), at which point we can continue
# on right.
# instead, let's penalize each movement so the agent will find a shorter route.
#
# grid = standard_grid()
grid = negative_grid(step_cost=-0.1)
# print rewards
print("rewards:")
print_values(grid.rewards, grid)
# no policy initialization, we will derive our policy from most recent Q
# initialize Q(s,a)
Q = {}
states = grid.all_states()
for s in states:
Q[s] = {}
for a in ALL_POSSIBLE_ACTIONS:
Q[s][a] = 0
# let's also keep track of how many times Q[s] has been updated
update_counts = {}
update_counts_sa = {}
for s in states:
update_counts_sa[s] = {}
for a in ALL_POSSIBLE_ACTIONS:
update_counts_sa[s][a] = 1.0
# repeat until convergence
t = 1.0
deltas = []
for it in range(10000):
if it % 100 == 0:
t += 1e-2
if it % 2000 == 0:
print("it:", it)
# instead of 'generating' an epsiode, we will PLAY
# an episode within this loop
s = (2, 0) # start state
grid.set_state(s)
# the first (s, r) tuple is the state we start in and 0
# (since we don't get a reward) for simply starting the game
# the last (s, r) tuple is the terminal state and the final reward
# the value for the terminal state is by definition 0, so we don't
# care about updating it.
a = max_dict(Q[s])[0]
a = random_action(a, eps=0.5/t)
biggest_change = 0
while not grid.game_over():
r = grid.move(a)
s2 = grid.current_state()
# we need the next action as well since Q(s,a) depends on Q(s',a')
# if s2 not in policy then it's a terminal state, all Q are 0
a2 = max_dict(Q[s2])[0]
a2 = random_action(a2, eps=0.5/t) # epsilon-greedy
# we will update Q(s,a) AS we experience the episode
alpha = ALPHA / update_counts_sa[s][a]
update_counts_sa[s][a] += 0.005
old_qsa = Q[s][a]
Q[s][a] = Q[s][a] + alpha*(r + GAMMA*Q[s2][a2] - Q[s][a])
biggest_change = max(biggest_change, np.abs(old_qsa - Q[s][a]))
# we would like to know how often Q(s) has been updated too
update_counts[s] = update_counts.get(s,0) + 1
# next state becomes current state
s = s2
a = a2
deltas.append(biggest_change)
plt.plot(deltas)
plt.show()
# determine the policy from Q*
# find V* from Q*
policy = {}
V = {}
for s in grid.actions.keys():
a, max_q = max_dict(Q[s])
policy[s] = a
V[s] = max_q
# what's the proportion of time we spend updating each part of Q?
print("update counts:")
total = np.sum(list(update_counts.values()))
for k, v in update_counts.items():
update_counts[k] = float(v) / total
print_values(update_counts, grid)
print("values:")
print_values(V, grid)
print("policy:")
print_policy(policy, grid)