diff --git a/eole/bin/convert/convert_HF.py b/eole/bin/convert/convert_HF.py index c66fa1c3..6ad47d29 100755 --- a/eole/bin/convert/convert_HF.py +++ b/eole/bin/convert/convert_HF.py @@ -461,16 +461,27 @@ def run(cls, args): norm_eps = config["layer_norm_eps"] else: norm_eps = 1e-6 + rope_config = {} if "rope_theta" in config.keys(): - rope_theta = config["rope_theta"] - else: - rope_theta = 1e4 + rope_config["rotary_theta"] = config["rope_theta"] if "rotary_dim" in config.keys(): - rotary_dim = config["rotary_dim"] + rope_config["rotary_dim"] = config["rotary_dim"] elif "partial_rotary_factor" in config.keys(): - rotary_dim = int(config["partial_rotary_factor"] * (hidden_size // heads)) - else: - rotary_dim = 0 + rope_config["rotary_dim"] = int( + config["partial_rotary_factor"] * (hidden_size // heads) + ) + if config.get("rope_scaling", None) is not None: + rope_config["scaling_type"] = config["rope_scaling"].get("rope_type", None) + rope_config["scaling_factor"] = config["rope_scaling"].get("factor", 8.0) + rope_config["low_freq_factor"] = config["rope_scaling"].get( + "low_freq_factor", 1.0 + ) + rope_config["high_freq_factor"] = config["rope_scaling"].get( + "high_freq_factor", 4.0 + ) + rope_config["original_max_position_embeddings"] = config[ + "rope_scaling" + ].get("original_max_position_embeddings", 8192) if "sliding_window" in config.keys(): sliding_window = config["sliding_window"] if sliding_window is None: @@ -541,8 +552,8 @@ def run(cls, args): add_qkvbias = False add_ffnbias = False - rotary_interleave = False shared_layer_norm = False + rope_config["rotary_interleave"] = False position_encoding = { "position_encoding_type": "Rotary", "n_positions": 0, @@ -555,7 +566,7 @@ def run(cls, args): shared_layer_norm = True add_qkvbias = True add_ffnbias = True - rotary_interleave = False + rope_config["rotary_interleave"] = False if arch == "GPT2LMHeadModel": parallel_residual = False shared_layer_norm = True @@ -1003,9 +1014,7 @@ def get_weight(checkpoint, tensor_name): layer_norm=layer_norm, norm_eps=norm_eps, mlp_activation_fn=mlp_activation_fn, - rotary_interleave=rotary_interleave, - rotary_theta=rope_theta, - rotary_dim=rotary_dim, + rope_config=rope_config, sliding_window=sliding_window, heads_kv=heads_kv, parallel_residual=parallel_residual, diff --git a/eole/config/models.py b/eole/config/models.py index fd39fb88..e899dcdf 100644 --- a/eole/config/models.py +++ b/eole/config/models.py @@ -156,6 +156,47 @@ class MeanEncoderConfig(EncoderConfig): encoder_type: Literal["mean"] = Field(default="mean") +class RotaryPositionConfig(Config): + """ + Configuration for rotary position embeddings used in transformer models. + """ + + rotary_interleave: bool = Field( + default=True, + description="Interleave the head dimensions when rotary embeddings are applied. " + "Otherwise the head dimensions are sliced in half. " + "(True=default Llama from Meta (original), " + "False= used by all HuggingFace models)", + ) + rotary_theta: int = Field( + default=10000, + description="Rotary theta base length, 1e4 for Llama2.Mistral, 1e6 for Mixtral", + ) + rotary_dim: int = Field( + default=0, + description="Rotary dim when model requires it to be different to head dim.", + ) + scaling_type: str | None = Field( + default=None, + description="Specifies the type of RoPE scaling to be applied, if any.", + ) + scaling_factor: float | None = Field( + default=8.0, description="Factor by which to scale RoPE embeddings." + ) + low_freq_factor: float | None = Field( + default=1.0, + description="Scaling factor applied to the lower frequency components of RoPE.", + ) + high_freq_factor: float | None = Field( + default=4.0, + description="Scaling factor applied to the higher frequency components of RoPE.", + ) + original_max_position_embeddings: int | None = Field( + default=8192, + description="Original maximum position embeddings for RoPE scaling.", + ) + + class TransformerConfig(Config): """ This base TransformerConfig class regroups parameters than can @@ -182,21 +223,6 @@ class TransformerConfig(Config): default=ActivationFunction.relu, description="The activation function to use in MLP layer.", ) - rotary_interleave: bool = Field( - default=True, - description="Interleave the head dimensions when rotary embeddings are applied. " - "Otherwise the head dimensions are sliced in half. " - "(True=default Llama from Meta (original), " - "False= used by all HuggingFace models)", - ) - rotary_theta: int = Field( - default=10000, - description="Rotary theta base length, 1e4 for Llama2.Mistral, 1e6 for Mixtral", - ) - rotary_dim: int = Field( - default=0, - description="Rotary dim when model requires it to be different to head dim.", - ) layer_norm: Literal["standard", "rms"] = Field( default="standard", description="Type of layer normalization in transformer architecture.", @@ -244,6 +270,16 @@ class TransformerConfig(Config): "Case 2: Max Relative Positions" "In the case of position_encoding_type: Relative", ) + rope_config: RotaryPositionConfig | None = Field( + default=None, description="Rotary position config, if relevant." + ) + + @model_validator(mode="after") + def _validate_transformer_config(self): + if self.position_encoding_type == PositionEncodingType.Rotary: + if self.rope_config is None: + self.rope_config = RotaryPositionConfig() + return self # could eole.encoders.TransformerEncoder class inherit from this? (it seems not unfortunately) diff --git a/eole/decoders/transformer_decoder.py b/eole/decoders/transformer_decoder.py index 7a4ecc45..452d926f 100644 --- a/eole/decoders/transformer_decoder.py +++ b/eole/decoders/transformer_decoder.py @@ -10,7 +10,8 @@ TransformerDecoderBase, ) from eole.modules.multi_headed_attn import ContextMHA -from eole.constants import LayerNorm +from eole.constants import LayerNorm, PositionEncodingType +from eole.modules.rope import RotaryPosition class TransformerDecoderLayer(TransformerDecoderLayerBase): @@ -57,6 +58,7 @@ def _forward( step=None, future=False, return_attn=False, + position_embeddings=None, ): """A naive forward pass for transformer decoder. @@ -70,6 +72,7 @@ def _forward( step (int or None): stepwise decoding counter future (bool): If set True, do not apply future_mask. return_attn (bool) : if set True requires attns output + position_embeddings (FloatTensor): rotary position encodings, if any Returns: (FloatTensor, FloatTensor): @@ -98,6 +101,7 @@ def _forward( sliding_window=self.sliding_window, step=step, return_attn=return_attn, + position_embeddings=position_embeddings, ) if self.dropout_p > 0: @@ -149,6 +153,9 @@ def __init__( model_config, running_config=running_config ) + if model_config.position_encoding_type == PositionEncodingType.Rotary: + self.rope = RotaryPosition(model_config) + self.transformer_layers = nn.ModuleList( [ TransformerDecoderLayer( @@ -191,6 +198,15 @@ def forward(self, emb, **kwargs): {"keys": torch.tensor([]), "values": torch.tensor([])}, ) + if hasattr(self, "rope"): + position_embeddings = self.rope( + emb, + step=step, + device=emb.device, + ) + else: + position_embeddings = None + with_align = kwargs.pop("with_align", False) return_attn = with_align or kwargs.pop("return_attn", False) @@ -205,6 +221,7 @@ def forward(self, emb, **kwargs): step=step, with_align=with_align, return_attn=return_attn, + position_embeddings=position_embeddings, ) if attn_align is not None: attn_aligns.append(attn_align) diff --git a/eole/decoders/transformer_lm_decoder.py b/eole/decoders/transformer_lm_decoder.py index 25307c45..40a66da4 100644 --- a/eole/decoders/transformer_lm_decoder.py +++ b/eole/decoders/transformer_lm_decoder.py @@ -10,7 +10,8 @@ TransformerDecoderLayerBase, TransformerDecoderBase, ) -from eole.constants import LayerNorm +from eole.constants import LayerNorm, PositionEncodingType +from eole.modules.rope import RotaryPosition class TransformerLMDecoderLayer(TransformerDecoderLayerBase): @@ -19,7 +20,15 @@ class TransformerLMDecoderLayer(TransformerDecoderLayerBase): See TransformerDecoderLayerBase """ - def _forward(self, layer_in, pad_mask, step=None, future=False, return_attn=False): + def _forward( + self, + layer_in, + pad_mask, + step=None, + future=False, + return_attn=False, + position_embeddings=None, + ): """A naive forward pass for transformer decoder. # T: could be 1 in the case of stepwise decoding or tgt_len @@ -31,6 +40,7 @@ def _forward(self, layer_in, pad_mask, step=None, future=False, return_attn=Fals step (int or None): stepwise decoding counter future (bool): If set True, do not apply future_mask. return_attn (bool): If set True return attn + position_embeddings (FloatTensor): rotary position encodings, if any Returns: (FloatTensor, FloatTensor): @@ -59,6 +69,7 @@ def _forward(self, layer_in, pad_mask, step=None, future=False, return_attn=Fals sliding_window=self.sliding_window, step=step, return_attn=return_attn, + position_embeddings=position_embeddings, ) if self.dropout_p > 0: attn_output = self.dropout(attn_output) @@ -91,6 +102,9 @@ def __init__( ): super(TransformerLMDecoder, self).__init__(model_config) + if model_config.position_encoding_type == PositionEncodingType.Rotary: + self.rope = RotaryPosition(model_config) + self.transformer_layers = nn.ModuleList( [ TransformerLMDecoderLayer( @@ -112,6 +126,16 @@ def forward(self, emb, **kwargs): pad_mask = kwargs.pop("tgt_pad_mask", None) assert pad_mask is not None, "TransformerLMDecoder requires a pad mask" step = kwargs.pop("step", None) + + if hasattr(self, "rope"): + position_embeddings = self.rope( + emb, + step=step, + device=emb.device, + ) + else: + position_embeddings = None + if step == 0: # decoding mode. # Initialize KV and key_pad_mask cache. @@ -140,6 +164,7 @@ def forward(self, emb, **kwargs): step=step, with_align=with_align, return_attn=return_attn, + position_embeddings=position_embeddings, ) emb = self.layer_norm(emb) @@ -160,7 +185,3 @@ def _init_cache(self, device, mask): "key_pad_mask": mask, }, ) - if hasattr(layer.self_attn, "rope"): - layer.self_attn.rope = layer.self_attn.rope.to(device) - layer.self_attn.cos = layer.self_attn.cos.to(device) - layer.self_attn.sin = layer.self_attn.sin.to(device) diff --git a/eole/encoders/transformer.py b/eole/encoders/transformer.py index 6f1aa048..cb7f729c 100644 --- a/eole/encoders/transformer.py +++ b/eole/encoders/transformer.py @@ -7,7 +7,8 @@ from eole.encoders.encoder import EncoderBase from eole.modules.multi_headed_attn import SelfMHA from eole.modules.transformer_mlp import MLP -from eole.constants import LayerNorm +from eole.constants import LayerNorm, PositionEncodingType +from eole.modules.rope import RotaryPosition class TransformerEncoderLayer(nn.Module): @@ -45,18 +46,21 @@ def __init__( running_config=running_config, ) - def forward(self, layer_in, mask): + def forward(self, layer_in, mask, position_embeddings=None): """ Args: layer_in (FloatTensor): ``(batch_size, src_len, model_dim)`` mask (LongTensor): ``(batch_size, 1, src_len)`` + position_embeddings (FloatTensor): rotary position encodings, if any Returns: (FloatTensor): * layer_out ``(batch_size, src_len, model_dim)`` """ norm_layer_in = self.input_layernorm(layer_in) - context, _ = self.self_attn(norm_layer_in, mask=mask) + context, _ = self.self_attn( + norm_layer_in, mask=mask, position_embeddings=position_embeddings + ) if self.dropout_p > 0: context = self.dropout(context) if self.parallel_residual: @@ -98,6 +102,9 @@ def __init__( ): super(TransformerEncoder, self).__init__() + if model_config.position_encoding_type == PositionEncodingType.Rotary: + self.rope = RotaryPosition(model_config) + self.transformer_layers = nn.ModuleList( [ TransformerEncoderLayer( @@ -130,8 +137,13 @@ def forward(self, emb, mask=None): # 1 to be expanded to number of heads in MHA # Run the forward pass of every layer of the tranformer. + if hasattr(self, "rope"): + position_embeddings = self.rope(emb, step=0, device=emb.device) + else: + position_embeddings = None + for layer in self.transformer_layers: - enc_out = layer(enc_out, mask) + enc_out = layer(enc_out, mask, position_embeddings=position_embeddings) enc_out = self.layer_norm(enc_out) return enc_out, None diff --git a/eole/modules/multi_headed_attn.py b/eole/modules/multi_headed_attn.py index 4721e711..b6e81fea 100644 --- a/eole/modules/multi_headed_attn.py +++ b/eole/modules/multi_headed_attn.py @@ -15,23 +15,6 @@ # Help functions for Rotary Embeddings # https://arxiv.org/pdf/2104.09864.pdf -# too convoluted to make maxseqlen a parameter. -# we suppose src_seq_len at training and max_length at inference -# are both < 2048 tokens. - - -def rotaryembeddings(dim: int, maxseqlen=2048, base=10000, device=None): - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) - tmax = torch.arange(maxseqlen, device=inv_freq.device) - rope = torch.outer(tmax, inv_freq).float() - # rope is now matrix [maxseqlen, dim/2] - rope = torch.polar(torch.ones_like(rope), rope) - rope = torch.cat((rope, rope), dim=1) - if device is not None: - rope = rope.to(device) - cos = rope[:, : rope.size(1) // 2].real.contiguous().half() - sin = rope[:, : rope.size(1) // 2].imag.contiguous().half() - return rope, cos, sin def rotate_half(x): @@ -316,6 +299,10 @@ def __init__( False, {"keys": torch.tensor([]), "values": torch.tensor([])}, ) + # TODO find a cleaner way to initialize? + self.relative_positions_embeddings = None + self.relative_attention_bias = None + self.rotary_interleave = None if model_config.relative_positions_buckets > 0: self.relative_attention_bias = nn.Embedding( model_config.relative_positions_buckets, self.heads @@ -331,27 +318,14 @@ def __init__( self.relative_positions_embeddings = nn.Embedding( vocab_size, self.dim_per_head ) - self.relative_attention_bias = None - else: - self.relative_positions_embeddings = None - self.relative_attention_bias = None - - if self.position_encoding_type == PositionEncodingType.Rotary: - if model_config.rotary_dim == 0: - self.rotary_dim = self.dim_per_head - else: - self.rotary_dim = model_config.rotary_dim - self.rope, self.cos, self.sin = rotaryembeddings( - self.rotary_dim, base=model_config.rotary_theta - ) - self.rotary_interleave = model_config.rotary_interleave - self.rotary_theta = model_config.rotary_theta + elif self.position_encoding_type == PositionEncodingType.Rotary: + if model_config.rope_config.rotary_dim == 0: + self.rotary_dim = self.dim_per_head else: - self.cos = None - self.sin = None - self.rotary_interleave = None - if model_config.position_encoding_type == PositionEncodingType.Alibi: - self.alibi = AlibiPositionalBias(self.heads) + self.rotary_dim = model_config.rope_config.rotary_dim + self.rotary_interleave = model_config.rope_config.rotary_interleave + elif model_config.position_encoding_type == PositionEncodingType.Alibi: + self.alibi = AlibiPositionalBias(self.heads) self.maybe_ckpt = ( checkpoint @@ -373,10 +347,29 @@ def update_dropout(self, dropout: float) -> None: self.dropout.p = dropout self.dropout_p = dropout - def _forward1( - self, key: Tensor, value: Tensor, query: Tensor + def _prepare_inputs( + self, + key: Tensor, + value: Tensor, + query: Tensor, + position_embeddings=None, ) -> Tuple[Tensor, Tensor, Tensor]: - """ """ + """ + Prepare inputs for attention computation. + This method performs the following steps: + 1. Applies linear transformations to key, value, and query inputs. + 2. Reshapes the tensors to the multi-head attention format. + 3. Applies rotary position encoding if configured. + + Args: + key (Tensor): The key tensor of shape [batch_size, seq_len, hidden_size]. + value (Tensor): The value tensor of shape [batch_size, seq_len, hidden_size]. + query (Tensor): The query tensor of shape [batch_size, seq_len, hidden_size]. + + Returns: + Tuple[Tensor, Tensor, Tensor]: Processed key, value, and query tensors, each of shape + [batch_size, num_heads, seq_len, dim_per_head]. + """ # Retrieve keys and values from linear layers (training mode). key = self.maybe_ckpt(self.linear_keys, key) value = self.maybe_ckpt(self.linear_values, value) @@ -388,21 +381,15 @@ def _forward1( if self.position_encoding_type == PositionEncodingType.Rotary: start_pos = 0 seqlen = query.size(2) - if seqlen > self.rope.size(0): - # Resize rotary embeddings. - self.rope, self.cos, self.sin = rotaryembeddings( - self.rotary_dim, - maxseqlen=(seqlen + 2048), - base=self.rotary_theta, - device=query.device, - ) - rope = self.rope[start_pos : start_pos + seqlen].to(query.device) + position_embeddings = position_embeddings[ + start_pos : start_pos + seqlen + ].to(query.device) query, key = apply_rotary_emb( - query, key, rope, interleave=self.rotary_interleave + query, key, position_embeddings, interleave=self.rotary_interleave ) return key, value, query - def _forward2( + def _compute_attention( self, key: Tensor, value: Tensor, @@ -417,13 +404,13 @@ def _forward2( Args: key (Tensor): set of `key_len` - key vectors ``(batch, key_len, dim)`` + key vectors ``(batch, head, key_len, dim)`` value (Tensor): set of `key_len` - value vectors ``(batch, key_len, dim)`` + value vectors ``(batch, head, key_len, dim)`` query (Tensor): set of `query_len` - query vectors ``(batch, query_len, dim)`` + query vectors ``(batch, head, query_len, dim)`` mask: binary mask 1/0 indicating which keys have - zero / non-zero attention ``(batch, query_len, key_len)`` + zero / non-zero attention ``(batch, 1, query_len, key_len)`` Returns: (Tensor, Tensor): @@ -564,6 +551,46 @@ def __init__( self.n_positions = model_config.n_positions super(SelfMHA, self).__init__(model_config, running_config, is_decoder) + def _prepare_inputs_w_cache( + self, + query, + key, + value, + step: Optional[int] = 0, + sliding_window: Optional[int] = 0, + position_embeddings=None, + ): + start_pos = step + seqlen = query.size(2) + if self.position_encoding_type == PositionEncodingType.Rotary: + if position_embeddings is not None: + position_embeddings = position_embeddings[ + start_pos : start_pos + seqlen + ] + query, key = apply_rotary_emb( + query, key, position_embeddings, interleave=self.rotary_interleave + ) + # update the cache + if self.layer_cache[1]["keys"].numel() != 0: + key = torch.cat((self.layer_cache[1]["keys"], key), dim=2) + value = torch.cat((self.layer_cache[1]["values"], value), dim=2) + if sliding_window > 0 and key.size(2) > sliding_window: + key = key[:, :, 1:, :] + value = value[:, :, 1:, :] + # mask values for LM left padding by batch + if step == 0: + key_pad_mask = self.layer_cache[1].get("key_pad_mask", None) + if key_pad_mask is not None: + x = key_pad_mask.expand(-1, value.size(1), -1) + x = x.unsqueeze(3) + x = x.expand(-1, -1, -1, value.size(3)) + value = value.masked_fill(x, 0) + + self.layer_cache[1]["keys"] = key + self.layer_cache[1]["values"] = value + + return key, value, query + def forward( self, query: Tensor, @@ -571,6 +598,7 @@ def forward( sliding_window: Optional[int] = 0, step: Optional[int] = 0, return_attn: Optional[bool] = False, + position_embeddings=None, ) -> Tuple[Tensor, Tensor]: if self.layer_cache[0]: # Inference step decoding @@ -581,7 +609,6 @@ def forward( value = shape(value, self.dim_per_head) query = shape(query, self.dim_per_head) start_pos = step - seqlen = query.size(2) if ( step == 0 or not self.flash @@ -590,38 +617,14 @@ def forward( or query.size(0) > 128 # to check or query.dtype != torch.float16 # to match with flash ): - if self.position_encoding_type == PositionEncodingType.Rotary: - if seqlen + start_pos > self.rope.size(0): - # Resize rotary embeddings. - self.rope, self.cos, self.sin = rotaryembeddings( - self.rotary_dim, - maxseqlen=(seqlen + start_pos + 2048), - base=self.rotary_theta, - device=self.rope.device, - ) - rope = self.rope[start_pos : start_pos + seqlen] - query, key = apply_rotary_emb( - query, key, rope, interleave=self.rotary_interleave - ) - # update the cache - if self.layer_cache[1]["keys"].numel() != 0: - key = torch.cat((self.layer_cache[1]["keys"], key), dim=2) - value = torch.cat((self.layer_cache[1]["values"], value), dim=2) - if sliding_window > 0 and key.size(2) > sliding_window: - key = key[:, :, 1:, :] - value = value[:, :, 1:, :] - # mask values for LM left padding by batch - if step == 0: - key_pad_mask = self.layer_cache[1].get("key_pad_mask", None) - if key_pad_mask is not None: - x = key_pad_mask.expand(-1, value.size(1), -1) - x = x.unsqueeze(3) - x = x.expand(-1, -1, -1, value.size(3)) - value = value.masked_fill(x, 0) - - self.layer_cache[1]["keys"] = key - self.layer_cache[1]["values"] = value - + key, value, query = self._prepare_inputs_w_cache( + query, + key, + value, + step=step, + sliding_window=sliding_window, + position_embeddings=position_embeddings, + ) else: # Fast path with flash_attn_with_kvcache if start_pos >= self.layer_cache[1]["keys"].size(2): @@ -649,20 +652,6 @@ def forward( ], dim=-2, ) - if ( - self.position_encoding_type == PositionEncodingType.Rotary - and start_pos + 32 >= self.rope.size(0) - ): - # Resize rotary embeddings. - # We take a margin of 32 tokens as the kv_cache - # is incremented by 32 tokens every 32 tokens. - self.rope, self.cos, self.sin = rotaryembeddings( - self.rotary_dim, - maxseqlen=(start_pos + 2048), - base=self.rotary_theta, - device=self.rope.device, - ) - if sliding_window > 0 and key.size(2) > sliding_window: self.layer_cache[1]["keys"] = self.layer_cache[1]["keys"][ :, :, 1:, : @@ -670,14 +659,25 @@ def forward( self.layer_cache[1]["values"] = self.layer_cache[1]["values"][ :, :, 1:, : ] + if position_embeddings is not None: + cos = ( + position_embeddings[:, : position_embeddings.size(1) // 2] + .real.contiguous() + .half() + ) + sin = ( + position_embeddings[:, : position_embeddings.size(1) // 2] + .imag.contiguous() + .half() + ) context = self.flash_attn_with_kvcache( query.transpose(1, 2), self.layer_cache[1]["keys"].transpose(1, 2), self.layer_cache[1]["values"].transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), - rotary_cos=self.cos, - rotary_sin=self.sin, + rotary_cos=cos, + rotary_sin=sin, cache_seqlens=step, rotary_interleaved=self.rotary_interleave, ).transpose(1, 2) @@ -687,9 +687,11 @@ def forward( return attn_output, None else: - key, value, query = super()._forward1(query, query, query) + key, value, query = super()._prepare_inputs( + query, query, query, position_embeddings=position_embeddings + ) - return super()._forward2( + return super()._compute_attention( key, value, query, @@ -710,6 +712,19 @@ def __init__( self.n_positions = 0 super(ContextMHA, self).__init__(model_config, running_config, True) + def _prepare_inputs_w_cache(self, key, value, query): + query = self.linear_query(query) + query = shape(query, self.dim_per_head) + if self.layer_cache[1]["keys"].numel() == 0: + key, value = self.linear_keys(key), self.linear_values(value) + self.layer_cache[1]["keys"] = shape(key, self.dim_per_head) + self.layer_cache[1]["values"] = shape(value, self.dim_per_head) + key, value = ( + self.layer_cache[1]["keys"], + self.layer_cache[1]["values"], + ) + return key, value, query + def forward( self, key: Tensor, @@ -722,21 +737,12 @@ def forward( ) -> Tuple[Tensor, Tensor]: if self.layer_cache[0]: # inference: we fill the cross-attention cache only once - query = self.linear_query(query) - query = shape(query, self.dim_per_head) - if self.layer_cache[1]["keys"].numel() == 0: - key, value = self.linear_keys(key), self.linear_values(value) - self.layer_cache[1]["keys"] = shape(key, self.dim_per_head) - self.layer_cache[1]["values"] = shape(value, self.dim_per_head) - key, value = ( - self.layer_cache[1]["keys"], - self.layer_cache[1]["values"], - ) + key, value, query = self._prepare_inputs_w_cache(key, value, query) else: # training: we project key, value query and apply rotary if required - key, value, query = super()._forward1(key, value, query) + key, value, query = super()._prepare_inputs(key, value, query) - return super()._forward2( + return super()._compute_attention( key, value, query, diff --git a/eole/modules/rope.py b/eole/modules/rope.py new file mode 100644 index 00000000..66026dc7 --- /dev/null +++ b/eole/modules/rope.py @@ -0,0 +1,144 @@ +import torch +import torch.nn as nn +import math + + +class RotaryPosition(nn.Module): + """ + Handles rotary position embeddings for transformer models. + + This module was refactored from multi-headed attention for improved clarity + and to support future enhancements, such as additional scaling types. + """ + + def __init__(self, model_config): + """ + Initializes the RotaryPosition module. + + Args: + model_config: Configuration object that contains model parameters, + including rotary embedding settings. + + Attributes: + model_config: The configuration object passed during initialization. + dim_per_head: The dimensionality of each attention head, computed + as `hidden_size // heads`. + rotary_interleave: Boolean flag to determine if head dimensions should + be interleaved or split when applying rotary embeddings. + rotary_theta: The base frequency for rotary embeddings. + inv_freq: Inverse frequency values used to calculate the rotary embeddings. + + Notes: + - If `rotary_dim` is set to 0 in the configuration, it defaults to + `dim_per_head`. + - Additional scaling types can be added in the future by extending this class. + """ + super(RotaryPosition, self).__init__() + self.model_config = model_config + self.dim_per_head = model_config.hidden_size // model_config.heads + if model_config.rope_config.rotary_dim == 0: + rotary_dim = self.dim_per_head + else: + rotary_dim = model_config.rope_config.rotary_dim + self.rotary_interleave = model_config.rope_config.rotary_interleave + self.rotary_theta = model_config.rope_config.rotary_theta + self.inv_freq = 1.0 / ( + self.rotary_theta ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim) + ) + # TODO: extend with other scaling types + if getattr(self.model_config.rope_config, "scaling_type", None) == "llama3": + self.llama3_scaling() + # cache rope tensor to limit unnecessary computations + self.rope = None + + def llama3_scaling(self): + """ + Applies the LLaMA3.1-specific scaling to the inverse frequencies. + + This scaling is based on LLaMA3.1's handling of different frequency components + within rotary embeddings. The method modifies `self.inv_freq` in place. + + Notes: + - Original values for `factor`, `low_freq_factor`, `high_freq_factor`, + and `original_max_position_embeddings` are taken from the configuration. + - The scaling factors are applied conditionally based on the wavelength + derived from the inverse frequencies. + """ + rope_config = self.model_config.rope_config + factor = rope_config.scaling_factor # `8` in the original implementation + low_freq_factor = ( + rope_config.low_freq_factor + ) # `1` in the original implementation + high_freq_factor = ( + rope_config.high_freq_factor + ) # `4` in the original implementation + old_context_len = ( + rope_config.original_max_position_embeddings + ) # `8192` in the original implementation + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + wavelen = 2 * math.pi / self.inv_freq + inv_freq_llama = torch.where( + wavelen > low_freq_wavelen, self.inv_freq / factor, self.inv_freq + ) + + smooth_factor = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + smoothed_inv_freq = ( + 1 - smooth_factor + ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) + self.inv_freq = inv_freq_llama + + def forward(self, emb, step=0, device=None, prefetch=1024): + """ + Computes the rotary position embeddings for a given input. + + Args: + emb: The input embeddings to which rotary embeddings will be applied. + step: The current step or position within the sequence. Defaults to 0. + device: The device on which the computations should be performed. + If None, defaults to the device of `self.inv_freq`. + offset: An optional offset to apply to the position indices. + This is used for the specific `flash_attn_with_kvcache` path, + which requires processes by chunks of 32 tokens. Defaults to 0. + + Returns: + torch.Tensor: A tensor containing the computed rotary embeddings. + + Notes: + - The returned tensor contains cosine and sine values representing the + rotary embeddings, concatenated along the last dimension. + - The output tensor's dimensions are `[maxseqlen, dim]`, where `dim` is + twice the size of the original inverse frequency tensor (`inv_freq`). + """ + if step is None: + step = 0 + maxseqlen = emb.size(1) + offset = ( + 32 # make sure we have at least 32 positions for flash_attn_with_kvcache + ) + # This could probably a bit cleaner/homogenized with the offset case + if self.rope is not None: + if self.rope.size(0) >= max(offset + step, 0) + maxseqlen: + return self.rope + else: + maxseqlen = maxseqlen + prefetch + tmax = torch.arange( + max(offset + step, 0) + maxseqlen, device=self.inv_freq.device + ) + rope = torch.outer(tmax, self.inv_freq).float() + # rope is now matrix [maxseqlen, dim/2] + rope = torch.polar(torch.ones_like(rope), rope) + rope = torch.cat((rope, rope), dim=1) + if device is not None: + rope = rope.to(device) + # cos = rope[:, : rope.size(1) // 2].real.contiguous().half() + # sin = rope[:, : rope.size(1) // 2].imag.contiguous().half() + # return rope, cos, sin + self.rope = rope + return rope diff --git a/recipes/llama3.1/README.md b/recipes/llama3.1/README.md new file mode 100644 index 00000000..a768f7e7 --- /dev/null +++ b/recipes/llama3.1/README.md @@ -0,0 +1,371 @@ +# Llama3 + +--- +**NOTE** +To make your life easier, run these commands from the recipe directory (here `recipes/llama3.1`). +--- + +## Retrieve and convert model + +### Set environment variables + +``` +export EOLE_MODEL_DIR= +export HF_TOKEN= +``` + +### Download and convert model + +``` +eole convert HF --model_dir meta-llama/Meta-Llama-3.1-8B --output $EOLE_MODEL_DIR/llama3.1-8b --token $HF_TOKEN +``` + + +## Inference + +### Write test prompt to text file + +(Example prompt inspired from this HF PR: https://github.com/huggingface/transformers/pull/24653) + +``` +echo -e "You are given this machine learning research paper, please read it carefully and answer the follow up question. + +=== BEGIN === + +2306.15595v2 [cs.CL] 28 Jun 2023 + +arXiv + +EXTENDING CONTEXT WINDOW OF LARGE LAN- +GUAGE MODELS VIA POSITION INTERPOLATION + +Shouyuan Chen Sherman Wong Liangjian Chen Yuandong Tian +Meta Platforms Inc. +{chenshouyuan, shermanwong, cli, yuandong}@meta . com + +1 INTRODUCTION + +Large language models (LLMs) typically come with a pre-defined context window size. For exam- +ple, inputs to LLaMA models (Touvron et al., 2023) must be fewer than 2048 tokens. This pre-set +context window limit is frequently exceeded in applications such as conducting long conversations, +summarizing long documents, or executing long-term planning. For these applications, LLMs with +longer context windows are preferred. However, training an LLM from scratch with long context +windows requires significant investments. This naturally leads to a question: Can we extend the +context window of an existing pre-trained LLM? + +One straightforward approach is to fine-tune an existing pre-trained Transformer with a longer con- +text window. However, empirically, we found that models trained this way adapt to long context +windows very slowly. After training for more than 10000 batches, the effective context window +saw a minimal increase, moving from 2048 to 2560 (Table 4). This suggests that such method is +inefficient for extending to substantially longer context windows. + +While certain techniques such as ALiBi (Press et al., 2022) and LeX (Sun et al., 2022) enable length +extrapolation of Transformers, i.e. train on short context windows and inference on longer ones, +many existing pre-trained LLMs, including LLaMA (Touvron et al., 2023), use positional encodings +that have weak extrapolation properties (e.g., RoPE (Su et al., 2021)). Therefore, the applicability +of these techniques for extending the context window sizes of such LLMs remains limited. + +In this work, we introduce Position Interpolation to enable context window extensions for certain +existing pre-trained LLMs, including LLaMA. The key idea is, instead of extrapolation, we directly +down-scale the position indices so that the maximum position index matches the previous context +window limit in the pre-training stage. See Figure 1 for an illustration. In other words, to accom- +modate more input tokens, we interpolate the position encodings at neighboring integer positions, +utilizing the fact that position encodings can be applied on non-integer positions, as opposed to +extrapolating outside the trained positions, which may lead to catastrophic values. We verify our +approach theoretically, by showing that the interpolated attention score has a much smaller upper + +bound (~ 600x smaller in LLaMA 7B setting) than the extrapolated one, and is thus much more +stable. Therefore, interpolated position encodings are easier for the model to adapt. + +Empirically, we found that Position Interpolation is highly effective and efficient, requiring only a +very short period of fine-tuning for the model to fully adapt to greatly extended context windows. +We present experimental results for extending the context window to up to 32768 from the initial +2048 across 7B to 65B LLaMA models using Position Interpolation. Our results show that + +1. Position Interpolation can easily enable very long context windows (e.g. 32768), requiring +only fine-tuning for 1000 steps on the Pile (Gao et al., 2020) to achieve a good quality. +The cost of fine-tuning is negligible compared to the pre-training costs. This confirms +our hypothesis that it is relatively easy for the models to adapt to interpolated position +encodings. + +2. Position Interpolation generates strong models that can effectively make use of much ex- +tended context window. We show that models extended by Position Interpolation enjoy +significant perplexity gains from greatly extended context windows for text modeling, and +we show that the perplexity reduces graceful with the enlargement of context windows. +We also applied Position Interpolation in a long text summarization task, and demonstrate +competitive performances. + +3. Position Interpolation preserves model quality relatively well for tasks within its original +context window sizes. We present a variety of evaluation results for the extended LLaMA +models on the original LLaMA benchmark. Compared with original LLaMA models, the +extended LLLaM A models saw a minor degradation on several standard benchmarks within +a 2048 token limit. + +Our results highlight the innate ability of Transformer models to “extrapolate to sequence lengths +longer than the ones encountered during training” as hypothesized in the seminal work of Vaswani +et al. (2017). We reaffirm this hypothesis and suggest that the previously known weakness of ex- +trapolating to longer sequences for language modeling (Press et al., 2022) may be due to direct + +extrapolation of positional encodings and it can be largely mitigated by interpolating position en- +codings instead. + +Concurrent work. Right before our release, we are informed with a concurrent blogpost (Super- +HOT kaiokendev (2023)) that also interpolates positional encoding in RoPE to extend the context +window from 2K to 8K. Recently, open source community picks it up in Reddit post ! and Github +Issues 2, which shows that fine-tuning with LoRA (Hu et al., 2021) also seems to work well. Our +paper shows a full fine-tuning with up to 65B model work well with Position Interpolation, and we +also give theoretical explanations why interpolation achieves much more stable results than extrap- +olation, by showing that the upper bound of interplated attention score is much lower than that of +extrapolated ones. + +2 METHOD + +2.1 BACKGROUND: ROTARY POSITION EMBEDDING (ROPE) + +Transformer models require explicit positional information to be injected, typically in the form of +positional encodings, to represent the order of inputs. We consider Rotary Position Embedding +(ROPE) (Su et al., 2021), which is the position encoding used in the LLLaMA model (Touvron et al., +2023). Given a position index m € [0, ¢) and an embedding vector x := [zg, 71,..., 241], Where +d is the dimension of the attention head, RoPE defines a vector-valued complex function f{x, m) as +follows + +Using RoPE, the self-attention score +is only dependent on relative position m — 7 through trigonometric functions. Here q and k are the +query and key vector for a specific attention head. At each layer, RoPE is applied on both query and +key embeddings for computing attention scores. + +2.2 DIRECT EXTRAPOLATION + +While the attention score in RoPE only depends on the relative positions, which is what we want, +its extrapolation performance is not great . In particular, when directly extending to larger context +windows unseen in the training, the perplexity may shoot up to very high numbers (i.e., > 10%), +comparable to untrained models. + +Ideally, we want to see the model trained on a context window of size L = 2048 to still work +reasonably well on longer context window, but may not have the capability to leverage information +that appears beyond L. For example, to answer a question located at 3000, the model trained on +maximal window size of I = 2048 cannot leverage evidences provided at location 0, but still +can leverage the evidences provided at location 2900. In contrast, in reality we see catastrophic +behaviors, i.e., question at location 3000 cannot be answered correctly, even if the evidences are +located at location 2900. + +What is the reason behind? How could this happen if the attention score a,,,—,, decays as the relative +distance |m — n/| increases, according to Section 3.4.3 of (Su et al., 2021), and content from very +far distances should not matter that much? It turns out that the upper bound derived in Section 3.4.3 +of (Su et al., 2021) may be too loose: while it indeed decays with respect to |m — nl, the bound +can still be quite large (i.e., the bound can be critically depends on the magnitude of v;) and thus +vacuous. In fact, if we treat all trigonometric functions as basis functions (i.e, ¢;(s) := #93), and +think about Eqn. 2 as basis expansion as the following: + +where s is the positional span between a query and a key and h; := (ga; + igaj+1){k2j — tk2j+1) +are complex coefficients depending on q and k (here the definition of h; is exactly the same as the +definition of k; in Sec 3.4.3 in RoPE (Su et al., 2021)). Now the the issue becomes clear: as shown +in Fig. 2, a, can be small in magnitude in the range of [0, 2048], but gives huge values out of the +region. The underlying reason is that the trigonometric family {¢;} (with sufficiently large d) is +a universal approximator and can fit any arbitrary functions. Therefore, for a, there always exist +coefficients {h;} (i.e. key and query) that corresponds to small function values in [0, 2048] but + +much larger in regions beyond. + +2.3 PROPOSED APPROACH: POSITION INTERPOLATION (PI) + +In Fig. 2, thanks to the smoothness of bases functions ¢; interpolation is much more stable and will +not lead to wild values. Therefore, instead of extrapolate the attention score in Eqn. 3 to s > L, +how about we define an attention score a{s) = a(Ls/L’) where L’ is the longer context window? +Formally, we replace RoPE f by {’ defined as follows + +We call this transformation on the position encoding Position Interpolation. In this step, we reduce +position indices from [0, L') to [0, L) to match the original range of indices before computing RoPE. +Consequently, as inputs to RoPE, the maximum relative distance between any two tokens has been +reduced from I’ to L. Since we align the ranges of position indices and relative distances before +and after extension, we mitigate the effect on attention score computation due to context window +extensions, which can allow the model easier to adapt. To further demonstrate this is the case, in the +following theorem, we show that the interpolated attention score is well-behaved: + +While there is no close form for B(s) := 4/21 |Ag41(s)|, numerically it is at least larger than d, and for many positional difference s, B(s) is much larger than d +(check Appendix B for the plot). Therefore, the interpolation bound is at least 2 - 294.73 ~ 600 x +smaller than the extrapolation bound, and thus the interpolated attention score is much more stable +than extrapolated one. + +Notably, our method of rescaling of position indices does not introduce extra weight, or modify +the model architecture in any way. This makes it attractive in practical applications, since most +infrastructure and optimization for the original model can be reused after the extension. + +Fine-tuning. We can further fine-tune the interpolated model using the next token prediction task +with interpolated position encodings on the extended context window size using a pre-training cor- +pus such as the Pile (Gao et al., 2020). In the next section, we show that our fine-tuning process +only needs tens to hundreds thousands of examples. We also find that the result of the fine-tuning +is not sensitive to the choice of examples. The reason may be that the model is only adapting to the +new context window during the fine-tuning phase, starting from a good initialization, as opposed to +acquiring new knowledge. + +Other ways to reduce interpolation/extrapolation bound. From the expression of the interpola- +tion (Eqn. 5) and extrapolation bound (Eqn. 8), a common term is max; ||, which is the maximal +magnitude of query/key products. If we enforce a regularization on || during LLM training, it is +possible that the catastrophic extrapolation error can be mitigated or even resolved. In fact, if we +apply ridge regression with proper regularization to fit a curve in Fig. 2, the magnitude of extrapo- +lated a(s) when s > L can be comparable to that within [0, L]. To our knowledge, we are not aware +of existing LLM pre-training techniques that leverage this regularization and will leave it for future +work. + +3 EXPERIMENTS + +We show Position Interpolation can effectively extend context window up to 32 times of the original +size, and such extension can be done with only several hundreds of training steps. We show the +resulting models are strong LLMs with fully effective long context windows. We demonstrate its +performance in a number of tasks including language modeling, passkey retrieval, and long doc- +ument summarization. We also present benchmark results of the extended models on the original +LLaMA evaluation benchmarks. +3.1 SETUP + +Model Variants. We extended the pre-trained 7B, 13B, 33B and 65B LLaMA models (Touvron +et al., 2023) to various context window of sizes up to 32768, using either direct fine-tuning or +Position Interpoloation method. Except for rescaling the position indices for models extended with +Position Interpolation, we did not modify LLaMA model architectures (Touvron et al., 2023) in any +ways. + +Training Procedure. We fine-tune all model variants using the next token prediction objective. We +use AdamW (Loshchilov & Hutter, 2019) with 5; = 0.9 and 2 = 0.95. We use a linear learning +rate warmup of 20 steps starting from 10% of the maximum learning rate. For 7B and 13B models, +we set the learning rate to 2 x 1075 and for 33B and 65B models we set the learning rate to 1072. We +set the weight decay to zero. For extending 7B, 13B and 33B models to the 8192 context window +size, we use 32 A100 GPUs and 64 global batch size. For all other cases we use 128 A100 GPUs and +128 global batch size. We note that the main need of using more GPUs is memory limitation during +fine-tuning, and it is possible to use fewer GPUs in certain cases. We train all models using PyTorch +(Paszke et al., 2019) with Fully Sharded Data Parallel (Zhao et al., 2023) and Flash Attention (Dao +et al., 2022). + +If not specified otherwise, for the Position Interpolation method, we fine-tune the models for 1000 +steps. For the direct fine-tuning method, we use 10000 steps. We primarily fine-tune using the Pile +training dataset (Gao et al., 2020). In Section 3.4 we also compared fine-tuning performance on the +RedPajama dataset (Computer, 2023). + +3.2 LONG SEQUENCE LANGUAGE MODELING + +We evaluate the long sequence language modeling performance of our extended models and base- +lines on two datasets: book corpus (PG-19) (Rae et al., 2020) and cleaned Arxiv Math proof-pile +dataset (Azerbayev et al., 2022). + +We use the test splits of PG19 (Rae et al., 2020) and proof-pile (Azerbayev et al., 2022). For PG19, +we use the whole test split consisting of 100 documents. For the proof-pile dataset, we use a random +subsample of 128 documents with at least 32768 SentencePiece (Kudo & Richardson, 2018) tokens +and truncate to the first 32768 tokens for each test document. We evaluate perplexity at various +context window size by using a sliding window approach following Press et al. (2022) with stride +S = 256. + +In Table 1 and Table 2, we report the perplexity results for our models and baselines on the datasets. +From the results, we found that models extended with our method enjoy a significantly improved +perplexity from longer context window sizes. By increasing the context window size from 2048 to +16384, we observed -0.28 and -0.5 reductions of perplexity for extending LLaMA 7B models on +both datasets, -0.27 and -0.48 reductions for extending LL.aMA 13B models, and -0.14 and -0.42 +reductions for extending LLaMA 33B models. For LLaMA 65B models, we observed -0.12 and +-0.3 reductions of perplexity by extending to the 8192 context window size. + +In general, we observed a consistent trend of our models achieving better perplexity with longer +context windows. This indicates our models can effectively make use of the longer context windows +to better predict next tokens in language modeling tasks. Moreover, we found this trend extends to +32768 window size without diminishing on the PG19 dataset for LLaMA 7B and 13B models. This +indicates that our method may enable extension to even longer context windows. + +In contrast, we observed that models extended via the direct fine-tuning method has shown regres- +sion (up to +0.48) or minor improvement (up to -0.12) on the perplexity at longer context windows. +This indicates that models extended this way have limited capability of making use of context win- +dows longer than their pre-trained settings. + +We saw a minor degradation of the perplexity on the original context window of 2048 for our ex- +tended models in some cases. For example, on the Proof-pile dataset, we saw a degradation ranging +from 0.01 to 0.05 across all models with extended with Position Interpolation. A small degradation +of performance within original evaluation context window is expected since Position Interpolation +forces position encodings in original context window to reside in a much narrower region, which +may negatively affect the language model’s performance. We present more benchmark results on +the original context window size in Section 3.4. + +In Table 3 we report the relationship between perplexity and the number of fine-tuning steps for +LLaMA 7B model extending to 8192 and 16384 context window sizes using Position Interpolation +evaluated on the PG19 dataset. We can see without fine-tuning (at step 0) the model can exhibit +certain language modeling capability, as indicated by < 20 perplexity for extending to 8192 context +window (in contrast, the direct extrapolation method leads to > 10% perplexity). With fine-tuning, +we observed that the perplexity improves quickly. At 200 steps the models surpassed the original +model’s perplexity on 2048 context window size, indicating the models gaining ability of effectively +using sequences longer than the pre-training settings for language modeling. At 1000 steps, we can +see the models have improved steadily and achieve a significantly better perplexity. + +3.3 MEASURING EFFECTIVE CONTEXT WINDOW SIZE THROUGH PASSKEY RETRIEVAL + +We study the effective context window size, i.e. the maximum distance of a token can effectively +attend to during inference, of our models after extension. To measure this, we follow a synthetic +evaluation task of passkey retrieval proposed by Mohtashami & Jaggi (2023). In this task, the models +are asked to recover a random passkey hidden in a long document. See Figure 3 for the format of +the document. + +Given a language model, we estimate the upper and lower bounds of effective context windows as +follows. Suppose the random passkey is k tokens away from the end of the input. When a model +persistently fails to retrieve the correct passkey value across several independent attempts, it suggests +that the effective context window size of the model is less than k. Conversely, if a model consistently +succeeds in retrieving the correct passkey value, we deduce that the effective context window size +of the model is at least k. + +We evaluate the 7B and 33B LLaMA model variants that are extended via Position Interpolation or +direct fine-tuning. For each model, we use 32 different &£ uniformly spaced in the targeted context +window L’ and run the above tests for 10 times for each k, where each time a random passkey of 5 +random digits is used. In Table 4, we report kyax as a function of the number of fine-tuning steps, + +We can see that models extended via Position Interpolation all successfully attain their desired ex- +tension objectives in terms of effective context window sizes, indicating by the effective context +window size reaching maximum kp, = L/, after merely fine-tuning for 200 steps, consistently +across both 7B and 33B model sizes and up to 32768 context windows. In contrast, LLLaMA models +that are extended via direct fine-tuning only saw a minimal increase of the effective context win- +dow size kay from 2048 to 2560, even after fine-tuning for more than 10000 steps, with no clear +indication of an acceleration in the increase of window size. + +3.4 BENCHMARKS ON ORIGINAL CONTEXT WINDOW SIZE + +We evaluate the models extended by Position Interpolation on several standard benchmark tasks +within the original context window size of 2048. The evaluation results are listed in Table 5. From +the results, we saw that models extended to 8192 produce comparable results on the original bench- +mark which is designed for a much smaller context window, with a degradation of up to 2% on +the benchmark tasks, for both 7B and 33B model sizes. Models extended to longer context win- +dows regressed more on the benchmarks, but still in reasonable ranges for most tasks. We also note +that the choice of fine-tuning datasets does not seem to lead significant difference in the benchmark +performances, which may be due to the limited number of fine-tuning steps used in our method. +The regression on benchmark tasks is consistent with our observation on perplexity regression in +Section 3.2. + +3.5 LONG DOCUMENT SUMMARIZATION + +In this task, we evaluate our models’ performance on the long document summarization task. In +particular, we consider the GovReport (Huang et al., 2021) dataset, which contains 17457 documents +for training and 972 documents for evaluation. Each document comes with a human generated +summary. We truncate all input documents to their first 15000 tokens. + +We fine-tune the LL.aMA models extended with Position Interpolation with a context window of +16384. Note the rescaling of position indices are still required during this fine-tuning step. We first +Model Size Context Window Fine-tune on BoolQ PIQA Race-M Race-H WinoGrande + +format the raw document using the prompt template in Figure 4, and then concatenate the prompt +with the ground-truth summary (truncate to 1000 tokens) associated with each document. We fine- +tune the model using the next token prediction task with the above setup for 10 epochs. The losses +from the input prompt proportion of training examples are excluded during our fine-tuning. + +We use a generation temperature of 0.5 and top, = 0.95 as our inference parameter to generate a +summarization of each document in the test set. The final output is truncated at 1000 tokens. We +used the ROUGE-1/ROUGE-2/ROUGE-L scores (Lin, 2004) as the evaluation metrics to evaluate +the models’ outputs vs the ground-truth summaries. + +In Table 6 we report our evaluation results. We have also included results from two baselines in +existing SCROLLS Leaderboard (Shaham et al., 2022; Ainslie et al., 2023). In general, we have +obtained competitive R1 score among other models with minimal tuning of hyper-parameters. This +result suggests our models with 16384 context window can effectively handle the long document +summarization task. + +=== END OF FILE === + +Question: What is the paper about? +Answer: " | sed ':a;N;$!ba;s/\n/⦅newline⦆/g' > test_prompt.txt +``` + +### Run inference + +``` +eole predict -c llama-inference.yaml -src test_prompt.txt -output test_output.txt +``` diff --git a/recipes/llama3.1/llama-inference.yaml b/recipes/llama3.1/llama-inference.yaml new file mode 100755 index 00000000..1bc672e3 --- /dev/null +++ b/recipes/llama3.1/llama-inference.yaml @@ -0,0 +1,36 @@ +transforms: [onmt_tokenize] + +transforms_configs: + onmt_tokenize: + src_subword_type: bpe + src_subword_model: "${EOLE_MODEL_DIR}/llama3.1-8b/bpe.model" + tgt_subword_type: bpe + tgt_subword_model: "${EOLE_MODEL_DIR}/llama3.1-8b/bpe.model" + gpt2_pretok: true + +# Model info +model_path: "${EOLE_MODEL_DIR}/llama3.1-8b" + +# Inference +seed: 42 +max_length: 256 +# max_length: 1 +gpu: 0 +batch_type: sents +batch_size: 4 +world_size: 1 +gpu_ranks: [0] +# world_size: 2 +# gpu_ranks: [0, 1] +# parallel_mode: "tensor_parallel" +# quant_layers: ['gate_up_proj', 'down_proj', 'up_proj', 'linear_values', 'linear_query', 'linear_keys', 'final_linear'] +# quant_type: "bnb_NF4" +compute_dtype: bf16 +#random_sampling_topk: 1 +#random_sampling_topp: 0.0 +#random_sampling_temp: 0.9 +beam_size: 1 +n_best: 1 +report_time: true +src: None + diff --git a/recipes/llama3/llama-inference.yaml b/recipes/llama3/llama-inference.yaml index d0f0611f..ebbd3e27 100755 --- a/recipes/llama3/llama-inference.yaml +++ b/recipes/llama3/llama-inference.yaml @@ -25,7 +25,7 @@ gpu_ranks: [0] # parallel_mode: "tensor_parallel" # quant_layers: ['gate_up_proj', 'down_proj', 'up_proj', 'linear_values', 'linear_query', 'linear_keys', 'final_linear'] # quant_type: "bnb_NF4" -precision: fp16 +compute_dtype: fp16 #random_sampling_topk: 1 #random_sampling_topp: 0.0 #random_sampling_temp: 0.9