Skip to content

Commit

Permalink
misc(configs): prepare complementarity trainings
Browse files Browse the repository at this point in the history
  • Loading branch information
theolepage committed Jan 18, 2022
1 parent 08ad49b commit 19515e1
Show file tree
Hide file tree
Showing 11 changed files with 186 additions and 69 deletions.
25 changes: 25 additions & 0 deletions configs/vicreg_b256_comp_1.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
name: 'vicreg_b256_comp_1'
encoder:
type: 'thinresnet34'
model:
type: 'simclr'
enable_mlp: true
infonce_loss_factor: 1.0
vic_reg_factor: 1.0
representations_loss_vic: true
representations_loss_nce: false
embeddings_loss_vic: false
embeddings_loss_nce: true
training:
epochs: 500
batch_size: 256
learning_rate: 0.001
dataset:
frame_length: 32000
frame_split: true
extract_mfcc: true
train: './data/voxceleb1_train_list'
val_ratio: 0.0
spec_augment: false
wav_augment:
enable: true
25 changes: 25 additions & 0 deletions configs/vicreg_b256_comp_2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
name: 'vicreg_b256_comp_2'
encoder:
type: 'thinresnet34'
model:
type: 'simclr'
enable_mlp: true
infonce_loss_factor: 1.0
vic_reg_factor: 1.0
representations_loss_vic: false
representations_loss_nce: true
embeddings_loss_vic: true
embeddings_loss_nce: false
training:
epochs: 500
batch_size: 256
learning_rate: 0.001
dataset:
frame_length: 32000
frame_split: true
extract_mfcc: true
train: './data/voxceleb1_train_list'
val_ratio: 0.0
spec_augment: false
wav_augment:
enable: true
25 changes: 25 additions & 0 deletions configs/vicreg_b256_comp_3.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
name: 'vicreg_b256_comp_3'
encoder:
type: 'thinresnet34'
model:
type: 'simclr'
enable_mlp: true
infonce_loss_factor: 1.0
vic_reg_factor: 0.1
representations_loss_vic: true
representations_loss_nce: true
embeddings_loss_vic: false
embeddings_loss_nce: false
training:
epochs: 500
batch_size: 256
learning_rate: 0.001
dataset:
frame_length: 32000
frame_split: true
extract_mfcc: true
train: './data/voxceleb1_train_list'
val_ratio: 0.0
spec_augment: false
wav_augment:
enable: true
25 changes: 25 additions & 0 deletions configs/vicreg_b256_comp_4.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
name: 'vicreg_b256_comp_4'
encoder:
type: 'thinresnet34'
model:
type: 'simclr'
enable_mlp: true
infonce_loss_factor: 1.0
vic_reg_factor: 0.1
representations_loss_vic: false
representations_loss_nce: false
embeddings_loss_vic: true
embeddings_loss_nce: true
training:
epochs: 500
batch_size: 256
learning_rate: 0.001
dataset:
frame_length: 32000
frame_split: true
extract_mfcc: true
train: './data/voxceleb1_train_list'
val_ratio: 0.0
spec_augment: false
wav_augment:
enable: true
22 changes: 22 additions & 0 deletions configs/vicreg_b256_mlp_512.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: 'vicreg_b256_mlp_512'
encoder:
type: 'thinresnet34'
model:
type: 'simclr'
enable_mlp: true
mlp_dim: 512
infonce_loss_factor: 0.0
vic_reg_factor: 1.0
training:
epochs: 500
batch_size: 256
learning_rate: 0.001
dataset:
frame_length: 32000
frame_split: true
extract_mfcc: true
train: './data/voxceleb1_train_list'
val_ratio: 0.0
spec_augment: false
wav_augment:
enable: true
1 change: 0 additions & 1 deletion evaluate_label_efficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def train(
# Disable features required only by self-supervised training
config.dataset.wav_augment.enable = False
config.dataset.frame_split = False
config.dataset.provide_clean_and_aug = False

gens, input_shape, nb_classes = load_dataset(config)
(train_gen, val_gen) = gens
Expand Down
Empty file modified run.sh
100755 → 100644
Empty file.
1 change: 0 additions & 1 deletion sslforslr/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ class DatasetConfig:
frame_length: int = 16000
frame_split: bool = False
max_samples: int = None
provide_clean_and_aug: bool = False
extract_mfcc: bool = False
spec_augment: bool = False
val_ratio: float = 0.1
Expand Down
16 changes: 2 additions & 14 deletions sslforslr/dataset/AudioDatasetLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __init__(
labels,
indices,
wav_augment=None,
provide_clean_and_aug=False,
extract_mfcc=False
):
self.epoch = 0
Expand All @@ -54,7 +53,6 @@ def __init__(
self.labels = labels
self.indices = indices
self.wav_augment = wav_augment
self.provide_clean_and_aug = provide_clean_and_aug
self.extract_mfcc = extract_mfcc

def __len__(self):
Expand Down Expand Up @@ -86,16 +84,8 @@ def __getitem__(self, i):
min_length=2*self.frame_length
) # (1, T)
frame1, frame2 = sample_frames(data, self.frame_length)
if self.provide_clean_and_aug:
frame1_clean = self.preprocess_data(frame1, augment=False)
frame1_aug = self.preprocess_data(frame1)
X1.append(np.stack((frame1_clean, frame1_aug), axis=-1))
frame2_clean = self.preprocess_data(frame2, augment=False)
frame2_aug = self.preprocess_data(frame2)
X2.append(np.stack((frame2_clean, frame2_aug), axis=-1))
else:
X1.append(self.preprocess_data(frame1))
X2.append(self.preprocess_data(frame2))
X1.append(self.preprocess_data(frame1))
X2.append(self.preprocess_data(frame2))
y.append(self.labels[index])
elif self.supervised_sampler:
frame1 = load_audio(self.files[index[0]], self.frame_length)
Expand Down Expand Up @@ -184,7 +174,6 @@ def load(self, batch_size):
self.labels,
train_indices,
self.wav_augment,
self.config.provide_clean_and_aug,
self.config.extract_mfcc
)

