Skip to content

Commit

Permalink
Added interrupting training and saving current model/plots
Browse files Browse the repository at this point in the history
  • Loading branch information
PedroContipelli committed Nov 18, 2021
1 parent 959f741 commit 2a9d294
Showing 1 changed file with 31 additions and 27 deletions.
58 changes: 31 additions & 27 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ def main():
plot_run_length = num_loops // 10

# Run training and store final model
model, end_states, victories, games = train(hof, num_loops, loop_length, Model(), epsilon, decay_factor)
model, end_states, victories, games, finished = train(hof, num_loops, loop_length, Model(), epsilon, decay_factor)

print("Training complete." if finished else "")
print("Saving trained model and plots")

print("Training complete.")
print("Saving trained model to models/modelXO and charts to plots folder")

model.save_to('models/' + sys.argv[1])
model.save_to('models/' + sys.argv[1] + ("Interrupted" if not finished else ""))

# Create data plots
plt.subplots(3, 3, constrained_layout=True)
Expand All @@ -45,7 +45,7 @@ def main():
hof.sample_histogram(20) # shows how HOF was sampled

plt.show()
plt.savefig("plots/plot{}.png".format(num_loops * loop_length))
plt.savefig("plots/plot{}{}.png".format(num_loops * loop_length, ("Interrupted" if not finished else "")))

'''
ind = 0
Expand Down Expand Up @@ -96,35 +96,39 @@ def train(hof, loops, loop_length, model, epsilon, decay_factor):
side_best = [-1, 1][random.random() > 0.5]
side_hof = side_best * -1

for loop in range(loops):
print("\nLoop: ", loop)
try:
for loop in range(loops):
print("\nLoop: ", loop)

# Initialize the agents
agent_best = Agent(model, side_best)
agent_hof = Agent(model_hof, side_hof)
# Initialize the agents
agent_best = Agent(model, side_best)
agent_hof = Agent(model_hof, side_hof)

for game in range(loop_length):
run_game(agent_best, agent_hof, epsilon, training=True)
for game in range(loop_length):
run_game(agent_best, agent_hof, epsilon, training=True)

epsilon -= decay_factor
epsilon -= decay_factor

# Run a diagnostic (non-training, no exploration) game to collect data
diagnostic_winner, game_data = run_game(agent_best, agent_hof, 0, training=False)
# Run a diagnostic (non-training, no exploration) game to collect data
diagnostic_winner, game_data = run_game(agent_best, agent_hof, 0, training=False)

# Switch sides for the next loop
side_best *= -1
side_hof = side_best * -1
# Switch sides for the next loop
side_best *= -1
side_hof = side_best * -1

# Update hall of fame and sample from it for the next loop
hof.gate(model)
model_hof = hof.sample("uniform")
# Update hall of fame and sample from it for the next loop
hof.gate(model)
model_hof = hof.sample("uniform")

# Store data from loop
games.append(game_data)
end_states.append(diagnostic_winner)
victories.append(diagnostic_winner*side_best)
# Store data from loop
games.append(game_data)
end_states.append(diagnostic_winner)
victories.append(diagnostic_winner*side_best)
except:
print("Training interrupted.")
return model, end_states, victories, games, False

return model, end_states, victories, games
return model, end_states, victories, games, True

if __name__ == "__main__":
main()

0 comments on commit 2a9d294

Please sign in to comment.