-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 525f889
Showing
67 changed files
with
11,935 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
experiments/* | ||
options/tmp/ | ||
__pycache__/ | ||
.ruff_cache/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
import warnings | ||
from copy import deepcopy | ||
from os import path as osp | ||
|
||
import numpy as np | ||
import onnx | ||
import onnxruntime | ||
import torch | ||
from onnxconverter_common.float16 import convert_float_to_float16 | ||
from onnxsim import simplify | ||
|
||
from neosr.archs import build_network | ||
from neosr.utils.options import parse_options | ||
|
||
|
||
def load_net(): | ||
# build_network | ||
print(f"\n-------- Attempting to build network [{args.network}].") | ||
if args.network is None: | ||
msg = "Please select a network using the -net option" | ||
raise ValueError(msg) | ||
net_opt = {"type": args.network} | ||
|
||
if args.network == "omnisr": | ||
net_opt["upsampling"] = args.scale | ||
net_opt["window_size"] = args.window | ||
|
||
if args.window: | ||
net_opt["window_size"] = args.window | ||
|
||
net = build_network(net_opt) | ||
load_net = torch.load(args.input, map_location=torch.device("cuda")) | ||
# find parameter key | ||
print("-------- Finding parameter key...") | ||
try: | ||
if "params-ema" in load_net: | ||
param_key = "params-ema" | ||
elif "params" in load_net: | ||
param_key = "params" | ||
elif "params_ema" in load_net: | ||
param_key = "params_ema" | ||
load_net = load_net[param_key] | ||
except: | ||
pass | ||
|
||
# remove unnecessary 'module.' | ||
for k, v in deepcopy(load_net).items(): | ||
if k.startswith("module."): | ||
load_net[k[7:]] = v | ||
load_net.pop(k) | ||
|
||
# load_network and send to device | ||
net.load_state_dict(load_net, strict=True) | ||
net = net.to(device="cuda", non_blocking=True) | ||
print(f"-------- Successfully loaded network [{args.network}].") | ||
torch.cuda.empty_cache() | ||
|
||
return net | ||
|
||
|
||
def assert_verify(onnx_model, torch_model) -> None: | ||
if args.static is not None: | ||
dummy_input = torch.randn(1, *args.static, requires_grad=True) | ||
else: | ||
dummy_input = torch.randn(1, 3, 20, 20, requires_grad=True) | ||
# onnxruntime output prediction | ||
# NOTE: "CUDAExecutionProvider" errors if some nvidia libs | ||
# are not found, defaulting to cpu | ||
ort_session = onnxruntime.InferenceSession( | ||
onnx_model, providers=["CPUExecutionProvider"] | ||
) | ||
ort_inputs = {ort_session.get_inputs()[0].name: dummy_input.detach().cpu().numpy()} | ||
ort_outs = ort_session.run(None, ort_inputs) | ||
|
||
# torch outputs | ||
torch_outputs = torch_model(dummy_input) | ||
|
||
# final assert - default tolerance values - rtol=1e-03, atol=1e-05 | ||
np.testing.assert_allclose( | ||
torch_outputs.detach().cpu().numpy(), ort_outs[0], rtol=0.01, atol=0.001 | ||
) | ||
print(f"-------- Model successfully verified.") | ||
|
||
|
||
def to_onnx() -> None: | ||
# error if network can't be converted | ||
net_error = ["craft", "ditn"] | ||
if args.network in net_error: | ||
msg = f"Network [{args.network}] cannot be converted to ONNX." | ||
raise RuntimeError(msg) | ||
|
||
# load network and send to device | ||
model = load_net() | ||
# set model to eval mode | ||
model.eval() | ||
|
||
# set static or dynamic | ||
if args.static is not None: | ||
dummy_input = torch.randn(1, *args.static, requires_grad=True) | ||
else: | ||
dummy_input = torch.randn(1, 3, 20, 20, requires_grad=True) | ||
|
||
# dict for dynamic axes | ||
if args.static is None: | ||
dyn_axes = { | ||
'dynamic_axes': { | ||
'input': {0: 'batch_size', 2: 'width', 3: 'height'}, | ||
'output': {0: 'batch_size', 2: 'width', 3: 'height'}, | ||
}, | ||
'input_names': ["input"], | ||
'output_names': ["output"], | ||
} | ||
else: | ||
dyn_axes = None | ||
|
||
# add _fp32 suffix to output str | ||
filename, extension = osp.splitext(args.output) | ||
output_fp32 = filename + "_fp32" + extension | ||
# begin conversion | ||
print("-------- Starting ONNX conversion (this can take a while)...") | ||
|
||
with torch.device("cpu"): | ||
# TODO: switch to dynamo_export once it supports ATen PixelShuffle | ||
# then torch.testing.assert_close for verification | ||
|
||
torch.onnx.export( | ||
model, | ||
dummy_input, | ||
output_fp32, | ||
export_params=True, | ||
opset_version=args.opset, | ||
do_constant_folding=False, | ||
**(dyn_axes or {}), | ||
) | ||
|
||
print("-------- Conversion was successful. Verifying...") | ||
# verify onnx | ||
load_onnx = onnx.load(output_fp32) | ||
torch.cuda.empty_cache() | ||
onnx.checker.check_model(load_onnx) | ||
print(f"-------- Model successfully converted to ONNX format. Saved at: {output_fp32}.") | ||
# verify outputs | ||
if args.nocheck is False: | ||
assert_verify(output_fp32, model) | ||
|
||
|
||
if args.optimize: | ||
print("-------- Running ONNX optimization...") | ||
#filename, extension = osp.splitext(args.output) | ||
#output_optimized = filename + "_fp32_optimized" + extension | ||
session_opt = onnxruntime.SessionOptions() | ||
# ENABLE_ALL can cause compatibility issues, leaving EXTENDED as default | ||
session_opt.graph_optimization_level = ( | ||
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED | ||
) | ||
session_opt.optimized_model_filepath = output_optimized | ||
# save | ||
onnxruntime.InferenceSession(output_fp32, session_opt) | ||
# verify | ||
onnx.checker.check_model(onnx.load(output_optimized)) | ||
print(f"-------- Model successfully optimized. Saved at: {output_optimized}") | ||
|
||
if args.fp16: | ||
print("-------- Converting to fp16...") | ||
output_fp16 = filename + "_fp16" + extension | ||
# convert to fp16 | ||
if args.optimize: | ||
to_fp16 = convert_float_to_float16(onnx.load(output_optimized)) | ||
else: | ||
to_fp16 = convert_float_to_float16(load_onnx) | ||
# save | ||
onnx.save(to_fp16, output_fp16) | ||
# verify | ||
onnx.checker.check_model(onnx.load(output_fp16)) | ||
print( | ||
f"-------- Model successfully converted to half-precision. Saved at: {output_fp16}." | ||
) | ||
|
||
if args.fulloptimization: | ||
# error if network can't run through onnxsim | ||
opt_error = ["omnisr"] | ||
if args.network in opt_error: | ||
msg = f"Network [{args.network}] doesnt support full optimization." | ||
raise RuntimeError(msg) | ||
|
||
print("-------- Running full optimization (this can take a while)...") | ||
output_fp32_fulloptimized = filename + "_fp32_fullyoptimized" + extension | ||
output_fp16_fulloptimized = filename + "_fp16_fullyoptimized" + extension | ||
# run onnxsim | ||
if args.optimize: | ||
simplified, check = simplify(onnx.load(output_optimized)) | ||
elif args.fp16: | ||
simplified, check = simplify(onnx.load(output_fp16)) | ||
else: | ||
simplified, check = simplify(load_onnx) | ||
assert check, "Couldn't validate ONNX model." | ||
|
||
# save and verify | ||
if args.fp16: | ||
onnx.save(simplified, output_fp16_fulloptimized) | ||
onnx.checker.check_model(onnx.load(output_fp16_fulloptimized)) | ||
else: | ||
onnx.save(simplified, output_fp32_fulloptimized) | ||
onnx.checker.check_model(onnx.load(output_fp32_fulloptimized)) | ||
|
||
print( | ||
f"-------- Model successfully optimized. Saved at: {output_fp32_fulloptimized}\n" | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
torch.set_default_device("cuda") | ||
warnings.filterwarnings("ignore", category=UserWarning) | ||
root_path = osp.abspath(osp.join(__file__, osp.pardir)) | ||
__, args = parse_options(root_path) | ||
to_onnx() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import argparse | ||
from os import path as osp | ||
|
||
from neosr.utils import scandir | ||
from neosr.utils.lmdb_util import make_lmdb_from_imgs | ||
|
||
|
||
def create_lmdb(): | ||
""" | ||
Create lmdb files. | ||
Before run this script, please run `extract_subimages.py`. | ||
""" | ||
|
||
folder_path = args.input | ||
lmdb_path = args.output | ||
img_path_list, keys = prepare_keys(folder_path) | ||
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys, multiprocessing_read=True) | ||
|
||
|
||
def prepare_keys(folder_path): | ||
"""Prepare image path list and keys. | ||
Args: | ||
folder_path (str): Folder path. | ||
Returns: | ||
list[str]: Image path list. | ||
list[str]: Key list. | ||
""" | ||
|
||
print('Reading image path list ...') | ||
img_path_list = sorted(list(scandir(folder_path, suffix='png', recursive=False))) | ||
keys = [img_path.split('.png')[0] for img_path in sorted(img_path_list)] | ||
|
||
return img_path_list, keys | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument('--input', type=str, help=("Input Path")) | ||
parser.add_argument('--output', type=str, help=("Output Path")) | ||
args = parser.parse_args() | ||
create_lmdb() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
> [!NOTE] | ||
> The package `lmdb` is required. It may only work with python <=3.11 | ||
## 📸 datasets | ||
|
||
As part of *neosr*, I have released a dataset series called *Nomos*. The purpose of these dataset is to distill only the best images from the academic and community datasets. A total of 14 datasets were manually reviewed and processed, including: [Adobe-MIT-5k](https://data.csail.mit.edu/graphics/fivek/), [RAISE](http://loki.disi.unitn.it/RAISE/), [LSDIR](https://data.vision.ee.ethz.ch/yawli/), [LIU4k-v2](https://structpku.github.io/LIU4K_Dataset/LIU4K_v2.html), [KONIQ-10k](https://database.mmsp-kn.de/koniq-10k-database.html), [Nikon LL RAW](https://www.kaggle.com/datasets/razorblade/nikon-camera-dataset), [DIV8k](https://ieeexplore.ieee.org/document/9021973), [FFHQ](https://github.com/NVlabs/ffhq-dataset), [Flickr2k](http://cv.snu.ac.kr/research/EDSR/Flickr2K.tar), [ModernAnimation1080_v2](https://huggingface.co/datasets/Zarxrax/ModernAnimation1080_v2), [Rawsamples](https://www.rawsamples.ch/index.php/en/), [SignatureEdits](https://www.signatureedits.com/free-raw-photos/), [Hasselblad raw samples](https://www.hasselblad.com/learn/sample-images/) and [Unsplash](https://unsplash.com/). | ||
|
||
- `Nomos-v2` (*recommended*): contains 6000 images, multipurpose. Data distribution: | ||
|
||
```mermaid | ||
pie | ||
title Nomos-v2 distribution | ||
"Animal / fur" : 439 | ||
"Interiors" : 280 | ||
"Exteriors / misc" : 696 | ||
"Architecture / geometric" : 1470 | ||
"Drawing / painting / anime" : 1076 | ||
"Humans" : 598 | ||
"Mountain / Rocks" : 317 | ||
"Text" : 102 | ||
"Textures" : 439 | ||
"Vegetation" : 574 | ||
``` | ||
|
||
- `nomos_uni` (*recommended for lightweight networks*): contains 2989 images, multipurpose. Meant to be used on lightweight networks (<800k parameters). | ||
- `hfa2k`: contains 2568 anime images. | ||
|
||
| dataset download | sha256 | | ||
|-----------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------| | ||
| [**nomosv2**](https://drive.google.com/file/d/1vqKWGtpkYHKv8uqK_xgYZbYuPHJS2l8j/view?usp=drive_link) (3GB) | [sha256](https://drive.google.com/file/d/12eNzPqHd2N1rTWMDh_rAv3urNypJexQT/view?usp=drive_link) | | ||
| [**nomosv2.lmdb**](https://drive.google.com/file/d/1Rzdjt3w0qXle7vHa8FeFltmyKTMIwPR5/view?usp=drive_link) (3GB) | [sha256](https://drive.google.com/file/d/1IrDjI37psiCc-Khn3_KSyov-xP4txZYe/view?usp=drive_link) | | ||
| [nomosv2_lq_4x](https://drive.google.com/file/d/1YiCywSFwRuwaYmnZ0TgoWDvcDQifAsZo/view?usp=drive_link) (187MB) | [sha256](https://drive.google.com/file/d/1iOOte6h-AE1iD-i5wl_gVx1uJzNTS4Cq/view?usp=drive_link) | | ||
| [nomosv2_lq_4x.lmdb](https://drive.google.com/file/d/1IrDjI37psiCc-Khn3_KSyov-xP4txZYe/view?usp=drive_link) (187MB) | [sha256](https://drive.google.com/file/d/1bpuuiGFNBrDuZiRSP5hpVgFQx44MImay/view?usp=drive_link) | | ||
| [nomos_uni](https://drive.google.com/file/d/1LVS7i9J3mP9f2Qav2Z9i9vq3y2xsKxA_/view?usp=sharing) (1.3GB) | [sha256](https://drive.google.com/file/d/1cdzVSpXQQhcbRVuFPbNtb6mZx_BoLwyW/view?usp=sharing) | | ||
| [nomos_uni.lmdb](https://drive.google.com/file/d/1MHJCS4Zl3H5nihgpA_VVliziXnhJ3aU7/view?usp=sharing) (1.3GB) | [sha256](https://drive.google.com/file/d/1g3XLV-hFdLUcuAHLv2R6Entye8MkMx0V/view?usp=drive_link) | | ||
| [nomos_uni_lq_4x](https://drive.google.com/file/d/1uvMl8dG8-LXjCOEoO9Aiq5Q9rd_BIUw9/view?usp=sharing) | [sha256](https://drive.google.com/file/d/1MTJBcfaMYdfWhsZCWEEOwbKSdmN5dVwl/view?usp=drive_link) | | ||
| [nomos_uni_lq_4x.lmdb](https://drive.google.com/file/d/1h27AsZze_FFsAsf8eXupcqIZQHhvwa1y/view?usp=sharing) | [sha256](https://drive.google.com/file/d/1dhvIlM_uaIYMEKuijemnlmMTg4qf7bj7/view?usp=drive_link) | | ||
| [hfa2k](https://drive.google.com/file/d/1PonJdHWwCtBdG4i1LwThm06t6RibnVu8/view?usp=sharing) | [sha256](https://drive.google.com/file/d/1ojDSyKCnCDoLOf9C-Zo4-BmuVSNTItEl/view?usp=sharing) | | ||
|
||
See more datasets on the [**readme section about datasets**](https://github.com/muslll/neosr?tab=readme-ov-file#datasets). | ||
|
||
|
||
## utils | ||
|
||
In the [utils](utils/) folder you can find some tools to help prepare datasets, such as generating meta info files and converting to LMDB. |
Oops, something went wrong.