-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added option to separate board into two planes, new hall of fame
sampling option, bugs with action-taking fixed, added diagnostic games to the loop
- Loading branch information
Showing
11 changed files
with
168 additions
and
113 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,35 +1,75 @@ | ||
import mnk | ||
import keras.models | ||
import tensorflow as tf | ||
import random | ||
|
||
|
||
class Agent: | ||
|
||
def __init__(self, board, model, player): | ||
def __init__(self, board, model, player, training): | ||
self.board = board | ||
self.model = model | ||
self.player = player | ||
self.training = training | ||
|
||
def action(self, epsilon=0.01): | ||
def greedy_action(self): | ||
legal_moves = self.board.legal_moves() | ||
assert len(legal_moves) > 0, "No legal moves can be played." | ||
|
||
# Exploration | ||
if (random.random() < epsilon): | ||
print("Played epsilon move ({:.5f})".format(epsilon)) | ||
self.board.move(*legal_moves[random.randint(0, len(legal_moves) - 1)]) | ||
return | ||
|
||
best_move = legal_moves[-1] | ||
max_evaluation = -1 | ||
|
||
for move in legal_moves: | ||
self.board.move(*move) | ||
evaluation = self.player * self.model(self.board.get_board()) | ||
if evaluation > max_evaluation: | ||
|
||
val = self.value() | ||
if val > max_evaluation: | ||
best_move = move | ||
max_evaluation = evaluation | ||
max_evaluation = val | ||
|
||
self.board.undo_move(*move) | ||
self.board.move(*best_move) | ||
|
||
return best_move | ||
|
||
def random_action(self): | ||
legal_moves = self.board.legal_moves() | ||
return legal_moves[random.randint(0, len(legal_moves) - 1)] | ||
|
||
def value(self): | ||
if self.board.who_won() == self.player: | ||
return tf.constant(1, dtype="float32", shape=(1, 1)) | ||
elif self.board.who_won() == -1*self.player: | ||
return tf.constant(-1, dtype="float32", shape=(1, 1)) | ||
elif self.board.who_won() == 0: | ||
return tf.constant(0, dtype="float32", shape=(1, 1)) | ||
else: | ||
return self.player*self.model(self.board.get_board()) | ||
|
||
def action(self, epsilon=0): | ||
legal_moves = self.board.legal_moves() | ||
assert len(legal_moves) > 0, "No legal moves can be played." | ||
|
||
greedy = self.greedy_action() | ||
if self.training and len(self.board.history()) >= (2 + (self.player == -1)): | ||
self.update_model(greedy) | ||
|
||
# Exploration | ||
if random.random() < epsilon: | ||
print("Played epsilon move ({:.5f})".format(epsilon)) | ||
move = self.random_action() | ||
else: | ||
move = greedy | ||
|
||
self.board.move(*move) | ||
|
||
def update_model(self, greedy_move=()): | ||
if greedy_move == (): | ||
assert self.board.who_won() != 2 and self.board.who_won() != self.player | ||
self.model.fit(self.board.history()[-2], self.value(), batch_size=1, verbose=0) | ||
else: | ||
self.board.move(*greedy_move) | ||
self.model.fit(self.board.history()[-3], self.value(), batch_size=1, verbose=0) | ||
self.board.undo_move(*greedy_move) | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,47 @@ | ||
import random | ||
import tensorflow as tf | ||
from math import floor | ||
from matplotlib import pyplot | ||
import os | ||
from math import floor | ||
|
||
|
||
class HOF: | ||
def __init__(self, folder): | ||
self.hof = [] | ||
self.folder = folder | ||
self.sample_history = [] | ||
self.pop_size = 0 | ||
self.basel = 0 # used in limit-uniform sampling | ||
if not os.path.isdir(folder): | ||
os.makedirs(folder) | ||
|
||
def store(self, model, name): | ||
model.save("{}/{}".format(self.folder, name)) | ||
self.hof.append(name) | ||
self.pop_size += 1 | ||
self.basel += 1/self.pop_size**2 | ||
|
||
def sample_hof(self, method='uniform'): | ||
if method == 'limit-uniform': | ||
threshold = random.random()*self.basel | ||
|
||
cum_prob = 0 | ||
ind = self.pop_size-1 | ||
for i in range(self.pop_size): | ||
cum_prob += 1/(self.pop_size-i)**2 | ||
if cum_prob > threshold: | ||
ind = i | ||
break | ||
elif method == 'uniform': | ||
ind = floor(random.random()*self.pop_size) | ||
|
||
self.sample_history.append(ind) | ||
|
||
def sample_hof(self): | ||
pop_size = len(self.hof) | ||
ind = floor(pop_size*random.random()) | ||
name = self.hof[ind] | ||
return tf.keras.models.load_model("{}/{}".format(self.folder, name)) | ||
|
||
def sample_hist(self, num=100): | ||
pyplot.hist(self.sample_history, num) | ||
pyplot.title("Sampling of Model Indices from HOF") | ||
pyplot.show() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.