-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stein mixture #1601
Merged
Merged
Stein mixture #1601
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
ee7a4f9
rewrote stein mixture and added mixture guide predictive.
OlaRonning e9d805a
added test examples
OlaRonning 947f2ec
added mixture guide predictive test.
OlaRonning 69db4f8
fixed sample ordering in `mixture_guide_predictive.py` so model sampl…
OlaRonning File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,26 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from numpyro.contrib.einstein.kernels import ( | ||
from numpyro.contrib.einstein.stein_kernels import ( | ||
GraphicalKernel, | ||
HessianPrecondMatrix, | ||
IMQKernel, | ||
LinearKernel, | ||
PrecondMatrix, | ||
PrecondMatrixKernel, | ||
MixtureKernel, | ||
ProbabilityProductKernel, | ||
RandomFeatureKernel, | ||
RBFKernel, | ||
) | ||
from numpyro.contrib.einstein.stein_loss import SteinLoss | ||
from numpyro.contrib.einstein.steinvi import SteinVI | ||
|
||
__all__ = [ | ||
"SteinVI", | ||
"SteinLoss", | ||
"RBFKernel", | ||
"PrecondMatrix", | ||
"IMQKernel", | ||
"LinearKernel", | ||
"RandomFeatureKernel", | ||
"HessianPrecondMatrix", | ||
"GraphicalKernel", | ||
"PrecondMatrixKernel", | ||
"MixtureKernel", | ||
"ProbabilityProductKernel", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from functools import partial | ||
from typing import Callable, Dict, Optional, Sequence | ||
|
||
from jax import numpy as jnp, random, tree_map, vmap | ||
from jax.tree_util import tree_flatten | ||
|
||
from numpyro.handlers import substitute | ||
from numpyro.infer import Predictive | ||
from numpyro.infer.util import _predictive | ||
|
||
|
||
class MixtureGuidePredictive: | ||
""" | ||
For single mixture component use numpyro.infer.Predictive. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: Callable, | ||
guide: Callable, | ||
params: Dict, | ||
guide_sites: Sequence, | ||
num_samples: Optional[int] = None, | ||
return_sites: Optional[Sequence[str]] = None, | ||
infer_discrete: bool = False, | ||
parallel: bool = False, | ||
mixture_assignment_sitename="mixture_assignments", | ||
): | ||
self.model_predictive = partial( | ||
Predictive, | ||
model=model, | ||
params={ | ||
name: param for name, param in params.items() if name not in guide_sites | ||
}, | ||
num_samples=num_samples, | ||
return_sites=return_sites, | ||
infer_discrete=infer_discrete, | ||
parallel=parallel, | ||
) | ||
self._batch_shape = (num_samples,) | ||
self.parallel = parallel | ||
self.guide_params = { | ||
name: param for name, param in params.items() if name in guide_sites | ||
} | ||
|
||
self.guide = guide | ||
self.return_sites = return_sites | ||
self.num_mixture_components = jnp.shape(tree_flatten(params)[0][0])[0] | ||
self.mixture_assignment_sitename = mixture_assignment_sitename | ||
|
||
def _call_with_params(self, rng_key, params, args, kwargs): | ||
rng_key, guide_rng_key = random.split(rng_key) | ||
# use return_sites='' as a special signal to return all sites | ||
guide = substitute(self.guide, params) | ||
samples = _predictive( | ||
guide_rng_key, | ||
guide, | ||
{}, | ||
self._batch_shape, | ||
return_sites="", | ||
parallel=self.parallel, | ||
model_args=args, | ||
model_kwargs=kwargs, | ||
) | ||
return samples | ||
|
||
def __call__(self, rng_key, *args, **kwargs): | ||
guide_key, assign_key, model_key = random.split(rng_key, 3) | ||
|
||
samples_guide = vmap( | ||
lambda key, params: self._call_with_params( | ||
key, params=params, args=args, kwargs=kwargs | ||
), | ||
in_axes=0, | ||
out_axes=1, | ||
)(random.split(guide_key, self.num_mixture_components), self.guide_params) | ||
|
||
assigns = random.randint( | ||
assign_key, | ||
shape=self._batch_shape, | ||
minval=0, | ||
maxval=self.num_mixture_components, | ||
) | ||
predictive_assign = tree_map( | ||
lambda arr: vmap(lambda i, assign: arr[i, assign])( | ||
jnp.arange(self._batch_shape[0]), assigns | ||
), | ||
samples_guide, | ||
) | ||
predictive_model = self.model_predictive(posterior_samples=predictive_assign) | ||
samples_model = predictive_model(model_key, *args, **kwargs) | ||
if self.return_sites is not None: | ||
samples_guide = { | ||
name: value | ||
for name, value in samples_guide.items() | ||
if name in self.return_sites | ||
} | ||
else: | ||
samples_guide = {} | ||
|
||
return { | ||
self.mixture_assignment_sitename: assigns, | ||
**samples_guide, | ||
**samples_model, # use samples from model if site in model and guide | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fehiepsi I believe we can remove the
nbatch_dims=1
from Predictive. Let me know what you think.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is useful in some cases, e.g. to do predictive with (num_chains, num_samples) or for autoguide posterior samples with multiple sample shapes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel that this is a bit complicated. I think the Predictive class is pretty flexible. Complicated logic to handle mixture stuff should belong to the guide, not the Predictive (correct me if I'm wrong).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, the core is I need the mixture assignments to sample from the model. I specialized the ELBO to SteinLoss account for this.
I suppose the guide is the shared element for both cases (loss and predictive sampling) so hiding the complications there seems more modular. I'll try to lift the guide to a mixture guide (and have it as a SteinVI attribute)--then the lifted guide would work with current Predictive and ELBO. Could also simplify some logic in SteinVI. I'll sketch it and get back.