-
Notifications
You must be signed in to change notification settings - Fork 24
/
evaluation_reflection.lua
68 lines (57 loc) · 2.56 KB
/
evaluation_reflection.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
require 'torch'
require 'image'
require 'sys'
require 'cunn'
require 'cutorch'
require 'cudnn'
imgPath = '/mnt/codes/reflection/reflection_data_blurry_few_400'
savePath = '/mnt/codes/reflection/models/data'
model = torch.load('/mnt/codes/reflection/models/CEILNet_reflection.net')
model = model:cuda()
model:training()
model_edge = nn.computeEdge(0.02)
files = {}
for file in paths.files(imgPath) do
if string.find(file,'-input.png') then
table.insert(files, paths.concat(imgPath,file))
end
end
for _,inputFile in ipairs(files) do
local inputImg = image.load(inputFile)
local height = inputImg:size(2)
local width = inputImg:size(3)
local input = torch.CudaTensor(1, 3, height, width)
input[1] = inputImg:cuda()
input = input * 255
local inputs = torch.CudaTensor(1, 4, height, width)
inputs[{{},{1,3},{},{}}] = input
inputs[{{},{4},{},{}}] = model_edge:forward(input)
inputs = inputs - 115
local inputs = {inputs,input}
local inputC = input:clone()
local predictions = model:forward(inputs)
local pred_b = predictions[2]
-- for m = 1,3 do
-- local numerator = torch.dot(pred_b[1][m], inputC[1][m])
-- local denominator = torch.dot(pred_b[1][m], pred_b[1][m])
-- local alpha = numerator/denominator
-- pred_b[1][m] = pred_b[1][m] * alpha
-- end
-- Note we comment the color correction codes since we observe better performance without this module for the reflection removal task, but this module still works well for the image smoothing task. This is simply due to dramatic difference between the color information of input and ground truth output images in the reflection removal task, which makes the alignment from predicted images to input images imprecise.
-- local pred_r = torch.csub(inputC, pred_b)
-- for m = 1,3 do
-- local numerator = torch.dot(pred_r[1][m], inputC[1][m])
-- local denominator = torch.dot(pred_r[1][m], pred_r[1][m])
-- local alpha = numerator/denominator
-- pred_r[1][m] = pred_r[1][m] * alpha
-- end
local savColor = string.gsub(inputFile,imgPath,savePath)
pred_b = pred_b/255
local sav = string.gsub(savColor,'input.png','predict1.png')
image.save(sav,pred_b[1])
-- -- Our CNN dose not predict refletion layers. The following code computes approximate reflection layers by simply subtracting the predicted background layers from the input images. Note the result so-obtained may not reflect the image structure and appearance of the original reflection scene.
-- pred_r = pred_r/255
-- local sav = string.gsub(savColor,'input.png','predict2.png')
-- image.save(sav,pred_r[1])
::done::
end