-
Notifications
You must be signed in to change notification settings - Fork 136
/
Copy pathto_pb.py
67 lines (46 loc) · 2.03 KB
/
to_pb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import pylib as py
import tensorflow as tf
import tflib as tl
import module
from tensorflow.python.framework import graph_util
# ==============================================================================
# = param =
# ==============================================================================
py.arg('--experiment_name', default='default')
args_ = py.args()
# output_dir
output_dir = py.join('output', args_.experiment_name)
# save settings
args = py.args_from_yaml(py.join(output_dir, 'settings.yml'))
args.__dict__.update(args_.__dict__)
# others
n_atts = len(args.att_names)
sess = tl.session()
sess.__enter__() # make default
# ==============================================================================
# = graph =
# ==============================================================================
def sample_graph():
# model
Genc, Gdec, _ = module.get_model(args.model, n_atts, weight_decay=args.weight_decay)
# placeholders & inputs
xa = tf.placeholder(tf.float32, shape=[None, args.crop_size, args.crop_size, 3], name='xa')
b_ = tf.placeholder(tf.float32, shape=[None, n_atts], name='b_')
# sample graph
x = Gdec(Genc(xa, training=False), b_, training=False)
x = tf.identity(x, name='xb')
sample = sample_graph()
# ==============================================================================
# = freeze =
# ==============================================================================
# checkpoint
checkpoint = tl.Checkpoint(
{v.name: v for v in tf.global_variables()},
py.join(output_dir, 'checkpoints'),
max_to_keep=1
)
checkpoint.restore().run_restore_ops()
with tf.gfile.GFile(py.join(output_dir, 'generator.pb'), 'wb') as f:
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['xb'])
f.write(constant_graph.SerializeToString())
sess.close()