Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More explicit prior shape broadcasting #1520

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

mshvartsman
Copy link
Contributor

This is an attempt to address #1317 and #1318 -- incomplete but hopefully helps concretize discussion. I think these bugs basically have to do with unexpected implicit broadcasting behavior of priors (where you think your prior should broadcast over events but it silently broadcasts over batches, or vice versa). So this PR makes prior broadcasting universal and explicit, and throws if it can't expand shapes. Without something like to_event (which I think we'd need pyro for) or a restriction to not to broadcast priors over event dimensions, we (still) have a mismatch between gpytorch batch/event shapes and the underlying torch distributions shapes, but I think that the errors are correct (i.e. they are thrown if "event" and "batch" dimensions don't match even if they're not encoded as such attributes).

I'm not sure it's the correct thing. For example, you need to prior.sample() without args to get a full gpytorch-batch of priors, and if you prior.sample(torch.Size([3]) you get 3 batches of batched priors.

Below are some examples of what happens with these changes, based on @gpleiss' taxonomy of parameters that take priors. If they make sense, they can become tests in this PR.

# setup
import torch
from gpytorch.priors import NormalPrior, GammaPrior
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.likelihoods.noise_models import HomoskedasticNoise
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean

Case 1: Noise prior

[*batch_shape, 1]

# no broadcast
noise = HomoskedasticNoise(noise_prior=GammaPrior(concentration=2., rate=2.))
new_noise = noise.noise_prior.sample()
assert noise.noise.shape == new_noise.shape

# batched broadcast
noise = HomoskedasticNoise(noise_prior=GammaPrior(concentration=2., rate=2.), batch_shape=torch.Size([3,5]))
new_noise = noise.noise_prior.sample()
assert noise.noise.shape == new_noise.shape

# fails (can't broadcast)
noise = HomoskedasticNoise(noise_prior=GammaPrior(concentration=2., rate=torch.ones(2)))

# also fails, because mixing event and batch dims
noise = HomoskedasticNoise(noise_prior=GammaPrior(concentration=2., rate=torch.ones(2)), batch_shape=torch.Size([2]))

# succeeds with batched priors
noise = HomoskedasticNoise(noise_prior=GammaPrior(concentration=2., rate=torch.ones(2, 1)), batch_shape=torch.Size([2]))

And here's @dme65's example:

# throws, noise is [*batch_shape, 1] and we give a [*batch_shape] prior which cannot broadcast
likelihood = GaussianLikelihood(
    noise_prior=gpytorch.priors.NormalPrior(loc=torch.zeros(2), scale=torch.ones(2)),
    batch_shape=torch.Size([2]),
)

# after this PR, the correct way of doing this would be:
likelihood = GaussianLikelihood(
    noise_prior=gpytorch.priors.NormalPrior(loc=torch.zeros(2, 1), scale=torch.ones(2, 1)),
    batch_shape=torch.Size([2]),
)

Case 2: means

[*batch_shape, 1], this is basically the same as noise.

# unbatched, works
mean = ConstantMean(prior=NormalPrior(loc=10, scale=3))
new_mean = mean.mean_prior.sample()
assert mean.constant.shape == new_mean.shape

# fails (can't broadcast)
mean = ConstantMean(prior=NormalPrior(loc=10, scale=torch.ones(2)))

# also fails, batch shape isn't event shape
mean = ConstantMean(prior=NormalPrior(loc=10, scale=torch.ones(2)), batch_shape = torch.Size([2]))

# succeeds 
mean = ConstantMean(prior=NormalPrior(loc=10, scale=torch.ones(2, 1)), batch_shape = torch.Size([2]))

Case 3: outputscales

[*batch_shape]

# unbatched, no broadcast
kernel = ScaleKernel(RBFKernel(), outputscale_prior=NormalPrior(loc=10, scale=1.3))
new_outputscale = kernel.outputscale_prior.sample()
assert kernel.outputscale.shape == new_outputscale.shape

# batched
kernel = ScaleKernel(RBFKernel(), outputscale_prior=NormalPrior(loc=10, scale=1.3), batch_shape=torch.Size([3,4]))
new_outputscale = kernel.outputscale_prior.sample()
assert kernel.outputscale.shape == new_outputscale.shape

# ScaleKernel has no ARD, error
kernel = ScaleKernel(RBFKernel(), outputscale_prior=NormalPrior(loc=10, scale=torch.ones(2)))

# this works. I'm not sure if it should work -- since there's no event dim, the last prior dim is actually batch dim here
kernel = ScaleKernel(RBFKernel(), outputscale_prior=NormalPrior(loc=10, scale=torch.ones(2)), batch_shape=torch.Size([2]))

Lengthscales

[*batch_shape, 1, d]

# no broadcasting
kernel = RBFKernel(ard_num_dims=2,lengthscale_prior=GammaPrior(3.0, torch.ones(2)))
new_lengthscale = kernel.lengthscale_prior.sample()
assert kernel.lengthscale.shape == new_lengthscale.shape

# broadcast event
kernel = RBFKernel(ard_num_dims=2, lengthscale_prior=NormalPrior(loc=10, scale=1))
new_lengthscale = kernel.lengthscale_prior.sample()
assert kernel.lengthscale.shape == new_lengthscale.shape

# broadcast batch
kernel = RBFKernel(ard_num_dims=1, lengthscale_prior=NormalPrior(loc=10, scale=1.3), batch_shape=torch.Size([3,5]))
new_lengthscale = kernel.lengthscale_prior.sample()
assert kernel.lengthscale.shape == new_lengthscale.shape

# broadcast batch + event -- this is possibly the most semantically strange one
kernel = RBFKernel(ard_num_dims=2, lengthscale_prior=NormalPrior(loc=10, scale=1.3), batch_shape=torch.Size([3,5]))
new_lengthscale = kernel.lengthscale_prior.sample()
assert kernel.lengthscale.shape == new_lengthscale.shape

# ARD, no way to broadcast, error
kernel = RBFKernel(ard_num_dims=1, lengthscale_prior=NormalPrior(loc=10, scale=torch.ones(2)))

# batch is not event, errors
kernel = RBFKernel(ard_num_dims=1,lengthscale_prior=NormalPrior(loc=10, scale=torch.ones(2)), batch_shape=torch.Size([2]),)
kernel = RBFKernel(ard_num_dims=1,lengthscale_prior=NormalPrior(loc=10, scale=torch.ones(2, 1)), batch_shape=torch.Size([2]),)

# succeeds, batched prior
kernel = RBFKernel(ard_num_dims=2,lengthscale_prior=NormalPrior(loc=10, scale=torch.ones(2, 1, 1)), batch_shape=torch.Size([2]))
new_lengthscale = kernel.lengthscale_prior.sample()
assert kernel.lengthscale.shape == new_lengthscale.shape

gpytorch/module.py Outdated Show resolved Hide resolved
@mshvartsman
Copy link
Contributor Author

Still failing: SmoothedBoxPrior has event_shape of 1 whereas all pytorch-imported priors have no event_shape , so if we are being pedantic about broadcasting you shouldn’t be able to use a SmoothedBoxPrior as a prior for outputscale (which also has no event_shape). Some options:

  1. Modify SmoothedBoxPrior to have the same semantics as pytorch scalar priors (no event_shape). Haven't checked yet whether we'll have the same issue for other gpytorch priors. It's weird because it's basically converting to incorrect behavior.
  2. Modify all the pytorch-imported priors to have pyro-style event_shapes, redo the broadcasting code to handle batch vs event correctly.
  3. Modify outputscales to have an event dimension.

I'm slightly leaning towards 3, but all of these seem like relatively big changes.

Separately: I added a guard in case prior is getting passed as None, which sometimes happens. Is this intentional?

I'm also wondering if we should do a more minor change that just warns if you're doing something that may lead to unexpected prior broadcasting behavior, because this PR feels like it's turning into a larger set of prior changes.

Copy link
Collaborator

@Balandat Balandat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm running into a ton of test failures when rebasing on master. Did this largely work in the past (except for the SmoothBoxPrior?). If so, there may have been some other changes since that caused issues with this.

Separately: I added a guard in case prior is getting passed as None, which sometimes happens. Is this intentional?
I don't think so?

@@ -262,6 +262,20 @@ def setting_closure(module, val):
)
closure = param_or_closure

if prior is not None:
hyperparameter_shape = closure(self).shape
prior_shape = prior.shape()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did this work at any point? Is this meant to be event_shape?

Alternatively we would give Priors a shape property that is just self._batch_shape + self._event_shape.

@mshvartsman
Copy link
Contributor Author

Yeah, I thought it worked except the box prior stuff, so let me rebase and revisit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants