-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathinference.py
223 lines (181 loc) · 8.41 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import importlib
import torch
import torch.backends.cudnn as cudnn
from utils.utils import *
from collections import OrderedDict
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import h5py
from torchvision.transforms import ToTensor
import imageio
from tqdm import tqdm
def MultiTestSetDataLoader(args):
# get testdataloader of every test dataset
data_list = None
if args.data_name in ['ALL', 'RE_Lytro', 'RE_HCI']:
if args.task == 'SR':
dataset_dir = args.path_for_test + 'SR_' + str(args.angRes_in) + 'x' + str(args.angRes_in) + '_' + \
str(args.scale_factor) + 'x/'
data_list = os.listdir(dataset_dir)
elif args.task == 'RE':
dataset_dir = args.path_for_test + 'RE_' + str(args.angRes_in) + 'x' + str(args.angRes_in) + '_' + \
str(args.angRes_out) + 'x' + str(args.angRes_out) + '/' + args.data_name
data_list = os.listdir(dataset_dir)
else:
data_list = [args.data_name]
test_Loaders = []
length_of_tests = 0
for data_name in data_list:
test_Dataset = TestSetDataLoader(args, data_name, Lr_Info=data_list.index(data_name))
length_of_tests += len(test_Dataset)
test_Loaders.append(DataLoader(dataset=test_Dataset, num_workers=args.num_workers, batch_size=1, shuffle=False))
return data_list, test_Loaders, length_of_tests
class TestSetDataLoader(Dataset):
def __init__(self, args, data_name = 'ALL', Lr_Info=None):
super(TestSetDataLoader, self).__init__()
self.angRes_in = args.angRes_in
self.angRes_out = args.angRes_out
if args.task == 'SR':
self.dataset_dir = args.path_for_test + 'SR_' + str(args.angRes_in) + 'x' + str(args.angRes_in) + '_' + \
str(args.scale_factor) + 'x/'
self.data_list = [data_name]
elif args.task == 'RE':
self.dataset_dir = args.path_for_test + 'RE_' + str(args.angRes_in) + 'x' + str(args.angRes_in) + '_' + \
str(args.angRes_out) + 'x' + str(args.angRes_out) + '/' + args.data_name + '/'
self.data_list = [data_name]
self.file_list = []
for data_name in self.data_list:
tmp_list = os.listdir(self.dataset_dir + data_name)
for index, _ in enumerate(tmp_list):
tmp_list[index] = data_name + '/' + tmp_list[index]
self.file_list.extend(tmp_list)
self.item_num = len(self.file_list)
def __getitem__(self, index):
file_name = [self.dataset_dir + self.file_list[index]]
with h5py.File(file_name[0], 'r') as hf:
Lr_SAI_y = np.array(hf.get('Lr_SAI_y'))
Hr_SAI_y = np.array(hf.get('Hr_SAI_y'))
Sr_SAI_cbcr = np.array(hf.get('Sr_SAI_cbcr'), dtype='single')
Lr_SAI_y = np.transpose(Lr_SAI_y, (1, 0))
Hr_SAI_y = np.transpose(Hr_SAI_y, (1, 0))
Sr_SAI_cbcr = np.transpose(Sr_SAI_cbcr, (2, 1, 0))
Lr_SAI_y = ToTensor()(Lr_SAI_y.copy())
Hr_SAI_y = ToTensor()(Hr_SAI_y.copy())
Sr_SAI_cbcr = ToTensor()(Sr_SAI_cbcr.copy())
Lr_angRes_in = self.angRes_in
Lr_angRes_out = self.angRes_out
LF_name = self.file_list[index].split('/')[-1].split('.')[0]
return Lr_SAI_y, Hr_SAI_y, Sr_SAI_cbcr, [Lr_angRes_in, Lr_angRes_out], LF_name
def __len__(self):
return self.item_num
def main(args):
''' Create Dir for Save '''
_, _, result_dir = create_dir(args)
result_dir = result_dir.joinpath('TEST')
result_dir.mkdir(exist_ok=True)
''' CPU or Cuda'''
device = torch.device(args.device)
if 'cuda' in args.device:
torch.cuda.set_device(device)
''' DATA TEST LOADING '''
print('\nLoad Test Dataset ...')
test_Names, test_Loaders, length_of_tests = MultiTestSetDataLoader(args)
print("The number of test data is: %d" % length_of_tests)
''' MODEL LOADING '''
print('\nModel Initial ...')
MODEL_PATH = 'model.' + args.task + '.' + args.model_name
MODEL = importlib.import_module(MODEL_PATH)
net = MODEL.get_model(args)
''' Load Pre-Trained PTH '''
if args.use_pre_ckpt == False:
net.apply(MODEL.weights_init)
else:
ckpt_path = args.path_pre_pth
checkpoint = torch.load(ckpt_path, map_location='cpu')
try:
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
name = 'module.' + k # add `module.`
new_state_dict[name] = v
# load params
net.load_state_dict(new_state_dict)
print('Use pretrain model!')
except:
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
new_state_dict[k] = v
# load params
net.load_state_dict(new_state_dict)
print('Use pretrain model!')
pass
pass
net = net.to(device)
cudnn.benchmark = True
''' Print Parameters '''
print('PARAMETER ...')
print(args)
''' TEST on every dataset '''
print('\nStart test...')
with torch.no_grad():
for index, test_name in enumerate(test_Names):
test_loader = test_Loaders[index]
save_dir = result_dir.joinpath(test_name)
save_dir.mkdir(exist_ok=True)
test(test_loader, device, net, save_dir)
pass
def test(test_loader, device, net, save_dir=None):
for idx_iter, (Lr_SAI_y, Hr_SAI_y, Sr_SAI_cbcr, data_info, LF_name) in tqdm(enumerate(test_loader), total=len(test_loader), ncols=70):
[Lr_angRes_in, Lr_angRes_out] = data_info
data_info[0] = Lr_angRes_in[0].item()
data_info[1] = Lr_angRes_out[0].item()
Lr_SAI_y = Lr_SAI_y.squeeze().to(device) # numU, numV, h*angRes, w*angRes
Sr_SAI_cbcr = Sr_SAI_cbcr
''' Crop LFs into Patches '''
subLFin = LFdivide(Lr_SAI_y, args.angRes_in, args.patch_size_for_test, args.stride_for_test)
numU, numV, H, W = subLFin.size()
subLFin = rearrange(subLFin, 'n1 n2 a1h a2w -> (n1 n2) 1 a1h a2w')
subLFout = torch.zeros(numU * numV, 1, args.angRes_in * args.patch_size_for_test * args.scale_factor,
args.angRes_in * args.patch_size_for_test * args.scale_factor)
''' SR the Patches '''
for i in range(0, numU * numV, args.minibatch_for_test):
tmp = subLFin[i:min(i + args.minibatch_for_test, numU * numV), :, :, :]
with torch.no_grad():
net.eval()
torch.cuda.empty_cache()
out = net(tmp.to(device), data_info)
subLFout[i:min(i + args.minibatch_for_test, numU * numV), :, :, :] = out
subLFout = rearrange(subLFout, '(n1 n2) 1 a1h a2w -> n1 n2 a1h a2w', n1=numU, n2=numV)
''' Restore the Patches to LFs '''
Sr_4D_y = LFintegrate(subLFout, args.angRes_out, args.patch_size_for_test * args.scale_factor,
args.stride_for_test * args.scale_factor, Hr_SAI_y.size(-2)//args.angRes_out, Hr_SAI_y.size(-1)//args.angRes_out)
Sr_SAI_y = rearrange(Sr_4D_y, 'a1 a2 h w -> 1 1 (a1 h) (a2 w)')
''' Save RGB '''
if save_dir is not None:
save_dir_ = save_dir.joinpath(LF_name[0])
save_dir_.mkdir(exist_ok=True)
Sr_SAI_ycbcr = torch.cat((Sr_SAI_y, Sr_SAI_cbcr), dim=1)
Sr_SAI_rgb = (ycbcr2rgb(Sr_SAI_ycbcr.squeeze().permute(1, 2, 0).numpy()).clip(0,1)*255).astype('uint8')
Sr_4D_rgb = rearrange(Sr_SAI_rgb, '(a1 h) (a2 w) c -> a1 a2 h w c', a1=args.angRes_out, a2=args.angRes_out)
# save all views
for i in range(args.angRes_out):
for j in range(args.angRes_out):
img = Sr_4D_rgb[i, j, :, :, :]
path = str(save_dir_) + '/' + 'View' + '_' + str(i) + '_' + str(j) + '.bmp'
imageio.imwrite(path, img)
pass
pass
pass
pass
pass
if __name__ == '__main__':
from option import args
args.scale_factor = 4
args.path_for_test = './data_for_inference/'
# args.data_name = 'NTIRE_Val_Real'
# args.model_name = 'LFT'
# args.path_pre_pth = './pth/LFT_5x5_4x_model.pth'
# main(args)
args.data_name = 'NTIRE_Val_Synth'
args.model_name = 'LFT'
args.path_pre_pth = './pth/LFT_5x5_4x_model.pth'
main(args)