-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconvert_old_checkpoint_into_new.py
55 lines (51 loc) · 2.25 KB
/
convert_old_checkpoint_into_new.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import sys
import torch
def convert_timm_swim_weight_to_hf_swin_weight(timm_weight):
embeddings_weight_mapping = {
'encoder.patch_embed.proj.weight':'encoder.embeddings.patch_embeddings.projection.weight',
'encoder.patch_embed.proj.bias' :'encoder.embeddings.patch_embeddings.projection.bias',
'encoder.patch_embed.norm.weight':'encoder.embeddings.norm.weight',
'encoder.patch_embed.norm.bias' :'encoder.embeddings.norm.bias',
}
state_dict = {}
for key, val in timm_weight.items():
if not key.startswith('encoder.'):
state_dict[key]= val
continue
if key.startswith('encoder.patch_embed'):
new_key = embeddings_weight_mapping[key]
state_dict[new_key]= val
continue
if key.startswith('encoder.norm.'):
new_key = key.replace('norm.','layernorm.')
state_dict[new_key]= val
continue
key = 'encoder.'+key
if 'qkv' in key:
q_weight, k_weight, v_weight = val.chunk(3)
new_key = key.replace('attn.qkv.','attention.self.query.')
state_dict[new_key]=q_weight
new_key = key.replace('attn.qkv.','attention.self.key.')
state_dict[new_key]=k_weight
new_key = key.replace('attn.qkv.','attention.self.value.')
state_dict[new_key]=v_weight
continue
if 'attn_mask' in key:continue
new_key = key.replace('norm1.','layernorm_before.'
).replace('attn.','attention.self.'
).replace('self.proj.','output.dense.'
).replace('mlp.fc1.','intermediate.dense.'
).replace('mlp.fc2.','output.dense.'
).replace('norm2.','layernorm_after.'
)
state_dict[new_key]= val
return state_dict
old_path = sys.argv[1]#'/mnt/data/oss_beijing/sunyu/nougat/PromptNougat/result/nougat/20240309/last.ckpt'
new_path = sys.argv[2]
weight = torch.load(old_path)
new_state_dict = {}
for key,val in weight.items():
key = key.replace('model.','')
new_state_dict[key] = val
new_state_dict = convert_timm_swim_weight_to_hf_swin_weight(new_state_dict)
torch.save(new_state_dict, new_path)