forked from nhatsmrt/superres
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_superres.py
86 lines (71 loc) · 2.61 KB
/
test_superres.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from superres.metrics import PSNR
from superres.learner import SuperResolutionLearner, MultiResolutionLearner
from superres.models import PixelShuffleUpsampler, DeepLaplacianPyramidNetV2
from nntoolbox.vision.utils import UnlabelledImageListDataset, UnsupervisedFromSupervisedDataset
from nntoolbox.vision.losses import CharbonnierLossV2
from nntoolbox.vision.transforms import RandomRescale
from nntoolbox.callbacks import Tensorboard, LossLogger, \
ModelCheckpoint, ToDeviceCallback
from generative_models.metrics import SSIM
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Resize, RandomCrop, ToTensor, RandomHorizontalFlip
from torchvision.datasets import CIFAR10
from torch.optim import SGD, Adam, AdamW
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
print("Begin creating dataset")
images = UnlabelledImageListDataset("data/train2014/")
# images = UnsupervisedFromSupervisedDataset(CIFAR10(root="data/CIFAR/", download=True, train=True))
upscale_factor = 4
batch_size = 64
print("Begin splitting data")
train_size = int(0.80 * len(images))
val_size = len(images) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(images, [train_size, val_size])
train_dataset.dataset.transform = Compose(
[
Resize((512, 512)),
RandomCrop((128, 128)),
RandomHorizontalFlip(0.5),
RandomRescale(),
# ToTensor()
]
)
val_dataset.dataset.transform = Compose(
[
Resize((512, 512)),
RandomCrop((128, 128)),
# ToTensor()
]
)
print("Begin creating data dataloaders")
dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
dataloader_val = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
# print(len(dataloader))
print("Creating models")
model = DeepLaplacianPyramidNetV2(max_scale_factor=upscale_factor)
print("Finish creating model")
# optimizer = SGD(model.parameters(), lr=1e-5, momentum=0.9, weight_decay=1e-4)
optimizer = Adam(model.parameters())
# optimizer = AdamW(model.parameters(), weight_decay=1e-4)
learner = MultiResolutionLearner(
dataloader, dataloader_val,
model, criterion=CharbonnierLossV2(), optimizer=optimizer
)
metrics = {
"psnr": PSNR(batch_size=batch_size),
"ssim": SSIM()
}
callbacks = [
ToDeviceCallback(),
Tensorboard(),
LossLogger(),
# lr_scheduler,
ModelCheckpoint(learner=learner, save_best_only=False, filepath='weights/model.pt'),
]
learner.learn(
n_epoch=10, downsampling_mode='bicubic', metrics=metrics, callbacks=callbacks,
max_upscale_factor=upscale_factor, final_metric='ssim'
)