Skip to content

Commit

Permalink
Merge pull request #159 from wrongu/rl-cleanup
Browse files Browse the repository at this point in the history
Refactor of RL using custom keras optimizer
  • Loading branch information
wrongu authored Sep 19, 2016
2 parents 9e44e26 + c3644fb commit d4f03c5
Show file tree
Hide file tree
Showing 3 changed files with 317 additions and 147 deletions.
21 changes: 14 additions & 7 deletions AlphaGo/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,18 @@ class GreedyPolicyPlayer(object):
move each turn)
"""

def __init__(self, policy_function, pass_when_offered=False):
def __init__(self, policy_function, pass_when_offered=False, move_limit=None):
self.policy = policy_function
self.pass_when_offered = pass_when_offered
self.move_limit = move_limit

def get_move(self, state):
if self.move_limit is not None and len(state.history) > self.move_limit:
return go.PASS_MOVE
if self.pass_when_offered:
if len(state.history) > 100 and state.history[-1] == go.PASS_MOVE:
return go.PASS_MOVE
sensible_moves = [move for move in state.get_legal_moves() if not state.is_eye(move, state.current_player)]
sensible_moves = [move for move in state.get_legal_moves(include_eyes=False)]
if len(sensible_moves) > 0:
move_probs = self.policy.eval_state(state, sensible_moves)
max_prob = max(move_probs, key=lambda (a, p): p)
Expand All @@ -34,17 +37,21 @@ class ProbabilisticPolicyPlayer(object):
(high temperature) or towards greedy play (low temperature)
"""

def __init__(self, policy_function, temperature=1.0, pass_when_offered=False):
def __init__(self, policy_function, temperature=1.0, pass_when_offered=False, move_limit=None):
assert(temperature > 0.0)
self.policy = policy_function
self.move_limit = move_limit
self.beta = 1.0 / temperature
self.pass_when_offered = pass_when_offered
self.move_limit = move_limit

def get_move(self, state):
if self.move_limit is not None and len(state.history) > self.move_limit:
return go.PASS_MOVE
if self.pass_when_offered:
if len(state.history) > 100 and state.history[-1] == go.PASS_MOVE:
return go.PASS_MOVE
sensible_moves = [move for move in state.get_legal_moves() if not state.is_eye(move, state.current_player)]
sensible_moves = [move for move in state.get_legal_moves(include_eyes=False)]
if len(sensible_moves) > 0:
move_probs = self.policy.eval_state(state, sensible_moves)
# zip(*list) is like the 'transpose' of zip; zip(*zip([1,2,3], [4,5,6])) is [(1,2,3), (4,5,6)]
Expand All @@ -60,11 +67,11 @@ def get_move(self, state):
def get_moves(self, states):
"""Batch version of get_move. A list of moves is returned (one per state)
"""
sensible_move_lists = [[move for move in st.get_legal_moves() if not st.is_eye(move, st.current_player)] for st in states]
sensible_move_lists = [[move for move in st.get_legal_moves(include_eyes=False)] for st in states]
all_moves_distributions = self.policy.batch_eval_state(states, sensible_move_lists)
move_list = [None] * len(states)
for i, move_probs in enumerate(all_moves_distributions):
if len(move_probs) == 0:
if len(move_probs) == 0 or len(states[i].history) > self.move_limit:
move_list[i] = go.PASS_MOVE
else:
# this 'else' clause is identical to ProbabilisticPolicyPlayer.get_move
Expand All @@ -83,7 +90,7 @@ def __init__(self, value_function, policy_function, rollout_function, lmbda=.5,
rollout_limit, playout_depth, n_playout)

def get_move(self, state):
sensible_moves = [move for move in state.get_legal_moves() if not state.is_eye(move, state.current_player)]
sensible_moves = [move for move in state.get_legal_moves(include_eyes=False)]
if len(sensible_moves) > 0:
move = self.mcts.get_move(state)
self.mcts.update_with_move(move)
Expand Down
Loading

0 comments on commit d4f03c5

Please sign in to comment.