Skip to content

Commit

Permalink
Update 0514, add CGAN.
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhang Yuan committed May 14, 2018
1 parent f25bbe1 commit 9b29cf5
Show file tree
Hide file tree
Showing 104 changed files with 166 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tensorflow/TensorGAN/cGAN/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Data path
mnist_path = "D:/datasets/MNIST_DATA/"

# Training hyper paramters
batch_size = 64
input_dim = 100
hidden_dim = 128
16 changes: 16 additions & 0 deletions tensorflow/TensorGAN/cGAN/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import tensorflow as tf
import numpy as np

import config
import net

seed = 42
np.random.seed(seed)
tf.set_random_seed(seed)

def main():
cgan = net.CGAN(config)
cgan.train()

if __name__ == "__main__":
main()
123 changes: 123 additions & 0 deletions tensorflow/TensorGAN/cGAN/net.py
Original file line number Diff line number Diff line change
@@ -1 +1,124 @@
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
import matplotlib.pyplot as plt
import numpy as np

import config
import util

mnist = input_data.read_data_sets(config.mnist_path, one_hot=True)
x_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]

def xavier_init(size):
in_dim = size[0]
xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
return tf.random_normal(shape=size, stddev=xavier_stddev)

def log(x):
'''
Sometimes discriminator outputs can reach values close to (or even slightly less than) zero due to numerical rounding.
This just make sure that we exclude those values so that we don't up with NaNs during optimization
'''
return tf.log(tf.maximum(x, 1e-5))

with tf.variable_scope('D'):
D_W1 = tf.Variable(xavier_init([x_dim + y_dim, config.hidden_dim]))
D_b1 = tf.Variable(tf.zeros(shape=[1]))

D_W2 = tf.Variable(xavier_init([config.hidden_dim, 1]))
D_b2 = tf.Variable(tf.zeros(shape=[1]))

def discriminator(x, y):
# Concatenate x and y
inputs = tf.concat(axis=1, values=[x, y])

D_h1 = tf.nn.relu(tf.matmul(inputs, D_W1) + D_b1)
D_logit = tf.matmul(D_h1, D_W2) + D_b2
D_prob = tf.nn.sigmoid(D_logit)

return D_prob, D_logit

with tf.variable_scope('G'):
G_W1 = tf.Variable(xavier_init([config.input_dim + y_dim, config.hidden_dim]))
G_b1 = tf.Variable(tf.zeros(shape=[config.hidden_dim]))

G_W2 = tf.Variable(xavier_init([config.hidden_dim, x_dim]))
G_b2 = tf.Variable(tf.zeros(shape=[x_dim]))

def generator(z, y):
# COncatenate z and y
inputs = tf.concat(axis=1, values=[z, y])

G_h1 = tf.nn.relu(tf.matmul(inputs, G_W1) + G_b1)
G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
G_prob = tf.nn.sigmoid(G_log_prob)

return G_prob

def sample_Z(m, n):
return np.random.uniform(-1., 1., size=[m, n])

class CGAN(object):
def __init__(self, config):
self.config = config

self.x = tf.placeholder(tf.float32, shape=[None, 784])
self.y = tf.placeholder(tf.float32, shape=[None, y_dim])
self.z = tf.placeholder(tf.float32, shape=[None, config.input_dim])

self.G_sample = generator(self.z, self.y)

self.D_real, self.D_logit_real = discriminator(self.x, self.y)
self.D_fake, self.D_logit_fake = discriminator(self.G_sample, self.y)

self.D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logit_real, labels=tf.ones_like(self.D_logit_real)))
self.D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logit_fake, labels=tf.zeros_like(self.D_logit_fake)))
self.D_loss = self.D_loss_real + self.D_loss_fake
self.G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logit_fake, labels=tf.ones_like(self.D_logit_fake)))

vars = tf.trainable_variables()
self.d_params = [v for v in vars if v.name.startswith('D/')]
self.g_params = [v for v in vars if v.name.startswith('G/')]

self.D_solver = tf.train.AdamOptimizer().minimize(self.D_loss, var_list=self.d_params)
self.G_solver = tf.train.AdamOptimizer().minimize(self.G_loss, var_list=self.g_params)

def train(self):
with tf.Session() as sess:

sess.run(tf.global_variables_initializer())

