diff --git a/.gitignore b/.gitignore
new file mode 100755
index 0000000..ff850b5
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,13 @@
+checkpoints/
+datasets/
+results/
+*.tar.gz
+*.pth
+*.zip
+*.pkl
+*.pyc
+*/__pycache__/
+*_example/
+visual_results/
+test_imgs/
+web/
diff --git a/README.md b/README.md
index 0c6c6b8..2d6b6c8 100644
--- a/README.md
+++ b/README.md
@@ -1,14 +1,51 @@
# Open-Edit: Open-Domain Image Manipulation with Open-Vocabulary Instructions
-[Xihui Liu](https://xh-liu.github.io), Zhe Lin, Jianming Zhang, Handong Zhao, Quan Tran, Xiaogang Wang, and Hongsheng Li.
+[Xihui Liu](https://xh-liu.github.io), [Zhe Lin](https://sites.google.com/site/zhelin625/), [Jianming Zhang](http://cs-people.bu.edu/jmzhang/), [Handong Zhao](https://hdzhao.github.io/), [Quan Tran](https://research.adobe.com/person/quan-hung-tran/), [Xiaogang Wang](https://www.ee.cuhk.edu.hk/~xgwang/), and [Hongsheng Li](https://www.ee.cuhk.edu.hk/~hsli/).
Published in ECCV 2020.
-### [Paper](https://arxiv.org/pdf/2008.01576.pdf) | [1-minute video](https://youtu.be/8E3bwvjCHYE)
+### [Paper](https://arxiv.org/pdf/2008.01576.pdf) | [1-minute video](https://youtu.be/8E3bwvjCHYE) | [Slides](https://drive.google.com/file/d/1m3JKSUotm6sRImak_qjwBMtMtd037XeK/view?usp=sharing)

-### Code Coming Soon!
+### Installation
+Clone this repo.
+```bash
+git clone https://github.com/xh-liu/Open-Edit
+cd Open-Edit
+```
+
+Install [PyTorch 1.1+](https://pytorch.org/get-started/locally/) and other requirements.
+```bash
+
+pip install -r requirements.txt
+```
+
+### Download pretrained models
+
+Download pretrained models from [Google Drive](https://drive.google.com/drive/folders/1iG_II7_PytTY6NdzyZ5WDkzPTXcB2NcE?usp=sharing)
+
+### Data preparation
+
+We use [Conceptual Captions dataset](https://ai.google.com/research/ConceptualCaptions/download) for training. Download the dataset and put it under the dataset folder. You can also use other datasets
+
+### Training
+
+The visual-semantic embedding model is trained with [VSE++](https://github.com/fartashf/vsepp).
+
+The image decoder is trained with:
+
+```bash
+bash train.sh
+```
+
+## Testing
+
+You can specify the image path and text instructions in test.sh.
+
+```bash
+bash test.sh
+```
### Citation
If you use this code for your research, please cite our papers.
diff --git a/data/__init__.py b/data/__init__.py
new file mode 100755
index 0000000..ae1e636
--- /dev/null
+++ b/data/__init__.py
@@ -0,0 +1,60 @@
+import importlib
+import torch.utils.data
+
+def collate_fn_img(images):
+ images = torch.stack(images, 0)
+ input_dicts = {'image': images}
+ return input_dicts
+
+
+def find_dataset_using_name(dataset_name):
+ # Given the option --dataset [datasetname],
+ # the file "datasets/datasetname_dataset.py"
+ # will be imported.
+ dataset_filename = "data." + dataset_name + "_dataset"
+ datasetlib = importlib.import_module(dataset_filename)
+
+ # In the file, the class called DatasetNameDataset() will
+ # be instantiated. It has to be a subclass of BaseDataset,
+ # and it is case-insensitive.
+ dataset = None
+ target_dataset_name = dataset_name.replace('_', '') + 'dataset'
+ for name, cls in datasetlib.__dict__.items():
+ if name.lower() == target_dataset_name.lower():
+ dataset = cls
+
+ if dataset is None:
+ raise ValueError("In %s.py, there should be a subclass of BaseDataset "
+ "with class name that matches %s in lowercase." %
+ (dataset_filename, target_dataset_name))
+
+ return dataset
+
+def create_dataloader(opt, world_size, rank):
+ dataset = find_dataset_using_name(opt.dataset_mode)
+ instance = dataset(opt)
+ print("dataset [%s] of size %d was created" %
+ (type(instance).__name__, len(instance)))
+
+ collate_fn = collate_fn_img
+
+ if opt.mpdist:
+ train_sampler = torch.utils.data.distributed.DistributedSampler(instance, num_replicas=world_size, rank=rank)
+ dataloader = torch.utils.data.DataLoader(
+ instance,
+ batch_size=opt.batchSize,
+ sampler=train_sampler,
+ shuffle=False,
+ num_workers=int(opt.nThreads),
+ collate_fn=collate_fn,
+ drop_last=opt.isTrain
+ )
+ else:
+ dataloader = torch.utils.data.DataLoader(
+ instance,
+ batch_size=opt.batchSize,
+ shuffle=not opt.serial_batches,
+ num_workers=int(opt.nThreads),
+ drop_last=opt.isTrain
+ )
+ return dataloader
diff --git a/data/conceptual_dataset.py b/data/conceptual_dataset.py
new file mode 100644
index 0000000..cea6a47
--- /dev/null
+++ b/data/conceptual_dataset.py
@@ -0,0 +1,32 @@
+import torch
+import torch.utils.data as data
+import torchvision.transforms as transforms
+import os
+from PIL import Image
+import json
+
+class ConceptualDataset(data.Dataset):
+ def __init__(self, opt):
+ self.path = os.path.join(opt.dataroot, 'images')
+ if opt.isTrain:
+ self.ids = json.load(open(os.path.join(opt.dataroot, 'val_index.json'), 'r'))
+ else:
+ self.ids = json.load(open(os.path.join(opt.dataroot, 'val_index.json'), 'r'))
+
+ transforms_list = []
+ transforms_list.append(transforms.Resize((opt.img_size, opt.img_size)))
+ transforms_list += [transforms.ToTensor()]
+ transforms_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
+ self.transform = transforms.Compose(transforms_list)
+
+ def __getitem__(self, index):
+ """This function returns a tuple that is further passed to collate_fn
+ """
+ img_id = self.ids[index]
+ image = Image.open(os.path.join(self.path, img_id)).convert('RGB')
+ image = self.transform(image)
+
+ return image
+
+ def __len__(self):
+ return len(self.ids)
diff --git a/models/OpenEdit_model.py b/models/OpenEdit_model.py
new file mode 100644
index 0000000..e94db5d
--- /dev/null
+++ b/models/OpenEdit_model.py
@@ -0,0 +1,256 @@
+import torch
+import torch.nn.functional as F
+import models.networks as networks
+import util.util as util
+import pickle
+from models.networks.txt_enc import EncoderText
+from models.networks.perturbation import PerturbationNet
+import random
+
+def l2norm(x, norm_dim=1):
+ norm = torch.pow(x, 2).sum(dim=norm_dim, keepdim=True).sqrt()
+ x = torch.div(x, norm)
+ return x
+
+class OpenEditModel(torch.nn.Module):
+
+ def __init__(self, opt):
+ super().__init__()
+ self.opt = opt
+ self.FloatTensor = torch.cuda.FloatTensor
+ self.ByteTensor = torch.cuda.ByteTensor
+ self.perturbation = opt.perturbation
+
+ self.netG, self.netD, self.netE = self.initialize_networks(opt)
+
+ self.generator = opt.netG
+
+ self.noise_range = opt.noise_range
+ if self.perturbation:
+ self.netP = PerturbationNet(opt)
+ self.netP.cuda()
+ if self.opt.manipulation:
+ self.vocab = pickle.load(open(opt.vocab_path, 'rb'))
+ self.txt_enc = EncoderText(len(self.vocab), 300, 1024, 1)
+ self.txt_enc.load_state_dict(torch.load(opt.vse_enc_path, map_location='cpu')['model'][1])
+ self.txt_enc.eval().cuda()
+
+ # set loss functions
+ if self.perturbation:
+ self.criterionPix = torch.nn.L1Loss()
+ self.criterionVGG = networks.VGGLoss(self.opt.gpu)
+ elif opt.isTrain:
+ self.loss_l1pix = opt.l1pix_loss
+ self.loss_gan = not opt.no_disc
+ self.loss_ganfeat = not opt.no_ganFeat_loss
+ self.loss_vgg = not opt.no_vgg_loss
+
+ if self.loss_l1pix:
+ self.criterionPix = torch.nn.L1Loss()
+ if self.loss_gan:
+ self.criterionGAN = networks.GANLoss(
+ opt.gan_mode, tensor=self.FloatTensor, opt=self.opt)
+ if self.loss_ganfeat:
+ self.criterionFeat = torch.nn.L1Loss()
+ if self.loss_vgg:
+ self.criterionVGG = networks.VGGLoss(self.opt.gpu)
+
+ def forward(self, data, mode, ori_cap=None, new_cap=None, alpha=1, global_edit=False):
+ if not data['image'].is_cuda:
+ data['image'] = data['image'].cuda()
+ if mode == 'generator':
+ g_loss, generated = self.compute_generator_loss(data)
+ return g_loss, generated
+ elif mode == 'discriminator':
+ d_loss = self.compute_discriminator_loss(data)
+ return d_loss
+ elif mode == 'inference':
+ with torch.no_grad():
+ fake_image, _ = self.generate_fake(data['image'])
+ return fake_image
+ elif mode == 'manipulate':
+ with torch.no_grad():
+ fake_image = self.manipulate(data['image'], ori_cap, new_cap, alpha, global_edit=global_edit)
+ return fake_image
+ elif mode == 'optimize':
+ g_loss, generated = self.optimizeP(data, ori_cap, new_cap, alpha, global_edit=global_edit)
+ return g_loss, generated
+ else:
+ raise ValueError("|mode| is invalid")
+
+ def create_P_optimizers(self, opt):
+ P_params = list(self.netP.parameters())
+
+ beta1, beta2 = 0, 0.9
+ P_lr = opt.lr
+
+ optimizer_P = torch.optim.Adam(P_params, lr=P_lr, betas=(beta1, beta2))
+
+ return optimizer_P
+
+
+ def create_optimizers(self, opt):
+ G_params = list(self.netG.parameters())
+ if opt.isTrain and self.loss_gan:
+ D_params = list(self.netD.parameters())
+
+ if opt.no_TTUR:
+ beta1, beta2 = opt.beta1, opt.beta2
+ G_lr, D_lr = opt.lr, opt.lr
+ else:
+ beta1, beta2 = 0, 0.9
+ G_lr, D_lr = opt.lr / 2, opt.lr * 2
+
+ optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2))
+ if self.loss_gan and opt.isTrain:
+ optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2))
+ else:
+ optimizer_D = None
+
+ return optimizer_G, optimizer_D
+
+ def save(self, epoch):
+ util.save_network(self.netG, 'G', epoch, self.opt)
+ if self.loss_gan:
+ util.save_network(self.netD, 'D', epoch, self.opt)
+
+ def initialize_networks(self, opt):
+ print(opt.isTrain)
+ netG = networks.define_G(opt)
+ netD = networks.define_D(opt) if opt.isTrain and not opt.no_disc else None
+ netE = networks.define_E(opt)
+
+ if not opt.isTrain or opt.continue_train or opt.manipulation:
+ netG = util.load_network(netG, 'G', opt.which_epoch, opt)
+ if opt.isTrain or opt.needs_D:
+ netD = util.load_network(netD, 'D', opt.which_epoch, opt)
+ print('network D loaded')
+
+ return netG, netD, netE
+
+ def compute_generator_loss(self, data):
+ G_losses = {}
+
+ fake_image, spatial_embedding = self.generate_fake(data['image'])
+
+ if self.loss_l1pix:
+ G_losses['Pix'] = self.criterionPix(fake_image, data['image'])
+ if self.loss_gan:
+ pred_fake, pred_real = self.discriminate(fake_image, data)
+
+ if self.loss_ganfeat:
+ actual_num_D = 0
+ num_D = len(pred_fake)
+ GAN_Feat_loss = self.FloatTensor(1).fill_(0)
+ for i in range(num_D): # for each discriminator
+ # last output is the final prediction, so we exclude it
+ num_intermediate_outputs = len(pred_fake[i]) - 1
+ if num_intermediate_outputs == 0:
+ continue
+ else:
+ actual_num_D += 1
+ for j in range(num_intermediate_outputs): # for each layer output
+ unweighted_loss = self.criterionFeat(
+ pred_fake[i][j], pred_real[i][j].detach())
+ GAN_Feat_loss += unweighted_loss * self.opt.lambda_feat
+ G_losses['GAN_Feat'] = GAN_Feat_loss / actual_num_D
+
+ if self.loss_gan:
+ G_losses['GAN'] = self.criterionGAN(pred_fake, True,
+ for_discriminator=False)
+
+ if self.loss_vgg:
+ G_losses['VGG'] = self.criterionVGG(fake_image, data['image']) \
+ * self.opt.lambda_vgg
+
+ return G_losses, fake_image
+
+ def compute_discriminator_loss(self, data):
+ D_losses = {}
+ with torch.no_grad():
+ fake_image, spatial_embedding = self.generate_fake(data['image'])
+
+ pred_fake, pred_real = self.discriminate(fake_image, data)
+
+ D_losses['D_Fake'] = self.criterionGAN(pred_fake, False,
+ for_discriminator=True)
+ D_losses['D_real'] = self.criterionGAN(pred_real, True,
+ for_discriminator=True)
+
+ return D_losses
+
+ def generate_fake(self, real_image, return_skip = False, return_edge = False):
+
+ with torch.no_grad():
+ spatial_embedding, edge = self.netE(real_image)
+ fake_image = self.netG(spatial_embedding, edge)
+ if return_edge:
+ return fake_image, spatial_embedding, edge
+ else:
+ return fake_image, spatial_embedding
+
+ def generate_fake_withP(self, real_image):
+ with torch.no_grad():
+ spatial_embedding, edge = self.netE(real_image)
+ p, P_reg = self.netP()
+ fake_image = self.netG(spatial_embedding, edge, perturbation=True, p=p)
+
+ return fake_image, P_reg
+
+ def discriminate(self, fake_image, data, edge=None):
+
+ fake_and_real_img = torch.cat([fake_image, data['image']], dim=0)
+
+ discriminator_out = self.netD(fake_and_real_img)
+
+ pred_fake, pred_real = self.divide_pred(discriminator_out)
+
+ return pred_fake, pred_real
+
+ def manipulate(self, real_image, ori_txt, new_txt, alpha, global_edit=False):
+
+ spatial_embedding, edge = self.netE(real_image, norm=True)
+
+ with torch.no_grad():
+ ori_txt = self.txt_enc(ori_txt, [ori_txt.shape[1]])
+ new_txt = self.txt_enc(new_txt, [new_txt.shape[1]])
+
+ proj = spatial_embedding * ori_txt.unsqueeze(2).unsqueeze(3).repeat(1,1,spatial_embedding.shape[2],spatial_embedding.shape[3])
+ proj_s = proj.sum(1, keepdim=True)
+ proj = proj_s.repeat(1,1024,1,1)
+ # proj = F.sigmoid(proj)
+
+ if global_edit:
+ proj[:] = 1 # for global attributes, don't need to do grounding
+
+ spatial_embedding = spatial_embedding - alpha * proj * ori_txt.unsqueeze(2).unsqueeze(3).repeat(1,1,spatial_embedding.shape[2],spatial_embedding.shape[3])
+ spatial_embedding = spatial_embedding + alpha * proj * new_txt.unsqueeze(2).unsqueeze(3).repeat(1,1,spatial_embedding.shape[2],spatial_embedding.shape[3])
+
+ spatial_embedding = l2norm(spatial_embedding)
+
+ if self.perturbation:
+ p, P_reg = self.netP()
+ fake_image = self.netG(spatial_embedding, edge, perturbation=True, p=p)
+ return fake_image
+ else:
+ fake_image = self.netG(spatial_embedding, edge)
+ return fake_image
+
+ def optimizeP(self, data, ori_cap, new_cap, alpha, global_edit=False):
+
+ P_losses = {}
+
+ fake_image, P_reg = self.generate_fake_withP(data['image'])
+
+ manip_image = self.manipulate(data['image'], ori_cap, new_cap, alpha, global_edit=global_edit)
+ cycle_image = self.manipulate(manip_image, new_cap, ori_cap, alpha, global_edit=global_edit)
+
+ P_losses['Pix'] = self.criterionPix(fake_image, data['image'])
+ P_losses['VGG'] = self.criterionVGG(fake_image, data['image']) * self.opt.lambda_vgg
+
+ P_losses['cyclePix'] = self.criterionPix(cycle_image, data['image'])
+ P_losses['cycleVGG'] = self.criterionVGG(cycle_image, data['image']) * self.opt.lambda_vgg
+
+ P_losses['Reg'] = P_reg
+
+ return P_losses, fake_image
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100755
index 0000000..112acce
--- /dev/null
+++ b/models/__init__.py
@@ -0,0 +1,34 @@
+import importlib
+import torch
+
+
+def find_model_using_name(model_name):
+ # Given the option --model [modelname],
+ # the file "models/modelname_model.py"
+ # will be imported.
+ model_filename = "models." + model_name + "_model"
+ modellib = importlib.import_module(model_filename)
+
+ # In the file, the class called ModelNameModel() will
+ # be instantiated. It has to be a subclass of torch.nn.Module,
+ # and it is case-insensitive.
+ model = None
+ target_model_name = model_name.replace('_', '') + 'model'
+ for name, cls in modellib.__dict__.items():
+ if name.lower() == target_model_name.lower() \
+ and issubclass(cls, torch.nn.Module):
+ model = cls
+
+ if model is None:
+ print("In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase." % (model_filename, target_model_name))
+ exit(0)
+
+ return model
+
+
+def create_model(opt):
+ model = find_model_using_name(opt.model)
+ instance = model(opt)
+ print("model [%s] was created" % (type(instance).__name__))
+
+ return instance
diff --git a/models/networks/__init__.py b/models/networks/__init__.py
new file mode 100644
index 0000000..8254fed
--- /dev/null
+++ b/models/networks/__init__.py
@@ -0,0 +1,59 @@
+import importlib
+import torch
+import torch.nn as nn
+import functools
+from torch.autograd import Variable
+import numpy as np
+import torch.nn.functional as F
+from models.networks.base_network import BaseNetwork
+from models.networks.loss import *
+from models.networks.discriminator import *
+from models.networks.generator import *
+from models.networks.encoder import *
+import util.util as util
+
+
+def find_network_using_name(target_network_name, filename):
+ target_class_name = target_network_name + filename
+ module_name = 'models.networks.' + filename
+ network = util.find_class_in_module(target_class_name, module_name)
+
+ assert issubclass(network, BaseNetwork), \
+ "Class %s should be a subclass of BaseNetwork" % network
+
+ return network
+
+def create_network(cls, opt, init=True, cuda=True):
+ net = cls(opt)
+ net.print_network()
+ assert(torch.cuda.is_available())
+ if cuda:
+ if opt.mpdist:
+ net.cuda(opt.gpu)
+ else:
+ net.cuda()
+ if init:
+ net.init_weights(opt.init_type, opt.init_variance)
+ return net
+
+def define_G(opt):
+ netG_cls = find_network_using_name(opt.netG, 'generator')
+ return create_network(netG_cls, opt)
+
+
+def define_D(opt):
+ netD_cls = find_network_using_name(opt.netD, 'discriminator')
+ return create_network(netD_cls, opt)
+
+
+def define_E(opt):
+ netE_cls = find_network_using_name(opt.netE, 'encoder')
+ netE = create_network(netE_cls, opt, init=False, cuda=False)
+ state_dict = torch.load(opt.vse_enc_path, map_location='cpu')
+ netE.load_state_dict(state_dict['model'][0])
+ if opt.mpdist:
+ netE.cuda(opt.gpu)
+ else:
+ netE.cuda()
+ netE.eval()
+ return netE
diff --git a/models/networks/architecture.py b/models/networks/architecture.py
new file mode 100644
index 0000000..31ef92e
--- /dev/null
+++ b/models/networks/architecture.py
@@ -0,0 +1,92 @@
+import math
+import re
+import torch
+import torch.nn as nn
+import numpy as np
+import torch.nn.functional as F
+import torchvision
+import torch.nn.utils.spectral_norm as spectral_norm
+from torch.nn import Parameter
+from models.networks.base_network import BaseNetwork
+from models.networks.normalization import get_norm_layer
+from models.networks.normalization import SPADE
+
+
+## VGG architecter, used for the perceptual loss using a pretrained VGG network
+class VGG19(torch.nn.Module):
+ def __init__(self, requires_grad=False, local_pretrained_path='checkpoints/vgg19.pth'):
+ super().__init__()
+ # if we have network access
+ # vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
+ # if no network access
+ model = torchvision.models.vgg19()
+ model.load_state_dict(torch.load(local_pretrained_path))
+ vgg_pretrained_features = model.features
+
+ self.slice1 = torch.nn.Sequential()
+ self.slice2 = torch.nn.Sequential()
+ self.slice3 = torch.nn.Sequential()
+ self.slice4 = torch.nn.Sequential()
+ self.slice5 = torch.nn.Sequential()
+ for x in range(2):
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(2, 7):
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(7, 12):
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(12, 21):
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(21, 30):
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, X):
+ h_relu1 = self.slice1(X)
+ h_relu2 = self.slice2(h_relu1)
+ h_relu3 = self.slice3(h_relu2)
+ h_relu4 = self.slice4(h_relu3)
+ h_relu5 = self.slice5(h_relu4)
+ out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
+ return out
+
+class SPADEResnetBlock(nn.Module):
+ def __init__(self, fin, fout, opt):
+ super().__init__()
+ # Attributes
+ self.learned_shortcut = (fin != fout)
+ fmiddle = min(fin, fout)
+ # create conv layers
+ self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
+ self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
+ if self.learned_shortcut:
+ self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
+
+ self.conv_0 = spectral_norm(self.conv_0)
+ self.conv_1 = spectral_norm(self.conv_1)
+ if self.learned_shortcut:
+ self.conv_s = spectral_norm(self.conv_s)
+ spade_config_str = 'batchnorm3x3'
+ self.norm_0 = SPADE(spade_config_str, fin, opt.edge_nc, opt)
+ self.norm_1 = SPADE(spade_config_str, fmiddle, opt.edge_nc, opt)
+ if self.learned_shortcut:
+ self.norm_s = SPADE(spade_config_str, fin, opt.edge_nc, opt)
+
+ def forward(self, x, seg):
+ x_s = self.shortcut(x, seg)
+ dx = self.conv_0(self.actvn(self.norm_0(x, seg)))
+ dx = self.conv_1(self.actvn(self.norm_1(dx, seg)))
+ out = x_s + dx
+ return out
+
+ def shortcut(self, x, seg):
+ if self.learned_shortcut:
+ x_s = self.conv_s(self.norm_s(x, seg))
+ else:
+ x_s = x
+ return x_s
+
+ def actvn(self, x):
+ return F.leaky_relu(x, 2e-1)
+
diff --git a/models/networks/base_network.py b/models/networks/base_network.py
new file mode 100644
index 0000000..108e4a8
--- /dev/null
+++ b/models/networks/base_network.py
@@ -0,0 +1,51 @@
+import torch
+import torch.nn as nn
+from torch.nn import init
+
+class BaseNetwork(nn.Module):
+ def __init__(self):
+ super(BaseNetwork, self).__init__()
+
+ def print_network(self):
+ if isinstance(self, list):
+ self = self[0]
+ num_params = 0
+ for param in self.parameters():
+ num_params += param.numel()
+ print('Network [%s] was created. Total number of parameters: %.1f million. '
+ 'To see the architecture, do print(network).'
+ % (type(self).__name__, num_params / 1000000))
+
+ def init_weights(self, init_type='normal', gain=0.02):
+ def init_func(m):
+ classname = m.__class__.__name__
+ if classname.find('BatchNorm2d') != -1:
+ if hasattr(m, 'weight') and m.weight is not None:
+ init.normal_(m.weight.data, 1.0, gain)
+ if hasattr(m, 'bias') and m.bias is not None:
+ init.constant_(m.bias.data, 0.0)
+ elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
+ if init_type == 'normal':
+ init.normal_(m.weight.data, 0.0, gain)
+ elif init_type == 'xavier':
+ init.xavier_normal_(m.weight.data, gain=gain)
+ elif init_type == 'xavier_uniform':
+ init.xavier_uniform_(m.weight.data, gain=1.0)
+ elif init_type == 'kaiming':
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
+ elif init_type == 'orthogonal':
+ init.orthogonal_(m.weight.data, gain=gain)
+ elif init_type == 'none': # uses pytorch's default init method
+ m.reset_parameters()
+ else:
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
+ if hasattr(m, 'bias') and m.bias is not None:
+ init.constant_(m.bias.data, 0.0)
+
+ self.apply(init_func)
+
+ # propagate to children
+ for m in self.children():
+ if hasattr(m, 'init_weights'):
+ m.init_weights(init_type, gain)
+
diff --git a/models/networks/bdcn.py b/models/networks/bdcn.py
new file mode 100644
index 0000000..433d32a
--- /dev/null
+++ b/models/networks/bdcn.py
@@ -0,0 +1,247 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from models.networks.vgg16_c import VGG16_C
+
+def crop(data1, data2, crop_h, crop_w):
+ _, _, h1, w1 = data1.size()
+ _, _, h2, w2 = data2.size()
+ assert(h2 <= h1 and w2 <= w1)
+ data = data1[:, :, crop_h:crop_h+h2, crop_w:crop_w+w2]
+ return data
+
+def get_upsampling_weight(in_channels, out_channels, kernel_size):
+ """Make a 2D bilinear kernel suitable for upsampling"""
+ factor = (kernel_size + 1) // 2
+ if kernel_size % 2 == 1:
+ center = factor - 1
+ else:
+ center = factor - 0.5
+ og = np.ogrid[:kernel_size, :kernel_size]
+ filt = (1 - abs(og[0] - center) / factor) * \
+ (1 - abs(og[1] - center) / factor)
+ weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size),
+ dtype=np.float64)
+ weight[range(in_channels), range(out_channels), :, :] = filt
+ return torch.from_numpy(weight).float()
+
+class MSBlock(nn.Module):
+ def __init__(self, c_in, rate=4):
+ super(MSBlock, self).__init__()
+ c_out = c_in
+ self.rate = rate
+
+ self.conv = nn.Conv2d(c_in, 32, 3, stride=1, padding=1)
+ self.relu = nn.ReLU(inplace=True)
+ dilation = self.rate*1 if self.rate >= 1 else 1
+ self.conv1 = nn.Conv2d(32, 32, 3, stride=1, dilation=dilation, padding=dilation)
+ self.relu1 = nn.ReLU(inplace=True)
+ dilation = self.rate*2 if self.rate >= 1 else 1
+ self.conv2 = nn.Conv2d(32, 32, 3, stride=1, dilation=dilation, padding=dilation)
+ self.relu2 = nn.ReLU(inplace=True)
+ dilation = self.rate*3 if self.rate >= 1 else 1
+ self.conv3 = nn.Conv2d(32, 32, 3, stride=1, dilation=dilation, padding=dilation)
+ self.relu3 = nn.ReLU(inplace=True)
+
+ self._initialize_weights()
+
+ def forward(self, x):
+ o = self.relu(self.conv(x))
+ o1 = self.relu1(self.conv1(o))
+ o2 = self.relu2(self.conv2(o))
+ o3 = self.relu3(self.conv3(o))
+ out = o + o1 + o2 + o3
+ return out
+
+ def _initialize_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ m.weight.data.normal_(0, 0.01)
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+class BDCN_multi(nn.Module):
+ def __init__(self, level=1, pretrain=None, logger=None, rate=4):
+ super(BDCN_multi, self).__init__()
+ self.pretrain = pretrain
+ t = 1
+
+ self.features = VGG16_C(pretrain, logger)
+ self.msblock1_1 = MSBlock(64, rate)
+ self.msblock1_2 = MSBlock(64, rate)
+ self.conv1_1_down = nn.Conv2d(32*t, 21, (1, 1), stride=1)
+ self.conv1_2_down = nn.Conv2d(32*t, 21, (1, 1), stride=1)
+ self.score_dsn1 = nn.Conv2d(21, 1, (1, 1), stride=1)
+ self.score_dsn1_1 = nn.Conv2d(21, 1, 1, stride=1)
+ self.msblock2_1 = MSBlock(128, rate)
+ self.msblock2_2 = MSBlock(128, rate)
+ self.conv2_1_down = nn.Conv2d(32*t, 21, (1, 1), stride=1)
+ self.conv2_2_down = nn.Conv2d(32*t, 21, (1, 1), stride=1)
+ self.score_dsn2 = nn.Conv2d(21, 1, (1, 1), stride=1)
+ self.score_dsn2_1 = nn.Conv2d(21, 1, (1, 1), stride=1)
+ self.msblock3_1 = MSBlock(256, rate)
+ self.msblock3_2 = MSBlock(256, rate)
+ self.msblock3_3 = MSBlock(256, rate)
+ self.conv3_1_down = nn.Conv2d(32*t, 21, (1, 1), stride=1)
+ self.conv3_2_down = nn.Conv2d(32*t, 21, (1, 1), stride=1)
+ self.conv3_3_down = nn.Conv2d(32*t, 21, (1, 1), stride=1)
+ self.score_dsn3 = nn.Conv2d(21, 1, (1, 1), stride=1)
+ self.score_dsn3_1 = nn.Conv2d(21, 1, (1, 1), stride=1)
+ self.msblock4_1 = MSBlock(512, rate)
+ self.msblock4_2 = MSBlock(512, rate)
+ self.msblock4_3 = MSBlock(512, rate)
+ self.conv4_1_down = nn.Conv2d(32*t, 21, (1, 1), stride=1)
+ self.conv4_2_down = nn.Conv2d(32*t, 21, (1, 1), stride=1)
+ self.conv4_3_down = nn.Conv2d(32*t, 21, (1, 1), stride=1)
+ self.score_dsn4 = nn.Conv2d(21, 1, (1, 1), stride=1)
+ self.score_dsn4_1 = nn.Conv2d(21, 1, (1, 1), stride=1)
+ self.msblock5_1 = MSBlock(512, rate)
+ self.msblock5_2 = MSBlock(512, rate)
+ self.msblock5_3 = MSBlock(512, rate)
+ self.conv5_1_down = nn.Conv2d(32*t, 21, (1, 1), stride=1)
+ self.conv5_2_down = nn.Conv2d(32*t, 21, (1, 1), stride=1)
+ self.conv5_3_down = nn.Conv2d(32*t, 21, (1, 1), stride=1)
+ self.score_dsn5 = nn.Conv2d(21, 1, (1, 1), stride=1)
+ self.score_dsn5_1 = nn.Conv2d(21, 1, (1, 1), stride=1)
+ self.upsample_2 = nn.ConvTranspose2d(1, 1, 4, stride=2, bias=False)
+ self.upsample_4 = nn.ConvTranspose2d(1, 1, 8, stride=4, bias=False)
+ self.upsample_8 = nn.ConvTranspose2d(1, 1, 16, stride=8, bias=False)
+ self.upsample_8_5 = nn.ConvTranspose2d(1, 1, 16, stride=8, bias=False)
+ self.fuse = nn.Conv2d(10, 1, 1, stride=1)
+
+ self._initialize_weights(logger)
+
+ self.level = level
+
+ def forward(self, x):
+ features = self.features(x)
+ sum1 = self.conv1_1_down(self.msblock1_1(features[0])) + \
+ self.conv1_2_down(self.msblock1_2(features[1]))
+ s1 = self.score_dsn1(sum1)
+ if self.level == 11:
+ return s1
+ # print(s1.data.shape, s11.data.shape)
+ sum2 = self.conv2_1_down(self.msblock2_1(features[2])) + \
+ self.conv2_2_down(self.msblock2_2(features[3]))
+ s2 = self.score_dsn2(sum2)
+ s2 = self.upsample_2(s2)
+ # print(s2.data.shape, s21.data.shape)
+ s2 = crop(s2, x, 1, 1)
+ if self.level == 21:
+ return s2 + s1
+ sum3 = self.conv3_1_down(self.msblock3_1(features[4])) + \
+ self.conv3_2_down(self.msblock3_2(features[5])) + \
+ self.conv3_3_down(self.msblock3_3(features[6]))
+ s3 = self.score_dsn3(sum3)
+ s3 =self.upsample_4(s3)
+ # print(s3.data.shape)
+ s3 = crop(s3, x, 2, 2)
+ if self.level == 31:
+ return s3 + s2 + s1
+
+ sum4 = self.conv4_1_down(self.msblock4_1(features[7])) + \
+ self.conv4_2_down(self.msblock4_2(features[8])) + \
+ self.conv4_3_down(self.msblock4_3(features[9]))
+ s4 = self.score_dsn4(sum4)
+ s4 = self.upsample_8(s4)
+ # print(s4.data.shape)
+ s4 = crop(s4, x, 4, 4)
+ if self.level == 41:
+ return s4 + s3 + s2 + s1
+ sum5 = self.conv5_1_down(self.msblock5_1(features[10])) + \
+ self.conv5_2_down(self.msblock5_2(features[11])) + \
+ self.conv5_3_down(self.msblock5_3(features[12]))
+ s5 = self.score_dsn5(sum5)
+ s5 = self.upsample_8_5(s5)
+ # print(s5.data.shape)
+ s5 = crop(s5, x, 0, 0)
+ if self.level == 51:
+ return s5 + s4 + s3 + s2 + s1
+
+ s51 = self.score_dsn5_1(sum5)
+ s51 = self.upsample_8_5(s51)
+ # print(s51.data.shape)
+ s51 = crop(s51, x, 0, 0)
+ if self.level == 52:
+ return s51
+ s41 = self.score_dsn4_1(sum4)
+ s41 = self.upsample_8(s41)
+ # print(s41.data.shape)
+ s41 = crop(s41, x, 4, 4)
+ if self.level == 42:
+ return s41 + s51
+ s31 = self.score_dsn3_1(sum3)
+ s31 =self.upsample_4(s31)
+ # print(s31.data.shape)
+ s31 = crop(s31, x, 2, 2)
+ if self.level == 32:
+ return s31 + s41 + s51
+ s21 = self.score_dsn2_1(sum2)
+ s21 = self.upsample_2(s21)
+ s21 = crop(s21, x, 1, 1)
+ if self.level == 22:
+ return s22 + s31 + s41 + s51
+ s11 = self.score_dsn1_1(sum1)
+ if self.level == 12:
+ return s11 + s21 + s31 + s41 + s51
+
+
+ o1, o2, o3, o4, o5 = s1.detach(), s2.detach(), s3.detach(), s4.detach(), s5.detach()
+ o11, o21, o31, o41, o51 = s11.detach(), s21.detach(), s31.detach(), s41.detach(), s51.detach()
+ p1_1 = s1
+ p2_1 = s2 + o1
+ p3_1 = s3 + o2 + o1
+ p4_1 = s4 + o3 + o2 + o1
+ p5_1 = s5 + o4 + o3 + o2 + o1
+ p1_2 = s11 + o21 + o31 + o41 + o51
+ p2_2 = s21 + o31 + o41 + o51
+ p3_2 = s31 + o41 + o51
+ p4_2 = s41 + o51
+ p5_2 = s51
+
+ if self.level == 11:
+ return p1_1
+ elif self.level == 21:
+ return p2_1
+ elif self.level == 31:
+ return p3_1
+ elif self.level == 41:
+ return p4_1
+ elif self.level == 51:
+ return p5_1
+ elif self.level == 12:
+ return p1_2
+ elif self.level == 22:
+ return p2_2
+ elif self.level == 32:
+ return p3_2
+ elif self.level == 42:
+ return p4_2
+ elif self.level == 52:
+ return p5_2
+
+ def _initialize_weights(self, logger=None):
+ for name, param in self.state_dict().items():
+ if self.pretrain and 'features' in name:
+ continue
+ # elif 'down' in name:
+ # param.zero_()
+ elif 'upsample' in name:
+ if logger:
+ logger.info('init upsamle layer %s ' % name)
+ k = int(name.split('.')[0].split('_')[1])
+ param.copy_(get_upsampling_weight(1, 1, k*2))
+ elif 'fuse' in name:
+ if logger:
+ logger.info('init params %s ' % name)
+ if 'bias' in name:
+ param.zero_()
+ else:
+ nn.init.constant(param, 0.080)
+ else:
+ if logger:
+ logger.info('init params %s ' % name)
+ if 'bias' in name:
+ param.zero_()
+ else:
+ param.normal_(0, 0.01)
diff --git a/models/networks/discriminator.py b/models/networks/discriminator.py
new file mode 100644
index 0000000..63563f7
--- /dev/null
+++ b/models/networks/discriminator.py
@@ -0,0 +1,94 @@
+import sys
+import torch
+import re
+import torch.nn as nn
+from collections import OrderedDict
+import os.path
+import functools
+from torch.autograd import Variable
+import numpy as np
+import torch.nn.functional as F
+from models.networks.base_network import BaseNetwork
+from models.networks.normalization import get_norm_layer
+import util.util as util
+
+def l2norm(X):
+ norm = torch.pow(X, 2).sum(dim=1, keepdim=True).sqrt()
+ X = torch.div(X, norm)
+ return X
+
+class MultiscaleDiscriminator(BaseNetwork):
+
+ def __init__(self, opt):
+ super().__init__()
+ opt.num_D = 2
+ self.opt = opt
+
+ for i in range(opt.num_D):
+ subnetD = NLayerDiscriminator(opt)
+ self.add_module('discriminator_%d' % i, subnetD)
+
+ def downsample(self, input):
+ return F.avg_pool2d(input, kernel_size=3,
+ stride=2, padding=[1, 1],
+ count_include_pad=False)
+
+ ## Returns list of lists of discriminator outputs.
+ ## The final result is of size opt.num_D x opt.n_layers_D
+ def forward(self, input, semantics=None):
+ result = []
+ get_intermediate_features = not self.opt.no_ganFeat_loss
+ for name, D in self.named_children():
+ if semantics is None:
+ out = D(input)
+ else:
+ out = D(input, semantics)
+ if not get_intermediate_features:
+ out = [out]
+ result.append(out)
+ input = self.downsample(input)
+
+ return result
+
+# Defines the PatchGAN discriminator with the specified arguments.
+class NLayerDiscriminator(BaseNetwork):
+ def __init__(self, opt):
+ super().__init__()
+ opt.n_layers_D = 4
+ self.opt = opt
+
+ kw = 4
+ padw = int(np.ceil((kw-1.0)/2))
+ nf = opt.ndf
+ input_nc = 3
+
+ norm_layer = get_norm_layer(opt, opt.norm_D)
+ sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw),
+ nn.LeakyReLU(0.2, False)]]
+
+ for n in range(1, opt.n_layers_D):
+ nf_prev = nf
+ nf = min(nf * 2, 512)
+ stride = 1 if n == opt.n_layers_D - 1 else 2
+ sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw,
+ stride=stride, padding=padw), opt),
+ nn.LeakyReLU(0.2, False)
+ ]]
+
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
+
+ ## We divide the layers into groups to extract intermediate layer outputs
+ for n in range(len(sequence)):
+ self.add_module('model'+str(n), nn.Sequential(*sequence[n]))
+
+ def forward(self, input):
+ results = [input]
+ for submodel in self.children():
+ intermediate_output = submodel(results[-1])
+ results.append(intermediate_output)
+
+ get_intermediate_features = not self.opt.no_ganFeat_loss
+ if get_intermediate_features:
+ return results[1:]
+ else:
+ return results[-1]
diff --git a/models/networks/encoder.py b/models/networks/encoder.py
new file mode 100644
index 0000000..461549c
--- /dev/null
+++ b/models/networks/encoder.py
@@ -0,0 +1,70 @@
+import torch
+import torch.nn as nn
+import functools
+from torch.autograd import Variable
+import numpy as np
+import torch.nn.functional as F
+import torchvision.models as models
+from models.networks.base_network import BaseNetwork
+from models.networks.bdcn import BDCN_multi
+import pdb
+
+def l2norm(x, norm_dim=1):
+ norm = torch.pow(x, 2).sum(dim=norm_dim, keepdim=True).sqrt()
+ x = torch.div(x, norm)
+ return x
+
+class ResNetbdcnEncoder(BaseNetwork):
+ def __init__(self, opt):
+ super().__init__()
+ img_dim = 2048
+ embed_size = 1024
+ self.cnn = self.get_cnn('resnet152', False)
+ self.cnn = nn.Sequential(self.cnn.conv1, self.cnn.bn1,
+ self.cnn.relu, self.cnn.maxpool,
+ self.cnn.layer1, self.cnn.layer2,
+ self.cnn.layer3, self.cnn.layer4)
+ self.conv1x1 = nn.Conv2d(img_dim, embed_size, 1)
+ # bdcn
+ self.edgenet = BDCN_multi(level=opt.edge_level)
+ edge_params = torch.load(opt.edge_model_path, map_location='cpu')
+ self.edgenet.load_state_dict(edge_params)
+ self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
+ self.tanh = nn.Tanh()
+ self.edge_tanh = opt.edge_tanh
+
+ def forward(self, x0, skip=False, norm=True, pool=False):
+ xe = (x0 + 1) / 2.0 * 255
+ xe = xe[:,[2,1,0],:,:]
+ xe[:,0,:,:] = xe[:,0,:,:] - 123.0
+ xe[:,1,:,:] = xe[:,1,:,:] - 117.0
+ xe[:,2,:,:] = xe[:,2,:,:] - 104.0
+ edge = self.edgenet(xe)
+ x = self.cnn(x0)
+ x = self.conv1x1(x)
+ if pool:
+ x = self.avg_pool(x)
+ if norm:
+ x = l2norm(x)
+ if self.edge_tanh:
+ return x, self.tanh(edge)
+ else:
+ return x, edge
+
+ def get_cnn(self, arch, pretrained):
+ if pretrained:
+ print("=> using pre-trained model '{}'".format(arch))
+ model = models.__dict__[arch](pretrained=True)
+ else:
+ print("=> creating model '{}'".format(arch))
+ model = models.__dict__[arch]()
+ return model
+
+ def load_state_dict(self, state_dict):
+ own_state = self.state_dict()
+ for name, param in state_dict.items():
+ if name not in own_state:
+ continue
+ if isinstance(param, nn.Parameter):
+ param = param.data
+ own_state[name].copy_(param)
diff --git a/models/networks/generator.py b/models/networks/generator.py
new file mode 100644
index 0000000..4d07891
--- /dev/null
+++ b/models/networks/generator.py
@@ -0,0 +1,64 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from models.networks.base_network import BaseNetwork
+from models.networks.architecture import SPADEResnetBlock
+import random
+import numpy as np
+
+
+class OpenEditGenerator(BaseNetwork):
+ def __init__(self, opt):
+ super().__init__()
+ self.opt = opt
+ nf = opt.ngf
+
+ self.sw = 7
+ self.sh = 7
+
+ self.head_0 = SPADEResnetBlock(16*nf, 16*nf, opt)
+
+ self.G_middle_0 = SPADEResnetBlock(16*nf, 16*nf, opt)
+ self.G_middle_1 = SPADEResnetBlock(16*nf, 16*nf, opt)
+
+ self.up_0 = SPADEResnetBlock(16*nf, 8*nf, opt)
+ self.up_1 = SPADEResnetBlock(8*nf, 4*nf, opt)
+ self.up_2 = SPADEResnetBlock(4*nf, 2*nf, opt)
+ self.up_3 = SPADEResnetBlock(2*nf, 1*nf, opt)
+
+ final_nc = nf
+
+ self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)
+
+ self.up = nn.Upsample(scale_factor=2)
+
+ def forward(self, x, edge, perturbation=False, p=None):
+
+ # input: 1024 x 7 x 7
+ x = self.head_0(x, edge) # 1024 x 7 x 7
+
+ x = self.up(x) # 1024 x 14 x 14
+ x = self.G_middle_0(x, edge) # 1024 x 14 x 14
+ if perturbation:
+ x = x + p[0]
+
+ x = self.G_middle_1(x, edge) # 1024 x 14 x 14
+
+ x = self.up(x) # 1024 x 28 x 28
+ x = self.up_0(x, edge) # 512 x 28 x 28
+ if perturbation:
+ x = x + p[1]
+ x = self.up(x) # 512 x 56 x 56
+ x = self.up_1(x, edge)
+ if perturbation:
+ x = x + p[2]
+ x = self.up(x) # 256 x 112 x 112
+ x = self.up_2(x, edge)
+ x = self.up(x) # 128 x 224 x 224
+ x = self.up_3(x, edge)
+
+ x = self.conv_img(F.leaky_relu(x, 2e-1))
+ x = F.tanh(x)
+
+ return x
+
diff --git a/models/networks/loss.py b/models/networks/loss.py
new file mode 100644
index 0000000..1fad989
--- /dev/null
+++ b/models/networks/loss.py
@@ -0,0 +1,107 @@
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+import torch.nn.functional as F
+import numpy as np
+from models.networks.architecture import VGG19
+
+class GANLoss(nn.Module):
+ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0,
+ tensor=torch.FloatTensor, opt=None):
+ super(GANLoss, self).__init__()
+ self.real_label = target_real_label
+ self.fake_label = target_fake_label
+ self.real_label_tensor = None
+ self.fake_label_tensor = None
+ self.zero_tensor = None
+ self.Tensor = tensor
+ self.gan_mode = gan_mode
+ self.opt = opt
+ if gan_mode == 'ls':
+ pass
+ elif gan_mode == 'original':
+ pass
+ elif gan_mode == 'w':
+ pass
+ elif gan_mode == 'hinge':
+ pass
+ else:
+ raise ValueError('Unexpected gan_mode {}'.format(gan_mode))
+
+ def get_target_tensor(self, input, target_is_real):
+ if target_is_real:
+ if self.real_label_tensor is None:
+ self.real_label_tensor = self.Tensor(1).fill_(self.real_label)
+ self.real_label_tensor.requires_grad_(False)
+ return self.real_label_tensor.expand_as(input)
+ else:
+ if self.fake_label_tensor is None:
+ self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label)
+ self.fake_label_tensor.requires_grad_(False)
+ return self.fake_label_tensor.expand_as(input)
+
+ def get_zero_tensor(self, input):
+ if self.zero_tensor is None:
+ self.zero_tensor = self.Tensor(1).fill_(0)
+ self.zero_tensor.requires_grad_(False)
+ return self.zero_tensor.expand_as(input)
+
+ def loss(self, input, target_is_real, for_discriminator=True):
+ if self.gan_mode == 'original': # cross entropy loss
+ target_tensor = self.get_target_tensor(input, target_is_real)
+ batchsize = input.size(0)
+ loss = F.binary_cross_entropy_with_logits(input, target_tensor)
+ return loss
+ elif self.gan_mode == 'ls':
+ target_tensor = self.get_target_tensor(input, target_is_real)
+ return F.mse_loss(input, target_tensor)
+ elif self.gan_mode == 'hinge':
+ if for_discriminator:
+ if target_is_real:
+ minval = torch.min(input - 1, self.get_zero_tensor(input))
+ loss = -torch.mean(minval)
+ else:
+ minval = torch.min(-input - 1, self.get_zero_tensor(input))
+ loss = -torch.mean(minval)
+ else:
+ assert target_is_real, "The generator's hinge loss must be aiming for real"
+ loss = -torch.mean(input)
+ return loss
+ else:
+ # wgan
+ if target_is_real:
+ return -input.mean()
+ else:
+ return input.mean()
+
+ def __call__(self, input, target_is_real, for_discriminator=True):
+ if isinstance(input, list):
+ loss = 0
+ for pred_i in input:
+ if isinstance(pred_i, list):
+ pred_i = pred_i[-1]
+ loss_tensor = self.loss(pred_i, target_is_real, for_discriminator)
+ bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
+ new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
+ loss += new_loss
+ return loss / len(input)
+ else:
+ return self.loss(input, target_is_real, for_discriminator)
+
+## Perceptual loss that uses a pretrained VGG network
+class VGGLoss(nn.Module):
+ def __init__(self, gpu):
+ super(VGGLoss, self).__init__()
+ if gpu is not None:
+ self.vgg = VGG19().cuda(gpu)
+ else:
+ self.vgg = VGG19().cuda()
+ self.criterion = nn.L1Loss()
+ self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
+
+ def forward(self, x, y):
+ x_vgg, y_vgg = self.vgg(x), self.vgg(y)
+ loss = 0
+ for i in range(len(x_vgg)):
+ loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
+ return loss
diff --git a/models/networks/normalization.py b/models/networks/normalization.py
new file mode 100644
index 0000000..bddae7e
--- /dev/null
+++ b/models/networks/normalization.py
@@ -0,0 +1,73 @@
+import re
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.nn.utils.spectral_norm as spectral_norm
+
+## Returns a function that creates a normalization function
+## that does not condition on semantic map
+def get_norm_layer(opt, norm_type='instance'):
+ # helper function to get # output channels of the previous layer
+ def get_out_channel(layer):
+ if hasattr(layer, 'out_channels'):
+ return getattr(layer, 'out_channels')
+ return layer.weight.size(0)
+
+ # this function will be returned
+ def add_norm_layer(layer, opt):
+ nonlocal norm_type
+ if norm_type.startswith('spectral'):
+ layer = spectral_norm(layer)
+ subnorm_type = norm_type[len('spectral'):]
+ else:
+ subnorm_type = norm_type
+
+ if subnorm_type == 'none' or len(subnorm_type) == 0:
+ return layer
+
+ # remove bias in the previous layer, which is meaningless
+ # since it has no effect after normalization
+ if getattr(layer, 'bias', None) is not None:
+ delattr(layer, 'bias')
+ layer.register_parameter('bias', None)
+
+ if subnorm_type == 'instance':
+ norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
+ elif subnorm_type == 'sync_batch' and opt.mpdist:
+ norm_layer = nn.SyncBatchNorm(get_out_channel(layer), affine=True)
+ else:
+ norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
+
+ return nn.Sequential(layer, norm_layer)
+
+ return add_norm_layer
+
+class SPADE(nn.Module):
+ def __init__(self, config_text, norm_nc, label_nc, opt):
+ super().__init__()
+
+ if 'instance' in opt.norm_G:
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
+ elif 'sync_batch' in opt.norm_G and opt.mpdist:
+ self.param_free_norm = nn.SyncBatchNorm(norm_nc, affine=False)
+ else:
+ self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
+
+ nhidden = 128
+ ks = 3
+ pw = 1
+ self.mlp_shared = nn.Sequential(nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),nn.ReLU())
+ self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
+ self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
+
+ def forward(self, x, condition):
+ normalized = self.param_free_norm(x)
+
+ if x.size()[2] != condition.size()[2]:
+ condition = F.interpolate(condition, size=x.size()[2:], mode='nearest')
+ actv = self.mlp_shared(condition)
+ gamma = self.mlp_gamma(actv)
+ beta = self.mlp_beta(actv)
+
+ out = normalized * (1 + gamma) + beta
+ return out
diff --git a/models/networks/perturbation.py b/models/networks/perturbation.py
new file mode 100644
index 0000000..b90aa31
--- /dev/null
+++ b/models/networks/perturbation.py
@@ -0,0 +1,22 @@
+import torch
+import torch.nn as nn
+import functools
+import numpy as np
+import torch.nn.functional as F
+from models.networks.base_network import BaseNetwork
+
+class PerturbationNet(BaseNetwork):
+ def __init__(self, opt):
+ super().__init__()
+ self.reg_weight = opt.reg_weight
+ self.param1 = nn.Parameter(torch.zeros(1, 1024, 14, 14))
+ self.param2 = nn.Parameter(torch.zeros(1, 512, 28, 28))
+ self.param3 = nn.Parameter(torch.zeros(1, 256, 56, 56))
+
+ def parameters(self):
+ return [self.param1, self.param2, self.param3]
+
+ def forward(self):
+ R_reg = torch.pow(self.param1, 2).sum() + torch.pow(self.param2, 2).sum() + torch.pow(self.param3, 2).sum()
+ R_reg = R_reg * self.reg_weight
+ return [self.param1, self.param2, self.param3], R_reg
diff --git a/models/networks/txt_enc.py b/models/networks/txt_enc.py
new file mode 100644
index 0000000..c14f3a8
--- /dev/null
+++ b/models/networks/txt_enc.py
@@ -0,0 +1,66 @@
+import torch
+import torch.nn as nn
+import torch.nn.init
+import torch.nn.functional as F
+import torchvision.models as models
+from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
+import torch.backends.cudnn as cudnn
+from torch.nn.utils.clip_grad import clip_grad_norm
+import numpy as np
+from collections import OrderedDict
+
+def l2norm(X, norm_dim=1):
+ """L2-normalize columns of X
+ """
+ norm = torch.pow(X, 2).sum(dim=norm_dim, keepdim=True).sqrt()
+ X = torch.div(X, norm)
+ return X
+
+class EncoderText(nn.Module):
+
+ def __init__(self, vocab_size, word_dim, embed_size, num_layers,
+ use_abs=False, no_txt_norm=False, nonorm=False):
+ super(EncoderText, self).__init__()
+ self.no_txtnorm = no_txt_norm
+ self.no_norm_gd = nonorm
+ self.use_abs = use_abs
+ self.embed_size = embed_size
+
+ # word embedding
+ self.embed = nn.Embedding(vocab_size, word_dim)
+
+ # caption embedding
+ self.rnn = nn.GRU(word_dim, embed_size, num_layers, batch_first=True)
+
+ self.init_weights()
+
+ def init_weights(self):
+ self.embed.weight.data.uniform_(-0.1, 0.1)
+
+ def forward(self, x, lengths, ret=True):
+ """Handles variable size captions
+ """
+ # Embed word ids to vectors
+ x = self.embed(x)
+ packed = pack_padded_sequence(x, lengths, batch_first=True)
+
+ # Forward propagate RNN
+ out, _ = self.rnn(packed)
+
+ # Reshape *final* output to (batch_size, hidden_size)
+ padded = pad_packed_sequence(out, batch_first=True)
+ I = torch.LongTensor(lengths).view(-1, 1, 1)
+ #I = Variable(I.expand(x.size(0), 1, self.embed_size)-1).cuda()
+ I = (I.expand(x.size(0), 1, self.embed_size)-1).cuda()
+ out = torch.gather(padded[0], 1, I).squeeze(1)
+
+ # normalization in the joint embedding space
+ if not self.no_txtnorm and (not self.no_norm_gd or ret):
+ out = l2norm(out)
+
+ # take absolute value, used by order embeddings
+ if self.use_abs:
+ out = torch.abs(out)
+
+ return out
+
diff --git a/models/networks/vgg16_c.py b/models/networks/vgg16_c.py
new file mode 100755
index 0000000..78bf861
--- /dev/null
+++ b/models/networks/vgg16_c.py
@@ -0,0 +1,112 @@
+import numpy as np
+import torch
+import torchvision
+import torch.nn as nn
+import math
+
+class VGG16_C(nn.Module):
+ """"""
+ def __init__(self, pretrain=None, logger=None):
+ super(VGG16_C, self).__init__()
+ self.conv1_1 = nn.Conv2d(3, 64, (3, 3), stride=1, padding=1)
+ self.relu1_1 = nn.ReLU(inplace=True)
+ self.conv1_2 = nn.Conv2d(64, 64, (3, 3), stride=1, padding=1)
+ self.relu1_2 = nn.ReLU(inplace=True)
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+ self.conv2_1 = nn.Conv2d(64, 128, (3, 3), stride=1, padding=1)
+ self.relu2_1 = nn.ReLU(inplace=True)
+ self.conv2_2 = nn.Conv2d(128, 128, (3, 3), stride=1, padding=1)
+ self.relu2_2 = nn.ReLU(inplace=True)
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+ self.conv3_1 = nn.Conv2d(128, 256, (3, 3), stride=1, padding=1)
+ self.relu3_1 = nn.ReLU(inplace=True)
+ self.conv3_2 = nn.Conv2d(256, 256, (3, 3), stride=1, padding=1)
+ self.relu3_2 = nn.ReLU(inplace=True)
+ self.conv3_3 = nn.Conv2d(256, 256, (3, 3), stride=1, padding=1)
+ self.relu3_3 = nn.ReLU(inplace=True)
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+ self.conv4_1 = nn.Conv2d(256, 512, (3, 3), stride=1, padding=1)
+ self.relu4_1 = nn.ReLU(inplace=True)
+ self.conv4_2 = nn.Conv2d(512, 512, (3, 3), stride=1, padding=1)
+ self.relu4_2 = nn.ReLU(inplace=True)
+ self.conv4_3 = nn.Conv2d(512, 512, (3, 3), stride=1, padding=1)
+ self.relu4_3 = nn.ReLU(inplace=True)
+ self.pool4 = nn.MaxPool2d(2, stride=1, ceil_mode=True)
+ self.conv5_1 = nn.Conv2d(512, 512, (3, 3), stride=1, padding=2, dilation=2)
+ self.relu5_1 = nn.ReLU(inplace=True)
+ self.conv5_2 = nn.Conv2d(512, 512, (3, 3), stride=1, padding=2, dilation=2)
+ self.relu5_2 = nn.ReLU(inplace=True)
+ self.conv5_3 = nn.Conv2d(512, 512, (3, 3), stride=1, padding=2, dilation=2)
+ self.relu5_3 = nn.ReLU(inplace=True)
+ if pretrain:
+ if '.npy' in pretrain:
+ state_dict = np.load(pretrain).item()
+ for k in state_dict:
+ state_dict[k] = torch.from_numpy(state_dict[k])
+ else:
+ state_dict = torch.load(pretrain)
+ own_state_dict = self.state_dict()
+ for name, param in own_state_dict.items():
+ if name in state_dict:
+ if logger:
+ logger.info('copy the weights of %s from pretrained model' % name)
+ param.copy_(state_dict[name])
+ else:
+ if logger:
+ logger.info('init the weights of %s from mean 0, std 0.01 gaussian distribution'\
+ % name)
+ if 'bias' in name:
+ param.zero_()
+ else:
+ param.normal_(0, 0.01)
+ else:
+ self._initialize_weights(logger)
+
+ def forward(self, x):
+ conv1_1 = self.relu1_1(self.conv1_1(x))
+ conv1_2 = self.relu1_2(self.conv1_2(conv1_1))
+ pool1 = self.pool1(conv1_2)
+ conv2_1 = self.relu2_1(self.conv2_1(pool1))
+ conv2_2 = self.relu2_2(self.conv2_2(conv2_1))
+ pool2 = self.pool2(conv2_2)
+ conv3_1 = self.relu3_1(self.conv3_1(pool2))
+ conv3_2 = self.relu3_2(self.conv3_2(conv3_1))
+ conv3_3 = self.relu3_3(self.conv3_3(conv3_2))
+ pool3 = self.pool3(conv3_3)
+ conv4_1 = self.relu4_1(self.conv4_1(pool3))
+ conv4_2 = self.relu4_2(self.conv4_2(conv4_1))
+ conv4_3 = self.relu4_3(self.conv4_3(conv4_2))
+ pool4 = self.pool4(conv4_3)
+ # pool4 = conv4_3
+ conv5_1 = self.relu5_1(self.conv5_1(pool4))
+ conv5_2 = self.relu5_2(self.conv5_2(conv5_1))
+ conv5_3 = self.relu5_3(self.conv5_3(conv5_2))
+
+ side = [conv1_1, conv1_2, conv2_1, conv2_2,
+ conv3_1, conv3_2, conv3_3, conv4_1,
+ conv4_2, conv4_3, conv5_1, conv5_2, conv5_3]
+ return side
+
+ def _initialize_weights(self, logger=None):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ if logger:
+ logger.info('init the weights of %s from mean 0, std 0.01 gaussian distribution'\
+ % m)
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Linear):
+ m.weight.data.normal_(0, 0.01)
+ m.bias.data.zero_()
+
+if __name__ == '__main__':
+ model = VGG16_C()
+ # im = np.zeros((1,3,100,100))
+ # out = model(Variable(torch.from_numpy(im)))
+
+
diff --git a/options/__init__.py b/options/__init__.py
new file mode 100755
index 0000000..e69de29
diff --git a/options/base_options.py b/options/base_options.py
new file mode 100755
index 0000000..5165877
--- /dev/null
+++ b/options/base_options.py
@@ -0,0 +1,145 @@
+import sys
+import argparse
+import os
+from util import util
+import torch
+import models
+import data
+import pickle
+import pdb
+
+class BaseOptions():
+ def __init__(self):
+ self.initialized = False
+
+ def initialize(self, parser):
+ # experiment specifics
+ parser.add_argument('--dist_url', type=str, default='tcp://127.0.0.1:10002')
+ parser.add_argument('--num_gpu', type=int, default=8, help='num of gpus for cluter training')
+ parser.add_argument('--name', type=str, default='open-edit', help='name of the experiment. It decides where to store samples and models')
+
+ parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
+ parser.add_argument('--vocab_path', type=str, default='vocab/conceptual_vocab.pkl', help='path to vocabulary')
+ parser.add_argument('--vse_enc_path', type=str, default='checkpoints/conceptual_model_best.pth.tar', help='path to the pretrained text encoder')
+ parser.add_argument('--edge_model_path', type=str, default='checkpoints/bdcn_pretrained_on_bsds500.pth', help='path to the pretrained edge extractor')
+ parser.add_argument('--model', type=str, default='OpenEdit', help='which model to use')
+ parser.add_argument('--norm_G', type=str, default='spectralsync_batch', help='instance normalization or batch normalization')
+ parser.add_argument('--norm_D', type=str, default='spectralinstance', help='instance normalization or batch normalization')
+ parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
+
+ # input/output sizes
+ parser.add_argument('--batchSize', type=int, default=8, help='input batch size')
+ parser.add_argument('--img_size', type=int, default=224, help='image size')
+ parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
+
+ # for setting inputs
+ parser.add_argument('--dataroot', type=str, default='./datasets/conceptual/')
+ parser.add_argument('--dataset_mode', type=str, default='conceptual')
+ parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
+ parser.add_argument('--nThreads', default=4, type=int, help='# threads for loading data')
+ parser.add_argument('--load_from_opt_file', action='store_true', help='load the options from checkpoints and use that as default')
+
+ # for displays
+ parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
+
+ # for generator
+ parser.add_argument('--netG', type=str, default='openedit', help='generator model')
+ parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
+ parser.add_argument('--init_type', type=str, default='xavier', help='network initialization [normal|xavier|kaiming|orthogonal]')
+ parser.add_argument('--init_variance', type=float, default=0.02, help='variance of the initialization distribution')
+
+ # for encoder
+ parser.add_argument('--netE', type=str, default='resnetbdcn')
+ parser.add_argument('--edge_nc', type=int, default=1)
+ parser.add_argument('--edge_level', type=int, default=41)
+ parser.add_argument('--edge_tanh', action='store_true')
+
+ # for image-specific finetuning
+ parser.add_argument('--reg_weight', type=float, default=1e-4)
+ parser.add_argument('--perturbation', action='store_true')
+ parser.add_argument('--manipulation', action='store_true')
+ parser.add_argument('--img_path', type=str)
+ parser.add_argument('--ori_cap', type=str)
+ parser.add_argument('--new_cap', type=str)
+ parser.add_argument('--global_edit', action='store_true')
+ parser.add_argument('--alpha', type=int, default=5)
+ parser.add_argument('--optimize_iter', type=int, default=50)
+
+ self.initialized = True
+ return parser
+
+ def gather_options(self):
+ # initialize parser with basic options
+ if not self.initialized:
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser = self.initialize(parser)
+
+ # get the basic options
+ opt, unknown = parser.parse_known_args()
+
+ # if there is opt_file, load it.
+ # The previous default options will be overwritten
+ if opt.load_from_opt_file:
+ parser = self.update_options_from_file(parser, opt)
+
+ opt = parser.parse_args()
+ self.parser = parser
+ return opt
+
+ def print_options(self, opt):
+ message = ''
+ message += '----------------- Options ---------------\n'
+ for k, v in sorted(vars(opt).items()):
+ comment = ''
+ default = self.parser.get_default(k)
+ if v != default:
+ comment = '\t[default: %s]' % str(default)
+ message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
+ message += '----------------- End -------------------'
+ print(message)
+
+ def option_file_path(self, opt, makedir=False):
+ expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
+ if makedir:
+ util.mkdirs(expr_dir)
+ file_name = os.path.join(expr_dir, 'opt')
+ return file_name
+
+ def save_options(self, opt):
+ file_name = self.option_file_path(opt, makedir=True)
+ with open(file_name + '.txt', 'wt') as opt_file:
+ for k, v in sorted(vars(opt).items()):
+ comment = ''
+ default = self.parser.get_default(k)
+ if v != default:
+ comment = '\t[default: %s]' % str(default)
+ opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment))
+
+ with open(file_name + '.pkl', 'wb') as opt_file:
+ pickle.dump(opt, opt_file)
+
+ def update_options_from_file(self, parser, opt):
+ new_opt = self.load_options(opt)
+ for k, v in sorted(vars(opt).items()):
+ if hasattr(new_opt, k) and v != getattr(new_opt, k):
+ new_val = getattr(new_opt, k)
+ parser.set_defaults(**{k: new_val})
+ return parser
+
+ def load_options(self, opt):
+ file_name = self.option_file_path(opt, makedir=False)
+ new_opt = pickle.load(open(file_name + '.pkl', 'rb'))
+ return new_opt
+
+ def parse(self, save=False):
+
+ opt = self.gather_options()
+ opt.isTrain = self.isTrain # train or test
+
+ self.print_options(opt)
+ if opt.isTrain and save:
+ self.save_options(opt)
+
+ self.opt = opt
+ return self.opt
diff --git a/options/train_options.py b/options/train_options.py
new file mode 100755
index 0000000..b6e75f0
--- /dev/null
+++ b/options/train_options.py
@@ -0,0 +1,39 @@
+from .base_options import BaseOptions
+
+
+class TrainOptions(BaseOptions):
+ def initialize(self, parser):
+ BaseOptions.initialize(self, parser)
+ # for displays
+ parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
+ parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs')
+ parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
+ parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed')
+
+ # for training
+ parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
+ parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
+ parser.add_argument('--niter', type=int, default=50, help='# of iter at starting learning rate. This is NOT the total #epochs. Totla #epochs is niter + niter_decay')
+ parser.add_argument('--niter_decay', type=int, default=50, help='# of iter to linearly decay learning rate to zero')
+ parser.add_argument('--optimizer', type=str, default='adam')
+ parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
+ parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam')
+ parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
+ parser.add_argument('--D_steps_per_G', type=int, default=1, help='number of discriminator iterations per generator iteration.')
+ parser.add_argument('--G_steps_per_D', type=int, default=1, help='number of generator iterations per discriminator iteration')
+
+ # for discriminators
+ parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
+ parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
+ parser.add_argument('--lambda_vgg', type=float, default=10.0, help='weight for vgg loss')
+ parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss')
+ parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss')
+ parser.add_argument('--no_l1feat_loss', action='store_true', help='if specified, do not use perceptual loss')
+ parser.add_argument('--l1pix_loss', action='store_true', help='if specified, use l1 loss on image pixels')
+ parser.add_argument('--gan_mode', type=str, default='hinge', help='(ls|original|hinge)')
+ parser.add_argument('--netD', type=str, default='multiscale', help='(n_layers|multiscale)')
+ parser.add_argument('--no_TTUR', action='store_true', help='Use TTUR training scheme')
+ parser.add_argument('--no_disc', action='store_true', help='no discriminator')
+
+ self.isTrain = True
+ return parser
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..ab19c50
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,3 @@
+dominate>=2.3.1
+dill
+scikit-image
diff --git a/test.py b/test.py
new file mode 100755
index 0000000..ca27468
--- /dev/null
+++ b/test.py
@@ -0,0 +1,108 @@
+import os
+from collections import OrderedDict
+
+import data
+from options.train_options import TrainOptions
+from models.OpenEdit_model import OpenEditModel
+from trainers.OpenEdit_optimizer import OpenEditOptimizer
+from util.visualizer import Visualizer
+from util import html
+import torch
+import torch.utils.data as data
+import torchvision.transforms as transforms
+import os
+from PIL import Image
+import numpy as np
+import json
+import pdb
+import pickle
+from util.vocab import Vocabulary
+
+TrainOptions = TrainOptions()
+opt = TrainOptions.parse()
+opt.gpu = 0
+
+ori_cap = opt.ori_cap.split()
+new_cap = opt.new_cap.split()
+import pdb
+pdb.set_trace()
+global_edit = False
+
+alpha = 5
+optimize_iter = 10
+
+opt.world_size = 1
+opt.rank = 0
+opt.mpdist = False
+opt.num_gpu = 1
+opt.batchSize = 1
+opt.manipulation = True
+opt.perturbation = True
+
+open_edit_optimizer = OpenEditOptimizer(opt)
+open_edit_optimizer.open_edit_model.netG.eval()
+
+# optimizer
+visualizer = Visualizer(opt, rank=0)
+
+# create a webpage that summarizes the all results
+web_dir = os.path.join('visual_results', opt.name,
+ '%s_%s' % (opt.phase, opt.which_epoch))
+webpage = html.HTML(web_dir,
+ 'Experiment = %s, Phase = %s, Epoch = %s' %
+ (opt.name, opt.phase, opt.which_epoch))
+
+# image loader
+transforms_list = []
+transforms_list.append(transforms.Resize((opt.img_size, opt.img_size)))
+transforms_list += [transforms.ToTensor()]
+transforms_list += [transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]
+transform = transforms.Compose(transforms_list)
+image = Image.open(opt.img_path).convert('RGB')
+image = transform(image)
+image = image.unsqueeze(0).cuda()
+
+# text loader
+vocab = pickle.load(open('vocab/'+opt.dataset_mode+'_vocab.pkl', 'rb'))
+ori_txt = []
+ori_txt.append(vocab(''))
+for word in ori_cap:
+ ori_txt.append(vocab(word))
+ori_txt.append(vocab(''))
+ori_txt = torch.LongTensor(ori_txt).unsqueeze(0).cuda()
+new_txt = []
+new_txt.append(vocab(''))
+for word in new_cap:
+ new_txt.append(vocab(word))
+new_txt.append(vocab(''))
+new_txt = torch.LongTensor(new_txt).unsqueeze(0).cuda()
+
+data = {'image': image, 'caption': new_txt, 'length': [4]}
+
+# save input image
+visuals = OrderedDict([('input_image', image[0])])
+
+# reconstruct original image
+reconstructed = open_edit_optimizer.open_edit_model(data, mode='inference')[0]
+visuals['reconstructed'] = reconstructed
+
+# manipulate without optimizing perturbations
+manipulated_ori = open_edit_optimizer.open_edit_model(data, mode='manipulate', ori_cap=ori_txt, new_cap=new_txt, alpha=alpha)
+visuals['manipulated_ori'] = manipulated_ori[0][0]
+
+# optimize perturbations
+for iter_cnt in range(optimize_iter):
+ open_edit_optimizer.run_opt_one_step(data, ori_txt, new_txt, alpha, global_edit=global_edit)
+ message = '(optimization, iters: %d) ' % iter_cnt
+ errors = open_edit_optimizer.get_latest_losses()
+ for k, v in errors.items():
+ v = v.mean().float()
+ message += '%s: %.3f ' % (k, v)
+ print(message)
+
+# manipulation results after optimizing perturbations
+visuals['optimized'] = open_edit_optimizer.get_latest_generated()[0]
+
+
+visualizer.save_images(webpage, visuals, [opt.img_path], gray=True)
+webpage.save()
diff --git a/test.sh b/test.sh
new file mode 100644
index 0000000..bf01a62
--- /dev/null
+++ b/test.sh
@@ -0,0 +1,8 @@
+python test.py \
+ --name OpenEdit \
+ --img_path test_imgs/car_blue.jpg \
+ --ori_cap 'red car' \
+ --new_cap 'blue car' \
+ --edge_level 41 \
+ --lr 0.002 \
+ --which_epoch latest
diff --git a/train.py b/train.py
new file mode 100755
index 0000000..15c3610
--- /dev/null
+++ b/train.py
@@ -0,0 +1,87 @@
+import sys
+from collections import OrderedDict
+import data
+import torch.multiprocessing as mp
+import torch.distributed as dist
+import torch
+from util.iter_counter import IterationCounter
+from util.visualizer import Visualizer
+from trainers.OpenEdit_trainer import OpenEditTrainer
+from options.train_options import TrainOptions
+
+def main_worker(gpu, world_size, opt):
+ print('Use GPU: {} for training'.format(gpu))
+ world_size = opt.world_size
+ rank = gpu
+ opt.gpu = gpu
+ dist.init_process_group(backend='nccl', init_method=opt.dist_url, world_size=world_size, rank=rank)
+ torch.cuda.set_device(gpu)
+
+ # load the dataset
+ dataloader = data.create_dataloader(opt, world_size, rank)
+
+ # create trainer for our model
+ trainer = OpenEditTrainer(opt)
+
+ # create tool for counting iterations
+ iter_counter = IterationCounter(opt, len(dataloader), world_size, rank)
+
+ # create tool for visualization
+ visualizer = Visualizer(opt, rank)
+
+ for epoch in iter_counter.training_epochs():
+ if opt.mpdist:
+ dataloader.sampler.set_epoch(epoch)
+ iter_counter.record_epoch_start(epoch)
+ for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
+ iter_counter.record_one_iteration()
+
+ # Training
+ # train generator
+ if i % opt.D_steps_per_G == 0:
+ trainer.run_generator_one_step(data_i)
+
+ # train discriminator
+ if not opt.no_disc and i % opt.G_steps_per_D == 0:
+ trainer.run_discriminator_one_step(data_i)
+
+ iter_counter.record_iteration_end()
+
+ # Visualizations
+ if iter_counter.needs_printing():
+ losses = trainer.get_latest_losses()
+ visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
+ losses, iter_counter.time_per_iter,
+ iter_counter.model_time_per_iter)
+ visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far)
+
+
+ visuals = OrderedDict([('synthesized_image', trainer.get_latest_generated()),
+ ('real_image', data_i['image'])])
+ visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far)
+
+ if rank == 0:
+ print('saving the latest model (epoch %d, total_steps %d)' %
+ (epoch, iter_counter.total_steps_so_far))
+ trainer.save('latest')
+ iter_counter.record_current_iter()
+
+ trainer.update_learning_rate(epoch)
+ iter_counter.record_epoch_end()
+
+ if (epoch % opt.save_epoch_freq == 0 or epoch == iter_counter.total_epochs) and (rank == 0):
+ print('saving the model at the end of epoch %d, iters %d' %
+ (epoch, iter_counter.total_steps_so_far))
+ trainer.save(epoch)
+
+ print('Training was successfully finished.')
+
+if __name__ == '__main__':
+ global TrainOptions
+ TrainOptions = TrainOptions()
+ opt = TrainOptions.parse(save=True)
+ opt.world_size = opt.num_gpu
+ opt.mpdist = True
+
+ mp.set_start_method('spawn', force=True)
+ mp.spawn(main_worker, nprocs=opt.world_size, args=(opt.world_size, opt))
diff --git a/train.sh b/train.sh
new file mode 100644
index 0000000..eb1640d
--- /dev/null
+++ b/train.sh
@@ -0,0 +1,5 @@
+python train.py \
+ --name Opendit \
+ --edge_level 41 \
+ --nThreads 4 \
+ --batchSize 4
diff --git a/trainers/OpenEdit_optimizer.py b/trainers/OpenEdit_optimizer.py
new file mode 100644
index 0000000..c9c6b18
--- /dev/null
+++ b/trainers/OpenEdit_optimizer.py
@@ -0,0 +1,35 @@
+from models.OpenEdit_model import OpenEditModel
+import torch
+import torch.nn as nn
+
+class OpenEditOptimizer():
+ """
+ Optimizer perturbation parameters during testing
+ """
+
+ def __init__(self, opt):
+ self.opt = opt
+ self.open_edit_model = OpenEditModel(opt)
+ self.open_edit_model_on_one_gpu = self.open_edit_model
+
+ self.generated = None
+
+ self.optimizer_P = self.open_edit_model_on_one_gpu.create_P_optimizers(opt)
+ self.old_lr = opt.lr
+
+ def run_opt_one_step(self, data, ori_cap, new_cap, alpha, global_edit=False):
+ self.optimizer_P.zero_grad()
+ r_losses, generated = self.open_edit_model(
+ data, mode='optimize', ori_cap=ori_cap, new_cap=new_cap,
+ alpha=alpha, global_edit=global_edit)
+ r_loss = sum(r_losses.values()).mean()
+ r_loss.backward()
+ self.optimizer_P.step()
+ self.r_losses = r_losses
+ self.generated = generated
+
+ def get_latest_losses(self):
+ return {**self.r_losses}
+
+ def get_latest_generated(self):
+ return self.generated
diff --git a/trainers/OpenEdit_trainer.py b/trainers/OpenEdit_trainer.py
new file mode 100644
index 0000000..e920958
--- /dev/null
+++ b/trainers/OpenEdit_trainer.py
@@ -0,0 +1,42 @@
+from models.OpenEdit_model import OpenEditModel
+import torch
+
+class OpenEditTrainer():
+ """
+ Trainer creates the model and optimizers, and uses them to
+ updates the weights of the network while reporting losses
+ and the latest visuals to visualize the progress in training.
+ """
+
+ def __init__(self, opt):
+ self.opt = opt
+ self.open_edit_model = OpenEditModel(opt)
+ if opt.mpdist:
+ self.open_edit_model = torch.nn.parallel.DistributedDataParallel(self.open_edit_model, device_ids=[opt.gpu], find_unused_parameters=True)
+ self.open_edit_model_on_one_gpu = self.open_edit_model.module
+ else:
+ self.open_edit_model_on_one_gpu = self.open_edit_model
+
+ self.generated = None
+ self.loss_gan = not opt.no_disc
+ if opt.isTrain:
+ self.optimizer_G, self.optimizer_D = \
+ self.open_edit_model_on_one_gpu.create_optimizers(opt)
+ self.old_lr = opt.lr
+
+ def run_generator_one_step(self, data):
+ self.optimizer_G.zero_grad()
+ g_losses, generated = self.open_edit_model(data, mode='generator')
+ g_loss = sum(g_losses.values()).mean()
+ g_loss.backward()
+ self.optimizer_G.step()
+ self.g_losses = g_losses
+ self.generated = generated
+
+ def run_discriminator_one_step(self, data):
+ self.optimizer_D.zero_grad()
+ d_losses = self.open_edit_model(data, mode='discriminator')
+ d_loss = sum(d_losses.values()).mean()
+ d_loss.backward()
+ self.optimizer_D.step()
+ self.d_losses = d_losses
diff --git a/trainers/__init__.py b/trainers/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/util/__init__.py b/util/__init__.py
new file mode 100755
index 0000000..e69de29
diff --git a/util/html.py b/util/html.py
new file mode 100755
index 0000000..79a64aa
--- /dev/null
+++ b/util/html.py
@@ -0,0 +1,72 @@
+import datetime
+import dominate
+from dominate.tags import *
+import os
+
+
+class HTML:
+ def __init__(self, web_dir, title, refresh=0):
+ if web_dir.endswith('.html'):
+ web_dir, html_name = os.path.split(web_dir)
+ else:
+ web_dir, html_name = web_dir, 'index.html'
+ self.title = title
+ self.web_dir = web_dir
+ self.html_name = html_name
+ self.img_dir = os.path.join(self.web_dir, 'images')
+ if len(self.web_dir) > 0 and not os.path.exists(self.web_dir):
+ os.makedirs(self.web_dir)
+ if len(self.web_dir) > 0 and not os.path.exists(self.img_dir):
+ os.makedirs(self.img_dir)
+
+ self.doc = dominate.document(title=title)
+ with self.doc:
+ h1(datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y"))
+ if refresh > 0:
+ with self.doc.head:
+ meta(http_equiv="refresh", content=str(refresh))
+
+ def get_image_dir(self):
+ return self.img_dir
+
+ def add_header(self, str):
+ with self.doc:
+ h3(str)
+
+ def add_table(self, border=1):
+ self.t = table(border=border, style="table-layout: fixed;")
+ self.doc.add(self.t)
+
+ def add_images(self, ims, txts, links, width=512):
+ self.add_table()
+ with self.t:
+ with tr():
+ for im, txt, link in zip(ims, txts, links):
+ with td(style="word-wrap: break-word;", halign="center", valign="top"):
+ with p():
+ with a(href=os.path.join('images', link)):
+ img(style="width:%dpx" % (width), src=os.path.join('images', im))
+ br()
+ p(txt.encode('utf-8'))
+
+ def save(self):
+ #html_file = '%s/%s' % (self.web_dir, self.html_name)
+ html_file = os.path.join(self.web_dir, self.html_name)
+ f = open(html_file, 'wt')
+ f.write(self.doc.render())
+ f.close()
+
+
+if __name__ == '__main__':
+ html = HTML('web/', 'test_html')
+ html.add_header('hello world')
+
+ ims = []
+ txts = []
+ links = []
+ for n in range(4):
+ ims.append('image_%d.jpg' % n)
+ txts.append('text_%d' % n)
+ links.append('image_%d.jpg' % n)
+ html.add_images(ims, txts, links)
+ html.save()
diff --git a/util/iter_counter.py b/util/iter_counter.py
new file mode 100644
index 0000000..afdf34b
--- /dev/null
+++ b/util/iter_counter.py
@@ -0,0 +1,71 @@
+import os
+import time
+import numpy as np
+
+
+## Helper class that keeps track of training iterations
+class IterationCounter():
+ def __init__(self, opt, dataset_size, world_size=1, rank=0):
+ self.opt = opt
+ self.dataset_size = dataset_size
+
+ self.first_epoch = 1
+ self.total_epochs = opt.niter + opt.niter_decay
+ self.epoch_iter = 0 # iter number within each epoch
+ self.iter_record_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, 'iter.txt')
+ if opt.isTrain and opt.continue_train:
+ try:
+ self.first_epoch, self.epoch_iter = np.loadtxt(
+ self.iter_record_path, delimiter=',', dtype=int)
+ print('Resuming from epoch %d at iteration %d' % (self.first_epoch, self.epoch_iter))
+ except:
+ print('Could not load iteration record at %s. Starting from beginning.' %
+ self.iter_record_path)
+
+ self.total_steps_so_far = (self.first_epoch - 1) * dataset_size + self.epoch_iter
+
+ self.world_size = world_size
+ self.rank = rank
+
+ # return the iterator of epochs for the training
+ def training_epochs(self):
+ return range(self.first_epoch, self.total_epochs + 1)
+
+ def record_epoch_start(self, epoch):
+ self.epoch_start_time = time.time()
+ self.epoch_iter = 0
+ self.last_iter_time = time.time()
+ self.current_epoch = epoch
+
+ def record_one_iteration(self):
+ current_time = time.time()
+
+ self.time_per_iter = (current_time - self.last_iter_time) / (self.opt.batchSize * self.world_size)
+ self.last_iter_time = current_time
+ self.total_steps_so_far += self.opt.batchSize * self.world_size
+ self.epoch_iter += self.opt.batchSize * self.world_size
+
+ def record_iteration_end(self):
+ current_time = time.time()
+
+ self.model_time_per_iter = (current_time - self.last_iter_time) / (self.opt.batchSize * self.world_size)
+
+ def record_epoch_end(self):
+ current_time = time.time()
+ self.time_per_epoch = current_time - self.epoch_start_time
+ if self.rank == 0:
+ print('End of epoch %d / %d \t Time Taken: %d sec' %
+ (self.current_epoch, self.total_epochs, self.time_per_epoch))
+ if self.current_epoch % self.opt.save_epoch_freq == 0:
+ np.savetxt(self.iter_record_path, (self.current_epoch+1, 0),
+ delimiter=',', fmt='%d')
+ print('Saved current iteration count at %s.' % self.iter_record_path)
+
+ def record_current_iter(self):
+ if self.rank == 0:
+ np.savetxt(self.iter_record_path, (self.current_epoch, self.epoch_iter),
+ delimiter=',', fmt='%d')
+ print('Saved current iteration count at %s.' % self.iter_record_path)
+
+ def needs_printing(self):
+ return (self.total_steps_so_far % self.opt.print_freq) < self.opt.batchSize * self.world_size
diff --git a/util/util.py b/util/util.py
new file mode 100755
index 0000000..7d09177
--- /dev/null
+++ b/util/util.py
@@ -0,0 +1,100 @@
+import re
+import importlib
+import torch
+from argparse import Namespace
+import numpy as np
+from PIL import Image
+import os
+import argparse
+
+
+# Converts a Tensor into a Numpy array
+# |imtype|: the desired type of the converted numpy array
+def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False):
+ if isinstance(image_tensor, list):
+ image_numpy = []
+ for i in range(len(image_tensor)):
+ image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
+ return image_numpy
+
+ if image_tensor.dim() == 4:
+ # transform each image in the batch
+ images_np = []
+ for b in range(image_tensor.size(0)):
+ one_image = image_tensor[b]
+ one_image_np = tensor2im(one_image)
+ images_np.append(one_image_np.reshape(1, *one_image_np.shape))
+ images_np = np.concatenate(images_np, axis=0)
+ if tile:
+ images_tiled = tile_images(images_np)
+ return images_tiled
+ else:
+ return images_np
+
+ if image_tensor.dim() == 2:
+ image_tensor = image_tensor.unsqueeze(0)
+ image_numpy = image_tensor.detach().cpu().float().numpy()
+ if normalize:
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
+ else:
+ image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
+ image_numpy = np.clip(image_numpy, 0, 255)
+ if image_numpy.shape[2] == 1:
+ image_numpy = image_numpy[:,:,0]
+ return image_numpy.astype(imtype)
+
+def save_image(image_numpy, image_path, create_dir=False, gray=False):
+ if create_dir:
+ os.makedirs(os.path.dirname(image_path), exist_ok=True)
+ ## save to png
+ if gray:
+ if (image_numpy.shape) == 3:
+ assert(image_numpy.shape[2] == 1)
+ image_numpy = image_numpy.squeeze(2)
+ image_pil = Image.fromarray(image_numpy)
+ image_pil.save(image_path.replace('.jpg', '.png'))
+ else:
+ if len(image_numpy.shape) == 2 and not gray:
+ image_numpy = np.expand_dims(image_numpy, axis=2)
+ if image_numpy.shape[2] == 1 and not gray:
+ image_numpy = np.repeat(image_numpy, 3, 2)
+ image_pil = Image.fromarray(image_numpy)
+ image_pil.save(image_path.replace('.jpg', '.png'))
+
+def mkdirs(paths):
+ if isinstance(paths, list) and not isinstance(paths, str):
+ for path in paths:
+ mkdir(path)
+ else:
+ mkdir(paths)
+
+def mkdir(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+def find_class_in_module(target_cls_name, module):
+ target_cls_name = target_cls_name.replace('_', '').lower()
+ clslib = importlib.import_module(module)
+ cls = None
+ for name, clsobj in clslib.__dict__.items():
+ if name.lower() == target_cls_name:
+ cls = clsobj
+
+ if cls is None:
+ print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name))
+ exit(0)
+
+ return cls
+
+def save_network(net, label, epoch, opt):
+ save_filename = '%s_net_%s.pth' % (epoch, label)
+ save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename)
+ torch.save(net.state_dict(), save_path) # net.cpu() -> net
+
+def load_network(net, label, epoch, opt):
+ save_filename = '%s_net_%s.pth' % (epoch, label)
+ save_dir = os.path.join(opt.checkpoints_dir, opt.name)
+ save_path = os.path.join(save_dir, save_filename)
+ weights = torch.load(save_path, map_location=torch.device('cpu'))
+ net.load_state_dict(weights)
+ return net
diff --git a/util/visualizer.py b/util/visualizer.py
new file mode 100755
index 0000000..049a3d3
--- /dev/null
+++ b/util/visualizer.py
@@ -0,0 +1,176 @@
+import os
+import ntpath
+import time
+from . import util
+from . import html
+import scipy.misc
+try:
+ from StringIO import StringIO # Python 2.7
+except ImportError:
+ from io import BytesIO # Python 3.x
+
+class Visualizer():
+ def __init__(self, opt, rank=0):
+ self.rank = rank
+ self.opt = opt
+ self.tf_log = opt.isTrain and opt.tf_log
+ self.use_html = opt.isTrain and not opt.no_html
+ self.win_size = opt.display_winsize
+ self.name = opt.name
+ if self.tf_log:
+ import tensorflow as tf
+ self.tf = tf
+ self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs')
+ if self.rank == 0:
+ self.writer = tf.summary.FileWriter(self.log_dir)
+
+ if self.use_html:
+ self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
+ self.img_dir = os.path.join(self.web_dir, 'images')
+ if self.rank == 0:
+ print('create web directory %s...' % self.web_dir)
+ util.mkdirs([self.web_dir, self.img_dir])
+ if opt.isTrain:
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
+ if self.rank == 0:
+ with open(self.log_name, "a") as log_file:
+ now = time.strftime("%c")
+ log_file.write('================ Training Loss (%s) ================\n' % now)
+
+ # |visuals|: dictionary of images to display or save
+ def display_current_results(self, visuals, epoch, step):
+
+ ## convert tensors to numpy arrays
+ visuals = self.convert_visuals_to_numpy(visuals)
+
+ if 0: # do not show images in tensorboard output
+ img_summaries = []
+ for label, image_numpy in visuals.items():
+ # Write the image to a string
+ try:
+ s = StringIO()
+ except:
+ s = BytesIO()
+ if len(image_numpy.shape) >= 4:
+ image_numpy = image_numpy[0]
+ if self.rank == 0:
+ scipy.misc.toimage(image_numpy).save(s, format="jpeg")
+ # Create an Image object
+ img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1])
+ # Create a Summary value
+ img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum))
+
+ if self.rank == 0:
+ # Create and write Summary
+ summary = self.tf.Summary(value=img_summaries)
+ self.writer.add_summary(summary, step)
+
+ if self.use_html: # save images to a html file
+ for label, image_numpy in visuals.items():
+ if isinstance(image_numpy, list):
+ for i in range(len(image_numpy)):
+ img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%d.png' % (epoch, label, i))
+ if self.rank == 0:
+ util.save_image(image_numpy[i], img_path)
+ else:
+ img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
+ if len(image_numpy.shape) >= 4:
+ image_numpy = image_numpy[0]
+ if self.rank == 0:
+ util.save_image(image_numpy, img_path)
+
+ if self.rank == 0:
+ # update website
+ webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=5)
+ for n in range(epoch, 0, -1):
+ webpage.add_header('epoch [%d]' % n)
+ ims = []
+ txts = []
+ links = []
+
+ for label, image_numpy in visuals.items():
+ if isinstance(image_numpy, list):
+ for i in range(len(image_numpy)):
+ img_path = 'epoch%.3d_%s_%d.png' % (n, label, i)
+ ims.append(img_path)
+ txts.append(label+str(i))
+ links.append(img_path)
+ else:
+ img_path = 'epoch%.3d_%s.png' % (n, label)
+ ims.append(img_path)
+ txts.append(label)
+ links.append(img_path)
+ if len(ims) < 10:
+ webpage.add_images(ims, txts, links, width=self.win_size)
+ else:
+ num = int(round(len(ims)/2.0))
+ webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size)
+ webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size)
+ webpage.save()
+
+ # errors: dictionary of error labels and values
+ def plot_current_errors(self, errors, step):
+ if self.tf_log:
+ for tag, value in errors.items():
+ value = value.mean().float()
+ if self.rank == 0:
+ summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)])
+ self.writer.add_summary(summary, step)
+
+ # errors: same format as |errors| of plotCurrentErrors
+ def print_current_errors(self, epoch, i, errors, t, mt):
+ message = '(epoch: %d, iters: %d, time: %.3f, model_time: %.3f) ' % (epoch, i, t, mt)
+ for k, v in errors.items():
+ #print(v)
+ #if v != 0:
+ v = v.mean().float()
+ message += '%s: %.3f ' % (k, v)
+
+ if self.rank == 0:
+ print(message)
+ with open(self.log_name, "a") as log_file:
+ log_file.write('%s\n' % message)
+
+ def convert_visuals_to_numpy(self, visuals, gray=False):
+ for key, t in visuals.items():
+ tile = self.opt.batchSize > 1
+ if 'input_label' == key and not gray:
+ t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile)
+ elif 'input_label' == key and gray:
+ t = util.tensor2labelgray(t, self.opt.label_nc + 2, tile=tile)
+ else:
+ t = util.tensor2im(t, tile=tile)
+ visuals[key] = t
+ return visuals
+
+ # save image to the disk
+ def save_images(self, webpage, visuals, image_path, gray=False):
+ visuals = self.convert_visuals_to_numpy(visuals, gray=gray)
+
+ image_dir = webpage.get_image_dir()
+ short_path = ntpath.basename(image_path[0])
+ name = os.path.splitext(short_path)[0]
+
+ if self.rank == 0:
+ webpage.add_header(name)
+ ims = []
+ txts = []
+ links = []
+
+ cnt = 0
+ for label, image_numpy in visuals.items():
+ image_name = os.path.join(label, '%s.png' % (name))
+ save_path = os.path.join(image_dir, image_name)
+ if self.rank == 0:
+ util.save_image(image_numpy, save_path, create_dir=True)
+
+ ims.append(image_name)
+ txts.append(label)
+ links.append(image_name)
+ cnt += 1
+ if cnt % 4 == 0:
+ if self.rank == 0:
+ webpage.add_images(ims, txts, links, width=self.win_size)
+ ims = []
+ txts = []
+ links = []
diff --git a/util/vocab.py b/util/vocab.py
new file mode 100644
index 0000000..4ad2412
--- /dev/null
+++ b/util/vocab.py
@@ -0,0 +1,78 @@
+# Create a vocabulary wrapper
+import nltk
+import pickle
+from collections import Counter
+import json
+import argparse
+import os
+
+annotations = {
+ 'conceptual': ['train_caption.json']
+}
+
+
+class Vocabulary(object):
+ """Simple vocabulary wrapper."""
+
+ def __init__(self):
+ self.word2idx = {}
+ self.idx2word = {}
+ self.idx = 0
+
+ def add_word(self, word):
+ if word not in self.word2idx:
+ self.word2idx[word] = self.idx
+ self.idx2word[self.idx] = word
+ self.idx += 1
+
+ def __call__(self, word):
+ if word not in self.word2idx:
+ return self.word2idx['']
+ return self.word2idx[word]
+
+ def __len__(self):
+ return len(self.word2idx)
+
+
+def build_vocab(data_path, data_name, jsons, threshold):
+ """Build a simple vocabulary wrapper."""
+ counter = Counter()
+ for path in jsons[data_name]:
+ full_path = os.path.join(os.path.join(os.path.join(data_path, data_name), 'img_lists'), path)
+ captions = json.load(open(full_path,'r'))
+ for i, caption in enumerate(captions):
+ tokens = nltk.tokenize.word_tokenize(caption.lower())
+ counter.update(tokens)
+
+ if i % 1000 == 0:
+ print("[%d/%d] tokenized the captions." % (i, len(captions)))
+
+ # Discard if the occurrence of the word is less than min_word_cnt.
+ words = [word for word, cnt in counter.items() if cnt >= threshold]
+
+ # Create a vocab wrapper and add some special tokens.
+ vocab = Vocabulary()
+ vocab.add_word('')
+ vocab.add_word('')
+ vocab.add_word('')
+ vocab.add_word('')
+
+ # Add words to the vocabulary.
+ for i, word in enumerate(words):
+ vocab.add_word(word)
+ return vocab
+
+
+def main(data_path, data_name):
+ vocab = build_vocab(data_path, data_name, jsons=annotations, threshold=0)
+ with open('./vocab/%s_vocab.pkl' % data_name, 'wb') as f:
+ pickle.dump(vocab, f, pickle.HIGHEST_PROTOCOL)
+ print("Saved vocabulary file to ", './vocab/%s_vocab.pkl' % data_name)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--data_path', default='data')
+ parser.add_argument('--data_name', default='conceptual')
+ opt = parser.parse_args()
+ main(opt.data_path, opt.data_name)