Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First implementation of a Variation Autoencoder CV #27

Merged
merged 1 commit into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlcvs/core/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@

from .mse import MSE_loss
from .tda_loss import TDA_loss
from .eigvals import reduce_eigenvalues
from .eigvals import reduce_eigenvalues
from .elbo import elbo_gaussians_loss
73 changes: 73 additions & 0 deletions mlcvs/core/loss/elbo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#!/usr/bin/env python

# =============================================================================
# MODULE DOCSTRING
# =============================================================================

"""
Evidence Lower BOund (ELBO) loss functions used to train variational Autoencoders.
"""

__all__ = ['elbo_gaussians_loss']


# =============================================================================
# GLOBAL IMPORTS
# =============================================================================

from typing import Optional
import torch
from mlcvs.core.loss.mse import MSE_loss


# =============================================================================
# LOSS FUNCTIONS
# =============================================================================

def elbo_gaussians_loss(
diff: torch.Tensor,
mean: torch.Tensor,
log_variance: torch.Tensor,
weights: Optional[torch.Tensor] = None
):
"""ELBO loss function assuming the latent and reconstruction distributions are Gaussian.

The ELBO uses the MSE as the reconstruction loss (i.e., assumes that the
decoder outputs the mean of a Gaussian distribution with variance 1), and
the KL divergence between two normal distributions ``N(mean, var)`` and
``N(0, 1)``, where ``mean`` and ``var`` are the output of the encoder.

Parameters
----------
diff : torch.Tensor
Shape ``(n_batches, in_features)``. The difference between the input of
the encoder and the output of the decoder.
mean : torch.Tensor
Shape ``(n_batches, latent_features)``. The means of the Gaussian
distributions associated to the inputs.
log_variance : torch.Tensor
Shape ``(n_batches, latent_features)``. The logarithm of the variances
of the Gaussian distributions associated to the inputs.
weights : torch.Tensor, optional
Shape ``(n_batches,)`. If given, the average over batches is weighted.
The default (``None``) is unweighted.

Returns
-------
loss: torch.Tensor
The value of the loss function.
"""
# KL divergence between N(mean, variance) and N(0, 1).
# See https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians
kl = -0.5 * (log_variance - log_variance.exp() - mean**2 + 1).sum(dim=1)

# Weighted mean over batches.
if weights is None:
kl = kl.mean()
else:
kl = (kl * weights).sum()

# Reconstruction loss.
reconstruction = MSE_loss(diff, weights=weights)

return reconstruction + kl
200 changes: 200 additions & 0 deletions mlcvs/cvs/unsupervised/vae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
#!/usr/bin/env python

# =============================================================================
# MODULE DOCSTRING
# =============================================================================

"""
Variational Autoencoder collective variable.
"""

__all__ = ["VAE_CV"]


# =============================================================================
# GLOBAL IMPORTS
# =============================================================================

from typing import Any, Optional, Tuple
import torch
import pytorch_lightning as pl
from mlcvs.cvs import BaseCV
from mlcvs.core import FeedForward, Normalization
from mlcvs.core.loss import elbo_gaussians_loss


# =============================================================================
# VARIATIONAL AUTOENCODER CV
# =============================================================================

class VAE_CV(BaseCV, pl.LightningModule):
"""Variational AutoEncoder Collective Variable.

At training time, the encoder outputs a mean and a variance for each CV
defining a Gaussian distribution associated to the input. One sample is
drawn from this Gaussian, and it goes through the decoder. Then the ELBO
loss is minimized. The ELBO sums the MSE of the reconstruction and the KL
divergence between the generated Gaussian and a N(0, 1) Gaussian.

At evaluation time, the encoder's output mean is used as the CV, while the
variance output and the decoder are ignored.

For training, it requires a DictionaryDataset with the key ``'data'`` and
optionally ``'weights'``.
"""

BLOCKS = ['normIn', 'encoder', 'decoder']

