From 02af0a9b7737b0891ad53eba140f00a3ff58f95b Mon Sep 17 00:00:00 2001 From: Theo Lepage Date: Sat, 4 Dec 2021 15:25:41 +0100 Subject: [PATCH] misc(configs): prepare trainings for evaluating vicreg hyperparams --- configs/simclr_vicreg_0.1.yml | 4 +- configs/vicreg_b1024.yml | 6 +- configs/vicreg_b2048.yml | 6 +- configs/vicreg_b256.yml | 6 +- configs/vicreg_b256_1_0.5_0.1.yml | 24 ++++ configs/vicreg_b256_1_1_0.04.yml | 24 ++++ ..._4xencoder.yml => vicreg_b256_1_1_0.1.yml} | 12 +- ...g_b256_2xmse.yml => vicreg_b256_1_1_0.yml} | 13 +- configs/vicreg_b512.yml | 6 +- prepare_data.py | 133 +++++++++--------- run.sh | 23 +-- sslforslr/dataset/AudioDatasetLoader.py | 22 +-- sslforslr/models/simclr/SimCLR.py | 10 +- sslforslr/models/simclr/SimCLRModelConfig.py | 11 +- sslforslr/modules/VICReg.py | 27 ++-- 15 files changed, 189 insertions(+), 138 deletions(-) create mode 100644 configs/vicreg_b256_1_0.5_0.1.yml create mode 100644 configs/vicreg_b256_1_1_0.04.yml rename configs/{vicreg_b256_4xencoder.yml => vicreg_b256_1_1_0.1.yml} (64%) rename configs/{vicreg_b256_2xmse.yml => vicreg_b256_1_1_0.yml} (64%) diff --git a/configs/simclr_vicreg_0.1.yml b/configs/simclr_vicreg_0.1.yml index 9dab785..78c23f3 100644 --- a/configs/simclr_vicreg_0.1.yml +++ b/configs/simclr_vicreg_0.1.yml @@ -3,10 +3,10 @@ encoder: type: 'thinresnet34' model: type: 'simclr' - loss_factor: 1.0 + infonce_loss_factor: 1.0 vic_reg_factor: 0.1 training: - epochs: 150 + epochs: 300 batch_size: 256 learning_rate: 0.001 dataset: diff --git a/configs/vicreg_b1024.yml b/configs/vicreg_b1024.yml index 4406d5b..c039131 100644 --- a/configs/vicreg_b1024.yml +++ b/configs/vicreg_b1024.yml @@ -4,10 +4,10 @@ encoder: model: type: 'simclr' enable_mlp: true - loss_factor: 0.0 - vic_reg_factor: 25.0 + infonce_loss_factor: 0.0 + vic_reg_factor: 1.0 training: - epochs: 500 + epochs: 300 batch_size: 1024 learning_rate: 0.001 dataset: diff --git a/configs/vicreg_b2048.yml b/configs/vicreg_b2048.yml index 29f6f28..c6823fd 100644 --- a/configs/vicreg_b2048.yml +++ b/configs/vicreg_b2048.yml @@ -4,10 +4,10 @@ encoder: model: type: 'simclr' enable_mlp: true - loss_factor: 0.0 - vic_reg_factor: 25.0 + infonce_loss_factor: 0.0 + vic_reg_factor: 1.0 training: - epochs: 500 + epochs: 300 batch_size: 2048 learning_rate: 0.001 dataset: diff --git a/configs/vicreg_b256.yml b/configs/vicreg_b256.yml index ccb82bf..328b989 100644 --- a/configs/vicreg_b256.yml +++ b/configs/vicreg_b256.yml @@ -4,10 +4,10 @@ encoder: model: type: 'simclr' enable_mlp: true - loss_factor: 0.0 - vic_reg_factor: 25.0 + infonce_loss_factor: 0.0 + vic_reg_factor: 1.0 training: - epochs: 500 + epochs: 300 batch_size: 256 learning_rate: 0.001 dataset: diff --git a/configs/vicreg_b256_1_0.5_0.1.yml b/configs/vicreg_b256_1_0.5_0.1.yml new file mode 100644 index 0000000..4855d1e --- /dev/null +++ b/configs/vicreg_b256_1_0.5_0.1.yml @@ -0,0 +1,24 @@ +name: 'vicreg_b256_1_0.5_0.1' +encoder: + type: 'thinresnet34' +model: + type: 'simclr' + enable_mlp: true + infonce_loss_factor: 0.0 + vic_reg_factor: 1.0 + vic_reg_inv_weight: 1.0 + vic_reg_var_weight: 0.5 + vic_reg_cov_weight: 0.1 +training: + epochs: 300 + batch_size: 256 + learning_rate: 0.001 +dataset: + frame_length: 32000 + frame_split: true + extract_mfcc: true + train: './data/voxceleb1_train_list' + val_ratio: 0.0 + spec_augment: false + wav_augment: + enable: true \ No newline at end of file diff --git a/configs/vicreg_b256_1_1_0.04.yml b/configs/vicreg_b256_1_1_0.04.yml new file mode 100644 index 0000000..6f93920 --- /dev/null +++ b/configs/vicreg_b256_1_1_0.04.yml @@ -0,0 +1,24 @@ +name: 'vicreg_b256_1_1_0.04' +encoder: + type: 'thinresnet34' +model: + type: 'simclr' + enable_mlp: true + infonce_loss_factor: 0.0 + vic_reg_factor: 1.0 + vic_reg_inv_weight: 1.0 + vic_reg_var_weight: 1.0 + vic_reg_cov_weight: 0.04 +training: + epochs: 300 + batch_size: 256 + learning_rate: 0.001 +dataset: + frame_length: 32000 + frame_split: true + extract_mfcc: true + train: './data/voxceleb1_train_list' + val_ratio: 0.0 + spec_augment: false + wav_augment: + enable: true \ No newline at end of file diff --git a/configs/vicreg_b256_4xencoder.yml b/configs/vicreg_b256_1_1_0.1.yml similarity index 64% rename from configs/vicreg_b256_4xencoder.yml rename to configs/vicreg_b256_1_1_0.1.yml index 281b7a1..1cd8744 100644 --- a/configs/vicreg_b256_4xencoder.yml +++ b/configs/vicreg_b256_1_1_0.1.yml @@ -1,14 +1,16 @@ -name: 'vicreg_b256_4xencoder' +name: 'vicreg_b256_1_1_0.1' encoder: type: 'thinresnet34' - scale: 2 model: type: 'simclr' enable_mlp: true - loss_factor: 0.0 - vic_reg_factor: 25.0 + infonce_loss_factor: 0.0 + vic_reg_factor: 1.0 + vic_reg_inv_weight: 1.0 + vic_reg_var_weight: 1.0 + vic_reg_cov_weight: 0.1 training: - epochs: 500 + epochs: 300 batch_size: 256 learning_rate: 0.001 dataset: diff --git a/configs/vicreg_b256_2xmse.yml b/configs/vicreg_b256_1_1_0.yml similarity index 64% rename from configs/vicreg_b256_2xmse.yml rename to configs/vicreg_b256_1_1_0.yml index ea6d3e9..6da450a 100644 --- a/configs/vicreg_b256_2xmse.yml +++ b/configs/vicreg_b256_1_1_0.yml @@ -1,14 +1,16 @@ -name: 'vicreg_b256_2xmse' +name: 'vicreg_b256_1_1_0' encoder: type: 'thinresnet34' model: type: 'simclr' enable_mlp: true - enable_mse_clean_aug: true - loss_factor: 0.0 - vic_reg_factor: 25.0 + infonce_loss_factor: 0.0 + vic_reg_factor: 1.0 + vic_reg_inv_weight: 1.0 + vic_reg_var_weight: 1.0 + vic_reg_cov_weight: 0.0 training: - epochs: 500 + epochs: 300 batch_size: 256 learning_rate: 0.001 dataset: @@ -17,7 +19,6 @@ dataset: extract_mfcc: true train: './data/voxceleb1_train_list' val_ratio: 0.0 - provide_clean_and_aug: true spec_augment: false wav_augment: enable: true \ No newline at end of file diff --git a/configs/vicreg_b512.yml b/configs/vicreg_b512.yml index b4d6ed6..0e388ae 100644 --- a/configs/vicreg_b512.yml +++ b/configs/vicreg_b512.yml @@ -4,10 +4,10 @@ encoder: model: type: 'simclr' enable_mlp: true - loss_factor: 0.0 - vic_reg_factor: 25.0 + infonce_loss_factor: 0.0 + vic_reg_factor: 1.0 training: - epochs: 500 + epochs: 300 batch_size: 512 learning_rate: 0.001 dataset: diff --git a/prepare_data.py b/prepare_data.py index b6bbe19..691744b 100644 --- a/prepare_data.py +++ b/prepare_data.py @@ -12,26 +12,26 @@ ('http://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partab', 'bbfaaccefab65d82b21903e81a8a8020'), ('http://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partac', '017d579a2a96a077f40042ec33e51512'), ('http://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partad', '7bb1e9f70fddc7a678fa998ea8b3ba19'), - ('http://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaa', 'da070494c573e5c0564b1d11c3b20577'), - ('http://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partab', '17fe6dab2b32b48abaf1676429cdd06f'), - ('http://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partac', '1de58e086c5edf63625af1cb6d831528'), - ('http://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partad', '5a043eb03e15c5a918ee6a52aad477f9'), - ('http://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partae', 'cea401b624983e2d0b2a87fb5d59aa60'), - ('http://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaf', 'fc886d9ba90ab88e7880ee98effd6ae9'), - ('http://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partag', 'd160ecc3f6ee3eed54d55349531cb42e'), - ('http://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partah', '6b84a81b9af72a9d9eecbb3b1f602e65'), + # ('http://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaa', 'da070494c573e5c0564b1d11c3b20577'), + # ('http://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partab', '17fe6dab2b32b48abaf1676429cdd06f'), + # ('http://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partac', '1de58e086c5edf63625af1cb6d831528'), + # ('http://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partad', '5a043eb03e15c5a918ee6a52aad477f9'), + # ('http://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partae', 'cea401b624983e2d0b2a87fb5d59aa60'), + # ('http://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaf', 'fc886d9ba90ab88e7880ee98effd6ae9'), + # ('http://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partag', 'd160ecc3f6ee3eed54d55349531cb42e'), + # ('http://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partah', '6b84a81b9af72a9d9eecbb3b1f602e65'), ('http://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip', '185fdc63c3c739954633d50379a3d102') ] VOX_CONCATENATE = [ ('vox1_dev_wav_parta*', 'vox1_dev_wav.zip', 'ae63e55b951748cc486645f532ba230b'), - ('vox2_dev_aac_parta*', 'vox2_dev_aac.zip', 'bbc063c46078a602ca71605645c2a402') + # ('vox2_dev_aac_parta*', 'vox2_dev_aac.zip', 'bbc063c46078a602ca71605645c2a402') ] VOX_EXTRACT = [ 'vox1_dev_wav.zip', 'vox1_test_wav.zip', - 'vox2_dev_aac.zip' + # 'vox2_dev_aac.zip' ] AUG_DOWNLOAD = [ @@ -57,52 +57,46 @@ def get_md5(path): return hash_md5.hexdigest() -def download(entries, output_path): +def download(entries): for url, md5 in entries: filename = url.split('/')[-1] - out = os.path.join(output_path, filename) - - status = subprocess.call('wget %s -O %s' % (url, out), shell=True) + status = subprocess.call('wget %s -O %s' % (url, filename), shell=True) if status != 0: raise Exception('Download of %s failed' % filename) - if md5 != get_md5(out): + if md5 != get_md5(filename): raise Warning('Checksum of %s failed' % filename) -def concatenate(entries, output_path): - for src_filename, dst_filename, md5 in entries: - src_path = os.path.join(output_path, src_filename) - dst_path = os.path.join(output_path, dst_filename) - - subprocess.call('cat %s > %s' % (src_path, dst_path), shell=True) - subprocess.call('rm %s' % (src_path), shell=True) +def concatenate(entries): + for src, dst, md5 in entries: + subprocess.call('cat %s > %s' % (src, dst), shell=True) + subprocess.call('rm %s' % (src), shell=True) - if md5 != get_md5(dst_path): - raise Warning('Checksum of %s failed' % dst_filename) + if md5 != get_md5(dst): + raise Warning('Checksum of %s failed' % dst) -def extract(entries, output_path): +def extract(entries): for filename in entries: - filename = os.path.join(output_path, filename) if filename.endswith('.tar.gz'): subprocess.call('tar xf %s' % (filename), shell=True) elif filename.endswith('.zip'): subprocess.call('unzip %s' % (filename), shell=True) -def fix_vox_structure(output_path): - subprocess.call('mkdir %s/voxceleb1' % (output_path), shell=True) - subprocess.call('mv %s/wav/* %s/voxceleb1' % (output_path, output_path), shell=True) - subprocess.call('rm -r %s/wav' % (output_path), shell=True) - subprocess.call('mkdir %s/voxceleb2' % (output_path), shell=True) - subprocess.call('mv %s/dev/aac/* %s/voxceleb2' % (output_path, output_path), shell=True) - subprocess.call('rm -r %s/dev' % (output_path), shell=True) - subprocess.call('rm -r %s/vox*.zip' % (output_path), shell=True) +def fix_vox_structure(): + subprocess.call('mkdir voxceleb1', shell=True) + subprocess.call('mv wav/* voxceleb1', shell=True) + subprocess.call('rm -r wav', shell=True) + # subprocess.call('mkdir voxceleb2', shell=True) + # subprocess.call('mv dev/aac/* voxceleb2', shell=True) + # subprocess.call('rm -r dev', shell=True) + subprocess.call('rm -r vox*.zip', shell=True) -def convert_to_wav(output_path): - files = glob.glob('%s/voxceleb2/*/*/*.m4a' % output_path) +def convert_vox2_to_wav(): + files = glob.glob('voxceleb2/*/*/*.m4a') for src in tqdm(files): dst = src.replace('.m4a', '.wav') @@ -115,20 +109,20 @@ def convert_to_wav(output_path): subprocess.call('rm %s' % src, shell=True) -def fix_aug_structure(output_path): - subprocess.call('mv %s/RIRS_NOISES/simulated_rirs %s' % (output_path, output_path), shell=True) - subprocess.call('rm -r %s/RIRS_NOISES' % (output_path), shell=True) - subprocess.call('rm -r %s/rirs_noises.zip' % (output_path), shell=True) - subprocess.call('rm -r %s/musan.tar.gz' % (output_path), shell=True) +def fix_aug_structure(): + subprocess.call('mv RIRS_NOISES/simulated_rirs .', shell=True) + subprocess.call('rm -r RIRS_NOISES', shell=True) + subprocess.call('rm -r rirs_noises.zip', shell=True) + subprocess.call('rm -r musan.tar.gz', shell=True) -def split_musan(output_path, length=16000*8, stride=16000*8): - files = glob.glob('%s/musan/*/*/*.wav' % output_path) +def split_musan(length=16000*8, stride=16000*8): + files = glob.glob('musan/*/*/*.wav') for file in tqdm(files): audio, fs = sf.read(file) - directory = os.path.dirname(file).replace('/musan/', '/musan_split/') + directory = os.path.dirname(file).replace('musan/', 'musan_split/') os.makedirs(directory, exist_ok=True) for st in range(0, len(audio) - length, stride): @@ -136,10 +130,10 @@ def split_musan(output_path, length=16000*8, stride=16000*8): filename = directory + '/' + filename sf.write(filename, audio[st:st+length], fs) - subprocess.call('rm -r %s/musan' % (output_path), shell=True) + subprocess.call('rm -r musan', shell=True) -def split_2_ssd(output_path): +def split_2_ssd(): dirs = glob.glob('dev/aac/*') for src in dirs: @@ -158,9 +152,9 @@ def split_2_ssd(output_path): shutil.copytree(src, dst) -def create_vox1_train_list_file(output_path): +def create_vox1_train_list_file(): test_speakers = set() - with open(os.path.join(output_path, TRIALS_FILENAME)) as trials: + with open(TRIALS_FILENAME) as trials: for line in trials.readlines(): parts = line.rstrip().split() spkr_id_a = parts[1].split('/')[0] @@ -168,9 +162,9 @@ def create_vox1_train_list_file(output_path): test_speakers.add(spkr_id_a) test_speakers.add(spkr_id_b) - files = glob.glob('%s/voxceleb1/*/*/*.wav' % output_path) + files = glob.glob('voxceleb1/*/*/*.wav') files.sort() - out_file = open('%s/%s' % (output_path, VOX1_TRAIN_LIST), 'w') + out_file = open(VOX1_TRAIN_LIST, 'w') for file in files: spkr_id = file.split('/')[-3] file = '/'.join(file.split('/')[-3:]) @@ -180,10 +174,10 @@ def create_vox1_train_list_file(output_path): out_file.close() -def create_vox2_train_list_file(output_path): - files = glob.glob('%s/voxceleb2/*/*/*.wav' % output_path) +def create_vox2_train_list_file(): + files = glob.glob('voxceleb2/*/*/*.wav') files.sort() - out_file = open('%s/%s' % (output_path, VOX2_TRAIN_LIST), 'w') + out_file = open(VOX2_TRAIN_LIST, 'w') for file in files: spkr_id = file.split('/')[-3] file = '/'.join(file.split('/')[-3:]) @@ -192,9 +186,8 @@ def create_vox2_train_list_file(output_path): out_file.close() -def download_trials_file(output_path): - out = os.path.join(output_path, TRIALS_FILENAME) - status = subprocess.call('wget %s -O %s' % (TRIALS_URL, out), shell=True) +def download_trials_file(): + status = subprocess.call('wget %s -O %s' % (TRIALS_URL, TRIALS_FILENAME), shell=True) if status != 0: raise Exception('Download of %s failed' % TRIALS_FILENAME) @@ -204,19 +197,21 @@ def download_trials_file(output_path): parser.add_argument('output_path', help='Path to store datasets.') args = parser.parse_args() + os.chdir(args.output_path) + # VoxCeleb1 and VoxCeleb2 - download(VOX_DOWNLOADS, args.output_path) - concatenate(VOX_CONCATENATE, args.output_path) - extract(VOX_EXTRACT, args.output_path) - fix_vox_structure(args.output_path) - convert_to_wav(args.output_path) + download(VOX_DOWNLOADS) + concatenate(VOX_CONCATENATE) + extract(VOX_EXTRACT) + fix_vox_structure() + # convert_vox2_to_wav() # Augmentation: MUSAN and simulated_rirs - download(AUG_DOWNLOAD, args.output_path) - extract(AUG_EXTRACT, args.output_path) - fix_aug_structure(args.output_path) - split_musan(args.output_path) - - download_trials_file(args.output_path) - create_vox1_train_list_file(args.output_path) - create_vox2_train_list_file(args.output_path) \ No newline at end of file + download(AUG_DOWNLOAD) + extract(AUG_EXTRACT) + fix_aug_structure() + split_musan() + + download_trials_file() + create_vox1_train_list_file() + # create_vox2_train_list_file() \ No newline at end of file diff --git a/run.sh b/run.sh index 0bc2be7..41334d4 100755 --- a/run.sh +++ b/run.sh @@ -3,22 +3,9 @@ mkdir data python prepare_data.py data -export CUDA_VISIBLE_DEVICES=0,1 +# export CUDA_VISIBLE_DEVICES=0,1 -python train.py configs/vicreg_b256.yml -python evaluate.py configs/vicreg_b256.yml - -python train.py configs/vicreg_b512.yml -python evaluate.py configs/vicreg_b512.yml - -python train.py configs/vicreg_b1024.yml -python evaluate.py configs/vicreg_b1024.yml - -python train.py configs/vicreg_b2048.yml -python evaluate.py configs/vicreg_b2048.yml - -python train.py configs/vicreg_b256_4xencoder.yml -python evaluate.py configs/vicreg_b256_4xencoder.yml - -python train.py configs/vicreg_b256_2xmse.yml -python evaluate.py configs/vicreg_b256_2xmse.yml \ No newline at end of file +python train.py configs/vicreg_b256_1_1_0.yml +python train.py configs/vicreg_b256_1_1_0.1.yml +python train.py configs/vicreg_b256_1_1_0.04.yml +python train.py configs/vicreg_b256_1_0.5_0.1.yml \ No newline at end of file diff --git a/sslforslr/dataset/AudioDatasetLoader.py b/sslforslr/dataset/AudioDatasetLoader.py index a65ef76..927906a 100644 --- a/sslforslr/dataset/AudioDatasetLoader.py +++ b/sslforslr/dataset/AudioDatasetLoader.py @@ -97,18 +97,20 @@ def __getitem__(self, i): X1.append(self.preprocess_data(frame1)) X2.append(self.preprocess_data(frame2)) y.append(self.labels[index]) - else: - data1 = load_audio(self.files[index[0]], self.frame_length) # (1, T) - data1 = self.preprocess_data(data1) - X1.append(data1) - data2 = load_audio(self.files[index[1]], self.frame_length) # (1, T) - data2 = self.preprocess_data(data2) - X2.append(data2) + elif self.supervised_sampler: + frame1 = load_audio(self.files[index[0]], self.frame_length) + frame2 = load_audio(self.files[index[1]], self.frame_length) + X1.append(self.preprocess_data(frame1)) + X2.append(self.preprocess_data(frame2)) y.append(self.labels[index[0]]) + else: + frame = load_audio(self.files[index], self.frame_length) + X1.append(self.preprocess_data(frame)) + y.append(self.labels[index]) - #if self.frame_split: - return np.array(X1), np.array(X2), np.array(y) - #return np.array(X1), np.array(y) + if self.frame_split or self.supervised_sampler: + return np.array(X1), np.array(X2), np.array(y) + return np.array(X1), np.array(y) def enable_supervision(self, nb_labels_per_spk=100): self.supervised_sampler = SupervisedTrainingSampler( diff --git a/sslforslr/models/simclr/SimCLR.py b/sslforslr/models/simclr/SimCLR.py index a3bb748..d0c7206 100644 --- a/sslforslr/models/simclr/SimCLR.py +++ b/sslforslr/models/simclr/SimCLR.py @@ -23,7 +23,7 @@ def __init__(self, self.enable_mlp = config.enable_mlp self.enable_mse_clean_aug = config.enable_mse_clean_aug - self.loss_factor = config.loss_factor + self.infonce_loss_factor = config.infonce_loss_factor self.vic_reg_factor = config.vic_reg_factor self.mse_clean_aug_factor = config.mse_clean_aug_factor self.reg = regularizers.l2(config.weight_reg) @@ -31,7 +31,11 @@ def __init__(self, self.encoder = encoder self.mlp = MLP() self.infonce_loss = InfoNCELoss() - self.vic_reg = VICReg() + self.vic_reg = VICReg( + config.vic_reg_inv_weight, + config.vic_reg_var_weight, + config.vic_reg_cov_weight + ) def compile(self, optimizer, **kwargs): super().compile(**kwargs) @@ -70,7 +74,7 @@ def train_step(self, data): Z_1_aug, Z_2_aug = self.get_embeddings(X_1_aug, X_2_aug) loss, accuracy = self.infonce_loss((Z_1_aug, Z_2_aug)) - loss = self.loss_factor * loss + loss = self.infonce_loss_factor * loss loss += self.vic_reg_factor * self.vic_reg((Z_1_aug, Z_2_aug)) if self.enable_mse_clean_aug: diff --git a/sslforslr/models/simclr/SimCLRModelConfig.py b/sslforslr/models/simclr/SimCLRModelConfig.py index 0563879..96a0351 100644 --- a/sslforslr/models/simclr/SimCLRModelConfig.py +++ b/sslforslr/models/simclr/SimCLRModelConfig.py @@ -5,10 +5,17 @@ @dataclass class SimCLRModelConfig(ModelConfig): enable_mlp: bool = False - enable_mse_clean_aug: bool = False - loss_factor: float = 1 + + infonce_loss_factor: float = 1.0 + vic_reg_factor: float = 0.1 + vic_reg_inv_weight: float = 1.0 + vic_reg_var_weight: float = 1.0 + vic_reg_cov_weight: float = 0.04 + + enable_mse_clean_aug: bool = False mse_clean_aug_factor: float = 0.1 + weight_reg: float = 1e-4 SimCLRModelConfig.__NAME__ = 'simclr' diff --git a/sslforslr/modules/VICReg.py b/sslforslr/modules/VICReg.py index 86d929e..8fae017 100644 --- a/sslforslr/modules/VICReg.py +++ b/sslforslr/modules/VICReg.py @@ -8,12 +8,17 @@ def off_diagonal(matrix): class VICReg(Layer): - def __init__(self, lamda=1, mu=1, nu=0.04): + def __init__( + self, + inv_weight=1.0, + var_weight=1.0, + cov_weight=0.04 + ): super().__init__() - self.lamda = lamda - self.mu = mu - self.nu = nu + self.inv_weight = inv_weight + self.var_weight = var_weight + self.cov_weight = cov_weight def call(self, data): X_a, X_b = data @@ -25,14 +30,14 @@ def call(self, data): X_b_mean, X_b_var = tf.nn.moments(X_b, axes=[0]) # Invariance loss - sim_loss = tf.keras.metrics.mean_squared_error(X_a, X_b) - sim_loss = tf.math.reduce_mean(sim_loss) + inv_loss = tf.keras.metrics.mean_squared_error(X_a, X_b) + inv_loss = tf.math.reduce_mean(inv_loss) # Variance loss X_a_std = tf.math.sqrt(X_a_var + 1e-04) X_b_std = tf.math.sqrt(X_b_var + 1e-04) - std_loss = tf.math.reduce_mean(tf.nn.relu(1 - X_a_std)) - std_loss += tf.math.reduce_mean(tf.nn.relu(1 - X_b_std)) + var_loss = tf.math.reduce_mean(tf.nn.relu(1 - X_a_std)) + var_loss += tf.math.reduce_mean(tf.nn.relu(1 - X_b_std)) # Covariance loss X_a = X_a - X_a_mean @@ -42,7 +47,7 @@ def call(self, data): cov_loss = tf.math.reduce_sum(tf.math.pow(off_diagonal(X_a_cov), 2)) / D cov_loss += tf.math.reduce_sum(tf.math.pow(off_diagonal(X_b_cov), 2)) / D - loss = self.lamda * sim_loss - loss += self.mu * std_loss - loss += self.nu * cov_loss + loss = self.inv_weight * inv_loss + loss += self.var_weight * var_loss + loss += self.cov_weight * cov_loss return loss \ No newline at end of file