Skip to content

Commit

Permalink
support minimum gemma 2 (#1772)
Browse files Browse the repository at this point in the history
* support minimum gemma 2

* fix ci

* fix ci
  • Loading branch information
minhthuc2502 authored Sep 6, 2024
1 parent 6647945 commit f89fa2b
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 2 deletions.
2 changes: 2 additions & 0 deletions include/ctranslate2/layers/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ namespace ctranslate2 {
const std::unique_ptr<const LayerNorm> _shared_layer_norm;
const std::unique_ptr<const LayerNorm> _input_layer_norm;
const std::unique_ptr<const LayerNorm> _post_attention_layer_norm;
const std::unique_ptr<const LayerNorm> _pre_feedforward_layer_norm;
const std::unique_ptr<const LayerNorm> _post_feedforward_layer_norm;
const std::unique_ptr<const AttentionLayer> _encoder_attention;
const FeedForwardNetwork _ff;
};
Expand Down
104 changes: 104 additions & 0 deletions python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,110 @@ def set_decoder(self, spec, module):
gc.collect()


@register_loader("Gemma2Config")
class Gemma2Loader(ModelLoader):
@property
def architecture_name(self):
return "Gemma2ForCausalLM"

def get_model_spec(self, model):
num_layers = model.config.num_hidden_layers

num_heads = model.config.num_attention_heads
num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
if num_heads_kv == num_heads:
num_heads_kv = None

activation_config = getattr(
model.config, "hidden_activation", "gelu_pytorch_tanh"
)

spec = transformer_spec.TransformerDecoderModelSpec.from_config(
num_layers,
num_heads,
activation=(
common_spec.Activation.GELU
if activation_config == "gelu"
else common_spec.Activation.GELUTanh
),
pre_norm=True,
ffn_glu=True,
rms_norm=True,
rotary_dim=0,
rotary_interleave=False,
rotary_base=getattr(model.config, "rope_theta", 10000),
num_heads_kv=num_heads_kv,
head_dim=model.config.head_dim,
pre_post_layer_norm=True,
)

self.set_decoder(spec.decoder, model.model)
self.set_linear(spec.decoder.projection, model.lm_head)
spec.decoder.embeddings.multiply_by_sqrt_depth = model.config.hidden_size**0.5
return spec

def get_vocabulary(self, model, tokenizer):
tokens = super().get_vocabulary(model, tokenizer)

extra_ids = model.config.vocab_size - len(tokens)
for i in range(extra_ids):
tokens.append("<extra_id_%d>" % i)
if model.config.vocab_size < len(tokens):
tokens = tokens[: model.config.vocab_size]

return tokens

def set_vocabulary(self, spec, tokens):
spec.register_vocabulary(tokens)

def set_config(self, config, model, tokenizer):
config.bos_token = tokenizer.bos_token
config.eos_token = tokenizer.eos_token
config.unk_token = tokenizer.unk_token
config.layer_norm_epsilon = model.config.rms_norm_eps

def set_layer_norm(self, spec, layer_norm):
spec.gamma = layer_norm.weight
spec.layer_norm_use_residual = True

def set_decoder(self, spec, module):
spec.scale_embeddings = True
spec.start_from_zero_embedding = False
self.set_embeddings(spec.embeddings, module.embed_tokens)
self.set_layer_norm(spec.layer_norm, module.norm)

for layer_spec, layer in zip(spec.layer, module.layers):
self.set_layer_norm(layer_spec.input_layer_norm, layer.input_layernorm)

self.set_layer_norm(
layer_spec.post_attention_layer_norm, layer.post_attention_layernorm
)

self.set_layer_norm(
layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm
)

self.set_layer_norm(
layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm
)

wq = layer.self_attn.q_proj.weight
wk = layer.self_attn.k_proj.weight
wv = layer.self_attn.v_proj.weight
wo = layer.self_attn.o_proj.weight

layer_spec.self_attention.linear[0].weight = torch.cat([wq, wk, wv])
layer_spec.self_attention.linear[1].weight = wo

self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj)
self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj)
self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj)

delattr(layer, "self_attn")
delattr(layer, "mlp")
gc.collect()


