Skip to content

Commit

Permalink
fix(simclr): use two tapes for representations and embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
theolepage committed Jan 18, 2022
1 parent 19515e1 commit 454f71a
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions sslforslr/models/simclr/SimCLR.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def train_step(self, data):
X_1, X_2, _ = data
# X shape: (B, H, W, C) = (B, 40, 200, 1)

with tf.GradientTape() as tape:
with tf.GradientTape() as tape1, tf.GradientTape() as tape2:
Z_1 = self.encoder(X_1, training=True)
Z_2 = self.encoder(X_2, training=True)
representations_loss, representations_accuracy = self.representations_loss(
Expand All @@ -91,15 +91,17 @@ def train_step(self, data):
)

# Apply representations loss
params = self.encoder.trainable_weights
grads = tape.gradient(representations_loss, params)
self.optimizer.apply_gradients(zip(grads, params))
if self.representations_loss_vic or self.representations_loss_nce:
params = self.encoder.trainable_weights
grads = tape1.gradient(representations_loss, params)
self.optimizer.apply_gradients(zip(grads, params))

# Aplly embeddings loss
params = self.encoder.trainable_weights
params += self.mlp.trainable_weights
grads = tape.gradient(embeddings_loss, params)
self.optimizer.apply_gradients(zip(grads, params))
if self.embeddings_loss_vic or self.embeddings_loss_nce:
params = self.encoder.trainable_weights
params += self.mlp.trainable_weights
grads = tape2.gradient(embeddings_loss, params)
self.optimizer.apply_gradients(zip(grads, params))

return {
'representations_loss': representations_loss,
Expand Down

0 comments on commit 454f71a

Please sign in to comment.