Skip to content

Commit

Permalink
add code
Browse files Browse the repository at this point in the history
  • Loading branch information
xh-liu committed Sep 26, 2020
1 parent 951eb27 commit 2400c86
Show file tree
Hide file tree
Showing 35 changed files with 2,461 additions and 3 deletions.
13 changes: 13 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
checkpoints/
datasets/
results/
*.tar.gz
*.pth
*.zip
*.pkl
*.pyc
*/__pycache__/
*_example/
visual_results/
test_imgs/
web/
43 changes: 40 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,51 @@
# Open-Edit: Open-Domain Image Manipulation with Open-Vocabulary Instructions

[Xihui Liu](https://xh-liu.github.io), Zhe Lin, Jianming Zhang, Handong Zhao, Quan Tran, Xiaogang Wang, and Hongsheng Li.<br>
[Xihui Liu](https://xh-liu.github.io), [Zhe Lin](https://sites.google.com/site/zhelin625/), [Jianming Zhang](http://cs-people.bu.edu/jmzhang/), [Handong Zhao](https://hdzhao.github.io/), [Quan Tran](https://research.adobe.com/person/quan-hung-tran/), [Xiaogang Wang](https://www.ee.cuhk.edu.hk/~xgwang/), and [Hongsheng Li](https://www.ee.cuhk.edu.hk/~hsli/).<br>
Published in ECCV 2020.

### [Paper](https://arxiv.org/pdf/2008.01576.pdf) | [1-minute video](https://youtu.be/8E3bwvjCHYE)
### [Paper](https://arxiv.org/pdf/2008.01576.pdf) | [1-minute video](https://youtu.be/8E3bwvjCHYE) | [Slides](https://drive.google.com/file/d/1m3JKSUotm6sRImak_qjwBMtMtd037XeK/view?usp=sharing)

![results](results.jpg)

### Code Coming Soon!
### Installation

Clone this repo.
```bash
git clone https://github.com/xh-liu/Open-Edit
cd Open-Edit
```

Install [PyTorch 1.1+](https://pytorch.org/get-started/locally/) and other requirements.
```bash

pip install -r requirements.txt
```

### Download pretrained models

Download pretrained models from [Google Drive](https://drive.google.com/drive/folders/1iG_II7_PytTY6NdzyZ5WDkzPTXcB2NcE?usp=sharing)

### Data preparation

We use [Conceptual Captions dataset](https://ai.google.com/research/ConceptualCaptions/download) for training. Download the dataset and put it under the dataset folder. You can also use other datasets

### Training

The visual-semantic embedding model is trained with [VSE++](https://github.com/fartashf/vsepp).

The image decoder is trained with:

```bash
bash train.sh
```

## Testing

You can specify the image path and text instructions in test.sh.

```bash
bash test.sh
```

### Citation
If you use this code for your research, please cite our papers.
Expand Down
60 changes: 60 additions & 0 deletions data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import importlib
import torch.utils.data

def collate_fn_img(images):
images = torch.stack(images, 0)
input_dicts = {'image': images}
return input_dicts


def find_dataset_using_name(dataset_name):
# Given the option --dataset [datasetname],
# the file "datasets/datasetname_dataset.py"
# will be imported.
dataset_filename = "data." + dataset_name + "_dataset"
datasetlib = importlib.import_module(dataset_filename)

# In the file, the class called DatasetNameDataset() will
# be instantiated. It has to be a subclass of BaseDataset,
# and it is case-insensitive.
dataset = None
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
for name, cls in datasetlib.__dict__.items():
if name.lower() == target_dataset_name.lower():
dataset = cls

if dataset is None:
raise ValueError("In %s.py, there should be a subclass of BaseDataset "
"with class name that matches %s in lowercase." %
(dataset_filename, target_dataset_name))

return dataset

def create_dataloader(opt, world_size, rank):
dataset = find_dataset_using_name(opt.dataset_mode)
instance = dataset(opt)
print("dataset [%s] of size %d was created" %
(type(instance).__name__, len(instance)))

collate_fn = collate_fn_img

if opt.mpdist:
train_sampler = torch.utils.data.distributed.DistributedSampler(instance, num_replicas=world_size, rank=rank)
dataloader = torch.utils.data.DataLoader(
instance,
batch_size=opt.batchSize,
sampler=train_sampler,
shuffle=False,
num_workers=int(opt.nThreads),
collate_fn=collate_fn,
drop_last=opt.isTrain
)
else:
dataloader = torch.utils.data.DataLoader(
instance,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads),
drop_last=opt.isTrain
)
return dataloader
32 changes: 32 additions & 0 deletions data/conceptual_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
import os
from PIL import Image
import json

class ConceptualDataset(data.Dataset):
def __init__(self, opt):
self.path = os.path.join(opt.dataroot, 'images')
if opt.isTrain:
self.ids = json.load(open(os.path.join(opt.dataroot, 'val_index.json'), 'r'))
else:
self.ids = json.load(open(os.path.join(opt.dataroot, 'val_index.json'), 'r'))

transforms_list = []
transforms_list.append(transforms.Resize((opt.img_size, opt.img_size)))
transforms_list += [transforms.ToTensor()]
transforms_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
self.transform = transforms.Compose(transforms_list)

def __getitem__(self, index):
"""This function returns a tuple that is further passed to collate_fn
"""
img_id = self.ids[index]
image = Image.open(os.path.join(self.path, img_id)).convert('RGB')
image = self.transform(image)

return image

def __len__(self):
return len(self.ids)
Loading

0 comments on commit 2400c86

Please sign in to comment.