forked from michaelnny/alpha_zero
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval_agent_go.py
142 lines (113 loc) · 4.33 KB
/
eval_agent_go.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Copyright (c) 2023 Michael Hu.
# This code is part of the book "The Art of Reinforcement Learning: Fundamentals, Mathematics, and Implementation with Python.".
# This project is released under the MIT License.
# See the accompanying LICENSE file for details.
"""Evaluate the AlphaZero agent on Go."""
from absl import flags
import os
import sys
import torch
FLAGS = flags.FLAGS
flags.DEFINE_integer('board_size', 9, 'Board size for Go.')
flags.DEFINE_float('komi', 7.5, 'Komi rule for Go.')
flags.DEFINE_integer(
'num_stack',
8,
'Stack N previous states, the state is an image of N x 2 + 1 binary planes.',
)
flags.DEFINE_integer('num_res_blocks', 10, 'Number of residual blocks in the neural network.')
flags.DEFINE_integer('num_filters', 128, 'Number of filters for the conv2d layers in the neural network.')
flags.DEFINE_integer(
'num_fc_units',
128,
'Number of hidden units in the linear layer of the neural network.',
)
flags.DEFINE_string(
'black_ckpt',
'./checkpoints/go/9x9/training_steps_154000.ckpt',
'Load the checkpoint file for black player.',
)
flags.DEFINE_string(
'white_ckpt',
'./checkpoints/go/9x9/training_steps_154000.ckpt',
'Load the checkpoint file for white player.',
)
flags.DEFINE_integer('num_simulations', 400, 'Number of iterations per MCTS search.')
flags.DEFINE_integer(
'num_parallel',
8,
'Number of leaves to collect before using the neural network to evaluate the positions during MCTS search, 1 means no parallel search.',
)
flags.DEFINE_float('c_puct_base', 19652, 'Exploration constants balancing priors vs. search values.')
flags.DEFINE_float('c_puct_init', 1.25, 'Exploration constants balancing priors vs. search values.')
flags.DEFINE_bool('human_vs_ai', True, 'Black player is human, default on.')
flags.DEFINE_bool('show_steps', False, 'Show step number on stones, default off.')
flags.DEFINE_integer('seed', 1, 'Seed the runtime.')
# Initialize flags
FLAGS(sys.argv)
os.environ['BOARD_SIZE'] = str(FLAGS.board_size)
from envs.go import GoEnv
from envs.gui import BoardGameGui
from network import AlphaZeroNet
from pipeline import create_mcts_player, set_seed, disable_auto_grad
from util import create_logger
def main():
set_seed(FLAGS.seed)
logger = create_logger()
runtime_device = 'cpu'
if torch.cuda.is_available():
runtime_device = 'cuda'
elif torch.backends.mps.is_available():
runtime_device = 'mps'
eval_env = GoEnv(komi=FLAGS.komi, num_stack=FLAGS.num_stack)
input_shape = eval_env.observation_space.shape
num_actions = eval_env.action_space.n
def network_builder():
return AlphaZeroNet(
input_shape,
num_actions,
FLAGS.num_res_blocks,
FLAGS.num_filters,
FLAGS.num_fc_units,
)
def load_checkpoint_for_net(network, ckpt_file, device):
if ckpt_file and os.path.isfile(ckpt_file):
loaded_state = torch.load(ckpt_file, map_location=torch.device(device))
network.load_state_dict(loaded_state['network'])
else:
logger.warning(f'Invalid checkpoint file "{ckpt_file}"')
def mcts_player_builder(ckpt_file, device):
network = network_builder().to(device)
disable_auto_grad(network)
load_checkpoint_for_net(network, ckpt_file, device)
network.eval()
return create_mcts_player(
network=network,
device=device,
num_simulations=FLAGS.num_simulations,
num_parallel=FLAGS.num_parallel,
root_noise=False,
deterministic=True,
)
# Wrap MCTS player for the GUI program
def wrap_player(mcts_player) -> int:
def act(env):
action, *_ = mcts_player(env, None, FLAGS.c_puct_base, FLAGS.c_puct_init, False)
return action
return act
white_player = mcts_player_builder(FLAGS.white_ckpt, runtime_device)
white_player = wrap_player(white_player)
if FLAGS.human_vs_ai:
black_player = 'human'
else:
black_player = mcts_player_builder(FLAGS.black_ckpt, runtime_device)
black_player = wrap_player(black_player)
game_gui = BoardGameGui(
eval_env,
black_player=black_player,
white_player=white_player,
show_steps=FLAGS.show_steps,
)
game_gui.start()
if __name__ == '__main__':
main()