diff --git a/projects/params_vs_compute/README.md b/projects/params_vs_compute/README.md new file mode 100644 index 00000000000..50a74cd2592 --- /dev/null +++ b/projects/params_vs_compute/README.md @@ -0,0 +1,46 @@ +# Which one is more important: more parameters or more computation? + +When we talk about the power of a deep learning model, often the only metric we pay attention to is its size, which is measured by the number parameters in that model. However, the amount of computation to run that model is an important metric too, but it is often overlooked because it is usually tied to the model size. Practitioners can then tend to think of those two metrics as a single thing. This is true most of the time, as each parameter participates in computation only once per input. So if a model has 1 million parameters, then it will take roughly 1 million floating point operations to process an input. This applies to feedforward models, recurrent models, and even Transformers. + +We are announcing the publication of two new methods that together help study this important question further -- and show that the computation of a model should be considered separately from the model size. Firstly, we can increase the model size without using more computation and improve its performance. The first paper proposes a simple, elegant method to achieve that by proposing hash layers. The second paper shows that the opposite is also true. We can increase the amount of computation without adding any new parameters to the model, which can improve performance significantly. A new family of staircase attention models is proposed that achieves this feat. Taken together, we believe these results open up a new way of thinking about deep learning models, requiring us to disentangle the concepts of parameters and computation. Thinking in this way, we believe we can arrive at more powerful models that are architected with regard to the resources available. + +## Hash Layers + +

