-
Notifications
You must be signed in to change notification settings - Fork 24
/
evaluation_smooth.lua
72 lines (57 loc) · 1.92 KB
/
evaluation_smooth.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
require 'torch'
require 'image'
require 'sys'
require 'cunn'
require 'cutorch'
require 'cudnn'
require 'nngraph'
cudnn.fastest = true
cudnn.benchmark = true
imgPath = '/mnt/codes/learning_to_optimize/testVOC'
savePath = '/mnt/codes/reflection/models/l0'
model = torch.load('/mnt/codes/reflection/models/CEILNet_smooth_L0.net')
model = model:cuda()
model:training()
model_computeEdge = nn.Sequential()
model_computeEdge:add(nn.computeEdge(100))
files = {}
for file in paths.files(imgPath) do
if string.find(file,'-input.png') then
table.insert(files, paths.concat(imgPath,file))
end
end
for _,inputFile in ipairs(files) do
local labelFile = string.gsub(inputFile,'input','label-L0smooth')
local labelImg = image.load(labelFile)
local inputImg = image.load(inputFile)
local savLabel = string.gsub(labelFile,imgPath,savePath)
local savColor = string.gsub(inputFile,imgPath,savePath)
image.save(savLabel,labelImg)
image.save(savColor,inputImg)
local height = inputImg:size(2)
local width = inputImg:size(3)
local input = torch.CudaTensor(1, 3, height, width)
input[1] = inputImg:cuda()
input = input * 255
local inputC = input:clone()
local label = torch.CudaTensor(1, 3, height, width)
label[1] = labelImg:cuda()
label = label * 255
local inputs = torch.CudaTensor(1, 4, height, width)
inputs[{{},{1,3},{},{}}] = input
inputs[{{},{4},{},{}}] = model_computeEdge:forward(input)
inputs = inputs - 115
inputs = {inputs,input}
local predictions = model:forward(inputs)
predictions2 = predictions[2]
for m = 1,3 do
local numerator = torch.dot(predictions2[1][m], inputC[1][m])
local denominator = torch.dot(predictions2[1][m], predictions2[1][m])
local alpha = numerator/denominator
predictions2[1][m] = predictions2[1][m] * alpha
end
predictions2 = predictions2/255
local sav = string.gsub(savColor,'%-input.png','-predict.png')
image.save(sav,predictions2[1])
::done::
end