Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Prior losses are incorrectly added to the mll in batch-mode #1318

Closed
dme65 opened this issue Oct 20, 2020 · 7 comments
Closed

[Bug] Prior losses are incorrectly added to the mll in batch-mode #1318

dme65 opened this issue Oct 20, 2020 · 7 comments

Comments

@dme65
Copy link
Collaborator

dme65 commented Oct 20, 2020

🐛 Bug

Prior losses are currently being added up incorrectly in ExactMarginalLogLikelihood. The line:

res.add_(prior.log_prob(closure()).sum())

will sum up all of the losses and then add them to the mll. If you are using a batch model this sum gets added to all of the batch dimensions which will count the losses multiple times when eventually calling loss.sum().backward(). It looks like the priors may not support batch mode which leads to a large variety of different shapes, but the .sum() call masks this issue since it just sums everything up anyway.

To reproduce

Code snippet (taken from test_train_on_batch_test_on_batch):

import math

import torch

import gpytorch
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean

train_x1 = torch.linspace(0, 2, 11).unsqueeze(-1)
train_y1 = torch.sin(train_x1 * (2 * math.pi)).squeeze()
train_x2 = torch.linspace(0, 1, 11).unsqueeze(-1)
train_y2 = torch.sin(train_x2 * (2 * math.pi)).squeeze()
train_x12 = torch.cat((train_x1.unsqueeze(0), train_x2.unsqueeze(0)), dim=0).contiguous()
train_y12 = torch.cat((train_y1.unsqueeze(0), train_y2.unsqueeze(0)), dim=0).contiguous()


class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_inputs, train_targets, likelihood, batch_shape=torch.Size()):
        super(ExactGPModel, self).__init__(train_inputs, train_targets, likelihood)
        self.mean_module = ConstantMean(batch_shape=batch_shape, prior=gpytorch.priors.SmoothedBoxPrior(-1, 1))
        self.covar_module = ScaleKernel(
            RBFKernel(
                batch_shape=batch_shape,
                lengthscale_prior=gpytorch.priors.NormalPrior(
                    loc=torch.zeros(*batch_shape, 1, 1), scale=torch.ones(*batch_shape, 1, 1)
                ),
            ),
            batch_shape=batch_shape,
            outputscale_prior=gpytorch.priors.SmoothedBoxPrior(-2, 2),
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)
    

# We're manually going to set the hyperparameters to something they shouldn't be
likelihood = GaussianLikelihood(
    noise_prior=gpytorch.priors.NormalPrior(loc=torch.zeros(2), scale=torch.ones(2)),
    batch_shape=torch.Size([2]),
)
gp_model = ExactGPModel(train_x12, train_y12, likelihood, batch_shape=torch.Size([2]))

for name, prior, closure, _ in gp_model.named_priors():
    print(name, prior.log_prob(closure()).shape)

Output:

likelihood.noise_covar.noise_prior torch.Size([2, 2])
mean_module.mean_prior torch.Size([2])
covar_module.outputscale_prior torch.Size([])
covar_module.base_kernel.lengthscale_prior torch.Size([2, 1, 1])

Expected Behavior

The prior losses should have the same size and be added up via res.add_(prior.log_prob(closure())) without the inner sum call.

System information

Please complete the following information:
GPyTorch Version: 1.2.0
PyTorch Version: 1.6.0
Mac

Additional context

This was originally discovered in PR #1314.

cc: @Balandat

@dme65 dme65 added the bug label Oct 20, 2020
@dme65
Copy link
Collaborator Author

dme65 commented Oct 20, 2020

Potentially related to #1317.

@georgedeath
Copy link

georgedeath commented Oct 21, 2020

Amazingly I also came across this bug yesterday and was about to post a bug report.

A simple (hacky) fix I'm currently using is to replace:

res.add_(prior.log_prob(closure()).sum())

with

val = prior.log_prob(closure())
if val.ndim == 3:
    val = val.sum((1, 2))
elif val.ndim == 2:
    val = val.sum(1)
res.add_(val.squeeze())

in the MLL, e.g. here.

I think this deals with all use-cases, but it is particually ugly.

@jacobrgardner
Copy link
Member

Oof okay this is pretty bad. I'll get #1317 and #1318 fixed early this week, and we can push out a 1.2.1

@gpleiss
Copy link
Member

gpleiss commented Oct 26, 2020

I think this is a bit complicated because of the event/batch shapes of the different parameters.

  • likelihood.noise_covar.noise_prior - the noise parameter has a shape [*batch_shape, 1]. So the prior applied to it should have an event_shape of 1. (I.e. it should be a NormalDistribution with to_event(1)).
  • mean_module.mean_prior - again, the constant parameter has a shape [*batch_shape, 1]. I'm not sure exactly how we're getting the exact shape... but I think something might be off
  • covar_module.outputscale_prior - This parameter has the shape [*batch_shape]. I think the SmoothedBoxPrior is treating the batch dimension as an event dimension.
  • covar_module.base_kernel.lengthscale_prior - the noise parameter has a shape [*batch_shape, 1, d]. So again, we need a prior with an event shape of 1 x d.

I'm not sure exactly what we should do to resolve these issues.

@jacobrgardner
Copy link
Member

We need to make priors aware of the shapes we expect. We can either do this manually, or using the same kind of logic we use to accomplish this for pyro integration using the fact that priors have expand methods:

prior = prior.expand(closure().shape)

We can either do this when registering them in the individual classes or in Module by here:

return getattr(self, param_or_closure)

What do you think?

@mshvartsman
Copy link
Contributor

Oops, I commented on #1317 without looking at updates here. What about the case where we're using a batched univariate prior for a vector-valued hyperparam (like lengthscales with ARD)? Seems like we shouldn't always expand.

@saitcakmak
Copy link
Collaborator

This was fixed in #2039. Feel free to close the issue!

@gpleiss gpleiss closed this as completed Oct 18, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants