diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index 60c9b7514..c8c123076 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -8,10 +8,11 @@ import numpy as np +import transformers + try: import huggingface_hub import torch - import transformers except ImportError: pass @@ -1422,7 +1423,7 @@ def set_decoder(self, spec, module): @register_loader("Gemma2Config") -class GemmaLoader(ModelLoader): +class Gemma2Loader(ModelLoader): @property def architecture_name(self): return "Gemma2ForCausalLM" @@ -1494,9 +1495,7 @@ def set_decoder(self, spec, module): self.set_layer_norm(spec.layer_norm, module.norm) for layer_spec, layer in zip(spec.layer, module.layers): - self.set_layer_norm( - layer_spec.input_layer_norm, layer.input_layernorm - ) + self.set_layer_norm(layer_spec.input_layer_norm, layer.input_layernorm) self.set_layer_norm( layer_spec.post_attention_layer_norm, layer.post_attention_layernorm diff --git a/python/ctranslate2/specs/transformer_spec.py b/python/ctranslate2/specs/transformer_spec.py index bb258e24a..230e62cfd 100644 --- a/python/ctranslate2/specs/transformer_spec.py +++ b/python/ctranslate2/specs/transformer_spec.py @@ -325,9 +325,15 @@ def __init__( if pre_post_layer_norm: self.input_layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm) - self.post_attention_layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm) - self.pre_feedforward_layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm) - self.post_feedforward_layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm) + self.post_attention_layer_norm = common_spec.LayerNormSpec( + rms_norm=rms_norm + ) + self.pre_feedforward_layer_norm = common_spec.LayerNormSpec( + rms_norm=rms_norm + ) + self.post_feedforward_layer_norm = common_spec.LayerNormSpec( + rms_norm=rms_norm + ) delattr(self.self_attention, "layer_norm") delattr(self.ffn, "layer_norm") diff --git a/third_party/googletest b/third_party/googletest index b7d472f12..f8d7d77c0 160000 --- a/third_party/googletest +++ b/third_party/googletest @@ -1 +1 @@ -Subproject commit b7d472f1225c5a64943821d8483fecb469d3f382 +Subproject commit f8d7d77c06936315286eb55f8de22cd23c188571