-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathexp-ae-celeba-mafl-10.py
78 lines (66 loc) · 2.46 KB
/
exp-ae-celeba-mafl-10.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import tensorflow as tf
import os
import sys
from copy import copy
from model.pipeline import Pipeline
from tensorflow.python import debug as tf_debug
if __name__ == "__main__":
num_keypoints = 10
patch_feature_dim = 8
decoding_levels = 5
kp_transform_loss = 1e4
base_recon_weight = 0.1
recon_weight = Pipeline.ValueScheduler(
"piecewise_constant",
[100000, 200000],
[base_recon_weight, base_recon_weight*10, base_recon_weight*100]
)
base_learning_rate=0.001
learning_rate = Pipeline.ValueScheduler(
"piecewise_constant",
[100000, 200000],
[base_learning_rate, base_learning_rate*0.1, base_learning_rate*0.01]
)
keypoint_separation_bandwidth=0.08
keypoint_separation_loss_weight = 20
opt = {
"optimizer": "Adam",
"data_name": "celeba_mafl_100x100_80x80",
"recon_name": "gaussian_fixedvar_in_01",
"encoder_name": "general_80x80",
"decoder_name": "general_80x80",
"latent_dim": num_keypoints*2+(num_keypoints+1)*patch_feature_dim,
"train_color_jittering": True,
"train_random_mirroring": False,
"train_batch_size": 16,
"train_shuffle_capacity": 1000,
"learning_rate": learning_rate,
"max_epochs": 2000,
"weight_decay": 1e-6,
"test_steps": 5000,
"test_limit": 200,
"recon_weight": recon_weight,
#"keep_checkpoint_every_n_hours": 0.1
}
opt["encoder_options"] = {
"keypoint_num": num_keypoints,
"patch_feature_dim": patch_feature_dim,
"ae_recon_type": opt["recon_name"],
"keypoint_concentration_loss_weight": 100.,
"keypoint_axis_balancing_loss_weight": 200. ,
"keypoint_separation_loss_weight": keypoint_separation_loss_weight,
"keypoint_separation_bandwidth": keypoint_separation_bandwidth,
"keypoint_transform_loss_weight": kp_transform_loss,
"keypoint_decoding_heatmap_levels": decoding_levels,
"keypoint_decoding_heatmap_level_base": 0.5**(1/2),
"image_channels": 3,
}
opt["decoder_options"] = copy(opt["encoder_options"])
# -------------------------------------
model_dir = os.path.join("results/celeba_10")
vp = Pipeline(None, opt, model_dir=model_dir)
print(vp.opt)
with vp.graph.as_default():
sess = vp.create_session()
vp.run_full_train(sess, restore=True)
vp.run_full_test(sess)