forked from lazyprogrammer/machine_learning_examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcomparing_epsilons.py
78 lines (60 loc) · 1.78 KB
/
comparing_epsilons.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# https://deeplearningcourses.com/c/artificial-intelligence-reinforcement-learning-in-python
# https://www.udemy.com/artificial-intelligence-reinforcement-learning-in-python
from __future__ import print_function, division
from builtins import range
# Note: you may need to update your version of future
# sudo pip install -U future
import numpy as np
import matplotlib.pyplot as plt
class Bandit:
def __init__(self, m):
self.m = m
self.mean = 0
self.N = 0
def pull(self):
return np.random.randn() + self.m
def update(self, x):
self.N += 1
self.mean = (1 - 1.0/self.N)*self.mean + 1.0/self.N*x
def run_experiment(m1, m2, m3, eps, N):
bandits = [Bandit(m1), Bandit(m2), Bandit(m3)]
data = np.empty(N)
for i in range(N):
# epsilon greedy
p = np.random.random()
if p < eps:
j = np.random.choice(3)
else:
j = np.argmax([b.mean for b in bandits])
x = bandits[j].pull()
bandits[j].update(x)
# for the plot
data[i] = x
cumulative_average = np.cumsum(data) / (np.arange(N) + 1)
# plot moving average ctr
plt.plot(cumulative_average)
plt.plot(np.ones(N)*m1)
plt.plot(np.ones(N)*m2)
plt.plot(np.ones(N)*m3)
plt.xscale('log')
plt.show()
for b in bandits:
print(b.mean)
return cumulative_average
if __name__ == '__main__':
c_1 = run_experiment(1.0, 2.0, 3.0, 0.1, 100000)
c_05 = run_experiment(1.0, 2.0, 3.0, 0.05, 100000)
c_01 = run_experiment(1.0, 2.0, 3.0, 0.01, 100000)
# log scale plot
plt.plot(c_1, label='eps = 0.1')
plt.plot(c_05, label='eps = 0.05')
plt.plot(c_01, label='eps = 0.01')
plt.legend()
plt.xscale('log')
plt.show()
# linear plot
plt.plot(c_1, label='eps = 0.1')
plt.plot(c_05, label='eps = 0.05')
plt.plot(c_01, label='eps = 0.01')
plt.legend()
plt.show()