From fcbd804749c4fad0df7fbd4f17c3f429f2bfbcdf Mon Sep 17 00:00:00 2001 From: giangtranml Date: Sun, 5 Apr 2020 19:09:03 +0700 Subject: [PATCH] Update GAN. --- gan/gan.py | 22 +++++++++++++++++++--- libs/utils.py | 6 ++++-- nn_components/layers.py | 2 +- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/gan/gan.py b/gan/gan.py index a1be94c..9fb8d99 100644 --- a/gan/gan.py +++ b/gan/gan.py @@ -17,6 +17,17 @@ def _backward_last(self): pass def backward(self, y, y_hat, z, discriminator): + """ + Generator don't compute directly with loss, so we need discriminator + backprop gradient from earlier layers according backward direction. + + Parameters + ---------- + y: vector ones (for optimizing fake to real), we want optimize generator parameters to fool discriminator. + y_hat: output pass from generator -> discriminator + z: random noise variables. + discriminator: discriminator network. + """ dA = discriminator.return_input_grads(y, y_hat) grads = self._backward(dA, None) self._update_params(grads) @@ -27,6 +38,9 @@ def __init__(self, optimizer:object, layers:list, loss_func:object=BinaryCrossEn super().__init__(optimizer, layers, loss_func) def return_input_grads(self, y, y_hat): + """ + Compute gradient of Loss w.r.t inputs, flow gradient to compute gradient of Loss w.r.t generator parameters. + """ dA_prev, _ = self._backward_last(y, y_hat) for i in range(len(self.layers)-1, 0, -1): backward_func = self.layers[i].backward_layer if isinstance(self.layers[i], LearnableLayer) else self.layers[i].backward @@ -34,7 +48,7 @@ def return_input_grads(self, y, y_hat): return dA_prev -def main(): +def main(digit=2): mnist_dim = 784 @@ -86,7 +100,7 @@ def main(): images, labels = mndata.load_training() images, labels = preprocess_data(images, labels, test=True) - images = images[labels == 2] + images = images[labels == digit] optimizer_G = Adam(alpha=0.006) optimizer_D = Adam(alpha=0.006) @@ -98,8 +112,10 @@ def main(): batch_size = 64 iterations = 10000 + + print("Training GAN with MNIST dataset to generate digit %d" % digit) trainerGAN = TrainerGAN(generator, discriminator, batch_size, iterations) trainerGAN.train(images) if __name__ == "__main__": - main() \ No newline at end of file + main(digit=8) \ No newline at end of file diff --git a/libs/utils.py b/libs/utils.py index de60c65..37212c1 100644 --- a/libs/utils.py +++ b/libs/utils.py @@ -14,7 +14,6 @@ from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support from tqdm import tqdm from .cifar10_lib import get_file, load_batch -import matplotlib.pyplot as plt from mpl_toolkits.axes_grid1 import ImageGrid @@ -83,7 +82,10 @@ def save_grid_images(images, iteration): # Iterating over the grid returns the Axes. ax.imshow(im, cmap="gray") - plt.savefig("gan-%d.png" % iteration) + path = os.path.abspath(".") + f = "gan-%d.png" % iteration + plt.savefig(f) + print("Image saved: %s" % os.path.join(path, f)) class Trainer: diff --git a/nn_components/layers.py b/nn_components/layers.py index 6fe8789..6823baa 100644 --- a/nn_components/layers.py +++ b/nn_components/layers.py @@ -116,7 +116,7 @@ def backward(self, d_prev, prev_layer): ---------- d_prev: gradient of J respect to A[l+1] of the previous layer according backward direction. prev_layer: previous layer according forward direction. - + Returns ------- d_prev: gradient of J respect to A[l] at the current layer.