Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sample method for LKJ correlation prior #1737

Merged
merged 28 commits into from
Feb 8, 2022

Conversation

wjmaddox
Copy link
Collaborator

@wjmaddox wjmaddox commented Aug 25, 2021

Adds support for sampling from the LKJ correlation prior. Wasn't sure if the sampled distribution should return correlation or covariance matrices, so I made it return correlation matrices as a default.

Example:

from gpytorch.priors import LKJPrior

prior = LKJPrior(50, 0.3)

samples = prior.rsample(torch.Size((256,)))
samples.shape #256 x 50 x 50

samples.diagonal(dim2=-2, dim1=-1) # all ones

Resolves #1690 and supports multiple batch dimensions.

TODOs:

  • verification that this is the correct sampling algorithm
  • unit tests

gpytorch/priors/lkj_prior.py Outdated Show resolved Hide resolved
gpytorch/priors/lkj_prior.py Outdated Show resolved Hide resolved
gpytorch/priors/lkj_prior.py Outdated Show resolved Hide resolved
@Balandat
Copy link
Collaborator

Balandat commented Sep 1, 2021

(I haven't checked the sampling algorithm)

@wjmaddox
Copy link
Collaborator Author

wjmaddox commented Sep 2, 2021

Yeah, I'm unable to find a good reference for the sampling algorithm, need to find one still.

@wjmaddox
Copy link
Collaborator Author

Ended up deciding to check marginal means and variances using the formulas given here https://distribution-explorer.github.io/multivariate_continuous/lkj.html. The tolerances are quite loose especially for the variance, but look decent as you yank up the samples.

@wjmaddox wjmaddox requested a review from Balandat November 23, 2021 18:48
@wjmaddox wjmaddox changed the title [Draft] Add rsample method for LKJ correlation prior Add rsample method for LKJ correlation prior Nov 23, 2021
@wjmaddox
Copy link
Collaborator Author

Unit test failure looks unrelated?

gpytorch/priors/lkj_prior.py Outdated Show resolved Hide resolved
gpytorch/priors/lkj_prior.py Outdated Show resolved Hide resolved
gpytorch/priors/lkj_prior.py Outdated Show resolved Hide resolved
gpytorch/priors/lkj_prior.py Outdated Show resolved Hide resolved
gpytorch/priors/lkj_prior.py Outdated Show resolved Hide resolved
@@ -107,7 +162,10 @@ def __init__(self, n, eta, sd_prior, validate_args=False):
raise ValueError("sd_prior must be an instance of Prior")
if not isinstance(n, int):
raise ValueError("n must be an integer")
if sd_prior.event_shape not in {torch.Size([1]), torch.Size([n])}:
# bug-fix event shapes if necessary
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this basically a hack around #1317/#1318?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is actually a hack around a difference in event shapes between exponential family distributions in torch.distributions and the event shape in SmoothedBoxPrior. I think it's a bug in torch.distributions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's the distinction:

from gpytorch.priors import GammaPrior, SmoothedBoxPrior
prior1 = SmoothedBoxPrior(0.1, 1.0)
prior1.event_shape # torch.Size([1])

prior2 = GammaPrior(3.0, 5.0)
prior2.event_shape # torch.Size([])

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Yeah so I believe all of torch.distributions will have an empty event shape if initialized with scalars. So I'm wondering if we should change SmoothedBoxPrior to do the same.

gpytorch/priors/lkj_prior.py Outdated Show resolved Hide resolved
Comment on lines 38 to 41
def rsample(self, sample_shape=torch.Size()):
# mocking the sample call here
# TODO: determine why torch.distributions.LKJCholesky only implements sample
return super().sample(sample_shape=sample_shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like the proper thing to do here would be to overwrite the sample method and then raise a NotImplementedError for rsample. Or is there anywhere in the code where we're calling rsample where we really would be ok with calling sample?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was really a placeholder for checking into why pytorch/pytorch#48798 uses sample instead of rsample (and also the name of the PR in the first place :) ).

test/priors/test_lkj_prior.py Outdated Show resolved Hide resolved
@wjmaddox wjmaddox changed the title Add rsample method for LKJ correlation prior Add sample method for LKJ correlation prior Feb 3, 2022
@wjmaddox
Copy link
Collaborator Author

wjmaddox commented Feb 3, 2022

Switched all methods over to ...sample from ...rsample because I still haven't figured out why torch.distributions.LKJ.. only implements sample.

Copy link
Collaborator

@Balandat Balandat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched all methods over to ...sample from ...rsample because I still haven't figured out why torch.distributions.LKJ.. only implements sample.

See pytorch/pytorch#69281

test/priors/test_lkj_prior.py Outdated Show resolved Hide resolved
@wjmaddox
Copy link
Collaborator Author

wjmaddox commented Feb 5, 2022

Thanks, I'd lost track of that issue, can edit the code to note switching to rsample if that goes in.

@Balandat Balandat merged commit 4df3048 into cornellius-gp:master Feb 8, 2022
facebook-github-bot pushed a commit to facebook/Ax that referenced this pull request Feb 9, 2022
Summary: Currently the training data resulted in the MTGP test to only have one task, which will cause errors due to the check for >1 tasks introduced in LKJPrior upstream by cornellius-gp/gpytorch#1737

Reviewed By: j-wilson

Differential Revision: D34083515

fbshipit-source-id: 75cb2f95db4c5a2002764507370b710734ecbef9
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Trying to sample from the LKJCovariancePrior
2 participants