Skip to content

Commit

Permalink
Update GAN.
Browse files Browse the repository at this point in the history
  • Loading branch information
giangtranml committed Apr 5, 2020
1 parent ce439b8 commit fcbd804
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 6 deletions.
22 changes: 19 additions & 3 deletions gan/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@ def _backward_last(self):
pass

def backward(self, y, y_hat, z, discriminator):
"""
Generator don't compute directly with loss, so we need discriminator
backprop gradient from earlier layers according backward direction.
Parameters
----------
y: vector ones (for optimizing fake to real), we want optimize generator parameters to fool discriminator.
y_hat: output pass from generator -> discriminator
z: random noise variables.
discriminator: discriminator network.
"""
dA = discriminator.return_input_grads(y, y_hat)
grads = self._backward(dA, None)
self._update_params(grads)
Expand All @@ -27,14 +38,17 @@ def __init__(self, optimizer:object, layers:list, loss_func:object=BinaryCrossEn
super().__init__(optimizer, layers, loss_func)

def return_input_grads(self, y, y_hat):
"""
Compute gradient of Loss w.r.t inputs, flow gradient to compute gradient of Loss w.r.t generator parameters.
"""
dA_prev, _ = self._backward_last(y, y_hat)
for i in range(len(self.layers)-1, 0, -1):
backward_func = self.layers[i].backward_layer if isinstance(self.layers[i], LearnableLayer) else self.layers[i].backward
dA_prev = backward_func(dA_prev, self.layers[i-1])
return dA_prev


def main():
def main(digit=2):

mnist_dim = 784

Expand Down Expand Up @@ -86,7 +100,7 @@ def main():
images, labels = mndata.load_training()
images, labels = preprocess_data(images, labels, test=True)

images = images[labels == 2]
images = images[labels == digit]

optimizer_G = Adam(alpha=0.006)
optimizer_D = Adam(alpha=0.006)
Expand All @@ -98,8 +112,10 @@ def main():

batch_size = 64
iterations = 10000

print("Training GAN with MNIST dataset to generate digit %d" % digit)
trainerGAN = TrainerGAN(generator, discriminator, batch_size, iterations)
trainerGAN.train(images)

if __name__ == "__main__":
main()
main(digit=8)
6 changes: 4 additions & 2 deletions libs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support
from tqdm import tqdm
from .cifar10_lib import get_file, load_batch
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid


Expand Down Expand Up @@ -83,7 +82,10 @@ def save_grid_images(images, iteration):
# Iterating over the grid returns the Axes.
ax.imshow(im, cmap="gray")

plt.savefig("gan-%d.png" % iteration)
path = os.path.abspath(".")
f = "gan-%d.png" % iteration
plt.savefig(f)
print("Image saved: %s" % os.path.join(path, f))


class Trainer:
Expand Down
2 changes: 1 addition & 1 deletion nn_components/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def backward(self, d_prev, prev_layer):
----------
d_prev: gradient of J respect to A[l+1] of the previous layer according backward direction.
prev_layer: previous layer according forward direction.
Returns
-------
d_prev: gradient of J respect to A[l] at the current layer.
Expand Down

0 comments on commit fcbd804

Please sign in to comment.