Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First implementation of a Variation Autoencoder CV #27

Merged
merged 1 commit into from
Mar 21, 2023
Merged

Conversation

andrrizzi
Copy link
Collaborator

@andrrizzi andrrizzi commented Mar 21, 2023

Description

This implements a CV based on a standard variational autoencoder using a MSE reconstruction loss and regularizing with a KL divergence w.r.t. a Normal(0, 1) distribution. For the CV, only the mean is used as CV in production.

Questions

  • I forgot, did we decide to have a mechanism to do the preprocessing inside training_step()? For now, I'm not doing any preprocessing during training.

Status

  • Ready to go


# Loss function.
diff = x - x_hat
loss = self.loss_function(diff, mean, log_variance, **options)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think it makes sense to make the loss function return also the contributions to the loss (reconstruction and KL) and log them?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that also for the TDA CV @EnricoTrizio might want to keep track of the contributions to the loss, so we might think of a uniform way of doing this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think that's a good idea. Maybe we could return only the loss as default and control whether the decomposition in kl and reconstruction is returned with an optional argument of elbo_gaussians_loss()?

Copy link
Owner

@luigibonati luigibonati Mar 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that sounds good. i think we just need to think about a general way of returning and logging that works in every case (and also if a custom loss is used). in general, it would be good to return a dict so that the keys determine the log name. actually this would involve also lda/tica CVS (there we want to check the values of the single eigenvalues) so it is pretty general.

Three options come to my mind:

  1. we make the loss functions return either a scalar loss or a dictionary that contains the 'loss' as well as additional keys and items to be logged. we then need to parse the variable in the training_step func and check which case we are in
  2. we return either a scalar loss or a tuple (loss, log_dict) and do the same as in 1)
  3. we change every loss to return not a scalar but always a dict

I would go for option no. 3 which is the cleanest one

what do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would avoid 2 since it makes logging harder, but both 1 and 3 seems good to me.

If I had to choose, I'd probably go with 1, returning the total loss by default (simply because it's what a user would likely expect when using a loss function object) and have an attribute that controls whether a more detailed dictionary is returned (which can be turned on in the default loss_options in the cvs).

@luigibonati
Copy link
Owner

Questions

  • I forgot, did we decide to have a mechanism to do the preprocessing inside training_step()? For now, I'm not doing any preprocessing during training.

No, the way you implemented is correct and uniform with the other classes (e.g. overloading forward_blocks) in the sense that base cv class has a forward method that adds pre and post processing modules to forward_blocks. As it is now, this is not supposed to take place during training

@luigibonati luigibonati merged commit 886fc51 into lightning Mar 21, 2023
@andrrizzi andrrizzi deleted the vae branch March 21, 2023 15:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants