-
Notifications
You must be signed in to change notification settings - Fork 561
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More explicit prior shape broadcasting #1520
base: main
Are you sure you want to change the base?
Conversation
Still failing:
I'm slightly leaning towards 3, but all of these seem like relatively big changes. Separately: I added a guard in case I'm also wondering if we should do a more minor change that just warns if you're doing something that may lead to unexpected prior broadcasting behavior, because this PR feels like it's turning into a larger set of prior changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm running into a ton of test failures when rebasing on master. Did this largely work in the past (except for the SmoothBoxPrior
?). If so, there may have been some other changes since that caused issues with this.
Separately: I added a guard in case prior is getting passed as None, which sometimes happens. Is this intentional?
I don't think so?
@@ -262,6 +262,20 @@ def setting_closure(module, val): | |||
) | |||
closure = param_or_closure | |||
|
|||
if prior is not None: | |||
hyperparameter_shape = closure(self).shape | |||
prior_shape = prior.shape() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did this work at any point? Is this meant to be event_shape
?
Alternatively we would give Prior
s a shape
property that is just self._batch_shape + self._event_shape
.
Yeah, I thought it worked except the box prior stuff, so let me rebase and revisit. |
This is an attempt to address #1317 and #1318 -- incomplete but hopefully helps concretize discussion. I think these bugs basically have to do with unexpected implicit broadcasting behavior of priors (where you think your prior should broadcast over events but it silently broadcasts over batches, or vice versa). So this PR makes prior broadcasting universal and explicit, and throws if it can't expand shapes. Without something like
to_event
(which I think we'd need pyro for) or a restriction to not to broadcast priors over event dimensions, we (still) have a mismatch between gpytorch batch/event shapes and the underlying torch distributions shapes, but I think that the errors are correct (i.e. they are thrown if "event" and "batch" dimensions don't match even if they're not encoded as such attributes).I'm not sure it's the correct thing. For example, you need to
prior.sample()
without args to get a full gpytorch-batch of priors, and if youprior.sample(torch.Size([3])
you get 3 batches of batched priors.Below are some examples of what happens with these changes, based on @gpleiss' taxonomy of parameters that take priors. If they make sense, they can become tests in this PR.
Case 1: Noise prior
[*batch_shape, 1]
And here's @dme65's example:
Case 2: means
[*batch_shape, 1]
, this is basically the same as noise.Case 3: outputscales
[*batch_shape]
Lengthscales
[*batch_shape, 1, d]