-
Notifications
You must be signed in to change notification settings - Fork 246
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add CVAE in Flax * Lint files * Remove init calls * Add README * Clean up the code a bit * Move some files and add entry in sphinx docs
- Loading branch information
Showing
10 changed files
with
458 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
## Conditional Variational Autoencoder in Flax | ||
|
||
Trains a *Conditional Variational Autoencoder* (CVAE) on the MNIST data using Flax' neural network API. | ||
|
||
The model first trains a baseline to predict an entire MNIST image from a single quadrant of it (i.e., input is one quadrant of an image, output is the entire image (not the other three quadrants)). | ||
Then, in a second model, the generation/prior/recognition nets of the CVAE are trained while keeping the model parameters of the baseline fixed/frozen. | ||
We use Optax' `multi_transform` to apply different gradient transformations to the trainable parameters and the frozen parameters. | ||
|
||
Running `main.py` trains the model(s) and plots a figure in the end comparing the baseline prediction with the CVAE prediction like this one: | ||
|
||
![CVAE prediction](https://github.com/pyro-ppl/numpyro/tree/master/docs/source/_static/img/examples/cvae.png) | ||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import numpy as np | ||
|
||
from jax import lax | ||
|
||
from numpyro.examples.datasets import _load | ||
|
||
|
||
def load_dataset( | ||
dset, batch_size=None, split="train", shuffle=True, num_datapoints=None, seed=23 | ||
): | ||
data = _load(dset, num_datapoints) | ||
if isinstance(data, dict): | ||
arrays = data[split] | ||
num_records = len(arrays[0]) | ||
idxs = np.arange(num_records) | ||
if not batch_size: | ||
batch_size = num_records | ||
|
||
Y = arrays[0] | ||
X = Y.copy() | ||
_, m, n = X.shape | ||
|
||
X = X[:, (m // 2) :, : (n // 2)] | ||
arrays = (X, Y) | ||
|
||
def init(): | ||
np.random.seed(seed) | ||
return ( | ||
num_records // batch_size, | ||
np.random.permutation(idxs) if shuffle else idxs, | ||
) | ||
|
||
def get_batch(i=0, idxs=idxs): | ||
ret_idx = lax.dynamic_slice_in_dim(idxs, i * batch_size, batch_size) | ||
d = tuple( | ||
np.take(a, ret_idx, axis=0) | ||
if isinstance(a, list) | ||
else lax.index_take(a, (ret_idx,), axes=(0,)) | ||
for a in arrays | ||
) | ||
return d | ||
|
||
return init, get_batch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import argparse | ||
|
||
from data import load_dataset | ||
import matplotlib.pyplot as plt | ||
from models import BaselineNet, Decoder, Encoder, cvae_guide, cvae_model | ||
from train_baseline import train_baseline | ||
from train_cvae import train_cvae | ||
|
||
from numpyro.examples.datasets import MNIST | ||
|
||
|
||
def main(args): | ||
train_init, train_fetch = load_dataset( | ||
MNIST, batch_size=args.batch_size, split="train", seed=args.rng_seed | ||
) | ||
test_init, test_fetch = load_dataset( | ||
MNIST, batch_size=args.batch_size, split="test", seed=args.rng_seed | ||
) | ||
|
||
num_train, train_idx = train_init() | ||
num_test, test_idx = test_init() | ||
|
||
baseline = BaselineNet() | ||
baseline_params = train_baseline( | ||
baseline, | ||
num_train, | ||
train_idx, | ||
train_fetch, | ||
num_test, | ||
test_idx, | ||
test_fetch, | ||
n_epochs=args.num_epochs, | ||
) | ||
|
||
cvae_params = train_cvae( | ||
cvae_model, | ||
cvae_guide, | ||
baseline_params, | ||
num_train, | ||
train_idx, | ||
train_fetch, | ||
num_test, | ||
test_idx, | ||
test_fetch, | ||
n_epochs=args.num_epochs, | ||
) | ||
|
||
x_test, y_test = test_fetch(0, test_idx) | ||
|
||
baseline = BaselineNet() | ||
recognition_net = Encoder() | ||
generation_net = Decoder() | ||
|
||
y_hat_base = baseline.apply({"params": cvae_params["baseline$params"]}, x_test) | ||
z_loc, z_scale = recognition_net.apply( | ||
{"params": cvae_params["recognition_net$params"]}, x_test, y_hat_base | ||
) | ||
y_hat_vae = generation_net.apply( | ||
{"params": cvae_params["generation_net$params"]}, z_loc | ||
) | ||
|
||
fig, axs = plt.subplots(4, 10, figsize=(15, 5)) | ||
for i in range(10): | ||
axs[0][i].imshow(x_test[i]) | ||
axs[1][i].imshow(y_test[i]) | ||
axs[2][i].imshow(y_hat_base[i]) | ||
axs[3][i].imshow(y_hat_vae[i]) | ||
for j, label in enumerate(["Input", "Truth", "Baseline", "CVAE"]): | ||
axs[j][i].set_xticks([]) | ||
axs[j][i].set_yticks([]) | ||
if i == 0: | ||
axs[j][i].set_ylabel(label, rotation="horizontal", labelpad=40) | ||
plt.show() | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description="Conditional Variational Autoencoder on MNIST using Flax" | ||
) | ||
parser.add_argument("--batch-size", type=int, default=128) | ||
parser.add_argument("--rng_seed", type=int, default=23) | ||
parser.add_argument("--num-epochs", type=int, default=10) | ||
|
||
args = parser.parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from flax import linen as nn | ||
from jax import numpy as jnp | ||
|
||
import numpyro | ||
from numpyro.contrib.module import flax_module | ||
import numpyro.distributions as dist | ||
|
||
|
||
def cross_entropy_loss(y_pred, y): | ||
log_p = jnp.log(y_pred) | ||
log_not_p = jnp.log1p(-y_pred) | ||
return -y * log_p - (1.0 - y) * log_not_p | ||
|
||
|
||
class BaselineNet(nn.Module): | ||
hidden_size: int = 512 | ||
|
||
@nn.compact | ||
def __call__(self, x): | ||
batch_size, _, _ = x.shape | ||
y_hat = nn.relu(nn.Dense(self.hidden_size)(x.reshape(-1, 196))) | ||
y_hat = nn.relu(nn.Dense(self.hidden_size)(y_hat)) | ||
y_hat = nn.Dense(784)(y_hat) | ||
y_hat = nn.sigmoid(y_hat).reshape((-1, 28, 28)) | ||
return y_hat | ||
|
||
|
||
class Encoder(nn.Module): | ||
hidden_size: int = 512 | ||
latent_dim: int = 256 | ||
|
||
@nn.compact | ||
def __call__(self, x, y): | ||
z = jnp.concatenate([x.reshape(-1, 196), y.reshape(-1, 784)], axis=-1) | ||
hidden = nn.relu(nn.Dense(self.hidden_size)(z)) | ||
hidden = nn.relu(nn.Dense(self.hidden_size)(hidden)) | ||
z_loc = nn.Dense(self.latent_dim)(hidden) | ||
z_scale = jnp.exp(nn.Dense(self.latent_dim)(hidden)) | ||
return z_loc, z_scale | ||
|
||
|
||
class Decoder(nn.Module): | ||
hidden_size: int = 512 | ||
latent_dim: int = 256 | ||
|
||
@nn.compact | ||
def __call__(self, z): | ||
y_hat = nn.relu(nn.Dense(self.hidden_size)(z)) | ||
y_hat = nn.relu(nn.Dense(self.hidden_size)(y_hat)) | ||
y_hat = nn.Dense(784)(y_hat) | ||
y_hat = nn.sigmoid(y_hat).reshape((-1, 28, 28)) | ||
return y_hat | ||
|
||
|
||
def cvae_model(x, y=None): | ||
baseline_net = flax_module( | ||
"baseline", BaselineNet(), x=jnp.ones((1, 14, 14), dtype=jnp.float32) | ||
) | ||
prior_net = flax_module( | ||
"prior_net", | ||
Encoder(), | ||
x=jnp.ones((1, 14, 14), dtype=jnp.float32), | ||
y=jnp.ones((1, 28, 28), dtype=jnp.float32), | ||
) | ||
generation_net = flax_module( | ||
"generation_net", Decoder(), z=jnp.ones((1, 256), dtype=jnp.float32) | ||
) | ||
|
||
y_hat = baseline_net(x) | ||
z_loc, z_scale = prior_net(x, y_hat) | ||
z = numpyro.sample("z", dist.Normal(z_loc, z_scale)) | ||
loc = generation_net(z) | ||
|
||
if y is None: | ||
numpyro.deterministic("y", loc) | ||
else: | ||
numpyro.sample("y", dist.Bernoulli(loc), obs=y) | ||
return loc | ||
|
||
|
||
def cvae_guide(x, y=None): | ||
recognition_net = flax_module( | ||
"recognition_net", | ||
Encoder(), | ||
x=jnp.ones((1, 14, 14), dtype=jnp.float32), | ||
y=jnp.ones((1, 28, 28), dtype=jnp.float32), | ||
) | ||
loc, scale = recognition_net(x, y) | ||
numpyro.sample("z", dist.Normal(loc, scale)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from models import cross_entropy_loss | ||
|
||
from flax.training.train_state import TrainState | ||
import jax | ||
from jax import lax, numpy as jnp, random | ||
import optax | ||
|
||
|
||
def create_train_state(model, x, learning_rate_fn): | ||
params = model.init(random.PRNGKey(0), x) | ||
tx = optax.adam(learning_rate_fn) | ||
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx) | ||
return state | ||
|
||
|
||
def train_step(state, x_batched, y_batched): | ||
def loss_fn(params): | ||
y_pred = state.apply_fn(params, x_batched) | ||
loss = cross_entropy_loss(y_pred, y_batched) | ||
return jnp.sum(loss) | ||
|
||
grad_fn = jax.value_and_grad(loss_fn) | ||
loss, grads = grad_fn(state.params) | ||
|
||
new_state = state.apply_gradients(grads=grads) | ||
return new_state, loss | ||
|
||
|
||
def train_epoch(state, train_fetch, num_train, train_idx, epoch_rng): | ||
def _fn(i, val): | ||
state, loss_sum = val | ||
x_batched, y_batched = train_fetch(i, train_idx) | ||
state, loss = train_step(state, x_batched, y_batched) | ||
loss_sum += loss | ||
return state, loss_sum | ||
|
||
return lax.fori_loop(0, num_train, _fn, (state, 0.0)) | ||
|
||
|
||
def eval_epoch(state, test_fetch, num_test, test_idx, epoch_rng): | ||
def _fn(i, loss_sum): | ||
x_batched, y_batched = test_fetch(i, test_idx) | ||
y_pred = state.apply_fn(state.params, x_batched) | ||
loss = cross_entropy_loss(y_pred, y_batched) | ||
loss_sum += jnp.sum(loss) | ||
return loss_sum | ||
|
||
loss = lax.fori_loop(0, num_test, _fn, 0.0) | ||
loss = loss / num_test | ||
return loss | ||
|
||
|
||
def train_baseline( | ||
model, | ||
num_train, | ||
train_idx, | ||
train_fetch, | ||
num_test, | ||
test_idx, | ||
test_fetch, | ||
n_epochs=100, | ||
): | ||
|
||
state = create_train_state(model, train_fetch(0, train_idx)[0], 0.003) | ||
|
||
rng = random.PRNGKey(0) | ||
best_val_loss = jnp.inf | ||
best_state = state | ||
for i in range(n_epochs): | ||
epoch_rng = jax.random.fold_in(rng, i) | ||
state, train_loss = train_epoch( | ||
state, train_fetch, num_train, train_idx, epoch_rng | ||
) | ||
val_loss = eval_epoch(state, test_fetch, num_test, test_idx, epoch_rng) | ||
print(f"Epoch loss - train loss: {train_loss}, validation loss: {val_loss}") | ||
if val_loss < best_val_loss: | ||
best_val_loss = val_loss | ||
best_state = state | ||
|
||
return best_state.params |
Oops, something went wrong.