diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000..ff850b5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,13 @@ +checkpoints/ +datasets/ +results/ +*.tar.gz +*.pth +*.zip +*.pkl +*.pyc +*/__pycache__/ +*_example/ +visual_results/ +test_imgs/ +web/ diff --git a/README.md b/README.md index 0c6c6b8..2d6b6c8 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,51 @@ # Open-Edit: Open-Domain Image Manipulation with Open-Vocabulary Instructions -[Xihui Liu](https://xh-liu.github.io), Zhe Lin, Jianming Zhang, Handong Zhao, Quan Tran, Xiaogang Wang, and Hongsheng Li.
+[Xihui Liu](https://xh-liu.github.io), [Zhe Lin](https://sites.google.com/site/zhelin625/), [Jianming Zhang](http://cs-people.bu.edu/jmzhang/), [Handong Zhao](https://hdzhao.github.io/), [Quan Tran](https://research.adobe.com/person/quan-hung-tran/), [Xiaogang Wang](https://www.ee.cuhk.edu.hk/~xgwang/), and [Hongsheng Li](https://www.ee.cuhk.edu.hk/~hsli/).
Published in ECCV 2020. -### [Paper](https://arxiv.org/pdf/2008.01576.pdf) | [1-minute video](https://youtu.be/8E3bwvjCHYE) +### [Paper](https://arxiv.org/pdf/2008.01576.pdf) | [1-minute video](https://youtu.be/8E3bwvjCHYE) | [Slides](https://drive.google.com/file/d/1m3JKSUotm6sRImak_qjwBMtMtd037XeK/view?usp=sharing) ![results](results.jpg) -### Code Coming Soon! +### Installation +Clone this repo. +```bash +git clone https://github.com/xh-liu/Open-Edit +cd Open-Edit +``` + +Install [PyTorch 1.1+](https://pytorch.org/get-started/locally/) and other requirements. +```bash + +pip install -r requirements.txt +``` + +### Download pretrained models + +Download pretrained models from [Google Drive](https://drive.google.com/drive/folders/1iG_II7_PytTY6NdzyZ5WDkzPTXcB2NcE?usp=sharing) + +### Data preparation + +We use [Conceptual Captions dataset](https://ai.google.com/research/ConceptualCaptions/download) for training. Download the dataset and put it under the dataset folder. You can also use other datasets + +### Training + +The visual-semantic embedding model is trained with [VSE++](https://github.com/fartashf/vsepp). + +The image decoder is trained with: + +```bash +bash train.sh +``` + +## Testing + +You can specify the image path and text instructions in test.sh. + +```bash +bash test.sh +``` ### Citation If you use this code for your research, please cite our papers. diff --git a/data/__init__.py b/data/__init__.py new file mode 100755 index 0000000..ae1e636 --- /dev/null +++ b/data/__init__.py @@ -0,0 +1,60 @@ +import importlib +import torch.utils.data + +def collate_fn_img(images): + images = torch.stack(images, 0) + input_dicts = {'image': images} + return input_dicts + + +def find_dataset_using_name(dataset_name): + # Given the option --dataset [datasetname], + # the file "datasets/datasetname_dataset.py" + # will be imported. + dataset_filename = "data." + dataset_name + "_dataset" + datasetlib = importlib.import_module(dataset_filename) + + # In the file, the class called DatasetNameDataset() will + # be instantiated. It has to be a subclass of BaseDataset, + # and it is case-insensitive. + dataset = None + target_dataset_name = dataset_name.replace('_', '') + 'dataset' + for name, cls in datasetlib.__dict__.items(): + if name.lower() == target_dataset_name.lower(): + dataset = cls + + if dataset is None: + raise ValueError("In %s.py, there should be a subclass of BaseDataset " + "with class name that matches %s in lowercase." % + (dataset_filename, target_dataset_name)) + + return dataset + +def create_dataloader(opt, world_size, rank): + dataset = find_dataset_using_name(opt.dataset_mode) + instance = dataset(opt) + print("dataset [%s] of size %d was created" % + (type(instance).__name__, len(instance))) + + collate_fn = collate_fn_img + + if opt.mpdist: + train_sampler = torch.utils.data.distributed.DistributedSampler(instance, num_replicas=world_size, rank=rank) + dataloader = torch.utils.data.DataLoader( + instance, + batch_size=opt.batchSize, + sampler=train_sampler, + shuffle=False, + num_workers=int(opt.nThreads), + collate_fn=collate_fn, + drop_last=opt.isTrain + ) + else: + dataloader = torch.utils.data.DataLoader( + instance, + batch_size=opt.batchSize, + shuffle=not opt.serial_batches, + num_workers=int(opt.nThreads), + drop_last=opt.isTrain + ) + return dataloader diff --git a/data/conceptual_dataset.py b/data/conceptual_dataset.py new file mode 100644 index 0000000..cea6a47 --- /dev/null +++ b/data/conceptual_dataset.py @@ -0,0 +1,32 @@ +import torch +import torch.utils.data as data +import torchvision.transforms as transforms +import os +from PIL import Image +import json + +class ConceptualDataset(data.Dataset): + def __init__(self, opt): + self.path = os.path.join(opt.dataroot, 'images') + if opt.isTrain: + self.ids = json.load(open(os.path.join(opt.dataroot, 'val_index.json'), 'r')) + else: + self.ids = json.load(open(os.path.join(opt.dataroot, 'val_index.json'), 'r')) + + transforms_list = [] + transforms_list.append(transforms.Resize((opt.img_size, opt.img_size))) + transforms_list += [transforms.ToTensor()] + transforms_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + self.transform = transforms.Compose(transforms_list) + + def __getitem__(self, index): + """This function returns a tuple that is further passed to collate_fn + """ + img_id = self.ids[index] + image = Image.open(os.path.join(self.path, img_id)).convert('RGB') + image = self.transform(image) + + return image + + def __len__(self): + return len(self.ids) diff --git a/models/OpenEdit_model.py b/models/OpenEdit_model.py new file mode 100644 index 0000000..e94db5d --- /dev/null +++ b/models/OpenEdit_model.py @@ -0,0 +1,256 @@ +import torch +import torch.nn.functional as F +import models.networks as networks +import util.util as util +import pickle +from models.networks.txt_enc import EncoderText +from models.networks.perturbation import PerturbationNet +import random + +def l2norm(x, norm_dim=1): + norm = torch.pow(x, 2).sum(dim=norm_dim, keepdim=True).sqrt() + x = torch.div(x, norm) + return x + +class OpenEditModel(torch.nn.Module): + + def __init__(self, opt): + super().__init__() + self.opt = opt + self.FloatTensor = torch.cuda.FloatTensor + self.ByteTensor = torch.cuda.ByteTensor + self.perturbation = opt.perturbation + + self.netG, self.netD, self.netE = self.initialize_networks(opt) + + self.generator = opt.netG + + self.noise_range = opt.noise_range + if self.perturbation: + self.netP = PerturbationNet(opt) + self.netP.cuda() + if self.opt.manipulation: + self.vocab = pickle.load(open(opt.vocab_path, 'rb')) + self.txt_enc = EncoderText(len(self.vocab), 300, 1024, 1) + self.txt_enc.load_state_dict(torch.load(opt.vse_enc_path, map_location='cpu')['model'][1]) + self.txt_enc.eval().cuda() + + # set loss functions + if self.perturbation: + self.criterionPix = torch.nn.L1Loss() + self.criterionVGG = networks.VGGLoss(self.opt.gpu) + elif opt.isTrain: + self.loss_l1pix = opt.l1pix_loss + self.loss_gan = not opt.no_disc + self.loss_ganfeat = not opt.no_ganFeat_loss + self.loss_vgg = not opt.no_vgg_loss + + if self.loss_l1pix: + self.criterionPix = torch.nn.L1Loss() + if self.loss_gan: + self.criterionGAN = networks.GANLoss( + opt.gan_mode, tensor=self.FloatTensor, opt=self.opt) + if self.loss_ganfeat: + self.criterionFeat = torch.nn.L1Loss() + if self.loss_vgg: + self.criterionVGG = networks.VGGLoss(self.opt.gpu) + + def forward(self, data, mode, ori_cap=None, new_cap=None, alpha=1, global_edit=False): + if not data['image'].is_cuda: + data['image'] = data['image'].cuda() + if mode == 'generator': + g_loss, generated = self.compute_generator_loss(data) + return g_loss, generated + elif mode == 'discriminator': + d_loss = self.compute_discriminator_loss(data) + return d_loss + elif mode == 'inference': + with torch.no_grad(): + fake_image, _ = self.generate_fake(data['image']) + return fake_image + elif mode == 'manipulate': + with torch.no_grad(): + fake_image = self.manipulate(data['image'], ori_cap, new_cap, alpha, global_edit=global_edit) + return fake_image + elif mode == 'optimize': + g_loss, generated = self.optimizeP(data, ori_cap, new_cap, alpha, global_edit=global_edit) + return g_loss, generated + else: + raise ValueError("|mode| is invalid") + + def create_P_optimizers(self, opt): + P_params = list(self.netP.parameters()) + + beta1, beta2 = 0, 0.9 + P_lr = opt.lr + + optimizer_P = torch.optim.Adam(P_params, lr=P_lr, betas=(beta1, beta2)) + + return optimizer_P + + + def create_optimizers(self, opt): + G_params = list(self.netG.parameters()) + if opt.isTrain and self.loss_gan: + D_params = list(self.netD.parameters()) + + if opt.no_TTUR: + beta1, beta2 = opt.beta1, opt.beta2 + G_lr, D_lr = opt.lr, opt.lr + else: + beta1, beta2 = 0, 0.9 + G_lr, D_lr = opt.lr / 2, opt.lr * 2 + + optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2)) + if self.loss_gan and opt.isTrain: + optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2)) + else: + optimizer_D = None + + return optimizer_G, optimizer_D + + def save(self, epoch): + util.save_network(self.netG, 'G', epoch, self.opt) + if self.loss_gan: + util.save_network(self.netD, 'D', epoch, self.opt) + + def initialize_networks(self, opt): + print(opt.isTrain) + netG = networks.define_G(opt) + netD = networks.define_D(opt) if opt.isTrain and not opt.no_disc else None + netE = networks.define_E(opt) + + if not opt.isTrain or opt.continue_train or opt.manipulation: + netG = util.load_network(netG, 'G', opt.which_epoch, opt) + if opt.isTrain or opt.needs_D: + netD = util.load_network(netD, 'D', opt.which_epoch, opt) + print('network D loaded') + + return netG, netD, netE + + def compute_generator_loss(self, data): + G_losses = {} + + fake_image, spatial_embedding = self.generate_fake(data['image']) + + if self.loss_l1pix: + G_losses['Pix'] = self.criterionPix(fake_image, data['image']) + if self.loss_gan: + pred_fake, pred_real = self.discriminate(fake_image, data) + + if self.loss_ganfeat: + actual_num_D = 0 + num_D = len(pred_fake) + GAN_Feat_loss = self.FloatTensor(1).fill_(0) + for i in range(num_D): # for each discriminator + # last output is the final prediction, so we exclude it + num_intermediate_outputs = len(pred_fake[i]) - 1 + if num_intermediate_outputs == 0: + continue + else: + actual_num_D += 1 + for j in range(num_intermediate_outputs): # for each layer output + unweighted_loss = self.criterionFeat( + pred_fake[i][j], pred_real[i][j].detach()) + GAN_Feat_loss += unweighted_loss * self.opt.lambda_feat + G_losses['GAN_Feat'] = GAN_Feat_loss / actual_num_D + + if self.loss_gan: + G_losses['GAN'] = self.criterionGAN(pred_fake, True, + for_discriminator=False) + + if self.loss_vgg: + G_losses['VGG'] = self.criterionVGG(fake_image, data['image']) \ + * self.opt.lambda_vgg + + return G_losses, fake_image + + def compute_discriminator_loss(self, data): + D_losses = {} + with torch.no_grad(): + fake_image, spatial_embedding = self.generate_fake(data['image']) + + pred_fake, pred_real = self.discriminate(fake_image, data) + + D_losses['D_Fake'] = self.criterionGAN(pred_fake, False, + for_discriminator=True) + D_losses['D_real'] = self.criterionGAN(pred_real, True, + for_discriminator=True) + + return D_losses + + def generate_fake(self, real_image, return_skip = False, return_edge = False): + + with torch.no_grad(): + spatial_embedding, edge = self.netE(real_image) + fake_image = self.netG(spatial_embedding, edge) + if return_edge: + return fake_image, spatial_embedding, edge + else: + return fake_image, spatial_embedding + + def generate_fake_withP(self, real_image): + with torch.no_grad(): + spatial_embedding, edge = self.netE(real_image) + p, P_reg = self.netP() + fake_image = self.netG(spatial_embedding, edge, perturbation=True, p=p) + + return fake_image, P_reg + + def discriminate(self, fake_image, data, edge=None): + + fake_and_real_img = torch.cat([fake_image, data['image']], dim=0) + + discriminator_out = self.netD(fake_and_real_img) + + pred_fake, pred_real = self.divide_pred(discriminator_out) + + return pred_fake, pred_real + + def manipulate(self, real_image, ori_txt, new_txt, alpha, global_edit=False): + + spatial_embedding, edge = self.netE(real_image, norm=True) + + with torch.no_grad(): + ori_txt = self.txt_enc(ori_txt, [ori_txt.shape[1]]) + new_txt = self.txt_enc(new_txt, [new_txt.shape[1]]) + + proj = spatial_embedding * ori_txt.unsqueeze(2).unsqueeze(3).repeat(1,1,spatial_embedding.shape[2],spatial_embedding.shape[3]) + proj_s = proj.sum(1, keepdim=True) + proj = proj_s.repeat(1,1024,1,1) + # proj = F.sigmoid(proj) + + if global_edit: + proj[:] = 1 # for global attributes, don't need to do grounding + + spatial_embedding = spatial_embedding - alpha * proj * ori_txt.unsqueeze(2).unsqueeze(3).repeat(1,1,spatial_embedding.shape[2],spatial_embedding.shape[3]) + spatial_embedding = spatial_embedding + alpha * proj * new_txt.unsqueeze(2).unsqueeze(3).repeat(1,1,spatial_embedding.shape[2],spatial_embedding.shape[3]) + + spatial_embedding = l2norm(spatial_embedding) + + if self.perturbation: + p, P_reg = self.netP() + fake_image = self.netG(spatial_embedding, edge, perturbation=True, p=p) + return fake_image + else: + fake_image = self.netG(spatial_embedding, edge) + return fake_image + + def optimizeP(self, data, ori_cap, new_cap, alpha, global_edit=False): + + P_losses = {} + + fake_image, P_reg = self.generate_fake_withP(data['image']) + + manip_image = self.manipulate(data['image'], ori_cap, new_cap, alpha, global_edit=global_edit) + cycle_image = self.manipulate(manip_image, new_cap, ori_cap, alpha, global_edit=global_edit) + + P_losses['Pix'] = self.criterionPix(fake_image, data['image']) + P_losses['VGG'] = self.criterionVGG(fake_image, data['image']) * self.opt.lambda_vgg + + P_losses['cyclePix'] = self.criterionPix(cycle_image, data['image']) + P_losses['cycleVGG'] = self.criterionVGG(cycle_image, data['image']) * self.opt.lambda_vgg + + P_losses['Reg'] = P_reg + + return P_losses, fake_image diff --git a/models/__init__.py b/models/__init__.py new file mode 100755 index 0000000..112acce --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,34 @@ +import importlib +import torch + + +def find_model_using_name(model_name): + # Given the option --model [modelname], + # the file "models/modelname_model.py" + # will be imported. + model_filename = "models." + model_name + "_model" + modellib = importlib.import_module(model_filename) + + # In the file, the class called ModelNameModel() will + # be instantiated. It has to be a subclass of torch.nn.Module, + # and it is case-insensitive. + model = None + target_model_name = model_name.replace('_', '') + 'model' + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower() \ + and issubclass(cls, torch.nn.Module): + model = cls + + if model is None: + print("In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase." % (model_filename, target_model_name)) + exit(0) + + return model + + +def create_model(opt): + model = find_model_using_name(opt.model) + instance = model(opt) + print("model [%s] was created" % (type(instance).__name__)) + + return instance diff --git a/models/networks/__init__.py b/models/networks/__init__.py new file mode 100644 index 0000000..8254fed --- /dev/null +++ b/models/networks/__init__.py @@ -0,0 +1,59 @@ +import importlib +import torch +import torch.nn as nn +import functools +from torch.autograd import Variable +import numpy as np +import torch.nn.functional as F +from models.networks.base_network import BaseNetwork +from models.networks.loss import * +from models.networks.discriminator import * +from models.networks.generator import * +from models.networks.encoder import * +import util.util as util + + +def find_network_using_name(target_network_name, filename): + target_class_name = target_network_name + filename + module_name = 'models.networks.' + filename + network = util.find_class_in_module(target_class_name, module_name) + + assert issubclass(network, BaseNetwork), \ + "Class %s should be a subclass of BaseNetwork" % network + + return network + +def create_network(cls, opt, init=True, cuda=True): + net = cls(opt) + net.print_network() + assert(torch.cuda.is_available()) + if cuda: + if opt.mpdist: + net.cuda(opt.gpu) + else: + net.cuda() + if init: + net.init_weights(opt.init_type, opt.init_variance) + return net + +def define_G(opt): + netG_cls = find_network_using_name(opt.netG, 'generator') + return create_network(netG_cls, opt) + + +def define_D(opt): + netD_cls = find_network_using_name(opt.netD, 'discriminator') + return create_network(netD_cls, opt) + + +def define_E(opt): + netE_cls = find_network_using_name(opt.netE, 'encoder') + netE = create_network(netE_cls, opt, init=False, cuda=False) + state_dict = torch.load(opt.vse_enc_path, map_location='cpu') + netE.load_state_dict(state_dict['model'][0]) + if opt.mpdist: + netE.cuda(opt.gpu) + else: + netE.cuda() + netE.eval() + return netE diff --git a/models/networks/architecture.py b/models/networks/architecture.py new file mode 100644 index 0000000..31ef92e --- /dev/null +++ b/models/networks/architecture.py @@ -0,0 +1,92 @@ +import math +import re +import torch +import torch.nn as nn +import numpy as np +import torch.nn.functional as F +import torchvision +import torch.nn.utils.spectral_norm as spectral_norm +from torch.nn import Parameter +from models.networks.base_network import BaseNetwork +from models.networks.normalization import get_norm_layer +from models.networks.normalization import SPADE + + +## VGG architecter, used for the perceptual loss using a pretrained VGG network +class VGG19(torch.nn.Module): + def __init__(self, requires_grad=False, local_pretrained_path='checkpoints/vgg19.pth'): + super().__init__() + # if we have network access + # vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features + # if no network access + model = torchvision.models.vgg19() + model.load_state_dict(torch.load(local_pretrained_path)) + vgg_pretrained_features = model.features + + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] + return out + +class SPADEResnetBlock(nn.Module): + def __init__(self, fin, fout, opt): + super().__init__() + # Attributes + self.learned_shortcut = (fin != fout) + fmiddle = min(fin, fout) + # create conv layers + self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) + self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) + if self.learned_shortcut: + self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) + + self.conv_0 = spectral_norm(self.conv_0) + self.conv_1 = spectral_norm(self.conv_1) + if self.learned_shortcut: + self.conv_s = spectral_norm(self.conv_s) + spade_config_str = 'batchnorm3x3' + self.norm_0 = SPADE(spade_config_str, fin, opt.edge_nc, opt) + self.norm_1 = SPADE(spade_config_str, fmiddle, opt.edge_nc, opt) + if self.learned_shortcut: + self.norm_s = SPADE(spade_config_str, fin, opt.edge_nc, opt) + + def forward(self, x, seg): + x_s = self.shortcut(x, seg) + dx = self.conv_0(self.actvn(self.norm_0(x, seg))) + dx = self.conv_1(self.actvn(self.norm_1(dx, seg))) + out = x_s + dx + return out + + def shortcut(self, x, seg): + if self.learned_shortcut: + x_s = self.conv_s(self.norm_s(x, seg)) + else: + x_s = x + return x_s + + def actvn(self, x): + return F.leaky_relu(x, 2e-1) + diff --git a/models/networks/base_network.py b/models/networks/base_network.py new file mode 100644 index 0000000..108e4a8 --- /dev/null +++ b/models/networks/base_network.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn +from torch.nn import init + +class BaseNetwork(nn.Module): + def __init__(self): + super(BaseNetwork, self).__init__() + + def print_network(self): + if isinstance(self, list): + self = self[0] + num_params = 0 + for param in self.parameters(): + num_params += param.numel() + print('Network [%s] was created. Total number of parameters: %.1f million. ' + 'To see the architecture, do print(network).' + % (type(self).__name__, num_params / 1000000)) + + def init_weights(self, init_type='normal', gain=0.02): + def init_func(m): + classname = m.__class__.__name__ + if classname.find('BatchNorm2d') != -1: + if hasattr(m, 'weight') and m.weight is not None: + init.normal_(m.weight.data, 1.0, gain) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'xavier_uniform': + init.xavier_uniform_(m.weight.data, gain=1.0) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=gain) + elif init_type == 'none': # uses pytorch's default init method + m.reset_parameters() + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + + self.apply(init_func) + + # propagate to children + for m in self.children(): + if hasattr(m, 'init_weights'): + m.init_weights(init_type, gain) + diff --git a/models/networks/bdcn.py b/models/networks/bdcn.py new file mode 100644 index 0000000..433d32a --- /dev/null +++ b/models/networks/bdcn.py @@ -0,0 +1,247 @@ +import numpy as np +import torch +import torch.nn as nn +from models.networks.vgg16_c import VGG16_C + +def crop(data1, data2, crop_h, crop_w): + _, _, h1, w1 = data1.size() + _, _, h2, w2 = data2.size() + assert(h2 <= h1 and w2 <= w1) + data = data1[:, :, crop_h:crop_h+h2, crop_w:crop_w+w2] + return data + +def get_upsampling_weight(in_channels, out_channels, kernel_size): + """Make a 2D bilinear kernel suitable for upsampling""" + factor = (kernel_size + 1) // 2 + if kernel_size % 2 == 1: + center = factor - 1 + else: + center = factor - 0.5 + og = np.ogrid[:kernel_size, :kernel_size] + filt = (1 - abs(og[0] - center) / factor) * \ + (1 - abs(og[1] - center) / factor) + weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), + dtype=np.float64) + weight[range(in_channels), range(out_channels), :, :] = filt + return torch.from_numpy(weight).float() + +class MSBlock(nn.Module): + def __init__(self, c_in, rate=4): + super(MSBlock, self).__init__() + c_out = c_in + self.rate = rate + + self.conv = nn.Conv2d(c_in, 32, 3, stride=1, padding=1) + self.relu = nn.ReLU(inplace=True) + dilation = self.rate*1 if self.rate >= 1 else 1 + self.conv1 = nn.Conv2d(32, 32, 3, stride=1, dilation=dilation, padding=dilation) + self.relu1 = nn.ReLU(inplace=True) + dilation = self.rate*2 if self.rate >= 1 else 1 + self.conv2 = nn.Conv2d(32, 32, 3, stride=1, dilation=dilation, padding=dilation) + self.relu2 = nn.ReLU(inplace=True) + dilation = self.rate*3 if self.rate >= 1 else 1 + self.conv3 = nn.Conv2d(32, 32, 3, stride=1, dilation=dilation, padding=dilation) + self.relu3 = nn.ReLU(inplace=True) + + self._initialize_weights() + + def forward(self, x): + o = self.relu(self.conv(x)) + o1 = self.relu1(self.conv1(o)) + o2 = self.relu2(self.conv2(o)) + o3 = self.relu3(self.conv3(o)) + out = o + o1 + o2 + o3 + return out + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + m.weight.data.normal_(0, 0.01) + if m.bias is not None: + m.bias.data.zero_() + +class BDCN_multi(nn.Module): + def __init__(self, level=1, pretrain=None, logger=None, rate=4): + super(BDCN_multi, self).__init__() + self.pretrain = pretrain + t = 1 + + self.features = VGG16_C(pretrain, logger) + self.msblock1_1 = MSBlock(64, rate) + self.msblock1_2 = MSBlock(64, rate) + self.conv1_1_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) + self.conv1_2_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) + self.score_dsn1 = nn.Conv2d(21, 1, (1, 1), stride=1) + self.score_dsn1_1 = nn.Conv2d(21, 1, 1, stride=1) + self.msblock2_1 = MSBlock(128, rate) + self.msblock2_2 = MSBlock(128, rate) + self.conv2_1_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) + self.conv2_2_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) + self.score_dsn2 = nn.Conv2d(21, 1, (1, 1), stride=1) + self.score_dsn2_1 = nn.Conv2d(21, 1, (1, 1), stride=1) + self.msblock3_1 = MSBlock(256, rate) + self.msblock3_2 = MSBlock(256, rate) + self.msblock3_3 = MSBlock(256, rate) + self.conv3_1_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) + self.conv3_2_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) + self.conv3_3_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) + self.score_dsn3 = nn.Conv2d(21, 1, (1, 1), stride=1) + self.score_dsn3_1 = nn.Conv2d(21, 1, (1, 1), stride=1) + self.msblock4_1 = MSBlock(512, rate) + self.msblock4_2 = MSBlock(512, rate) + self.msblock4_3 = MSBlock(512, rate) + self.conv4_1_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) + self.conv4_2_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) + self.conv4_3_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) + self.score_dsn4 = nn.Conv2d(21, 1, (1, 1), stride=1) + self.score_dsn4_1 = nn.Conv2d(21, 1, (1, 1), stride=1) + self.msblock5_1 = MSBlock(512, rate) + self.msblock5_2 = MSBlock(512, rate) + self.msblock5_3 = MSBlock(512, rate) + self.conv5_1_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) + self.conv5_2_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) + self.conv5_3_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) + self.score_dsn5 = nn.Conv2d(21, 1, (1, 1), stride=1) + self.score_dsn5_1 = nn.Conv2d(21, 1, (1, 1), stride=1) + self.upsample_2 = nn.ConvTranspose2d(1, 1, 4, stride=2, bias=False) + self.upsample_4 = nn.ConvTranspose2d(1, 1, 8, stride=4, bias=False) + self.upsample_8 = nn.ConvTranspose2d(1, 1, 16, stride=8, bias=False) + self.upsample_8_5 = nn.ConvTranspose2d(1, 1, 16, stride=8, bias=False) + self.fuse = nn.Conv2d(10, 1, 1, stride=1) + + self._initialize_weights(logger) + + self.level = level + + def forward(self, x): + features = self.features(x) + sum1 = self.conv1_1_down(self.msblock1_1(features[0])) + \ + self.conv1_2_down(self.msblock1_2(features[1])) + s1 = self.score_dsn1(sum1) + if self.level == 11: + return s1 + # print(s1.data.shape, s11.data.shape) + sum2 = self.conv2_1_down(self.msblock2_1(features[2])) + \ + self.conv2_2_down(self.msblock2_2(features[3])) + s2 = self.score_dsn2(sum2) + s2 = self.upsample_2(s2) + # print(s2.data.shape, s21.data.shape) + s2 = crop(s2, x, 1, 1) + if self.level == 21: + return s2 + s1 + sum3 = self.conv3_1_down(self.msblock3_1(features[4])) + \ + self.conv3_2_down(self.msblock3_2(features[5])) + \ + self.conv3_3_down(self.msblock3_3(features[6])) + s3 = self.score_dsn3(sum3) + s3 =self.upsample_4(s3) + # print(s3.data.shape) + s3 = crop(s3, x, 2, 2) + if self.level == 31: + return s3 + s2 + s1 + + sum4 = self.conv4_1_down(self.msblock4_1(features[7])) + \ + self.conv4_2_down(self.msblock4_2(features[8])) + \ + self.conv4_3_down(self.msblock4_3(features[9])) + s4 = self.score_dsn4(sum4) + s4 = self.upsample_8(s4) + # print(s4.data.shape) + s4 = crop(s4, x, 4, 4) + if self.level == 41: + return s4 + s3 + s2 + s1 + sum5 = self.conv5_1_down(self.msblock5_1(features[10])) + \ + self.conv5_2_down(self.msblock5_2(features[11])) + \ + self.conv5_3_down(self.msblock5_3(features[12])) + s5 = self.score_dsn5(sum5) + s5 = self.upsample_8_5(s5) + # print(s5.data.shape) + s5 = crop(s5, x, 0, 0) + if self.level == 51: + return s5 + s4 + s3 + s2 + s1 + + s51 = self.score_dsn5_1(sum5) + s51 = self.upsample_8_5(s51) + # print(s51.data.shape) + s51 = crop(s51, x, 0, 0) + if self.level == 52: + return s51 + s41 = self.score_dsn4_1(sum4) + s41 = self.upsample_8(s41) + # print(s41.data.shape) + s41 = crop(s41, x, 4, 4) + if self.level == 42: + return s41 + s51 + s31 = self.score_dsn3_1(sum3) + s31 =self.upsample_4(s31) + # print(s31.data.shape) + s31 = crop(s31, x, 2, 2) + if self.level == 32: + return s31 + s41 + s51 + s21 = self.score_dsn2_1(sum2) + s21 = self.upsample_2(s21) + s21 = crop(s21, x, 1, 1) + if self.level == 22: + return s22 + s31 + s41 + s51 + s11 = self.score_dsn1_1(sum1) + if self.level == 12: + return s11 + s21 + s31 + s41 + s51 + + + o1, o2, o3, o4, o5 = s1.detach(), s2.detach(), s3.detach(), s4.detach(), s5.detach() + o11, o21, o31, o41, o51 = s11.detach(), s21.detach(), s31.detach(), s41.detach(), s51.detach() + p1_1 = s1 + p2_1 = s2 + o1 + p3_1 = s3 + o2 + o1 + p4_1 = s4 + o3 + o2 + o1 + p5_1 = s5 + o4 + o3 + o2 + o1 + p1_2 = s11 + o21 + o31 + o41 + o51 + p2_2 = s21 + o31 + o41 + o51 + p3_2 = s31 + o41 + o51 + p4_2 = s41 + o51 + p5_2 = s51 + + if self.level == 11: + return p1_1 + elif self.level == 21: + return p2_1 + elif self.level == 31: + return p3_1 + elif self.level == 41: + return p4_1 + elif self.level == 51: + return p5_1 + elif self.level == 12: + return p1_2 + elif self.level == 22: + return p2_2 + elif self.level == 32: + return p3_2 + elif self.level == 42: + return p4_2 + elif self.level == 52: + return p5_2 + + def _initialize_weights(self, logger=None): + for name, param in self.state_dict().items(): + if self.pretrain and 'features' in name: + continue + # elif 'down' in name: + # param.zero_() + elif 'upsample' in name: + if logger: + logger.info('init upsamle layer %s ' % name) + k = int(name.split('.')[0].split('_')[1]) + param.copy_(get_upsampling_weight(1, 1, k*2)) + elif 'fuse' in name: + if logger: + logger.info('init params %s ' % name) + if 'bias' in name: + param.zero_() + else: + nn.init.constant(param, 0.080) + else: + if logger: + logger.info('init params %s ' % name) + if 'bias' in name: + param.zero_() + else: + param.normal_(0, 0.01) diff --git a/models/networks/discriminator.py b/models/networks/discriminator.py new file mode 100644 index 0000000..63563f7 --- /dev/null +++ b/models/networks/discriminator.py @@ -0,0 +1,94 @@ +import sys +import torch +import re +import torch.nn as nn +from collections import OrderedDict +import os.path +import functools +from torch.autograd import Variable +import numpy as np +import torch.nn.functional as F +from models.networks.base_network import BaseNetwork +from models.networks.normalization import get_norm_layer +import util.util as util + +def l2norm(X): + norm = torch.pow(X, 2).sum(dim=1, keepdim=True).sqrt() + X = torch.div(X, norm) + return X + +class MultiscaleDiscriminator(BaseNetwork): + + def __init__(self, opt): + super().__init__() + opt.num_D = 2 + self.opt = opt + + for i in range(opt.num_D): + subnetD = NLayerDiscriminator(opt) + self.add_module('discriminator_%d' % i, subnetD) + + def downsample(self, input): + return F.avg_pool2d(input, kernel_size=3, + stride=2, padding=[1, 1], + count_include_pad=False) + + ## Returns list of lists of discriminator outputs. + ## The final result is of size opt.num_D x opt.n_layers_D + def forward(self, input, semantics=None): + result = [] + get_intermediate_features = not self.opt.no_ganFeat_loss + for name, D in self.named_children(): + if semantics is None: + out = D(input) + else: + out = D(input, semantics) + if not get_intermediate_features: + out = [out] + result.append(out) + input = self.downsample(input) + + return result + +# Defines the PatchGAN discriminator with the specified arguments. +class NLayerDiscriminator(BaseNetwork): + def __init__(self, opt): + super().__init__() + opt.n_layers_D = 4 + self.opt = opt + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + nf = opt.ndf + input_nc = 3 + + norm_layer = get_norm_layer(opt, opt.norm_D) + sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, False)]] + + for n in range(1, opt.n_layers_D): + nf_prev = nf + nf = min(nf * 2, 512) + stride = 1 if n == opt.n_layers_D - 1 else 2 + sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, + stride=stride, padding=padw), opt), + nn.LeakyReLU(0.2, False) + ]] + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] + + ## We divide the layers into groups to extract intermediate layer outputs + for n in range(len(sequence)): + self.add_module('model'+str(n), nn.Sequential(*sequence[n])) + + def forward(self, input): + results = [input] + for submodel in self.children(): + intermediate_output = submodel(results[-1]) + results.append(intermediate_output) + + get_intermediate_features = not self.opt.no_ganFeat_loss + if get_intermediate_features: + return results[1:] + else: + return results[-1] diff --git a/models/networks/encoder.py b/models/networks/encoder.py new file mode 100644 index 0000000..461549c --- /dev/null +++ b/models/networks/encoder.py @@ -0,0 +1,70 @@ +import torch +import torch.nn as nn +import functools +from torch.autograd import Variable +import numpy as np +import torch.nn.functional as F +import torchvision.models as models +from models.networks.base_network import BaseNetwork +from models.networks.bdcn import BDCN_multi +import pdb + +def l2norm(x, norm_dim=1): + norm = torch.pow(x, 2).sum(dim=norm_dim, keepdim=True).sqrt() + x = torch.div(x, norm) + return x + +class ResNetbdcnEncoder(BaseNetwork): + def __init__(self, opt): + super().__init__() + img_dim = 2048 + embed_size = 1024 + self.cnn = self.get_cnn('resnet152', False) + self.cnn = nn.Sequential(self.cnn.conv1, self.cnn.bn1, + self.cnn.relu, self.cnn.maxpool, + self.cnn.layer1, self.cnn.layer2, + self.cnn.layer3, self.cnn.layer4) + self.conv1x1 = nn.Conv2d(img_dim, embed_size, 1) + # bdcn + self.edgenet = BDCN_multi(level=opt.edge_level) + edge_params = torch.load(opt.edge_model_path, map_location='cpu') + self.edgenet.load_state_dict(edge_params) + self.avg_pool = nn.AdaptiveAvgPool2d((1,1)) + self.tanh = nn.Tanh() + self.edge_tanh = opt.edge_tanh + + def forward(self, x0, skip=False, norm=True, pool=False): + xe = (x0 + 1) / 2.0 * 255 + xe = xe[:,[2,1,0],:,:] + xe[:,0,:,:] = xe[:,0,:,:] - 123.0 + xe[:,1,:,:] = xe[:,1,:,:] - 117.0 + xe[:,2,:,:] = xe[:,2,:,:] - 104.0 + edge = self.edgenet(xe) + x = self.cnn(x0) + x = self.conv1x1(x) + if pool: + x = self.avg_pool(x) + if norm: + x = l2norm(x) + if self.edge_tanh: + return x, self.tanh(edge) + else: + return x, edge + + def get_cnn(self, arch, pretrained): + if pretrained: + print("=> using pre-trained model '{}'".format(arch)) + model = models.__dict__[arch](pretrained=True) + else: + print("=> creating model '{}'".format(arch)) + model = models.__dict__[arch]() + return model + + def load_state_dict(self, state_dict): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name not in own_state: + continue + if isinstance(param, nn.Parameter): + param = param.data + own_state[name].copy_(param) diff --git a/models/networks/generator.py b/models/networks/generator.py new file mode 100644 index 0000000..4d07891 --- /dev/null +++ b/models/networks/generator.py @@ -0,0 +1,64 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from models.networks.base_network import BaseNetwork +from models.networks.architecture import SPADEResnetBlock +import random +import numpy as np + + +class OpenEditGenerator(BaseNetwork): + def __init__(self, opt): + super().__init__() + self.opt = opt + nf = opt.ngf + + self.sw = 7 + self.sh = 7 + + self.head_0 = SPADEResnetBlock(16*nf, 16*nf, opt) + + self.G_middle_0 = SPADEResnetBlock(16*nf, 16*nf, opt) + self.G_middle_1 = SPADEResnetBlock(16*nf, 16*nf, opt) + + self.up_0 = SPADEResnetBlock(16*nf, 8*nf, opt) + self.up_1 = SPADEResnetBlock(8*nf, 4*nf, opt) + self.up_2 = SPADEResnetBlock(4*nf, 2*nf, opt) + self.up_3 = SPADEResnetBlock(2*nf, 1*nf, opt) + + final_nc = nf + + self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1) + + self.up = nn.Upsample(scale_factor=2) + + def forward(self, x, edge, perturbation=False, p=None): + + # input: 1024 x 7 x 7 + x = self.head_0(x, edge) # 1024 x 7 x 7 + + x = self.up(x) # 1024 x 14 x 14 + x = self.G_middle_0(x, edge) # 1024 x 14 x 14 + if perturbation: + x = x + p[0] + + x = self.G_middle_1(x, edge) # 1024 x 14 x 14 + + x = self.up(x) # 1024 x 28 x 28 + x = self.up_0(x, edge) # 512 x 28 x 28 + if perturbation: + x = x + p[1] + x = self.up(x) # 512 x 56 x 56 + x = self.up_1(x, edge) + if perturbation: + x = x + p[2] + x = self.up(x) # 256 x 112 x 112 + x = self.up_2(x, edge) + x = self.up(x) # 128 x 224 x 224 + x = self.up_3(x, edge) + + x = self.conv_img(F.leaky_relu(x, 2e-1)) + x = F.tanh(x) + + return x + diff --git a/models/networks/loss.py b/models/networks/loss.py new file mode 100644 index 0000000..1fad989 --- /dev/null +++ b/models/networks/loss.py @@ -0,0 +1,107 @@ +import torch +import torch.nn as nn +from torch.autograd import Variable +import torch.nn.functional as F +import numpy as np +from models.networks.architecture import VGG19 + +class GANLoss(nn.Module): + def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0, + tensor=torch.FloatTensor, opt=None): + super(GANLoss, self).__init__() + self.real_label = target_real_label + self.fake_label = target_fake_label + self.real_label_tensor = None + self.fake_label_tensor = None + self.zero_tensor = None + self.Tensor = tensor + self.gan_mode = gan_mode + self.opt = opt + if gan_mode == 'ls': + pass + elif gan_mode == 'original': + pass + elif gan_mode == 'w': + pass + elif gan_mode == 'hinge': + pass + else: + raise ValueError('Unexpected gan_mode {}'.format(gan_mode)) + + def get_target_tensor(self, input, target_is_real): + if target_is_real: + if self.real_label_tensor is None: + self.real_label_tensor = self.Tensor(1).fill_(self.real_label) + self.real_label_tensor.requires_grad_(False) + return self.real_label_tensor.expand_as(input) + else: + if self.fake_label_tensor is None: + self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label) + self.fake_label_tensor.requires_grad_(False) + return self.fake_label_tensor.expand_as(input) + + def get_zero_tensor(self, input): + if self.zero_tensor is None: + self.zero_tensor = self.Tensor(1).fill_(0) + self.zero_tensor.requires_grad_(False) + return self.zero_tensor.expand_as(input) + + def loss(self, input, target_is_real, for_discriminator=True): + if self.gan_mode == 'original': # cross entropy loss + target_tensor = self.get_target_tensor(input, target_is_real) + batchsize = input.size(0) + loss = F.binary_cross_entropy_with_logits(input, target_tensor) + return loss + elif self.gan_mode == 'ls': + target_tensor = self.get_target_tensor(input, target_is_real) + return F.mse_loss(input, target_tensor) + elif self.gan_mode == 'hinge': + if for_discriminator: + if target_is_real: + minval = torch.min(input - 1, self.get_zero_tensor(input)) + loss = -torch.mean(minval) + else: + minval = torch.min(-input - 1, self.get_zero_tensor(input)) + loss = -torch.mean(minval) + else: + assert target_is_real, "The generator's hinge loss must be aiming for real" + loss = -torch.mean(input) + return loss + else: + # wgan + if target_is_real: + return -input.mean() + else: + return input.mean() + + def __call__(self, input, target_is_real, for_discriminator=True): + if isinstance(input, list): + loss = 0 + for pred_i in input: + if isinstance(pred_i, list): + pred_i = pred_i[-1] + loss_tensor = self.loss(pred_i, target_is_real, for_discriminator) + bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) + new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) + loss += new_loss + return loss / len(input) + else: + return self.loss(input, target_is_real, for_discriminator) + +## Perceptual loss that uses a pretrained VGG network +class VGGLoss(nn.Module): + def __init__(self, gpu): + super(VGGLoss, self).__init__() + if gpu is not None: + self.vgg = VGG19().cuda(gpu) + else: + self.vgg = VGG19().cuda() + self.criterion = nn.L1Loss() + self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] + + def forward(self, x, y): + x_vgg, y_vgg = self.vgg(x), self.vgg(y) + loss = 0 + for i in range(len(x_vgg)): + loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) + return loss diff --git a/models/networks/normalization.py b/models/networks/normalization.py new file mode 100644 index 0000000..bddae7e --- /dev/null +++ b/models/networks/normalization.py @@ -0,0 +1,73 @@ +import re +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.utils.spectral_norm as spectral_norm + +## Returns a function that creates a normalization function +## that does not condition on semantic map +def get_norm_layer(opt, norm_type='instance'): + # helper function to get # output channels of the previous layer + def get_out_channel(layer): + if hasattr(layer, 'out_channels'): + return getattr(layer, 'out_channels') + return layer.weight.size(0) + + # this function will be returned + def add_norm_layer(layer, opt): + nonlocal norm_type + if norm_type.startswith('spectral'): + layer = spectral_norm(layer) + subnorm_type = norm_type[len('spectral'):] + else: + subnorm_type = norm_type + + if subnorm_type == 'none' or len(subnorm_type) == 0: + return layer + + # remove bias in the previous layer, which is meaningless + # since it has no effect after normalization + if getattr(layer, 'bias', None) is not None: + delattr(layer, 'bias') + layer.register_parameter('bias', None) + + if subnorm_type == 'instance': + norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) + elif subnorm_type == 'sync_batch' and opt.mpdist: + norm_layer = nn.SyncBatchNorm(get_out_channel(layer), affine=True) + else: + norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) + + return nn.Sequential(layer, norm_layer) + + return add_norm_layer + +class SPADE(nn.Module): + def __init__(self, config_text, norm_nc, label_nc, opt): + super().__init__() + + if 'instance' in opt.norm_G: + self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) + elif 'sync_batch' in opt.norm_G and opt.mpdist: + self.param_free_norm = nn.SyncBatchNorm(norm_nc, affine=False) + else: + self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) + + nhidden = 128 + ks = 3 + pw = 1 + self.mlp_shared = nn.Sequential(nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),nn.ReLU()) + self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) + self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) + + def forward(self, x, condition): + normalized = self.param_free_norm(x) + + if x.size()[2] != condition.size()[2]: + condition = F.interpolate(condition, size=x.size()[2:], mode='nearest') + actv = self.mlp_shared(condition) + gamma = self.mlp_gamma(actv) + beta = self.mlp_beta(actv) + + out = normalized * (1 + gamma) + beta + return out diff --git a/models/networks/perturbation.py b/models/networks/perturbation.py new file mode 100644 index 0000000..b90aa31 --- /dev/null +++ b/models/networks/perturbation.py @@ -0,0 +1,22 @@ +import torch +import torch.nn as nn +import functools +import numpy as np +import torch.nn.functional as F +from models.networks.base_network import BaseNetwork + +class PerturbationNet(BaseNetwork): + def __init__(self, opt): + super().__init__() + self.reg_weight = opt.reg_weight + self.param1 = nn.Parameter(torch.zeros(1, 1024, 14, 14)) + self.param2 = nn.Parameter(torch.zeros(1, 512, 28, 28)) + self.param3 = nn.Parameter(torch.zeros(1, 256, 56, 56)) + + def parameters(self): + return [self.param1, self.param2, self.param3] + + def forward(self): + R_reg = torch.pow(self.param1, 2).sum() + torch.pow(self.param2, 2).sum() + torch.pow(self.param3, 2).sum() + R_reg = R_reg * self.reg_weight + return [self.param1, self.param2, self.param3], R_reg diff --git a/models/networks/txt_enc.py b/models/networks/txt_enc.py new file mode 100644 index 0000000..c14f3a8 --- /dev/null +++ b/models/networks/txt_enc.py @@ -0,0 +1,66 @@ +import torch +import torch.nn as nn +import torch.nn.init +import torch.nn.functional as F +import torchvision.models as models +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence +import torch.backends.cudnn as cudnn +from torch.nn.utils.clip_grad import clip_grad_norm +import numpy as np +from collections import OrderedDict + +def l2norm(X, norm_dim=1): + """L2-normalize columns of X + """ + norm = torch.pow(X, 2).sum(dim=norm_dim, keepdim=True).sqrt() + X = torch.div(X, norm) + return X + +class EncoderText(nn.Module): + + def __init__(self, vocab_size, word_dim, embed_size, num_layers, + use_abs=False, no_txt_norm=False, nonorm=False): + super(EncoderText, self).__init__() + self.no_txtnorm = no_txt_norm + self.no_norm_gd = nonorm + self.use_abs = use_abs + self.embed_size = embed_size + + # word embedding + self.embed = nn.Embedding(vocab_size, word_dim) + + # caption embedding + self.rnn = nn.GRU(word_dim, embed_size, num_layers, batch_first=True) + + self.init_weights() + + def init_weights(self): + self.embed.weight.data.uniform_(-0.1, 0.1) + + def forward(self, x, lengths, ret=True): + """Handles variable size captions + """ + # Embed word ids to vectors + x = self.embed(x) + packed = pack_padded_sequence(x, lengths, batch_first=True) + + # Forward propagate RNN + out, _ = self.rnn(packed) + + # Reshape *final* output to (batch_size, hidden_size) + padded = pad_packed_sequence(out, batch_first=True) + I = torch.LongTensor(lengths).view(-1, 1, 1) + #I = Variable(I.expand(x.size(0), 1, self.embed_size)-1).cuda() + I = (I.expand(x.size(0), 1, self.embed_size)-1).cuda() + out = torch.gather(padded[0], 1, I).squeeze(1) + + # normalization in the joint embedding space + if not self.no_txtnorm and (not self.no_norm_gd or ret): + out = l2norm(out) + + # take absolute value, used by order embeddings + if self.use_abs: + out = torch.abs(out) + + return out + diff --git a/models/networks/vgg16_c.py b/models/networks/vgg16_c.py new file mode 100755 index 0000000..78bf861 --- /dev/null +++ b/models/networks/vgg16_c.py @@ -0,0 +1,112 @@ +import numpy as np +import torch +import torchvision +import torch.nn as nn +import math + +class VGG16_C(nn.Module): + """""" + def __init__(self, pretrain=None, logger=None): + super(VGG16_C, self).__init__() + self.conv1_1 = nn.Conv2d(3, 64, (3, 3), stride=1, padding=1) + self.relu1_1 = nn.ReLU(inplace=True) + self.conv1_2 = nn.Conv2d(64, 64, (3, 3), stride=1, padding=1) + self.relu1_2 = nn.ReLU(inplace=True) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.conv2_1 = nn.Conv2d(64, 128, (3, 3), stride=1, padding=1) + self.relu2_1 = nn.ReLU(inplace=True) + self.conv2_2 = nn.Conv2d(128, 128, (3, 3), stride=1, padding=1) + self.relu2_2 = nn.ReLU(inplace=True) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.conv3_1 = nn.Conv2d(128, 256, (3, 3), stride=1, padding=1) + self.relu3_1 = nn.ReLU(inplace=True) + self.conv3_2 = nn.Conv2d(256, 256, (3, 3), stride=1, padding=1) + self.relu3_2 = nn.ReLU(inplace=True) + self.conv3_3 = nn.Conv2d(256, 256, (3, 3), stride=1, padding=1) + self.relu3_3 = nn.ReLU(inplace=True) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.conv4_1 = nn.Conv2d(256, 512, (3, 3), stride=1, padding=1) + self.relu4_1 = nn.ReLU(inplace=True) + self.conv4_2 = nn.Conv2d(512, 512, (3, 3), stride=1, padding=1) + self.relu4_2 = nn.ReLU(inplace=True) + self.conv4_3 = nn.Conv2d(512, 512, (3, 3), stride=1, padding=1) + self.relu4_3 = nn.ReLU(inplace=True) + self.pool4 = nn.MaxPool2d(2, stride=1, ceil_mode=True) + self.conv5_1 = nn.Conv2d(512, 512, (3, 3), stride=1, padding=2, dilation=2) + self.relu5_1 = nn.ReLU(inplace=True) + self.conv5_2 = nn.Conv2d(512, 512, (3, 3), stride=1, padding=2, dilation=2) + self.relu5_2 = nn.ReLU(inplace=True) + self.conv5_3 = nn.Conv2d(512, 512, (3, 3), stride=1, padding=2, dilation=2) + self.relu5_3 = nn.ReLU(inplace=True) + if pretrain: + if '.npy' in pretrain: + state_dict = np.load(pretrain).item() + for k in state_dict: + state_dict[k] = torch.from_numpy(state_dict[k]) + else: + state_dict = torch.load(pretrain) + own_state_dict = self.state_dict() + for name, param in own_state_dict.items(): + if name in state_dict: + if logger: + logger.info('copy the weights of %s from pretrained model' % name) + param.copy_(state_dict[name]) + else: + if logger: + logger.info('init the weights of %s from mean 0, std 0.01 gaussian distribution'\ + % name) + if 'bias' in name: + param.zero_() + else: + param.normal_(0, 0.01) + else: + self._initialize_weights(logger) + + def forward(self, x): + conv1_1 = self.relu1_1(self.conv1_1(x)) + conv1_2 = self.relu1_2(self.conv1_2(conv1_1)) + pool1 = self.pool1(conv1_2) + conv2_1 = self.relu2_1(self.conv2_1(pool1)) + conv2_2 = self.relu2_2(self.conv2_2(conv2_1)) + pool2 = self.pool2(conv2_2) + conv3_1 = self.relu3_1(self.conv3_1(pool2)) + conv3_2 = self.relu3_2(self.conv3_2(conv3_1)) + conv3_3 = self.relu3_3(self.conv3_3(conv3_2)) + pool3 = self.pool3(conv3_3) + conv4_1 = self.relu4_1(self.conv4_1(pool3)) + conv4_2 = self.relu4_2(self.conv4_2(conv4_1)) + conv4_3 = self.relu4_3(self.conv4_3(conv4_2)) + pool4 = self.pool4(conv4_3) + # pool4 = conv4_3 + conv5_1 = self.relu5_1(self.conv5_1(pool4)) + conv5_2 = self.relu5_2(self.conv5_2(conv5_1)) + conv5_3 = self.relu5_3(self.conv5_3(conv5_2)) + + side = [conv1_1, conv1_2, conv2_1, conv2_2, + conv3_1, conv3_2, conv3_3, conv4_1, + conv4_2, conv4_3, conv5_1, conv5_2, conv5_3] + return side + + def _initialize_weights(self, logger=None): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if logger: + logger.info('init the weights of %s from mean 0, std 0.01 gaussian distribution'\ + % m) + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() + +if __name__ == '__main__': + model = VGG16_C() + # im = np.zeros((1,3,100,100)) + # out = model(Variable(torch.from_numpy(im))) + + diff --git a/options/__init__.py b/options/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/options/base_options.py b/options/base_options.py new file mode 100755 index 0000000..5165877 --- /dev/null +++ b/options/base_options.py @@ -0,0 +1,145 @@ +import sys +import argparse +import os +from util import util +import torch +import models +import data +import pickle +import pdb + +class BaseOptions(): + def __init__(self): + self.initialized = False + + def initialize(self, parser): + # experiment specifics + parser.add_argument('--dist_url', type=str, default='tcp://127.0.0.1:10002') + parser.add_argument('--num_gpu', type=int, default=8, help='num of gpus for cluter training') + parser.add_argument('--name', type=str, default='open-edit', help='name of the experiment. It decides where to store samples and models') + + parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') + parser.add_argument('--vocab_path', type=str, default='vocab/conceptual_vocab.pkl', help='path to vocabulary') + parser.add_argument('--vse_enc_path', type=str, default='checkpoints/conceptual_model_best.pth.tar', help='path to the pretrained text encoder') + parser.add_argument('--edge_model_path', type=str, default='checkpoints/bdcn_pretrained_on_bsds500.pth', help='path to the pretrained edge extractor') + parser.add_argument('--model', type=str, default='OpenEdit', help='which model to use') + parser.add_argument('--norm_G', type=str, default='spectralsync_batch', help='instance normalization or batch normalization') + parser.add_argument('--norm_D', type=str, default='spectralinstance', help='instance normalization or batch normalization') + parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + + # input/output sizes + parser.add_argument('--batchSize', type=int, default=8, help='input batch size') + parser.add_argument('--img_size', type=int, default=224, help='image size') + parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') + + # for setting inputs + parser.add_argument('--dataroot', type=str, default='./datasets/conceptual/') + parser.add_argument('--dataset_mode', type=str, default='conceptual') + parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') + parser.add_argument('--nThreads', default=4, type=int, help='# threads for loading data') + parser.add_argument('--load_from_opt_file', action='store_true', help='load the options from checkpoints and use that as default') + + # for displays + parser.add_argument('--display_winsize', type=int, default=256, help='display window size') + + # for generator + parser.add_argument('--netG', type=str, default='openedit', help='generator model') + parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') + parser.add_argument('--init_type', type=str, default='xavier', help='network initialization [normal|xavier|kaiming|orthogonal]') + parser.add_argument('--init_variance', type=float, default=0.02, help='variance of the initialization distribution') + + # for encoder + parser.add_argument('--netE', type=str, default='resnetbdcn') + parser.add_argument('--edge_nc', type=int, default=1) + parser.add_argument('--edge_level', type=int, default=41) + parser.add_argument('--edge_tanh', action='store_true') + + # for image-specific finetuning + parser.add_argument('--reg_weight', type=float, default=1e-4) + parser.add_argument('--perturbation', action='store_true') + parser.add_argument('--manipulation', action='store_true') + parser.add_argument('--img_path', type=str) + parser.add_argument('--ori_cap', type=str) + parser.add_argument('--new_cap', type=str) + parser.add_argument('--global_edit', action='store_true') + parser.add_argument('--alpha', type=int, default=5) + parser.add_argument('--optimize_iter', type=int, default=50) + + self.initialized = True + return parser + + def gather_options(self): + # initialize parser with basic options + if not self.initialized: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = self.initialize(parser) + + # get the basic options + opt, unknown = parser.parse_known_args() + + # if there is opt_file, load it. + # The previous default options will be overwritten + if opt.load_from_opt_file: + parser = self.update_options_from_file(parser, opt) + + opt = parser.parse_args() + self.parser = parser + return opt + + def print_options(self, opt): + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + + def option_file_path(self, opt, makedir=False): + expr_dir = os.path.join(opt.checkpoints_dir, opt.name) + if makedir: + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, 'opt') + return file_name + + def save_options(self, opt): + file_name = self.option_file_path(opt, makedir=True) + with open(file_name + '.txt', 'wt') as opt_file: + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)) + + with open(file_name + '.pkl', 'wb') as opt_file: + pickle.dump(opt, opt_file) + + def update_options_from_file(self, parser, opt): + new_opt = self.load_options(opt) + for k, v in sorted(vars(opt).items()): + if hasattr(new_opt, k) and v != getattr(new_opt, k): + new_val = getattr(new_opt, k) + parser.set_defaults(**{k: new_val}) + return parser + + def load_options(self, opt): + file_name = self.option_file_path(opt, makedir=False) + new_opt = pickle.load(open(file_name + '.pkl', 'rb')) + return new_opt + + def parse(self, save=False): + + opt = self.gather_options() + opt.isTrain = self.isTrain # train or test + + self.print_options(opt) + if opt.isTrain and save: + self.save_options(opt) + + self.opt = opt + return self.opt diff --git a/options/train_options.py b/options/train_options.py new file mode 100755 index 0000000..b6e75f0 --- /dev/null +++ b/options/train_options.py @@ -0,0 +1,39 @@ +from .base_options import BaseOptions + + +class TrainOptions(BaseOptions): + def initialize(self, parser): + BaseOptions.initialize(self, parser) + # for displays + parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') + parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs') + parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') + parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed') + + # for training + parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') + parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + parser.add_argument('--niter', type=int, default=50, help='# of iter at starting learning rate. This is NOT the total #epochs. Totla #epochs is niter + niter_decay') + parser.add_argument('--niter_decay', type=int, default=50, help='# of iter to linearly decay learning rate to zero') + parser.add_argument('--optimizer', type=str, default='adam') + parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') + parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam') + parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') + parser.add_argument('--D_steps_per_G', type=int, default=1, help='number of discriminator iterations per generator iteration.') + parser.add_argument('--G_steps_per_D', type=int, default=1, help='number of generator iterations per discriminator iteration') + + # for discriminators + parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') + parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') + parser.add_argument('--lambda_vgg', type=float, default=10.0, help='weight for vgg loss') + parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss') + parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss') + parser.add_argument('--no_l1feat_loss', action='store_true', help='if specified, do not use perceptual loss') + parser.add_argument('--l1pix_loss', action='store_true', help='if specified, use l1 loss on image pixels') + parser.add_argument('--gan_mode', type=str, default='hinge', help='(ls|original|hinge)') + parser.add_argument('--netD', type=str, default='multiscale', help='(n_layers|multiscale)') + parser.add_argument('--no_TTUR', action='store_true', help='Use TTUR training scheme') + parser.add_argument('--no_disc', action='store_true', help='no discriminator') + + self.isTrain = True + return parser diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ab19c50 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +dominate>=2.3.1 +dill +scikit-image diff --git a/test.py b/test.py new file mode 100755 index 0000000..ca27468 --- /dev/null +++ b/test.py @@ -0,0 +1,108 @@ +import os +from collections import OrderedDict + +import data +from options.train_options import TrainOptions +from models.OpenEdit_model import OpenEditModel +from trainers.OpenEdit_optimizer import OpenEditOptimizer +from util.visualizer import Visualizer +from util import html +import torch +import torch.utils.data as data +import torchvision.transforms as transforms +import os +from PIL import Image +import numpy as np +import json +import pdb +import pickle +from util.vocab import Vocabulary + +TrainOptions = TrainOptions() +opt = TrainOptions.parse() +opt.gpu = 0 + +ori_cap = opt.ori_cap.split() +new_cap = opt.new_cap.split() +import pdb +pdb.set_trace() +global_edit = False + +alpha = 5 +optimize_iter = 10 + +opt.world_size = 1 +opt.rank = 0 +opt.mpdist = False +opt.num_gpu = 1 +opt.batchSize = 1 +opt.manipulation = True +opt.perturbation = True + +open_edit_optimizer = OpenEditOptimizer(opt) +open_edit_optimizer.open_edit_model.netG.eval() + +# optimizer +visualizer = Visualizer(opt, rank=0) + +# create a webpage that summarizes the all results +web_dir = os.path.join('visual_results', opt.name, + '%s_%s' % (opt.phase, opt.which_epoch)) +webpage = html.HTML(web_dir, + 'Experiment = %s, Phase = %s, Epoch = %s' % + (opt.name, opt.phase, opt.which_epoch)) + +# image loader +transforms_list = [] +transforms_list.append(transforms.Resize((opt.img_size, opt.img_size))) +transforms_list += [transforms.ToTensor()] +transforms_list += [transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))] +transform = transforms.Compose(transforms_list) +image = Image.open(opt.img_path).convert('RGB') +image = transform(image) +image = image.unsqueeze(0).cuda() + +# text loader +vocab = pickle.load(open('vocab/'+opt.dataset_mode+'_vocab.pkl', 'rb')) +ori_txt = [] +ori_txt.append(vocab('')) +for word in ori_cap: + ori_txt.append(vocab(word)) +ori_txt.append(vocab('')) +ori_txt = torch.LongTensor(ori_txt).unsqueeze(0).cuda() +new_txt = [] +new_txt.append(vocab('')) +for word in new_cap: + new_txt.append(vocab(word)) +new_txt.append(vocab('')) +new_txt = torch.LongTensor(new_txt).unsqueeze(0).cuda() + +data = {'image': image, 'caption': new_txt, 'length': [4]} + +# save input image +visuals = OrderedDict([('input_image', image[0])]) + +# reconstruct original image +reconstructed = open_edit_optimizer.open_edit_model(data, mode='inference')[0] +visuals['reconstructed'] = reconstructed + +# manipulate without optimizing perturbations +manipulated_ori = open_edit_optimizer.open_edit_model(data, mode='manipulate', ori_cap=ori_txt, new_cap=new_txt, alpha=alpha) +visuals['manipulated_ori'] = manipulated_ori[0][0] + +# optimize perturbations +for iter_cnt in range(optimize_iter): + open_edit_optimizer.run_opt_one_step(data, ori_txt, new_txt, alpha, global_edit=global_edit) + message = '(optimization, iters: %d) ' % iter_cnt + errors = open_edit_optimizer.get_latest_losses() + for k, v in errors.items(): + v = v.mean().float() + message += '%s: %.3f ' % (k, v) + print(message) + +# manipulation results after optimizing perturbations +visuals['optimized'] = open_edit_optimizer.get_latest_generated()[0] + + +visualizer.save_images(webpage, visuals, [opt.img_path], gray=True) +webpage.save() diff --git a/test.sh b/test.sh new file mode 100644 index 0000000..bf01a62 --- /dev/null +++ b/test.sh @@ -0,0 +1,8 @@ +python test.py \ + --name OpenEdit \ + --img_path test_imgs/car_blue.jpg \ + --ori_cap 'red car' \ + --new_cap 'blue car' \ + --edge_level 41 \ + --lr 0.002 \ + --which_epoch latest diff --git a/train.py b/train.py new file mode 100755 index 0000000..15c3610 --- /dev/null +++ b/train.py @@ -0,0 +1,87 @@ +import sys +from collections import OrderedDict +import data +import torch.multiprocessing as mp +import torch.distributed as dist +import torch +from util.iter_counter import IterationCounter +from util.visualizer import Visualizer +from trainers.OpenEdit_trainer import OpenEditTrainer +from options.train_options import TrainOptions + +def main_worker(gpu, world_size, opt): + print('Use GPU: {} for training'.format(gpu)) + world_size = opt.world_size + rank = gpu + opt.gpu = gpu + dist.init_process_group(backend='nccl', init_method=opt.dist_url, world_size=world_size, rank=rank) + torch.cuda.set_device(gpu) + + # load the dataset + dataloader = data.create_dataloader(opt, world_size, rank) + + # create trainer for our model + trainer = OpenEditTrainer(opt) + + # create tool for counting iterations + iter_counter = IterationCounter(opt, len(dataloader), world_size, rank) + + # create tool for visualization + visualizer = Visualizer(opt, rank) + + for epoch in iter_counter.training_epochs(): + if opt.mpdist: + dataloader.sampler.set_epoch(epoch) + iter_counter.record_epoch_start(epoch) + for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter): + iter_counter.record_one_iteration() + + # Training + # train generator + if i % opt.D_steps_per_G == 0: + trainer.run_generator_one_step(data_i) + + # train discriminator + if not opt.no_disc and i % opt.G_steps_per_D == 0: + trainer.run_discriminator_one_step(data_i) + + iter_counter.record_iteration_end() + + # Visualizations + if iter_counter.needs_printing(): + losses = trainer.get_latest_losses() + visualizer.print_current_errors(epoch, iter_counter.epoch_iter, + losses, iter_counter.time_per_iter, + iter_counter.model_time_per_iter) + visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far) + + + visuals = OrderedDict([('synthesized_image', trainer.get_latest_generated()), + ('real_image', data_i['image'])]) + visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far) + + if rank == 0: + print('saving the latest model (epoch %d, total_steps %d)' % + (epoch, iter_counter.total_steps_so_far)) + trainer.save('latest') + iter_counter.record_current_iter() + + trainer.update_learning_rate(epoch) + iter_counter.record_epoch_end() + + if (epoch % opt.save_epoch_freq == 0 or epoch == iter_counter.total_epochs) and (rank == 0): + print('saving the model at the end of epoch %d, iters %d' % + (epoch, iter_counter.total_steps_so_far)) + trainer.save(epoch) + + print('Training was successfully finished.') + +if __name__ == '__main__': + global TrainOptions + TrainOptions = TrainOptions() + opt = TrainOptions.parse(save=True) + opt.world_size = opt.num_gpu + opt.mpdist = True + + mp.set_start_method('spawn', force=True) + mp.spawn(main_worker, nprocs=opt.world_size, args=(opt.world_size, opt)) diff --git a/train.sh b/train.sh new file mode 100644 index 0000000..eb1640d --- /dev/null +++ b/train.sh @@ -0,0 +1,5 @@ +python train.py \ + --name Opendit \ + --edge_level 41 \ + --nThreads 4 \ + --batchSize 4 diff --git a/trainers/OpenEdit_optimizer.py b/trainers/OpenEdit_optimizer.py new file mode 100644 index 0000000..c9c6b18 --- /dev/null +++ b/trainers/OpenEdit_optimizer.py @@ -0,0 +1,35 @@ +from models.OpenEdit_model import OpenEditModel +import torch +import torch.nn as nn + +class OpenEditOptimizer(): + """ + Optimizer perturbation parameters during testing + """ + + def __init__(self, opt): + self.opt = opt + self.open_edit_model = OpenEditModel(opt) + self.open_edit_model_on_one_gpu = self.open_edit_model + + self.generated = None + + self.optimizer_P = self.open_edit_model_on_one_gpu.create_P_optimizers(opt) + self.old_lr = opt.lr + + def run_opt_one_step(self, data, ori_cap, new_cap, alpha, global_edit=False): + self.optimizer_P.zero_grad() + r_losses, generated = self.open_edit_model( + data, mode='optimize', ori_cap=ori_cap, new_cap=new_cap, + alpha=alpha, global_edit=global_edit) + r_loss = sum(r_losses.values()).mean() + r_loss.backward() + self.optimizer_P.step() + self.r_losses = r_losses + self.generated = generated + + def get_latest_losses(self): + return {**self.r_losses} + + def get_latest_generated(self): + return self.generated diff --git a/trainers/OpenEdit_trainer.py b/trainers/OpenEdit_trainer.py new file mode 100644 index 0000000..e920958 --- /dev/null +++ b/trainers/OpenEdit_trainer.py @@ -0,0 +1,42 @@ +from models.OpenEdit_model import OpenEditModel +import torch + +class OpenEditTrainer(): + """ + Trainer creates the model and optimizers, and uses them to + updates the weights of the network while reporting losses + and the latest visuals to visualize the progress in training. + """ + + def __init__(self, opt): + self.opt = opt + self.open_edit_model = OpenEditModel(opt) + if opt.mpdist: + self.open_edit_model = torch.nn.parallel.DistributedDataParallel(self.open_edit_model, device_ids=[opt.gpu], find_unused_parameters=True) + self.open_edit_model_on_one_gpu = self.open_edit_model.module + else: + self.open_edit_model_on_one_gpu = self.open_edit_model + + self.generated = None + self.loss_gan = not opt.no_disc + if opt.isTrain: + self.optimizer_G, self.optimizer_D = \ + self.open_edit_model_on_one_gpu.create_optimizers(opt) + self.old_lr = opt.lr + + def run_generator_one_step(self, data): + self.optimizer_G.zero_grad() + g_losses, generated = self.open_edit_model(data, mode='generator') + g_loss = sum(g_losses.values()).mean() + g_loss.backward() + self.optimizer_G.step() + self.g_losses = g_losses + self.generated = generated + + def run_discriminator_one_step(self, data): + self.optimizer_D.zero_grad() + d_losses = self.open_edit_model(data, mode='discriminator') + d_loss = sum(d_losses.values()).mean() + d_loss.backward() + self.optimizer_D.step() + self.d_losses = d_losses diff --git a/trainers/__init__.py b/trainers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/util/__init__.py b/util/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/util/html.py b/util/html.py new file mode 100755 index 0000000..79a64aa --- /dev/null +++ b/util/html.py @@ -0,0 +1,72 @@ +import datetime +import dominate +from dominate.tags import * +import os + + +class HTML: + def __init__(self, web_dir, title, refresh=0): + if web_dir.endswith('.html'): + web_dir, html_name = os.path.split(web_dir) + else: + web_dir, html_name = web_dir, 'index.html' + self.title = title + self.web_dir = web_dir + self.html_name = html_name + self.img_dir = os.path.join(self.web_dir, 'images') + if len(self.web_dir) > 0 and not os.path.exists(self.web_dir): + os.makedirs(self.web_dir) + if len(self.web_dir) > 0 and not os.path.exists(self.img_dir): + os.makedirs(self.img_dir) + + self.doc = dominate.document(title=title) + with self.doc: + h1(datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")) + if refresh > 0: + with self.doc.head: + meta(http_equiv="refresh", content=str(refresh)) + + def get_image_dir(self): + return self.img_dir + + def add_header(self, str): + with self.doc: + h3(str) + + def add_table(self, border=1): + self.t = table(border=border, style="table-layout: fixed;") + self.doc.add(self.t) + + def add_images(self, ims, txts, links, width=512): + self.add_table() + with self.t: + with tr(): + for im, txt, link in zip(ims, txts, links): + with td(style="word-wrap: break-word;", halign="center", valign="top"): + with p(): + with a(href=os.path.join('images', link)): + img(style="width:%dpx" % (width), src=os.path.join('images', im)) + br() + p(txt.encode('utf-8')) + + def save(self): + #html_file = '%s/%s' % (self.web_dir, self.html_name) + html_file = os.path.join(self.web_dir, self.html_name) + f = open(html_file, 'wt') + f.write(self.doc.render()) + f.close() + + +if __name__ == '__main__': + html = HTML('web/', 'test_html') + html.add_header('hello world') + + ims = [] + txts = [] + links = [] + for n in range(4): + ims.append('image_%d.jpg' % n) + txts.append('text_%d' % n) + links.append('image_%d.jpg' % n) + html.add_images(ims, txts, links) + html.save() diff --git a/util/iter_counter.py b/util/iter_counter.py new file mode 100644 index 0000000..afdf34b --- /dev/null +++ b/util/iter_counter.py @@ -0,0 +1,71 @@ +import os +import time +import numpy as np + + +## Helper class that keeps track of training iterations +class IterationCounter(): + def __init__(self, opt, dataset_size, world_size=1, rank=0): + self.opt = opt + self.dataset_size = dataset_size + + self.first_epoch = 1 + self.total_epochs = opt.niter + opt.niter_decay + self.epoch_iter = 0 # iter number within each epoch + self.iter_record_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, 'iter.txt') + if opt.isTrain and opt.continue_train: + try: + self.first_epoch, self.epoch_iter = np.loadtxt( + self.iter_record_path, delimiter=',', dtype=int) + print('Resuming from epoch %d at iteration %d' % (self.first_epoch, self.epoch_iter)) + except: + print('Could not load iteration record at %s. Starting from beginning.' % + self.iter_record_path) + + self.total_steps_so_far = (self.first_epoch - 1) * dataset_size + self.epoch_iter + + self.world_size = world_size + self.rank = rank + + # return the iterator of epochs for the training + def training_epochs(self): + return range(self.first_epoch, self.total_epochs + 1) + + def record_epoch_start(self, epoch): + self.epoch_start_time = time.time() + self.epoch_iter = 0 + self.last_iter_time = time.time() + self.current_epoch = epoch + + def record_one_iteration(self): + current_time = time.time() + + self.time_per_iter = (current_time - self.last_iter_time) / (self.opt.batchSize * self.world_size) + self.last_iter_time = current_time + self.total_steps_so_far += self.opt.batchSize * self.world_size + self.epoch_iter += self.opt.batchSize * self.world_size + + def record_iteration_end(self): + current_time = time.time() + + self.model_time_per_iter = (current_time - self.last_iter_time) / (self.opt.batchSize * self.world_size) + + def record_epoch_end(self): + current_time = time.time() + self.time_per_epoch = current_time - self.epoch_start_time + if self.rank == 0: + print('End of epoch %d / %d \t Time Taken: %d sec' % + (self.current_epoch, self.total_epochs, self.time_per_epoch)) + if self.current_epoch % self.opt.save_epoch_freq == 0: + np.savetxt(self.iter_record_path, (self.current_epoch+1, 0), + delimiter=',', fmt='%d') + print('Saved current iteration count at %s.' % self.iter_record_path) + + def record_current_iter(self): + if self.rank == 0: + np.savetxt(self.iter_record_path, (self.current_epoch, self.epoch_iter), + delimiter=',', fmt='%d') + print('Saved current iteration count at %s.' % self.iter_record_path) + + def needs_printing(self): + return (self.total_steps_so_far % self.opt.print_freq) < self.opt.batchSize * self.world_size diff --git a/util/util.py b/util/util.py new file mode 100755 index 0000000..7d09177 --- /dev/null +++ b/util/util.py @@ -0,0 +1,100 @@ +import re +import importlib +import torch +from argparse import Namespace +import numpy as np +from PIL import Image +import os +import argparse + + +# Converts a Tensor into a Numpy array +# |imtype|: the desired type of the converted numpy array +def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False): + if isinstance(image_tensor, list): + image_numpy = [] + for i in range(len(image_tensor)): + image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) + return image_numpy + + if image_tensor.dim() == 4: + # transform each image in the batch + images_np = [] + for b in range(image_tensor.size(0)): + one_image = image_tensor[b] + one_image_np = tensor2im(one_image) + images_np.append(one_image_np.reshape(1, *one_image_np.shape)) + images_np = np.concatenate(images_np, axis=0) + if tile: + images_tiled = tile_images(images_np) + return images_tiled + else: + return images_np + + if image_tensor.dim() == 2: + image_tensor = image_tensor.unsqueeze(0) + image_numpy = image_tensor.detach().cpu().float().numpy() + if normalize: + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + else: + image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 + image_numpy = np.clip(image_numpy, 0, 255) + if image_numpy.shape[2] == 1: + image_numpy = image_numpy[:,:,0] + return image_numpy.astype(imtype) + +def save_image(image_numpy, image_path, create_dir=False, gray=False): + if create_dir: + os.makedirs(os.path.dirname(image_path), exist_ok=True) + ## save to png + if gray: + if (image_numpy.shape) == 3: + assert(image_numpy.shape[2] == 1) + image_numpy = image_numpy.squeeze(2) + image_pil = Image.fromarray(image_numpy) + image_pil.save(image_path.replace('.jpg', '.png')) + else: + if len(image_numpy.shape) == 2 and not gray: + image_numpy = np.expand_dims(image_numpy, axis=2) + if image_numpy.shape[2] == 1 and not gray: + image_numpy = np.repeat(image_numpy, 3, 2) + image_pil = Image.fromarray(image_numpy) + image_pil.save(image_path.replace('.jpg', '.png')) + +def mkdirs(paths): + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + +def find_class_in_module(target_cls_name, module): + target_cls_name = target_cls_name.replace('_', '').lower() + clslib = importlib.import_module(module) + cls = None + for name, clsobj in clslib.__dict__.items(): + if name.lower() == target_cls_name: + cls = clsobj + + if cls is None: + print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)) + exit(0) + + return cls + +def save_network(net, label, epoch, opt): + save_filename = '%s_net_%s.pth' % (epoch, label) + save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename) + torch.save(net.state_dict(), save_path) # net.cpu() -> net + +def load_network(net, label, epoch, opt): + save_filename = '%s_net_%s.pth' % (epoch, label) + save_dir = os.path.join(opt.checkpoints_dir, opt.name) + save_path = os.path.join(save_dir, save_filename) + weights = torch.load(save_path, map_location=torch.device('cpu')) + net.load_state_dict(weights) + return net diff --git a/util/visualizer.py b/util/visualizer.py new file mode 100755 index 0000000..049a3d3 --- /dev/null +++ b/util/visualizer.py @@ -0,0 +1,176 @@ +import os +import ntpath +import time +from . import util +from . import html +import scipy.misc +try: + from StringIO import StringIO # Python 2.7 +except ImportError: + from io import BytesIO # Python 3.x + +class Visualizer(): + def __init__(self, opt, rank=0): + self.rank = rank + self.opt = opt + self.tf_log = opt.isTrain and opt.tf_log + self.use_html = opt.isTrain and not opt.no_html + self.win_size = opt.display_winsize + self.name = opt.name + if self.tf_log: + import tensorflow as tf + self.tf = tf + self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs') + if self.rank == 0: + self.writer = tf.summary.FileWriter(self.log_dir) + + if self.use_html: + self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') + self.img_dir = os.path.join(self.web_dir, 'images') + if self.rank == 0: + print('create web directory %s...' % self.web_dir) + util.mkdirs([self.web_dir, self.img_dir]) + if opt.isTrain: + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + if self.rank == 0: + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + # |visuals|: dictionary of images to display or save + def display_current_results(self, visuals, epoch, step): + + ## convert tensors to numpy arrays + visuals = self.convert_visuals_to_numpy(visuals) + + if 0: # do not show images in tensorboard output + img_summaries = [] + for label, image_numpy in visuals.items(): + # Write the image to a string + try: + s = StringIO() + except: + s = BytesIO() + if len(image_numpy.shape) >= 4: + image_numpy = image_numpy[0] + if self.rank == 0: + scipy.misc.toimage(image_numpy).save(s, format="jpeg") + # Create an Image object + img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1]) + # Create a Summary value + img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum)) + + if self.rank == 0: + # Create and write Summary + summary = self.tf.Summary(value=img_summaries) + self.writer.add_summary(summary, step) + + if self.use_html: # save images to a html file + for label, image_numpy in visuals.items(): + if isinstance(image_numpy, list): + for i in range(len(image_numpy)): + img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%d.png' % (epoch, label, i)) + if self.rank == 0: + util.save_image(image_numpy[i], img_path) + else: + img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) + if len(image_numpy.shape) >= 4: + image_numpy = image_numpy[0] + if self.rank == 0: + util.save_image(image_numpy, img_path) + + if self.rank == 0: + # update website + webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=5) + for n in range(epoch, 0, -1): + webpage.add_header('epoch [%d]' % n) + ims = [] + txts = [] + links = [] + + for label, image_numpy in visuals.items(): + if isinstance(image_numpy, list): + for i in range(len(image_numpy)): + img_path = 'epoch%.3d_%s_%d.png' % (n, label, i) + ims.append(img_path) + txts.append(label+str(i)) + links.append(img_path) + else: + img_path = 'epoch%.3d_%s.png' % (n, label) + ims.append(img_path) + txts.append(label) + links.append(img_path) + if len(ims) < 10: + webpage.add_images(ims, txts, links, width=self.win_size) + else: + num = int(round(len(ims)/2.0)) + webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size) + webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size) + webpage.save() + + # errors: dictionary of error labels and values + def plot_current_errors(self, errors, step): + if self.tf_log: + for tag, value in errors.items(): + value = value.mean().float() + if self.rank == 0: + summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)]) + self.writer.add_summary(summary, step) + + # errors: same format as |errors| of plotCurrentErrors + def print_current_errors(self, epoch, i, errors, t, mt): + message = '(epoch: %d, iters: %d, time: %.3f, model_time: %.3f) ' % (epoch, i, t, mt) + for k, v in errors.items(): + #print(v) + #if v != 0: + v = v.mean().float() + message += '%s: %.3f ' % (k, v) + + if self.rank == 0: + print(message) + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) + + def convert_visuals_to_numpy(self, visuals, gray=False): + for key, t in visuals.items(): + tile = self.opt.batchSize > 1 + if 'input_label' == key and not gray: + t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile) + elif 'input_label' == key and gray: + t = util.tensor2labelgray(t, self.opt.label_nc + 2, tile=tile) + else: + t = util.tensor2im(t, tile=tile) + visuals[key] = t + return visuals + + # save image to the disk + def save_images(self, webpage, visuals, image_path, gray=False): + visuals = self.convert_visuals_to_numpy(visuals, gray=gray) + + image_dir = webpage.get_image_dir() + short_path = ntpath.basename(image_path[0]) + name = os.path.splitext(short_path)[0] + + if self.rank == 0: + webpage.add_header(name) + ims = [] + txts = [] + links = [] + + cnt = 0 + for label, image_numpy in visuals.items(): + image_name = os.path.join(label, '%s.png' % (name)) + save_path = os.path.join(image_dir, image_name) + if self.rank == 0: + util.save_image(image_numpy, save_path, create_dir=True) + + ims.append(image_name) + txts.append(label) + links.append(image_name) + cnt += 1 + if cnt % 4 == 0: + if self.rank == 0: + webpage.add_images(ims, txts, links, width=self.win_size) + ims = [] + txts = [] + links = [] diff --git a/util/vocab.py b/util/vocab.py new file mode 100644 index 0000000..4ad2412 --- /dev/null +++ b/util/vocab.py @@ -0,0 +1,78 @@ +# Create a vocabulary wrapper +import nltk +import pickle +from collections import Counter +import json +import argparse +import os + +annotations = { + 'conceptual': ['train_caption.json'] +} + + +class Vocabulary(object): + """Simple vocabulary wrapper.""" + + def __init__(self): + self.word2idx = {} + self.idx2word = {} + self.idx = 0 + + def add_word(self, word): + if word not in self.word2idx: + self.word2idx[word] = self.idx + self.idx2word[self.idx] = word + self.idx += 1 + + def __call__(self, word): + if word not in self.word2idx: + return self.word2idx[''] + return self.word2idx[word] + + def __len__(self): + return len(self.word2idx) + + +def build_vocab(data_path, data_name, jsons, threshold): + """Build a simple vocabulary wrapper.""" + counter = Counter() + for path in jsons[data_name]: + full_path = os.path.join(os.path.join(os.path.join(data_path, data_name), 'img_lists'), path) + captions = json.load(open(full_path,'r')) + for i, caption in enumerate(captions): + tokens = nltk.tokenize.word_tokenize(caption.lower()) + counter.update(tokens) + + if i % 1000 == 0: + print("[%d/%d] tokenized the captions." % (i, len(captions))) + + # Discard if the occurrence of the word is less than min_word_cnt. + words = [word for word, cnt in counter.items() if cnt >= threshold] + + # Create a vocab wrapper and add some special tokens. + vocab = Vocabulary() + vocab.add_word('') + vocab.add_word('') + vocab.add_word('') + vocab.add_word('') + + # Add words to the vocabulary. + for i, word in enumerate(words): + vocab.add_word(word) + return vocab + + +def main(data_path, data_name): + vocab = build_vocab(data_path, data_name, jsons=annotations, threshold=0) + with open('./vocab/%s_vocab.pkl' % data_name, 'wb') as f: + pickle.dump(vocab, f, pickle.HIGHEST_PROTOCOL) + print("Saved vocabulary file to ", './vocab/%s_vocab.pkl' % data_name) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--data_path', default='data') + parser.add_argument('--data_name', default='conceptual') + opt = parser.parse_args() + main(opt.data_path, opt.data_name)