-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
39 lines (32 loc) · 1.24 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import tensorflow as tf
# from agents.epsilon_greedy_agent import EpsilonGreedyAgent
from agents.td_leaf_agent import TDLeafAgent
from envs.tic_tac_toe import TicTacToeEnv
from envs.chess import ChessEnv
from value_model import ValueModel
from value_model import ValueModel
import time
import cProfile
def main():
env = TicTacToeEnv()
network = ValueModel(env.get_feature_vector_size())
# agent = EpsilonGreedyAgent('agent_0', network, env, verbose=True)
# env = ChessEnv()
# network = ChessValueModel()
agent = TDLeafAgent('agent_0', network, env, verbose=True)
summary_op = tf.summary.merge_all()
log_dir = "./log/" + str(int(time.time()))
with tf.train.SingularMonitoredSession(checkpoint_dir=log_dir,
scaffold=tf.train.Scaffold(summary_op=summary_op)) as sess:
agent.sess = sess
# cProfile.runctx('agent.train(depth=3)', globals(), locals())
for i in range(10000):
if i % 100 == 0:
agent.random_agent_test(depth=2)
# agent.test(0, depth=3)
# pass
else:
agent.train(depth=2)
sess.run(agent.increment_episode_count)
if __name__ == "__main__":
main()