Skip to content

Commit

Permalink
Added command-line arguments parser (-v = verbose)
Browse files Browse the repository at this point in the history
  • Loading branch information
PedroContipelli committed Feb 20, 2022
1 parent 0a3fdc2 commit dbb2025
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 16 deletions.
10 changes: 2 additions & 8 deletions save_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
from model import Model
import datetime

def save_model(model, argv):
model_name = name(argv)

def save_model(model, model_name):
print("Saving trained model to models/{}".format(model_name))
model.save_to('models/{}'.format(model_name))

def name(argv):
return argv[1] if len(argv) == 2 else "Model__" + str(datetime.datetime.now())[:-7].replace(" ", "__")
model.save_to('models/{}'.format(model_name))
16 changes: 10 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from model import Model
from plot import plot_wins
from hof import HOF
from utils import run_game
from save_model import save_model, name
from utils import run_game, arg_parser
from save_model import save_model
import sys

# Set cmd-line training arguments
verbose, mcts, model_name = arg_parser(sys.argv)
mnk = (3, 3, 3)

def main():
Expand All @@ -18,15 +20,14 @@ def main():
games_per_batch = 5
epsilon = 0.2 # Epsilon is the exploration factor: probability with which a random move is chosen to play

# Declare hall of fame
hof = HOF(mnk, folder="menagerie")

print("Training model: {}".format(name(sys.argv)))
print("\nTraining model: {}\n".format(model_name))

# Run training and store final model
model, end_states, victories, games = train(hof, num_batches, games_per_batch, epsilon, Model())

save_model(model, sys.argv)
save_model(model, model_name)

# Create data plots # All this should be in plot.py preferably
plt.figure()
Expand Down Expand Up @@ -92,14 +93,17 @@ def train(hof, num_batches, games_per_batch, epsilon, model):
agent_hof = Agent(model_hof, side_hof)

# Run a diagnostic (non-training, no exploration) game to collect data
diagnostic_winner, game_data = run_game(agent_best, agent_hof, 0, training=False, mnk=mnk)
diagnostic_winner, game_data = run_game(agent_best, agent_hof, 0, training=False, mnk=mnk, verbose=verbose)

# Store data from diagnostic game for this batch
games.append(game_data)
end_states.append(diagnostic_winner)
victories.append(diagnostic_winner*side_best)

except KeyboardInterrupt:
print("\n=======================")
print("Training interrupted.")
print("=======================")

print("Training completed.")
return model, end_states, victories, games
Expand Down
21 changes: 19 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from mnk import Board
import datetime


def run_game(agent_train, agent_versing, epsilon=0, training=False, mnk=(3, 3, 3)):
def run_game(agent_train, agent_versing, epsilon=0, training=False, mnk=(3, 3, 3), verbose=False):
board = Board(*mnk, form="multiplanar-turnflipped", hist_length=-1)
game = []

Expand All @@ -21,4 +21,21 @@ def run_game(agent_train, agent_versing, epsilon=0, training=False, mnk=(3, 3, 3
if winner != agent_train.player and training:
agent_train.model.td_update(board, terminal=True)

if verbose:
print(board)

return winner, game

def arg_parser(argv):
possible_arguments = ["-v", "-mcts"]

# List of booleans representing if each argument is present (in order above)
present = [1 if arg in argv else 0 for arg in possible_arguments]

# Last value will be model name
if len(argv) > 1 and not argv[1].startswith("-"):
present.append(argv[1])
else:
present.append("Model__" + str(datetime.datetime.now())[:-7].replace(" ", "__"))

return tuple(present)

0 comments on commit dbb2025

Please sign in to comment.