From 71605987acb70dba6976929e3c25226dadc5a704 Mon Sep 17 00:00:00 2001 From: Michael Gschwind Date: Tue, 9 Feb 2021 07:19:02 -0800 Subject: [PATCH] Back out "Open-source non-autoregressive optimization" Summary: Original commit changeset: 2044306db439 - breaks build docs Differential Revision: D26339586 fbshipit-source-id: 750d60b3442e97099bf84a2c1e077f81fb06ba60 --- pytext/loss/__init__.py | 16 - pytext/loss/loss.py | 109 ------ pytext/loss/regularized_loss.py | 338 ------------------ pytext/loss/regularizer.py | 148 -------- .../samplewise_label_smoothing_loss_test.py | 37 -- 5 files changed, 648 deletions(-) delete mode 100644 pytext/loss/regularized_loss.py delete mode 100644 pytext/loss/regularizer.py delete mode 100644 pytext/loss/tests/samplewise_label_smoothing_loss_test.py diff --git a/pytext/loss/__init__.py b/pytext/loss/__init__.py index 38e1ba00c..f1cc6b593 100644 --- a/pytext/loss/__init__.py +++ b/pytext/loss/__init__.py @@ -18,15 +18,7 @@ NLLLoss, PairwiseRankingLoss, SourceType, - MaxMarginLoss, ) -from .regularized_loss import ( - LabelSmoothingLoss, - SamplewiseLabelSmoothingLoss, - NARSequenceLoss, - NARSamplewiseSequenceLoss, -) -from .regularizer import UniformRegularizer, EntropyRegularizer, AdaptiveRegularizer __all__ = [ @@ -46,12 +38,4 @@ "PairwiseRankingLoss", "LabelSmoothedCrossEntropyLoss", "SourceType", - "LabelSmoothingLoss", - "SamplewiseLabelSmoothingLoss", - "MaxMarginLoss", - "NARSequenceLoss", - "NARSamplewiseSequenceLoss", - "UniformRegularizer", - "EntropyRegularizer", - "AdaptiveRegularizer", ] diff --git a/pytext/loss/loss.py b/pytext/loss/loss.py index 567533621..4e848a776 100644 --- a/pytext/loss/loss.py +++ b/pytext/loss/loss.py @@ -12,19 +12,6 @@ from torch import nn -def maybe_log_normalize(logits, logits_type, dim=-1): - """Optionally log normalizes logits on the given dimension.""" - - if logits_type == SourceType.LOGITS: - return F.log_softmax(logits, dim) - elif logits_type == SourceType.PROBS: - return logits.log() - elif logits_type == SourceType.LOG_PROBS: - return logits - else: - raise NotImplementedError - - class SourceType(Enum): LOG_PROBS = "log_probs" LOGITS = "logits" @@ -620,99 +607,3 @@ def __call__(self, logits, targets, reduce=True): self.label_smoothing_loss = label_smoothing_loss return (1.0 - self.beta) * cross_entropy_loss + self.beta * label_smoothing_loss - - -class MaxMarginLoss(Loss): - """ - Computes a max-margin loss for structured prediction: - max(0, m + cost(Y',Y) + S(Y'|X) - S(Y|X)) - - Here, we require the score of the gold sequence S(Y|X) to be _at least_ - m + cost(Y',Y) higher than the score of a hypothesis sequence S(Y'|X), - where m = margin and cost(Y',Y) = Hamming distance between Y' and Y. - - To efficiently search for the sequence with the largest margin violation - when cost(Y',Y) is included, we greedily decode a sequence with score - S(Y'|X) + I[Y'!=Y]. Intuitively, this forces our model to score the - gold label above other candidate labels. - """ - - class Config(ConfigBase): - # Enables m (when m > 0) - margin: float = 0.0 - # Enables cost(Y',Y) - use_cost: bool = False - # Multiplies cost(Y',Y) with this amount - cost_scale: float = 1.0 - - def __init__(self, config, pad_index=1, *args, **kwargs): - self.margin = config.margin - self.use_cost = config.use_cost - self.cost_scale = config.cost_scale - self.pad_index = pad_index - - def get_sequence_scores(self, logits, indices, mask): - """ - Computes the score of the sequence: sum_i S(Y_i|Y_ 0.0: - loss += self.margin - - loss.clip_(min=0.0) # B - - return loss.sum() if reduce else loss diff --git a/pytext/loss/regularized_loss.py b/pytext/loss/regularized_loss.py deleted file mode 100644 index 9da29a8d0..000000000 --- a/pytext/loss/regularized_loss.py +++ /dev/null @@ -1,338 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved - -from typing import Union - -import torch -from pytext.config import ConfigBase -from pytext.config.component import create_loss - -from .loss import ( - Loss, - NLLLoss, - MaxMarginLoss, - HingeLoss, - maybe_log_normalize, - SourceType, -) -from .regularizer import UniformRegularizer, EntropyRegularizer, AdaptiveRegularizer - - -class LabelSmoothingLoss(Loss): - """Label loss with an optional regularizer for smoothing.""" - - class Config(ConfigBase): - beta: float = 0.1 - label_loss: Union[ - NLLLoss.Config, MaxMarginLoss.Config, HingeLoss.Config - ] = NLLLoss.Config() - smoothing_loss: Union[ - UniformRegularizer.Config, - EntropyRegularizer.Config, - AdaptiveRegularizer.Config, - ] = UniformRegularizer.Config() - - def __init__(self, config, ignore_index=1): - self.beta = config.beta - self.label_loss_fn = create_loss(config.label_loss, ignore_index=ignore_index) - self.smoothing_loss_fn = create_loss( - config.smoothing_loss, ignore_index=ignore_index - ) - self.ignore_index = ignore_index - - # Tracking variables. - self.label_loss = 0 - self.smoothing_loss = 0 - - def __call__(self, logits, targets, reduce=True): - label_loss = self.label_loss_fn(logits, targets, reduce) - - # Flatten logits if we're using a structured label loss. - if isinstance(self.label_loss_fn, MaxMarginLoss): - logits = logits.reshape(-1, logits.size(-1)) - targets = targets.view(-1) - - smoothing_loss = self.smoothing_loss_fn(logits, targets, reduce) - - # Set tracking variables. - self.label_loss = label_loss - self.smoothing_loss = smoothing_loss - - loss = label_loss + self.beta * smoothing_loss - - return loss - - -class SamplewiseLabelSmoothingLoss(LabelSmoothingLoss): - """Label smoothing loss with sample-wise logging.""" - - def __init__(self, config, ignore_index=-1): - super().__init__(config, ignore_index) - - # Sample-wise tracking variables. - self.samplewise_label_loss = 0 - self.samplewise_smoothing_loss = 0 - - def _reduce_mean( - self, logits, targets, batch_size, label_loss, smoothing_loss, reduce=True - ): - """ - Class-specific reduction function to extract sample-wise losses. Currently, - passing in reduce="mean" averages over all samples without providing access - to sample-wise losses. - """ - - # Save original losses. - orig_label_loss = label_loss.clone() - orig_smoothing_loss = smoothing_loss.clone() - - # Create target mask for pad tokens. - mask = targets.ne(self.ignore_index) - - if mask.any(): - # Guarantee ignored tokens have zero contribution to loss. - label_loss[~mask] = 0 - smoothing_loss[~mask] = 0 - - # Lengths after masking. - lengths = torch.sum(mask.reshape(batch_size, -1), dim=1) - - # Sample-wise losses (we do not consider masked tokens in this loss). - samplewise_label_loss = ( - torch.sum(label_loss.reshape(batch_size, -1), dim=-1) / lengths - ) - samplewise_smoothing_loss = ( - torch.sum(smoothing_loss.reshape(batch_size, -1), dim=-1) / lengths - ) - - # Replace NaNs with zero (only happens with zero length samples). - samplewise_label_loss[torch.isnan(samplewise_label_loss)] = 0 - samplewise_smoothing_loss[torch.isnan(samplewise_smoothing_loss)] = 0 - - # Update original loss to use non-masked samples. - label_loss = label_loss[mask] - smoothing_loss = smoothing_loss[mask] - else: - samplewise_label_loss = torch.zeros(batch_size, device=logits.device) - samplewise_smoothing_loss = torch.zeros(batch_size, device=logits.device) - label_loss = torch.zeros(mask.shape, device=logits.shape) - smoothing_loss = torch.zeros(mask.shape, device=logits.shape) - - # If `reduce` is enabled, compute mean loss over sequence. Otherwise, - # revert values before masking. - label_loss = torch.mean(label_loss) if reduce else orig_label_loss - smoothing_loss = torch.mean(smoothing_loss) if reduce else orig_smoothing_loss - - return ( - samplewise_label_loss, - samplewise_smoothing_loss, - label_loss, - smoothing_loss, - ) - - def __call__(self, logits, targets, reduce=True, batch_size=None): - label_loss = self.label_loss_fn(logits, targets, reduce=False) - smoothing_loss = self.smoothing_loss_fn(logits, targets, reduce=False) - - # Unless specified, batch_size is equal to the length of logits. - if batch_size is None: - batch_size = logits.shape[0] - - # Extract sample-wise losses and reduce regular losses. - ( - samplewise_label_loss, - samplewise_smoothing_loss, - label_loss, - smoothing_loss, - ) = self._reduce_mean( - logits=logits, - targets=targets, - batch_size=batch_size, - label_loss=label_loss, - smoothing_loss=smoothing_loss, - reduce=reduce, - ) - - # Set sample-wise tracking variables. - self.samplewise_label_loss = samplewise_label_loss - self.samplewise_smoothing_loss = samplewise_smoothing_loss - self.samplewise_total_loss = ( - (samplewise_label_loss + self.beta * samplewise_smoothing_loss) - if samplewise_label_loss is not None - and samplewise_smoothing_loss is not None - else None - ) - - # Set tracking variables. - self.label_loss = label_loss - self.smoothing_loss = smoothing_loss - - loss = label_loss + self.beta * smoothing_loss - - return loss - - -class NARSequenceLoss(Loss): - """Joint loss over labels and length of sequences for non-autoregressive modeling.""" - - class Config(ConfigBase): - beta: float = 0.1 - assert_valid_targets: bool = True - label_type: SourceType = SourceType.LOG_PROBS - length_type: SourceType = SourceType.LOG_PROBS - label_loss: LabelSmoothingLoss.Config = LabelSmoothingLoss.Config() - length_loss: LabelSmoothingLoss.Config = LabelSmoothingLoss.Config() - - def __init__(self, config, ignore_index=1): - self.beta = config.beta - self.assert_valid_targets = config.assert_valid_targets - self.label_type = config.label_type - self.length_type = config.length_type - - # Enforce loss constraints. Specifically, because `MaxMarginLoss` requires - # structured outputs, we can't use this as a length loss. - if isinstance(config.length_loss.label_loss, MaxMarginLoss): - raise ValueError("MaxMarginLoss can't be used as a length loss") - - self.label_loss_fn = create_loss(config.label_loss, ignore_index=ignore_index) - self.length_loss_fn = create_loss(config.length_loss, ignore_index=ignore_index) - - def __call__( - self, - label_logits, - label_targets, - length_logits, - length_targets, - reduce=True, - ): - """ - label_logits: (B x T) x V_1 - label_targets: (B x T) - length_logits: B x V_2 - length_targets: B - """ - - label_logits = maybe_log_normalize( - logits=label_logits, logits_type=self.label_type, dim=-1 - ) - length_logits = maybe_log_normalize( - logits=length_logits, logits_type=self.length_type, dim=-1 - ) - - max_supported_dim = length_logits.size(1) - length_targets = length_targets.unsqueeze(-1) # (B x T) x 1 - - if self.assert_valid_targets: - if torch.any(length_targets >= max_supported_dim): - total_violations = str( - length_targets[length_targets >= max_supported_dim] - .flatten() - .tolist() - ) - raise RuntimeError( - f"max_supported_dim: {max_supported_dim}, " - f"total violations: {total_violations}" - ) - else: - length_targets[length_targets >= max_supported_dim] = max_supported_dim - 1 - - label_loss = self.label_loss_fn(label_logits, label_targets, reduce) - length_loss = self.length_loss_fn( - length_logits, length_targets.squeeze(-1), reduce - ) - - loss = label_loss + self.beta * length_loss - - return ( - loss, - { - "label_loss": label_loss, - "length_loss": length_loss, - "label_label_loss": self.label_loss_fn.label_loss, - "label_smoothing_loss": self.label_loss_fn.smoothing_loss, - "length_label_loss": self.length_loss_fn.label_loss, - "length_smoothing_loss": self.length_loss_fn.smoothing_loss, - }, - ) - - -class NARSamplewiseSequenceLoss(NARSequenceLoss): - """Non-autoregressive sequence loss with sample-wise logging.""" - - class Config(NARSequenceLoss.Config): - label_loss: SamplewiseLabelSmoothingLoss.Config = ( - SamplewiseLabelSmoothingLoss.Config() - ) - length_loss: SamplewiseLabelSmoothingLoss.Config = ( - SamplewiseLabelSmoothingLoss.Config() - ) - - def __call__( - self, - label_logits, - label_targets, - length_logits, - length_targets, - reduce=True, - ): - """ - label_logits: (B x T) x V_1 - label_targets: (B x T) - length_logits: B x V_2 - length_targets: B - """ - - label_logits = maybe_log_normalize( - logits=label_logits, logits_type=self.label_type, dim=-1 - ) - length_logits = maybe_log_normalize( - logits=length_logits, logits_type=self.length_type, dim=-1 - ) - - max_length = int(torch.max(length_targets)) - batch_size = label_logits.shape[0] // max_length - max_supported_dim = length_logits.size(1) - length_targets = length_targets.unsqueeze(-1) # (B x T) x 1 - - if self.assert_valid_targets: - if torch.any(length_targets >= max_supported_dim): - total_violations = str( - length_targets[length_targets >= max_supported_dim] - .flatten() - .tolist() - ) - raise RuntimeError( - f"max_supported_dim: {max_supported_dim}, " - f"total violations: {total_violations}" - ) - else: - length_targets[length_targets >= max_supported_dim] = max_supported_dim - 1 - - label_loss = self.label_loss_fn(label_logits, label_targets, reduce, batch_size) - length_loss = self.length_loss_fn( - length_logits, length_targets.squeeze(-1), reduce - ) - loss = label_loss + self.beta * length_loss - - # Log sample-wise losses. - samplewise_losses = { - "samplewise_label_loss": self.label_loss_fn.samplewise_total_loss, - "samplewise_length_loss": self.length_loss_fn.samplewise_total_loss, - "samplewise_label_label_loss": self.label_loss_fn.samplewise_label_loss, - "samplewise_label_smoothing_loss": self.label_loss_fn.samplewise_smoothing_loss, - "samplewise_length_label_loss": self.length_loss_fn.samplewise_label_loss, - "samplewise_length_smoothing_loss": self.length_loss_fn.samplewise_smoothing_loss, - } - - return ( - loss, - { - "label_loss": label_loss, - "length_loss": length_loss, - "label_label_loss": self.label_loss_fn.label_loss, - "label_smoothing_loss": self.label_loss_fn.smoothing_loss, - "length_label_loss": self.length_loss_fn.label_loss, - "length_smoothing_loss": self.length_loss_fn.smoothing_loss, - **samplewise_losses, - }, - ) diff --git a/pytext/loss/regularizer.py b/pytext/loss/regularizer.py deleted file mode 100644 index 2e23957d3..000000000 --- a/pytext/loss/regularizer.py +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved - -import torch -import torch.nn as nn -import torch.nn.functional as F -from pytext.config import ConfigBase - -from .loss import Loss - - -class Regularizer(Loss): - """Generic regularization function to be added to a surrogate loss (e.g., cross-entropy).""" - - def __init__(self, config, ignore_index=1): - self.ignore_index = ignore_index - - def __call__(self, logits, targets, reduce=True): - raise NotImplementedError - - -class UniformRegularizer(Regularizer): - """ - Negative KL between the uniform and predicted distribution. - Defined as: - - KL(U || P(Y|X)) = - sum_i U_i * log (P(Y_i | X) / U_i) - = - sum_i U_i * log P(Y_i|X) + H[U] - = - (1/n) * sum_i log P(Y_i | X) + H[U] - - H[U] does not depend on X, thus it is omitted during optimization. - """ - - def __call__(self, logits, targets, reduce=True): - mask = targets.ne(self.ignore_index) - - loss = -logits.mean(dim=1) - - if reduce: - return ( - loss[mask].mean() - if mask.any() - else torch.tensor(0.0, device=logits.device) - ) - - return loss - - -class EntropyRegularizer(Regularizer): - """ - Entropy of the predicted distribution. Defined as: - H[P(Y|X)] = - sum_i P(Y_i|X) * log P(Y_i|X) - """ - - def __call__(self, logits, targets, reduce=True): - mask = targets.ne(self.ignore_index) - - loss = -torch.sum(logits * logits.exp(), dim=1) - - if reduce: - return ( - loss[mask].mean() - if mask.any() - else torch.tensor(0.0, device=logits.device) - ) - - return loss - - -class AdaptiveRegularizer(Regularizer): - """ - Adaptive variant of `UniformRegularizer` which learns the mix-in noise distribution. - - Learning Better Structured Representations using Low-Rank Adaptive Label Smoothing - (Ghoshal+ 2021; https://openreview.net/pdf?id=5NsEIflpbSv) - """ - - class Config(ConfigBase): - # Controls the shape of the noise distribution. Larger values of `eta` result - # in a sharper, low-entropy distribution. Must be >= 0. - eta: float = 0.1 - # `label_embedding_dim` and `label_embedding_dropout` control the dimension - # and regularization, respectively, of the adaptive label embedding matrix. - label_embedding_dim: int = 20 - label_embedding_dropout: float = 0.4 - - def __init__(self, config, ignore_index=1): - super().__init__(config, ignore_index) - - if config.eta < 0: - raise ValueError("eta must be >= 0") - if config.label_embedding_dropout < 0 or config.label_embedding_dropout >= 1: - raise ValueError("label_embedding_dropout must be [0, 1)") - - self.eta = config.eta - self.label_embedding_dim = config.label_embedding_dim - self.label_embedding_dropout = config.label_embedding_dropout - self.label_embedding = None - - def compute_adaptive_loss(self, logits, targets, label_embedding): - """ - Using Equation 3 and 4, computes several terms of the adaptive penalty. - Specifically, we implement adaptive smoothing (`smooth_term`) and - an entropy constraint (`eta_term`). - """ - - if targets.dim() == logits.dim() - 1: - targets = targets.unsqueeze(-1) - - U = torch.mm( - torch.index_select(label_embedding, 0, targets.squeeze(-1)), - label_embedding.T, - ) - V = F.softmax(U.float(), dim=-1).to(logits.dtype) - - smooth_term = -torch.bmm(V.unsqueeze(1), logits.unsqueeze(2)).squeeze(2) - eta_term = -self.eta * ( - -torch.bmm(U.unsqueeze(1), V.unsqueeze(2)).mean() - + torch.logsumexp(U, axis=-1).mean() - ) - loss = smooth_term + eta_term - - return loss - - def __call__(self, logits, targets, reduce=True): - mask = targets.ne(self.ignore_index) - - if self.label_embedding is None: - # Initialize label embedding matrix to ones. - num_labels = logits.shape[1] - self.label_embedding = nn.Parameter( - torch.ones(num_labels, self.label_embedding_dim), - requires_grad=True, - ).to(device=logits.device, dtype=logits.dtype) - - loss = self.compute_adaptive_loss( - logits, - targets, - F.dropout(self.label_embedding, self.label_embedding_dropout), - ) - - if reduce: - return ( - loss[mask].mean() - if mask.any() - else torch.tensor(0.0, device=logits.device) - ) - - return loss diff --git a/pytext/loss/tests/samplewise_label_smoothing_loss_test.py b/pytext/loss/tests/samplewise_label_smoothing_loss_test.py deleted file mode 100644 index 963c636ff..000000000 --- a/pytext/loss/tests/samplewise_label_smoothing_loss_test.py +++ /dev/null @@ -1,37 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved - -import unittest - -import torch -import torch.nn.functional as F -from pytext.loss import LabelSmoothingLoss, SamplewiseLabelSmoothingLoss - - -class SamplewiseLabelSmoothingLossTest(unittest.TestCase): - def test_samplewise_label_smoothing_loss(self): - batch_size = 5 - num_labels = 5 - - label_smoothing_loss = LabelSmoothingLoss( - LabelSmoothingLoss.Config(), ignore_index=-1 - ) - samplewise_label_smoothing_loss = SamplewiseLabelSmoothingLoss( - SamplewiseLabelSmoothingLoss.Config(), ignore_index=-1 - ) - - logits = F.log_softmax(torch.rand(batch_size, num_labels), 1) - targets = torch.randint(batch_size, (num_labels,)) - - self.assertTrue( - torch.isclose( - label_smoothing_loss(logits, targets, reduce=True), - samplewise_label_smoothing_loss(logits, targets, reduce=True), - ) - ) - self.assertTrue( - torch.isclose( - label_smoothing_loss(logits, targets, reduce=False), - samplewise_label_smoothing_loss(logits, targets, reduce=False), - ).all() - )