def __init__(self,
n_cvs : int,
encoder_layers : list,
decoder_layers : Optional[list] = None,
options : Optional[dict] = None,
**kwargs):
"""
Variational autoencoder constructor.

Parameters
----------
n_cvs : int
The dimension of the CV or, equivalently, the dimension of the latent
space of the autoencoder.
encoder_layers : list
Number of neurons per layer of the encoder up to the last hidden layer.
The size of the output layer is instead specified with ``n_cvs``
decoder_layers : list, optional
Number of neurons per layer of the decoder, except for the input layer
which is specified by ``n_cvs``. If ``None`` (default), it takes automatically
the reversed architecture of the encoder.
options : dict[str, Any], optional
Options for the building blocks of the model, by default ``None``.
Available blocks are: ``'normIn'``, ``'encoder'``, and ``'decoder'``.
Set ``'block_name' = None`` or ``False`` to turn off a block. Encoder
and decoder cannot be turned off.
"""
super().__init__(in_features=encoder_layers[0], out_features=n_cvs, **kwargs)

# ===== BLOCKS =====

options = self.sanitize_options(options)

# parse info from args
if decoder_layers is None:
decoder_layers = encoder_layers[::-1]

# initialize normIn
o = 'normIn'
if ( options[o] is not False ) and (options[o] is not None):
self.normIn = Normalization(self.in_features, **options[o])

# initialize encoder
# The encoder outputs two values for each CV representig mean and std.
o = 'encoder'
self.encoder = FeedForward(encoder_layers + [n_cvs*2], **options[o])

# initialize encoder
o = 'decoder'
self.decoder = FeedForward([n_cvs] + decoder_layers, **options[o])

# ===== LOSS OPTIONS =====
self.loss_options = {}

@property
def n_cvs(self):
"""Number of CVs."""
return self.decoder.in_features

def forward_blocks(self, x: torch.tensor) -> torch.Tensor:
"""Compute the value of the CV from preprocessed input.

Return the mean output (ignoring the variance output) of the encoder
after (optionally) applying the normalization to the input.

Parameters
----------
x : torch.Tensor
Shape ``(n_batches, n_descriptors)`` or ``(n_descriptors,)``. The
input descriptors of the CV after preprocessing.

Returns
-------
cv : torch.Tensor
Shape ``(n_batches, n_cvs)``. The CVs, i.e., the mean output of the
encoder (the variance output is discarded).
"""
if self.normIn is not None:
x = self.normIn(x)
x = self.encoder(x)

# Take only the means and ignore the log variances.
return x[..., :self.n_cvs]

