Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/th-truong/tj-chess
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Truong committed Apr 24, 2022
2 parents 1aa48df + c104ac5 commit 9460499
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
24 changes: 13 additions & 11 deletions src/engine/mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ def __init__(
value: Optional[float],
parent: Optional[Node] = None,
children: Dict[str, Node] = None,
cache: Any = None,
):
self.state = state
self.value = value
self.parent = parent
self.children = children or {}
self.cache = cache


def _select(node: Node) -> Node:
Expand All @@ -38,13 +40,14 @@ def _select(node: Node) -> Node:
return _select(child)


def expand_state_chess(state: chess.Board) -> Dict[str, chess.Board]:
def expand_node_chess(node: Node) -> Dict[str, Node]:
children = {}
for move in state.legal_moves:
for move in node.state.legal_moves:
# TODO: can we push/pop to aviod copies?
board = state.copy()
board = node.state.copy()
board.push(move)
children[move.uci()] = board
child_node = Node(board, None, parent=node, cache={})
children[move.uci()] = child_node
return children


Expand All @@ -63,9 +66,9 @@ def node_to_all_layers(node):
# hist_layers.extend(board_to_layers(None, None))
else:
if i % 2 == 0:
hist_layers.extend(cur.layers)
hist_layers.extend(cur.cache['layers'])
else:
hist_layers.extend(flip_layers(cur.layers))
hist_layers.extend(flip_layers(cur.cache['layers']))
cur = cur.parent
hist_layers = np.array(hist_layers[:112])
return hist_layers
Expand All @@ -76,9 +79,9 @@ def simulate_states_chess(nodes: List[Node]) -> List[float]:
for node in nodes:
# TODO: jamming layers into the node is hacky
if node.parent is None:
node.layers = board_to_all_layers(node.state.copy())
node.cache['layers'] = board_to_all_layers(node.state.copy())
else:
node.layers = board_to_layers(node.state, node.state.turn)
node.cache['layers'] = board_to_layers(node.state, node.state.turn)

all_layers = []
for node in nodes:
Expand Down Expand Up @@ -139,7 +142,7 @@ def _back_propagate(node: Optional[Node]):

def mcts(
root: Node,
expand_state: Callable[[Any], Dict[str, Any]],
expand_node: Callable[[Node], Dict[str, Node]],
simulate_states: Callable[[List[Node]], float],
limit_search: Optional[Callable[[Node], bool]] = None,
batches=10,
Expand All @@ -159,8 +162,7 @@ def mcts(
# TODO: maybe eval blindly instead?
if node is None:
continue
child_states = expand_state(node.state)
node.children = {m: Node(s, None, node) for m, s in child_states.items()}
node.children = expand_node(node)
nodes.append(node)

child_nodes = [c for n in nodes for c in n.children.values()]
Expand Down
6 changes: 3 additions & 3 deletions src/network_utils/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import chess
import torch
import numpy as np
from engine.mcts import mcts, Node, build_chess_state_simulator, expand_state_chess, build_chess_limit_search
from engine.mcts import mcts, Node, build_chess_state_simulator, expand_node_chess, build_chess_limit_search

from network_utils import network_out_interpreter as noi
from network_utils.load_tj_model import load_tj_model
Expand All @@ -22,11 +22,11 @@ def analyse(self, board, limit: chess.engine.Limit, multipv=5):

def play(self, board, limit=None):
if self.root is None or self.root.state != board:
self.root = Node(board, None)
self.root = Node(board, None, cache={})
with torch.no_grad():
simulate_states_chess = build_chess_state_simulator(self.model)
limit_search_chess = build_chess_limit_search(chess.engine.Limit(nodes=20000))
uci, _value = mcts(self.root, expand_state_chess, simulate_states_chess, limit_search=limit_search_chess)
uci, _value = mcts(self.root, expand_node_chess, simulate_states_chess, limit_search=limit_search_chess)
self.root = self.root.children[uci]
self.root.parent = None
move = chess.Move.from_uci(uci)
Expand Down

0 comments on commit 9460499

Please sign in to comment.