Skip to content

Commit

Permalink
initial
Browse files Browse the repository at this point in the history
  • Loading branch information
Tengfei-Wang committed Oct 31, 2021
1 parent 5c66e66 commit 14924fc
Show file tree
Hide file tree
Showing 81 changed files with 3,278 additions and 0 deletions.
Empty file added configs/__init__.py
Empty file.
20 changes: 20 additions & 0 deletions configs/data_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from configs import transforms_config
from configs.paths_config import dataset_paths


DATASETS = {
'ffhq_encode': {
'transforms': transforms_config.EncodeTransforms,
'train_source_root': dataset_paths['ffhq'],
'train_target_root': dataset_paths['ffhq'],
'test_source_root': dataset_paths['ffhq_val'],
'test_target_root': dataset_paths['ffhq_val'],
},
'cars_encode': {
'transforms': transforms_config.CarsEncodeTransforms,
'train_source_root': dataset_paths['cars_train'],
'train_target_root': dataset_paths['cars_train'],
'test_source_root': dataset_paths['cars_val'],
'test_target_root': dataset_paths['cars_val'],
}
}
16 changes: 16 additions & 0 deletions configs/paths_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
dataset_paths = {
# Face Datasets (FFHQ - train, CelebA-HQ - test)
'ffhq': '',
'ffhq_val': '',

# Cars Dataset (Stanford cars)
'cars_train': '',
'cars_val': '',
}

model_paths = {
'stylegan_ffhq': './pretrained/stylegan2-ffhq-config-f.pt',
'ir_se50': './pretrained/model_ir_se50.pth',
'shape_predictor': './pretrained/shape_predictor_68_face_landmarks.dat',
'moco': './pretrained/moco_v2_800ep_pretrain.pt'
}
62 changes: 62 additions & 0 deletions configs/transforms_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from abc import abstractmethod
import torchvision.transforms as transforms


class TransformsConfig(object):

def __init__(self, opts):
self.opts = opts

@abstractmethod
def get_transforms(self):
pass


class EncodeTransforms(TransformsConfig):

def __init__(self, opts):
super(EncodeTransforms, self).__init__(opts)

def get_transforms(self):
transforms_dict = {
'transform_gt_train': transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(0.5),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_source': None,
'transform_test': transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_inference': transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
}
return transforms_dict


class CarsEncodeTransforms(TransformsConfig):

def __init__(self, opts):
super(CarsEncodeTransforms, self).__init__(opts)

