diff --git a/AlphaGo/ai.py b/AlphaGo/ai.py index f8c695732..593e856da 100644 --- a/AlphaGo/ai.py +++ b/AlphaGo/ai.py @@ -9,15 +9,18 @@ class GreedyPolicyPlayer(object): move each turn) """ - def __init__(self, policy_function, pass_when_offered=False): + def __init__(self, policy_function, pass_when_offered=False, move_limit=None): self.policy = policy_function self.pass_when_offered = pass_when_offered + self.move_limit = move_limit def get_move(self, state): + if self.move_limit is not None and len(state.history) > self.move_limit: + return go.PASS_MOVE if self.pass_when_offered: if len(state.history) > 100 and state.history[-1] == go.PASS_MOVE: return go.PASS_MOVE - sensible_moves = [move for move in state.get_legal_moves() if not state.is_eye(move, state.current_player)] + sensible_moves = [move for move in state.get_legal_moves(include_eyes=False)] if len(sensible_moves) > 0: move_probs = self.policy.eval_state(state, sensible_moves) max_prob = max(move_probs, key=lambda (a, p): p) @@ -34,17 +37,21 @@ class ProbabilisticPolicyPlayer(object): (high temperature) or towards greedy play (low temperature) """ - def __init__(self, policy_function, temperature=1.0, pass_when_offered=False): + def __init__(self, policy_function, temperature=1.0, pass_when_offered=False, move_limit=None): assert(temperature > 0.0) self.policy = policy_function + self.move_limit = move_limit self.beta = 1.0 / temperature self.pass_when_offered = pass_when_offered + self.move_limit = move_limit def get_move(self, state): + if self.move_limit is not None and len(state.history) > self.move_limit: + return go.PASS_MOVE if self.pass_when_offered: if len(state.history) > 100 and state.history[-1] == go.PASS_MOVE: return go.PASS_MOVE - sensible_moves = [move for move in state.get_legal_moves() if not state.is_eye(move, state.current_player)] + sensible_moves = [move for move in state.get_legal_moves(include_eyes=False)] if len(sensible_moves) > 0: move_probs = self.policy.eval_state(state, sensible_moves) # zip(*list) is like the 'transpose' of zip; zip(*zip([1,2,3], [4,5,6])) is [(1,2,3), (4,5,6)] @@ -60,11 +67,11 @@ def get_move(self, state): def get_moves(self, states): """Batch version of get_move. A list of moves is returned (one per state) """ - sensible_move_lists = [[move for move in st.get_legal_moves() if not st.is_eye(move, st.current_player)] for st in states] + sensible_move_lists = [[move for move in st.get_legal_moves(include_eyes=False)] for st in states] all_moves_distributions = self.policy.batch_eval_state(states, sensible_move_lists) move_list = [None] * len(states) for i, move_probs in enumerate(all_moves_distributions): - if len(move_probs) == 0: + if len(move_probs) == 0 or len(states[i].history) > self.move_limit: move_list[i] = go.PASS_MOVE else: # this 'else' clause is identical to ProbabilisticPolicyPlayer.get_move @@ -83,7 +90,7 @@ def __init__(self, value_function, policy_function, rollout_function, lmbda=.5, rollout_limit, playout_depth, n_playout) def get_move(self, state): - sensible_moves = [move for move in state.get_legal_moves() if not state.is_eye(move, state.current_player)] + sensible_moves = [move for move in state.get_legal_moves(include_eyes=False)] if len(sensible_moves) > 0: move = self.mcts.get_move(state) self.mcts.update_with_move(move) diff --git a/AlphaGo/training/reinforcement_policy_trainer.py b/AlphaGo/training/reinforcement_policy_trainer.py index 0e208b446..da5766b89 100644 --- a/AlphaGo/training/reinforcement_policy_trainer.py +++ b/AlphaGo/training/reinforcement_policy_trainer.py @@ -1,142 +1,193 @@ import os -import argparse import json import numpy as np -import itertools from shutil import copyfile -from keras.optimizers import SGD +from keras.optimizers import Optimizer +import keras.backend as K from AlphaGo.ai import ProbabilisticPolicyPlayer import AlphaGo.go as go from AlphaGo.go import GameState from AlphaGo.models.policy import CNNPolicy -from AlphaGo.preprocessing.preprocessing import Preprocess from AlphaGo.util import flatten_idx -def make_training_pairs(player, opp, features, mini_batch_size, board_size=19): - """Make training pairs for batch of matches, utilizing player.get_moves (parallel form of - player.get_move), which calls `CNNPolicy.batch_eval_state`. - - Args: - player -- player that we're always updating - opp -- batch opponent - feature_list -- game features to be one-hot encoded - mini_batch_size -- number of games in mini-batch - - Return: - X_list -- list of 1-hot board states associated with moves. - y_list -- list of 1-hot moves associated with board states. - winners -- list of winners associated with each game in batch - """ - - def record_training_pair(st, mv, X, y): - # Convert move to one-hot - bsize_flat = bsize * bsize - state_1hot = preprocessor.state_to_tensor(st) - move_1hot = np.zeros(bsize_flat) - move_1hot[flatten_idx(mv, bsize)] = 1 - X.append(state_1hot) - y.append(move_1hot) - - # First we want to prep the states so that half of the boards get a move. - # The other half we want to leave alone so that the second player makes the first move (being black). - # Decided to alternate every board because this is how humans would play a match. - def play_half_of_boards(states, moves, X_list, y_list, player_color): - for st, mv, X, y, should_move in zip(states, moves, X_list, y_list, itertools.cycle([True, False])): - if should_move: - if st.current_player == player_color: - record_training_pair(st, mv, X, y) - st.do_move(mv) - return states, X_list, y_list - - def do_move(states, moves, X_list, y_list, player_color): - for st, mv, X, y in zip(states, moves, X_list, y_list): - # Only do more moves if not end of game already - if not st.is_end_of_game: - # Only want to record moves by the 'player', not the opponent - if st.current_player == player_color and mv is not go.PASS_MOVE: - record_training_pair(st, mv, X, y) - st.do_move(mv) - return states, X_list, y_list - - # Lists of game training pairs (1-hot) - X_list = [list() for _ in xrange(mini_batch_size)] - y_list = [list() for _ in xrange(mini_batch_size)] - preprocessor = Preprocess(features) - bsize = player.policy.model.input_shape[-1] - states = [GameState(size=board_size) for i in xrange(mini_batch_size)] - # Randomly choose who goes first (i.e. color of 'player') - player_color = np.random.choice([go.BLACK, go.WHITE]) - player1, player2 = (player, opp) if player_color == go.BLACK else (opp, player) - # We let player1 move first for half of the boards in the minibatch - # The other half of the boards will be empty... waiting for a 'black' player - moves_player1 = player1.get_moves(states) - states, X_list, y_list = play_half_of_boards(states, moves_player1, X_list, y_list, player_color) - # Now player2 can move and will act as white for half and black for half - moves_player2 = player2.get_moves(states) - states, X_list, y_list = do_move(states, moves_player2, X_list, y_list, player_color) - # Now the game can continue.. each player acting as white and black split across the boards - while True: - # Get moves (batch) for player1 - moves_player1 = player1.get_moves(states) - states, X_list, y_list = do_move(states, moves_player1, X_list, y_list, player_color) - # Get moves for player2 - moves_player2 = player2.get_moves(states) - states, X_list, y_list = do_move(states, moves_player2, X_list, y_list, player_color) - # If all games have ended, we're done. Get winners. - done = [st.is_end_of_game for st in states] - if all(done): - break - won_game_list = [] - # If player was black, every even board is black, odd board white - if player_color == go.BLACK: - for st, game_color in zip(states, itertools.cycle([go.BLACK, go.WHITE])): - won_game_list.append(st.get_winner() == game_color) - else: - for st, game_color in zip(states, itertools.cycle([go.WHITE, go.BLACK])): - won_game_list.append(st.get_winner() == game_color) - # Concatenate tensors across turns within each game - for i in xrange(mini_batch_size): - X_list[i] = np.concatenate(X_list[i], axis=0) - y_list[i] = np.vstack(y_list[i]) - return X_list, y_list, won_game_list - - -def train_batch(player, X_list, y_list, won_game_list, lr): - """Given the outcomes of a mini-batch of play against a fixed opponent, - update the weights with reinforcement learning. - - Args: - player -- player object with policy weights to be updated - X_list -- List of one-hot encoded states. - y_list -- List of one-hot encoded actions (to pair with X_list). - winners -- List of winners corresponding to each item in - training_pairs_list - lr -- Keras learning rate - - Return: - player -- same player, with updated weights. - """ - - for X, y, won_game in zip(X_list, y_list, won_game_list): - # Update weights in + direction if player won, and - direction if player lost. - # Setting learning rate negative is hack for negative weights update. - if won_game: - player.policy.model.optimizer.lr.set_value(lr) - else: - player.policy.model.optimizer.lr.set_value(-lr) - player.policy.model.fit(X, y, nb_epoch=1, batch_size=len(X)) +class BatchedReinforcementLearningSGD(Optimizer): + '''A Keras Optimizer that sums gradients together for each game, applying them only once the + winner is known. + + It is the responsibility of the calling code to call set_current_game() before each example to + tell the optimizer for which game gradients should be accumulated, and to call set_result() to + tell the optimizer what the sign of the gradient for each game should be and when all games are + over. + + Arguments + lr: float >= 0. Learning rate. + ng: int > 0. Number of games played in parallel. Each one has its own cumulative gradient. + ''' + def __init__(self, lr=0.01, ng=20, **kwargs): + super(BatchedReinforcementLearningSGD, self).__init__(**kwargs) + self.__dict__.update(locals()) + self.lr = K.variable(lr) + self.cumulative_gradients = [] + self.num_games = ng + self.game_idx = K.variable(0) # which gradient to accumulate in the next batch. + self.gradient_sign = [K.variable(0) for _ in range(ng)] + self.running_games = K.variable(self.num_games) + + def set_current_game(self, game_idx): + K.set_value(self.game_idx, game_idx) + + def set_result(self, game_idx, won_game): + '''Mark the outcome of the game at index game_idx. Once all games are complete, updates + are automatically triggered in the next call to a keras fit function. + ''' + K.set_value(self.gradient_sign[game_idx], +1 if won_game else -1) + # Note: using '-= 1' would create a new variable, which would invalidate the dependencies + # in get_updates(). + K.set_value(self.running_games, K.get_value(self.running_games) - 1) + + def get_updates(self, params, constraints, loss): + # Note: get_updates is called *once* by keras. Its job is to return a set of 'update + # operations' to any K.variable (e.g. model weights or self.num_games). Updates are applied + # whenever Keras' train_function is evaluated, i.e. in every batch. Model.fit_on_batch() + # will trigger exactly one update. All updates use the 'old' value of parameters - there is + # no dependency on the order of the list of updates. + self.updates = [] + # Get expressions for gradients of model parameters. + grads = self.get_gradients(loss, params) + # Create a set of accumulated gradients, one for each game. + shapes = [K.get_variable_shape(p) for p in params] + self.cumulative_gradients = [[K.zeros(shape) for shape in shapes] for _ in range(self.num_games)] + + def conditional_update(cond, variable, new_value): + '''Helper function to create updates that only happen when cond is True. Writes to + self.updates and returns the new variable. + + Note: K.update(x, x) is cheap, but K.update_add(x, K.zeros_like(x)) can be expensive. + ''' + maybe_new_value = K.switch(cond, new_value, variable) + self.updates.append(K.update(variable, maybe_new_value)) + return maybe_new_value + + # Update cumulative gradient at index game_idx. This is done by returning an update for all + # gradients that is a no-op everywhere except for the game_idx'th one. When game_idx is + # changed by a call to set_current_game(), it will change the gradient that is getting + # accumulated. + # new_cumulative_gradients keeps references to the updated variables for use below in + # updating parameters with the freshly-accumulated gradients. + new_cumulative_gradients = [[None] * len(cgs) for cgs in self.cumulative_gradients] + for i, cgs in enumerate(self.cumulative_gradients): + for j, (g, cg) in enumerate(zip(grads, cgs)): + new_gradient = conditional_update(K.equal(self.game_idx, i), cg, cg + g) + new_cumulative_gradients[i][j] = new_gradient + + # Compute the net update to parameters, taking into account the sign of each cumulative + # gradient. + net_grads = [K.zeros_like(g) for g in grads] + for i, cgs in enumerate(new_cumulative_gradients): + for j, cg in enumerate(cgs): + net_grads[j] += self.gradient_sign[i] * cg + + # Trigger a full update when all games have finished. + self.trigger_update = K.lesser_equal(self.running_games, 0) + + # Update model parameters conditional on trigger_update. + for p, g in zip(params, net_grads): + new_p = p + g * self.lr + if p in constraints: + c = constraints[p] + new_p = c(new_p) + conditional_update(self.trigger_update, p, new_p) + + # 'reset' game counter and gradient signs when parameters are updated. + for sign in self.gradient_sign: + conditional_update(self.trigger_update, sign, K.variable(0)) + conditional_update(self.trigger_update, self.running_games, K.variable(self.num_games)) + return self.updates + + def get_config(self): + config = { + 'lr': float(K.get_value(self.lr)), + 'ng': self.num_games} + base_config = super(BatchedReinforcementLearningSGD, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +def _make_training_pair(st, mv, preprocessor): + # Convert move to one-hot + st_tensor = preprocessor.state_to_tensor(st) + mv_tensor = np.zeros((1, st.size * st.size)) + mv_tensor[(0, flatten_idx(mv, st.size))] = 1 + return (st_tensor, mv_tensor) + + +def run_n_games(optimizer, learner, opponent, num_games): + '''Run num_games games to completion, calling train_batch() on each position the learner sees. + + (Note: optimizer only accumulates gradients in its update function until all games have finished) + ''' + board_size = learner.policy.model.input_shape[-1] + states = [GameState(size=board_size) for _ in range(num_games)] + learner_net = learner.policy.model + + # Start all odd games with moves by 'opponent'. Even games will have 'learner' black. + learner_color = [go.BLACK if i % 2 == 0 else go.WHITE for i in range(num_games)] + odd_states = states[1::2] + moves = opponent.get_moves(odd_states) + for st, mv in zip(odd_states, moves): + st.do_move(mv) + + current = learner + other = opponent + # Need to keep track of the index of unfinished states so that we can communicate which one is + # being updated to the optimizer. + idxs_to_unfinished_states = {i: states[i] for i in range(num_games)} + while len(idxs_to_unfinished_states) > 0: + # Get next moves by current player for all unfinished states. + moves = current.get_moves(idxs_to_unfinished_states.values()) + just_finished = [] + # Do each move to each state in order. + for (idx, state), mv in zip(idxs_to_unfinished_states.iteritems(), moves): + # Order is important here. We must first get the training pair on the unmodified state. + # Next, the state is updated and checked to see if the game is over. If it is over, the + # optimizer is notified via set_result. Finally, train_on_batch is called, which + # will trigger an update of all parameters only if set_result() has been called + # for all games already (so set_result must come before train_on_batch). + is_learnable = current is learner and mv is not go.PASS_MOVE + if is_learnable: + (X, y) = _make_training_pair(state, mv, learner.policy.preprocessor) + state.do_move(mv) + if state.is_end_of_game: + learner_is_winner = state.get_winner() == learner_color[idx] + optimizer.set_result(idx, learner_is_winner) + just_finished.append(idx) + if is_learnable: + optimizer.set_current_game(idx) + learner_net.train_on_batch(X, y) + + # Remove games that have finished from dict. + for idx in just_finished: + del idxs_to_unfinished_states[idx] + + # Swap 'current' and 'other' for next turn. + current, other = other, current + + # Return the win ratio. + wins = sum(state.get_winner == pc for (state, pc) in zip(states, learner_color)) + return float(wins) / num_games def run_training(cmd_line_args=None): + import argparse parser = argparse.ArgumentParser(description='Perform reinforcement learning to improve given policy network. Second phase of pipeline.') parser.add_argument("model_json", help="Path to policy model JSON.") parser.add_argument("initial_weights", help="Path to HDF5 file with inital weights (i.e. result of supervised training).") parser.add_argument("out_directory", help="Path to folder where the model params and metadata will be saved after each epoch.") - parser.add_argument("--learning-rate", help="Keras learning rate (Default: .03)", type=float, default=.03) + parser.add_argument("--learning-rate", help="Keras learning rate (Default: 0.001)", type=float, default=0.001) parser.add_argument("--policy-temp", help="Distribution temperature of players using policies (Default: 0.67)", type=float, default=0.67) parser.add_argument("--save-every", help="Save policy as a new opponent every n batches (Default: 500)", type=int, default=500) parser.add_argument("--game-batch", help="Number of games per mini-batch (Default: 20)", type=int, default=20) + parser.add_argument("--move-limit", help="Maximum number of moves per game", type=int, default=500) parser.add_argument("--iterations", help="Number of training batches/iterations (Default: 10000)", type=int, default=10000) parser.add_argument("--resume", help="Load latest weights in out_directory and resume", default=False, action="store_true") parser.add_argument("--verbose", "-v", help="Turn on verbose mode", default=False, action="store_true") @@ -176,13 +227,12 @@ def run_training(cmd_line_args=None): # Set initial conditions policy = CNNPolicy.load_model(args.model_json) policy.model.load_weights(args.initial_weights) - player = ProbabilisticPolicyPlayer(policy, temperature=args.policy_temp) - features = policy.preprocessor.feature_list + player = ProbabilisticPolicyPlayer(policy, temperature=args.policy_temp, move_limit=args.move_limit) - # different opponents come from simply changing the weights of - # opponent.policy.model "behind the scenes" + # different opponents come from simply changing the weights of 'opponent.policy.model'. That + # is, only 'opp_policy' needs to be changed, and 'opponent' will change. opp_policy = CNNPolicy.load_model(args.model_json) - opponent = ProbabilisticPolicyPlayer(opp_policy, temperature=args.policy_temp) + opponent = ProbabilisticPolicyPlayer(opp_policy, temperature=args.policy_temp, move_limit=args.move_limit) if args.verbose: print "created player and opponent with temperature {}".format(args.policy_temp) @@ -208,28 +258,28 @@ def save_metadata(): with open(os.path.join(args.out_directory, "metadata.json"), "w") as f: json.dump(metadata, f, sort_keys=True, indent=2) - # Set SGD and compile - sgd = SGD(lr=args.learning_rate) - player.policy.model.compile(loss='binary_crossentropy', optimizer=sgd) - board_size = player.policy.model.input_shape[-1] + optimizer = BatchedReinforcementLearningSGD(lr=args.learning_rate, ng=args.game_batch) + player.policy.model.compile(loss='categorical_crossentropy', optimizer=optimizer) for i_iter in xrange(1, args.iterations + 1): - # Train mini-batches by randomly choosing opponent from pool (possibly self) - # and playing game_batch games against them + # Randomly choose opponent from pool (possibly self), and playing game_batch games against + # them. opp_weights = np.random.choice(metadata["opponents"]) opp_path = os.path.join(args.out_directory, opp_weights) - # load new weights into opponent, but otherwise its the same + + # Load new weights into opponent's network, but keep the same opponent object. opponent.policy.model.load_weights(opp_path) if args.verbose: print "Batch {}\tsampled opponent is {}".format(i_iter, opp_weights) - # Make training pairs and do RL - X_list, y_list, won_game_list = make_training_pairs(player, opponent, features, args.game_batch, board_size) - win_ratio = np.sum(won_game_list) / float(args.game_batch) + + # Run games (and learn from results). Keep track of the win ratio vs each opponent over time. + win_ratio = run_n_games(optimizer, player, opponent, args.game_batch) metadata["win_ratio"][player_weights] = (opp_weights, win_ratio) - train_batch(player, X_list, y_list, won_game_list, args.learning_rate) - # Save intermediate models + + # Save all intermediate models. player_weights = "weights.%05d.hdf5" % i_iter player.policy.model.save_weights(os.path.join(args.out_directory, player_weights)) - # add player to batch of oppenents once in a while + + # Add player to batch of oppenents once in a while. if i_iter % args.save_every == 0: metadata["opponents"].append(player_weights) save_metadata() diff --git a/tests/test_reinforcement_policy_trainer.py b/tests/test_reinforcement_policy_trainer.py index ecda3a5fe..702d25332 100644 --- a/tests/test_reinforcement_policy_trainer.py +++ b/tests/test_reinforcement_policy_trainer.py @@ -1,13 +1,19 @@ import os -from AlphaGo.training.reinforcement_policy_trainer import run_training +from AlphaGo.training.reinforcement_policy_trainer import run_training, _make_training_pair, BatchedReinforcementLearningSGD import unittest +import numpy as np +import numpy.testing as npt +import keras.backend as K +from AlphaGo.models.policy import CNNPolicy +from AlphaGo.go import GameState class TestReinforcementPolicyTrainer(unittest.TestCase): + def testTrain(self): - model = 'tests/test_data/minimodel.json' - init_weights = 'tests/test_data/hdf5/random_minimodel_weights.hdf5' - output = 'tests/test_data/.tmp.rl.training/' + model = os.path.join('tests', 'test_data', 'minimodel.json') + init_weights = os.path.join('tests', 'test_data', 'hdf5', 'random_minimodel_weights.hdf5') + output = os.path.join('tests', 'test_data', '.tmp.rl.training/') args = [model, init_weights, output, '--game-batch', '1', '--iterations', '1'] run_training(args) @@ -16,5 +22,112 @@ def testTrain(self): os.remove(os.path.join(output, 'weights.00001.hdf5')) os.rmdir(output) + +class TestOptimizer(unittest.TestCase): + + def testApplyAndResetOnGamesFinished(self): + policy = CNNPolicy.load_model(os.path.join('tests', 'test_data', 'minimodel.json')) + state = GameState(size=19) + optimizer = BatchedReinforcementLearningSGD(lr=0.01, ng=2) + policy.model.compile(loss='categorical_crossentropy', optimizer=optimizer) + + # Helper to check initial conditions of the optimizer. + def assertOptimizerInitialConditions(): + for v in optimizer.gradient_sign: + self.assertEqual(K.eval(v), 0) + self.assertEqual(K.eval(optimizer.running_games), 2) + + initial_parameters = policy.model.get_weights() + + def assertModelEffect(changed): + any_change = False + for cur, init in zip(policy.model.get_weights(), initial_parameters): + if not np.allclose(init, cur): + any_change = True + break + self.assertEqual(any_change, changed) + + assertOptimizerInitialConditions() + + # Make moves on the state and get trainable (state, action) pairs from them. + state_tensors = [] + action_tensors = [] + moves = [(2, 2), (16, 16), (3, 17), (16, 2), (4, 10), (10, 3)] + for m in moves: + (st_tensor, mv_tensor) = _make_training_pair(state, m, policy.preprocessor) + state_tensors.append(st_tensor) + action_tensors.append(mv_tensor) + state.do_move(m) + + for i, (s, a) in enumerate(zip(state_tensors, action_tensors)): + # Even moves in game 0, odd moves in game 1 + game_idx = i % 2 + optimizer.set_current_game(game_idx) + is_last_move = i + 2 >= len(moves) + if is_last_move: + # Mark game 0 as a win and game 1 as a loss. + optimizer.set_result(game_idx, game_idx == 0) + else: + # Games not finished yet; assert no change to optimizer state. + assertOptimizerInitialConditions() + # train_on_batch accumulates gradients, and should only cause a change to parameters + # on the first call after the final set_result() call + policy.model.train_on_batch(s, a) + if i + 1 < len(moves): + assertModelEffect(changed=False) + else: + assertModelEffect(changed=True) + # Once both games finished, the last call to train_on_batch() should have triggered a reset + # to the optimizer parameters back to initial conditions. + assertOptimizerInitialConditions() + + def testGradientDirectionChangesWithGameResult(self): + + def run_and_get_new_weights(init_weights, win0, win1): + state = GameState(size=19) + policy = CNNPolicy.load_model(os.path.join('tests', 'test_data', 'minimodel.json')) + policy.model.set_weights(init_weights) + optimizer = BatchedReinforcementLearningSGD(lr=0.01, ng=2) + policy.model.compile(loss='categorical_crossentropy', optimizer=optimizer) + + # Make moves on the state and get trainable (state, action) pairs from them. + moves = [(2, 2), (16, 16), (3, 17), (16, 2), (4, 10), (10, 3)] + state_tensors = [] + action_tensors = [] + for m in moves: + (st_tensor, mv_tensor) = _make_training_pair(state, m, policy.preprocessor) + state_tensors.append(st_tensor) + action_tensors.append(mv_tensor) + state.do_move(m) + + for i, (s, a) in enumerate(zip(state_tensors, action_tensors)): + # Put even state/action pairs in game 0, odd ones in game 1. + game_idx = i % 2 + optimizer.set_current_game(game_idx) + is_last_move = i + 2 >= len(moves) + if is_last_move: + if game_idx == 0: + optimizer.set_result(game_idx, win0) + else: + optimizer.set_result(game_idx, win1) + # train_on_batch accumulates gradients, and should only cause a change to parameters + # on the first call after the final set_result() call + policy.model.train_on_batch(s, a) + return policy.model.get_weights() + + policy = CNNPolicy.load_model(os.path.join('tests', 'test_data', 'minimodel.json')) + initial_parameters = policy.model.get_weights() + # Cases 1 and 2 have identical starting models and identical (state, action) pairs, + # but they differ in who won the games. + parameters1 = run_and_get_new_weights(initial_parameters, True, False) + parameters2 = run_and_get_new_weights(initial_parameters, False, True) + + # Changes in case 1 should be equal and opposite to changes in case 2. Allowing 0.1% + # difference in precision. + for (i, p1, p2) in zip(initial_parameters, parameters1, parameters2): + diff1 = p1 - i + diff2 = p2 - i + npt.assert_allclose(diff1, -diff2, rtol=1e-3) + if __name__ == '__main__': unittest.main()