-
Notifications
You must be signed in to change notification settings - Fork 81
/
Copy pathtrain.lua
154 lines (128 loc) · 6.14 KB
/
train.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
require 'torch'
require 'nn'
require 'nngraph'
require 'optim'
local CharLMMinibatchLoader = require 'data.CharLMMinibatchLoader'
local LSTM = require 'LSTM' -- LSTM timestep and utilities
require 'Embedding' -- class name is Embedding (not namespaced)
local model_utils=require 'model_utils'
local cmd = torch.CmdLine()
cmd:text()
cmd:text('Training a simple character-level LSTM language model')
cmd:text()
cmd:text('Options')
cmd:option('-vocabfile','vocabfile.t7','filename of the string->int table')
cmd:option('-datafile','datafile.t7','filename of the serialized torch ByteTensor to load')
cmd:option('-batch_size',16,'number of sequences to train on in parallel')
cmd:option('-seq_length',16,'number of timesteps to unroll to')
cmd:option('-rnn_size',256,'size of LSTM internal state')
cmd:option('-max_epochs',1,'number of full passes through the training data')
cmd:option('-savefile','model_autosave','filename to autosave the model (protos) to, appended with the,param,string.t7')
cmd:option('-save_every',100,'save every 100 steps, overwriting the existing file')
cmd:option('-print_every',10,'how many steps/minibatches between printing out the loss')
cmd:option('-seed',123,'torch manual random number generator seed')
cmd:text()
-- parse input params
local opt = cmd:parse(arg)
-- preparation stuff:
torch.manualSeed(opt.seed)
opt.savefile = cmd:string(opt.savefile, opt,
{save_every=true, print_every=true, savefile=true, vocabfile=true, datafile=true})
.. '.t7'
local loader = CharLMMinibatchLoader.create(
opt.datafile, opt.vocabfile, opt.batch_size, opt.seq_length)
local vocab_size = loader.vocab_size -- the number of distinct characters
-- define model prototypes for ONE timestep, then clone them
--
local protos = {}
protos.embed = Embedding(vocab_size, opt.rnn_size)
-- lstm timestep's input: {x, prev_c, prev_h}, output: {next_c, next_h}
protos.lstm = LSTM.lstm(opt)
protos.softmax = nn.Sequential():add(nn.Linear(opt.rnn_size, vocab_size)):add(nn.LogSoftMax())
protos.criterion = nn.ClassNLLCriterion()
-- put the above things into one flattened parameters tensor
local params, grad_params = model_utils.combine_all_parameters(protos.embed, protos.lstm, protos.softmax)
params:uniform(-0.08, 0.08)
-- make a bunch of clones, AFTER flattening, as that reallocates memory
local clones = {}
for name,proto in pairs(protos) do
print('cloning '..name)
clones[name] = model_utils.clone_many_times(proto, opt.seq_length, not proto.parameters)
end
-- LSTM initial state (zero initially, but final state gets sent to initial state when we do BPTT)
local initstate_c = torch.zeros(opt.batch_size, opt.rnn_size)
local initstate_h = initstate_c:clone()
-- LSTM final state's backward message (dloss/dfinalstate) is 0, since it doesn't influence predictions
local dfinalstate_c = initstate_c:clone()
local dfinalstate_h = initstate_c:clone()
-- do fwd/bwd and return loss, grad_params
function feval(params_)
if params_ ~= params then
params:copy(params_)
end
grad_params:zero()
------------------ get minibatch -------------------
local x, y = loader:next_batch()
------------------- forward pass -------------------
local embeddings = {} -- input embeddings
local lstm_c = {[0]=initstate_c} -- internal cell states of LSTM
local lstm_h = {[0]=initstate_h} -- output values of LSTM
local predictions = {} -- softmax outputs
local loss = 0
for t=1,opt.seq_length do
embeddings[t] = clones.embed[t]:forward(x[{{}, t}])
-- we're feeding the *correct* things in here, alternatively
-- we could sample from the previous timestep and embed that, but that's
-- more commonly done for LSTM encoder-decoder models
lstm_c[t], lstm_h[t] = unpack(clones.lstm[t]:forward{embeddings[t], lstm_c[t-1], lstm_h[t-1]})
predictions[t] = clones.softmax[t]:forward(lstm_h[t])
loss = loss + clones.criterion[t]:forward(predictions[t], y[{{}, t}])
end
------------------ backward pass -------------------
-- complete reverse order of the above
local dembeddings = {} -- d loss / d input embeddings
local dlstm_c = {[opt.seq_length]=dfinalstate_c} -- internal cell states of LSTM
local dlstm_h = {} -- output values of LSTM
for t=opt.seq_length,1,-1 do
-- backprop through loss, and softmax/linear
local doutput_t = clones.criterion[t]:backward(predictions[t], y[{{}, t}])
-- Two cases for dloss/dh_t:
-- 1. h_T is only used once, sent to the softmax (but not to the next LSTM timestep).
-- 2. h_t is used twice, for the softmax and for the next step. To obey the
-- multivariate chain rule, we add them.
if t == opt.seq_length then
assert(dlstm_h[t] == nil)
dlstm_h[t] = clones.softmax[t]:backward(lstm_h[t], doutput_t)
else
dlstm_h[t]:add(clones.softmax[t]:backward(lstm_h[t], doutput_t))
end
-- backprop through LSTM timestep
dembeddings[t], dlstm_c[t-1], dlstm_h[t-1] = unpack(clones.lstm[t]:backward(
{embeddings[t], lstm_c[t-1], lstm_h[t-1]},
{dlstm_c[t], dlstm_h[t]}
))
-- backprop through embeddings
clones.embed[t]:backward(x[{{}, t}], dembeddings[t])
end
------------------------ misc ----------------------
-- transfer final state to initial state (BPTT)
initstate_c:copy(lstm_c[#lstm_c])
initstate_h:copy(lstm_h[#lstm_h])
-- clip gradient element-wise
grad_params:clamp(-5, 5)
return loss, grad_params
end
-- optimization stuff
local losses = {}
local optim_state = {learningRate = 1e-1}
local iterations = opt.max_epochs * loader.nbatches
for i = 1, iterations do
local _, loss = optim.adagrad(feval, params, optim_state)
losses[#losses + 1] = loss[1]
if i % opt.save_every == 0 then
torch.save(opt.savefile, protos)
end
if i % opt.print_every == 0 then
print(string.format("iteration %4d, loss = %6.8f, loss/seq_len = %6.8f, gradnorm = %6.4e", i, loss[1], loss[1] / opt.seq_length, grad_params:norm()))
end
end