@register_loader("LlamaConfig")
class LlamaLoader(ModelLoader):
@property
Expand Down
22 changes: 22 additions & 0 deletions python/ctranslate2/specs/transformer_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(
max_position_embeddings: int = 0,
parallel_residual: bool = False,
shared_layer_norm: bool = False,
pre_post_layer_norm: bool = False,
multi_query_attention: bool = False,
num_heads_kv: Optional[int] = None,
head_dim: Optional[int] = None,
Expand Down Expand Up @@ -147,6 +148,7 @@ def __init__(
by the GPT-J and GPT-NeoX models.
shared_layer_norm: When using parallel residual, share the input and post
attention layer norms.
pre_post_layer_norm: Add post layer norm for each pre norm layer
multi_query_attention: Use multi-query attention (alias for num_heads_kv=1).
num_heads_kv: Number of attention heads for the key and value.
sliding_window: Max sequence length to retain in KV Cache.
Expand Down Expand Up @@ -216,6 +218,7 @@ def __init__(
max_position_embeddings=max_position_embeddings,
parallel_residual=parallel_residual,
shared_layer_norm=shared_layer_norm,
pre_post_layer_norm=pre_post_layer_norm,
num_heads_kv=num_heads_kv,
head_dim=head_dim,
sliding_window=sliding_window,
Expand Down Expand Up @@ -279,6 +282,7 @@ def __init__(
max_position_embeddings=0,
parallel_residual=False,
shared_layer_norm=False,
pre_post_layer_norm=False,
num_heads_kv=None,
head_dim=None,
sliding_window=None,
Expand Down Expand Up @@ -319,6 +323,21 @@ def __init__(
delattr(self.self_attention, "layer_norm")
delattr(self.ffn, "layer_norm")

if pre_post_layer_norm:
self.input_layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm)
self.post_attention_layer_norm = common_spec.LayerNormSpec(
rms_norm=rms_norm
)
self.pre_feedforward_layer_norm = common_spec.LayerNormSpec(
rms_norm=rms_norm
)
self.post_feedforward_layer_norm = common_spec.LayerNormSpec(
rms_norm=rms_norm
)

delattr(self.self_attention, "layer_norm")
delattr(self.ffn, "layer_norm")


class FeedForwardSpec(model_spec.LayerSpec):
def __init__(self, glu=False, rms_norm=False):
Expand Down Expand Up @@ -530,6 +549,7 @@ def from_config(
max_position_embeddings: int = 0,
parallel_residual: bool = False,
shared_layer_norm: bool = False,
pre_post_layer_norm: bool = False,
multi_query_attention: bool = False,
num_heads_kv: Optional[int] = None,
head_dim: Optional[int] = None,
Expand Down Expand Up @@ -570,6 +590,7 @@ def from_config(
by the GPT-J and GPT-NeoX models.
shared_layer_norm: When using parallel residual, share the input and post
attention layer norms.
pre_post_layer_norm: add post layer norm for each pre norm layer
multi_query_attention: Use multi-query attention (alias for num_heads_kv=1).
num_heads_kv: Number of attention heads for the key and value.
head_dim: Number of head
Expand Down Expand Up @@ -602,6 +623,7 @@ def from_config(
max_position_embeddings=max_position_embeddings,
parallel_residual=parallel_residual,
shared_layer_norm=shared_layer_norm,
pre_post_layer_norm=pre_post_layer_norm,
multi_query_attention=multi_query_attention,
num_heads_kv=num_heads_kv,
head_dim=head_dim,
Expand Down
2 changes: 1 addition & 1 deletion src/layers/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ namespace ctranslate2 {
if (queries_padder)
queries_padder->add_padding(fused_proj);

const ops::Split split_op(2, {_d_model, _num_heads_kv * _d_head, _num_heads_kv * _d_head});
const ops::Split split_op(2, {_num_heads * _d_head, _num_heads_kv * _d_head, _num_heads_kv * _d_head});
split_op(fused_proj, queries_proj, keys_proj, values_proj);

if (_merge_time_and_head_dims) {
Expand Down
39 changes: 39 additions & 0 deletions src/layers/transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ namespace ctranslate2 {
, _input_layer_norm(build_optional_layer<LayerNorm>(model, scope + "/input_layer_norm"))
, _post_attention_layer_norm(build_optional_layer<LayerNorm>(
model, scope + "/post_attention_layer_norm"))
, _pre_feedforward_layer_norm(build_optional_layer<LayerNorm>(
model, scope + "/pre_feedforward_layer_norm"))
, _post_feedforward_layer_norm(build_optional_layer<LayerNorm>(
model, scope + "/post_feedforward_layer_norm"))
, _encoder_attention(build_optional_layer<MultiHeadAttention>(model,
scope + "/attention",
num_heads,
Expand Down Expand Up @@ -149,6 +153,41 @@ namespace ctranslate2 {
const DataType dtype = input.dtype();
const Device device = input.device();

const bool pre_post_layer_norm = _post_feedforward_layer_norm && _pre_feedforward_layer_norm;
if (pre_post_layer_norm) {
StorageView hidden(dtype, device);
StorageView context(dtype, device);
(*_input_layer_norm)(input, hidden);

if (_self_attention)
(*_self_attention)(hidden,
hidden,
input_length,
context,
cached_self_attn_keys,
cached_self_attn_values,
nullptr,
input_padder,
input_padder,
true,
position_bias,
offset);

(*_post_attention_layer_norm)(context, output);
ops::Add()(output, input, output);

context = std::move(output);
(*_pre_feedforward_layer_norm)(context, output);
hidden = std::move(output);

_ff(hidden, output);

hidden = std::move(output);
(*_post_feedforward_layer_norm)(hidden, output);
ops::Add()(output, context, output);
return;
}

const bool use_parallel_residual = _shared_layer_norm || _input_layer_norm;

if (use_parallel_residual) {
Expand Down
2 changes: 1 addition & 1 deletion third_party/googletest
Submodule googletest updated 245 files

0 comments on commit f89fa2b

Please sign in to comment.