-
Notifications
You must be signed in to change notification settings - Fork 561
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug] Prior losses are incorrectly added to the mll in batch-mode #1318
Comments
Potentially related to #1317. |
Amazingly I also came across this bug yesterday and was about to post a bug report. A simple (hacky) fix I'm currently using is to replace:
with
in the MLL, e.g. here. I think this deals with all use-cases, but it is particually ugly. |
I think this is a bit complicated because of the event/batch shapes of the different parameters.
I'm not sure exactly what we should do to resolve these issues. |
We need to make priors aware of the shapes we expect. We can either do this manually, or using the same kind of logic we use to accomplish this for pyro integration using the fact that priors have expand methods: Line 445 in 4912fe9
We can either do this when registering them in the individual classes or in Module by here: Line 235 in 4912fe9
What do you think? |
Oops, I commented on #1317 without looking at updates here. What about the case where we're using a batched univariate prior for a vector-valued hyperparam (like lengthscales with ARD)? Seems like we shouldn't always expand. |
This was fixed in #2039. Feel free to close the issue! |
🐛 Bug
Prior losses are currently being added up incorrectly in
ExactMarginalLogLikelihood
. The line:will sum up all of the losses and then add them to the mll. If you are using a batch model this sum gets added to all of the batch dimensions which will count the losses multiple times when eventually calling
loss.sum().backward()
. It looks like the priors may not support batch mode which leads to a large variety of different shapes, but the.sum()
call masks this issue since it just sums everything up anyway.To reproduce
Code snippet (taken from
test_train_on_batch_test_on_batch
):Output:
Expected Behavior
The prior losses should have the same size and be added up via
res.add_(prior.log_prob(closure()))
without the inner sum call.System information
Please complete the following information:
GPyTorch Version: 1.2.0
PyTorch Version: 1.6.0
Mac
Additional context
This was originally discovered in PR #1314.
cc: @Balandat
The text was updated successfully, but these errors were encountered: