This is a Tensorflow project of super-resolution that implemented the paper Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network. The TensorZoom network can scale up an image by 4 times its width and height (16 times the area) with significantly better quality comparing with bilinear scaling. The training process is based on adversarial training that trains a discriminative network and a generative network alternatively in order to achieve a better visual result.
##Android Demo This project is linked to an Android demo. The demo can be found in here:
This demo has already included the pre-trained GraphDef files:
- tz6-s-stitch.pb for small image and level 3 zooming
- tz6-s-stitch-sblur-lowtv-gen.pb for large image taken from camera
The demo also supports to use custom GraphDef file. Open setting -> select "Use custom Tensorflow GraphDef for Tensorzoom" -> Select GraphDef file to choose a .pb file on the device.
Hint: Depending on your Android version, you may need to click "..." icon on the menu and select "Show SD card" in order to find the files.
See the "Export GraphDef" section below to see how to create a .pb file for the APP and the detail of all the pre-trained Graph in this project.
##Modifications
- This implementation is mainly focused on mobile usage. Therefore only 6 batch normalisation layers are used.
- In general, a mobile device can only afford to calculate an image with at most 120x120 size. So a large image is sliced into smaller tiles to render. Then the tiles are stitched together to form a large image. In order to improve the quality of the stitched result, we added function to slice the input image into 16 smaller images and concat them together to form a batch for training. This method can greatly improve the quality when multiple scaled images are stitched together.
- An additional deblur training is added. The input image is blurred using
tf.nn.depthwise_conv2d
. The quality is better for the image taken from the mobile camera but worse than normal for small images or thumbnails.
##Samples Here are some sample result generated by this algorithm.(click to download the full size to see the result)
##Requirements
- Python 2.7.10
- Tensorflow r0.10
- numpy (
pip install numpy
) - scikit-image (
pip install scikit-image
)
Additional items are required to run the training. See the section "Train the Network" below.
##How to use
To test rendering an image is simple and do not require extra files and project.
An example is provided in
net_analysis.py
.
To train the network need to download additional files and change the settings in trainer.py
:
# https://github.com/machrisaa/tensorflow-vgg
VGG_NPY_PATH = '../tensoflow_vgg/vgg19.npy'
# http://msvocds.blob.core.windows.net/coco2014/train2014.zip
COCO2014_PATH = '../../datasets/coco2014/train2014'
The training need to use a the result of the conv2_2
layer from a pre-trained VGG19 network - Tensorflow-VGG. The dataset we are using is the Microsoft Coco2014 data set - 2014 Training images [80K/13GB] data set.
Download and update the path in trainer.py
before start running this file.
The training will save per 100 iterations in 2 formats:
- the NPY file (dis and gen): easier to be used by other python applications
- the standard Tensorflow saves: easier to resume training, e.g. global steps Both the NPY files and the Tensorflow save can be used to resume the training. Tensorboard logging will also be added per 100 iterations.
Exporting a network into a GraphDef file require to select a NPY file and call the save_graph
method:
net = TensorZoomNet(npy_path='./results/tz6-s-stitch-sblur-lowtv/tz6-s-stitch-sblur-lowtv-gen.npy', trainable=False)
net.save_graph(logdir='./', name='export.pb')
Sample can be found in here
This repository provides several pre-trained NPY results that can be used to export the GraphDef:
- tz6-s-nostitch-gen.npy: The plain vanilla training without the stitch and deblur method. Bad for mobile sliced images but good for rendering a whole image on desktop.
- tz6-s-stitch-gen.npy: A training that used the stitch method, no deblur. Bad for the multi-megapixels photo taken by the mobile device.
- tz6-s-stitch-sblur-gen.npy: Trained from the stitched result with additional deblur training method. Better quality for the large image taken from the mobile camera.
- tz6-s-stitch-sblur-lowtv-gen.npy: Another version with deblur version. Try to remove total variant cost to train the network.
Due to the file size limit, the NPY files for the Discriminative Network is not included.