Skip to content

Commit

Permalink
Add CVAE in Flax (#1429)
Browse files Browse the repository at this point in the history
* Add CVAE in Flax

* Lint files

* Remove init calls

* Add README

* Clean up the code a bit

* Move some files and add entry in sphinx docs
  • Loading branch information
dirmeier authored Jun 21, 2022
1 parent acb2cd8 commit a832572
Show file tree
Hide file tree
Showing 10 changed files with 458 additions and 0 deletions.
Binary file added docs/source/_static/img/examples/cvae.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ NumPyro documentation
examples/holt_winters
examples/mortality
examples/zero_inflated_poisson
examples/cvae

.. nbgallery::
:maxdepth: 1
Expand Down
12 changes: 12 additions & 0 deletions examples/cvae-flax/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
## Conditional Variational Autoencoder in Flax

Trains a *Conditional Variational Autoencoder* (CVAE) on the MNIST data using Flax' neural network API.

The model first trains a baseline to predict an entire MNIST image from a single quadrant of it (i.e., input is one quadrant of an image, output is the entire image (not the other three quadrants)).
Then, in a second model, the generation/prior/recognition nets of the CVAE are trained while keeping the model parameters of the baseline fixed/frozen.
We use Optax' `multi_transform` to apply different gradient transformations to the trainable parameters and the frozen parameters.

Running `main.py` trains the model(s) and plots a figure in the end comparing the baseline prediction with the CVAE prediction like this one:

![CVAE prediction](https://github.com/pyro-ppl/numpyro/tree/master/docs/source/_static/img/examples/cvae.png)

Empty file added examples/cvae-flax/__init__.py
Empty file.
46 changes: 46 additions & 0 deletions examples/cvae-flax/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import numpy as np

from jax import lax

from numpyro.examples.datasets import _load


def load_dataset(
dset, batch_size=None, split="train", shuffle=True, num_datapoints=None, seed=23
):
data = _load(dset, num_datapoints)
if isinstance(data, dict):
arrays = data[split]
num_records = len(arrays[0])
idxs = np.arange(num_records)
if not batch_size:
batch_size = num_records

Y = arrays[0]
X = Y.copy()
_, m, n = X.shape

X = X[:, (m // 2) :, : (n // 2)]
arrays = (X, Y)

def init():
np.random.seed(seed)
return (
num_records // batch_size,
np.random.permutation(idxs) if shuffle else idxs,
)

def get_batch(i=0, idxs=idxs):
ret_idx = lax.dynamic_slice_in_dim(idxs, i * batch_size, batch_size)
d = tuple(
np.take(a, ret_idx, axis=0)
if isinstance(a, list)
else lax.index_take(a, (ret_idx,), axes=(0,))
for a in arrays
)
return d

return init, get_batch
88 changes: 88 additions & 0 deletions examples/cvae-flax/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import argparse

from data import load_dataset
import matplotlib.pyplot as plt
from models import BaselineNet, Decoder, Encoder, cvae_guide, cvae_model
from train_baseline import train_baseline
from train_cvae import train_cvae

from numpyro.examples.datasets import MNIST


def main(args):
train_init, train_fetch = load_dataset(
MNIST, batch_size=args.batch_size, split="train", seed=args.rng_seed
)
test_init, test_fetch = load_dataset(
MNIST, batch_size=args.batch_size, split="test", seed=args.rng_seed
)

num_train, train_idx = train_init()
num_test, test_idx = test_init()

baseline = BaselineNet()
baseline_params = train_baseline(
baseline,
num_train,
train_idx,
train_fetch,
num_test,
test_idx,
test_fetch,
n_epochs=args.num_epochs,
)

cvae_params = train_cvae(
cvae_model,
cvae_guide,
baseline_params,
num_train,
train_idx,
train_fetch,
num_test,
test_idx,
test_fetch,
n_epochs=args.num_epochs,
)

x_test, y_test = test_fetch(0, test_idx)

baseline = BaselineNet()
recognition_net = Encoder()
generation_net = Decoder()

y_hat_base = baseline.apply({"params": cvae_params["baseline$params"]}, x_test)
z_loc, z_scale = recognition_net.apply(
{"params": cvae_params["recognition_net$params"]}, x_test, y_hat_base
)
y_hat_vae = generation_net.apply(
{"params": cvae_params["generation_net$params"]}, z_loc
)

fig, axs = plt.subplots(4, 10, figsize=(15, 5))
for i in range(10):
axs[0][i].imshow(x_test[i])
axs[1][i].imshow(y_test[i])
axs[2][i].imshow(y_hat_base[i])
axs[3][i].imshow(y_hat_vae[i])
for j, label in enumerate(["Input", "Truth", "Baseline", "CVAE"]):
axs[j][i].set_xticks([])
axs[j][i].set_yticks([])
if i == 0:
axs[j][i].set_ylabel(label, rotation="horizontal", labelpad=40)
plt.show()


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Conditional Variational Autoencoder on MNIST using Flax"
)
parser.add_argument("--batch-size", type=int, default=128)
parser.add_argument("--rng_seed", type=int, default=23)
parser.add_argument("--num-epochs", type=int, default=10)

args = parser.parse_args()
main(args)
92 changes: 92 additions & 0 deletions examples/cvae-flax/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from flax import linen as nn
from jax import numpy as jnp

import numpyro
from numpyro.contrib.module import flax_module
import numpyro.distributions as dist


def cross_entropy_loss(y_pred, y):
log_p = jnp.log(y_pred)
log_not_p = jnp.log1p(-y_pred)
return -y * log_p - (1.0 - y) * log_not_p


class BaselineNet(nn.Module):
hidden_size: int = 512

@nn.compact
def __call__(self, x):
batch_size, _, _ = x.shape
y_hat = nn.relu(nn.Dense(self.hidden_size)(x.reshape(-1, 196)))
y_hat = nn.relu(nn.Dense(self.hidden_size)(y_hat))
y_hat = nn.Dense(784)(y_hat)
y_hat = nn.sigmoid(y_hat).reshape((-1, 28, 28))
return y_hat


class Encoder(nn.Module):
hidden_size: int = 512
latent_dim: int = 256

@nn.compact
def __call__(self, x, y):
z = jnp.concatenate([x.reshape(-1, 196), y.reshape(-1, 784)], axis=-1)
hidden = nn.relu(nn.Dense(self.hidden_size)(z))
hidden = nn.relu(nn.Dense(self.hidden_size)(hidden))
z_loc = nn.Dense(self.latent_dim)(hidden)
z_scale = jnp.exp(nn.Dense(self.latent_dim)(hidden))
return z_loc, z_scale


class Decoder(nn.Module):
hidden_size: int = 512
latent_dim: int = 256

@nn.compact
def __call__(self, z):
y_hat = nn.relu(nn.Dense(self.hidden_size)(z))
y_hat = nn.relu(nn.Dense(self.hidden_size)(y_hat))
y_hat = nn.Dense(784)(y_hat)
y_hat = nn.sigmoid(y_hat).reshape((-1, 28, 28))
return y_hat


def cvae_model(x, y=None):
baseline_net = flax_module(
"baseline", BaselineNet(), x=jnp.ones((1, 14, 14), dtype=jnp.float32)
)
prior_net = flax_module(
"prior_net",
Encoder(),
x=jnp.ones((1, 14, 14), dtype=jnp.float32),
y=jnp.ones((1, 28, 28), dtype=jnp.float32),
)
generation_net = flax_module(
"generation_net", Decoder(), z=jnp.ones((1, 256), dtype=jnp.float32)
)

y_hat = baseline_net(x)
z_loc, z_scale = prior_net(x, y_hat)
z = numpyro.sample("z", dist.Normal(z_loc, z_scale))
loc = generation_net(z)

if y is None:
numpyro.deterministic("y", loc)
else:
numpyro.sample("y", dist.Bernoulli(loc), obs=y)
return loc


def cvae_guide(x, y=None):
recognition_net = flax_module(
"recognition_net",
Encoder(),
x=jnp.ones((1, 14, 14), dtype=jnp.float32),
y=jnp.ones((1, 28, 28), dtype=jnp.float32),
)
loc, scale = recognition_net(x, y)
numpyro.sample("z", dist.Normal(loc, scale))
83 changes: 83 additions & 0 deletions examples/cvae-flax/train_baseline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from models import cross_entropy_loss

from flax.training.train_state import TrainState
import jax
from jax import lax, numpy as jnp, random
import optax


def create_train_state(model, x, learning_rate_fn):
params = model.init(random.PRNGKey(0), x)
tx = optax.adam(learning_rate_fn)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)
return state


def train_step(state, x_batched, y_batched):
def loss_fn(params):
y_pred = state.apply_fn(params, x_batched)
loss = cross_entropy_loss(y_pred, y_batched)
return jnp.sum(loss)

grad_fn = jax.value_and_grad(loss_fn)
loss, grads = grad_fn(state.params)

new_state = state.apply_gradients(grads=grads)
return new_state, loss


def train_epoch(state, train_fetch, num_train, train_idx, epoch_rng):
def _fn(i, val):
state, loss_sum = val
x_batched, y_batched = train_fetch(i, train_idx)
state, loss = train_step(state, x_batched, y_batched)
loss_sum += loss
return state, loss_sum

return lax.fori_loop(0, num_train, _fn, (state, 0.0))


def eval_epoch(state, test_fetch, num_test, test_idx, epoch_rng):
def _fn(i, loss_sum):
x_batched, y_batched = test_fetch(i, test_idx)
y_pred = state.apply_fn(state.params, x_batched)
loss = cross_entropy_loss(y_pred, y_batched)
loss_sum += jnp.sum(loss)
return loss_sum

loss = lax.fori_loop(0, num_test, _fn, 0.0)
loss = loss / num_test
return loss


def train_baseline(
model,
num_train,
train_idx,
train_fetch,
num_test,
test_idx,
test_fetch,
n_epochs=100,
):

state = create_train_state(model, train_fetch(0, train_idx)[0], 0.003)

rng = random.PRNGKey(0)
best_val_loss = jnp.inf
best_state = state
for i in range(n_epochs):
epoch_rng = jax.random.fold_in(rng, i)
state, train_loss = train_epoch(
state, train_fetch, num_train, train_idx, epoch_rng
)
val_loss = eval_epoch(state, test_fetch, num_test, test_idx, epoch_rng)
print(f"Epoch loss - train loss: {train_loss}, validation loss: {val_loss}")
if val_loss < best_val_loss:
best_val_loss = val_loss
best_state = state

return best_state.params
Loading

0 comments on commit a832572

Please sign in to comment.