From 18cdaef38459aaeef038b3d679869e42fca98477 Mon Sep 17 00:00:00 2001 From: Alexandre Fournier Montgieux Date: Wed, 20 Sep 2023 11:10:35 +0200 Subject: [PATCH 1/2] Validation update, slurm local rank and world size --- convert.py | 7 ++++-- data.py | 40 +++++++++++------------------- dataset/five_validation_dataset.py | 24 ++++++++++++------ evaluate_utils.py | 14 +++++------ train_val.py | 18 +++++++++----- utils.py | 20 +++++++++++++-- 6 files changed, 74 insertions(+), 49 deletions(-) diff --git a/convert.py b/convert.py index 0ba32e0..cb6d9a8 100644 --- a/convert.py +++ b/convert.py @@ -3,7 +3,6 @@ import mxnet as mx from tqdm import tqdm from PIL import Image -import bcolz import pickle import cv2 import numpy as np @@ -49,9 +48,11 @@ def save_rec_to_img_dir(rec_path, swap_color_channel=False, save_as_png=False): img.save(img_save_path, quality=95) def load_bin(path, rootdir, image_size=[112,112]): + import os test_transform = trans.Compose([ trans.ToTensor(), + trans.Resize(image_size, antialias=True), trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) if not rootdir.exists(): @@ -87,11 +88,13 @@ def load_bin(path, rootdir, image_size=[112,112]): save_rec_to_img_dir(rec_path, swap_color_channel=args.swap_color_channel) if args.make_validation_memfiles: + import bcolz + # for saving memory usage during training # bin_files = ['agedb_30', 'cfp_fp', 'lfw', 'calfw', 'cfp_ff', 'cplfw', 'vgg2_fp'] bin_files = list(filter(lambda x: os.path.splitext(x)[1] in ['.bin'], os.listdir(args.rec_path))) bin_files = [i.split('.')[0] for i in bin_files] - for i in range(len(bin_files)): + print("dealing with ", bin_files[i], flush=True) load_bin(rec_path/(bin_files[i]+'.bin'), rec_path/bin_files[i]) diff --git a/data.py b/data.py index 8ef327c..233b1ff 100644 --- a/data.py +++ b/data.py @@ -6,7 +6,7 @@ import pandas as pd import evaluate_utils from dataset.image_folder_dataset import CustomImageFolderDataset -from dataset.five_validation_dataset import FiveValidationDataset +from dataset.five_validation_dataset import MultipleValidationDataset from dataset.record_dataset import AugmentRecordDataset @@ -27,7 +27,6 @@ def __init__(self, **kwargs): self.photometric_augmentation_prob = kwargs['photometric_augmentation_prob'] self.swap_color_channel = kwargs['swap_color_channel'] self.use_mxrecord = kwargs['use_mxrecord'] - concat_mem_file_name = os.path.join(self.data_root, self.val_data_path, 'concat_validation_memfile') self.concat_mem_file_name = concat_mem_file_name @@ -41,8 +40,16 @@ def prepare_data(self): if not os.path.isfile(self.concat_mem_file_name): # create a concat memfile concat = [] - for key in ['agedb_30', 'cfp_fp', 'lfw', 'cplfw', 'calfw']: - np_array, issame = evaluate_utils.get_val_pair(path=os.path.join(self.data_root, self.val_data_path), + keys= []#['agedb_30', 'cfp_fp', 'lfw', 'cplfw', 'calfw'] + validation_sets_path = os.path.join(self.data_root, self.val_data_path) + for file in os.listdir(validation_sets_path): + if file[-4:] == ".bin": + key = ".".join(file.split(".")[:-1] ) + if key not in keys: + keys.append(key) + print("collected keys data.py ", keys) + for key in keys: + np_array, issame = evaluate_utils.get_val_pair(path=validation_sets_path, name=key, use_memfile=False) concat.append(np_array) @@ -69,7 +76,6 @@ def setup(self, stage=None): with open('assets/ms1mv2_train_subset_index.txt', 'r') as f: subset_index = [int(i) for i in f.read().split(',')] self.subset_ms1mv2_dataset(subset_index) - print('creating val dataset') self.val_dataset = val_dataset(self.data_root, self.val_data_path, self.concat_mem_file_name) @@ -127,6 +133,7 @@ def train_dataset(data_root, train_data_path, output_dir): train_transform = transforms.Compose([ + transforms.Resize((112,112), antialias=True), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) @@ -157,32 +164,15 @@ def train_dataset(data_root, train_data_path, def val_dataset(data_root, val_data_path, concat_mem_file_name): val_data = evaluate_utils.get_val_data(os.path.join(data_root, val_data_path)) - # theses datasets are already normalized with mean 0.5, std 0.5 - age_30, cfp_fp, lfw, age_30_issame, cfp_fp_issame, lfw_issame, cplfw, cplfw_issame, calfw, calfw_issame = val_data - val_data_dict = { - 'agedb_30': (age_30, age_30_issame), - "cfp_fp": (cfp_fp, cfp_fp_issame), - "lfw": (lfw, lfw_issame), - "cplfw": (cplfw, cplfw_issame), - "calfw": (calfw, calfw_issame), - } - val_dataset = FiveValidationDataset(val_data_dict, concat_mem_file_name) + val_dataset = MultipleValidationDataset(val_data, concat_mem_file_name) return val_dataset -def test_dataset(data_root, val_data_path, concat_mem_file_name): +def test_dataset(data_root, val_data_path, concat_mem_file_name, output_dir = None): val_data = evaluate_utils.get_val_data(os.path.join(data_root, val_data_path)) # theses datasets are already normalized with mean 0.5, std 0.5 - age_30, cfp_fp, lfw, age_30_issame, cfp_fp_issame, lfw_issame, cplfw, cplfw_issame, calfw, calfw_issame = val_data - val_data_dict = { - 'agedb_30': (age_30, age_30_issame), - "cfp_fp": (cfp_fp, cfp_fp_issame), - "lfw": (lfw, lfw_issame), - "cplfw": (cplfw, cplfw_issame), - "calfw": (calfw, calfw_issame), - } - val_dataset = FiveValidationDataset(val_data_dict, concat_mem_file_name) + val_dataset = MultipleValidationDataset(val_data, concat_mem_file_name) return val_dataset diff --git a/dataset/five_validation_dataset.py b/dataset/five_validation_dataset.py index 7ff9120..79aac10 100644 --- a/dataset/five_validation_dataset.py +++ b/dataset/five_validation_dataset.py @@ -3,16 +3,17 @@ import evaluate_utils import torch -class FiveValidationDataset(Dataset): +class MultipleValidationDataset(Dataset): def __init__(self, val_data_dict, concat_mem_file_name): ''' - concatenates all validation datasets from emore + concatenates all validation datasets from emore, for instance: val_data_dict = { 'agedb_30': (agedb_30, agedb_30_issame), "cfp_fp": (cfp_fp, cfp_fp_issame), "lfw": (lfw, lfw_issame), "cplfw": (cplfw, cplfw_issame), "calfw": (calfw, calfw_issame), + ... } agedb_30: 0 cfp_fp: 1 @@ -20,24 +21,33 @@ def __init__(self, val_data_dict, concat_mem_file_name): cplfw: 3 calfw: 4 ''' - self.dataname_to_idx = {"agedb_30": 0, "cfp_fp": 1, "lfw": 2, "cplfw": 3, "calfw": 4} - + self.dataname_to_idx ={} #{"agedb_30": 0, "cfp_fp": 1, "lfw": 2, "cplfw": 3, "calfw": 4} + for i, key in enumerate(val_data_dict.keys()): + print(key, " is present") + self.dataname_to_idx[key] = i + next_idx = 5 self.val_data_dict = val_data_dict # concat all dataset all_imgs = [] all_issame = [] all_dataname = [] key_orders = [] + print() for key, (imgs, issame) in val_data_dict.items(): all_imgs.append(imgs) dup_issame = [] # hacky way to make the issame length same as imgs. [1, 1, 0, 0, ...] for same in issame: dup_issame.append(same) dup_issame.append(same) + assert len(dup_issame) == len(imgs), f"found {len(dup_issame)} labels for {len(imgs)} imgs. Please check the following dataset: {key}" all_issame.append(dup_issame) + if not key in self.dataname_to_idx.keys(): + raise Exception("error met, keys unconsistency") all_dataname.append([self.dataname_to_idx[key]] * len(imgs)) + key_orders.append(key) - assert key_orders == ['agedb_30', 'cfp_fp', 'lfw', 'cplfw', 'calfw'] + # assert is irrelevent since the switch to ordereddict but keeping it for fun. If you want to get rid of hardcode stuff remove this + assert np.all([key in key_orders for key in ['agedb_30', 'cfp_fp', 'lfw', 'cplfw', 'calfw']]), f"assert is irrelevent since the switch to ordereddict but keeping it for fun. If you want to get rid of hardcode stuff remove this" if isinstance(all_imgs[0], np.memmap): self.all_imgs = evaluate_utils.read_memmap(concat_mem_file_name) @@ -47,8 +57,8 @@ def __init__(self, val_data_dict, concat_mem_file_name): self.all_issame = np.concatenate(all_issame) self.all_dataname = np.concatenate(all_dataname) - assert len(self.all_imgs) == len(self.all_issame) - assert len(self.all_issame) == len(self.all_dataname) + assert len(self.all_imgs) == len(self.all_issame), f"after concatenation {len(self.all_imgs)} images found vs {len(self.all_issame)} labels, maybe regenerate the memfiles" + assert len(self.all_issame) == len(self.all_dataname), f"after concatenation {len(self.all_dataname)} dataname associations found vs {len(self.all_issame)} labels" def __getitem__(self, index): x_np = self.all_imgs[index].copy() diff --git a/evaluate_utils.py b/evaluate_utils.py index 1ce7f29..45cdbe7 100644 --- a/evaluate_utils.py +++ b/evaluate_utils.py @@ -6,22 +6,22 @@ from sklearn.model_selection import KFold from sklearn.decomposition import PCA import sklearn +from collections import OrderedDict from scipy import interpolate def get_val_data(data_path): - agedb_30, agedb_30_issame = get_val_pair(data_path, 'agedb_30') - cfp_fp, cfp_fp_issame = get_val_pair(data_path, 'cfp_fp') - lfw, lfw_issame = get_val_pair(data_path, 'lfw') - cplfw, cplfw_issame = get_val_pair(data_path, 'cplfw') - calfw, calfw_issame = get_val_pair(data_path, 'calfw') - return agedb_30, cfp_fp, lfw, agedb_30_issame, cfp_fp_issame, lfw_issame, cplfw, cplfw_issame, calfw, calfw_issame + all_val_sets = list(filter(lambda elt : elt[-4:] == ".bin", os.listdir(data_path))) + returned_dict = OrderedDict() + for set_name in all_val_sets: + set_name =set_name [:-4] + returned_dict[set_name] = get_val_pair(data_path, set_name) + return returned_dict def get_val_pair(path, name, use_memfile=True): if use_memfile: mem_file_dir = os.path.join(path, name, 'memfile') mem_file_name = os.path.join(mem_file_dir, 'mem_file.dat') if os.path.isdir(mem_file_dir): - print('laoding validation data memfile') np_array = read_memmap(mem_file_name) else: os.makedirs(mem_file_dir) diff --git a/train_val.py b/train_val.py index abba4cc..1507601 100644 --- a/train_val.py +++ b/train_val.py @@ -3,6 +3,7 @@ import torch.optim.lr_scheduler as lr_scheduler from pytorch_lightning.core import LightningModule from torch.nn import CrossEntropyLoss +from torch.nn import Module as TorchModule import evaluate_utils import head import net @@ -14,6 +15,8 @@ class Trainer(LightningModule): def __init__(self, **kwargs): super(Trainer, self).__init__() self.save_hyperparameters() # sets self.hparams + self.val_dataname_to_idx = None + self.test_dataname_to_idx = None self.class_num = utils.get_num_class(self.hparams) print('classnum: {}'.format(self.class_num)) @@ -88,7 +91,7 @@ def training_epoch_end(self, outputs): def validation_step(self, batch, batch_idx): images, labels, dataname, image_index = batch embeddings, norms = self.model(images) - + fliped_images = torch.flip(images, dims=[3]) flipped_embeddings, flipped_norms = self.model(fliped_images) stacked_embeddings = torch.stack([embeddings, flipped_embeddings], dim=0) @@ -115,10 +118,12 @@ def validation_step(self, batch, batch_idx): } def validation_epoch_end(self, outputs): - all_output_tensor, all_norm_tensor, all_target_tensor, all_dataname_tensor = self.gather_outputs(outputs) - dataname_to_idx = {"agedb_30": 0, "cfp_fp": 1, "lfw": 2, "cplfw": 3, "calfw": 4} + + if self.val_dataname_to_idx is None: + self.val_dataname_to_idx = self.trainer.val_dataloaders[0].dataset.dataname_to_idx + dataname_to_idx = self.val_dataname_to_idx#{"agedb_30": 0, "cfp_fp": 1, "lfw": 2, "cplfw": 3, "calfw": 4} idx_to_dataname = {val: key for key, val in dataname_to_idx.items()} val_logs = {} for dataname_idx in all_dataname_tensor.unique(): @@ -139,7 +144,6 @@ def validation_epoch_end(self, outputs): val_logs[f'{dataname}_val_acc'] for dataname in dataname_to_idx.keys() if f'{dataname}_val_acc' in val_logs ]) val_logs['epoch'] = self.current_epoch - for k, v in val_logs.items(): # self.log(name=k, value=v, rank_zero_only=True) self.log(name=k, value=v) @@ -150,10 +154,12 @@ def test_step(self, batch, batch_idx): return self.validation_step(batch, batch_idx) def test_epoch_end(self, outputs): - all_output_tensor, all_norm_tensor, all_target_tensor, all_dataname_tensor = self.gather_outputs(outputs) - dataname_to_idx = {"agedb_30": 0, "cfp_fp": 1, "lfw": 2, "cplfw": 3, "calfw": 4} + if self.test_dataname_to_idx is None: + self.test_dataname_to_idx = self.trainer.test_dataloaders[0].dataset.dataname_to_idx + + dataname_to_idx = self.test_dataname_to_idx # {"agedb_30": 0, "cfp_fp": 1, "lfw": 2, "cplfw": 3, "calfw": 4} idx_to_dataname = {val: key for key, val in dataname_to_idx.items()} test_logs = {} for dataname_idx in all_dataname_tensor.unique(): diff --git a/utils.py b/utils.py index a1fa154..2882d4c 100644 --- a/utils.py +++ b/utils.py @@ -3,7 +3,6 @@ import torch import torch.distributed as dist - class dotdict(dict): """dot.notation access to dictionary attributes""" __getattr__ = dict.get @@ -37,14 +36,31 @@ def is_dist_avail_and_initialized(): return True def get_world_size(): + try: + return int(os.environ["SLURM_TASKS_PER_NODE"]) + except Exception as e: + print("SLURM not detected fro infering world size", os.environ("SLURM_TASKS_PER_NODE"), flush=True) + pass if not is_dist_avail_and_initialized(): return 1 return dist.get_world_size() def get_local_rank(): + try: + return int(os.environ["SLURM_LOCALID"]) + except Exception as e: + print("SLURM NOT DETECTED FOR INFEREING LOCAL ID", flush=True) + pass if not is_dist_avail_and_initialized(): return 0 - return int(os.environ["LOCAL_RANK"]) + + try: + return int(os.environ["LOCAL_RANK"]) + except Exception as e: + print("env varaiable LOCAL_RANK is NOT SET, setting it though dist", flush=True) + print("local rank is ", dist.get_rank()," and world size is ", get_world_size(), flush=True) + return int(dist.get_rank()) + def all_gather(data): """ From 6782b0e7351de4975d842c38d5ca249bce42abcc Mon Sep 17 00:00:00 2001 From: Alexandre Fournier Montgieux Date: Wed, 20 Sep 2023 11:15:11 +0200 Subject: [PATCH 2/2] requirements update --- requirements.txt | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9d15561..a75643c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,12 @@ pytorch-lightning==1.8.6 -pytorch<=1.13.1 +torch<=1.13.1 tqdm bcolz-zipline prettytable menpo mxnet opencv-python +scikit-learn +torchvision<=0.14.1 +pandas +numpy<=1.23 \ No newline at end of file