Skip to content

Commit

Permalink
break long lines
Browse files Browse the repository at this point in the history
  • Loading branch information
thouis committed Sep 22, 2016
1 parent 2f089be commit b5f3986
Show file tree
Hide file tree
Showing 15 changed files with 141 additions and 74 deletions.
12 changes: 8 additions & 4 deletions AlphaGo/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,23 @@ def get_move(self, state):
sensible_moves = [move for move in state.get_legal_moves(include_eyes=False)]
if len(sensible_moves) > 0:
move_probs = self.policy.eval_state(state, sensible_moves)
# zip(*list) is like the 'transpose' of zip; zip(*zip([1,2,3], [4,5,6])) is [(1,2,3), (4,5,6)]
# zip(*list) is like the 'transpose' of zip;
# zip(*zip([1,2,3], [4,5,6])) is [(1,2,3), (4,5,6)]
moves, probabilities = zip(*move_probs)
probabilities = np.array(probabilities)
probabilities = probabilities ** self.beta
probabilities = probabilities / probabilities.sum()
# numpy interprets a list of tuples as 2D, so we must choose an _index_ of moves then apply it in 2 steps
# numpy interprets a list of tuples as 2D, so we must choose an
# _index_ of moves then apply it in 2 steps
choice_idx = np.random.choice(len(moves), p=probabilities)
return moves[choice_idx]
return go.PASS_MOVE

def get_moves(self, states):
"""Batch version of get_move. A list of moves is returned (one per state)
"""
sensible_move_lists = [[move for move in st.get_legal_moves(include_eyes=False)] for st in states]
sensible_move_lists = [[move for move in st.get_legal_moves(include_eyes=False)]
for st in states]
all_moves_distributions = self.policy.batch_eval_state(states, sensible_move_lists)
move_list = [None] * len(states)
for i, move_probs in enumerate(all_moves_distributions):
Expand All @@ -85,7 +88,8 @@ def get_moves(self, states):


class MCTSPlayer(object):
def __init__(self, value_function, policy_function, rollout_function, lmbda=.5, c_puct=5, rollout_limit=500, playout_depth=40, n_playout=100):
def __init__(self, value_function, policy_function, rollout_function, lmbda=.5, c_puct=5,
rollout_limit=500, playout_depth=40, n_playout=100):
self.mcts = mcts.MCTS(value_function, policy_function, rollout_function, lmbda, c_puct,
rollout_limit, playout_depth, n_playout)

Expand Down
15 changes: 10 additions & 5 deletions AlphaGo/go.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def __init__(self, size=19, komi=7.5, enforce_superko=False):
# similarly to `liberty_sets`, `group_sets[x][y]` points to a set of tuples
# containing all (x',y') pairs in the group connected to (x,y)
self.group_sets = [[set() for _ in range(size)] for _ in range(size)]
# cache of list of legal moves (actually 'sensible' moves, with a separate list for eye-moves on request)
# cache of list of legal moves (actually 'sensible' moves, with a
# separate list for eye-moves on request)
self.__legal_move_cache = None
self.__legal_eyes_cache = None
# on-the-fly record of 'age' of each stone
Expand Down Expand Up @@ -104,7 +105,8 @@ def _create_neighbors_cache(self):
GameState.__NEIGHBORS_CACHE[self.size] = {}
for x in xrange(self.size):
for y in xrange(self.size):
neighbors = [xy for xy in [(x - 1, y), (x + 1, y), (x, y - 1), (x, y + 1)] if self._on_board(xy)]
neighbors = [xy for xy in [(x - 1, y), (x + 1, y), (x, y - 1), (x, y + 1)]
if self._on_board(xy)]
GameState.__NEIGHBORS_CACHE[self.size][(x, y)] = neighbors

def _neighbors(self, position):
Expand All @@ -117,7 +119,8 @@ def _diagonals(self, position):
"""Like _neighbors but for diagonal positions
"""
(x, y) = position
return filter(self._on_board, [(x - 1, y - 1), (x + 1, y + 1), (x + 1, y - 1), (x - 1, y + 1)])
return filter(self._on_board, [(x - 1, y - 1), (x + 1, y + 1),
(x + 1, y - 1), (x - 1, y + 1)])

