Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Sampling from priors doesn't match shape of hyperparameters #1317

Open
mshvartsman opened this issue Oct 20, 2020 · 9 comments
Open

[Bug] Sampling from priors doesn't match shape of hyperparameters #1317

mshvartsman opened this issue Oct 20, 2020 · 9 comments

Comments

@mshvartsman
Copy link
Contributor

🐛 Bug

I found some unexpected interactions between ard_num_dims and the shapes of priors for kernels -- a few settings where if I sample from a hyperparameter prior I don't get a tensor the same shape as the hyperparameter. I'm not sure if all of these are intended or not, but looks like a bug to me.

To reproduce

import torch
from gpytorch.priors import GammaPrior, NormalPrior
from gpytorch.kernels import RBFKernel

# make a kernel
scales = torch.Tensor([1,1])
kernel = RBFKernel(
    ard_num_dims=2,
    lengthscale_prior=GammaPrior(3.0, 6.0 / scales),
)
new_lengthscale = kernel.lengthscale_prior.sample(kernel.lengthscale.shape)
print(kernel.lengthscale.shape) # size 1,2
print(new_lengthscale.shape) # size 1,2,2, if I try to assign it back I get an error

# same with another prior
kernel2 = RBFKernel(
    ard_num_dims=2,
    lengthscale_prior=NormalPrior(loc=10, scale=scales)
)

new_lengthscale = kernel2.lengthscale_prior.sample(kernel2.lengthscale.shape)
print(kernel2.lengthscale.shape) # size 1,2
print(new_lengthscale.shape) # size 1, 2, 2

# ard_num_dims is only 1 but we have a higher-dim prior. Is this behavior defined?
kernel3 = RBFKernel(
    ard_num_dims=1,
    lengthscale_prior=NormalPrior(loc=10, scale=scales)
)

new_lengthscale = kernel3.lengthscale_prior.sample(kernel3.lengthscale.shape)
print(kernel3.lengthscale.shape) # size 1, 1 -- but shouldn't we expect 1,2? 
print(new_lengthscale.shape) # size 1, 1, 2

# ok, ard_num_dims is 2 but my prior is 1d, now it works correctly
kernel4 = RBFKernel(
    ard_num_dims=2,
    lengthscale_prior=NormalPrior(loc=10, scale=1)
)

new_lengthscale = kernel4.lengthscale_prior.sample(kernel4.lengthscale.shape)
print(kernel4.lengthscale.shape) # size 1,2
print(new_lengthscale.shape) # size 1, 2

Expected Behavior

It would be nice if we got a warning/error earlier for undefined/unsupported behavior, and otherwise shapes matched correctly.

System information

Please complete the following information:

  • GPyTorch Version: 1.2.0
  • PyTorch Version: 1.6.0.
  • Computer OS: verified on OSX and CentOS.
@Balandat
Copy link
Collaborator

Yeah @dme65 ran into some issues that are likely related, it seems like there are some issues with batch sizes and the priors.

@mshvartsman
Copy link
Contributor Author

