Skip to content

Commit

Permalink
Fixes to state_value
Browse files Browse the repository at this point in the history
  • Loading branch information
fshcat committed Nov 22, 2022
1 parent 303fca1 commit a98ea87
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
27 changes: 15 additions & 12 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,16 @@ def initialize_model(self, regularization=0.0001):
m, n, k = self.mnk

self.model = Sequential()
self.model.add(Conv2D(filters=16, kernel_size=3, padding="same", input_shape=(m, n, 2), kernel_regularizer=l2(regularization)))
self.model.add(Conv2D(filters=32, kernel_size=3, padding="same", kernel_regularizer=l2(regularization)))
self.model.add(Conv2D(filters=16, kernel_size=3, padding="same", input_shape=(m, n, 2), kernel_regularizer=l2(regularization)))
self.model.add(Conv2D(filters=1, kernel_size=3, padding="same", kernel_regularizer=l2(regularization)))
self.model.add(Flatten())
#self.model.add(Conv2D(filters=16, kernel_size=3, padding="same", input_shape=(m, n, 2), kernel_regularizer=l2(regularization)))
#self.model.add(Conv2D(filters=32, kernel_size=3, padding="same", kernel_regularizer=l2(regularization)))
#self.model.add(Conv2D(filters=16, kernel_size=3, padding="same", input_shape=(m, n, 2), kernel_regularizer=l2(regularization)))
#self.model.add(Conv2D(filters=1, kernel_size=3, padding="same", kernel_regularizer=l2(regularization)))
#self.model.add(Flatten())

#model.add(Flatten())
#model.add(Conv2D(filters=32, kernel_size=3, padding="same", input_shape=(m, n, 2), kernel_regularizer=l2(regularization)))
#model.add(Dense(128, kernel_initializer='normal', activation='relu', kernel_regularizer=l2(regularization)))
#model.add(Dense(mnk[0] * mnk[1], kernel_initializer='normal', kernel_regularizer=l2(regularization)))
self.model.add(Conv2D(filters=32, kernel_size=3, input_shape=(m, n, 2), kernel_regularizer=l2(regularization)))
self.model.add(Flatten())
self.model.add(Dense(128, kernel_initializer='normal', activation='relu', kernel_regularizer=l2(regularization)))
self.model.add(Dense(m * n, kernel_initializer='normal', kernel_regularizer=l2(regularization)))

self.opt = Adam(learning_rate=self.lr)
self.model.compile(loss='mean_squared_error', optimizer=self.opt)
Expand Down Expand Up @@ -93,10 +93,13 @@ def state_value(self, states, terminal=None):
k, m, n, _ = states.shape

illegal_actions = (np.sum(states, axis=3) != 0).reshape(k, m * n)
action_vals += np.where(illegal_actions, np.full(shape=(k, m * n), fill_value=np.NINF, dtype="float32"), illegal_actions == 0)


# Replace values for illegal actions with -infinity so they can't be picked as max
action_vals = np.where(illegal_actions, np.full(shape=(k, m * n), fill_value=np.NINF, dtype="float32"), action_vals)
max_vals = tf.math.reduce_max(action_vals, axis=1)
max_inds = np.argmax(action_vals, axis=1)

# If state is terminal, return an index of -1 for that state
max_inds = np.where(terminal, np.full(shape=k, fill_value=-1, dtype="int32"), np.argmax(action_vals, axis=1))

return max_vals, max_inds

Expand Down
6 changes: 4 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def train(hof, params, model):
side_hof *= -1
side_best = side_hof * -1

assert side_hof != side_best, "Opponents can't be on the same side"

# Regularly attempt to add the model into HOF ("gating")
if game % params.hof_gate_rate == 0 and games_since_hof > params.hof_wait_time:
reward, improvement = diagnostics.get_recent_performance()
Expand Down Expand Up @@ -345,9 +347,9 @@ def main():
batch_size = 32 # Batch size for training lr = 0.001 # Learning rate for SGD
lr = 0.001

buffer_size = 50000 # Num of moves to store in replay buffer
buffer_size = 20000 # Num of moves to store in replay buffer
alpha = 0.7
buffer_beta = 0.8
buffer_beta = 0.5
min_priority = 0.01

update_rate = 4 # How often to train the model on a replay batch (in moves)
Expand Down

0 comments on commit a98ea87

Please sign in to comment.