From bc1308fccb67420f5200036664f5359119ab67df Mon Sep 17 00:00:00 2001 From: Xian Li Date: Mon, 30 Jul 2018 17:15:45 -0700 Subject: [PATCH] modularize words prediction Summary: A series of diffs to extend words prediction so that it can be used as vocab reduction purpose on any arch. The code under research/words_prediction has lots of reduandant logic, both in the loss and model arch. This diff tries to make words prediction a plug and play which can work with existing translation loss and encoder/decoder classes. Reviewed By: jhcross Differential Revision: D9012503 fbshipit-source-id: 775c265e5fadebb78c9b933f7131485ae8239c09 --- .../word_prediction/word_prediction_model.py | 258 ------------------ pytorch_translate/train.py | 4 +- .../word_prediction_criterion.py | 54 ++-- .../word_prediction/word_prediction_model.py | 117 ++++++++ .../word_prediction/word_predictor.py | 0 5 files changed, 149 insertions(+), 284 deletions(-) delete mode 100644 pytorch_translate/research/word_prediction/word_prediction_model.py rename pytorch_translate/{research => }/word_prediction/word_prediction_criterion.py (65%) create mode 100644 pytorch_translate/word_prediction/word_prediction_model.py rename pytorch_translate/{research => }/word_prediction/word_predictor.py (100%) diff --git a/pytorch_translate/research/word_prediction/word_prediction_model.py b/pytorch_translate/research/word_prediction/word_prediction_model.py deleted file mode 100644 index b8950c73..00000000 --- a/pytorch_translate/research/word_prediction/word_prediction_model.py +++ /dev/null @@ -1,258 +0,0 @@ -from fairseq.models import ( - register_model, - register_model_architecture, - FairseqModel, -) -from pytorch_translate import vocab_reduction -from pytorch_translate.rnn import ( - torch_find, - LSTMSequenceEncoder, - RNNEncoder, - RNNDecoder, -) -from .word_predictor import WordPredictor - - -class FairseqWordPredictionModel(FairseqModel): - def __init__(self, encoder, decoder, predictor): - super().__init__(encoder, decoder) - self.predictor = predictor - - def forward(self, src_tokens, src_lengths, prev_output_tokens): - encoder_output = self.encoder(src_tokens, src_lengths) - pred_output = self.predictor(encoder_output) - decoder_output = self.decoder(prev_output_tokens, encoder_output) - return pred_output, decoder_output - - def get_predictor_normalized_probs(self, pred_output, log_probs): - return self.predictor.get_normalized_probs(pred_output, log_probs) - - def get_target_words(self, sample): - return sample['target'] - - -@register_model('rnn_wp') -class RNNWordPredictionModel(FairseqWordPredictionModel): - - @staticmethod - def add_args(parser): - parser.add_argument( - '--dropout', - default=0.1, - type=float, - metavar='D', - help='dropout probability', - ) - parser.add_argument( - '--encoder-embed-dim', - type=int, - metavar='N', - help='encoder embedding dimension', - ) - parser.add_argument( - '--encoder-freeze-embed', - default=False, - action='store_true', - help=('whether to freeze the encoder embedding or allow it to be ' - 'updated during training'), - ) - parser.add_argument( - '--encoder-hidden-dim', - type=int, - metavar='N', - help='encoder cell num units', - ) - parser.add_argument( - '--encoder-layers', - type=int, - metavar='N', - help='number of encoder layers', - ) - parser.add_argument( - '--encoder-bidirectional', - action='store_true', - help='whether the first layer is bidirectional or not', - ) - parser.add_argument( - '--averaging-encoder', - default=False, - action='store_true', - help=( - 'whether use mean encoder hidden states as decoder initial ' - 'states or not' - ), - ) - parser.add_argument( - '--decoder-embed-dim', - type=int, - metavar='N', - help='decoder embedding dimension', - ) - parser.add_argument( - '--decoder-freeze-embed', - default=False, - action='store_true', - help=('whether to freeze the encoder embedding or allow it to be ' - 'updated during training'), - ) - parser.add_argument( - '--decoder-hidden-dim', - type=int, - metavar='N', - help='decoder cell num units', - ) - parser.add_argument( - '--decoder-layers', - type=int, - metavar='N', - help='number of decoder layers', - ) - parser.add_argument( - '--decoder-out-embed-dim', - type=int, - metavar='N', - help='decoder output embedding dimension', - ) - parser.add_argument( - '--attention-type', - type=str, - metavar='EXPR', - help='decoder attention, defaults to dot', - ) - parser.add_argument( - '--residual-level', - default=None, - type=int, - help=( - 'First layer where to apply a residual connection. ' - 'The value should be greater than 0 and smaller than the number of ' - 'layers.' - ), - ) - parser.add_argument( - '--cell-type', - default='lstm', - type=str, - metavar='EXPR', - help='cell type, defaults to lstm, values:lstm, milstm, layer_norm_lstm', - ) - - # Granular dropout settings (if not specified these default to --dropout) - parser.add_argument( - '--encoder-dropout-in', - type=float, - metavar='D', - help='dropout probability for encoder input embedding', - ) - parser.add_argument( - '--encoder-dropout-out', - type=float, - metavar='D', - help='dropout probability for encoder output', - ) - parser.add_argument( - '--decoder-dropout-in', - type=float, - metavar='D', - help='dropout probability for decoder input embedding', - ) - parser.add_argument( - '--decoder-dropout-out', - type=float, - metavar='D', - help='dropout probability for decoder output', - ) - parser.add_argument( - '--sequence-lstm', - action='store_true', - help='use nn.LSTM implementation for encoder', - ) - # new arg - parser.add_argument( - '--predictor-hidden-dim', - type=int, - metavar='N', - help='word predictor num units', - ) - - # Args for vocab reduction - vocab_reduction.add_args(parser) - - @classmethod - def build_model(cls, args, task): - """Build a new model instance.""" - src_dict, dst_dict = task.source_dictionary, task.target_dictionary - base_architecture_wp(args) - if args.sequence_lstm: - encoder_class = LSTMSequenceEncoder - else: - encoder_class = RNNEncoder - encoder = encoder_class( - src_dict, - embed_dim=args.encoder_embed_dim, - freeze_embed=args.encoder_freeze_embed, - cell_type=args.cell_type, - num_layers=args.encoder_layers, - hidden_dim=args.encoder_hidden_dim, - dropout_in=args.encoder_dropout_in, - dropout_out=args.encoder_dropout_out, - residual_level=args.residual_level, - bidirectional=bool(args.encoder_bidirectional), - ) - decoder = RNNDecoder( - src_dict=src_dict, - dst_dict=dst_dict, - vocab_reduction_params=args.vocab_reduction_params, - encoder_hidden_dim=args.encoder_hidden_dim, - embed_dim=args.decoder_embed_dim, - freeze_embed=args.decoder_freeze_embed, - out_embed_dim=args.decoder_out_embed_dim, - cell_type=args.cell_type, - num_layers=args.decoder_layers, - hidden_dim=args.decoder_hidden_dim, - attention_type=args.attention_type, - dropout_in=args.decoder_dropout_in, - dropout_out=args.decoder_dropout_out, - residual_level=args.residual_level, - averaging_encoder=args.averaging_encoder, - ) - predictor = WordPredictor( - args.encoder_hidden_dim, args.predictor_hidden_dim, len(dst_dict) - ) - return cls(encoder, decoder, predictor) - - def get_targets(self, sample, net_output): - targets = sample['target'].view(-1) - possible_translation_tokens = net_output[-1] - if possible_translation_tokens is not None: - targets = torch_find( - possible_translation_tokens.data, - targets.data, - len(self.dst_dict), - ) - return targets - - -@register_model_architecture('rnn_wp', 'rnn_wp') -def base_architecture_wp(args): - # default architecture - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) - args.encoder_layers = getattr(args, 'encoder_layers', 1) - args.encoder_hidden_dim = getattr(args, 'encoder_hidden_dim', 512) - args.encoder_bidirectional = getattr(args, 'encoder_bidirectional', False) - args.encoder_dropout_in = getattr(args, 'encoder_dropout_in', args.dropout) - args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', args.dropout) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) - args.decoder_layers = getattr(args, 'decoder_layers', 1) - args.decoder_hidden_dim = getattr(args, 'decoder_hidden_dim', 512) - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 512) - args.attention_type = getattr(args, 'attention_type', 'dot') - args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout) - args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout) - args.averaging_encoder = getattr(args, 'averaging_encoder', False) - args.encoder_freeze_embed = getattr(args, 'encoder_freeze_embed', False) - args.decoder_freeze_embed = getattr(args, 'decoder_freeze_embed', False) - args.cell_type = getattr(args, 'cell_type', 'lstm') - vocab_reduction.set_arg_defaults(args) - args.sequence_lstm = getattr(args, 'sequence_lstm', False) - args.predictor_hidden_dim = getattr(args, 'predictor_hidden_dim', 512) diff --git a/pytorch_translate/train.py b/pytorch_translate/train.py index 651f43be..b47ca946 100644 --- a/pytorch_translate/train.py +++ b/pytorch_translate/train.py @@ -40,11 +40,11 @@ preprocess, tasks as pytorch_translate_tasks, ) +from pytorch_translate.word_prediction import word_prediction_criterion # noqa +from pytorch_translate.word_prediction import word_prediction_model # noqa from pytorch_translate.research.knowledge_distillation import ( # noqa knowledge_distillation_loss ) -from pytorch_translate.research.word_prediction import word_prediction_criterion # noqa -from pytorch_translate.research.word_prediction import word_prediction_model # noqa from pytorch_translate.utils import ManagedCheckpoints diff --git a/pytorch_translate/research/word_prediction/word_prediction_criterion.py b/pytorch_translate/word_prediction/word_prediction_criterion.py similarity index 65% rename from pytorch_translate/research/word_prediction/word_prediction_criterion.py rename to pytorch_translate/word_prediction/word_prediction_criterion.py index d8e377e7..6d138471 100644 --- a/pytorch_translate/research/word_prediction/word_prediction_criterion.py +++ b/pytorch_translate/word_prediction/word_prediction_criterion.py @@ -1,47 +1,53 @@ +#!/usr/bin/env python3 + import math -import torch -import torch.nn.functional as F -from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.criterions import register_criterion +from fairseq.criterions.label_smoothed_cross_entropy \ + import LabelSmoothedCrossEntropyCriterion from fairseq import utils -from pytorch_translate.utils import maybe_cuda @register_criterion('word_prediction') -class WordPredictionCriterion(FairseqCriterion): +class WordPredictionCriterion(LabelSmoothedCrossEntropyCriterion): + """ + Implement a combined loss from translation and target words prediction. + """ + def __init__(self, args, task): + super().__init__(args, task) + self.eps = args.label_smoothing def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: - 1) the loss, as a Variable + 1) total loss, as a Variable 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ predictor_output, decoder_output = model(**sample['net_input']) # translation loss - translation_lprobs = model.get_normalized_probs(decoder_output, log_probs=True) - translation_target = model.get_targets(sample, decoder_output).view(-1) - translation_loss = F.nll_loss( - translation_lprobs, - translation_target, - size_average=False, - ignore_index=self.padding_idx, - reduce=reduce + translation_loss, _ = super().compute_loss( + model, + decoder_output, + sample, + reduce, ) + prediction_target = model.get_target_words(sample) # predictor loss prediction_lprobs = model.get_predictor_normalized_probs( predictor_output, log_probs=True) + prediction_lprobs = prediction_lprobs.view(-1, prediction_lprobs.size(-1)) # prevent domination of padding idx - non_padding_mask = maybe_cuda(torch.ones(prediction_lprobs.size(1))) - non_padding_mask[model.encoder.padding_idx] = 0 - prediction_lprobs = prediction_lprobs * non_padding_mask.unsqueeze(0) + non_pad_mask = prediction_target.ne(model.encoder.padding_idx) - prediction_target = model.get_target_words(sample) assert prediction_lprobs.size(0) == prediction_target.size(0) assert prediction_lprobs.dim() == 2 - word_prediction_loss = -torch.gather(prediction_lprobs, 1, prediction_target) - + word_prediction_loss = -prediction_lprobs.gather( + dim=-1, + index=prediction_target, + )[non_pad_mask] + # TODO: normalize , sentence avg if reduce: word_prediction_loss = word_prediction_loss.sum() else: @@ -56,8 +62,8 @@ def forward(self, model, sample, reduce=True): sample_size = sample['ntokens'] logging_output = { - 'loss': translation_loss, - 'word_prediction_loss': word_prediction_loss, + 'translation_loss': translation_loss.data, + 'word_prediction_loss': word_prediction_loss.data, 'ntokens': sample['ntokens'], 'sample_size': sample_size, } @@ -76,11 +82,11 @@ def aggregate_logging_outputs(logging_outputs): sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) agg_output = {'sample_size': sample_size} - for loss in ['loss', 'word_prediction_loss']: + for loss in ['translation_loss', 'word_prediction_loss']: loss_sum = sum(log.get(loss, 0) for log in logging_outputs) agg_output[loss] = loss_sum / sample_size / math.log(2) - if loss == 'loss' and sample_size != ntokens: + if loss == 'translation_loss' and sample_size != ntokens: agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) return agg_output diff --git a/pytorch_translate/word_prediction/word_prediction_model.py b/pytorch_translate/word_prediction/word_prediction_model.py new file mode 100644 index 00000000..eab3197f --- /dev/null +++ b/pytorch_translate/word_prediction/word_prediction_model.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 + +from fairseq.models import ( + register_model, + register_model_architecture, + FairseqModel, +) +from pytorch_translate import rnn +from pytorch_translate.rnn import ( + torch_find, + LSTMSequenceEncoder, + RNNEncoder, + RNNDecoder +) +from pytorch_translate.word_prediction import word_predictor + + +class WordPredictionModel(FairseqModel): + """ + An architecuture which jointly learns translation and target words + prediction, as described in http://aclweb.org/anthology/D17-1013. + """ + def __init__(self, task, encoder, decoder, predictor): + super().__init__(encoder, decoder) + self.predictor = predictor + self.task = task + + def forward(self, src_tokens, src_lengths, prev_output_tokens): + encoder_output = self.encoder(src_tokens, src_lengths) + pred_output = self.predictor(encoder_output) + decoder_output = self.decoder(prev_output_tokens, encoder_output) + return pred_output, decoder_output + + def get_predictor_normalized_probs(self, pred_output, log_probs): + return self.predictor.get_normalized_probs(pred_output, log_probs) + + def get_target_words(self, sample): + return sample['target'] + + +@register_model('rnn_word_pred') +class RNNWordPredictionModel(WordPredictionModel): + """ + A subclass which adds words prediction to RNN arch. + """ + @staticmethod + def add_args(parser): + rnn.RNNModel.add_args(parser) + parser.add_argument( + '--predictor-hidden-dim', + type=int, + metavar='N', + help='word predictor num units', + ) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + src_dict, dst_dict = task.source_dictionary, task.target_dictionary + base_architecture_wp(args) + if args.sequence_lstm: + encoder_class = LSTMSequenceEncoder + else: + encoder_class = RNNEncoder + decoder_class = RNNDecoder + + encoder = encoder_class( + src_dict, + embed_dim=args.encoder_embed_dim, + freeze_embed=args.encoder_freeze_embed, + cell_type=args.cell_type, + num_layers=args.encoder_layers, + hidden_dim=args.encoder_hidden_dim, + dropout_in=args.encoder_dropout_in, + dropout_out=args.encoder_dropout_out, + residual_level=args.residual_level, + bidirectional=bool(args.encoder_bidirectional), + ) + decoder = decoder_class( + src_dict=src_dict, + dst_dict=dst_dict, + vocab_reduction_params=args.vocab_reduction_params, + encoder_hidden_dim=args.encoder_hidden_dim, + embed_dim=args.decoder_embed_dim, + freeze_embed=args.decoder_freeze_embed, + out_embed_dim=args.decoder_out_embed_dim, + cell_type=args.cell_type, + num_layers=args.decoder_layers, + hidden_dim=args.decoder_hidden_dim, + attention_type=args.attention_type, + dropout_in=args.decoder_dropout_in, + dropout_out=args.decoder_dropout_out, + residual_level=args.residual_level, + averaging_encoder=args.averaging_encoder, + ) + predictor = word_predictor.WordPredictor( + args.encoder_hidden_dim, args.predictor_hidden_dim, len(dst_dict) + ) + return cls(task, encoder, decoder, predictor) + + def get_targets(self, sample, net_output): + targets = sample['target'].view(-1) + possible_translation_tokens = net_output[-1] + if possible_translation_tokens is not None: + targets = torch_find( + possible_translation_tokens, + targets, + len(self.task.target_dictionary), + ) + return targets + + +@register_model_architecture('rnn_word_pred', 'rnn_word_pred') +def base_architecture_wp(args): + # default architecture + rnn.base_architecture(args) + args.predictor_hidden_dim = getattr(args, 'predictor_hidden_dim', 512) diff --git a/pytorch_translate/research/word_prediction/word_predictor.py b/pytorch_translate/word_prediction/word_predictor.py similarity index 100% rename from pytorch_translate/research/word_prediction/word_predictor.py rename to pytorch_translate/word_prediction/word_predictor.py