Skip to content
This repository has been archived by the owner on Aug 1, 2023. It is now read-only.

modularize words prediction #180

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 0 additions & 258 deletions pytorch_translate/research/word_prediction/word_prediction_model.py

This file was deleted.

4 changes: 2 additions & 2 deletions pytorch_translate/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@
preprocess,
tasks as pytorch_translate_tasks,
)
from pytorch_translate.word_prediction import word_prediction_criterion # noqa
from pytorch_translate.word_prediction import word_prediction_model # noqa
from pytorch_translate.research.knowledge_distillation import ( # noqa
knowledge_distillation_loss
)
from pytorch_translate.research.word_prediction import word_prediction_criterion # noqa
from pytorch_translate.research.word_prediction import word_prediction_model # noqa
from pytorch_translate.utils import ManagedCheckpoints


Expand Down
Original file line number Diff line number Diff line change
@@ -1,47 +1,53 @@
#!/usr/bin/env python3

import math
import torch
import torch.nn.functional as F

from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.criterions import register_criterion
from fairseq.criterions.label_smoothed_cross_entropy \
import LabelSmoothedCrossEntropyCriterion
from fairseq import utils
from pytorch_translate.utils import maybe_cuda


@register_criterion('word_prediction')
class WordPredictionCriterion(FairseqCriterion):
class WordPredictionCriterion(LabelSmoothedCrossEntropyCriterion):
"""
Implement a combined loss from translation and target words prediction.
"""
def __init__(self, args, task):
super().__init__(args, task)
self.eps = args.label_smoothing

def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.

Returns a tuple with three elements:
1) the loss, as a Variable
1) total loss, as a Variable
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
predictor_output, decoder_output = model(**sample['net_input'])
# translation loss
translation_lprobs = model.get_normalized_probs(decoder_output, log_probs=True)
translation_target = model.get_targets(sample, decoder_output).view(-1)
translation_loss = F.nll_loss(
translation_lprobs,
translation_target,
size_average=False,
ignore_index=self.padding_idx,
reduce=reduce
translation_loss, _ = super().compute_loss(
model,
decoder_output,
sample,
reduce,
)
prediction_target = model.get_target_words(sample)
# predictor loss
prediction_lprobs = model.get_predictor_normalized_probs(
predictor_output, log_probs=True)
prediction_lprobs = prediction_lprobs.view(-1, prediction_lprobs.size(-1))
# prevent domination of padding idx
non_padding_mask = maybe_cuda(torch.ones(prediction_lprobs.size(1)))
non_padding_mask[model.encoder.padding_idx] = 0
prediction_lprobs = prediction_lprobs * non_padding_mask.unsqueeze(0)
non_pad_mask = prediction_target.ne(model.encoder.padding_idx)

prediction_target = model.get_target_words(sample)
assert prediction_lprobs.size(0) == prediction_target.size(0)
assert prediction_lprobs.dim() == 2
word_prediction_loss = -torch.gather(prediction_lprobs, 1, prediction_target)

word_prediction_loss = -prediction_lprobs.gather(
dim=-1,
index=prediction_target,
)[non_pad_mask]
# TODO: normalize , sentence avg
if reduce:
word_prediction_loss = word_prediction_loss.sum()
else:
Expand All @@ -56,8 +62,8 @@ def forward(self, model, sample, reduce=True):
sample_size = sample['ntokens']

logging_output = {
'loss': translation_loss,
'word_prediction_loss': word_prediction_loss,
'translation_loss': translation_loss.data,
'word_prediction_loss': word_prediction_loss.data,
'ntokens': sample['ntokens'],
'sample_size': sample_size,
}
Expand All @@ -76,11 +82,11 @@ def aggregate_logging_outputs(logging_outputs):
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_output = {'sample_size': sample_size}

for loss in ['loss', 'word_prediction_loss']:
for loss in ['translation_loss', 'word_prediction_loss']:
loss_sum = sum(log.get(loss, 0) for log in logging_outputs)

agg_output[loss] = loss_sum / sample_size / math.log(2)
if loss == 'loss' and sample_size != ntokens:
if loss == 'translation_loss' and sample_size != ntokens:
agg_output['nll_loss'] = loss_sum / ntokens / math.log(2)

return agg_output
Loading