Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
cszhilu1998 committed May 3, 2024
0 parents commit dd0e77b
Show file tree
Hide file tree
Showing 32 changed files with 5,546 additions and 0 deletions.
105 changes: 105 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# SelfDZSR++ (TPAMI 2024)

Official PyTorch implementation of **SelfDZSR++ and SelfTZSR++**

This work is extended from [SelfDZSR](https://arxiv.org/abs/2203.01325) (ECCV 2022, [Github](https://github.com/cszhilu1998/SelfDZSR)).



> [**Self-Supervised Learning for Real-World Super-Resolution from Dual and Multiple Zoomed Observations**](https://ieeexplore.ieee.org/abstract/document/10476716) <br>
> IEEE TPAMI, 2024 <br>
> [Zhilu Zhang](https://scholar.google.com/citations?user=8pIq2N0AAAAJ), [Ruohao Wang](https://scholar.google.com/citations?user=o1FPNwQAAAAJ), [Hongzhi Zhang](https://scholar.google.com/citations?user=Ysk4WBwAAAAJ), [Wangmeng Zuo](https://scholar.google.com/citations?user=rUOpCEYAAAAJ)
<br>Harbin Institute of Technology, China

![visitors](https://visitor-badge.laobi.icu/badge?page_id=cszhilu1998.SelfDZSR_PlusPlus)



## 1. Improvements Compared to SelfDZSR

### 1.1 Improvements in Methodology

- Introduce patch-based optical flow alignment to further mitigate the effect of the misalignment between data pairs. (**Section 3.2.1: Patch-based Optical Flow Alignment**)
- Present LOSW (Local Overlapped Sliced Wasserstein) loss to generate visually pleasing results. (**Section 3.5: LOSW Loss and Learning Objective**)
- Extend DZSR (Super-Resolution based on Dual Zoomed Observations) to multiple zoomed observations, where we present a progressive fusion scheme for better restoration. (**Section 3.6: Extension to Multiple Zoomed Observations**)

### 1.2 Improvements in Experiments

- Evaluate the proposed method on a larger iPhone camera dataset. (**Section 4.2: Quantitative and Qualitative Results**)
- Evaluate the extended method that is based on triple zoomed images. (**Section 4.2: Quantitative and Qualitative Results**)
- Compare \#FLOPs of different methods. (**Section 4.3: Comparison of \#Parameters and \#FLOPs**)
- More ablation experiments, including (1) effect of different loss terms (**Section 5.3: Effect of LOSW Loss**); (2) effect of different reference images and fusion schemes (**Section 5.4: Effect of Different Refs and Fusion Schemes**); (3) effect of scaling up models (**Section 5.5: Effect of Scaling up Models**).



## 2. Preparation and Datasets

### 2.1 Prerequisites
- Python 3.x and **PyTorch 1.12**.
- OpenCV, NumPy, Pillow, tqdm, lpips, scikit-image and tensorboardX.

### 2.2 Dataset
- **Nikon camera images** can be downloaded from this [link](https://pan.baidu.com/s/1yEPBCMjJzFsEpTWU8W4SgQ?pwd=2rbh).
- **iPhone camera images** can be downloaded from this [link](https://pan.baidu.com/s/1_AXwyn-nDhSckcH5phnEzQ?pwd=ripz).



## 3. Quick Start

### 3.1 Pre-Trained Models

- All pre-trained models are provided in this [link](https://pan.baidu.com/s/1ZvdCqZZVY36GyX67qjtl0g?pwd=bkd9). Please place `ckpt` folder under `SelfDZSR_PlusPlus` folder.

- For simplifying the training process, we provide the pre-trained models of feature extractors and auxiliary-LR generator. The models for Nikon and iPhone camera images are put in the `./ckpt/nikon_pretrain_models/` and `./ckpt/iphone_pretrain_models/` folder, respectively.

- For direct testing, we provide the 8 pre-trained DZSR and TZSR models (`dzsr_nikon_l1`, `dzsr_nikon_l1sw`, `dzsr_iphone_l1`, `dzsr_iphone_l1sw`, `tzsr_nikon_l1`, `tzsr_nikon_l1sw`, `tzsr_iphone_l1`, and `tzsr_iphone_l1sw`) in the `./ckpt/` folder. Taking `tzsr_nikon_l1sw` as an example, it represents the TZSR model trained on the Nikon camera images using $l_1$ and local overlapped sliced Wasserstein loss terms.


### 3.2 Training


- Modify `dataroot`, `camera`, `data`, `model`, and `name`
- Run [`sh train.sh`](train.sh)



### 3.3 Testing

- Modify `dataroot`, `camera`, `data`, `model`, and `name`
- Run [`sh test.sh`](test.sh)

### 3.4 Note

- You can specify which GPU to use by `--gpu_ids`, e.g., `--gpu_ids 0,1`, `--gpu_ids 3`, `--gpu_ids -1` (for CPU mode). In the default setting, all GPUs are used.
- You can refer to [options](./options/base_options.py) for more arguments.


## 4. Results


<p align="center"><img src="./imgs/results1.png" width="95%"></p>

<p align="center"><img src="./imgs/results2.png" width="95%"></p>

## 5. Citation
If you find it useful in your research, please consider citing:

@inproceedings{SelfDZSR,
title={Self-Supervised Learning for Real-World Super-Resolution from Dual Zoomed Observations},
author={Zhang, Zhilu and Wang, Ruohao and Zhang, Hongzhi and Chen, Yunjin and Zuo, Wangmeng},
booktitle={European Conference on Computer Vision (ECCV)},
year={2022}
}

@article{SelfDZSR_PlusPlus,
title={Self-Supervised Learning for Real-World Super-Resolution from Dual and Multiple Zoomed Observations},
author={Zhang, Zhilu and Wang, Ruohao and Zhang, Hongzhi and Zuo, Wangmeng},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)},
year={2024},
publisher={IEEE}
}

## 6. Acknowledgement

This repo is built upon the framework of [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix), and we borrow some code from [C2-Matching](https://github.com/yumingj/C2-Matching) and [DCSR](https://github.com/Tengfei-Wang/DCSR), thanks for their excellent work!
56 changes: 56 additions & 0 deletions data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import importlib
import torch.utils.data
from data.base_dataset import BaseDataset

def find_dataset_using_name(dataset_name, split='train'):
dataset_filename = "data." + dataset_name + "_dataset"
datasetlib = importlib.import_module(dataset_filename)

dataset = None
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
for name, cls in datasetlib.__dict__.items():
if name.lower() == target_dataset_name.lower() \
and issubclass(cls, BaseDataset):
dataset = cls

if dataset is None:
raise NotImplementedError("In %s.py, there should be a subclass of "
"BaseDataset with class name that matches %s in "
"lowercase." % (dataset_filename, target_dataset_name))
return dataset


def create_dataset(dataset_name, split, opt):
data_loader = CustomDatasetDataLoader(dataset_name, split, opt)
dataset = data_loader.load_data()
return dataset


class CustomDatasetDataLoader():
def __init__(self, dataset_name, split, opt):
self.opt = opt
dataset_class = find_dataset_using_name(dataset_name, split)
self.dataset = dataset_class(opt, split, dataset_name)
self.imio = self.dataset.imio
print("dataset [%s(%s)] created" % (dataset_name, split))
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batch_size if split=='train' else 1,
shuffle=opt.shuffle and split=='train',
num_workers=int(opt.num_dataloader),
drop_last=opt.drop_last)

def load_data(self):
return self

def __len__(self):
"""Return the number of data in the dataset"""
return min(len(self.dataset), self.opt.max_dataset_size)

def __iter__(self):
"""Return a batch of data"""
for i, data in enumerate(self.dataloader):
if i * self.opt.batch_size >= self.opt.max_dataset_size:
break
yield data

21 changes: 21 additions & 0 deletions data/base_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import random
import numpy as np
import torch.utils.data as data
from abc import ABC, abstractmethod


class BaseDataset(data.Dataset, ABC):
def __init__(self, opt, split, dataset_name):
self.opt = opt
self.split = split
self.root = opt.dataroot
self.dataset_name = dataset_name.lower()

@abstractmethod
def __len__(self):
return 0

@abstractmethod
def __getitem__(self, index):
pass

197 changes: 197 additions & 0 deletions data/degrade/degrade_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import cv2
import numpy as np
import random
import torch
from scipy import ndimage
from scipy.interpolate import interp2d
from .unprocess import unprocess, random_noise_levels, add_noise
from .process import process
from PIL import Image


def get_degrade_seq(sf):
degrade_seq = []

# -----------
# down sample
# -----------
B_down = {
"mode": "down",
"sf": sf
}
B_down["down_mode"] = "bilinear"
degrade_seq.append(B_down)

# --------------
# gaussian noise
# --------------
B_noise = {
"mode": "noise",
"noise_level": random.randint(5, 30) # 1, 19
}
degrade_seq.append(B_noise)

# ----------
# jpeg noise
# ----------
B_jpeg = {
"mode": "jpeg",
"qf": random.randint(60, 95) # 40, 95
}
degrade_seq.append(B_jpeg)

# -------
# shuffle
# -------
random.shuffle(degrade_seq)
return degrade_seq


def degrade_kernel(img, sf=4):
h, w, c = np.array(img).shape
degrade_seq = get_degrade_seq(sf)
for degrade_dict in degrade_seq:
mode = degrade_dict["mode"]
if mode == "blur":
img = get_blur(img, degrade_dict)
elif mode == "down":
img = get_down(img, degrade_dict)
elif mode == "noise":
img = get_noise(img, degrade_dict)
elif mode == 'jpeg':
img = get_jpeg(img, degrade_dict)
elif mode == 'camera':
img = get_camera(img, h, w, degrade_dict)
elif mode == 'restore':
img = get_restore(img, w, h, degrade_dict)
else:
exit(mode)
return img, degrade_seq


def get_blur(img, degrade_dict):
img = np.array(img)
k_size = degrade_dict["kernel_size"]
if degrade_dict["is_aniso"]:
sigma_x = degrade_dict["x_sigma"]
sigma_y = degrade_dict["y_sigma"]
angle = degrade_dict["rotation"]
else:
sigma_x = degrade_dict["sigma"]
sigma_y = degrade_dict["sigma"]
angle = 0

kernel = np.zeros((k_size, k_size))
d = k_size // 2
for x in range(-d, d+1):
for y in range(-d, d+1):
kernel[x+d][y+d] = get_kernel_pixel(x, y, sigma_x, sigma_y)
M = cv2.getRotationMatrix2D((k_size//2, k_size//2), angle, 1)
kernel = cv2.warpAffine(kernel, M, (k_size, k_size))
kernel = kernel / np.sum(kernel)
img = ndimage.filters.convolve(img, np.expand_dims(kernel, axis=2), mode='reflect')

return Image.fromarray(np.uint8(np.clip(img, 0.0, 255.0)))


def get_down(img, degrade_dict):
img = np.array(img)
sf = degrade_dict["sf"]
mode = degrade_dict["down_mode"]
h, w, c = img.shape
if mode == "nearest":
img = img[0::sf, 0::sf, :]
elif mode == "bilinear":
new_h, new_w = int(h/sf), int(w/sf)
img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
elif mode == "bicubic":
new_h, new_w = int(h/sf), int(w/sf)
img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
return Image.fromarray(np.uint8(np.clip(img, 0.0, 255.0)))

def get_noise(img, degrade_dict):
noise_level = degrade_dict["noise_level"]
img = np.array(img)
img = img + np.random.normal(0, noise_level, img.shape)
return Image.fromarray(np.uint8(np.clip(img, 0.0, 255.0)))


def get_jpeg(img, degrade_dict):
qf = degrade_dict["qf"]
img = np.array(img)
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY),qf] # (0,100),higher is better,default is 95
_, encA = cv2.imencode('.jpg',img,encode_param)
Img = cv2.imdecode(encA,1)
return Image.fromarray(np.uint8(np.clip(Img, 0.0, 255.0)))


def get_camera(img, h, w, degrade_dict):
if h // 2 == 0 and w // 2 == 0:
img = torch.from_numpy(np.array(img)) / 255.0
deg_img, features = unprocess(img)
shot_noise, read_noise = random_noise_levels()
deg_img = add_noise(deg_img, shot_noise, read_noise)
deg_img = deg_img.unsqueeze(0)
features['red_gain'] = features['red_gain'].unsqueeze(0)
features['blue_gain'] = features['blue_gain'].unsqueeze(0)
features['cam2rgb'] = features['cam2rgb'].unsqueeze(0)
deg_img = process(deg_img, features['red_gain'], features['blue_gain'], features['cam2rgb'])
deg_img = deg_img.squeeze(0)
deg_img = torch.clamp(deg_img * 255.0, 0.0, 255.0).numpy()
deg_img = deg_img.astype(np.uint8)
return Image.fromarray(deg_img)
else:
return img


def get_restore(img, h, w, degrade_dict):
need_shift = degrade_dict["need_shift"]
sf = degrade_dict["sf"]
img = np.array(img)
mode = degrade_dict["up_mode"]
if mode == "bilinear":
img = cv2.resize(img, (h, w), interpolation=cv2.INTER_LINEAR)
else:
img = cv2.resize(img, (h, w), interpolation=cv2.INTER_CUBIC)
if need_shift:
img = shift_pixel(img, int(sf))
return Image.fromarray(img)


def get_kernel_pixel(x, y, sigma_x, sigma_y):
return 1/(2*np.pi*sigma_x*sigma_y)*np.exp(-((x*x/(2*sigma_x*sigma_x))+(y*y/(2*sigma_y*sigma_y))))


def shift_pixel(x, sf, upper_left=True):
"""shift pixel for super-resolution with different scale factors
Args:
x: WxHxC or WxH
sf: scale factor
upper_left: shift direction
"""
h, w = x.shape[:2]
shift = (sf-1)*0.5
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
if upper_left:
x1 = xv + shift
y1 = yv + shift
else:
x1 = xv - shift
y1 = yv - shift

x1 = np.clip(x1, 0, w-1)
y1 = np.clip(y1, 0, h-1)

if x.ndim == 2:
x = interp2d(xv, yv, x)(x1, y1)
if x.ndim == 3:
for i in range(x.shape[-1]):
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)

return x


def print_degrade_seg(degrade_seq):
for degrade_dict in degrade_seq:
print(degrade_dict)

Loading

0 comments on commit dd0e77b

Please sign in to comment.