-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Zhang Yuan
committed
May 9, 2018
1 parent
b40e5ae
commit 8a56e19
Showing
4 changed files
with
188 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Training hyper parameters | ||
|
||
num_steps = 1000 | ||
hidden_size = 4 | ||
batch_size = 8 | ||
minibatch = True | ||
log_every = 10 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import net | ||
import util | ||
import config | ||
|
||
import tensorflow as tf | ||
import seaborn as sns | ||
import numpy as np | ||
|
||
sns.set(color_codes=True) | ||
|
||
seed = 42 | ||
np.random.seed(seed) | ||
tf.set_random_seed(seed) | ||
|
||
def main(): | ||
gan = net.GAN(config) | ||
net.train(gan, util.DataDistribution(), util.GeneratorDistribution(range=8), config) | ||
|
||
if __name__=="__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
import util | ||
|
||
def linear(input, output_dim, scope=None, stddev=1.0): | ||
with tf.variable_scope(scope or 'linear'): | ||
w = tf.get_variable('w', [input.get_shape()[1], output_dim], initializer=tf.random_normal_initializer(stddev=stddev)) | ||
b = tf.get_variable('b', [output_dim], initializer=tf.constant_initializer(0.0)) | ||
|
||
return tf.matmul(input, w) + b | ||
|
||
def generator(input, h_dim): | ||
h0 = tf.nn.softplus(linear(input, h_dim, 'g0')) | ||
h1 = linear(h0, 1, 'g1') | ||
return h1 | ||
|
||
def discriminator(input, h_dim, minibatch_layer=True): | ||
h0 = tf.nn.relu(linear(input, h_dim * 2, 'd0')) | ||
h1 = tf.nn.relu(linear(h0, h_dim * 2, 'd1')) | ||
|
||
# without the minibatch layer, the discriminator needs an additional layer to have enough capacity to separate the two distributions correctly | ||
if minibatch_layer: | ||
print("Activate minibatch") | ||
h2 = minibatch(h1) | ||
else: | ||
h2 = tf.nn.relu(linear(h1, h_dim * 2, scope='d2')) | ||
|
||
h3 = tf.sigmoid(linear(h2, 1, scope='d3')) | ||
return h3 | ||
|
||
def minibatch(input, num_kernels=5, kernel_dim=3): | ||
x = linear(input, num_kernels * kernel_dim, scope='minibatch', stddev=0.02) | ||
activation = tf.reshape(x, (-1, num_kernels, kernel_dim)) | ||
diffs = tf.expand_dims(activation, 3) - tf.expand_dims(tf.transpose(activation, [1, 2, 0]), 0) | ||
abs_diffs = tf.reduce_sum(tf.abs(diffs), 2) | ||
minibatch_features = tf.reduce_sum(tf.exp(-abs_diffs), 2) | ||
return tf.concat([input, minibatch_features], 1) | ||
|
||
def optimizer(loss, var_list): | ||
learning_rate = 0.001 | ||
step = tf.Variable(0, trainable=False) | ||
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=step, var_list=var_list) | ||
return optimizer | ||
|
||
def log(x): | ||
''' | ||
Sometimes discriminator outputs can reach vakyes close to (or even slightly less than) zero due to numerical rounding. | ||
This just make sure that we exclude those values so that we don't up with NaNs during optimization | ||
''' | ||
return tf.log(tf.maximum(x, 1e-5)) | ||
|
||
class GAN(object): | ||
def __init__(self, config): | ||
# This defines the generator network - it takes samples from a noise distribution as input, and passes them thourgh an MLP. | ||
with tf.variable_scope('G'): | ||
self.z = tf.placeholder(tf.float32, shape=(config.batch_size, 1)) | ||
self.G = generator(self.z, config.hidden_size) | ||
|
||
''' | ||
The discriminator tries to tell the difference between samples from the true data distribution (self.x) and the generated samples (self.z). | ||
''' | ||
self.x = tf.placeholder(tf.float32, shape=(config.batch_size, 1)) | ||
with tf.variable_scope('D'): | ||
self.D1 = discriminator(self.x, config.hidden_size, config.minibatch) | ||
|
||
with tf.variable_scope('D', reuse=True): | ||
self.D2 = discriminator(self.G, config.hidden_size, config.minibatch) | ||
|
||
# Define loss for discriminator and generator networks | ||
self.loss_d = tf.reduce_mean(-log(self.D1) - log(1 - self.D2)) | ||
self.loss_g = tf.reduce_mean(-log(self.D2)) | ||
|
||
vars = tf.trainable_variables() | ||
self.d_params = [v for v in vars if v.name.startswith('D/')] | ||
self.g_params = [v for v in vars if v.name.startswith('G/')] | ||
|
||
self.opt_d = optimizer(self.loss_d, self.d_params) | ||
self.opt_g = optimizer(self.loss_g, self.g_params) | ||
|
||
def train(model, data, gen, config): | ||
|
||
with tf.Session() as session: | ||
tf.local_variables_initializer().run() | ||
tf.global_variables_initializer().run() | ||
|
||
for step in range(config.num_steps + 1): | ||
# update discriminator | ||
x = data.sample(config.batch_size) | ||
z = gen.sample(config.batch_size) | ||
loss_d, _ = session.run([model.loss_d, model.opt_d], {model.x: np.reshape(x, (config.batch_size, 1)), model.z: np.reshape(z, (config.batch_size, 1))}) | ||
|
||
# update generator | ||
z = gen.sample(config.batch_size) | ||
loss_g, _ = session.run([model.loss_g, model.opt_g], {model.z: np.reshape(z, (config.batch_size, 1))}) | ||
|
||
if step % config.log_every == 0: | ||
log_info = "step " + str(step) + " loss_d: " + str(loss_d) + " loss_g: " + str(loss_g) | ||
print(log_info) | ||
|
||
samps = util.samples(model, session, data, gen.range, config.batch_size) | ||
util.plot_distributions(samps, gen.range) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
class DataDistribution(object): | ||
def __init__(self): | ||
self.mu = 2 | ||
self.sigma = 0.5 | ||
|
||
def sample(self, N): | ||
samples = np.random.normal(self.mu, self.sigma, N) | ||
samples.sort() | ||
return samples | ||
|
||
class GeneratorDistribution(object): | ||
def __init__(self, range): | ||
self.range = range | ||
|
||
def sample(self, N): | ||
return np.linspace(-self.range, self.range, N) + np.random.random(N) * 0.01 | ||
|
||
def samples(model, session, data, sample_range, batch_size, num_points=10000, num_bins=100): | ||
''' | ||
Return a tuple (db, pd, pg), where db is current decision boundary, pd is a histogram of samples from data distribution and pg is a histogram of generated samples. | ||
''' | ||
xs = np.linspace(-sample_range, sample_range, num_points) | ||
bins = np.linspace(-sample_range, sample_range, num_bins) | ||
|
||
#decision boundary | ||
db = np.zeros((num_points, 1)) | ||
for i in range(num_points // batch_size): | ||
db[batch_size * i: batch_size * (i+1)] = session.run(model.D1, {model.x: np.reshape(xs[batch_size * i: batch_size * (i+1)], (batch_size, 1))}) | ||
|
||
# data distribution | ||
d = data.sample(num_points) | ||
pd, _ = np.histogram(d, bins=bins, density=True) | ||
|
||
#generated samples | ||
zs = np.linspace(-sample_range, sample_range, num_points) | ||
g = np.zeros((num_points, 1)) | ||
for i in range(num_points // batch_size): | ||
g[batch_size * i: batch_size * (i+1)] = session.run(model.G, {model.z: np.reshape(zs[batch_size * i: batch_size * (i+1)], (batch_size, 1))}) | ||
pg, _ = np.histogram(g, bins=bins, density=True) | ||
|
||
return db, pd, pg | ||
|
||
def plot_distributions(samps, sample_range): | ||
db, pd, pg = samps | ||
db_x = np.linspace(-sample_range, sample_range, len(db)) | ||
p_x = np.linspace(-sample_range, sample_range, len(pd)) | ||
f, ax = plt.subplots(1) | ||
ax.plot(db_x, db, label="decision boundary") | ||
ax.set_ylim(0, 1) | ||
plt.plot(p_x, pd, label="real data") | ||
plt.plot(p_x, pg, label="generated data") | ||
plt.title('1D Generative Adversarial Network') | ||
plt.xlabel('Data values') | ||
plt.ylabel('Probability density') | ||
plt.legend() | ||
plt.show() |