-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathtest.lua
48 lines (39 loc) · 1.25 KB
/
test.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
require 'torch'
require 'optim'
require 'image'
local utils = require 'super_resolution.utils'
local gm = require 'graphicsmagick'
local cmd = torch.CmdLine()
-- Generic options
cmd:option('-img','./imgs/comic_input.bmp')
cmd:option('-output','output.bmp')
-- Super-resolution options
cmd:option('-use_tanh', false)
-- Checkpointing
cmd:option('-model', './models/SRResNet_MSE_100.t7')
-- Backend options
cmd:option('-gpu', 0)
cmd:option('-use_cudnn', 0)
cmd:option('-backend', 'cuda')
function main()
local opt = cmd:parse(arg)
-- Figure out the backend
local dtype, use_cudnn = utils.setup_gpu(opt.gpu, opt.backend, opt.use_cudnn)
-- Build the model
local model = nil
print('Loading model ...')
model = torch.load(opt.model):type(dtype)
if use_cudnn then cudnn.convert(model, cudnn) end
cudnn.benchmark = false
model:training()
local img = gm.Image(opt.img):colorspace('RGB')
local input = img:toTensor('float','RGB','DHW')
input = torch.reshape(input,1,input:size(1),input:size(2),input:size(3))
if opt.use_tanh then
input = input:mul(2.0):add(-1.0)
end
local output = model:forward(input:type(dtype))
local image = gm.Image(output[1]:float():clamp(0, 1),'RGB','DHW')
image:save(opt.output)
end
main()