Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
umzi2 committed Jun 15, 2024
0 parents commit 525f889
Show file tree
Hide file tree
Showing 67 changed files with 11,935 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
experiments/*
options/tmp/
__pycache__/
.ruff_cache/
216 changes: 216 additions & 0 deletions convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import warnings
from copy import deepcopy
from os import path as osp

import numpy as np
import onnx
import onnxruntime
import torch
from onnxconverter_common.float16 import convert_float_to_float16
from onnxsim import simplify

from neosr.archs import build_network
from neosr.utils.options import parse_options


def load_net():
# build_network
print(f"\n-------- Attempting to build network [{args.network}].")
if args.network is None:
msg = "Please select a network using the -net option"
raise ValueError(msg)
net_opt = {"type": args.network}

if args.network == "omnisr":
net_opt["upsampling"] = args.scale
net_opt["window_size"] = args.window

if args.window:
net_opt["window_size"] = args.window

net = build_network(net_opt)
load_net = torch.load(args.input, map_location=torch.device("cuda"))
# find parameter key
print("-------- Finding parameter key...")
try:
if "params-ema" in load_net:
param_key = "params-ema"
elif "params" in load_net:
param_key = "params"
elif "params_ema" in load_net:
param_key = "params_ema"
load_net = load_net[param_key]
except:
pass

# remove unnecessary 'module.'
for k, v in deepcopy(load_net).items():
if k.startswith("module."):
load_net[k[7:]] = v
load_net.pop(k)

# load_network and send to device
net.load_state_dict(load_net, strict=True)
net = net.to(device="cuda", non_blocking=True)
print(f"-------- Successfully loaded network [{args.network}].")
torch.cuda.empty_cache()

return net


def assert_verify(onnx_model, torch_model) -> None:
if args.static is not None:
dummy_input = torch.randn(1, *args.static, requires_grad=True)
else:
dummy_input = torch.randn(1, 3, 20, 20, requires_grad=True)
# onnxruntime output prediction
# NOTE: "CUDAExecutionProvider" errors if some nvidia libs
# are not found, defaulting to cpu
ort_session = onnxruntime.InferenceSession(
onnx_model, providers=["CPUExecutionProvider"]
)
ort_inputs = {ort_session.get_inputs()[0].name: dummy_input.detach().cpu().numpy()}
ort_outs = ort_session.run(None, ort_inputs)

# torch outputs
torch_outputs = torch_model(dummy_input)

# final assert - default tolerance values - rtol=1e-03, atol=1e-05
np.testing.assert_allclose(
torch_outputs.detach().cpu().numpy(), ort_outs[0], rtol=0.01, atol=0.001
)
print(f"-------- Model successfully verified.")


def to_onnx() -> None:
# error if network can't be converted
net_error = ["craft", "ditn"]
if args.network in net_error:
msg = f"Network [{args.network}] cannot be converted to ONNX."
raise RuntimeError(msg)

# load network and send to device
model = load_net()
# set model to eval mode
model.eval()

# set static or dynamic
if args.static is not None:
dummy_input = torch.randn(1, *args.static, requires_grad=True)
else:
dummy_input = torch.randn(1, 3, 20, 20, requires_grad=True)

# dict for dynamic axes
if args.static is None:
dyn_axes = {
'dynamic_axes': {
'input': {0: 'batch_size', 2: 'width', 3: 'height'},
'output': {0: 'batch_size', 2: 'width', 3: 'height'},
},
'input_names': ["input"],
'output_names': ["output"],
}
else:
dyn_axes = None

# add _fp32 suffix to output str
filename, extension = osp.splitext(args.output)
output_fp32 = filename + "_fp32" + extension
# begin conversion
print("-------- Starting ONNX conversion (this can take a while)...")

with torch.device("cpu"):
# TODO: switch to dynamo_export once it supports ATen PixelShuffle
# then torch.testing.assert_close for verification

torch.onnx.export(
model,
dummy_input,
output_fp32,
export_params=True,
opset_version=args.opset,
do_constant_folding=False,
**(dyn_axes or {}),
)

print("-------- Conversion was successful. Verifying...")
# verify onnx
load_onnx = onnx.load(output_fp32)
torch.cuda.empty_cache()
onnx.checker.check_model(load_onnx)
print(f"-------- Model successfully converted to ONNX format. Saved at: {output_fp32}.")
# verify outputs
if args.nocheck is False:
assert_verify(output_fp32, model)


if args.optimize:
print("-------- Running ONNX optimization...")
#filename, extension = osp.splitext(args.output)
#output_optimized = filename + "_fp32_optimized" + extension
session_opt = onnxruntime.SessionOptions()
# ENABLE_ALL can cause compatibility issues, leaving EXTENDED as default
session_opt.graph_optimization_level = (
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
)
session_opt.optimized_model_filepath = output_optimized
# save
onnxruntime.InferenceSession(output_fp32, session_opt)
# verify
onnx.checker.check_model(onnx.load(output_optimized))
print(f"-------- Model successfully optimized. Saved at: {output_optimized}")

if args.fp16:
print("-------- Converting to fp16...")
output_fp16 = filename + "_fp16" + extension
# convert to fp16
if args.optimize:
to_fp16 = convert_float_to_float16(onnx.load(output_optimized))
else:
to_fp16 = convert_float_to_float16(load_onnx)
# save
onnx.save(to_fp16, output_fp16)
# verify
onnx.checker.check_model(onnx.load(output_fp16))
print(
f"-------- Model successfully converted to half-precision. Saved at: {output_fp16}."
)

if args.fulloptimization:
# error if network can't run through onnxsim
opt_error = ["omnisr"]
if args.network in opt_error:
msg = f"Network [{args.network}] doesnt support full optimization."
raise RuntimeError(msg)

print("-------- Running full optimization (this can take a while)...")
output_fp32_fulloptimized = filename + "_fp32_fullyoptimized" + extension
output_fp16_fulloptimized = filename + "_fp16_fullyoptimized" + extension
# run onnxsim
if args.optimize:
simplified, check = simplify(onnx.load(output_optimized))
elif args.fp16:
simplified, check = simplify(onnx.load(output_fp16))
else:
simplified, check = simplify(load_onnx)
assert check, "Couldn't validate ONNX model."

# save and verify
if args.fp16:
onnx.save(simplified, output_fp16_fulloptimized)
onnx.checker.check_model(onnx.load(output_fp16_fulloptimized))
else:
onnx.save(simplified, output_fp32_fulloptimized)
onnx.checker.check_model(onnx.load(output_fp32_fulloptimized))

print(
f"-------- Model successfully optimized. Saved at: {output_fp32_fulloptimized}\n"
)


if __name__ == "__main__":
torch.set_default_device("cuda")
warnings.filterwarnings("ignore", category=UserWarning)
root_path = osp.abspath(osp.join(__file__, osp.pardir))
__, args = parse_options(root_path)
to_onnx()
45 changes: 45 additions & 0 deletions dataset/create_lmdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import argparse
from os import path as osp

from neosr.utils import scandir
from neosr.utils.lmdb_util import make_lmdb_from_imgs


def create_lmdb():
"""
Create lmdb files.
Before run this script, please run `extract_subimages.py`.
"""

folder_path = args.input
lmdb_path = args.output
img_path_list, keys = prepare_keys(folder_path)
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys, multiprocessing_read=True)


def prepare_keys(folder_path):
"""Prepare image path list and keys.
Args:
folder_path (str): Folder path.
Returns:
list[str]: Image path list.
list[str]: Key list.
"""

print('Reading image path list ...')
img_path_list = sorted(list(scandir(folder_path, suffix='png', recursive=False)))
keys = [img_path.split('.png')[0] for img_path in sorted(img_path_list)]

return img_path_list, keys


if __name__ == '__main__':
parser = argparse.ArgumentParser()

parser.add_argument('--input', type=str, help=("Input Path"))
parser.add_argument('--output', type=str, help=("Output Path"))
args = parser.parse_args()
create_lmdb()

45 changes: 45 additions & 0 deletions dataset/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
> [!NOTE]
> The package `lmdb` is required. It may only work with python <=3.11
## 📸 datasets

As part of *neosr*, I have released a dataset series called *Nomos*. The purpose of these dataset is to distill only the best images from the academic and community datasets. A total of 14 datasets were manually reviewed and processed, including: [Adobe-MIT-5k](https://data.csail.mit.edu/graphics/fivek/), [RAISE](http://loki.disi.unitn.it/RAISE/), [LSDIR](https://data.vision.ee.ethz.ch/yawli/), [LIU4k-v2](https://structpku.github.io/LIU4K_Dataset/LIU4K_v2.html), [KONIQ-10k](https://database.mmsp-kn.de/koniq-10k-database.html), [Nikon LL RAW](https://www.kaggle.com/datasets/razorblade/nikon-camera-dataset), [DIV8k](https://ieeexplore.ieee.org/document/9021973), [FFHQ](https://github.com/NVlabs/ffhq-dataset), [Flickr2k](http://cv.snu.ac.kr/research/EDSR/Flickr2K.tar), [ModernAnimation1080_v2](https://huggingface.co/datasets/Zarxrax/ModernAnimation1080_v2), [Rawsamples](https://www.rawsamples.ch/index.php/en/), [SignatureEdits](https://www.signatureedits.com/free-raw-photos/), [Hasselblad raw samples](https://www.hasselblad.com/learn/sample-images/) and [Unsplash](https://unsplash.com/).

- `Nomos-v2` (*recommended*): contains 6000 images, multipurpose. Data distribution:

```mermaid
pie
title Nomos-v2 distribution
"Animal / fur" : 439
"Interiors" : 280
"Exteriors / misc" : 696
"Architecture / geometric" : 1470
"Drawing / painting / anime" : 1076
"Humans" : 598
"Mountain / Rocks" : 317
"Text" : 102
"Textures" : 439
"Vegetation" : 574
```

- `nomos_uni` (*recommended for lightweight networks*): contains 2989 images, multipurpose. Meant to be used on lightweight networks (<800k parameters).
- `hfa2k`: contains 2568 anime images.

| dataset download | sha256 |
|-----------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------|
| [**nomosv2**](https://drive.google.com/file/d/1vqKWGtpkYHKv8uqK_xgYZbYuPHJS2l8j/view?usp=drive_link) (3GB) | [sha256](https://drive.google.com/file/d/12eNzPqHd2N1rTWMDh_rAv3urNypJexQT/view?usp=drive_link) |
| [**nomosv2.lmdb**](https://drive.google.com/file/d/1Rzdjt3w0qXle7vHa8FeFltmyKTMIwPR5/view?usp=drive_link) (3GB) | [sha256](https://drive.google.com/file/d/1IrDjI37psiCc-Khn3_KSyov-xP4txZYe/view?usp=drive_link) |
| [nomosv2_lq_4x](https://drive.google.com/file/d/1YiCywSFwRuwaYmnZ0TgoWDvcDQifAsZo/view?usp=drive_link) (187MB) | [sha256](https://drive.google.com/file/d/1iOOte6h-AE1iD-i5wl_gVx1uJzNTS4Cq/view?usp=drive_link) |
| [nomosv2_lq_4x.lmdb](https://drive.google.com/file/d/1IrDjI37psiCc-Khn3_KSyov-xP4txZYe/view?usp=drive_link) (187MB) | [sha256](https://drive.google.com/file/d/1bpuuiGFNBrDuZiRSP5hpVgFQx44MImay/view?usp=drive_link) |
| [nomos_uni](https://drive.google.com/file/d/1LVS7i9J3mP9f2Qav2Z9i9vq3y2xsKxA_/view?usp=sharing) (1.3GB) | [sha256](https://drive.google.com/file/d/1cdzVSpXQQhcbRVuFPbNtb6mZx_BoLwyW/view?usp=sharing) |
| [nomos_uni.lmdb](https://drive.google.com/file/d/1MHJCS4Zl3H5nihgpA_VVliziXnhJ3aU7/view?usp=sharing) (1.3GB) | [sha256](https://drive.google.com/file/d/1g3XLV-hFdLUcuAHLv2R6Entye8MkMx0V/view?usp=drive_link) |
| [nomos_uni_lq_4x](https://drive.google.com/file/d/1uvMl8dG8-LXjCOEoO9Aiq5Q9rd_BIUw9/view?usp=sharing) | [sha256](https://drive.google.com/file/d/1MTJBcfaMYdfWhsZCWEEOwbKSdmN5dVwl/view?usp=drive_link) |
| [nomos_uni_lq_4x.lmdb](https://drive.google.com/file/d/1h27AsZze_FFsAsf8eXupcqIZQHhvwa1y/view?usp=sharing) | [sha256](https://drive.google.com/file/d/1dhvIlM_uaIYMEKuijemnlmMTg4qf7bj7/view?usp=drive_link) |
| [hfa2k](https://drive.google.com/file/d/1PonJdHWwCtBdG4i1LwThm06t6RibnVu8/view?usp=sharing) | [sha256](https://drive.google.com/file/d/1ojDSyKCnCDoLOf9C-Zo4-BmuVSNTItEl/view?usp=sharing) |

See more datasets on the [**readme section about datasets**](https://github.com/muslll/neosr?tab=readme-ov-file#datasets).


## utils

In the [utils](utils/) folder you can find some tools to help prepare datasets, such as generating meta info files and converting to LMDB.
Loading

0 comments on commit 525f889

Please sign in to comment.