def _update_neighbors(self, position):
"""A private helper function to update self.group_sets and self.liberty_sets
Expand Down Expand Up @@ -229,8 +232,10 @@ def is_suicide(self, action):
return False

def is_positional_superko(self, action):
"""Find all actions that the current_player has done in the past, taking into account the fact that
history starts with BLACK when there are no handicaps or with WHITE when there are.
"""Find all actions that the current_player has done in the past, taking into
account the fact that history starts with BLACK when there are no
handicaps or with WHITE when there are.
"""
if len(self.handicaps) == 0 and self.current_player == BLACK:
player_history = self.history[0::2]
Expand Down
9 changes: 5 additions & 4 deletions AlphaGo/mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ class MCTS(object):
fast evaluation from leaf nodes to the end of the game.
"""

def __init__(self, value_fn, policy_fn, rollout_policy_fn, lmbda=0.5, c_puct=5, rollout_limit=500, playout_depth=20, n_playout=10000):
def __init__(self, value_fn, policy_fn, rollout_policy_fn, lmbda=0.5, c_puct=5,
rollout_limit=500, playout_depth=20, n_playout=10000):
"""Arguments:
value_fn -- a function that takes in a state and ouputs a score in [-1, 1], i.e. the
expected value of the end game score from the current player's perspective.
Expand All @@ -115,9 +116,9 @@ def __init__(self, value_fn, policy_fn, rollout_policy_fn, lmbda=0.5, c_puct=5,
lmbda -- controls the relative weight of the value network and fast rollout policy result
in determining the value of a leaf node. lmbda must be in [0, 1], where 0 means use only
the value network and 1 means use only the result from the rollout.
c_puct -- a number in (0, inf) that controls how quickly exploration converges to the maximum-
value policy, where a higher value means relying on the prior more, and should be used only
in conjunction with a large value for n_playout.
c_puct -- a number in (0, inf) that controls how quickly exploration converges to the
maximum-value policy, where a higher value means relying on the prior more, and
should be used only in conjunction with a large value for n_playout.
"""
self._root = TreeNode(None, 1.0)
self._value = value_fn
Expand Down
6 changes: 4 additions & 2 deletions AlphaGo/models/nn_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def _model_forward(self):
# be set to 0 when using the network in prediction mode and is automatically set to 1
# during training.
if self.model.uses_learning_phase:
forward_function = K.function([self.model.input, K.learning_phase()], [self.model.output])
forward_function = K.function([self.model.input, K.learning_phase()],
[self.model.output])

# the forward_function returns a list of tensors
# the first [0] gets the front tensor.
Expand All @@ -68,7 +69,8 @@ def load_model(json_file):
try:
network_class = NeuralNetBase.subclasses[class_name]
except KeyError:
raise ValueError("Unknown neural network type in json file: {}\n(was it registered with the @neuralnet decorator?)".format(class_name))
raise ValueError("Unknown neural network type in json file: {}\n"
"(was it registered with the @neuralnet decorator?)".format(class_name))

# create new object
new_net = network_class(object_specs['feature_list'], init_network=False)
Expand Down
16 changes: 10 additions & 6 deletions AlphaGo/models/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ def batch_eval_state(self, states, moves_lists=None):
raise ValueError("all states must have the same size")
# concatenate together all one-hot encoded states along the 'batch' dimension
nn_input = np.concatenate([self.preprocessor.state_to_tensor(s) for s in states], axis=0)
# pass all input through the network at once (backend makes use of batches if len(states) is large)
# pass all input through the network at once (backend makes use of
# batches if len(states) is large)
network_output = self.forward(nn_input)
# default move lists to all legal moves
moves_lists = moves_lists or [st.get_legal_moves() for st in states]
results = [None] * n_states
for i in range(n_states):
results[i] = self._select_moves_and_normalize(network_output[i], moves_lists[i], state_size)
results[i] = self._select_moves_and_normalize(network_output[i], moves_lists[i],
state_size)
return results

def eval_state(self, state, moves=None):
Expand Down Expand Up @@ -168,10 +170,12 @@ def create_network(**kwargs):
O - output
M - merge
The input is always passed through a Conv2D layer, the output of which layer is counted as '1'.
Each subsequent [R -- C] block is counted as one 'layer'. The 'merge' layer isn't counted; hence
if n_skip_1 is 2, the next valid skip parameter is n_skip_3, which will start at the output
of the merge
The input is always passed through a Conv2D layer, the output of which
layer is counted as '1'. Each subsequent [R -- C] block is counted as
one 'layer'. The 'merge' layer isn't counted; hence if n_skip_1 is 2,
the next valid skip parameter is n_skip_3, which will start at the
output of the merge
"""
defaults = {
"board": 19,
Expand Down
25 changes: 16 additions & 9 deletions AlphaGo/preprocessing/game_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,22 @@ def sgfs_to_hdf5(self, sgf_files, hdf5_file, bd_size=19, ignore_errors=True, ver
- sgf_files : an iterable of relative or absolute paths to SGF files
- hdf5_file : the name of the HDF5 where features will be saved
- bd_size : side length of board of games that are loaded
- ignore_errors : if True, issues a Warning when there is an unknown exception rather than halting. Note
that sgf.ParseException and go.IllegalMove exceptions are always skipped
- ignore_errors : if True, issues a Warning when there is an unknown
exception rather than halting. Note that sgf.ParseException and
go.IllegalMove exceptions are always skipped
The resulting file has the following properties:
states : dataset with shape (n_data, n_features, board width, board height)
actions : dataset with shape (n_data, 2) (actions are stored as x,y tuples of where the move was played)
actions : dataset with shape (n_data, 2) (actions are stored as x,y tuples of
where the move was played)
file_offsets : group mapping from filenames to tuples of (index, length)
For example, to find what positions in the dataset come from 'test.sgf':
index, length = file_offsets['test.sgf']
test_states = states[index:index+length]
test_actions = actions[index:index+length]
"""
# TODO - also save feature list

Expand All @@ -72,9 +76,9 @@ def sgfs_to_hdf5(self, sgf_files, hdf5_file, bd_size=19, ignore_errors=True, ver
'states',
dtype=np.uint8,
shape=(1, self.n_features, bd_size, bd_size),
maxshape=(None, self.n_features, bd_size, bd_size), # 'None' dimension allows it to grow arbitrarily
exact=False, # allow non-uint8 datasets to be loaded, coerced to uint8
chunks=(64, self.n_features, bd_size, bd_size), # approximately 1MB chunks
maxshape=(None, self.n_features, bd_size, bd_size), # 'None' == arbitrary size
exact=False, # allow non-uint8 datasets to be loaded, coerced to uint8
chunks=(64, self.n_features, bd_size, bd_size), # approximately 1MB chunks
compression="lzf")
actions = h5f.require_dataset(
'actions',
Expand Down Expand Up @@ -107,20 +111,23 @@ def sgfs_to_hdf5(self, sgf_files, hdf5_file, bd_size=19, ignore_errors=True, ver
n_pairs += 1
next_idx += 1
except go.IllegalMove:
warnings.warn("Illegal Move encountered in %s\n\tdropping the remainder of the game" % file_name)
warnings.warn("Illegal Move encountered in %s\n"
"\tdropping the remainder of the game" % file_name)
except sgf.ParseException:
warnings.warn("Could not parse %s\n\tdropping game" % file_name)
except SizeMismatchError:
warnings.warn("Skipping %s; wrong board size" % file_name)
except Exception as e:
# catch everything else
if ignore_errors:
warnings.warn("Unkown exception with file %s\n\t%s" % (file_name, e), stacklevel=2)
warnings.warn("Unkown exception with file %s\n\t%s" % (file_name, e),
stacklevel=2)
else:
raise e
finally:
if n_pairs > 0:
# '/' has special meaning in HDF5 key names, so they are replaced with ':' here
# '/' has special meaning in HDF5 key names, so they
# are replaced with ':' here
file_name_key = file_name.replace('/', ':')
file_offsets[file_name_key] = [file_start_idx, n_pairs]
if verbose:
Expand Down
18 changes: 12 additions & 6 deletions AlphaGo/preprocessing/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,25 @@ def get_liberties(state, maximum=8):
"""
planes = np.zeros((maximum, state.size, state.size))
for i in range(maximum):
# single liberties in plane zero (groups won't have zero), double liberties in plane one, etc
# single liberties in plane zero (groups won't have zero), double
# liberties in plane one, etc
planes[i, state.liberty_counts == i + 1] = 1
# the "maximum-or-more" case on the backmost plane
planes[maximum - 1, state.liberty_counts >= maximum] = 1
return planes


def get_capture_size(state, maximum=8):
"""A feature encoding the number of opponent stones that would be captured by playing at each location,
up to 'maximum'
"""A feature encoding the number of opponent stones that would be captured by
playing at each location, up to 'maximum'
Note:
- we currently *do* treat the 0th plane as "capturing zero stones"
- the [maximum-1] plane is used for any capturable group of size greater than or equal to maximum-1
- the [maximum-1] plane is used for any capturable group of size
greater than or equal to maximum-1
- the 0th plane is used for legal moves that would not result in capture
- illegal move locations are all-zero features
"""
planes = np.zeros((maximum, state.size, state.size))
for (x, y) in state.get_legal_moves():
Expand All @@ -71,14 +74,17 @@ def get_capture_size(state, maximum=8):
# (note suicide and ko are not an issue because they are not
# legal moves)
(gx, gy) = next(iter(neighbor_group))
if (state.liberty_counts[gx][gy] == 1) and (state.board[gx, gy] != state.current_player):
if (state.liberty_counts[gx][gy] == 1) and \
(state.board[gx, gy] != state.current_player):
n_captured += len(state.group_sets[gx][gy])
planes[min(n_captured, maximum - 1), x, y] = 1
return planes


def get_self_atari_size(state, maximum=8):
"""A feature encoding the size of the own-stone group that is put into atari by playing at a location
"""A feature encoding the size of the own-stone group that is put into atari by
playing at a location
"""
planes = np.zeros((maximum, state.size, state.size))

Expand Down
Loading

0 comments on commit b5f3986

Please sign in to comment.