More thoughts about this: priors are just thin wrappers around pytorch distributions, right? The sample method on those

Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.
(https://pytorch.org/docs/stable/distributions.html#torch.distributions.distribution.Distribution.sample)

So if my prior is not batched, then in order to get the right size of kernel hyperparams I need to prior.sample(hyperparam.shape) but if my prior is batched with batch size hyperparam.shape (or multivariate) I need to prior.sample(torch.Size([1])), which is awkward.

So, what should we do about it? Some options:

  • Option 1: ensure consistency in the kernel constructor. Something like the following:
	If prior is univariate with batch_size == 1: 
		promote prior to batched with batch_size=hyperparam.shape
	elif prior is multivariate or univariate with nonzero batch_size: 
		check that prior size (multivariate) or batch size (univariate) matches hyperparam.shape

(and maybe, warn on calls to prior.sample(...) with argument equal to the prior batch size). There are probably reasons to call prior.sample() with various sizes but so this might be an overly stringent warning, but it'd reduce this confusion in the future.

Thoughts? Happy to PR this if there's consensus on how to proceed.

@Balandat
Copy link
Collaborator

Balandat commented Dec 9, 2020

So Option 1 makes sense to me. We could just use the existing expand method to achieve this promotion.

I think one potential challenge with this is that since this will not be in-place if we pass a prior as a prior to a module constructor, then the registered prior will refer to a different object. Specifically, say prior is a non-batched prior, then

prior.sample(torch.Size([1]))  # size 1
k = Kernel(batch_shape=torch.Size([2]), lengthscale_prior=prior)
k._priors["lengthscale_prior"].sample(torch.Size([1]))  # size 2 x 1

This is probably ok though...

A similar shape promotion would have to be done for ARD I guess.

@mshvartsman
Copy link
Contributor Author

Could we override prior.sample to do the checks rather than do it on the __init__ and call expand at the last possible moment? Then the object is the same, though the size mismatch in your example will still come up, and we're eating all the extra memcopy costs, so I guess that's not better unless we also do some caching, at which point they're different objects again.

@Balandat
Copy link
Collaborator

Yeah we could do that. I guess the one concern would be perf if we repeatedly sample from the prior and have to do the expansion each time. This shouldn't be too much of an issue though unless we want to sample a ton, e.g. when doing fully Bayesian inference (cc @jpchen)

On the other hand, I don't think too many folks will do something like k._priors["lengthscale_prior"].sample, so the point I brought up might not really be a valid concern.

@mshvartsman
Copy link
Contributor Author

Looking at this again: do we want to do this promotion in all batching cases, or ARD only? Basically, is ARD the only case where pytorch and gpytorch batch semantics don't match?

Option 1: just handle ARD in register_priors

  • Logic: if hasattr(self, ard_num_dims), check and promote (using prior.expand as @Balandat suggested), else do nothing.

Option 2: promote last dim only in register_priors

  • Logic: if prior.shape() in [torch.Size([]), torch.Size([1])] and closure().shape[-1] >1, promote.
  • Also requires a reordering of register_prior and register_constraints in all kernel init calls, since the prior closure won't evaluate until the constraint is registered. No idea what the knock-on effects are.

Option 3: promote all dims in register_priors

  • As in Option 3, but promote beyond just last dim. I don't have a solid enough understanding of what happens elsewhere to understand the implications here.

Option 4: promote somewhere in kernel rather than register_priors

  • This is what I originally suggested as option 1 above. It seems more elegant than Option 1, except that there's currently no single entry point where it would happen (we can capture lengthscales in just the Kernel __init__ but otherwise it's a lot of small changes today (because each kernel might have different hyperparameters) and future kernels need to know to support this (i.e. they don't get this for free by subclassing).

My lean is option 1 unless someone has a more clever solution for option 4. I think options 2 and 3 have too large a blast radius I don't really understand right now.

Under all options, the correct call would become kernel.lengthscale_prior.sample(torch.Size([batch_size])) where batch_size is the gpytorch batch size rather than the pytorch batch size. Of my examples above, all would work except kernel3, which would throw an error (because there's no obvious way, to me, to have more scales than ard_num_dims).

Can get working on a PR as soon as some folks weigh in.

@Balandat
Copy link
Collaborator

Balandat commented Mar 2, 2021

I think Option 4 is going to be a pain to get right, in particular since it's not only kernels but also likelihoods and anywhere else where there may be a prior.

I like Option 3 the best, as it seems to me this is the right thing to do, in the sense that we're broadcasting priors. Having all priors have the appropriate batch shapes will also make some of the bookkeeping that we definitely will need to do in the computation of the mll (#1318).

I am not sure about the implications of the inconsistency in shapes of different parameters that @gpleiss raised in #1318 on this though.

@mshvartsman
Copy link
Contributor Author

I think as long as we expand based on closure() in register_prior we should be fine, right? Though I'm not sure whether batch_shape will play well with this. I'll make some changes locally and give it a shot. Relatedly, if this and #1318 is the same bug (as seems to be the case), should we close one of them to keep discussion consolidated?

@Balandat
Copy link
Collaborator

Balandat commented Mar 2, 2021

Relatedly, if this and 1318 is the same bug (as seems to be the case), should we close one of them to keep discussion consolidated?

It's the same root cause but not the same symptom (sampling vs. evaluating MLL) - so maybe let';s keep it open until fixed to make it easier to find.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

No branches or pull requests

3 participants