-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrnntest01.lua
81 lines (67 loc) · 2.49 KB
/
rnntest01.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
-- Simple torch-rnn (https://github.com/Element-Research/rnn) demo
-- based on torch-rnn library demos
require 'rnn'
batchSize = 1
rho = 3 --number of time steps
hiddenSize = 120
inputSize = 1
nIndex = 100 --LookupTable input space
lr = 0.2 --learning rate
-- build a dummy dataset (task is to predict next item, given previous)
sequence = torch.Tensor(nIndex):fill(1)
for i=3,nIndex do --create a Fibonnaci sequence 1,1,2,3,5,8,13,...; If the input is 1, target is either 1 or 2 depending on previous step
sequence[i]=sequence[i-1]+sequence[i-2]
if sequence[i] > nIndex then
sequence[i] = 1
sequence[i-1] = 1
end
end
print('Sequence:')
print(sequence)
-- define model
-- recurrent layer
local r = nn.Recurrent(
hiddenSize, --output size
nn.LookupTable(nIndex, hiddenSize), --input layer. Use discrete space to apply LookupTable (https://github.com/Element-Research/rnn/issues/113)
nn.Linear(hiddenSize, hiddenSize), --recurrent layer
nn.Tanh(),
rho
)
local rnn = nn.Sequential()
:add(r) --add recurrent layer
:add(nn.Linear(hiddenSize, nIndex))
:add(nn.LogSoftMax()) --classifier
-- add sequencer
rnn = nn.Sequencer(rnn)
--set criterion
criterion = nn.SequencerCriterion(nn.ClassNLLCriterion())
print("model:", rnn)
-- training
local iteration = 1
local seqIndex = 1
while iteration<100 do
-- 1. create a sequence of rho time-steps
if seqIndex > nIndex-rho then seqIndex = 1 end
local inputs, targets = {}, {}
for step=1,rho do
inputs[step] = sequence:sub(seqIndex,seqIndex) --select input
targets[step] = sequence:sub(seqIndex+1,seqIndex+1) --select target
seqIndex = seqIndex + 1
end
seqIndex = seqIndex - rho+1
-- 2. forward sequence through rnn
rnn:zeroGradParameters()
local outputs = rnn:forward(inputs)
local err = criterion:forward(outputs, targets)
-- get the classifier output
maxOutput, maxIndex = torch.max(outputs[rho][1],1)
print('# iteration: ', iteration, 'input:', inputs[rho][1], 'target:', targets[rho][1], 'output:', maxIndex[1])
-- 3. backward sequence through rnn (i.e. backprop through time)
local gradOutputs = criterion:backward(outputs, targets)
local gradInputs = rnn:backward(inputs, gradOutputs)
-- note that LookupTable does not generate any gradInputs and it can be a problem in more complicated models.
-- please refer to https://github.com/Element-Research/rnn/issues/185 for a workaround
-- 4. update
rnn:updateParameters(lr)
iteration = iteration + 1
end --end iteration