Skip to content

Commit

Permalink
Win matrix fix, TD error calculation correction
Browse files Browse the repository at this point in the history
  • Loading branch information
fshcat committed Nov 22, 2022
1 parent a98ea87 commit c841825
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 11 deletions.
7 changes: 5 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,11 @@ def state_value(self, states, terminal=None):
action_vals = np.where(illegal_actions, np.full(shape=(k, m * n), fill_value=np.NINF, dtype="float32"), action_vals)
max_vals = tf.math.reduce_max(action_vals, axis=1)

# If state is terminal, return an index of -1 for that state
max_inds = np.where(terminal, np.full(shape=k, fill_value=-1, dtype="int32"), np.argmax(action_vals, axis=1))
if terminal is not None:
# If state is terminal, return an index of -1 for that state
max_inds = np.where(terminal, np.full(shape=k, fill_value=-1, dtype="int32"), np.argmax(action_vals, axis=1))
else:
max_inds = np.argmax(action_vals, axis=1)

return max_vals, max_inds

Expand Down
10 changes: 5 additions & 5 deletions plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,13 @@ def sample_histogram(sample_history, bins=100):


# 1v1 matrix for historical models: ideally, newer versions beating earlier ones
def winrate_matrix(mnk, num_games, step):
def winrate_matrix(mnk, hof_dir, num_games, step):
print("Calculating winrate matrix... (may take a while)")
matrix = np.zeros((num_games // step, num_games // step))
for i in range(0, num_games, step):
for j in range(0, num_games, step):
model_i = Model(mnk, "menagerie/{}".format(i))
model_j = Model(mnk, "menagerie/{}".format(j))
model_i = Model(mnk, location="{}/{}".format(hof_dir, i))
model_j = Model(mnk, location="{}/{}".format(hof_dir, j))

side_i = 1
side_j = side_i * -1
Expand All @@ -128,7 +128,7 @@ def get_moving_avg(data, run_length=50):
return arr


def save_plots(mnk, hof, plots_dir, model_name, diagnostics):
def save_plots(mnk, hof, plots_dir, hof_dir, model_name, diagnostics):

# Create model's plots folder
if not os.path.isdir(plots_dir):
Expand Down Expand Up @@ -189,7 +189,7 @@ def save_plots(mnk, hof, plots_dir, model_name, diagnostics):
plt.clf()

step = max(1, hof.pop_size // 40)
matrix = winrate_matrix(mnk, hof.pop_size, step)
matrix = winrate_matrix(mnk, hof_dir, hof.pop_size, step)
plt.imshow(matrix, cmap="bwr")
plt.imsave("plots/{}/Matrix.png".format(model_name), matrix, cmap="bwr")
plt.clf()
Expand Down
10 changes: 6 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
verbose, mcts, model_name = arg_parser(sys.argv)
mnk = (3, 3, 3)
plot_folder = "plots/{}".format(model_name)
hof_folder = "menagerie/{}".format(model_name) # Folder to store the hall-of-fame models

class ResetType(Enum):
NONE = 0 # Reset nothing
Expand Down Expand Up @@ -108,14 +109,15 @@ def train_on_replays(model, lagging_model, replay_buffer, params):
states[i], actions[i], next_states[i], rewards[i], terminal[i] = experience

bootstrap_vals = np.zeros(batch_size, dtype="float32")
state_vals, _ = model.state_value(states)

state_action_vals = np.array(model.action_values(states))[np.arange(batch_size), actions]
next_state_action_vals = lagging_model.action_values(next_states)
_, argmax_inds = model.state_value(next_states, terminal)

for i in range(batch_size):
bootstrap_vals[i] = 0 if argmax_inds[i] == -1 else next_state_action_vals[i][argmax_inds[i]]

td_errors = bootstrap_vals + rewards - state_vals
td_errors = bootstrap_vals + rewards - state_action_vals
weights = tf.math.pow(importance_sampling, params.buffer_beta)

priorities = tf.math.abs(td_errors) + tf.constant(params.min_priority, dtype=tf.float32, shape=(batch_size))
Expand Down Expand Up @@ -294,7 +296,7 @@ def train(hof, params, model):

if game % params.plotting_rate == 0:
save_model(model, model_name)
save_plots(mnk, hof, plot_folder, model_name, diagnostics)
save_plots(mnk, hof, plot_folder, hof_folder, model_name, diagnostics)


except KeyboardInterrupt:
Expand Down Expand Up @@ -381,7 +383,7 @@ def main():
model, diagnostics, games = train(hof, params, Model(mnk, lr=params.lr))

save_model(model, model_name)
save_plots(mnk, hof, plot_folder, model_name, diagnostics)
save_plots(mnk, hof, plot_folder, hof_folder, model_name, diagnostics)

# Can be used after looking at plot to analyze important milestones
# TODO: Put into a function
Expand Down

0 comments on commit c841825

Please sign in to comment.