Skip to content

Commit

Permalink
fix loss calculation for RNN (pytorch#732)
Browse files Browse the repository at this point in the history
* fix loss calculation for RNN

* fixes loss for both RNN & Transformer
  • Loading branch information
Taufiquzzaman Peyash authored Mar 17, 2020
1 parent 5551061 commit 4902431
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
7 changes: 4 additions & 3 deletions word_language_model/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,11 @@ def evaluate(data_source):
data, targets = get_batch(data_source, i)
if args.model == 'Transformer':
output = model(data)
output = output.view(-1, ntokens)
else:
output, hidden = model(data, hidden)
hidden = repackage_hidden(hidden)
output_flat = output.view(-1, ntokens)
total_loss += len(data) * criterion(output_flat, targets).item()
total_loss += len(data) * criterion(output, targets).item()
return total_loss / (len(data_source) - 1)


Expand All @@ -168,10 +168,11 @@ def train():
model.zero_grad()
if args.model == 'Transformer':
output = model(data)
output = output.view(-1, ntokens)
else:
hidden = repackage_hidden(hidden)
output, hidden = model(data, hidden)
loss = criterion(output.view(-1, ntokens), targets)
loss = criterion(output, targets)
loss.backward()

# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
Expand Down
4 changes: 3 additions & 1 deletion word_language_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class RNNModel(nn.Module):

def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False):
super(RNNModel, self).__init__()
self.ntoken = ntoken
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, ninp)
if rnn_type in ['LSTM', 'GRU']:
Expand Down Expand Up @@ -49,7 +50,8 @@ def forward(self, input, hidden):
output, hidden = self.rnn(emb, hidden)
output = self.drop(output)
decoded = self.decoder(output)
return decoded, hidden
decoded = decoded.view(-1, self.ntoken)
return F.log_softmax(decoded, dim=1), hidden

def init_hidden(self, bsz):
weight = next(self.parameters())
Expand Down

0 comments on commit 4902431

Please sign in to comment.