forked from Element-Research/rnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathRecursor.lua
144 lines (125 loc) · 4.57 KB
/
Recursor.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
------------------------------------------------------------------------
--[[ Recursor ]]--
-- Decorates module to be used within an AbstractSequencer.
-- It does this by making the decorated module conform to the
-- AbstractRecurrent interface (which is inherited by LSTM/Recurrent)
------------------------------------------------------------------------
local Recursor, parent = torch.class('nn.Recursor', 'nn.AbstractRecurrent')
function Recursor:__init(module, rho)
parent.__init(self, rho or 9999999)
self.recurrentModule = module
self.recurrentModule:backwardOnline()
self.onlineBackward = true
self.module = module
self.modules = {module}
end
function Recursor:updateOutput(input)
if self.train ~= false then
-- set/save the output states
self:recycle()
local recurrentModule = self:getStepModule(self.step)
output = recurrentModule:updateOutput(input)
else
output = self.recurrentModule:updateOutput(input)
end
if self.train ~= false then
local input_ = self.inputs[self.step]
self.inputs[self.step] = self.copyInputs
and nn.rnn.recursiveCopy(input_, input)
or nn.rnn.recursiveSet(input_, input)
end
self.outputs[self.step] = output
self.output = output
self.step = self.step + 1
self.updateGradInputStep = nil
self.accGradParametersStep = nil
self.gradParametersAccumulated = false
return self.output
end
function Recursor:backwardThroughTime(timeStep, timeRho)
timeStep = timeStep or self.step
local rho = math.min(timeRho or self.rho, timeStep-1)
local stop = timeStep - rho
local gradInput
if self.fastBackward then
self.gradInputs = {}
for step=timeStep-1,math.max(stop, 1),-1 do
-- backward propagate through this step
local recurrentModule = self:getStepModule(step)
gradInput = recurrentModule:backward(self.inputs[step], self.gradOutputs[step] , self.scales[step])
table.insert(self.gradInputs, 1, gradInput)
end
self.gradParametersAccumulated = true
else
gradInput = self:updateGradInputThroughTime(timeStep, timeRho)
self:accGradParametersThroughTime(timeStep, timeRho)
end
return gradInput
end
function Recursor:updateGradInputThroughTime(timeStep, rho)
assert(self.step > 1, "expecting at least one updateOutput")
self.gradInputs = {}
timeStep = timeStep or self.step
local rho = math.min(rho or self.rho, timeStep-1)
local stop = timeStep - rho
local gradInput
for step=timeStep-1,math.max(stop,1),-1 do
-- backward propagate through this step
local recurrentModule = self:getStepModule(step)
gradInput = recurrentModule:updateGradInput(self.inputs[step], self.gradOutputs[step])
table.insert(self.gradInputs, 1, gradInput)
end
return gradInput
end
function Recursor:accGradParametersThroughTime(timeStep, rho)
timeStep = timeStep or self.step
local rho = math.min(rho or self.rho, timeStep-1)
local stop = timeStep - rho
for step=timeStep-1,math.max(stop,1),-1 do
-- backward propagate through this step
local recurrentModule = self:getStepModule(step)
recurrentModule:accGradParameters(self.inputs[step], self.gradOutputs[step], self.scales[step])
end
self.gradParametersAccumulated = true
return gradInput
end
function Recursor:accUpdateGradParametersThroughTime(lr, timeStep, rho)
timeStep = timeStep or self.step
local rho = math.min(rho or self.rho, timeStep-1)
local stop = timeStep - rho
for step=timeStep-1,math.max(stop,1),-1 do
-- backward propagate through this step
local recurrentModule = self:getStepModule(step)
recurrentModule:accUpdateGradParameters(self.inputs[step], self.gradOutputs[step], lr*self.scales[step])
end
return gradInput
end
function Recursor:includingSharedClones(f)
local modules = self.modules
self.modules = {}
local sharedClones = self.sharedClones
self.sharedClones = nil
for i,modules in ipairs{modules, sharedClones} do
for j, module in pairs(modules) do
table.insert(self.modules, module)
end
end
local r = f()
self.modules = modules
self.sharedClones = sharedClones
return r
end
function Recursor:backwardOnline(online)
assert(oneline ~= false, "Recursor only supports online backwards")
parent.backwardOnline(self)
end
function Recursor:forget(offset)
parent.forget(self, offset)
nn.Module.forget(self)
return self
end
function Recursor:maxBPTTstep(rho)
self.rho = rho
nn.Module.maxBPTTstep(self, rho)
end
Recursor.__tostring__ = nn.Decorator.__tostring__