Skip to content

Commit

Permalink
fix convert_HF
Browse files Browse the repository at this point in the history
  • Loading branch information
francoishernandez committed Aug 28, 2024
1 parent 4eb4853 commit 2604c7c
Showing 1 changed file with 14 additions and 28 deletions.
42 changes: 14 additions & 28 deletions eole/bin/convert/convert_HF.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,28 +461,27 @@ def run(cls, args):
norm_eps = config["layer_norm_eps"]
else:
norm_eps = 1e-6
rope_config = {}
if "rope_theta" in config.keys():
rope_theta = config["rope_theta"]
else:
rope_theta = 1e4
rope_config["rotary_theta"] = config["rope_theta"]
if "rotary_dim" in config.keys():
rotary_dim = config["rotary_dim"]
rope_config["rotary_dim"] = config["rotary_dim"]
elif "partial_rotary_factor" in config.keys():
rotary_dim = int(config["partial_rotary_factor"] * (hidden_size // heads))
else:
rotary_dim = 0
if "rope_scaling" in config.keys():
rope_scaling_type = config["rope_scaling"].get("rope_type", None)
rope_scaling_factor = config["rope_scaling"].get("factor", 8.0)
rope_scaling_low_freq_factor = config["rope_scaling"].get(
rope_config["rotary_dim"] = int(
config["partial_rotary_factor"] * (hidden_size // heads)
)
if config.get("rope_scaling", None) is not None:
rope_config["scaling_type"] = config["rope_scaling"].get("rope_type", None)
rope_config["scaling_factor"] = config["rope_scaling"].get("factor", 8.0)
rope_config["low_freq_factor"] = config["rope_scaling"].get(
"low_freq_factor", 1.0
)
rope_scaling_high_freq_factor = config["rope_scaling"].get(
rope_config["high_freq_factor"] = config["rope_scaling"].get(
"high_freq_factor", 4.0
)
rope_scaling_original_max_position_embeddings = config["rope_scaling"].get(
"original_max_position_embeddings", 8192
)
rope_config["original_max_position_embeddings"] = config[
"rope_scaling"
].get("original_max_position_embeddings", 8192)
if "sliding_window" in config.keys():
sliding_window = config["sliding_window"]
if sliding_window is None:
Expand Down Expand Up @@ -984,19 +983,6 @@ def get_weight(checkpoint, tensor_name):
for tok in vocab_dict["src"]:
vocabfile.write(tok + "\n")

rope_config = None
if position_encoding["position_encoding_type"] == "Rotary":
rope_config = {
"rotary_theta": rope_theta,
"rotary_dim": rotary_dim,
"rotary_interleave": rotary_interleave,
"scaling_type": rope_scaling_type,
"scaling_factor": rope_scaling_factor,
"low_freq_factor": rope_scaling_low_freq_factor,
"high_freq_factor": rope_scaling_high_freq_factor,
"original_max_position_embeddings": rope_scaling_original_max_position_embeddings,
}

config = TrainConfig(
data=None,
skip_empty_level="silent", # default is "warning"
Expand Down

0 comments on commit 2604c7c

Please sign in to comment.