forked from Element-Research/rnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSequencer.lua
158 lines (137 loc) · 5.77 KB
/
Sequencer.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
------------------------------------------------------------------------
--[[ Sequencer ]]--
-- Encapsulates a Module.
-- Input is a sequence (a table) of tensors.
-- Output is a sequence (a table) of tensors of the same length.
-- Applies the module to each element in the sequence.
-- Handles both recurrent modules and non-recurrent modules.
-- The sequences in a batch must have the same size.
-- But the sequence length of each batch can vary.
------------------------------------------------------------------------
assert(not nn.Sequencer, "update nnx package : luarocks install nnx")
local Sequencer, parent = torch.class('nn.Sequencer', 'nn.AbstractSequencer')
function Sequencer:__init(module)
parent.__init(self)
if not torch.isTypeOf(module, 'nn.Module') then
error"Sequencer: expecting nn.Module instance at arg 1"
end
-- we can decorate the module with a Recursor to make it AbstractRecurrent
self.module = (not torch.isTypeOf(module, 'nn.AbstractRecurrent')) and nn.Recursor(module) or module
-- backprop through time (BPTT) will be done online (in reverse order of forward)
self.module:backwardOnline()
self.modules = {self.module}
for i,modula in ipairs(self.module:listModules()) do
if torch.isTypeOf(modula, "nn.AbstractRecurrent") then
modula.copyInputs = false
modula.copyGradOutputs = false
end
end
self.output = {}
-- table of buffers used for evaluation
self._output = {}
-- so that these buffers aren't serialized :
self.dpnn_mediumEmpty = _.clone(self.dpnn_mediumEmpty)
table.insert(self.dpnn_mediumEmpty, '_output')
-- default is to forget previous inputs before each forward()
self._remember = 'neither'
end
function Sequencer:updateOutput(inputTable)
assert(torch.type(inputTable) == 'table', "expecting input table")
-- Note that the Sequencer hijacks the rho attribute of the rnn
self.module:maxBPTTstep(#inputTable)
if self.train ~= false then -- training
if not (self._remember == 'train' or self._remember == 'both') then
self.module:forget()
end
self.output = {}
for step, input in ipairs(inputTable) do
self.output[step] = self.module:updateOutput(input)
end
else -- evaluation
if not (self._remember == 'eval' or self._remember == 'both') then
self.module:forget()
end
-- during evaluation, recurrent modules reuse memory (i.e. outputs)
-- so we need to copy each output into our own table
for step, input in ipairs(inputTable) do
self.output[step] = nn.rnn.recursiveCopy(
self.output[step] or table.remove(self._output, 1),
self.module:updateOutput(input)
)
end
-- remove extra output tensors (save for later)
for i=#inputTable+1,#self.output do
table.insert(self._output, self.output[i])
self.output[i] = nil
end
end
return self.output
end
function Sequencer:updateGradInput(inputTable, gradOutputTable)
assert(torch.type(gradOutputTable) == 'table', "expecting gradOutput table")
assert(#gradOutputTable == #inputTable, "gradOutput should have as many elements as input")
-- back-propagate through time (BPTT)
self.gradInput = {}
for step=#gradOutputTable,1,-1 do
self.gradInput[step] = self.module:updateGradInput(inputTable[step], gradOutputTable[step])
end
assert(#inputTable == #self.gradInput, #inputTable.." ~= "..#self.gradInput)
return self.gradInput
end
function Sequencer:accGradParameters(inputTable, gradOutputTable, scale)
assert(torch.type(gradOutputTable) == 'table', "expecting gradOutput table")
assert(#gradOutputTable == #inputTable, "gradOutput should have as many elements as input")
-- back-propagate through time (BPTT)
for step=#gradOutputTable,1,-1 do
self.module:accGradParameters(inputTable[step], gradOutputTable[step], scale)
end
end
function Sequencer:accUpdateGradParameters(inputTable, gradOutputTable, lr)
assert(torch.type(gradOutputTable) == 'table', "expecting gradOutput table")
assert(#gradOutputTable == #inputTable, "gradOutput should have as many elements as input")
-- back-propagate through time (BPTT)
for step=#gradOutputTable,1,-1 do
self.module:accUpdateGradParameters(inputTable[step], gradOutputTable[step], lr)
end
end
-- Toggle to feed long sequences using multiple forwards.
-- 'eval' only affects evaluation (recommended for RNNs)
-- 'train' only affects training
-- 'neither' affects neither training nor evaluation
-- 'both' affects both training and evaluation (recommended for LSTMs)
-- Essentially, forget() isn't called on rnn module when remember is on
function Sequencer:remember(remember)
self._remember = (remember == nil) and 'both' or remember
assert(_.contains({'both','eval','train','neither'}, self._remember),
"Sequencer : unrecognized value for remember : "..self._remember)
return self
end
function Sequencer:training()
if self.train == false then
-- empty output table (tensor mem was managed by seq)
for i,output in ipairs(self.output) do
table.insert(self._output, output)
self.output[i] = nil
end
-- forget at the start of each training
self:forget()
end
parent.training(self)
end
function Sequencer:evaluate()
if self.train ~= false then
-- empty output table (tensor mem was managed by rnn)
self.output = {}
-- forget at the start of each evaluation
self:forget()
end
parent.evaluate(self)
assert(self.train == false)
end
function Sequencer:reinforce(reward)
if torch.type(reward) == 'table' then
error"Sequencer Error : step-wise rewards not yet supported"
end
return parent.reinforce(self, reward)
end
Sequencer.__tostring__ = nn.Decorator.__tostring__