Skip to content

Commit

Permalink
Speed changes
Browse files Browse the repository at this point in the history
  • Loading branch information
fshcat committed Nov 18, 2022
1 parent 2ad4657 commit 57088d2
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 29 deletions.
18 changes: 11 additions & 7 deletions replay_buffer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import random
import math

Expand Down Expand Up @@ -73,20 +74,23 @@ def store(self, experience, priority=1e6):
self.buffer.add(experience, priority)

def sample_batch(self):
priorities = [random.uniform(0, self.buffer.get_total()) for _ in range(self.batch_size)]
experiences = [None for _ in range(self.batch_size)]
indices = [-1 for _ in range(self.batch_size)]
imp_sampling = [-1 for _ in range(self.batch_size)]
p_total = self.buffer.get_total()
segment = p_total / self.batch_size

experiences = []
indices = np.zeros(self.batch_size, dtype="int32")
imp_sampling = np.zeros(self.batch_size)

for i in range(self.batch_size):
experiences[i], indices[i], imp_sampling[i] = self.buffer.sample_priority(priorities[i])

priority = random.uniform(segment * i, segment * (i+1))
experience, indices[i], imp_sampling[i] = self.buffer.sample_priority(priority)
experiences.append(experience)

self.last_batch = indices
return experiences, imp_sampling

def update_batch(self, priorities):
assert self.last_batch != None, "No batches have been sampled from this buffer."
assert self.last_batch is not None, "No batches have been sampled from this buffer."

for ind, priority in zip(self.last_batch, priorities):
self.buffer.update(ind, priority)
Expand Down
41 changes: 19 additions & 22 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def get_corrected_action_values(model, lagging_model, states, actions, td_errors
prev_outputs = model.action_values(states)

# Illegal actions will be ignored. This could be changed to assign -1 to illegal actions
# but will likely hinder training.
# but would likely hinder training.
target_outputs = np.copy(prev_outputs)

for i in range(target_outputs.shape[0]):
target_outputs[i][actions[i][0] * n + actions[i][1]] += weights[i] * td_errors[i]
target_outputs[i][actions[i]] += weights[i] * td_errors[i]

return target_outputs

Expand All @@ -56,21 +56,19 @@ def train_on_replays(model, lagging_model, replay_buffer, alpha, beta, min_prior
batch: Batch of (state, action, next_state) tuples being trained on
"""

m, n, k = model.mnk
batch_size = replay_buffer.batch_size
batch, importance_sampling = replay_buffer.sample_batch()

states_arr = [None for _ in range(batch_size)]
next_states_arr = [None for _ in range(batch_size)]
actions = [None for _ in range(batch_size)]
states = np.zeros(shape=(batch_size, m, n, 2))
next_states = np.zeros(shape=(batch_size, m, n, 2))
actions = np.zeros(batch_size, dtype="int32")
rewards = np.zeros(batch_size, dtype="float32")
terminal = [None for _ in range(batch_size)]
terminal = np.zeros(batch_size, dtype="?")

# Experiences are tuples (state, action, state')
for i, experience in enumerate(batch):
states_arr[i], actions[i], next_states_arr[i], rewards[i], terminal[i] = experience

states = tf.stack(states_arr)
next_states = tf.stack(next_states_arr)
states[i], actions[i], next_states[i], rewards[i], terminal[i] = experience

bootstrap_vals = np.zeros(batch_size, dtype="float32")
state_vals, _ = model.state_value(states)
Expand All @@ -81,18 +79,16 @@ def train_on_replays(model, lagging_model, replay_buffer, alpha, beta, min_prior
bootstrap_vals[i] = 0 if argmax_inds[i] == -1 else next_state_action_vals[i][argmax_inds[i]]

td_errors = bootstrap_vals + rewards - state_vals
weights = tf.math.pow(tf.convert_to_tensor(importance_sampling), beta)
weights = tf.math.pow(importance_sampling, beta)
weights /= tf.math.reduce_max(weights)

priorities = tf.math.abs(td_errors) + tf.constant(min_priority, dtype=tf.float64, shape=(batch_size))
priorities = tf.math.pow(priorities, alpha)

replay_buffer.update_batch(priorities)

target_outputs = get_corrected_action_values(model, lagging_model, states, actions, td_errors, weights)

# Theres a parameter for sample weights. Use if we do importance sampling
lr_scheduler = tf.keras.callbacks.LearningRateScheduler(scheduler)

model.model.fit(states, target_outputs, epochs=1, batch_size=len(states), steps_per_epoch=1, callbacks=[lr_scheduler], verbose=False)

def run_training_game(transitions, agent_train, agent_versing, lagging_model, replay_buffer, alpha, beta, min_priority, n_steps=1, model_update_freq=4, lagging_freq=100, start_at=5000, epsilon=0, policy_beta=1, mnk=(3, 3, 3), verbose=False):
Expand All @@ -106,8 +102,8 @@ def run_training_game(transitions, agent_train, agent_versing, lagging_model, re
mnk: Board parameters
verbose: Whether to print the final board
"""

board = Board(*mnk, hist_length=-1)
m, n, k = mnk
board = Board(m, n, k, hist_length=-1)
game = []

# State queue used for multi-step targets
Expand All @@ -124,7 +120,7 @@ def run_training_game(transitions, agent_train, agent_versing, lagging_model, re
if len(state_queue) >= n_steps:
# Adds last action to replay buffer
state, action = state_queue[0]
replay_buffer.store((get_input_rep(state)[0], action, get_input_rep(board.get_board())[0], 0, False))
replay_buffer.store((get_input_rep(state)[0], action[0] * n + action[1], get_input_rep(board.get_board())[0], 0, False))

if transitions % model_update_freq == 0 and transitions > start_at:
# Trains on a replay batch
Expand All @@ -149,7 +145,7 @@ def run_training_game(transitions, agent_train, agent_versing, lagging_model, re
while len(state_queue) > 0:
reward = agent_train.player * winner
state, action = state_queue.pop(0)
replay_buffer.store((get_input_rep(state)[0], action, get_input_rep(board.get_board())[0], reward, True))
replay_buffer.store((get_input_rep(state)[0], action[0] * n + action[1], get_input_rep(board.get_board())[0], reward, True))

return winner, game, transitions

Expand Down Expand Up @@ -293,18 +289,19 @@ def main():
batch_size = 32 # Batch size for training
lr = 0.001 # Learning rate for SGD

update_freq = 8 # How often to train the model on a replay batch (in moves)
update_freq = 4 # How often to train the model on a replay batch (in moves)
buffer_size = 50000 # Num of moves to store in replay buffer
alpha = 0.7
alpha = 0.5
beta = 0.5
min_priority = 0.01

n_steps = 1 # Num of steps used for temporal difference training targets
lagging_freq = 500 # How often to update the lagging model (in moves)
start_transition = 10000
start_transition = 50

epsilon = 0.1 # Chance of picking a random move
policy_beta = 1.5 # The lower this is, the more likely a "worse" move is chosen (don't set < 0)
epsilon = 0.2 # Chance of picking a random move
policy_beta = 1.0 # The lower this is, the more likely a "worse" move is chosen (don't set < 0)

hof_folder = "menagerie" # Folder to store the hall-of-fame models
hof = HOF(mnk, folder=hof_folder)
Expand Down

0 comments on commit 57088d2

Please sign in to comment.