-
Notifications
You must be signed in to change notification settings - Fork 246
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* rewrote stein mixture and added mixture guide predictive. * added test examples * added mixture guide predictive test. * fixed sample ordering in `mixture_guide_predictive.py` so model samples are returned when sites are in guide and model.
- Loading branch information
1 parent
523162f
commit 0b9e0f0
Showing
15 changed files
with
609 additions
and
334 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,26 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from numpyro.contrib.einstein.kernels import ( | ||
from numpyro.contrib.einstein.stein_kernels import ( | ||
GraphicalKernel, | ||
HessianPrecondMatrix, | ||
IMQKernel, | ||
LinearKernel, | ||
PrecondMatrix, | ||
PrecondMatrixKernel, | ||
MixtureKernel, | ||
ProbabilityProductKernel, | ||
RandomFeatureKernel, | ||
RBFKernel, | ||
) | ||
from numpyro.contrib.einstein.stein_loss import SteinLoss | ||
from numpyro.contrib.einstein.steinvi import SteinVI | ||
|
||
__all__ = [ | ||
"SteinVI", | ||
"SteinLoss", | ||
"RBFKernel", | ||
"PrecondMatrix", | ||
"IMQKernel", | ||
"LinearKernel", | ||
"RandomFeatureKernel", | ||
"HessianPrecondMatrix", | ||
"GraphicalKernel", | ||
"PrecondMatrixKernel", | ||
"MixtureKernel", | ||
"ProbabilityProductKernel", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from functools import partial | ||
from typing import Callable, Dict, Optional, Sequence | ||
|
||
from jax import numpy as jnp, random, tree_map, vmap | ||
from jax.tree_util import tree_flatten | ||
|
||
from numpyro.handlers import substitute | ||
from numpyro.infer import Predictive | ||
from numpyro.infer.util import _predictive | ||
|
||
|
||
class MixtureGuidePredictive: | ||
""" | ||
For single mixture component use numpyro.infer.Predictive. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: Callable, | ||
guide: Callable, | ||
params: Dict, | ||
guide_sites: Sequence, | ||
num_samples: Optional[int] = None, | ||
return_sites: Optional[Sequence[str]] = None, | ||
infer_discrete: bool = False, | ||
parallel: bool = False, | ||
mixture_assignment_sitename="mixture_assignments", | ||
): | ||
self.model_predictive = partial( | ||
Predictive, | ||
model=model, | ||
params={ | ||
name: param for name, param in params.items() if name not in guide_sites | ||
}, | ||
num_samples=num_samples, | ||
return_sites=return_sites, | ||
infer_discrete=infer_discrete, | ||
parallel=parallel, | ||
) | ||
self._batch_shape = (num_samples,) | ||
self.parallel = parallel | ||
self.guide_params = { | ||
name: param for name, param in params.items() if name in guide_sites | ||
} | ||
|
||
self.guide = guide | ||
self.return_sites = return_sites | ||
self.num_mixture_components = jnp.shape(tree_flatten(params)[0][0])[0] | ||
self.mixture_assignment_sitename = mixture_assignment_sitename | ||
|
||
def _call_with_params(self, rng_key, params, args, kwargs): | ||
rng_key, guide_rng_key = random.split(rng_key) | ||
# use return_sites='' as a special signal to return all sites | ||
guide = substitute(self.guide, params) | ||
samples = _predictive( | ||
guide_rng_key, | ||
guide, | ||
{}, | ||
self._batch_shape, | ||
return_sites="", | ||
parallel=self.parallel, | ||
model_args=args, | ||
model_kwargs=kwargs, | ||
) | ||
return samples | ||
|
||
def __call__(self, rng_key, *args, **kwargs): | ||
guide_key, assign_key, model_key = random.split(rng_key, 3) | ||
|
||
samples_guide = vmap( | ||
lambda key, params: self._call_with_params( | ||
key, params=params, args=args, kwargs=kwargs | ||
), | ||
in_axes=0, | ||
out_axes=1, | ||
)(random.split(guide_key, self.num_mixture_components), self.guide_params) | ||
|
||
assigns = random.randint( | ||
assign_key, | ||
shape=self._batch_shape, | ||
minval=0, | ||
maxval=self.num_mixture_components, | ||
) | ||
predictive_assign = tree_map( | ||
lambda arr: vmap(lambda i, assign: arr[i, assign])( | ||
jnp.arange(self._batch_shape[0]), assigns | ||
), | ||
samples_guide, | ||
) | ||
predictive_model = self.model_predictive(posterior_samples=predictive_assign) | ||
samples_model = predictive_model(model_key, *args, **kwargs) | ||
if self.return_sites is not None: | ||
samples_guide = { | ||
name: value | ||
for name, value in samples_guide.items() | ||
if name in self.return_sites | ||
} | ||
else: | ||
samples_guide = {} | ||
|
||
return { | ||
self.mixture_assignment_sitename: assigns, | ||
**samples_guide, | ||
**samples_model, # use samples from model if site in model and guide | ||
} |
Oops, something went wrong.