-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit dd0e77b
Showing
32 changed files
with
5,546 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# SelfDZSR++ (TPAMI 2024) | ||
|
||
Official PyTorch implementation of **SelfDZSR++ and SelfTZSR++** | ||
|
||
This work is extended from [SelfDZSR](https://arxiv.org/abs/2203.01325) (ECCV 2022, [Github](https://github.com/cszhilu1998/SelfDZSR)). | ||
|
||
|
||
|
||
> [**Self-Supervised Learning for Real-World Super-Resolution from Dual and Multiple Zoomed Observations**](https://ieeexplore.ieee.org/abstract/document/10476716) <br> | ||
> IEEE TPAMI, 2024 <br> | ||
> [Zhilu Zhang](https://scholar.google.com/citations?user=8pIq2N0AAAAJ), [Ruohao Wang](https://scholar.google.com/citations?user=o1FPNwQAAAAJ), [Hongzhi Zhang](https://scholar.google.com/citations?user=Ysk4WBwAAAAJ), [Wangmeng Zuo](https://scholar.google.com/citations?user=rUOpCEYAAAAJ) | ||
<br>Harbin Institute of Technology, China | ||
|
||
 | ||
|
||
|
||
|
||
## 1. Improvements Compared to SelfDZSR | ||
|
||
### 1.1 Improvements in Methodology | ||
|
||
- Introduce patch-based optical flow alignment to further mitigate the effect of the misalignment between data pairs. (**Section 3.2.1: Patch-based Optical Flow Alignment**) | ||
- Present LOSW (Local Overlapped Sliced Wasserstein) loss to generate visually pleasing results. (**Section 3.5: LOSW Loss and Learning Objective**) | ||
- Extend DZSR (Super-Resolution based on Dual Zoomed Observations) to multiple zoomed observations, where we present a progressive fusion scheme for better restoration. (**Section 3.6: Extension to Multiple Zoomed Observations**) | ||
|
||
### 1.2 Improvements in Experiments | ||
|
||
- Evaluate the proposed method on a larger iPhone camera dataset. (**Section 4.2: Quantitative and Qualitative Results**) | ||
- Evaluate the extended method that is based on triple zoomed images. (**Section 4.2: Quantitative and Qualitative Results**) | ||
- Compare \#FLOPs of different methods. (**Section 4.3: Comparison of \#Parameters and \#FLOPs**) | ||
- More ablation experiments, including (1) effect of different loss terms (**Section 5.3: Effect of LOSW Loss**); (2) effect of different reference images and fusion schemes (**Section 5.4: Effect of Different Refs and Fusion Schemes**); (3) effect of scaling up models (**Section 5.5: Effect of Scaling up Models**). | ||
|
||
|
||
|
||
## 2. Preparation and Datasets | ||
|
||
### 2.1 Prerequisites | ||
- Python 3.x and **PyTorch 1.12**. | ||
- OpenCV, NumPy, Pillow, tqdm, lpips, scikit-image and tensorboardX. | ||
|
||
### 2.2 Dataset | ||
- **Nikon camera images** can be downloaded from this [link](https://pan.baidu.com/s/1yEPBCMjJzFsEpTWU8W4SgQ?pwd=2rbh). | ||
- **iPhone camera images** can be downloaded from this [link](https://pan.baidu.com/s/1_AXwyn-nDhSckcH5phnEzQ?pwd=ripz). | ||
|
||
|
||
|
||
## 3. Quick Start | ||
|
||
### 3.1 Pre-Trained Models | ||
|
||
- All pre-trained models are provided in this [link](https://pan.baidu.com/s/1ZvdCqZZVY36GyX67qjtl0g?pwd=bkd9). Please place `ckpt` folder under `SelfDZSR_PlusPlus` folder. | ||
|
||
- For simplifying the training process, we provide the pre-trained models of feature extractors and auxiliary-LR generator. The models for Nikon and iPhone camera images are put in the `./ckpt/nikon_pretrain_models/` and `./ckpt/iphone_pretrain_models/` folder, respectively. | ||
|
||
- For direct testing, we provide the 8 pre-trained DZSR and TZSR models (`dzsr_nikon_l1`, `dzsr_nikon_l1sw`, `dzsr_iphone_l1`, `dzsr_iphone_l1sw`, `tzsr_nikon_l1`, `tzsr_nikon_l1sw`, `tzsr_iphone_l1`, and `tzsr_iphone_l1sw`) in the `./ckpt/` folder. Taking `tzsr_nikon_l1sw` as an example, it represents the TZSR model trained on the Nikon camera images using $l_1$ and local overlapped sliced Wasserstein loss terms. | ||
|
||
|
||
### 3.2 Training | ||
|
||
|
||
- Modify `dataroot`, `camera`, `data`, `model`, and `name` | ||
- Run [`sh train.sh`](train.sh) | ||
|
||
|
||
|
||
### 3.3 Testing | ||
|
||
- Modify `dataroot`, `camera`, `data`, `model`, and `name` | ||
- Run [`sh test.sh`](test.sh) | ||
|
||
### 3.4 Note | ||
|
||
- You can specify which GPU to use by `--gpu_ids`, e.g., `--gpu_ids 0,1`, `--gpu_ids 3`, `--gpu_ids -1` (for CPU mode). In the default setting, all GPUs are used. | ||
- You can refer to [options](./options/base_options.py) for more arguments. | ||
|
||
|
||
## 4. Results | ||
|
||
|
||
<p align="center"><img src="./imgs/results1.png" width="95%"></p> | ||
|
||
<p align="center"><img src="./imgs/results2.png" width="95%"></p> | ||
|
||
## 5. Citation | ||
If you find it useful in your research, please consider citing: | ||
|
||
@inproceedings{SelfDZSR, | ||
title={Self-Supervised Learning for Real-World Super-Resolution from Dual Zoomed Observations}, | ||
author={Zhang, Zhilu and Wang, Ruohao and Zhang, Hongzhi and Chen, Yunjin and Zuo, Wangmeng}, | ||
booktitle={European Conference on Computer Vision (ECCV)}, | ||
year={2022} | ||
} | ||
|
||
@article{SelfDZSR_PlusPlus, | ||
title={Self-Supervised Learning for Real-World Super-Resolution from Dual and Multiple Zoomed Observations}, | ||
author={Zhang, Zhilu and Wang, Ruohao and Zhang, Hongzhi and Zuo, Wangmeng}, | ||
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)}, | ||
year={2024}, | ||
publisher={IEEE} | ||
} | ||
|
||
## 6. Acknowledgement | ||
|
||
This repo is built upon the framework of [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix), and we borrow some code from [C2-Matching](https://github.com/yumingj/C2-Matching) and [DCSR](https://github.com/Tengfei-Wang/DCSR), thanks for their excellent work! |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import importlib | ||
import torch.utils.data | ||
from data.base_dataset import BaseDataset | ||
|
||
def find_dataset_using_name(dataset_name, split='train'): | ||
dataset_filename = "data." + dataset_name + "_dataset" | ||
datasetlib = importlib.import_module(dataset_filename) | ||
|
||
dataset = None | ||
target_dataset_name = dataset_name.replace('_', '') + 'dataset' | ||
for name, cls in datasetlib.__dict__.items(): | ||
if name.lower() == target_dataset_name.lower() \ | ||
and issubclass(cls, BaseDataset): | ||
dataset = cls | ||
|
||
if dataset is None: | ||
raise NotImplementedError("In %s.py, there should be a subclass of " | ||
"BaseDataset with class name that matches %s in " | ||
"lowercase." % (dataset_filename, target_dataset_name)) | ||
return dataset | ||
|
||
|
||
def create_dataset(dataset_name, split, opt): | ||
data_loader = CustomDatasetDataLoader(dataset_name, split, opt) | ||
dataset = data_loader.load_data() | ||
return dataset | ||
|
||
|
||
class CustomDatasetDataLoader(): | ||
def __init__(self, dataset_name, split, opt): | ||
self.opt = opt | ||
dataset_class = find_dataset_using_name(dataset_name, split) | ||
self.dataset = dataset_class(opt, split, dataset_name) | ||
self.imio = self.dataset.imio | ||
print("dataset [%s(%s)] created" % (dataset_name, split)) | ||
self.dataloader = torch.utils.data.DataLoader( | ||
self.dataset, | ||
batch_size=opt.batch_size if split=='train' else 1, | ||
shuffle=opt.shuffle and split=='train', | ||
num_workers=int(opt.num_dataloader), | ||
drop_last=opt.drop_last) | ||
|
||
def load_data(self): | ||
return self | ||
|
||
def __len__(self): | ||
"""Return the number of data in the dataset""" | ||
return min(len(self.dataset), self.opt.max_dataset_size) | ||
|
||
def __iter__(self): | ||
"""Return a batch of data""" | ||
for i, data in enumerate(self.dataloader): | ||
if i * self.opt.batch_size >= self.opt.max_dataset_size: | ||
break | ||
yield data | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import random | ||
import numpy as np | ||
import torch.utils.data as data | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class BaseDataset(data.Dataset, ABC): | ||
def __init__(self, opt, split, dataset_name): | ||
self.opt = opt | ||
self.split = split | ||
self.root = opt.dataroot | ||
self.dataset_name = dataset_name.lower() | ||
|
||
@abstractmethod | ||
def __len__(self): | ||
return 0 | ||
|
||
@abstractmethod | ||
def __getitem__(self, index): | ||
pass | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
import cv2 | ||
import numpy as np | ||
import random | ||
import torch | ||
from scipy import ndimage | ||
from scipy.interpolate import interp2d | ||
from .unprocess import unprocess, random_noise_levels, add_noise | ||
from .process import process | ||
from PIL import Image | ||
|
||
|
||
def get_degrade_seq(sf): | ||
degrade_seq = [] | ||
|
||
# ----------- | ||
# down sample | ||
# ----------- | ||
B_down = { | ||
"mode": "down", | ||
"sf": sf | ||
} | ||
B_down["down_mode"] = "bilinear" | ||
degrade_seq.append(B_down) | ||
|
||
# -------------- | ||
# gaussian noise | ||
# -------------- | ||
B_noise = { | ||
"mode": "noise", | ||
"noise_level": random.randint(5, 30) # 1, 19 | ||
} | ||
degrade_seq.append(B_noise) | ||
|
||
# ---------- | ||
# jpeg noise | ||
# ---------- | ||
B_jpeg = { | ||
"mode": "jpeg", | ||
"qf": random.randint(60, 95) # 40, 95 | ||
} | ||
degrade_seq.append(B_jpeg) | ||
|
||
# ------- | ||
# shuffle | ||
# ------- | ||
random.shuffle(degrade_seq) | ||
return degrade_seq | ||
|
||
|
||
def degrade_kernel(img, sf=4): | ||
h, w, c = np.array(img).shape | ||
degrade_seq = get_degrade_seq(sf) | ||
for degrade_dict in degrade_seq: | ||
mode = degrade_dict["mode"] | ||
if mode == "blur": | ||
img = get_blur(img, degrade_dict) | ||
elif mode == "down": | ||
img = get_down(img, degrade_dict) | ||
elif mode == "noise": | ||
img = get_noise(img, degrade_dict) | ||
elif mode == 'jpeg': | ||
img = get_jpeg(img, degrade_dict) | ||
elif mode == 'camera': | ||
img = get_camera(img, h, w, degrade_dict) | ||
elif mode == 'restore': | ||
img = get_restore(img, w, h, degrade_dict) | ||
else: | ||
exit(mode) | ||
return img, degrade_seq | ||
|
||
|
||
def get_blur(img, degrade_dict): | ||
img = np.array(img) | ||
k_size = degrade_dict["kernel_size"] | ||
if degrade_dict["is_aniso"]: | ||
sigma_x = degrade_dict["x_sigma"] | ||
sigma_y = degrade_dict["y_sigma"] | ||
angle = degrade_dict["rotation"] | ||
else: | ||
sigma_x = degrade_dict["sigma"] | ||
sigma_y = degrade_dict["sigma"] | ||
angle = 0 | ||
|
||
kernel = np.zeros((k_size, k_size)) | ||
d = k_size // 2 | ||
for x in range(-d, d+1): | ||
for y in range(-d, d+1): | ||
kernel[x+d][y+d] = get_kernel_pixel(x, y, sigma_x, sigma_y) | ||
M = cv2.getRotationMatrix2D((k_size//2, k_size//2), angle, 1) | ||
kernel = cv2.warpAffine(kernel, M, (k_size, k_size)) | ||
kernel = kernel / np.sum(kernel) | ||
img = ndimage.filters.convolve(img, np.expand_dims(kernel, axis=2), mode='reflect') | ||
|
||
return Image.fromarray(np.uint8(np.clip(img, 0.0, 255.0))) | ||
|
||
|
||
def get_down(img, degrade_dict): | ||
img = np.array(img) | ||
sf = degrade_dict["sf"] | ||
mode = degrade_dict["down_mode"] | ||
h, w, c = img.shape | ||
if mode == "nearest": | ||
img = img[0::sf, 0::sf, :] | ||
elif mode == "bilinear": | ||
new_h, new_w = int(h/sf), int(w/sf) | ||
img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) | ||
elif mode == "bicubic": | ||
new_h, new_w = int(h/sf), int(w/sf) | ||
img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_CUBIC) | ||
return Image.fromarray(np.uint8(np.clip(img, 0.0, 255.0))) | ||
|
||
def get_noise(img, degrade_dict): | ||
noise_level = degrade_dict["noise_level"] | ||
img = np.array(img) | ||
img = img + np.random.normal(0, noise_level, img.shape) | ||
return Image.fromarray(np.uint8(np.clip(img, 0.0, 255.0))) | ||
|
||
|
||
def get_jpeg(img, degrade_dict): | ||
qf = degrade_dict["qf"] | ||
img = np.array(img) | ||
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY),qf] # (0,100),higher is better,default is 95 | ||
_, encA = cv2.imencode('.jpg',img,encode_param) | ||
Img = cv2.imdecode(encA,1) | ||
return Image.fromarray(np.uint8(np.clip(Img, 0.0, 255.0))) | ||
|
||
|
||
def get_camera(img, h, w, degrade_dict): | ||
if h // 2 == 0 and w // 2 == 0: | ||
img = torch.from_numpy(np.array(img)) / 255.0 | ||
deg_img, features = unprocess(img) | ||
shot_noise, read_noise = random_noise_levels() | ||
deg_img = add_noise(deg_img, shot_noise, read_noise) | ||
deg_img = deg_img.unsqueeze(0) | ||
features['red_gain'] = features['red_gain'].unsqueeze(0) | ||
features['blue_gain'] = features['blue_gain'].unsqueeze(0) | ||
features['cam2rgb'] = features['cam2rgb'].unsqueeze(0) | ||
deg_img = process(deg_img, features['red_gain'], features['blue_gain'], features['cam2rgb']) | ||
deg_img = deg_img.squeeze(0) | ||
deg_img = torch.clamp(deg_img * 255.0, 0.0, 255.0).numpy() | ||
deg_img = deg_img.astype(np.uint8) | ||
return Image.fromarray(deg_img) | ||
else: | ||
return img | ||
|
||
|
||
def get_restore(img, h, w, degrade_dict): | ||
need_shift = degrade_dict["need_shift"] | ||
sf = degrade_dict["sf"] | ||
img = np.array(img) | ||
mode = degrade_dict["up_mode"] | ||
if mode == "bilinear": | ||
img = cv2.resize(img, (h, w), interpolation=cv2.INTER_LINEAR) | ||
else: | ||
img = cv2.resize(img, (h, w), interpolation=cv2.INTER_CUBIC) | ||
if need_shift: | ||
img = shift_pixel(img, int(sf)) | ||
return Image.fromarray(img) | ||
|
||
|
||
def get_kernel_pixel(x, y, sigma_x, sigma_y): | ||
return 1/(2*np.pi*sigma_x*sigma_y)*np.exp(-((x*x/(2*sigma_x*sigma_x))+(y*y/(2*sigma_y*sigma_y)))) | ||
|
||
|
||
def shift_pixel(x, sf, upper_left=True): | ||
"""shift pixel for super-resolution with different scale factors | ||
Args: | ||
x: WxHxC or WxH | ||
sf: scale factor | ||
upper_left: shift direction | ||
""" | ||
h, w = x.shape[:2] | ||
shift = (sf-1)*0.5 | ||
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) | ||
if upper_left: | ||
x1 = xv + shift | ||
y1 = yv + shift | ||
else: | ||
x1 = xv - shift | ||
y1 = yv - shift | ||
|
||
x1 = np.clip(x1, 0, w-1) | ||
y1 = np.clip(y1, 0, h-1) | ||
|
||
if x.ndim == 2: | ||
x = interp2d(xv, yv, x)(x1, y1) | ||
if x.ndim == 3: | ||
for i in range(x.shape[-1]): | ||
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) | ||
|
||
return x | ||
|
||
|
||
def print_degrade_seg(degrade_seq): | ||
for degrade_dict in degrade_seq: | ||
print(degrade_dict) | ||
|
Oops, something went wrong.