if not os.path.exists('./out/'):
os.makedirs('./out/')

save_index = 0

for iter in range(100000):

batch_x, batch_y = mnist.train.next_batch(config.batch_size)
batch_z = sample_Z(config.batch_size, config.input_dim)

_, D_loss_cur = sess.run([self.D_solver, self.D_loss], feed_dict={self.x: batch_x, self.z: batch_z, self.y: batch_y})
_, G_loss_cur = sess.run([self.G_solver, self.G_loss], feed_dict={self.z: batch_z, self.y: batch_y})

if iter % 1000 == 0:
print("iter: ", iter)
print("D_loss: ", D_loss_cur)
print("G_loss: ", G_loss_cur)

samples_num = 16

z_sample = sample_Z(samples_num, config.input_dim)
y_sample = np.zeros(shape=[samples_num, y_dim])
y_sample[:, 7] = 1

samples = sess.run(self.G_sample, feed_dict={self.z: z_sample, self.y: y_sample})

fig = util.plot(samples)
save_filename = './out/' + str(save_index).zfill(3)
save_index += 1
plt.savefig(save_filename, bbox_inches='tight')
plt.close()

Binary file added tensorflow/TensorGAN/cGAN/out/000.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tensorflow/TensorGAN/cGAN/out/001.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tensorflow/TensorGAN/cGAN/out/002.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tensorflow/TensorGAN/cGAN/out/003.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tensorflow/TensorGAN/cGAN/out/004.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tensorflow/TensorGAN/cGAN/out/005.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tensorflow/TensorGAN/cGAN/out/006.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tensorflow/TensorGAN/cGAN/out/007.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tensorflow/TensorGAN/cGAN/out/008.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tensorflow/TensorGAN/cGAN/out/009.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tensorflow/TensorGAN/cGAN/out/010.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tensorflow/TensorGAN/cGAN/out/011.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tensorflow/TensorGAN/cGAN/out/012.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tensorflow/TensorGAN/cGAN/out/013.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tensorflow/TensorGAN/cGAN/out/014.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tensorflow/TensorGAN/cGAN/out/015.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tensorflow/TensorGAN/cGAN/out/016.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tensorflow/TensorGAN/cGAN/out/017.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tensorflow/TensorGAN/cGAN/out/018.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tensorflow/TensorGAN/cGAN/out/019.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tensorflow/TensorGAN/cGAN/out/020.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tensorflow/TensorGAN/cGAN/out/021.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tensorflow/TensorGAN/cGAN/out/022.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tensorflow/TensorGAN/cGAN/out/023.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tensorflow/TensorGAN/cGAN/out/024.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tensorflow/TensorGAN/cGAN/out/025.png
Binary file added tensorflow/TensorGAN/cGAN/out/026.png
Binary file added tensorflow/TensorGAN/cGAN/out/027.png
Binary file added tensorflow/TensorGAN/cGAN/out/028.png
Binary file added tensorflow/TensorGAN/cGAN/out/029.png
Binary file added tensorflow/TensorGAN/cGAN/out/030.png
Binary file added tensorflow/TensorGAN/cGAN/out/031.png
Binary file added tensorflow/TensorGAN/cGAN/out/032.png
Binary file added tensorflow/TensorGAN/cGAN/out/033.png
Binary file added tensorflow/TensorGAN/cGAN/out/034.png
Binary file added tensorflow/TensorGAN/cGAN/out/035.png
Binary file added tensorflow/TensorGAN/cGAN/out/036.png
Binary file added tensorflow/TensorGAN/cGAN/out/037.png
Binary file added tensorflow/TensorGAN/cGAN/out/038.png
Binary file added tensorflow/TensorGAN/cGAN/out/039.png
Binary file added tensorflow/TensorGAN/cGAN/out/040.png
Binary file added tensorflow/TensorGAN/cGAN/out/041.png
Binary file added tensorflow/TensorGAN/cGAN/out/042.png
Binary file added tensorflow/TensorGAN/cGAN/out/043.png
Binary file added tensorflow/TensorGAN/cGAN/out/044.png
Binary file added tensorflow/TensorGAN/cGAN/out/045.png
Binary file added tensorflow/TensorGAN/cGAN/out/046.png
Binary file added tensorflow/TensorGAN/cGAN/out/047.png
Binary file added tensorflow/TensorGAN/cGAN/out/048.png
Binary file added tensorflow/TensorGAN/cGAN/out/049.png
Binary file added tensorflow/TensorGAN/cGAN/out/050.png
Binary file added tensorflow/TensorGAN/cGAN/out/051.png
Binary file added tensorflow/TensorGAN/cGAN/out/052.png
Binary file added tensorflow/TensorGAN/cGAN/out/053.png
Binary file added tensorflow/TensorGAN/cGAN/out/054.png
Binary file added tensorflow/TensorGAN/cGAN/out/055.png
Binary file added tensorflow/TensorGAN/cGAN/out/056.png
Binary file added tensorflow/TensorGAN/cGAN/out/057.png
Binary file added tensorflow/TensorGAN/cGAN/out/058.png
Binary file added tensorflow/TensorGAN/cGAN/out/059.png
Binary file added tensorflow/TensorGAN/cGAN/out/060.png
Binary file added tensorflow/TensorGAN/cGAN/out/061.png
Binary file added tensorflow/TensorGAN/cGAN/out/062.png
Binary file added tensorflow/TensorGAN/cGAN/out/063.png
Binary file added tensorflow/TensorGAN/cGAN/out/064.png
Binary file added tensorflow/TensorGAN/cGAN/out/065.png
Binary file added tensorflow/TensorGAN/cGAN/out/066.png
Binary file added tensorflow/TensorGAN/cGAN/out/067.png
Binary file added tensorflow/TensorGAN/cGAN/out/068.png
Binary file added tensorflow/TensorGAN/cGAN/out/069.png
Binary file added tensorflow/TensorGAN/cGAN/out/070.png
Binary file added tensorflow/TensorGAN/cGAN/out/071.png
Binary file added tensorflow/TensorGAN/cGAN/out/072.png
Binary file added tensorflow/TensorGAN/cGAN/out/073.png
Binary file added tensorflow/TensorGAN/cGAN/out/074.png
Binary file added tensorflow/TensorGAN/cGAN/out/075.png
Binary file added tensorflow/TensorGAN/cGAN/out/076.png
Binary file added tensorflow/TensorGAN/cGAN/out/077.png
Binary file added tensorflow/TensorGAN/cGAN/out/078.png
Binary file added tensorflow/TensorGAN/cGAN/out/079.png
Binary file added tensorflow/TensorGAN/cGAN/out/080.png
Binary file added tensorflow/TensorGAN/cGAN/out/081.png
Binary file added tensorflow/TensorGAN/cGAN/out/082.png
Binary file added tensorflow/TensorGAN/cGAN/out/083.png
Binary file added tensorflow/TensorGAN/cGAN/out/084.png
Binary file added tensorflow/TensorGAN/cGAN/out/085.png
Binary file added tensorflow/TensorGAN/cGAN/out/086.png
Binary file added tensorflow/TensorGAN/cGAN/out/087.png
Binary file added tensorflow/TensorGAN/cGAN/out/088.png
Binary file added tensorflow/TensorGAN/cGAN/out/089.png
Binary file added tensorflow/TensorGAN/cGAN/out/090.png
Binary file added tensorflow/TensorGAN/cGAN/out/091.png
Binary file added tensorflow/TensorGAN/cGAN/out/092.png
Binary file added tensorflow/TensorGAN/cGAN/out/093.png
Binary file added tensorflow/TensorGAN/cGAN/out/094.png
Binary file added tensorflow/TensorGAN/cGAN/out/095.png
Binary file added tensorflow/TensorGAN/cGAN/out/096.png
Binary file added tensorflow/TensorGAN/cGAN/out/097.png
Binary file added tensorflow/TensorGAN/cGAN/out/098.png
Binary file added tensorflow/TensorGAN/cGAN/out/099.png
20 changes: 20 additions & 0 deletions tensorflow/TensorGAN/cGAN/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

def plot(samples):

plt.switch_backend('agg')

fig = plt.figure(figsize=(4, 4))
gs = gridspec.GridSpec(4, 4)
gs.update(wspace=0.05, hspace=0.05)

for i, sample in enumerate(samples):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

return fig

0 comments on commit 9b29cf5

Please sign in to comment.