diff --git a/sslforslr/models/simclr/SimCLR.py b/sslforslr/models/simclr/SimCLR.py index 7bd11ee..1330328 100644 --- a/sslforslr/models/simclr/SimCLR.py +++ b/sslforslr/models/simclr/SimCLR.py @@ -74,7 +74,7 @@ def train_step(self, data): X_1, X_2, _ = data # X shape: (B, H, W, C) = (B, 40, 200, 1) - with tf.GradientTape() as tape: + with tf.GradientTape() as tape1, tf.GradientTape() as tape2: Z_1 = self.encoder(X_1, training=True) Z_2 = self.encoder(X_2, training=True) representations_loss, representations_accuracy = self.representations_loss( @@ -91,15 +91,17 @@ def train_step(self, data): ) # Apply representations loss - params = self.encoder.trainable_weights - grads = tape.gradient(representations_loss, params) - self.optimizer.apply_gradients(zip(grads, params)) + if self.representations_loss_vic or self.representations_loss_nce: + params = self.encoder.trainable_weights + grads = tape1.gradient(representations_loss, params) + self.optimizer.apply_gradients(zip(grads, params)) # Aplly embeddings loss - params = self.encoder.trainable_weights - params += self.mlp.trainable_weights - grads = tape.gradient(embeddings_loss, params) - self.optimizer.apply_gradients(zip(grads, params)) + if self.embeddings_loss_vic or self.embeddings_loss_nce: + params = self.encoder.trainable_weights + params += self.mlp.trainable_weights + grads = tape2.gradient(embeddings_loss, params) + self.optimizer.apply_gradients(zip(grads, params)) return { 'representations_loss': representations_loss,