Skip to content

Commit

Permalink
enable non-square input, update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
yiranran committed Mar 20, 2021
1 parent a96adb2 commit 2d89d97
Show file tree
Hide file tree
Showing 9 changed files with 16 additions and 17 deletions.
1 change: 1 addition & 0 deletions data/single_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __getitem__(self, index):
"""
A_path = self.A_paths[index]
A_img = Image.open(A_path).convert('RGB')
self.opt.W, self.opt.H = A_img.size
transform_params_A = get_params(self.opt, A_img.size)
A = get_transform(self.opt, transform_params_A, grayscale=(self.input_nc == 1))(A_img)
item = {'A': A, 'A_paths': A_path}
Expand Down
Binary file added imgs/architecture.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed imgs/architecture.png
Binary file not shown.
Binary file modified imgs/how_to_crop.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/result_html.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/test1/cropped2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 5 additions & 3 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ This project generates artistic portrait drawings from face photos using a GAN-b

## Our Proposed Framework

<img src = 'imgs/architecture.png'>
<img src = 'imgs/architecture.jpg'>

## Sample Results
From left to right: input, output(style1), output(style2), output(style3)
Expand All @@ -34,7 +34,7 @@ If you use this code for your research, please cite our paper.


## Installation
- Install PyTorch 1.1.0 and torchvision from http://pytorch.org and other dependencies (e.g., [visdom](https://github.com/facebookresearch/visdom) and [dominate](https://github.com/Knio/dominate)). You can install all the dependencies by
- To install the dependencies, run
```bash
pip install -r requirements.txt
```
Expand All @@ -55,7 +55,7 @@ The result images are saved in `./results/pretrained/test_200/images3styles`,
where `real`, `fake1`, `fake2`, `fake3` correspond to input face photo, style1 drawing, style2 drawing, style3 drawing respectively.

<img src = 'imgs/how_to_crop.jpg'>
- 3. To test on your own photos, the photos need to be square (since the program will load it and resized as 512x512). You can use an image editor to crop a square area of your photo that contains face (or use an optional preprocess [here](preprocess/readme.md)). Then specify the folder that contains test photos using `--dataroot`, specify save folder name using `--savefolder` and run the above command again:
- 3. To test on your own photos: First use an image editor to crop the face region of your photo (or use an optional preprocess [here](preprocess/readme.md)). Then specify the folder that contains test photos using `--dataroot`, specify save folder name using `--savefolder` and run the above command again:

``` bash
# with GPU
Expand All @@ -67,6 +67,8 @@ python test_seq_style.py --gpu -1 --dataroot ./imgs/test1 --savefolder 3styles_t
```
The test results will be saved to a html file here: `./results/pretrained/test_200/index[save_folder_name].html`.
The result images are saved in `./results/pretrained/test_200/images[save_folder_name]`.
An example html screenshot is shown below:
<img src = 'imgs/result_html.jpg'>

You can contact email [email protected] for any questions.

Expand Down
2 changes: 1 addition & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,5 @@
img_path = model.get_image_paths() # get image paths
if i % 5 == 0: # save images to an HTML file
print('processing (%04d)-th image... %s' % (i, img_path))
save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize, W=opt.W, H=opt.H)
webpage.save() # save the HTML
22 changes: 9 additions & 13 deletions util/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,15 @@
import time
from . import util, html
from subprocess import Popen, PIPE
from scipy.misc import imresize
import pdb
from scipy.io import savemat
from PIL import Image

if sys.version_info[0] == 2:
VisdomExceptionBase = Exception
else:
VisdomExceptionBase = ConnectionError


def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256, W=None, H=None):
"""Save images to the disk.
Parameters:
Expand All @@ -37,16 +35,16 @@ def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
for label, im_data in visuals.items():
## tensor to im
im = util.tensor2im(im_data)
#im,imo = util.tensor2im(im_data)
#matname = os.path.join(image_dir, '%s_%s.mat' % (name, label))
#savemat(matname,{'imo':imo})
image_name = '%s_%s.png' % (name, label)
save_path = os.path.join(image_dir, image_name)
h, w, _ = im.shape
if aspect_ratio > 1.0:
im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic')
if aspect_ratio < 1.0:
im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic')
if W is not None and H is not None and (W != w or H != h):
im = np.array(Image.fromarray(im).resize((W, H), Image.BICUBIC))
else:
if aspect_ratio > 1.0:
im = np.array(Image.fromarray(im).resize((int(w * aspect_ratio), h), Image.BICUBIC))
if aspect_ratio < 1.0:
im = np.array(Image.fromarray(im).resize((w, int(h / aspect_ratio)), Image.BICUBIC))
util.save_image(im, save_path)

ims.append(image_name)
Expand Down Expand Up @@ -133,7 +131,6 @@ def display_current_results(self, visuals, epoch, save_result):
for label, image in visuals.items():
image_numpy = util.tensor2im(image)
label_html_row += '<td>%s</td>' % label
#pdb.set_trace()
images.append(image_numpy.transpose([2, 0, 1]))
idx += 1
if idx % ncols == 0:
Expand Down Expand Up @@ -203,7 +200,6 @@ def plot_current_losses(self, epoch, counter_ratio, losses):
self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
#X = np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1)
#Y = np.array(self.plot_data['Y'])
#pdb.set_trace()
try:
self.vis.line(
X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
Expand Down

0 comments on commit 2d89d97

Please sign in to comment.