-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtf2pytorch.py
254 lines (210 loc) · 10.9 KB
/
tf2pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
mapping = {
"G_mapping/Dense0/weight":"style.1.weight",
"G_mapping/Dense0/bias":"style.1.bias",
"G_mapping/Dense1/weight":"style.2.weight",
"G_mapping/Dense1/bias":"style.2.bias",
"G_mapping/Dense2/weight":"style.3.weight",
"G_mapping/Dense2/bias":"style.3.bias",
"G_mapping/Dense3/weight":"style.4.weight",
"G_mapping/Dense3/bias":"style.4.bias",
"G_mapping/Dense4/weight":"style.5.weight",
"G_mapping/Dense4/bias":"style.5.bias",
"G_mapping/Dense5/weight":"style.6.weight",
"G_mapping/Dense5/bias":"style.6.bias",
"G_mapping/Dense6/weight":"style.7.weight",
"G_mapping/Dense6/bias":"style.7.bias",
"G_mapping/Dense7/weight":"style.8.weight",
"G_mapping/Dense7/bias":"style.8.bias",
"G_mapping/Dense8/weight":"style.9.weight",
"G_mapping/Dense8/bias":"style.9.bias",
"G_synthesis/4x4/Const/const":"input.input",
"G_synthesis/4x4/Conv/weight":"conv1.conv.weight",
"G_synthesis/4x4/Conv/mod_weight":"conv1.conv.modulation.weight",
"G_synthesis/4x4/Conv/mod_bias":"conv1.conv.modulation.bias",
"G_synthesis/4x4/Conv/noise_strength":"conv1.noise.weight",
"G_synthesis/4x4/Conv/bias":"conv1.activate.bias",
"G_synthesis/4x4/ToRGB/bias":"to_rgb1.bias",
"G_synthesis/4x4/ToRGB/weight":"to_rgb1.conv.weight",
"G_synthesis/4x4/ToRGB/mod_weight":"to_rgb1.conv.modulation.weight",
"G_synthesis/4x4/ToRGB/mod_bias":"to_rgb1.conv.modulation.bias",
"G_synthesis/8x8/Conv0_up/weight":"convs.0.conv.weight",
# "":"convs.0.conv.blur.kernel",
"G_synthesis/8x8/Conv0_up/mod_weight":"convs.0.conv.modulation.weight",
"G_synthesis/8x8/Conv0_up/mod_bias":"convs.0.conv.modulation.bias",
"G_synthesis/8x8/Conv0_up/noise_strength":"convs.0.noise.weight",
"G_synthesis/8x8/Conv0_up/bias":"convs.0.activate.bias",
"G_synthesis/8x8/Conv1/weight":"convs.1.conv.weight",
"G_synthesis/8x8/Conv1/mod_weight":"convs.1.conv.modulation.weight",
"G_synthesis/8x8/Conv1/mod_bias":"convs.1.conv.modulation.bias",
"G_synthesis/8x8/Conv1/noise_strength":"convs.1.noise.weight",
"G_synthesis/8x8/Conv1/bias":"convs.1.activate.bias",
"G_synthesis/16x16/Conv0_up/weight":"convs.2.conv.weight",
# "":"convs.0.conv.blur.kernel",
"G_synthesis/16x16/Conv0_up/mod_weight":"convs.2.conv.modulation.weight",
"G_synthesis/16x16/Conv0_up/mod_bias":"convs.2.conv.modulation.bias",
"G_synthesis/16x16/Conv0_up/noise_strength":"convs.2.noise.weight",
"G_synthesis/16x16/Conv0_up/bias":"convs.2.activate.bias",
"G_synthesis/16x16/Conv1/weight":"convs.3.conv.weight",
"G_synthesis/16x16/Conv1/mod_weight":"convs.3.conv.modulation.weight",
"G_synthesis/16x16/Conv1/mod_bias":"convs.3.conv.modulation.bias",
"G_synthesis/16x16/Conv1/noise_strength":"convs.3.noise.weight",
"G_synthesis/16x16/Conv1/bias":"convs.3.activate.bias",
"G_synthesis/32x32/Conv0_up/weight":"convs.4.conv.weight",
# "":"convs.0.conv.blur.kernel",
"G_synthesis/32x32/Conv0_up/mod_weight":"convs.4.conv.modulation.weight",
"G_synthesis/32x32/Conv0_up/mod_bias":"convs.4.conv.modulation.bias",
"G_synthesis/32x32/Conv0_up/noise_strength":"convs.4.noise.weight",
"G_synthesis/32x32/Conv0_up/bias":"convs.4.activate.bias",
"G_synthesis/32x32/Conv1/weight":"convs.5.conv.weight",
"G_synthesis/32x32/Conv1/mod_weight":"convs.5.conv.modulation.weight",
"G_synthesis/32x32/Conv1/mod_bias":"convs.5.conv.modulation.bias",
"G_synthesis/32x32/Conv1/noise_strength":"convs.5.noise.weight",
"G_synthesis/32x32/Conv1/bias":"convs.5.activate.bias",
"G_synthesis/64x64/Conv0_up/weight":"convs.6.conv.weight",
# "":"convs.0.conv.blur.kernel",
"G_synthesis/64x64/Conv0_up/mod_weight":"convs.6.conv.modulation.weight",
"G_synthesis/64x64/Conv0_up/mod_bias":"convs.6.conv.modulation.bias",
"G_synthesis/64x64/Conv0_up/noise_strength":"convs.6.noise.weight",
"G_synthesis/64x64/Conv0_up/bias":"convs.6.activate.bias",
"G_synthesis/64x64/Conv1/weight":"convs.7.conv.weight",
"G_synthesis/64x64/Conv1/mod_weight":"convs.7.conv.modulation.weight",
"G_synthesis/64x64/Conv1/mod_bias":"convs.7.conv.modulation.bias",
"G_synthesis/64x64/Conv1/noise_strength":"convs.7.noise.weight",
"G_synthesis/64x64/Conv1/bias":"convs.7.activate.bias",
"G_synthesis/128x128/Conv0_up/weight":"convs.8.conv.weight",
# "":"convs.0.conv.blur.kernel",
"G_synthesis/128x128/Conv0_up/mod_weight":"convs.8.conv.modulation.weight",
"G_synthesis/128x128/Conv0_up/mod_bias":"convs.8.conv.modulation.bias",
"G_synthesis/128x128/Conv0_up/noise_strength":"convs.8.noise.weight",
"G_synthesis/128x128/Conv0_up/bias":"convs.8.activate.bias",
"G_synthesis/128x128/Conv1/weight":"convs.9.conv.weight",
"G_synthesis/128x128/Conv1/mod_weight":"convs.9.conv.modulation.weight",
"G_synthesis/128x128/Conv1/mod_bias":"convs.9.conv.modulation.bias",
"G_synthesis/128x128/Conv1/noise_strength":"convs.9.noise.weight",
"G_synthesis/128x128/Conv1/bias":"convs.9.activate.bias",
"G_synthesis/256x256/Conv0_up/weight":"convs.10.conv.weight",
# "":"convs.0.conv.blur.kernel",
"G_synthesis/256x256/Conv0_up/mod_weight":"convs.10.conv.modulation.weight",
"G_synthesis/256x256/Conv0_up/mod_bias":"convs.10.conv.modulation.bias",
"G_synthesis/256x256/Conv0_up/noise_strength":"convs.10.noise.weight",
"G_synthesis/256x256/Conv0_up/bias":"convs.10.activate.bias",
"G_synthesis/256x256/Conv1/weight":"convs.11.conv.weight",
"G_synthesis/256x256/Conv1/mod_weight":"convs.11.conv.modulation.weight",
"G_synthesis/256x256/Conv1/mod_bias":"convs.11.conv.modulation.bias",
"G_synthesis/256x256/Conv1/noise_strength":"convs.11.noise.weight",
"G_synthesis/256x256/Conv1/bias":"convs.11.activate.bias",
"G_synthesis/512x512/Conv0_up/weight":"convs.12.conv.weight",
# "":"convs.0.conv.blur.kernel",
"G_synthesis/512x512/Conv0_up/mod_weight":"convs.12.conv.modulation.weight",
"G_synthesis/512x512/Conv0_up/mod_bias":"convs.12.conv.modulation.bias",
"G_synthesis/512x512/Conv0_up/noise_strength":"convs.12.noise.weight",
"G_synthesis/512x512/Conv0_up/bias":"convs.12.activate.bias",
"G_synthesis/512x512/Conv1/weight":"convs.13.conv.weight",
"G_synthesis/512x512/Conv1/mod_weight":"convs.13.conv.modulation.weight",
"G_synthesis/512x512/Conv1/mod_bias":"convs.13.conv.modulation.bias",
"G_synthesis/512x512/Conv1/noise_strength":"convs.13.noise.weight",
"G_synthesis/512x512/Conv1/bias":"convs.13.activate.bias",
"G_synthesis/8x8/ToRGB/bias":"to_rgbs.0.bias",
# "":"to_rgbs.0.upsample.kernel",
"G_synthesis/8x8/ToRGB/weight":"to_rgbs.0.conv.weight",
"G_synthesis/8x8/ToRGB/mod_weight":"to_rgbs.0.conv.modulation.weight",
"G_synthesis/8x8/ToRGB/mod_bias":"to_rgbs.0.conv.modulation.bias",
"G_synthesis/16x16/ToRGB/bias":"to_rgbs.1.bias",
# "":"to_rgbs.1.upsample.kernel",
"G_synthesis/16x16/ToRGB/weight":"to_rgbs.1.conv.weight",
"G_synthesis/16x16/ToRGB/mod_weight":"to_rgbs.1.conv.modulation.weight",
"G_synthesis/16x16/ToRGB/mod_bias":"to_rgbs.1.conv.modulation.bias",
"G_synthesis/32x32/ToRGB/bias":"to_rgbs.2.bias",
# "":"to_rgbs.2.upsample.kernel",
"G_synthesis/32x32/ToRGB/weight":"to_rgbs.2.conv.weight",
"G_synthesis/32x32/ToRGB/mod_weight":"to_rgbs.2.conv.modulation.weight",
"G_synthesis/32x32/ToRGB/mod_bias":"to_rgbs.2.conv.modulation.bias",
"G_synthesis/64x64/ToRGB/bias":"to_rgbs.3.bias",
# "":"to_rgbs.3.upsample.kernel",
"G_synthesis/64x64/ToRGB/weight":"to_rgbs.3.conv.weight",
"G_synthesis/64x64/ToRGB/mod_weight":"to_rgbs.3.conv.modulation.weight",
"G_synthesis/64x64/ToRGB/mod_bias":"to_rgbs.3.conv.modulation.bias",
"G_synthesis/128x128/ToRGB/bias":"to_rgbs.4.bias",
# "":"to_rgbs.4.upsample.kernel",
"G_synthesis/128x128/ToRGB/weight":"to_rgbs.4.conv.weight",
"G_synthesis/128x128/ToRGB/mod_weight":"to_rgbs.4.conv.modulation.weight",
"G_synthesis/128x128/ToRGB/mod_bias":"to_rgbs.4.conv.modulation.bias",
"G_synthesis/256x256/ToRGB/bias":"to_rgbs.5.bias",
# "":"to_rgbs.5.upsample.kernel",
"G_synthesis/256x256/ToRGB/weight":"to_rgbs.5.conv.weight",
"G_synthesis/256x256/ToRGB/mod_weight":"to_rgbs.5.conv.modulation.weight",
"G_synthesis/256x256/ToRGB/mod_bias":"to_rgbs.5.conv.modulation.bias",
"G_synthesis/512x512/ToRGB/bias":"to_rgbs.6.bias",
# "":"to_rgbs.5.upsample.kernel",
"G_synthesis/512x512/ToRGB/weight":"to_rgbs.6.conv.weight",
"G_synthesis/512x512/ToRGB/mod_weight":"to_rgbs.6.conv.modulation.weight",
"G_synthesis/512x512/ToRGB/mod_bias":"to_rgbs.6.conv.modulation.bias",
"G_synthesis/noise0":"noises.noise_0",
"G_synthesis/noise1":"noises.noise_1",
"G_synthesis/noise2":"noises.noise_2",
"G_synthesis/noise3":"noises.noise_3",
"G_synthesis/noise4":"noises.noise_4",
"G_synthesis/noise5":"noises.noise_5",
"G_synthesis/noise6":"noises.noise_6",
"G_synthesis/noise7":"noises.noise_7",
"G_synthesis/noise8":"noises.noise_8",
"G_synthesis/noise9":"noises.noise_9",
"G_synthesis/noise10":"noises.noise_10",
"G_synthesis/noise11":"noises.noise_11",
"G_synthesis/noise12":"noises.noise_12",
"G_synthesis/noise13":"noises.noise_13",
"G_synthesis/noise14":"noises.noise_14",
"lod":"lod",
"dlatent_avg":"dlatent_avg"
}
import pickle
import dnnlib
import dnnlib.tflib as tflib
import torch
import numpy as np
import re
import argparse
def main(tf_cp, output_path, compare = False, torch_cp = None):
tflib.init_tf()
with open(tf_cp, 'rb') as f:
G, D, Gs = pickle.load(f, encoding='latin1')
if compare:
data2 = torch.load(torch_cp)
output = {}
for k,v in Gs.vars.items():
data = v.eval()
if len(data.shape) == 0:
output[mapping[k]] = torch.tensor([data])
else:
if re.match(r'.*modulation.bias',mapping[k]):
data = data + 1
elif re.match(r'.*modulation.weight',mapping[k]):
data = np.transpose(data)
elif re.match(r'.*conv.weight',mapping[k]):
if len(data.shape) == 4:
if mapping[k] in ['convs.0.conv.weight','convs.2.conv.weight','convs.4.conv.weight','convs.6.conv.weight','convs.8.conv.weight','convs.10.conv.weight','convs.12.conv.weight']:
data = np.transpose(data, (3, 2, 0, 1))[:,:,[2,1,0],:][:,:,:,[2,1,0]]
data = np.expand_dims(data, 0)
else:
data = np.expand_dims(np.transpose(data, (-1, -2, 0, 1)),0)
elif re.match(r'to_rgb.*bias',mapping[k]):
data = np.expand_dims(np.expand_dims(np.expand_dims(data, 0), -1), -1)
elif re.match(r'style.*weight', mapping[k]):
data = np.transpose(data)
output[mapping[k]] = torch.tensor(data)
if compare:
diff = 0
for k,v in output.items():
if k in data2['g_ema']:
diff += np.mean(np.abs(data2['g_ema'][k].numpy() - v.numpy()))
print(diff)
with open(output_path, 'wb') as f:
pickle.dump(output, f)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--tf_cp', type=str, default='stylegan2-car-config-f.pkl')
parser.add_argument('--compare', type=bool, default=False)
parser.add_argument('--torch_cp', type=str, default='stylegan2_networks_stylegan2-car-config-f.pt')
parser.add_argument('--output_path', type=str, default='stylegan2_pytorch.pkl')
args = parser.parse_args()
main(args.tf_cp, args.output_path, args.compare, args.torch_cp)