Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CVAE in Flax #1429

Merged
merged 6 commits into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/source/_static/img/examples/cvae.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ NumPyro documentation
examples/holt_winters
examples/mortality
examples/zero_inflated_poisson
examples/cvae

.. nbgallery::
:maxdepth: 1
Expand Down
12 changes: 12 additions & 0 deletions examples/cvae-flax/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
## Conditional Variational Autoencoder in Flax

Trains a *Conditional Variational Autoencoder* (CVAE) on the MNIST data using Flax' neural network API.

The model first trains a baseline to predict an entire MNIST image from a single quadrant of it (i.e., input is one quadrant of an image, output is the entire image (not the other three quadrants)).
Then, in a second model, the generation/prior/recognition nets of the CVAE are trained while keeping the model parameters of the baseline fixed/frozen.
We use Optax' `multi_transform` to apply different gradient transformations to the trainable parameters and the frozen parameters.

Running `main.py` trains the model(s) and plots a figure in the end comparing the baseline prediction with the CVAE prediction like this one:

![CVAE prediction](https://github.com/pyro-ppl/numpyro/tree/master/docs/source/_static/img/examples/cvae.png)

Empty file added examples/cvae-flax/__init__.py
Empty file.
46 changes: 46 additions & 0 deletions examples/cvae-flax/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import numpy as np

from jax import lax

from numpyro.examples.datasets import _load


def load_dataset(
dset, batch_size=None, split="train", shuffle=True, num_datapoints=None, seed=23
):
data = _load(dset, num_datapoints)
if isinstance(data, dict):
arrays = data[split]
num_records = len(arrays[0])
idxs = np.arange(num_records)
if not batch_size:
batch_size = num_records

Y = arrays[0]
X = Y.copy()
_, m, n = X.shape

X = X[:, (m // 2) :, : (n // 2)]
arrays = (X, Y)

def init():
np.random.seed(seed)
return (
num_records // batch_size,
np.random.permutation(idxs) if shuffle else idxs,
)

def get_batch(i=0, idxs=idxs):
ret_idx = lax.dynamic_slice_in_dim(idxs, i * batch_size, batch_size)
d = tuple(
np.take(a, ret_idx, axis=0)
if isinstance(a, list)
else lax.index_take(a, (ret_idx,), axes=(0,))
for a in arrays
)
return d

return init, get_batch
88 changes: 88 additions & 0 deletions examples/cvae-flax/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import argparse

from data import load_dataset
import matplotlib.pyplot as plt
from models import BaselineNet, Decoder, Encoder, cvae_guide, cvae_model
from train_baseline import train_baseline
from train_cvae import train_cvae

from numpyro.examples.datasets import MNIST


def main(args):
train_init, train_fetch = load_dataset(
MNIST, batch_size=args.batch_size, split="train", seed=args.rng_seed
)
test_init, test_fetch = load_dataset(
MNIST, batch_size=args.batch_size, split="test", seed=args.rng_seed
)

num_train, train_idx = train_init()
num_test, test_idx = test_init()

baseline = BaselineNet()
baseline_params = train_baseline(
baseline,
num_train,
train_idx,
train_fetch,
num_test,
test_idx,
test_fetch,
n_epochs=args.num_epochs,
)

cvae_params = train_cvae(
cvae_model,
cvae_guide,
baseline_params,
num_train,
train_idx,
train_fetch,
num_test,
test_idx,
test_fetch,
n_epochs=args.num_epochs,
)

x_test, y_test = test_fetch(0, test_idx)

baseline = BaselineNet()
recognition_net = Encoder()
generation_net = Decoder()

y_hat_base = baseline.apply({"params": cvae_params["baseline$params"]}, x_test)
z_loc, z_scale = recognition_net.apply(
{"params": cvae_params["recognition_net$params"]}, x_test, y_hat_base
)
y_hat_vae = generation_net.apply(
{"params": cvae_params["generation_net$params"]}, z_loc
)

fig, axs = plt.subplots(4, 10, figsize=(15, 5))
for i in range(10):
axs[0][i].imshow(x_test[i])
axs[1][i].imshow(y_test[i])
axs[2][i].imshow(y_hat_base[i])
axs[3][i].imshow(y_hat_vae[i])
for j, label in enumerate(["Input", "Truth", "Baseline", "CVAE"]):
axs[j][i].set_xticks([])
axs[j][i].set_yticks([])
if i == 0:
axs[j][i].set_ylabel(label, rotation="horizontal", labelpad=40)
plt.show()


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Conditional Variational Autoencoder on MNIST using Flax"
)
parser.add_argument("--batch-size", type=int, default=128)
parser.add_argument("--rng_seed", type=int, default=23)
parser.add_argument("--num-epochs", type=int, default=10)

args = parser.parse_args()
main(args)
92 changes: 92 additions & 0 deletions examples/cvae-flax/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from flax import linen as nn
from jax import numpy as jnp

import numpyro
from numpyro.contrib.module import flax_module
import numpyro.distributions as dist


def cross_entropy_loss(y_pred, y):
log_p = jnp.log(y_pred)
log_not_p = jnp.log1p(-y_pred)
return -y * log_p - (1.0 - y) * log_not_p


class BaselineNet(nn.Module):
hidden_size: int = 512

@nn.compact
def __call__(self, x):
batch_size, _, _ = x.shape
y_hat = nn.relu(nn.Dense(self.hidden_size)(x.reshape(-1, 196)))
y_hat = nn.relu(nn.Dense(self.hidden_size)(y_hat))
y_hat = nn.Dense(784)(y_hat)
y_hat = nn.sigmoid(y_hat).reshape((-1, 28, 28))
return y_hat


class Encoder(nn.Module):
hidden_size: int = 512
latent_dim: int = 256

@nn.compact
def __call__(self, x, y):
z = jnp.concatenate([x.reshape(-1, 196), y.reshape(-1, 784)], axis=-1)
hidden = nn.relu(nn.Dense(self.hidden_size)(z))
hidden = nn.relu(nn.Dense(self.hidden_size)(hidden))
z_loc = nn.Dense(self.latent_dim)(hidden)
z_scale = jnp.exp(nn.Dense(self.latent_dim)(hidden))
return z_loc, z_scale


class Decoder(nn.Module):
hidden_size: int = 512
latent_dim: int = 256

@nn.compact
def __call__(self, z):
y_hat = nn.relu(nn.Dense(self.hidden_size)(z))
y_hat = nn.relu(nn.Dense(self.hidden_size)(y_hat))
y_hat = nn.Dense(784)(y_hat)
y_hat = nn.sigmoid(y_hat).reshape((-1, 28, 28))
return y_hat


def cvae_model(x, y=None):
baseline_net = flax_module(
"baseline", BaselineNet(), x=jnp.ones((1, 14, 14), dtype=jnp.float32)
)
prior_net = flax_module(
"prior_net",
Encoder(),
x=jnp.ones((1, 14, 14), dtype=jnp.float32),
y=jnp.ones((1, 28, 28), dtype=jnp.float32),
)
generation_net = flax_module(
"generation_net", Decoder(), z=jnp.ones((1, 256), dtype=jnp.float32)
)

y_hat = baseline_net(x)
z_loc, z_scale = prior_net(x, y_hat)
z = numpyro.sample("z", dist.Normal(z_loc, z_scale))
loc = generation_net(z)

if y is None:
numpyro.deterministic("y", loc)
else:
numpyro.sample("y", dist.Bernoulli(loc), obs=y)
return loc


def cvae_guide(x, y=None):
recognition_net = flax_module(
"recognition_net",
Encoder(),
x=jnp.ones((1, 14, 14), dtype=jnp.float32),
y=jnp.ones((1, 28, 28), dtype=jnp.float32),
)
loc, scale = recognition_net(x, y)
numpyro.sample("z", dist.Normal(loc, scale))
83 changes: 83 additions & 0 deletions examples/cvae-flax/train_baseline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from models import cross_entropy_loss

from flax.training.train_state import TrainState
import jax
from jax import lax, numpy as jnp, random
import optax


def create_train_state(model, x, learning_rate_fn):
params = model.init(random.PRNGKey(0), x)
tx = optax.adam(learning_rate_fn)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)
return state


def train_step(state, x_batched, y_batched):
def loss_fn(params):
y_pred = state.apply_fn(params, x_batched)
loss = cross_entropy_loss(y_pred, y_batched)
return jnp.sum(loss)

grad_fn = jax.value_and_grad(loss_fn)
loss, grads = grad_fn(state.params)

new_state = state.apply_gradients(grads=grads)
return new_state, loss


def train_epoch(state, train_fetch, num_train, train_idx, epoch_rng):
def _fn(i, val):
state, loss_sum = val
x_batched, y_batched = train_fetch(i, train_idx)
state, loss = train_step(state, x_batched, y_batched)
loss_sum += loss
return state, loss_sum

return lax.fori_loop(0, num_train, _fn, (state, 0.0))


def eval_epoch(state, test_fetch, num_test, test_idx, epoch_rng):
def _fn(i, loss_sum):
x_batched, y_batched = test_fetch(i, test_idx)
y_pred = state.apply_fn(state.params, x_batched)
loss = cross_entropy_loss(y_pred, y_batched)
loss_sum += jnp.sum(loss)
return loss_sum

loss = lax.fori_loop(0, num_test, _fn, 0.0)
loss = loss / num_test
return loss


def train_baseline(
model,
num_train,
train_idx,
train_fetch,
num_test,
test_idx,
test_fetch,
n_epochs=100,
):

state = create_train_state(model, train_fetch(0, train_idx)[0], 0.003)

rng = random.PRNGKey(0)
best_val_loss = jnp.inf
best_state = state
for i in range(n_epochs):
epoch_rng = jax.random.fold_in(rng, i)
state, train_loss = train_epoch(
state, train_fetch, num_train, train_idx, epoch_rng
)
val_loss = eval_epoch(state, test_fetch, num_test, test_idx, epoch_rng)
print(f"Epoch loss - train loss: {train_loss}, validation loss: {val_loss}")
if val_loss < best_val_loss:
best_val_loss = val_loss
best_state = state

return best_state.params
Loading