Skip to content

Commit

Permalink
Fixed play.py and added command-line argument naming of saved models
Browse files Browse the repository at this point in the history
  • Loading branch information
PedroContipelli committed Nov 18, 2021
1 parent 4ebe47f commit 88b8326
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 9 deletions.
1 change: 1 addition & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self, location=False):
self.model = Sequential()
self.model.add(Dense(27, input_shape=(1, 9), kernel_initializer='normal', activation='tanh'))
self.model.add(Dense(27, kernel_initializer='normal', activation='tanh'))
self.model.add(Dense(9, kernel_initializer='normal', activation='relu'))
self.model.add(Dense(1, kernel_initializer='normal', activation='tanh'))

self.model.compile(loss='mean_squared_error', optimizer=opt)
Expand Down
11 changes: 6 additions & 5 deletions models/modelXO/keras_metadata.pb

Large diffs are not rendered by default.

Binary file modified models/modelXO/saved_model.pb
Binary file not shown.
Binary file modified models/modelXO/variables/variables.data-00000-of-00001
Binary file not shown.
Binary file modified models/modelXO/variables/variables.index
Binary file not shown.
6 changes: 5 additions & 1 deletion play.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from agent import Agent
import mnk
import tensorflow as tf
import model
import sys

board = mnk.Board(3, 3, 3, form="flatten")
model = tf.keras.models.load_model('models/modelXO')

assert len(sys.argv) == 2, "Please specify which model you would like to play against (ex: python3 play.py 3LayersModel)"
model = model.Model('models/' + sys.argv[1])

print("\n\n" + str(board))
current_player = input("\nWho plays first (Me/AI)? ")
Expand Down
Binary file modified plots/plot10000.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 8 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,20 @@
from model import Model
from plot import plot_wins
from hof import HOF
import sys

mnk = (3, 3, 3)

def main():

assert len(sys.argv) == 2, "Please name the model you are training (ex: python3 train.py 3LayersModel)"

# Initialize hall of fame
hof = HOF("menagerie")

# Each loop trains n games and then does one diagnostic game without exploration moves
num_loops = 20
loop_length = 5 # also doubles as print_frequency
num_loops = 200
loop_length = 50 # also doubles as print_frequency
epsilon = 0.2 # exploration constant
decay_freq = 30 # number of games between each epsilon decrement [Currently not being used]
decay_factor = 0.0005 # how much to decrease by [Currently not being used]
Expand All @@ -28,7 +31,7 @@ def main():
print("Training complete.")
print("Saving trained model to models/modelXO and charts to plots folder")

model.save_to('models/modelXO')
model.save_to('models/' + sys.argv[1])

# Create data plots
plt.subplots(3, 3, constrained_layout=True)
Expand All @@ -45,12 +48,14 @@ def main():
plt.show()
plt.savefig("plots/plot{}.png".format(num_loops * loop_length))

'''
ind = 0
while ind != -1:
ind = int(input("Query a game: "))
for move in games[ind]:
print(move)
pass
'''

# Runs a game from start to end
def run_game(agent_train, agent_versing, epsilon, training):
Expand Down

0 comments on commit 88b8326

Please sign in to comment.