Expand All @@ -198,7 +187,6 @@ def load(self, batch_size):
self.labels,
val_indices,
self.wav_augment,
self.config.provide_clean_and_aug,
self.config.extract_mfcc
)

Expand Down
104 changes: 55 additions & 49 deletions sslforslr/models/simclr/SimCLR.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@ def __init__(self,
super().__init__()

self.enable_mlp = config.enable_mlp
self.enable_mse_clean_aug = config.enable_mse_clean_aug
self.infonce_loss_factor = config.infonce_loss_factor
self.vic_reg_factor = config.vic_reg_factor
self.barlow_twins_factor = config.barlow_twins_factor
self.mse_clean_aug_factor = config.mse_clean_aug_factor
self.reg = regularizers.l2(config.weight_reg)

self.representations_loss_vic = config.representations_loss_vic
self.representations_loss_nce = config.representations_loss_nce
self.embeddings_loss_vic = config.embeddings_loss_vic
self.embeddings_loss_nce = config.embeddings_loss_nce

self.encoder = encoder
self.mlp = MLP(config.mlp_dim)
self.infonce_loss = InfoNCELoss()
Expand All @@ -45,56 +48,65 @@ def compile(self, optimizer, **kwargs):
self.optimizer = optimizer

def call(self, X):
if len(X.shape) == 4 and self.enable_mse_clean_aug:
X, _ = self.extract_clean_and_aug(X)
return self.encoder(X)

@tf.function
def get_embeddings(self, X_1, X_2):
Z_1 = self.encoder(X_1, training=True)
Z_2 = self.encoder(X_2, training=True)
if self.enable_mlp:
Z_1 = self.mlp(Z_1, training=True)
Z_2 = self.mlp(Z_2, training=True)
return Z_1, Z_2
def representations_loss(self, Z_1, Z_2):
loss, accuracy = 0, 0
if self.representations_loss_nce:
loss, accuracy = self.infonce_loss((Z_1, Z_2))
loss = self.infonce_loss_factor * loss
if self.representations_loss_vic:
loss += self.vic_reg_factor * self.vic_reg((Z_1, Z_2))
return loss, accuracy

@tf.function
def extract_clean_and_aug(self, X):
X_clean, X_aug = tf.split(X, 2, axis=-1)
X_clean = tf.squeeze(X_clean, axis=-1)
X_aug = tf.squeeze(X_aug, axis=-1)
return X_clean, X_aug
def embeddings_loss(self, Z_1, Z_2):
loss, accuracy = 0, 0
if self.embeddings_loss_nce:
loss, accuracy = self.infonce_loss((Z_1, Z_2))
loss = self.infonce_loss_factor * loss
if self.embeddings_loss_vic:
loss += self.vic_reg_factor * self.vic_reg((Z_1, Z_2))
return loss, accuracy

def train_step(self, data):
X_1_aug, X_2_aug, _ = data
X_1, X_2, _ = data
# X shape: (B, H, W, C) = (B, 40, 200, 1)

if self.enable_mse_clean_aug:
X_1_clean, X_1_aug = self.extract_clean_and_aug(X_1_aug)
X_2_clean, X_2_aug = self.extract_clean_and_aug(X_2_aug)

with tf.GradientTape() as tape:
Z_1_aug, Z_2_aug = self.get_embeddings(X_1_aug, X_2_aug)

loss, accuracy = self.infonce_loss((Z_1_aug, Z_2_aug))
loss = self.infonce_loss_factor * loss
loss += self.vic_reg_factor * self.vic_reg((Z_1_aug, Z_2_aug))
loss += self.barlow_twins_factor * self.barlow_twins((Z_1_aug, Z_2_aug))

if self.enable_mse_clean_aug:
Z_1_clean, Z_2_clean = self.get_embeddings(X_1_clean, X_2_clean)
loss += self.mse_clean_aug_factor * mse_loss(Z_1_clean, Z_1_aug)
loss += self.mse_clean_aug_factor * mse_loss(Z_2_clean, Z_2_aug)

trainable_params = self.encoder.trainable_weights
if self.enable_mlp:
trainable_params += self.mlp.trainable_weights

grads = tape.gradient(loss, trainable_params)
# grads, _ = tf.clip_by_global_norm(grads, 5.0)
self.optimizer.apply_gradients(zip(grads, trainable_params))

return { 'loss': loss, 'accuracy': accuracy }
Z_1 = self.encoder(X_1, training=True)
Z_2 = self.encoder(X_2, training=True)
representations_loss, representations_accuracy = self.representations_loss(
Z_1,
Z_2
)

if self.enable_mlp:
Z_1 = self.mlp(Z_1, training=True)
Z_2 = self.mlp(Z_2, training=True)
embeddings_loss, embeddings_accuracy = self.embeddings_loss(
Z_1,
Z_2
)

# Apply representations loss
params = self.encoder.trainable_weights
grads = tape.gradient(representations_loss, params)
self.optimizer.apply_gradients(zip(grads, params))

# Aplly embeddings loss
params = self.encoder.trainable_weights
params += self.mlp.trainable_weights
grads = tape.gradient(embeddings_loss, params)
self.optimizer.apply_gradients(zip(grads, params))

return {
'representations_loss': representations_loss,
'representations_accuracy': representations_accuracy,
'embeddings_loss': embeddings_loss,
'embeddings_accuracy': embeddings_accuracy
}


class MLP(Model):
Expand Down Expand Up @@ -156,10 +168,4 @@ def call(self, data):
preds_acc = tf.math.equal(pred_indices, labels)
accuracy = tf.math.count_nonzero(preds_acc, dtype=tf.int32) / batch_size

return loss, accuracy


@tf.function
def mse_loss(Z_clean, Z_aug):
mse = tf.keras.metrics.mean_squared_error(Z_clean, Z_aug)
return tf.math.reduce_mean(mse)
return loss, accuracy
11 changes: 7 additions & 4 deletions sslforslr/models/simclr/SimCLRModelConfig.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing import List

from sslforslr.configs import ModelConfig

Expand All @@ -16,10 +17,12 @@ class SimCLRModelConfig(ModelConfig):

barlow_twins_factor: float = 0.0
barlow_twins_lambda: float = 0.05

enable_mse_clean_aug: bool = False
mse_clean_aug_factor: float = 0.1


representations_loss_vic: bool = False
representations_loss_nce: bool = False
embeddings_loss_vic: bool = True
embeddings_loss_nce: bool = True

weight_reg: float = 1e-4

SimCLRModelConfig.__NAME__ = 'simclr'

0 comments on commit 19515e1

Please sign in to comment.