def encode_decode(self, x: torch.tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Run a pass of encoding + decoding.

The function applies the normalizing to the inputs and its reverse on
the output.

Parameters
----------
x : torch.Tensor
Shape ``(n_batches, n_descriptors)`` or ``(n_descriptors,)``. The
input descriptors of the CV after preprocessing.

Returns
-------
mean : torch.Tensor
Shape ``(n_batches, n_cvs)`` of ``(n_cvs,)``. The mean of the
Gaussian distribution associated to the input in latent space.
log_variance : torch.Tensor
Shape ``(n_batches, n_cvs)`` of ``(n_cvs,)``. The logarithm of the
variance of the Gaussian distribution associated to the input in
latent space.
x_hat : torch.Tensor
Shape ``(n_batches, n_descriptors)`` or ``(n_descriptors,)``. The
reconstructed descriptors.
"""
# Normalize inputs.
if self.normIn is not None:
x = self.normIn(x)

# Encode input into a Gaussian distribution.
x = self.encoder(x)
mean, log_variance = x[..., :self.n_cvs], x[..., self.n_cvs:]

# Sample from the Gaussian distribution in latent space.
std = torch.exp(log_variance / 2)
z = torch.distributions.Normal(mean, std).rsample()

# Decode sample.
x_hat = self.decoder(z)
if self.normIn is not None:
x_hat = self.normIn.inverse(x)

return mean, log_variance, x_hat

def loss_function(self, diff, mean, log_variance, **kwargs):
"""ELBO loss function when latent space and reconstruction distributions are Gaussians."""
return elbo_gaussians_loss(diff, mean, log_variance, **kwargs)

def training_step(self, train_batch, batch_idx):
"""Single training step performed by the PyTorch Lightning Trainer."""
options = self.loss_options.copy()
x = train_batch['data']
if 'weights' in train_batch:
options['weights'] = train_batch['weights']

# TODO: Should we do preprocessing here?

# Encode/decode.
mean, log_variance, x_hat = self.encode_decode(x)

# Loss function.
diff = x - x_hat
loss = self.loss_function(diff, mean, log_variance, **options)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think it makes sense to make the loss function return also the contributions to the loss (reconstruction and KL) and log them?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that also for the TDA CV @EnricoTrizio might want to keep track of the contributions to the loss, so we might think of a uniform way of doing this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think that's a good idea. Maybe we could return only the loss as default and control whether the decomposition in kl and reconstruction is returned with an optional argument of elbo_gaussians_loss()?

Copy link
Owner

@luigibonati luigibonati Mar 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that sounds good. i think we just need to think about a general way of returning and logging that works in every case (and also if a custom loss is used). in general, it would be good to return a dict so that the keys determine the log name. actually this would involve also lda/tica CVS (there we want to check the values of the single eigenvalues) so it is pretty general.

Three options come to my mind:

  1. we make the loss functions return either a scalar loss or a dictionary that contains the 'loss' as well as additional keys and items to be logged. we then need to parse the variable in the training_step func and check which case we are in
  2. we return either a scalar loss or a tuple (loss, log_dict) and do the same as in 1)
  3. we change every loss to return not a scalar but always a dict

I would go for option no. 3 which is the cleanest one

what do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would avoid 2 since it makes logging harder, but both 1 and 3 seems good to me.

If I had to choose, I'd probably go with 1, returning the total loss by default (simply because it's what a user would likely expect when using a loss function object) and have an attribute that controls whether a more detailed dictionary is returned (which can be turned on in the default loss_options in the cvs).


# Log.
name = 'train' if self.training else 'valid'
self.log(f'{name}_loss', loss, on_epoch=True)

return loss
61 changes: 61 additions & 0 deletions mlcvs/tests/test_cvs_unsupervised_vae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#!/usr/bin/env python


# =============================================================================
# MODULE DOCSTRING
# =============================================================================

"""
Test objects and function in mlcvs.cvs.unsupervised.vae.
"""


# =============================================================================
# GLOBAL IMPORTS
# =============================================================================

import pytest
import pytorch_lightning as pl
import torch

from mlcvs.cvs.unsupervised.vae import VAE_CV
from mlcvs.data import DictionaryDataset, DictionaryDataModule


# =============================================================================
# TESTS
# =============================================================================

@pytest.mark.parametrize('weights', [False, True])
def test_vae_cv_training(weights):
"""Run a full training of a VAECv."""
# Create VAE CV.
n_cvs = 2
in_features = 8
model = VAE_CV(
n_cvs=n_cvs,
encoder_layers=[in_features, 6, 4],
options={
'normIn': None,
'encoder': {'activation' : 'relu'},
}
)

# Create input data.
batch_size = 100
x = torch.randn(batch_size, in_features)
data = {'data': x}

# Create weights.
if weights is True:
data['weights'] = torch.rand(batch_size)

# Train.
datamodule = DictionaryDataModule(DictionaryDataset(data))
trainer = pl.Trainer(max_epochs=1, log_every_n_steps=2, logger=None, enable_checkpointing=False)
trainer.fit(model, datamodule)

# Eval.
model.eval()
x_hat = model(x)
assert x_hat.shape == (batch_size, n_cvs)