Skip to content

Commit

Permalink
Update GAN.
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhang Yuan committed May 10, 2018
1 parent d79158f commit cbcfadc
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 23 deletions.
2 changes: 1 addition & 1 deletion tensorflow/TensorGAN/GAN/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

def main():
gan = net.GAN(config)
net.train(gan, util.DataDistribution(), util.GeneratorDistribution(range=8), config)
gan.train(util.DataDistribution(), util.GeneratorDistribution(range=8), config)

if __name__=="__main__":
main()
44 changes: 22 additions & 22 deletions tensorflow/TensorGAN/GAN/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,25 +78,25 @@ def __init__(self, config):
self.opt_d = optimizer(self.loss_d, self.d_params)
self.opt_g = optimizer(self.loss_g, self.g_params)

def train(model, data, gen, config):

with tf.Session() as session:
tf.local_variables_initializer().run()
tf.global_variables_initializer().run()

for step in range(config.num_steps + 1):
# update discriminator
x = data.sample(config.batch_size)
z = gen.sample(config.batch_size)
loss_d, _ = session.run([model.loss_d, model.opt_d], {model.x: np.reshape(x, (config.batch_size, 1)), model.z: np.reshape(z, (config.batch_size, 1))})

# update generator
z = gen.sample(config.batch_size)
loss_g, _ = session.run([model.loss_g, model.opt_g], {model.z: np.reshape(z, (config.batch_size, 1))})

if step % config.log_every == 0:
log_info = "step " + str(step) + " loss_d: " + str(loss_d) + " loss_g: " + str(loss_g)
print(log_info)

samps = util.samples(model, session, data, gen.range, config.batch_size)
util.plot_distributions(samps, gen.range)
def train(self, data, gen, config):
with tf.Session() as session:
tf.local_variables_initializer().run()
tf.global_variables_initializer().run()

for step in range(config.num_steps + 1):
# update discriminator
x = data.sample(config.batch_size)
z = gen.sample(config.batch_size)
loss_d, _ = session.run([self.loss_d, self.opt_d], {self.x: np.reshape(x, (config.batch_size, 1)), self.z: np.reshape(z, (config.batch_size, 1))})

# update generator
z = gen.sample(config.batch_size)
loss_g, _ = session.run([self.loss_g, self.opt_g], {self.z: np.reshape(z, (config.batch_size, 1))})

if step % config.log_every == 0:
log_info = "step " + str(step) + " loss_d: " + str(loss_d) + " loss_g: " + str(loss_g)
print(log_info)

samps = util.samples(self, session, data, gen.range, config.batch_size)
util.plot_distributions(samps, gen.range)

0 comments on commit cbcfadc

Please sign in to comment.