-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.lua
200 lines (185 loc) · 6.1 KB
/
main.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
require 'cunn'
require 'optim'
local inspect = require('inspect')
local u = require('utils')
local data = require('data')
local Convnet = require('convnet')
local MLP = require('mlp')
local main = {}
include('VBLinear.lua')
torch.setdefaulttensortype('torch.FloatTensor')
function main:train(net, dataset, opt)
local accuracy = 0
local error = 0
local t = 1
local B = opt.trainSize/opt.batchSize
self.indices = self.indices or torch.range(1, opt.trainSize, opt.batchSize)
u.shuffle(self.indices):apply(function(batch_index)
-- local batchtime = sys.clock()
local inputs, targets = dataset:create_minibatch(batch_index, opt.batchSize, opt.trainSize, opt.geometry)
if opt.cuda then
inputs = inputs:cuda()
targets = targets:cuda()
end
collectgarbage()
-- reset gradients
net:resetGradients()
if opt.type == 'vb' then
local sample_err = 0
local sample_acc = 0
for i = 1, opt.S do
net:sample()
local err, acc = net:run(inputs, targets)
sample_err = sample_err + err
sample_acc = sample_acc + acc
end
accuracy = accuracy + sample_acc/opt.S
error = error + sample_err/opt.S
net:update(opt)
-- opt.S = torch.ceil(torch.pow(opt.S, 0.95))
else
local err, acc = net:run(inputs, targets)
accuracy = accuracy + acc
error = error + err
net:update(opt)
end
xlua.progress(t*opt.batchSize, opt.trainSize)
t = t + 1
end)
return accuracy/B, error/B
end
function main:test(net, dataset, opt)
local error = 0
local accuracy = 0
local B = opt.testSize/opt.testBatchSize
for t = 1,opt.testSize,opt.testBatchSize do
-- disp progress
xlua.progress(t, opt.testSize)
local inputs, targets = dataset:create_minibatch(t, opt.testBatchSize, opt.testSize, opt.geometry)
if opt.cuda then
inputs = inputs:cuda()
targets = targets:cuda()
end
collectgarbage()
local err, acc = net:test(inputs, targets)
accuracy = accuracy + acc
error = error + err
end
return accuracy/B, error/B
end
function main:crossvalidate()
local opt = require('config')
local max_epoch = 50
local totalacc = 0
local totalerr = 0
local i = 1
local k = 10
Log = require('logger'):init(opt.network_name)
while i <= k do
local net = MLP:buildModel(opt)
local trainAccuracy, trainError
local testAccuracy, testError
local old_lc = 0
local new_lc = 0.00001
local t = 1
while new_lc > old_lc do
old_lc = new_lc
-- while t <= max_epoch do
local trainSet, testSet = data.getBacteriaFold(i, k)
testAccuracy, testError = self:test(net, testSet, opt)
print(testAccuracy, testError)
trainAccuracy, trainError = self:train(net, trainSet, opt)
print(trainAccuracy, trainError)
if opt.log then
Log:add('devacc-fold='..i, testAccuracy)
Log:add('trainacc-fold='..i, trainAccuracy)
Log:add('deverr-fold='..i, testError)
Log:add('trainerr-fold='..i, trainError)
if opt.type == 'vb' then
new_lc = net:calc_lc(opt)
Log:add('lc-fold='..i, new_lc)
end
Log:flush()
end
t = t + 1
end
u.safe_save(net, opt.network_name, 'model-fold='..i)
totalacc = totalacc + testAccuracy
totalerr = totalerr + testError
i = i + 1
end
return totalacc/k, totalerr/k
end
function main:checkgrads(net)
local to_check = net.lvars
local epsilon = 2*torch.sqrt(1e-12)*(1+torch.norm(to_check))
print('epsilon', epsilon)
local diff, dc, dc_est = optim.checkgrad(function()
net:compute_prior()
local lce, lcg = net:compute_vargrads()
-- local lce, lcg = net:compute_mugrads()
return net:calc_LC(), lcg:double() end,
to_check,
epsilon)
print("difference: ",diff)
print(dc:min(), dc:max())
print(dc_est:min(), dc_est:max())
end
function main:run()
local opt = require('config')
-- global logger
torch.setnumthreads(opt.threads)
print('<torch> set nb of threads to ' .. torch.getnumthreads())
-- local net = Convnet:buildModel(opt)
local net
if opt.network_to_load ~= "" then
net = torch.load(opt.network_to_load..'/model')
Log = require('logger'):init(opt.network_name, true)
else
net = MLP:buildModel(opt)
Log = require('logger'):init(opt.network_name, false)
end
local trainSet, testSet
if opt.dataset == 'mnist' then
trainSet, testSet = data.getMnist()
else
trainSet, testSet = data.getBacteriaFold(1, 10)
end
-- self:checkgrads(net)
-- exit()
while true do
local trainAccuracy, trainError = self:train(net, trainSet, opt)
print(trainAccuracy, trainError)
local testAccuracy, testError = self:test(net, testSet, opt)
print(testAccuracy, testError)
if opt.log then
Log:add('devacc', testAccuracy)
Log:add('trainacc', trainAccuracy)
Log:add('deverr', testError)
Log:add('trainerr', trainError)
if opt.type == 'vb' then
local lc = net:calc_lc(opt)
Log:add('lc', lc)
print('LC: ', lc)
end
Log:flush()
end
u.safe_save(net, opt.network_name, 'model')
-- net:save()
end
end
main:run()
--print(main:crossvalidate())
--local net = MLP:load('vsadf2')
--local net = torch.load('vsadf2/model')
--local opt = net.opt
--opt.testSamples = 5
--opt.quicktest = false
--local opt = require('config')
--local trainSet, testSet = data.getMnist()
--local trainSet, testSet = data.getBacteriaFold(1, 10)
--print(net.model)
--print(testSet)
--print(main:test(net, testSet, opt))
--
return main