-
Notifications
You must be signed in to change notification settings - Fork 6
/
weight_init.lua
78 lines (65 loc) · 2.12 KB
/
weight_init.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#! /usr/bin/env lua
--[[ Different weight initialization methods
--
-- > model = require('weight-init')(model, 'heuristic')
--]]
require("nn")
-- "Efficient backprop"
-- Yann Lecun, 1998
local function w_init_heuristic(fan_in, fan_out)
return math.sqrt(1/(3*fan_in))
end
-- "Understanding the difficulty of training deep feedforward neural networks"
-- Xavier Glorot, 2010
local function w_init_xavier(fan_in, fan_out)
return math.sqrt(2/(fan_in + fan_out))
end
-- "Understanding the difficulty of training deep feedforward neural networks"
-- Xavier Glorot, 2010
local function w_init_xavier_caffe(fan_in, fan_out)
return math.sqrt(1/fan_in)
end
-- "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification"
-- Kaiming He, 2015
local function w_init_kaiming(fan_in, fan_out)
return math.sqrt(4/(fan_in + fan_out))
end
local function w_init(net, arg)
-- choose initialization method
local method = nil
if arg == 'heuristic' then
method = w_init_heuristic
elseif arg == 'xavier' then
method = w_init_xavier
elseif arg == 'xavier_caffe' then
method = w_init_xavier_caffe
elseif arg == 'kaiming' then
method = w_init_kaiming
else
assert(false)
end
-- loop over all convolutional modules
for i = 1, #net.modules do
local m = net.modules[i]
if m.__typename == 'nn.SpatialConvolution' then
m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW))
elseif m.__typename == 'nn.SpatialConvolutionMM' then
m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW))
elseif m.__typename == 'nn.LateralConvolution' then
m:reset(method(m.nInputPlane*1*1, m.nOutputPlane*1*1))
elseif m.__typename == 'nn.VerticalConvolution' then
m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW))
elseif m.__typename == 'nn.HorizontalConvolution' then
m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW))
elseif m.__typename == 'nn.Linear' then
m:reset(method(m.weight:size(2), m.weight:size(1)))
elseif m.__typename == 'nn.TemporalConvolution' then
m:reset(method(m.weight:size(2), m.weight:size(1)))
end
if m.bias then
m.bias:zero()
end
end
return net
end
return w_init