Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some MHA and RoPE refactoring, llama-3.1 rope_scaling #91

Merged
merged 9 commits into from
Aug 30, 2024
31 changes: 19 additions & 12 deletions eole/bin/convert/convert_HF.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,16 +461,27 @@ def run(cls, args):
norm_eps = config["layer_norm_eps"]
else:
norm_eps = 1e-6
rope_config = {}
if "rope_theta" in config.keys():
rope_theta = config["rope_theta"]
else:
rope_theta = 1e4
rope_config["rotary_theta"] = config["rope_theta"]
if "rotary_dim" in config.keys():
rotary_dim = config["rotary_dim"]
rope_config["rotary_dim"] = config["rotary_dim"]
elif "partial_rotary_factor" in config.keys():
rotary_dim = int(config["partial_rotary_factor"] * (hidden_size // heads))
else:
rotary_dim = 0
rope_config["rotary_dim"] = int(
config["partial_rotary_factor"] * (hidden_size // heads)
)
if config.get("rope_scaling", None) is not None:
rope_config["scaling_type"] = config["rope_scaling"].get("rope_type", None)
rope_config["scaling_factor"] = config["rope_scaling"].get("factor", 8.0)
rope_config["low_freq_factor"] = config["rope_scaling"].get(
"low_freq_factor", 1.0
)
rope_config["high_freq_factor"] = config["rope_scaling"].get(
"high_freq_factor", 4.0
)
rope_config["original_max_position_embeddings"] = config[
"rope_scaling"
].get("original_max_position_embeddings", 8192)
if "sliding_window" in config.keys():
sliding_window = config["sliding_window"]
if sliding_window is None:
Expand Down Expand Up @@ -541,7 +552,6 @@ def run(cls, args):

add_qkvbias = False
add_ffnbias = False
rotary_interleave = False
shared_layer_norm = False
position_encoding = {
"position_encoding_type": "Rotary",
Expand All @@ -555,7 +565,6 @@ def run(cls, args):
shared_layer_norm = True
add_qkvbias = True
add_ffnbias = True
rotary_interleave = False
if arch == "GPT2LMHeadModel":
parallel_residual = False
shared_layer_norm = True
Expand Down Expand Up @@ -1003,9 +1012,7 @@ def get_weight(checkpoint, tensor_name):
layer_norm=layer_norm,
norm_eps=norm_eps,
mlp_activation_fn=mlp_activation_fn,
rotary_interleave=rotary_interleave,
rotary_theta=rope_theta,
rotary_dim=rotary_dim,
rope_config=rope_config,
sliding_window=sliding_window,
heads_kv=heads_kv,
parallel_residual=parallel_residual,
Expand Down
66 changes: 51 additions & 15 deletions eole/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,47 @@ class MeanEncoderConfig(EncoderConfig):
encoder_type: Literal["mean"] = Field(default="mean")


class RotaryPositionConfig(Config):
"""
Configuration for rotary position embeddings used in transformer models.
"""

rotary_interleave: bool = Field(
default=True,
description="Interleave the head dimensions when rotary embeddings are applied. "
"Otherwise the head dimensions are sliced in half. "
"(True=default Llama from Meta (original), "
"False= used by all HuggingFace models)",
)
rotary_theta: int = Field(
default=10000,
description="Rotary theta base length, 1e4 for Llama2.Mistral, 1e6 for Mixtral",
)
rotary_dim: int = Field(
default=0,
description="Rotary dim when model requires it to be different to head dim.",
)
scaling_type: str | None = Field(
default=None,
description="Specifies the type of RoPE scaling to be applied, if any.",
)
scaling_factor: float | None = Field(
default=8.0, description="Factor by which to scale RoPE embeddings."
)
low_freq_factor: float | None = Field(
default=1.0,
description="Scaling factor applied to the lower frequency components of RoPE.",
)
high_freq_factor: float | None = Field(
default=4.0,
description="Scaling factor applied to the higher frequency components of RoPE.",
)
original_max_position_embeddings: int | None = Field(
default=8192,
description="Original maximum position embeddings for RoPE scaling.",
)


class TransformerConfig(Config):
"""
This base TransformerConfig class regroups parameters than can
Expand All @@ -182,21 +223,6 @@ class TransformerConfig(Config):
default=ActivationFunction.relu,
description="The activation function to use in MLP layer.",
)
rotary_interleave: bool = Field(
default=True,
description="Interleave the head dimensions when rotary embeddings are applied. "
"Otherwise the head dimensions are sliced in half. "
"(True=default Llama from Meta (original), "
"False= used by all HuggingFace models)",
)
rotary_theta: int = Field(
default=10000,
description="Rotary theta base length, 1e4 for Llama2.Mistral, 1e6 for Mixtral",
)
rotary_dim: int = Field(
default=0,
description="Rotary dim when model requires it to be different to head dim.",
)
layer_norm: Literal["standard", "rms"] = Field(
default="standard",
description="Type of layer normalization in transformer architecture.",
Expand Down Expand Up @@ -244,6 +270,16 @@ class TransformerConfig(Config):
"Case 2: Max Relative Positions"
"In the case of position_encoding_type: Relative",
)
rope_config: RotaryPositionConfig | None = Field(
default=None, description="Rotary position config, if relevant."
)

@model_validator(mode="after")
def _validate_transformer_config(self):
if self.position_encoding_type == PositionEncodingType.Rotary:
if self.rope_config is None:
self.rope_config = RotaryPositionConfig()
return self


# could eole.encoders.TransformerEncoder class inherit from this? (it seems not unfortunately)
Expand Down
29 changes: 27 additions & 2 deletions eole/decoders/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
TransformerDecoderBase,
)
from eole.modules.multi_headed_attn import ContextMHA
from eole.constants import LayerNorm
from eole.constants import LayerNorm, PositionEncodingType
from eole.modules.rope import RotaryPosition


class TransformerDecoderLayer(TransformerDecoderLayerBase):
Expand Down Expand Up @@ -57,6 +58,7 @@ def _forward(
step=None,
future=False,
return_attn=False,
position_embeddings=None,
):
"""A naive forward pass for transformer decoder.

Expand All @@ -70,6 +72,7 @@ def _forward(
step (int or None): stepwise decoding counter
future (bool): If set True, do not apply future_mask.
return_attn (bool) : if set True requires attns output
position_embeddings (FloatTensor): rotary position encodings, if any

Returns:
(FloatTensor, FloatTensor):
Expand Down Expand Up @@ -98,6 +101,7 @@ def _forward(
sliding_window=self.sliding_window,
step=step,
return_attn=return_attn,
position_embeddings=position_embeddings,
)

if self.dropout_p > 0:
Expand Down Expand Up @@ -149,6 +153,9 @@ def __init__(
model_config, running_config=running_config
)

if model_config.position_encoding_type == PositionEncodingType.Rotary:
self.rope = RotaryPosition(model_config)

self.transformer_layers = nn.ModuleList(
[
TransformerDecoderLayer(
Expand Down Expand Up @@ -191,12 +198,29 @@ def forward(self, emb, **kwargs):
{"keys": torch.tensor([]), "values": torch.tensor([])},
)

if hasattr(self, "rope"):
position_embeddings = self.rope(
emb,
step=step,
device=emb.device,
offset=32
# TODO: this condition is a bit edgy and should probably be better handled
if (
step != 0
and self.transformer_layers[0].self_attn.layer_cache[0]
and self.transformer_layers[0].self_attn.flash
)
else 0,
)
else:
position_embeddings = None

with_align = kwargs.pop("with_align", False)
return_attn = with_align or kwargs.pop("return_attn", False)

attn_aligns = []

for layer in self.transformer_layers:
for i, layer in enumerate(self.transformer_layers):
emb, attn, attn_align = layer(
emb,
enc_out,
Expand All @@ -205,6 +229,7 @@ def forward(self, emb, **kwargs):
step=step,
with_align=with_align,
return_attn=return_attn,
position_embeddings=position_embeddings,
)
if attn_align is not None:
attn_aligns.append(attn_align)
Expand Down
37 changes: 35 additions & 2 deletions eole/decoders/transformer_lm_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
TransformerDecoderLayerBase,
TransformerDecoderBase,
)
from eole.constants import LayerNorm
from eole.constants import LayerNorm, PositionEncodingType
from eole.modules.rope import RotaryPosition


class TransformerLMDecoderLayer(TransformerDecoderLayerBase):
Expand All @@ -19,7 +20,15 @@ class TransformerLMDecoderLayer(TransformerDecoderLayerBase):
See TransformerDecoderLayerBase
"""

def _forward(self, layer_in, pad_mask, step=None, future=False, return_attn=False):
def _forward(
self,
layer_in,
pad_mask,
step=None,
future=False,
return_attn=False,
position_embeddings=None,
):
"""A naive forward pass for transformer decoder.

# T: could be 1 in the case of stepwise decoding or tgt_len
Expand All @@ -31,6 +40,7 @@ def _forward(self, layer_in, pad_mask, step=None, future=False, return_attn=Fals
step (int or None): stepwise decoding counter
future (bool): If set True, do not apply future_mask.
return_attn (bool): If set True return attn
position_embeddings (FloatTensor): rotary position encodings, if any

Returns:
(FloatTensor, FloatTensor):
Expand Down Expand Up @@ -59,6 +69,7 @@ def _forward(self, layer_in, pad_mask, step=None, future=False, return_attn=Fals
sliding_window=self.sliding_window,
step=step,
return_attn=return_attn,
position_embeddings=position_embeddings,
)
if self.dropout_p > 0:
attn_output = self.dropout(attn_output)
Expand Down Expand Up @@ -91,6 +102,9 @@ def __init__(
):
super(TransformerLMDecoder, self).__init__(model_config)

if model_config.position_encoding_type == PositionEncodingType.Rotary:
self.rope = RotaryPosition(model_config)

self.transformer_layers = nn.ModuleList(
[
TransformerLMDecoderLayer(
Expand All @@ -112,6 +126,24 @@ def forward(self, emb, **kwargs):
pad_mask = kwargs.pop("tgt_pad_mask", None)
assert pad_mask is not None, "TransformerLMDecoder requires a pad mask"
step = kwargs.pop("step", None)

if hasattr(self, "rope"):
position_embeddings = self.rope(
emb,
step=step,
device=emb.device,
offset=32
# TODO: this condition is a bit edgy and should probably be better handled
if (
step != 0
and self.transformer_layers[0].self_attn.layer_cache[0]
and self.transformer_layers[0].self_attn.flash
)
else 0,
)
else:
position_embeddings = None

if step == 0:
# decoding mode.
# Initialize KV and key_pad_mask cache.
Expand Down Expand Up @@ -140,6 +172,7 @@ def forward(self, emb, **kwargs):
step=step,
with_align=with_align,
return_attn=return_attn,
position_embeddings=position_embeddings,
)

emb = self.layer_norm(emb)
Expand Down
20 changes: 16 additions & 4 deletions eole/encoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from eole.encoders.encoder import EncoderBase
from eole.modules.multi_headed_attn import SelfMHA
from eole.modules.transformer_mlp import MLP
from eole.constants import LayerNorm
from eole.constants import LayerNorm, PositionEncodingType
from eole.modules.rope import RotaryPosition


class TransformerEncoderLayer(nn.Module):
Expand Down Expand Up @@ -45,18 +46,21 @@ def __init__(
running_config=running_config,
)

def forward(self, layer_in, mask):
def forward(self, layer_in, mask, position_embeddings=None):
"""
Args:
layer_in (FloatTensor): ``(batch_size, src_len, model_dim)``
mask (LongTensor): ``(batch_size, 1, src_len)``
position_embeddings (FloatTensor): rotary position encodings, if any

Returns:
(FloatTensor):
* layer_out ``(batch_size, src_len, model_dim)``
"""
norm_layer_in = self.input_layernorm(layer_in)
context, _ = self.self_attn(norm_layer_in, mask=mask)
context, _ = self.self_attn(
norm_layer_in, mask=mask, position_embeddings=position_embeddings
)
if self.dropout_p > 0:
context = self.dropout(context)
if self.parallel_residual:
Expand Down Expand Up @@ -98,6 +102,9 @@ def __init__(
):
super(TransformerEncoder, self).__init__()

if model_config.position_encoding_type == PositionEncodingType.Rotary:
self.rope = RotaryPosition(model_config)

self.transformer_layers = nn.ModuleList(
[
TransformerEncoderLayer(
Expand Down Expand Up @@ -130,8 +137,13 @@ def forward(self, emb, mask=None):
# 1 to be expanded to number of heads in MHA
# Run the forward pass of every layer of the tranformer.

if hasattr(self, "rope"):
position_embeddings = self.rope(emb, step=0, device=emb.device)
else:
position_embeddings = None

for layer in self.transformer_layers:
enc_out = layer(enc_out, mask)
enc_out = layer(enc_out, mask, position_embeddings=position_embeddings)
enc_out = self.layer_norm(enc_out)
return enc_out, None

Expand Down
Loading
Loading