-
Notifications
You must be signed in to change notification settings - Fork 510
/
Copy pathconvert_weights.py
289 lines (239 loc) · 10 KB
/
convert_weights.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import re
from typing import Any, Dict
import torch
# state dict key mappings from Meta's format to torchtune's format
_FROM_META = {
"tok_embeddings.weight": "tok_embeddings.weight",
"norm.weight": "norm.scale",
"output.weight": "output.weight",
"layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight",
"layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight",
"layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight",
"layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight",
"layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale",
"layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale",
"layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight",
"layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight",
"layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight",
}
# state dict key mappings from HF's format to torchtune's format
_FROM_HF = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attn.k_proj.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attn.v_proj.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight",
"model.layers.{}.self_attn.rotary_emb.inv_freq": None,
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.w1.weight",
"model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale",
"model.norm.weight": "norm.scale",
"lm_head.weight": "output.weight",
}
def get_mapped_key(key: str, mapping_dict: Dict[str, str]) -> str:
try:
if "layers" in key:
# Replace layer number with "{}" to create key for lookup
abstract_key = re.sub(r"(\.\d+)", ".{}", key)
layer_num = re.search(r"\d+", key).group(0)
new_key = mapping_dict[abstract_key]
new_key = new_key.format(layer_num)
else:
new_key = mapping_dict[key]
except KeyError as e:
raise Exception(
f'Error converting the state dict. Found unexpected key: "{key}". '
"Please make sure you're loading a checkpoint with the right format. "
) from e
return new_key
def meta_to_tune(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Convert a state dict from Meta's format to torchtune's format. State dicts
from multiple checkpoint files should be consolidated into a single state dict
before calling this function.
Eg of Meta-format state dict can be found in the ``meta-llama/Llama-2-7b``
repo in HF (https://huggingface.co/meta-llama/Llama-2-7b).
Args:
state_dict (Dict[str, torch.Tensor]): State dict in Meta's format.
Returns:
Dict[str, torch.Tensor]: State dict in torchtune's format.
"""
converted_state_dict = {}
for key, value in state_dict.items():
if key not in ["rope.freqs"]: # Skip loading the position embeddings
new_key = get_mapped_key(key, _FROM_META)
converted_state_dict[new_key] = value
return converted_state_dict
def tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Convert a state dict from torchtune's format to Meta's format. This function
doesn't handle any sharding or splitting of state dicts. It follows the
state_dict IN -> state_dict OUT pattern.
Args:
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.
Returns:
Dict[str, torch.Tensor]: State dict in Meta's format.
"""
converted_state_dict = {}
inverted_mapping_dict = {v: k for k, v in _FROM_META.items()}
for key, value in state_dict.items():
new_key = get_mapped_key(key, inverted_mapping_dict)
converted_state_dict[new_key] = value
return converted_state_dict
def hf_to_tune(
state_dict: Dict[str, torch.Tensor],
num_heads: int = 32,
num_kv_heads: int = 32,
dim: int = 4096,
head_dim: int = None,
) -> Dict[str, torch.Tensor]:
"""
Convert a state dict from HF's format to torchtune's format. State dicts
from multiple checkpoint files should be consolidated into a single state dict
before calling this function.
Eg of HF-format state dict can be found in the ``meta-llama/Llama-2-7b-hf``
repo in HF (https://huggingface.co/meta-llama/Llama-2-7b-hf).
Args:
state_dict (Dict[str, torch.Tensor]): State dict in HF's format.
num_heads (int): Number of heads in the model.
num_kv_heads (int): Number of heads in the key/value projection layers.
dim (int): Dimension of the model.
head_dim (int): Dimension of the head. If not provided, it will be calculated
as dim // num_heads.
Returns:
Dict[str, torch.Tensor]: State dict in torchtune's format.
"""
converted_state_dict = {}
if head_dim is None:
head_dim = dim // num_heads
def _permute(t, n_heads):
return (
t.view(n_heads, 2, head_dim // 2, dim)
.transpose(1, 2)
.reshape((head_dim * n_heads), dim)
)
for key, value in state_dict.items():
if "rotary_emb.inv_freq" not in key: # Skip loading the position embeddings
new_key = get_mapped_key(key, _FROM_HF)
if "q_proj" in key:
value = _permute(value, num_heads)
elif "k_proj" in key:
value = _permute(value, num_kv_heads)
converted_state_dict[new_key] = value
return converted_state_dict
def tune_to_hf(
state_dict: Dict[str, torch.Tensor],
num_heads: int = 32,
num_kv_heads: int = 32,
dim: int = 4096,
head_dim: int = None,
):
"""
Convert a state dict from torchtune's format to HF's format. This function
doesn't handle any sharding or splitting of state dicts. It follows the
state_dict IN -> state_dict OUT pattern.
Args:
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.
num_heads (int): Number of heads in the model.
num_kv_heads (int): Number of heads in the key/value projection layers.
dim (int): Dimension of the model.
head_dim (int): Dimension of model attention heads. Default None.
Returns:
Dict[str, torch.Tensor]: State dict in HF's format.
"""
converted_state_dict = {}
inverted_mapping_dict = {v: k for k, v in _FROM_HF.items()}
if head_dim is None:
head_dim = dim // num_heads
def _permute(t, n_heads):
return (
t.view(n_heads, head_dim // 2, 2, dim)
.transpose(1, 2)
.reshape((head_dim * n_heads), dim)
)
for key, value in state_dict.items():
new_key = get_mapped_key(key, inverted_mapping_dict)
if "q_proj" in key:
value = _permute(value, num_heads)
elif "k_proj" in key:
value = _permute(value, num_kv_heads)
converted_state_dict[new_key] = value
return converted_state_dict
# Mapping from torchtune LoRA module names to PEFT LoRA module names
_TO_PEFT_KEYS = {
"lora_a": "lora_A",
"lora_b": "lora_B",
}
# Mapping from torchtune module names to target modules for PEFT adapter config
_TO_PEFT_TARGET_MODULES = {
"q_proj": "q_proj",
"k_proj": "k_proj",
"v_proj": "v_proj",
"output_proj": "o_proj",
"w1": "gate_proj",
"w2": "down_proj",
"w3": "up_proj",
"output": "lm_head",
}
# Keys expected in PEFT's adapter_config.json
_PEFT_CONFIG_EXPECTED_KEYS = ["target_modules", "r", "lora_alpha"]
def tune_to_peft_adapter_config(
adapter_config: Dict[str, Any],
):
if not all([x in adapter_config.keys() for x in _PEFT_CONFIG_EXPECTED_KEYS]):
raise ValueError(
f"PEFT adapter config requires {_PEFT_CONFIG_EXPECTED_KEYS}, found {adapter_config.keys()}"
)
for k in adapter_config["target_modules"]:
if k not in _TO_PEFT_TARGET_MODULES:
raise ValueError(f"Unknown target module {k}")
adapter_config["target_modules"] = list(
map(_TO_PEFT_TARGET_MODULES.get, adapter_config["target_modules"])
)
return adapter_config
def tune_to_peft_adapter_weights(
state_dict: Dict[str, torch.Tensor],
num_heads: int = 32,
num_kv_heads: int = 32,
dim: int = 4096,
head_dim: int = None,
):
converted_state_dict = {}
full_mapping = {}
# Rather than recreate a separate mapping for LoRA adapter weights, we just
# re-use the _FROM_HF mapping for base model weights. We iterate over it twice:
# once to add mappings for LoRA A matrices and once to add mappings for LoRA B matrices.
for k, v in _TO_PEFT_KEYS.items():
full_mapping.update(
{
vv.replace(".weight", f".{k}.weight"): kk.replace(
".weight", f".{v}.weight"
)
for kk, vv in _FROM_HF.items()
if vv is not None
}
)
if head_dim is None:
head_dim = dim // num_heads
def _permute_lora_matrix(t, n_heads):
rank = t.shape[-1]
return (
t.view(n_heads, head_dim // 2, 2, rank)
.transpose(1, 2)
.reshape((head_dim * n_heads), rank)
)
for key, value in state_dict.items():
new_key = get_mapped_key(key, full_mapping)
if "q_proj" in new_key and "lora_B" in new_key:
value = _permute_lora_matrix(value, num_heads)
elif "k_proj" in new_key and "lora_B" in new_key:
value = _permute_lora_matrix(value, num_kv_heads)
converted_state_dict["base_model.model." + new_key] = value
return converted_state_dict