+ + +In recent years, a trend emerged of making Transformer models bigger and bigger as a way of achieving impressive results on language tasks. The number of parameters in those models extend to billions, and even a trillion. While this shows the potential of deep learning, the bigger models require more computation that makes them less practical. + +One way to make big models use less computation is a sparse mixture-of-experts (MoE) approach. Each expert has its own parameters, which are only used for a small part of the input. Each input is routed to only some of the experts, meaning only some of the parameters need to be used, resulting in less computation. Indeed, recent works showed that Transformers can be made bigger efficiently this way. The key element of MoE is a router that decides which expert to use on which data. +In [our paper](https://arxiv.org/abs/2106.04426), we propose a routing mechanism based on hashing of input tokens. Unlike previous works, the hashing MoE is much simpler as it does not require any learning or change in objective function. Each word in the dictionary is simply assigned to a fixed expert, which is either chosen at random or assigned such that the distribution is balanced. Despite its simplicity, the method works well on a number of challenging tasks in language and dialogue. + +

+ +On the pushshift.io Reddit language modeling task, our hashing mechanism outperforms the learning-based Switch baseline, especially when there are more experts. The largest models here have 1.28 billion parameters, but only 17% of them are used for any particular input. We go further by training 4.5 billion parameter models on larger data, where we see the hashing outperforms another competitive sparse MoE model, BASE. The natural balancing of the expert assignment also means that training is efficient and scalable across a cluster, compared to those existing approaches. In our experiments this gives an improvement of about 11% in updates-per-second compared to BASE, and as the number of expert layers increases, we expect this difference to become more exaggerated. + +## Staircase Attention + +

+ +While adding more parameters to Transformers for better performance is a popular topic of study, increasing its computation is underexplored. One reason for that is that the standard Transformer interlocks computation and parameters with the architecture choice, making this impossible. In [our paper](https://arxiv.org/abs/2106.04279), we introduce an alternative family of architectures which detaches these concepts, and show that adding more computation is an alternate route to improving the performance. In particular, we propose a family of models with recurrent applications of Transformers, called Staircase and Ladder models. + +

+ +The Ladder model simply stacks the same Transformer multiple times. This means a parameter in the Transformer will participate in the computation multiple times, increasing the amount of computation while keeping the model size fixed. This straightforward modification brings a significant performance improvement to real-world tasks such as language modeling and dialogue. Furthermore, it indicates that increasing computation -- thus adding more power per parameter -- is a compelling research direction for better performance. + + +The Staircase model stacks Transformers, like Ladder, but shifts each Transformer multiple time steps forward. This change makes it possible to continue stacking Transformers as long as inputs continue, forming a model shaped like a staircase. Unlike Transformers, this continuation makes Staircase recurrent in time, which is crucial for maintaining an internal state for tracking changes. On simple constructed tasks where the model just needs to maintain an internal state and update it with incoming information, feedforward models like Transformer and Ladder struggle, but Staircase can solve them with ease. In addition, Staircase models also enjoy the same performance boost as Ladder models on language modeling tasks because they have more compute per parameter. + +## Why not both? + +A natural question after introducing these two methods is -- can we combine then? The answer is -- yes! The improvements gained from the two approaches appear to be orthogonal, and we observe significant gains from a Hash Layer + Ladder model compared to either alone. Taken together, these two methods give a fine-grained control over the parameter size and computation size, leading to these improvements. + +

+ +In summary, our work has examined the issues of computation vs. parameter size, and shown that these two concepts should be treated quite differently when thinking about new methods -- rather than tying them together as in many standard machine learning models. In particular, we present two new types of architecture that explore these tradeoffs -- either increasing the parameter size, or the computation amount -- and showing how their ideas can be combined together. We believe this way of thinking, and the use of our new methods in particular, can be a fruitful way forward for machine learning research. + + +To get more into the details read the [Hash Layers](https://arxiv.org/abs/2106.04426) +and [Staircase Attention](https://arxiv.org/abs/2106.04279) papers. + +Code is available [here](https://github.com/facebookresearch/ParlAI/tree/master/projects/params_vs_compute/hash_ladder). diff --git a/projects/params_vs_compute/figs/hash.png b/projects/params_vs_compute/figs/hash.png new file mode 100644 index 00000000000..79e20ab5b3c Binary files /dev/null and b/projects/params_vs_compute/figs/hash.png differ diff --git a/projects/params_vs_compute/figs/hash_ladder_results.png b/projects/params_vs_compute/figs/hash_ladder_results.png new file mode 100644 index 00000000000..3ca8109c067 Binary files /dev/null and b/projects/params_vs_compute/figs/hash_ladder_results.png differ diff --git a/projects/params_vs_compute/figs/hash_results.png b/projects/params_vs_compute/figs/hash_results.png new file mode 100644 index 00000000000..dbe3a654208 Binary files /dev/null and b/projects/params_vs_compute/figs/hash_results.png differ diff --git a/projects/params_vs_compute/figs/staircase.png b/projects/params_vs_compute/figs/staircase.png new file mode 100644 index 00000000000..9ae69e7a73e Binary files /dev/null and b/projects/params_vs_compute/figs/staircase.png differ diff --git a/projects/params_vs_compute/figs/staircase_results.png b/projects/params_vs_compute/figs/staircase_results.png new file mode 100644 index 00000000000..a948a37b63b Binary files /dev/null and b/projects/params_vs_compute/figs/staircase_results.png differ diff --git a/projects/params_vs_compute/hash_ladder/hash_ladder.py b/projects/params_vs_compute/hash_ladder/hash_ladder.py new file mode 100644 index 00000000000..ed5de6dc988 --- /dev/null +++ b/projects/params_vs_compute/hash_ladder/hash_ladder.py @@ -0,0 +1,460 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations +import torch +from typing import Dict, Optional, Tuple +import torch.nn as nn + +from parlai.agents.transformer.modules import ( + TransformerDecoder, + TransformerDecoderLayer, + TransformerGeneratorModel, +) + +from parlai.agents.transformer.modules import ( + create_position_codes, + get_n_positions_from_options, + LAYER_NORM_EPS, +) + +from parlai.agents.transformer.transformer import TransformerGeneratorAgent +from parlai.core.opt import Opt +from parlai.core.params import ParlaiParser +from parlai.utils.misc import warn_once +import torch.nn.functional as F + +from torch.nn import LayerNorm + +########################################### +# Hash Ladder Transformer # +########################################### + +""" +Use with, e.g.: + +parlai train_model -m projects.params_vs_compute.hash_ladder.hash_ladder:HashLadderAgent -t convai2:normalized -mf /tmp/model_file --ladder-size 1 --hash-size 32 --hash-layer 1 +""" + + +class HashLadderAgent(TransformerGeneratorAgent): + """ + Simple implementation of Hash Layers and the Ladder model from the following papers: + """ + + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + TransformerGeneratorAgent.add_cmdline_args(parser, partial_opt=partial_opt) + # Add transformer args. + parser.add_argument( + '--ladder-size', + type=int, + default=1, + help='Number of ladder steps, default is not to use ladder.', + ) + parser.add_argument( + '--hash-size', type=int, default=32, help='Number of hash bins.' + ) + parser.add_argument( + '--hash-layer', + type=int, + default=7, + help='Layer number the Hash Layer appears on.', + ) + return parser + + def build_model(self, states=None): + wrapped_class = TransformerGeneratorModel.with_components(decoder=Decoder) + return wrapped_class(self.opt, self.dict) + + +def _normalize(tensor, norm_layer): + """ + Broadcast layer norm. + """ + is_cpu = tensor.device == 'cpu' or tensor.device.type == 'cpu' + return norm_layer(tensor) + + +class Decoder(TransformerDecoder): + """ + Custom Decoder with Ladder model. + """ + + def __init__( + self, + opt: Opt, + embedding: Optional[nn.Embedding] = None, + n_positions: Optional[int] = None, + **kwargs, + ): + super().__init__(opt, **kwargs) + + def _default(val, default): + return val if val is not None else default + + opt['dict_size'] = embedding.weight.size(0) + self.opt = opt + self.embedding_size = opt['embedding_size'] + self.ffn_size = opt['ffn_size'] + self.n_layers = ( + opt['n_decoder_layers'] + if opt.get('n_decoder_layers', -1) > 0 + else opt['n_layers'] + ) + self.n_heads = opt['n_heads'] + self.dim = self.embedding_size + self.activation = opt.get('activation', 'relu') + self.variant = opt.get('variant', 'aiayn') + + self.embeddings_scale = opt.get('embeddings_scale', True) + dropout_frac = opt.get('dropout', 0.0) + self.dropout = nn.Dropout(p=dropout_frac) # --dropout + + self.n_positions = _default(n_positions, get_n_positions_from_options(opt)) + self.out_dim = self.embedding_size + assert ( + self.embedding_size % self.n_heads == 0 + ), 'Transformer embedding size must be a multiple of n_heads' + + self.embeddings = embedding + + if ( + self.variant == 'xlm' + or self.variant == 'prelayernorm' + or self.variant == 'bart' + ): + self.norm_embeddings = torch.nn.LayerNorm(self.dim, eps=LAYER_NORM_EPS) + if self.variant == 'xlm': + warn_once( + 'DEPRECATED: XLM should only be used for backwards compatibility, ' + 'as it involves a less-stable layernorm operation.' + ) + elif self.variant == 'aiayn': + pass + else: + raise ValueError("Can't handle --variant {}".format(self.variant)) + + # create the positional embeddings + self.position_embeddings = nn.Embedding(self.n_positions, self.embedding_size) + if not opt.get('learn_positional_embeddings', False): + create_position_codes( + self.n_positions, + self.embedding_size, + out=self.position_embeddings.weight, + ) + else: + nn.init.normal_( + self.position_embeddings.weight, 0, self.embedding_size ** -0.5 + ) + + # build the model + self.layers = nn.ModuleList() + for i in range(self.n_layers): + if self.opt['hash_layer'] == i: + self.layers.append( + HashLayer( + self.n_heads, + self.embedding_size, + self.ffn_size, + attention_dropout=opt.get('attention_dropout', 0.0), + relu_dropout=opt.get('relu_dropout', 0.0), + dropout=dropout_frac, + activation=self.activation, + variant=self.variant, + opt=self.opt, + ) # type: ignore + ) + else: + self.layers.append( + self.swappables.layer( + self.n_heads, + self.embedding_size, + self.ffn_size, + attention_dropout=opt.get('attention_dropout', 0.0), + relu_dropout=opt.get('relu_dropout', 0.0), + dropout=dropout_frac, + activation=self.activation, + variant=self.variant, + ) # type: ignore + ) + + def forward_layers( + self, + tensor: torch.Tensor, + encoder_output: torch.Tensor, + encoder_mask: torch.Tensor, + incr_state: Dict[int, Dict[str, Dict[str, torch.Tensor]]], + original_input: torch.Tensor, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Forward pass of decoder layers. + + :param tensor: + embedded input tensor for the decoder + :param enc_out: + encoder outputs + :param enc_mask: + encoder output mask + :param incr_state: + Dict mapping layer_idx to incremental state + + :return (tensor, new_incr_state): + return encoding after applying decoder layers, as well + as new incremental decoding state. + """ + new_incr_state = {} + if getattr(self.layers, 'is_model_parallel', False): + tensor, new_incr_state = self._apply_model_parallel( + tensor, encoder_output, encoder_mask, incr_state + ) + else: + for _s in range(0, self.opt['ladder_size']): + for idx, layer in enumerate(self.layers): + if idx == self.opt['hash_layer']: + tensor, new_incr_state[idx] = layer( + x=tensor, + encoder_output=encoder_output, + encoder_mask=encoder_mask, + incr_state=incr_state.get(idx), + orig_input=original_input, + ) + else: + tensor, new_incr_state[idx] = layer( + x=tensor, + encoder_output=encoder_output, + encoder_mask=encoder_mask, + incr_state=incr_state.get(idx), + ) + + return tensor, new_incr_state + + def forward(self, input, encoder_state, incr_state=None): + """ + Forward pass. + + :param LongTensor[batch,seqlen] input: + The decoder inputs (partial or full decoded token IDs). + :param encoder_state: + Output from the encoder module forward pass. + :param incr_state: + The incremental state: a dictionary whose keys index the layers and whose + values contain the incremental state for each layer. + """ + encoder_output, encoder_mask = encoder_state + + seq_len = input.size(1) + positions = input.new(seq_len).long() + positions = torch.arange(seq_len, out=positions).unsqueeze(0) + + if incr_state is not None: + # We're doing incremental decoding, so select only the most recent position + input = input[:, -1:] + if positions is not None: + positions = positions[:, -1:] + else: + incr_state = {} + + tensor = self.forward_embedding(input, positions) + + tensor = self.dropout(tensor) # --dropout + + tensor, new_incr_state = self.forward_layers( + tensor, encoder_output, encoder_mask, incr_state, original_input=input + ) + + if self.variant == 'prelayernorm': + tensor = _normalize(tensor, self.norm_embeddings) + + return tensor, new_incr_state + + +class HashLayer(TransformerDecoderLayer): + def __init__( + self, + n_heads: int, + embedding_size: int, + ffn_size: int, + opt: Opt, + attention_dropout: float = 0.0, + relu_dropout: float = 0.0, + dropout: float = 0.0, + activation: str = 'relu', + variant: str = 'aiayn', + **kwargs, + ): + super().__init__(n_heads, embedding_size, ffn_size, **kwargs) + self.dim = embedding_size + self.ffn_dim = ffn_size + self.variant = variant + self.activation = activation + self.dropout = nn.Dropout(p=dropout) + + self.self_attention = self.swappables.self_attention( + n_heads, embedding_size, dropout=attention_dropout + ) # type: ignore + self.norm1 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) + + self.encoder_attention = self.swappables.encoder_attention( + n_heads, embedding_size, dropout=attention_dropout + ) # type: ignore + self.norm2 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) + + self.ffn = HashLayerFFN( + opt, + embedding_size, + ffn_size, + relu_dropout=relu_dropout, + activation=activation, + ) # type: ignore + self.norm3 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) + + def forward( + self, x, encoder_output, encoder_mask, incr_state=None, orig_input=None + ): + """ + Forward pass. + + The incremental state is a dict with values for self- and encoder-attention + states. + """ + + if incr_state is None: + incr_state = {} + + decoder_mask = self._create_selfattn_mask(x) + # first self attn + residual = x + if self.variant == 'prelayernorm': + x = _normalize(x, self.norm1) + + # don't peak into the future! + x, final_self_attn_incr_state = self.self_attention( + query=x, + mask=decoder_mask, + incr_state=incr_state.get('self_attn'), + static_kv=False, + )[:2] + x = self.dropout(x) # --dropout + x = x + residual + if self.variant == 'aiayn' or self.variant == 'xlm' or self.variant == 'bart': + x = _normalize(x, self.norm1) + + residual = x + # encoder_attn_layer_norm norm 2 + if self.variant == 'prelayernorm': + x = _normalize(x, self.norm2) + x, final_encoder_attn_incr_state, dotprod = self.encoder_attention( + query=x, + key=encoder_output, + value=encoder_output, + mask=encoder_mask, + incr_state=incr_state.get('encoder_attn'), + static_kv=True, + ) + x = self.dropout(x) # --dropout + x = residual + x + if self.variant == 'aiayn' or self.variant == 'xlm' or self.variant == 'bart': + x = _normalize(x, self.norm2) + + # finally the ffn + residual = x + if self.variant == 'prelayernorm': + x = _normalize(x, self.norm3) + x = self.ffn(x, orig_input) + x = self.dropout(x) # --dropout + x = residual + x + if self.variant == 'aiayn' or self.variant == 'xlm' or self.variant == 'bart': + x = _normalize(x, self.norm3) + + new_incr_state = { + 'self_attn': final_self_attn_incr_state, + 'encoder_attn': final_encoder_attn_incr_state, + } + + self.output = x + + return x, new_incr_state + + +class HashLayerFFN(nn.Module): + """ + Implements the Hash Layer FFN. + """ + + def __init__(self, opt, dim, dim_hidden, relu_dropout=0, activation='relu'): + super(HashLayerFFN, self).__init__() + self.relu_dropout = nn.Dropout(p=relu_dropout) + self.nonlinear = F.relu + self.opt = opt + self.dim = dim + self.dim_hidden = dim_hidden + self.hashsize = opt['hash_size'] + + linears1 = [] + linears2 = [] + norms = [] + + for i in range(0, self.hashsize): + linears1.append(nn.Linear(dim, dim_hidden)) + nn.init.xavier_uniform_(linears1[i].weight) + for i in range(0, self.hashsize): + linears2.append(nn.Linear(dim_hidden, dim)) + nn.init.xavier_uniform_(linears2[i].weight) + + embedding_size = self.opt['embedding_size'] + norms.append(LayerNorm(embedding_size, eps=LAYER_NORM_EPS)) + + self.linears1 = nn.ModuleList(linears1) + self.linears2 = nn.ModuleList(linears2) + self.norms = nn.ModuleList(norms) + + self.alter_tok = -1 + self.alter_bin = -1 + + def hash(self, xi): + # Insert your choice of hash function here. + # In this code we simply randomly hash based on the given token IDs for simplicity. + if not hasattr(self, 'hash_bin_map'): + # create random mapping. + sz = self.opt['dict_size'] + self.hash_bin_map = torch.LongTensor(sz).fill_(0) + import random + + random.seed(42) + for i in range(sz): + self.hash_bin_map[i] = random.randrange(0, self.hashsize) + + # Now compute the hash bins given the mapping function (Whatever it is). + return self.hash_bin_map[xi] + + def forward(self, x, orig_input): + """ + Forward pass. + """ + xhs = self.hash(orig_input) + + # Now do the real work. + # This implementation could be more efficient -- but it works. + index_list = [ + torch.eq(xhs, i).nonzero(as_tuple=True) for i in range(self.hashsize) + ] + final_output = x.new_zeros(x.shape) + + for i in range(self.hashsize): + vecs = x[index_list[i][0], index_list[i][1], :] + if vecs.shape[0] > 0: + residual = vecs + x1 = self.linears1[i](vecs) + x1 = self.nonlinear(x1) + x1 = self.relu_dropout(x1) # --relu-dropout + x1 = self.linears2[i](x1) + x1 = residual + x1 + x1 = _normalize(x1, self.norms[0]) + final_output[index_list[i][0], index_list[i][1], :] = x1 + + return final_output