Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multipe Validation Set Update #125

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import mxnet as mx
from tqdm import tqdm
from PIL import Image
import bcolz
import pickle
import cv2
import numpy as np
Expand Down Expand Up @@ -49,9 +48,11 @@ def save_rec_to_img_dir(rec_path, swap_color_channel=False, save_as_png=False):
img.save(img_save_path, quality=95)

def load_bin(path, rootdir, image_size=[112,112]):
import os

test_transform = trans.Compose([
trans.ToTensor(),
trans.Resize(image_size, antialias=True),
trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
if not rootdir.exists():
Expand Down Expand Up @@ -87,11 +88,13 @@ def load_bin(path, rootdir, image_size=[112,112]):
save_rec_to_img_dir(rec_path, swap_color_channel=args.swap_color_channel)

if args.make_validation_memfiles:
import bcolz

# for saving memory usage during training
# bin_files = ['agedb_30', 'cfp_fp', 'lfw', 'calfw', 'cfp_ff', 'cplfw', 'vgg2_fp']
bin_files = list(filter(lambda x: os.path.splitext(x)[1] in ['.bin'], os.listdir(args.rec_path)))
bin_files = [i.split('.')[0] for i in bin_files]

for i in range(len(bin_files)):
print("dealing with ", bin_files[i], flush=True)
load_bin(rec_path/(bin_files[i]+'.bin'), rec_path/bin_files[i])

40 changes: 15 additions & 25 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
import evaluate_utils
from dataset.image_folder_dataset import CustomImageFolderDataset
from dataset.five_validation_dataset import FiveValidationDataset
from dataset.five_validation_dataset import MultipleValidationDataset
from dataset.record_dataset import AugmentRecordDataset


Expand All @@ -27,7 +27,6 @@ def __init__(self, **kwargs):
self.photometric_augmentation_prob = kwargs['photometric_augmentation_prob']
self.swap_color_channel = kwargs['swap_color_channel']
self.use_mxrecord = kwargs['use_mxrecord']

concat_mem_file_name = os.path.join(self.data_root, self.val_data_path, 'concat_validation_memfile')
self.concat_mem_file_name = concat_mem_file_name

Expand All @@ -41,8 +40,16 @@ def prepare_data(self):
if not os.path.isfile(self.concat_mem_file_name):
# create a concat memfile
concat = []
for key in ['agedb_30', 'cfp_fp', 'lfw', 'cplfw', 'calfw']:
np_array, issame = evaluate_utils.get_val_pair(path=os.path.join(self.data_root, self.val_data_path),
keys= []#['agedb_30', 'cfp_fp', 'lfw', 'cplfw', 'calfw']
validation_sets_path = os.path.join(self.data_root, self.val_data_path)
for file in os.listdir(validation_sets_path):
if file[-4:] == ".bin":
key = ".".join(file.split(".")[:-1] )
if key not in keys:
keys.append(key)
print("collected keys data.py ", keys)
for key in keys:
np_array, issame = evaluate_utils.get_val_pair(path=validation_sets_path,
name=key,
use_memfile=False)
concat.append(np_array)
Expand All @@ -69,7 +76,6 @@ def setup(self, stage=None):
with open('assets/ms1mv2_train_subset_index.txt', 'r') as f:
subset_index = [int(i) for i in f.read().split(',')]
self.subset_ms1mv2_dataset(subset_index)

print('creating val dataset')
self.val_dataset = val_dataset(self.data_root, self.val_data_path, self.concat_mem_file_name)

Expand Down Expand Up @@ -127,6 +133,7 @@ def train_dataset(data_root, train_data_path,
output_dir):

train_transform = transforms.Compose([
transforms.Resize((112,112), antialias=True),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
Expand Down Expand Up @@ -157,32 +164,15 @@ def train_dataset(data_root, train_data_path,

def val_dataset(data_root, val_data_path, concat_mem_file_name):
val_data = evaluate_utils.get_val_data(os.path.join(data_root, val_data_path))
# theses datasets are already normalized with mean 0.5, std 0.5
age_30, cfp_fp, lfw, age_30_issame, cfp_fp_issame, lfw_issame, cplfw, cplfw_issame, calfw, calfw_issame = val_data
val_data_dict = {
'agedb_30': (age_30, age_30_issame),
"cfp_fp": (cfp_fp, cfp_fp_issame),
"lfw": (lfw, lfw_issame),
"cplfw": (cplfw, cplfw_issame),
"calfw": (calfw, calfw_issame),
}
val_dataset = FiveValidationDataset(val_data_dict, concat_mem_file_name)
val_dataset = MultipleValidationDataset(val_data, concat_mem_file_name)

return val_dataset


def test_dataset(data_root, val_data_path, concat_mem_file_name):
def test_dataset(data_root, val_data_path, concat_mem_file_name, output_dir = None):
val_data = evaluate_utils.get_val_data(os.path.join(data_root, val_data_path))
# theses datasets are already normalized with mean 0.5, std 0.5
age_30, cfp_fp, lfw, age_30_issame, cfp_fp_issame, lfw_issame, cplfw, cplfw_issame, calfw, calfw_issame = val_data
val_data_dict = {
'agedb_30': (age_30, age_30_issame),
"cfp_fp": (cfp_fp, cfp_fp_issame),
"lfw": (lfw, lfw_issame),
"cplfw": (cplfw, cplfw_issame),
"calfw": (calfw, calfw_issame),
}
val_dataset = FiveValidationDataset(val_data_dict, concat_mem_file_name)
val_dataset = MultipleValidationDataset(val_data, concat_mem_file_name)
return val_dataset


24 changes: 17 additions & 7 deletions dataset/five_validation_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,51 @@
import evaluate_utils
import torch

class FiveValidationDataset(Dataset):
class MultipleValidationDataset(Dataset):
def __init__(self, val_data_dict, concat_mem_file_name):
'''
concatenates all validation datasets from emore
concatenates all validation datasets from emore, for instance:
val_data_dict = {
'agedb_30': (agedb_30, agedb_30_issame),
"cfp_fp": (cfp_fp, cfp_fp_issame),
"lfw": (lfw, lfw_issame),
"cplfw": (cplfw, cplfw_issame),
"calfw": (calfw, calfw_issame),
...
}
agedb_30: 0
cfp_fp: 1
lfw: 2
cplfw: 3
calfw: 4
'''
self.dataname_to_idx = {"agedb_30": 0, "cfp_fp": 1, "lfw": 2, "cplfw": 3, "calfw": 4}

self.dataname_to_idx ={} #{"agedb_30": 0, "cfp_fp": 1, "lfw": 2, "cplfw": 3, "calfw": 4}
for i, key in enumerate(val_data_dict.keys()):
print(key, " is present")
self.dataname_to_idx[key] = i
next_idx = 5
self.val_data_dict = val_data_dict
# concat all dataset
all_imgs = []
all_issame = []
all_dataname = []
key_orders = []
print()
for key, (imgs, issame) in val_data_dict.items():
all_imgs.append(imgs)
dup_issame = [] # hacky way to make the issame length same as imgs. [1, 1, 0, 0, ...]
for same in issame:
dup_issame.append(same)
dup_issame.append(same)
assert len(dup_issame) == len(imgs), f"found {len(dup_issame)} labels for {len(imgs)} imgs. Please check the following dataset: {key}"
all_issame.append(dup_issame)
if not key in self.dataname_to_idx.keys():
raise Exception("error met, keys unconsistency")
all_dataname.append([self.dataname_to_idx[key]] * len(imgs))

key_orders.append(key)
assert key_orders == ['agedb_30', 'cfp_fp', 'lfw', 'cplfw', 'calfw']
# assert is irrelevent since the switch to ordereddict but keeping it for fun. If you want to get rid of hardcode stuff remove this
assert np.all([key in key_orders for key in ['agedb_30', 'cfp_fp', 'lfw', 'cplfw', 'calfw']]), f"assert is irrelevent since the switch to ordereddict but keeping it for fun. If you want to get rid of hardcode stuff remove this"

if isinstance(all_imgs[0], np.memmap):
self.all_imgs = evaluate_utils.read_memmap(concat_mem_file_name)
Expand All @@ -47,8 +57,8 @@ def __init__(self, val_data_dict, concat_mem_file_name):
self.all_issame = np.concatenate(all_issame)
self.all_dataname = np.concatenate(all_dataname)

assert len(self.all_imgs) == len(self.all_issame)
assert len(self.all_issame) == len(self.all_dataname)
assert len(self.all_imgs) == len(self.all_issame), f"after concatenation {len(self.all_imgs)} images found vs {len(self.all_issame)} labels, maybe regenerate the memfiles"
assert len(self.all_issame) == len(self.all_dataname), f"after concatenation {len(self.all_dataname)} dataname associations found vs {len(self.all_issame)} labels"

def __getitem__(self, index):
x_np = self.all_imgs[index].copy()
Expand Down
14 changes: 7 additions & 7 deletions evaluate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,22 @@
from sklearn.model_selection import KFold
from sklearn.decomposition import PCA
import sklearn
from collections import OrderedDict
from scipy import interpolate

def get_val_data(data_path):
agedb_30, agedb_30_issame = get_val_pair(data_path, 'agedb_30')
cfp_fp, cfp_fp_issame = get_val_pair(data_path, 'cfp_fp')
lfw, lfw_issame = get_val_pair(data_path, 'lfw')
cplfw, cplfw_issame = get_val_pair(data_path, 'cplfw')
calfw, calfw_issame = get_val_pair(data_path, 'calfw')
return agedb_30, cfp_fp, lfw, agedb_30_issame, cfp_fp_issame, lfw_issame, cplfw, cplfw_issame, calfw, calfw_issame
all_val_sets = list(filter(lambda elt : elt[-4:] == ".bin", os.listdir(data_path)))
returned_dict = OrderedDict()
for set_name in all_val_sets:
set_name =set_name [:-4]
returned_dict[set_name] = get_val_pair(data_path, set_name)
return returned_dict

def get_val_pair(path, name, use_memfile=True):
if use_memfile:
mem_file_dir = os.path.join(path, name, 'memfile')
mem_file_name = os.path.join(mem_file_dir, 'mem_file.dat')
if os.path.isdir(mem_file_dir):
print('laoding validation data memfile')
np_array = read_memmap(mem_file_name)
else:
os.makedirs(mem_file_dir)
Expand Down
6 changes: 5 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
pytorch-lightning==1.8.6
pytorch<=1.13.1
torch<=1.13.1
tqdm
bcolz-zipline
prettytable
menpo
mxnet
opencv-python
scikit-learn
torchvision<=0.14.1
pandas
numpy<=1.23
18 changes: 12 additions & 6 deletions train_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.optim.lr_scheduler as lr_scheduler
from pytorch_lightning.core import LightningModule
from torch.nn import CrossEntropyLoss
from torch.nn import Module as TorchModule
import evaluate_utils
import head
import net
Expand All @@ -14,6 +15,8 @@ class Trainer(LightningModule):
def __init__(self, **kwargs):
super(Trainer, self).__init__()
self.save_hyperparameters() # sets self.hparams
self.val_dataname_to_idx = None
self.test_dataname_to_idx = None

self.class_num = utils.get_num_class(self.hparams)
print('classnum: {}'.format(self.class_num))
Expand Down Expand Up @@ -88,7 +91,7 @@ def training_epoch_end(self, outputs):
def validation_step(self, batch, batch_idx):
images, labels, dataname, image_index = batch
embeddings, norms = self.model(images)

fliped_images = torch.flip(images, dims=[3])
flipped_embeddings, flipped_norms = self.model(fliped_images)
stacked_embeddings = torch.stack([embeddings, flipped_embeddings], dim=0)
Expand All @@ -115,10 +118,12 @@ def validation_step(self, batch, batch_idx):
}

def validation_epoch_end(self, outputs):

all_output_tensor, all_norm_tensor, all_target_tensor, all_dataname_tensor = self.gather_outputs(outputs)

dataname_to_idx = {"agedb_30": 0, "cfp_fp": 1, "lfw": 2, "cplfw": 3, "calfw": 4}

if self.val_dataname_to_idx is None:
self.val_dataname_to_idx = self.trainer.val_dataloaders[0].dataset.dataname_to_idx
dataname_to_idx = self.val_dataname_to_idx#{"agedb_30": 0, "cfp_fp": 1, "lfw": 2, "cplfw": 3, "calfw": 4}
idx_to_dataname = {val: key for key, val in dataname_to_idx.items()}
val_logs = {}
for dataname_idx in all_dataname_tensor.unique():
Expand All @@ -139,7 +144,6 @@ def validation_epoch_end(self, outputs):
val_logs[f'{dataname}_val_acc'] for dataname in dataname_to_idx.keys() if f'{dataname}_val_acc' in val_logs
])
val_logs['epoch'] = self.current_epoch

for k, v in val_logs.items():
# self.log(name=k, value=v, rank_zero_only=True)
self.log(name=k, value=v)
Expand All @@ -150,10 +154,12 @@ def test_step(self, batch, batch_idx):
return self.validation_step(batch, batch_idx)

def test_epoch_end(self, outputs):

all_output_tensor, all_norm_tensor, all_target_tensor, all_dataname_tensor = self.gather_outputs(outputs)

dataname_to_idx = {"agedb_30": 0, "cfp_fp": 1, "lfw": 2, "cplfw": 3, "calfw": 4}
if self.test_dataname_to_idx is None:
self.test_dataname_to_idx = self.trainer.test_dataloaders[0].dataset.dataname_to_idx

dataname_to_idx = self.test_dataname_to_idx # {"agedb_30": 0, "cfp_fp": 1, "lfw": 2, "cplfw": 3, "calfw": 4}
idx_to_dataname = {val: key for key, val in dataname_to_idx.items()}
test_logs = {}
for dataname_idx in all_dataname_tensor.unique():
Expand Down
20 changes: 18 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
import torch.distributed as dist


class dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
Expand Down Expand Up @@ -37,14 +36,31 @@ def is_dist_avail_and_initialized():
return True

def get_world_size():
try:
return int(os.environ["SLURM_TASKS_PER_NODE"])
except Exception as e:
print("SLURM not detected fro infering world size", os.environ("SLURM_TASKS_PER_NODE"), flush=True)
pass
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()

def get_local_rank():
try:
return int(os.environ["SLURM_LOCALID"])
except Exception as e:
print("SLURM NOT DETECTED FOR INFEREING LOCAL ID", flush=True)
pass
if not is_dist_avail_and_initialized():
return 0
return int(os.environ["LOCAL_RANK"])

try:
return int(os.environ["LOCAL_RANK"])
except Exception as e:
print("env varaiable LOCAL_RANK is NOT SET, setting it though dist", flush=True)
print("local rank is ", dist.get_rank()," and world size is ", get_world_size(), flush=True)
return int(dist.get_rank())


def all_gather(data):
"""
Expand Down