-
Notifications
You must be signed in to change notification settings - Fork 561
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add sample method for LKJ correlation prior #1737
Conversation
(I haven't checked the sampling algorithm) |
Yeah, I'm unable to find a good reference for the sampling algorithm, need to find one still. |
Co-authored-by: Max Balandat <[email protected]>
Ended up deciding to check marginal means and variances using the formulas given here https://distribution-explorer.github.io/multivariate_continuous/lkj.html. The tolerances are quite loose especially for the variance, but look decent as you yank up the samples. |
Unit test failure looks unrelated? |
@@ -107,7 +162,10 @@ def __init__(self, n, eta, sd_prior, validate_args=False): | |||
raise ValueError("sd_prior must be an instance of Prior") | |||
if not isinstance(n, int): | |||
raise ValueError("n must be an integer") | |||
if sd_prior.event_shape not in {torch.Size([1]), torch.Size([n])}: | |||
# bug-fix event shapes if necessary |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this is actually a hack around a difference in event shapes between exponential family distributions in torch.distributions
and the event shape in SmoothedBoxPrior
. I think it's a bug in torch.distributions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's the distinction:
from gpytorch.priors import GammaPrior, SmoothedBoxPrior
prior1 = SmoothedBoxPrior(0.1, 1.0)
prior1.event_shape # torch.Size([1])
prior2 = GammaPrior(3.0, 5.0)
prior2.event_shape # torch.Size([])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Yeah so I believe all of torch.distributions
will have an empty event shape if initialized with scalars. So I'm wondering if we should change SmoothedBoxPrior
to do the same.
Co-authored-by: Max Balandat <[email protected]>
gpytorch/priors/lkj_prior.py
Outdated
def rsample(self, sample_shape=torch.Size()): | ||
# mocking the sample call here | ||
# TODO: determine why torch.distributions.LKJCholesky only implements sample | ||
return super().sample(sample_shape=sample_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like the proper thing to do here would be to overwrite the sample
method and then raise a NotImplementedError
for rsample
. Or is there anywhere in the code where we're calling rsample
where we really would be ok with calling sample
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was really a placeholder for checking into why pytorch/pytorch#48798 uses sample instead of rsample (and also the name of the PR in the first place :) ).
Switched all methods over to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switched all methods over to ...sample from ...rsample because I still haven't figured out why torch.distributions.LKJ.. only implements sample.
Thanks, I'd lost track of that issue, can edit the code to note switching to rsample if that goes in. |
Summary: Currently the training data resulted in the MTGP test to only have one task, which will cause errors due to the check for >1 tasks introduced in LKJPrior upstream by cornellius-gp/gpytorch#1737 Reviewed By: j-wilson Differential Revision: D34083515 fbshipit-source-id: 75cb2f95db4c5a2002764507370b710734ecbef9
Adds support for sampling from the LKJ correlation prior. Wasn't sure if the sampled distribution should return correlation or covariance matrices, so I made it return correlation matrices as a default.
Example:
Resolves #1690 and supports multiple batch dimensions.
TODOs: