Skip to content

Commit

Permalink
Some MHA and RoPE refactoring, llama-3.1 rope_scaling (#91)
Browse files Browse the repository at this point in the history
  • Loading branch information
francoishernandez authored Aug 30, 2024
1 parent 7a4757f commit b81cce1
Show file tree
Hide file tree
Showing 10 changed files with 809 additions and 157 deletions.
33 changes: 21 additions & 12 deletions eole/bin/convert/convert_HF.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,16 +466,27 @@ def run(cls, args):
norm_eps = config["layer_norm_eps"]
else:
norm_eps = 1e-6
rope_config = {}
if "rope_theta" in config.keys():
rope_theta = config["rope_theta"]
else:
rope_theta = 1e4
rope_config["rotary_theta"] = config["rope_theta"]
if "rotary_dim" in config.keys():
rotary_dim = config["rotary_dim"]
rope_config["rotary_dim"] = config["rotary_dim"]
elif "partial_rotary_factor" in config.keys():
rotary_dim = int(config["partial_rotary_factor"] * (hidden_size // heads))
else:
rotary_dim = 0
rope_config["rotary_dim"] = int(
config["partial_rotary_factor"] * (hidden_size // heads)
)
if config.get("rope_scaling", None) is not None:
rope_config["scaling_type"] = config["rope_scaling"].get("rope_type", None)
rope_config["scaling_factor"] = config["rope_scaling"].get("factor", 8.0)
rope_config["low_freq_factor"] = config["rope_scaling"].get(
"low_freq_factor", 1.0
)
rope_config["high_freq_factor"] = config["rope_scaling"].get(
"high_freq_factor", 4.0
)
rope_config["original_max_position_embeddings"] = config[
"rope_scaling"
].get("original_max_position_embeddings", 8192)
if "sliding_window" in config.keys():
sliding_window = config["sliding_window"]
if sliding_window is None:
Expand Down Expand Up @@ -546,8 +557,8 @@ def run(cls, args):

add_qkvbias = False
add_ffnbias = False
rotary_interleave = False
shared_layer_norm = False
rope_config["rotary_interleave"] = False
position_encoding = {
"position_encoding_type": "Rotary",
"n_positions": 0,
Expand All @@ -560,7 +571,7 @@ def run(cls, args):
shared_layer_norm = True
add_qkvbias = True
add_ffnbias = True
rotary_interleave = False
rope_config["rotary_interleave"] = False
if arch == "GPT2LMHeadModel":
parallel_residual = False
shared_layer_norm = True
Expand Down Expand Up @@ -1008,9 +1019,7 @@ def get_weight(checkpoint, tensor_name):
layer_norm=layer_norm,
norm_eps=norm_eps,
mlp_activation_fn=mlp_activation_fn,
rotary_interleave=rotary_interleave,
rotary_theta=rope_theta,
rotary_dim=rotary_dim,
rope_config=rope_config,
sliding_window=sliding_window,
heads_kv=heads_kv,
head_dim=head_dim,
Expand Down
66 changes: 51 additions & 15 deletions eole/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,47 @@ class MeanEncoderConfig(EncoderConfig):
encoder_type: Literal["mean"] = Field(default="mean")


class RotaryPositionConfig(Config):
"""
Configuration for rotary position embeddings used in transformer models.
"""

rotary_interleave: bool = Field(
default=True,
description="Interleave the head dimensions when rotary embeddings are applied. "
"Otherwise the head dimensions are sliced in half. "
"(True=default Llama from Meta (original), "
"False= used by all HuggingFace models)",
)
rotary_theta: int = Field(
default=10000,
description="Rotary theta base length, 1e4 for Llama2.Mistral, 1e6 for Mixtral",
)
rotary_dim: int = Field(
default=0,
description="Rotary dim when model requires it to be different to head dim.",
)
scaling_type: str | None = Field(
default=None,
description="Specifies the type of RoPE scaling to be applied, if any.",
)
scaling_factor: float | None = Field(
default=8.0, description="Factor by which to scale RoPE embeddings."
)
low_freq_factor: float | None = Field(
default=1.0,
description="Scaling factor applied to the lower frequency components of RoPE.",
)
high_freq_factor: float | None = Field(
default=4.0,
description="Scaling factor applied to the higher frequency components of RoPE.",
)
original_max_position_embeddings: int | None = Field(
default=8192,
description="Original maximum position embeddings for RoPE scaling.",
)


class TransformerConfig(Config):
"""
This base TransformerConfig class regroups parameters than can
Expand All @@ -183,21 +224,6 @@ class TransformerConfig(Config):
default=ActivationFunction.relu,
description="The activation function to use in MLP layer.",
)
rotary_interleave: bool = Field(
default=True,
description="Interleave the head dimensions when rotary embeddings are applied. "
"Otherwise the head dimensions are sliced in half. "
"(True=default Llama from Meta (original), "
"False= used by all HuggingFace models)",
)
rotary_theta: int = Field(
default=10000,
description="Rotary theta base length, 1e4 for Llama2.Mistral, 1e6 for Mixtral",
)
rotary_dim: int = Field(
default=0,
description="Rotary dim when model requires it to be different to head dim.",
)
layer_norm: Literal["standard", "rms"] = Field(
default="standard",
description="Type of layer normalization in transformer architecture.",
Expand Down Expand Up @@ -249,6 +275,16 @@ class TransformerConfig(Config):
"Case 2: Max Relative Positions"
"In the case of position_encoding_type: Relative",
)
rope_config: RotaryPositionConfig | None = Field(
default=None, description="Rotary position config, if relevant."
)

@model_validator(mode="after")
def _validate_transformer_config(self):
if self.position_encoding_type == PositionEncodingType.Rotary:
if self.rope_config is None:
self.rope_config = RotaryPositionConfig()
return self

@computed_field
@property
Expand Down
19 changes: 18 additions & 1 deletion eole/decoders/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
TransformerDecoderBase,
)
from eole.modules.multi_headed_attn import ContextMHA
from eole.constants import LayerNorm
from eole.constants import LayerNorm, PositionEncodingType
from eole.modules.rope import RotaryPosition


class TransformerDecoderLayer(TransformerDecoderLayerBase):
Expand Down Expand Up @@ -57,6 +58,7 @@ def _forward(
step=None,
future=False,
return_attn=False,
position_embeddings=None,
):
"""A naive forward pass for transformer decoder.
Expand All @@ -70,6 +72,7 @@ def _forward(
step (int or None): stepwise decoding counter
future (bool): If set True, do not apply future_mask.
return_attn (bool) : if set True requires attns output
position_embeddings (FloatTensor): rotary position encodings, if any
Returns:
(FloatTensor, FloatTensor):
Expand Down Expand Up @@ -98,6 +101,7 @@ def _forward(
sliding_window=self.sliding_window,
step=step,
return_attn=return_attn,
position_embeddings=position_embeddings,
)

if self.dropout_p > 0:
Expand Down Expand Up @@ -149,6 +153,9 @@ def __init__(
model_config, running_config=running_config
)

if model_config.position_encoding_type == PositionEncodingType.Rotary:
self.rope = RotaryPosition(model_config)

self.transformer_layers = nn.ModuleList(
[
TransformerDecoderLayer(
Expand Down Expand Up @@ -191,6 +198,15 @@ def forward(self, emb, **kwargs):
{"keys": torch.tensor([]), "values": torch.tensor([])},
)

if hasattr(self, "rope"):
position_embeddings = self.rope(
emb,
step=step,
device=emb.device,
)
else:
position_embeddings = None

with_align = kwargs.pop("with_align", False)
return_attn = with_align or kwargs.pop("return_attn", False)

Expand All @@ -205,6 +221,7 @@ def forward(self, emb, **kwargs):
step=step,
with_align=with_align,
return_attn=return_attn,
position_embeddings=position_embeddings,
)
if attn_align is not None:
attn_aligns.append(attn_align)
Expand Down
33 changes: 27 additions & 6 deletions eole/decoders/transformer_lm_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
TransformerDecoderLayerBase,
TransformerDecoderBase,
)
from eole.constants import LayerNorm
from eole.constants import LayerNorm, PositionEncodingType
from eole.modules.rope import RotaryPosition


class TransformerLMDecoderLayer(TransformerDecoderLayerBase):
Expand All @@ -19,7 +20,15 @@ class TransformerLMDecoderLayer(TransformerDecoderLayerBase):
See TransformerDecoderLayerBase
"""

def _forward(self, layer_in, pad_mask, step=None, future=False, return_attn=False):
def _forward(
self,
layer_in,
pad_mask,
step=None,
future=False,
return_attn=False,
position_embeddings=None,
):
"""A naive forward pass for transformer decoder.
# T: could be 1 in the case of stepwise decoding or tgt_len
Expand All @@ -31,6 +40,7 @@ def _forward(self, layer_in, pad_mask, step=None, future=False, return_attn=Fals
step (int or None): stepwise decoding counter
future (bool): If set True, do not apply future_mask.
return_attn (bool): If set True return attn
position_embeddings (FloatTensor): rotary position encodings, if any
Returns:
(FloatTensor, FloatTensor):
Expand Down Expand Up @@ -59,6 +69,7 @@ def _forward(self, layer_in, pad_mask, step=None, future=False, return_attn=Fals
sliding_window=self.sliding_window,
step=step,
return_attn=return_attn,
position_embeddings=position_embeddings,
)
if self.dropout_p > 0:
attn_output = self.dropout(attn_output)
Expand Down Expand Up @@ -91,6 +102,9 @@ def __init__(
):
super(TransformerLMDecoder, self).__init__(model_config)

if model_config.position_encoding_type == PositionEncodingType.Rotary:
self.rope = RotaryPosition(model_config)

self.transformer_layers = nn.ModuleList(
[
TransformerLMDecoderLayer(
Expand All @@ -112,6 +126,16 @@ def forward(self, emb, **kwargs):
pad_mask = kwargs.pop("tgt_pad_mask", None)
assert pad_mask is not None, "TransformerLMDecoder requires a pad mask"
step = kwargs.pop("step", None)

if hasattr(self, "rope"):
position_embeddings = self.rope(
emb,
step=step,
device=emb.device,
)
else:
position_embeddings = None

if step == 0:
# decoding mode.
# Initialize KV and key_pad_mask cache.
Expand Down Expand Up @@ -140,6 +164,7 @@ def forward(self, emb, **kwargs):
step=step,
with_align=with_align,
return_attn=return_attn,
position_embeddings=position_embeddings,
)

emb = self.layer_norm(emb)
Expand All @@ -160,7 +185,3 @@ def _init_cache(self, device, mask):
"key_pad_mask": mask,
},
)
if hasattr(layer.self_attn, "rope"):
layer.self_attn.rope = layer.self_attn.rope.to(device)
layer.self_attn.cos = layer.self_attn.cos.to(device)
layer.self_attn.sin = layer.self_attn.sin.to(device)
20 changes: 16 additions & 4 deletions eole/encoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from eole.encoders.encoder import EncoderBase
from eole.modules.multi_headed_attn import SelfMHA
from eole.modules.transformer_mlp import MLP
from eole.constants import LayerNorm
from eole.constants import LayerNorm, PositionEncodingType
from eole.modules.rope import RotaryPosition


class TransformerEncoderLayer(nn.Module):
Expand Down Expand Up @@ -45,18 +46,21 @@ def __init__(
running_config=running_config,
)

def forward(self, layer_in, mask):
def forward(self, layer_in, mask, position_embeddings=None):
"""
Args:
layer_in (FloatTensor): ``(batch_size, src_len, model_dim)``
mask (LongTensor): ``(batch_size, 1, src_len)``
position_embeddings (FloatTensor): rotary position encodings, if any
Returns:
(FloatTensor):
* layer_out ``(batch_size, src_len, model_dim)``
"""
norm_layer_in = self.input_layernorm(layer_in)
context, _ = self.self_attn(norm_layer_in, mask=mask)
context, _ = self.self_attn(
norm_layer_in, mask=mask, position_embeddings=position_embeddings
)
if self.dropout_p > 0:
context = self.dropout(context)
if self.parallel_residual:
Expand Down Expand Up @@ -98,6 +102,9 @@ def __init__(
):
super(TransformerEncoder, self).__init__()

if model_config.position_encoding_type == PositionEncodingType.Rotary:
self.rope = RotaryPosition(model_config)

self.transformer_layers = nn.ModuleList(
[
TransformerEncoderLayer(
Expand Down Expand Up @@ -130,8 +137,13 @@ def forward(self, emb, mask=None):
# 1 to be expanded to number of heads in MHA
# Run the forward pass of every layer of the tranformer.

if hasattr(self, "rope"):
position_embeddings = self.rope(emb, step=0, device=emb.device)
else:
position_embeddings = None

for layer in self.transformer_layers:
enc_out = layer(enc_out, mask)
enc_out = layer(enc_out, mask, position_embeddings=position_embeddings)
enc_out = self.layer_norm(enc_out)
return enc_out, None

Expand Down
Loading

0 comments on commit b81cce1

Please sign in to comment.