-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
First implementation of a Variation Autoencoder CV #27
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
#!/usr/bin/env python | ||
|
||
# ============================================================================= | ||
# MODULE DOCSTRING | ||
# ============================================================================= | ||
|
||
""" | ||
Evidence Lower BOund (ELBO) loss functions used to train variational Autoencoders. | ||
""" | ||
|
||
__all__ = ['elbo_gaussians_loss'] | ||
|
||
|
||
# ============================================================================= | ||
# GLOBAL IMPORTS | ||
# ============================================================================= | ||
|
||
from typing import Optional | ||
import torch | ||
from mlcvs.core.loss.mse import MSE_loss | ||
|
||
|
||
# ============================================================================= | ||
# LOSS FUNCTIONS | ||
# ============================================================================= | ||
|
||
def elbo_gaussians_loss( | ||
diff: torch.Tensor, | ||
mean: torch.Tensor, | ||
log_variance: torch.Tensor, | ||
weights: Optional[torch.Tensor] = None | ||
): | ||
"""ELBO loss function assuming the latent and reconstruction distributions are Gaussian. | ||
|
||
The ELBO uses the MSE as the reconstruction loss (i.e., assumes that the | ||
decoder outputs the mean of a Gaussian distribution with variance 1), and | ||
the KL divergence between two normal distributions ``N(mean, var)`` and | ||
``N(0, 1)``, where ``mean`` and ``var`` are the output of the encoder. | ||
|
||
Parameters | ||
---------- | ||
diff : torch.Tensor | ||
Shape ``(n_batches, in_features)``. The difference between the input of | ||
the encoder and the output of the decoder. | ||
mean : torch.Tensor | ||
Shape ``(n_batches, latent_features)``. The means of the Gaussian | ||
distributions associated to the inputs. | ||
log_variance : torch.Tensor | ||
Shape ``(n_batches, latent_features)``. The logarithm of the variances | ||
of the Gaussian distributions associated to the inputs. | ||
weights : torch.Tensor, optional | ||
Shape ``(n_batches,)`. If given, the average over batches is weighted. | ||
The default (``None``) is unweighted. | ||
|
||
Returns | ||
------- | ||
loss: torch.Tensor | ||
The value of the loss function. | ||
""" | ||
# KL divergence between N(mean, variance) and N(0, 1). | ||
# See https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians | ||
kl = -0.5 * (log_variance - log_variance.exp() - mean**2 + 1).sum(dim=1) | ||
|
||
# Weighted mean over batches. | ||
if weights is None: | ||
kl = kl.mean() | ||
else: | ||
kl = (kl * weights).sum() | ||
|
||
# Reconstruction loss. | ||
reconstruction = MSE_loss(diff, weights=weights) | ||
|
||
return reconstruction + kl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
#!/usr/bin/env python | ||
|
||
# ============================================================================= | ||
# MODULE DOCSTRING | ||
# ============================================================================= | ||
|
||
""" | ||
Variational Autoencoder collective variable. | ||
""" | ||
|
||
__all__ = ["VAE_CV"] | ||
|
||
|
||
# ============================================================================= | ||
# GLOBAL IMPORTS | ||
# ============================================================================= | ||
|
||
from typing import Any, Optional, Tuple | ||
import torch | ||
import pytorch_lightning as pl | ||
from mlcvs.cvs import BaseCV | ||
from mlcvs.core import FeedForward, Normalization | ||
from mlcvs.core.loss import elbo_gaussians_loss | ||
|
||
|
||
# ============================================================================= | ||
# VARIATIONAL AUTOENCODER CV | ||
# ============================================================================= | ||
|
||
class VAE_CV(BaseCV, pl.LightningModule): | ||
"""Variational AutoEncoder Collective Variable. | ||
|
||
At training time, the encoder outputs a mean and a variance for each CV | ||
defining a Gaussian distribution associated to the input. One sample is | ||
drawn from this Gaussian, and it goes through the decoder. Then the ELBO | ||
loss is minimized. The ELBO sums the MSE of the reconstruction and the KL | ||
divergence between the generated Gaussian and a N(0, 1) Gaussian. | ||
|
||
At evaluation time, the encoder's output mean is used as the CV, while the | ||
variance output and the decoder are ignored. | ||
|
||
For training, it requires a DictionaryDataset with the key ``'data'`` and | ||
optionally ``'weights'``. | ||
""" | ||
|
||
BLOCKS = ['normIn', 'encoder', 'decoder'] | ||
|
||
def __init__(self, | ||
n_cvs : int, | ||
encoder_layers : list, | ||
decoder_layers : Optional[list] = None, | ||
options : Optional[dict] = None, | ||
**kwargs): | ||
""" | ||
Variational autoencoder constructor. | ||
|
||
Parameters | ||
---------- | ||
n_cvs : int | ||
The dimension of the CV or, equivalently, the dimension of the latent | ||
space of the autoencoder. | ||
encoder_layers : list | ||
Number of neurons per layer of the encoder up to the last hidden layer. | ||
The size of the output layer is instead specified with ``n_cvs`` | ||
decoder_layers : list, optional | ||
Number of neurons per layer of the decoder, except for the input layer | ||
which is specified by ``n_cvs``. If ``None`` (default), it takes automatically | ||
the reversed architecture of the encoder. | ||
options : dict[str, Any], optional | ||
Options for the building blocks of the model, by default ``None``. | ||
Available blocks are: ``'normIn'``, ``'encoder'``, and ``'decoder'``. | ||
Set ``'block_name' = None`` or ``False`` to turn off a block. Encoder | ||
and decoder cannot be turned off. | ||
""" | ||
super().__init__(in_features=encoder_layers[0], out_features=n_cvs, **kwargs) | ||
|
||
# ===== BLOCKS ===== | ||
|
||
options = self.sanitize_options(options) | ||
|
||
# parse info from args | ||
if decoder_layers is None: | ||
decoder_layers = encoder_layers[::-1] | ||
|
||
# initialize normIn | ||
o = 'normIn' | ||
if ( options[o] is not False ) and (options[o] is not None): | ||
self.normIn = Normalization(self.in_features, **options[o]) | ||
|
||
# initialize encoder | ||
# The encoder outputs two values for each CV representig mean and std. | ||
o = 'encoder' | ||
self.encoder = FeedForward(encoder_layers + [n_cvs*2], **options[o]) | ||
|
||
# initialize encoder | ||
o = 'decoder' | ||
self.decoder = FeedForward([n_cvs] + decoder_layers, **options[o]) | ||
|
||
# ===== LOSS OPTIONS ===== | ||
self.loss_options = {} | ||
|
||
@property | ||
def n_cvs(self): | ||
"""Number of CVs.""" | ||
return self.decoder.in_features | ||
|
||
def forward_blocks(self, x: torch.tensor) -> torch.Tensor: | ||
"""Compute the value of the CV from preprocessed input. | ||
|
||
Return the mean output (ignoring the variance output) of the encoder | ||
after (optionally) applying the normalization to the input. | ||
|
||
Parameters | ||
---------- | ||
x : torch.Tensor | ||
Shape ``(n_batches, n_descriptors)`` or ``(n_descriptors,)``. The | ||
input descriptors of the CV after preprocessing. | ||
|
||
Returns | ||
------- | ||
cv : torch.Tensor | ||
Shape ``(n_batches, n_cvs)``. The CVs, i.e., the mean output of the | ||
encoder (the variance output is discarded). | ||
""" | ||
if self.normIn is not None: | ||
x = self.normIn(x) | ||
x = self.encoder(x) | ||
|
||
# Take only the means and ignore the log variances. | ||
return x[..., :self.n_cvs] | ||
|
||
def encode_decode(self, x: torch.tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
"""Run a pass of encoding + decoding. | ||
|
||
The function applies the normalizing to the inputs and its reverse on | ||
the output. | ||
|
||
Parameters | ||
---------- | ||
x : torch.Tensor | ||
Shape ``(n_batches, n_descriptors)`` or ``(n_descriptors,)``. The | ||
input descriptors of the CV after preprocessing. | ||
|
||
Returns | ||
------- | ||
mean : torch.Tensor | ||
Shape ``(n_batches, n_cvs)`` of ``(n_cvs,)``. The mean of the | ||
Gaussian distribution associated to the input in latent space. | ||
log_variance : torch.Tensor | ||
Shape ``(n_batches, n_cvs)`` of ``(n_cvs,)``. The logarithm of the | ||
variance of the Gaussian distribution associated to the input in | ||
latent space. | ||
x_hat : torch.Tensor | ||
Shape ``(n_batches, n_descriptors)`` or ``(n_descriptors,)``. The | ||
reconstructed descriptors. | ||
""" | ||
# Normalize inputs. | ||
if self.normIn is not None: | ||
x = self.normIn(x) | ||
|
||
# Encode input into a Gaussian distribution. | ||
x = self.encoder(x) | ||
mean, log_variance = x[..., :self.n_cvs], x[..., self.n_cvs:] | ||
|
||
# Sample from the Gaussian distribution in latent space. | ||
std = torch.exp(log_variance / 2) | ||
z = torch.distributions.Normal(mean, std).rsample() | ||
|
||
# Decode sample. | ||
x_hat = self.decoder(z) | ||
if self.normIn is not None: | ||
x_hat = self.normIn.inverse(x) | ||
|
||
return mean, log_variance, x_hat | ||
|
||
def loss_function(self, diff, mean, log_variance, **kwargs): | ||
"""ELBO loss function when latent space and reconstruction distributions are Gaussians.""" | ||
return elbo_gaussians_loss(diff, mean, log_variance, **kwargs) | ||
|
||
def training_step(self, train_batch, batch_idx): | ||
"""Single training step performed by the PyTorch Lightning Trainer.""" | ||
options = self.loss_options.copy() | ||
x = train_batch['data'] | ||
if 'weights' in train_batch: | ||
options['weights'] = train_batch['weights'] | ||
|
||
# TODO: Should we do preprocessing here? | ||
|
||
# Encode/decode. | ||
mean, log_variance, x_hat = self.encode_decode(x) | ||
|
||
# Loss function. | ||
diff = x - x_hat | ||
loss = self.loss_function(diff, mean, log_variance, **options) | ||
|
||
# Log. | ||
name = 'train' if self.training else 'valid' | ||
self.log(f'{name}_loss', loss, on_epoch=True) | ||
|
||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
#!/usr/bin/env python | ||
|
||
|
||
# ============================================================================= | ||
# MODULE DOCSTRING | ||
# ============================================================================= | ||
|
||
""" | ||
Test objects and function in mlcvs.cvs.unsupervised.vae. | ||
""" | ||
|
||
|
||
# ============================================================================= | ||
# GLOBAL IMPORTS | ||
# ============================================================================= | ||
|
||
import pytest | ||
import pytorch_lightning as pl | ||
import torch | ||
|
||
from mlcvs.cvs.unsupervised.vae import VAE_CV | ||
from mlcvs.data import DictionaryDataset, DictionaryDataModule | ||
|
||
|
||
# ============================================================================= | ||
# TESTS | ||
# ============================================================================= | ||
|
||
@pytest.mark.parametrize('weights', [False, True]) | ||
def test_vae_cv_training(weights): | ||
"""Run a full training of a VAECv.""" | ||
# Create VAE CV. | ||
n_cvs = 2 | ||
in_features = 8 | ||
model = VAE_CV( | ||
n_cvs=n_cvs, | ||
encoder_layers=[in_features, 6, 4], | ||
options={ | ||
'normIn': None, | ||
'encoder': {'activation' : 'relu'}, | ||
} | ||
) | ||
|
||
# Create input data. | ||
batch_size = 100 | ||
x = torch.randn(batch_size, in_features) | ||
data = {'data': x} | ||
|
||
# Create weights. | ||
if weights is True: | ||
data['weights'] = torch.rand(batch_size) | ||
|
||
# Train. | ||
datamodule = DictionaryDataModule(DictionaryDataset(data)) | ||
trainer = pl.Trainer(max_epochs=1, log_every_n_steps=2, logger=None, enable_checkpointing=False) | ||
trainer.fit(model, datamodule) | ||
|
||
# Eval. | ||
model.eval() | ||
x_hat = model(x) | ||
assert x_hat.shape == (batch_size, n_cvs) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you think it makes sense to make the loss function return also the contributions to the loss (reconstruction and KL) and log them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that also for the TDA CV @EnricoTrizio might want to keep track of the contributions to the loss, so we might think of a uniform way of doing this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think that's a good idea. Maybe we could return only the loss as default and control whether the decomposition in kl and reconstruction is returned with an optional argument of
elbo_gaussians_loss()
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, that sounds good. i think we just need to think about a general way of returning and logging that works in every case (and also if a custom loss is used). in general, it would be good to return a dict so that the keys determine the log name. actually this would involve also lda/tica CVS (there we want to check the values of the single eigenvalues) so it is pretty general.
Three options come to my mind:
I would go for option no. 3 which is the cleanest one
what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would avoid 2 since it makes logging harder, but both 1 and 3 seems good to me.
If I had to choose, I'd probably go with 1, returning the total loss by default (simply because it's what a user would likely expect when using a loss function object) and have an attribute that controls whether a more detailed dictionary is returned (which can be turned on in the default
loss_options
in the cvs).