Skip to content

Commit

Permalink
Corrected replay weighting, added hyperparam class, added model reset…
Browse files Browse the repository at this point in the history
… options
  • Loading branch information
fshcat committed Nov 21, 2022
1 parent cd00b84 commit 303fca1
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 74 deletions.
28 changes: 20 additions & 8 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,38 @@ def __init__(self, mnk, lr=0.001, location=None, model=None):
is provided a new model is initialized. Defaults to None.
"""
self.mnk = mnk
m, n, k = mnk
self.lr = lr

# If a location is provided, retrieve the model stored at that location
if location is not None:
self.model = self.retrieve(location)
return

if model is not None:
elif model is not None:
self.model = model
return
else:
self.initialize_model()

def reset_optimizer(self):
self.opt = Adam(learning_rate=self.lr)
self.model.compile(loss='mean_squared_error', optimizer=self.opt)

self.opt = Adam(learning_rate=lr)
regularization = 0.0001
def initialize_model(self, regularization=0.0001):
m, n, k = self.mnk

self.model = Sequential()
self.model.add(Conv2D(filters=32, kernel_size=3, input_shape=(m, n, 2), kernel_regularizer=l2(regularization)))
self.model.add(Conv2D(filters=16, kernel_size=3, padding="same", input_shape=(m, n, 2), kernel_regularizer=l2(regularization)))
self.model.add(Conv2D(filters=32, kernel_size=3, padding="same", kernel_regularizer=l2(regularization)))
self.model.add(Conv2D(filters=16, kernel_size=3, padding="same", input_shape=(m, n, 2), kernel_regularizer=l2(regularization)))
self.model.add(Conv2D(filters=1, kernel_size=3, padding="same", kernel_regularizer=l2(regularization)))
self.model.add(Flatten())
self.model.add(Dense(128, kernel_initializer='normal', activation='relu', kernel_regularizer=l2(regularization)))
self.model.add(Dense(mnk[0] * mnk[1], kernel_initializer='normal', kernel_regularizer=l2(regularization)))

#model.add(Flatten())
#model.add(Conv2D(filters=32, kernel_size=3, padding="same", input_shape=(m, n, 2), kernel_regularizer=l2(regularization)))
#model.add(Dense(128, kernel_initializer='normal', activation='relu', kernel_regularizer=l2(regularization)))
#model.add(Dense(mnk[0] * mnk[1], kernel_initializer='normal', kernel_regularizer=l2(regularization)))

self.opt = Adam(learning_rate=self.lr)
self.model.compile(loss='mean_squared_error', optimizer=self.opt)

@staticmethod
Expand Down
5 changes: 2 additions & 3 deletions plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class Diagnostics:
def __init__(self, run_length=50, training_run_length=500):
def __init__(self, run_length=50, training_run_length=200):
self.run_length = run_length
self.training_run_length = training_run_length
self.xo_outcomes = [[], [], []]
Expand Down Expand Up @@ -128,10 +128,9 @@ def get_moving_avg(data, run_length=50):
return arr


def save_plots(mnk, hof, model_name, diagnostics):
def save_plots(mnk, hof, plots_dir, model_name, diagnostics):

# Create model's plots folder
plots_dir = "plots/{}".format(model_name)
if not os.path.isdir(plots_dir):
os.makedirs(plots_dir)

Expand Down
61 changes: 44 additions & 17 deletions replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,45 +4,65 @@

class PrioritySumTree():
def __init__(self, capacity):
self.capacity = capacity
self.size = 0

self.depth = math.ceil(math.log2(capacity))
self.size = 2**(self.depth + 1) - 1
self.array_len = 2**(self.depth + 1) - 1
self.leaf_start = 2**self.depth - 1

self.priorities = np.zeros(self.size, dtype="float32")
self.datatable = [None for _ in range(self.size)]
self.capacity = capacity
self.priorities = np.zeros(self.array_len, dtype="float32")
self.priorities_min = np.full(shape=self.array_len, fill_value=np.Inf, dtype="float32")

self.datatable = [None for _ in range(self.array_len)]
self.ind = 0
self.min = None

def clear(self):
self.priorities = np.zeros(self.size, dtype="float32")
self.datatable = [None for _ in range(self.size)]
self.priorities = np.zeros(self.array_len, dtype="float32")
self.priorities_min = np.full(shape=self.array_len, fill_value=np.Inf, dtype="float32")
self.datatable = [None for _ in range(self.array_len)]

self.size = 0
self.ind = 0

def add(self, data, priority):
self.size = min(self.size + 1, self.capacity)

index = self.ind + self.leaf_start
change = priority - self.priorities[index]

self.priorities[index] += change
self.priorities[index] = priority
self.priorities_min[index] = priority
self.datatable[index] = data

current = (index - 1) // 2
while (current >= 0):
self.priorities[current] += change
self.priorities[current] = self.priorities[current * 2 + 1] + self.priorities[current * 2 + 2]
self.priorities_min[current] = min(self.priorities_min[current * 2 + 1], self.priorities_min[current * 2 + 2])
current = (current - 1) // 2

self.ind = (self.ind + 1) % self.capacity

def update_vectorized(self, batch_ancestors, old_priorities, new_priorities):
change = new_priorities - old_priorities
if change.shape == ():
change = np.expand_dims(change, axis=0)
self.priorities[batch_ancestors[:, -1]] = self.priorities_min[batch_ancestors[:, -1]] = new_priorities

# change = new_priorities - old_priorities
# if change.shape == ():
# change = np.expand_dims(change, axis=0)

for i in range(batch_ancestors.shape[0]):
self.priorities[batch_ancestors[i]] += change[i]
for index in batch_ancestors[i][:-1:-1]:
# self.priorities[batch_ancestors[i]] += change[i]
self.priorities[index] = self.priorities[index * 2 + 1] + self.priorities[index * 2 + 2]
self.priorities_min[index] = min(self.priorities_min[index * 2 + 1], self.priorities_min[index * 2 + 2])

def get_total(self):
return self.priorities[0]

def get_min(self):
return self.priorities_min[0]

def sample_priority(self, val, record_ancestors=None):
current = 0
index = 0
Expand All @@ -55,30 +75,36 @@ def sample_priority(self, val, record_ancestors=None):
val -= self.priorities[2 * current + 1]
current = 2 * current + 2
else:
if val > self.priorities[2 * current + 1]:
print("Avoided 0 priority node.")

current = 2 * current + 1

index += 1

if record_ancestors is not None:
record_ancestors[index] = current

return self.datatable[current], self.priorities[current]
return self.datatable[current], self.priorities[current], self.get_min(), self.size


class ReplayBuffer:
def __init__(self, capacity, batch_size):
def __init__(self, capacity, batch_size, alpha):
self.capacity = capacity
self.batch_size = batch_size
self.buffer = PrioritySumTree(capacity)
self.index = 0

self.last_batch = None
self.last_batch_priorities = None
self.max_priority = 2.0**alpha

def clear(self):
self.buffer.clear()

def store(self, experience, priority=2.0):
def store(self, experience, priority=None):
if priority is None:
priority = self.max_priority
self.buffer.add(experience, priority)

def sample_batch(self):
Expand All @@ -93,10 +119,11 @@ def sample_batch(self):

for i in range(self.batch_size):
priority = random.uniform(segment * i, segment * (i+1))
experience, real_priority = self.buffer.sample_priority(priority, self.last_batch[i])
experience, real_priority, min_priority, size = self.buffer.sample_priority(priority, self.last_batch[i])
self.last_batch_priorities[i] = real_priority

imp_sampling[i] = (1 / self.capacity) * (real_priority / p_total)
imp_sampling[i] = (1 / size) / (real_priority / p_total)
imp_sampling[i] /= (1 / size) / (min_priority / p_total)
experiences.append(experience)

return experiences, imp_sampling
Expand Down
Loading

0 comments on commit 303fca1

Please sign in to comment.