-
Notifications
You must be signed in to change notification settings - Fork 43
/
mnist_ebgan_train.py
92 lines (65 loc) · 1.94 KB
/
mnist_ebgan_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import sugartensor as tf
import numpy as np
from model import *
__author__ = '[email protected]'
# set log level to debug
tf.sg_verbosity(10)
#
# hyper parameters
#
batch_size = 128 # batch size
#
# inputs
#
# MNIST input tensor ( with QueueRunner )
data = tf.sg_data.Mnist(batch_size=batch_size)
# input images
x = data.train.image
# random uniform seed
z = tf.random_uniform((batch_size, z_dim))
#
# Computational graph
#
# generator
gen = generator(z)
# add image summary
tf.sg_summary_image(x, name='real')
tf.sg_summary_image(gen, name='fake')
# discriminator
disc_real = discriminator(x)
disc_fake = discriminator(gen)
#
# pull-away term ( PT ) regularizer
#
sample = gen.sg_flatten()
nom = tf.matmul(sample, tf.transpose(sample, perm=[1, 0]))
denom = tf.reduce_sum(tf.square(sample), reduction_indices=[1], keep_dims=True)
pt = tf.square(nom/denom)
pt -= tf.diag(tf.diag_part(pt))
pt = tf.reduce_sum(pt) / (batch_size * (batch_size - 1))
#
# loss & train ops
#
# mean squared errors
mse_real = tf.reduce_mean(tf.square(disc_real - x), reduction_indices=[1, 2, 3])
mse_fake = tf.reduce_mean(tf.square(disc_fake - gen), reduction_indices=[1, 2, 3])
# discriminator loss
loss_disc = mse_real + tf.maximum(margin - mse_fake, 0)
# generator loss + PT regularizer
loss_gen = mse_fake + pt * pt_weight
train_disc = tf.sg_optim(loss_disc, lr=0.001, category='discriminator') # discriminator train ops
train_gen = tf.sg_optim(loss_gen, lr=0.001, category='generator') # generator train ops
# add summary
tf.sg_summary_loss(loss_disc, name='disc')
tf.sg_summary_loss(loss_gen, name='gen')
#
# training
#
# def alternate training func
@tf.sg_train_func
def alt_train(sess, opt):
l_disc = sess.run([loss_disc, train_disc])[0] # training discriminator
l_gen = sess.run([loss_gen, train_gen])[0] # training generator
return np.mean(l_disc) + np.mean(l_gen)
# do training
alt_train(log_interval=10, max_ep=30, ep_size=data.train.num_batch)