forked from algolab-inc/tf-dqn-simple
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
53 lines (39 loc) · 1.31 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import numpy as np
from catch_ball import CatchBall
from dqn_agent import DQNAgent
if __name__ == "__main__":
# parameters
n_epochs = 1000
# environment, agent
env = CatchBall()
agent = DQNAgent(env.enable_actions, env.name)
# variables
win = 0
for e in range(n_epochs):
# reset
frame = 0
loss = 0.0
Q_max = 0.0
env.reset()
state_t_1, reward_t, terminal = env.observe()
while not terminal:
state_t = state_t_1
# execute action in environment
action_t = agent.select_action(state_t, agent.exploration)
env.execute_action(action_t)
# observe environment
state_t_1, reward_t, terminal = env.observe()
# store experience
agent.store_experience(state_t, action_t, reward_t, state_t_1, terminal)
# experience replay
agent.experience_replay()
# for log
frame += 1
loss += agent.current_loss
Q_max += np.max(agent.Q_values(state_t))
if reward_t == 1:
win += 1
print("EPOCH: {:03d}/{:03d} | WIN: {:03d} | LOSS: {:.4f} | Q_MAX: {:.4f}".format(
e, n_epochs - 1, win, loss / frame, Q_max / frame))
# save model
agent.save_model()