def get_transforms(self):
transforms_dict = {
'transform_gt_train': transforms.Compose([
transforms.Resize((192, 256)),
transforms.RandomHorizontalFlip(0.5),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_source': None,
'transform_test': transforms.Compose([
transforms.Resize((192, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_inference': transforms.Compose([
transforms.Resize((192, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
}
return transforms_dict
Empty file added criteria/__init__.py
Empty file.
47 changes: 47 additions & 0 deletions criteria/id_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
from torch import nn
from configs.paths_config import model_paths
from models.encoders.model_irse import Backbone


class IDLoss(nn.Module):
def __init__(self):
super(IDLoss, self).__init__()
print('Loading ResNet ArcFace')
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
self.facenet.load_state_dict(torch.load(model_paths['ir_se50']))
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
self.facenet.eval()
for module in [self.facenet, self.face_pool]:
for param in module.parameters():
param.requires_grad = False

def extract_feats(self, x):
x = x[:, :, 35:223, 32:220] # Crop interesting region
x = self.face_pool(x)
x_feats = self.facenet(x)
return x_feats

def forward(self, y_hat, y, x):
n_samples = x.shape[0]
x_feats = self.extract_feats(x)
y_feats = self.extract_feats(y) # Otherwise use the feature from there
y_hat_feats = self.extract_feats(y_hat)
y_feats = y_feats.detach()
loss = 0
sim_improvement = 0
id_logs = []
count = 0
for i in range(n_samples):
diff_target = y_hat_feats[i].dot(y_feats[i])
diff_input = y_hat_feats[i].dot(x_feats[i])
diff_views = y_feats[i].dot(x_feats[i])
id_logs.append({'diff_target': float(diff_target),
'diff_input': float(diff_input),
'diff_views': float(diff_views)})
loss += 1 - diff_target
id_diff = float(diff_target) - float(diff_views)
sim_improvement += id_diff
count += 1

return loss / count, sim_improvement / count, id_logs
Empty file added criteria/lpips/__init__.py
Empty file.
Binary file added criteria/lpips/__pycache__/__init__.cpython-36.pyc
Binary file not shown.
Binary file added criteria/lpips/__pycache__/lpips.cpython-36.pyc
Binary file not shown.
Binary file not shown.
Binary file added criteria/lpips/__pycache__/utils.cpython-36.pyc
Binary file not shown.
35 changes: 35 additions & 0 deletions criteria/lpips/lpips.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
import torch.nn as nn

from criteria.lpips.networks import get_network, LinLayers
from criteria.lpips.utils import get_state_dict


class LPIPS(nn.Module):
r"""Creates a criterion that measures
Learned Perceptual Image Patch Similarity (LPIPS).
Arguments:
net_type (str): the network type to compare the features:
'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
version (str): the version of LPIPS. Default: 0.1.
"""
def __init__(self, net_type: str = 'alex', version: str = '0.1'):

assert version in ['0.1'], 'v0.1 is only supported now'

super(LPIPS, self).__init__()

# pretrained network
self.net = get_network(net_type).to("cuda")

# linear layers
self.lin = LinLayers(self.net.n_channels_list).to("cuda")
self.lin.load_state_dict(get_state_dict(net_type, version))

def forward(self, x: torch.Tensor, y: torch.Tensor):
feat_x, feat_y = self.net(x), self.net(y)

diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]

return torch.sum(torch.cat(res, 0)) / x.shape[0]
98 changes: 98 additions & 0 deletions criteria/lpips/networks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from typing import Sequence

from itertools import chain

import torch
import torch.nn as nn
from torchvision import models

from criteria.lpips.utils import normalize_activation


def get_network(net_type: str):
if net_type == 'alex':
return AlexNet()
elif net_type == 'squeeze':
return SqueezeNet()
elif net_type == 'vgg':
return VGG16()
else:
raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')


class LinLayers(nn.ModuleList):
def __init__(self, n_channels_list: Sequence[int]):
super(LinLayers, self).__init__([
nn.Sequential(
nn.Identity(),
nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
) for nc in n_channels_list
])

for param in self.parameters():
param.requires_grad = False


class BaseNet(nn.Module):
def __init__(self):
super(BaseNet, self).__init__()

# register buffer
self.register_buffer(
'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
self.register_buffer(
'std', torch.Tensor([.458, .448, .450])[None, :, None, None])

def set_requires_grad(self, state: bool):
for param in chain(self.parameters(), self.buffers()):
param.requires_grad = state

def z_score(self, x: torch.Tensor):
return (x - self.mean) / self.std

def forward(self, x: torch.Tensor):
x = self.z_score(x)

output = []
for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
x = layer(x)
if i in self.target_layers:
output.append(normalize_activation(x))
if len(output) == len(self.target_layers):
break
return output


class SqueezeNet(BaseNet):
def __init__(self):
super(SqueezeNet, self).__init__()

self.layers = models.squeezenet1_1(True).features
self.target_layers = [2, 5, 8, 10, 11, 12, 13]
self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]

self.set_requires_grad(False)


class AlexNet(BaseNet):
def __init__(self):
super(AlexNet, self).__init__()

model = models.alexnet(True)
#model.load_state_dict(torch.load('./pretrained/alexnet-owt-4df8aa71.pth'))
self.layers = model.features
self.target_layers = [2, 5, 8, 10, 12]
self.n_channels_list = [64, 192, 384, 256, 256]

self.set_requires_grad(False)


class VGG16(BaseNet):
def __init__(self):
super(VGG16, self).__init__()

self.layers = models.vgg16(True).features
self.target_layers = [4, 9, 16, 23, 30]
self.n_channels_list = [64, 128, 256, 512, 512]

self.set_requires_grad(False)
31 changes: 31 additions & 0 deletions criteria/lpips/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from collections import OrderedDict

import torch


def normalize_activation(x, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
return x / (norm_factor + eps)


def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
# build url
url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
+ f'master/lpips/weights/v{version}/{net_type}.pth'

# download
old_state_dict = torch.hub.load_state_dict_from_url(
url, progress=True,
map_location=None if torch.cuda.is_available() else torch.device('cpu')
)

# rename keys
new_state_dict = OrderedDict()
for key, val in old_state_dict.items():
new_key = key
new_key = new_key.replace('lin', '')
new_key = new_key.replace('model.', '')
new_state_dict[new_key] = val

return new_state_dict

71 changes: 71 additions & 0 deletions criteria/moco_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import torch
from torch import nn
import torch.nn.functional as F

from configs.paths_config import model_paths


class MocoLoss(nn.Module):

def __init__(self, opts):
super(MocoLoss, self).__init__()
print("Loading MOCO model from path: {}".format(model_paths["moco"]))
self.model = self.__load_model()
self.model.eval()
for param in self.model.parameters():
param.requires_grad = False

@staticmethod
def __load_model():
import torchvision.models as models
model = models.__dict__["resnet50"]()
# freeze all layers but the last fc
for name, param in model.named_parameters():
if name not in ['fc.weight', 'fc.bias']:
param.requires_grad = False
checkpoint = torch.load(model_paths['moco'], map_location="cpu")
state_dict = checkpoint['state_dict']
# rename moco pre-trained keys
for k in list(state_dict.keys()):
# retain only encoder_q up to before the embedding layer
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
# remove prefix
state_dict[k[len("module.encoder_q."):]] = state_dict[k]
# delete renamed or unused k
del state_dict[k]
msg = model.load_state_dict(state_dict, strict=False)
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
# remove output layer
model = nn.Sequential(*list(model.children())[:-1]).cuda()
return model

def extract_feats(self, x):
x = F.interpolate(x, size=224)
x_feats = self.model(x)
x_feats = nn.functional.normalize(x_feats, dim=1)
x_feats = x_feats.squeeze()
return x_feats

def forward(self, y_hat, y, x):
n_samples = x.shape[0]
x_feats = self.extract_feats(x)
y_feats = self.extract_feats(y)
y_hat_feats = self.extract_feats(y_hat)
y_feats = y_feats.detach()
loss = 0
sim_improvement = 0
sim_logs = []
count = 0
for i in range(n_samples):
diff_target = y_hat_feats[i].dot(y_feats[i])
diff_input = y_hat_feats[i].dot(x_feats[i])
diff_views = y_feats[i].dot(x_feats[i])
sim_logs.append({'diff_target': float(diff_target),
'diff_input': float(diff_input),
'diff_views': float(diff_views)})
loss += 1 - diff_target
sim_diff = float(diff_target) - float(diff_views)
sim_improvement += sim_diff
count += 1

return loss / count, sim_improvement / count, sim_logs
Empty file added datasets/__init__.py
Empty file.
Loading

0 comments on commit 14924fc

Please sign in to comment.