diff --git a/README.md b/README.md index a78c439fa..4dced1df7 100644 --- a/README.md +++ b/README.md @@ -196,10 +196,12 @@ As discussed above, model [reparameterization](https://num.pyro.ai/en/latest/rep - [HMCGibbs](https://num.pyro.ai/en/latest/mcmc.html#hmcgibbs) combines HMC/NUTS steps with custom Gibbs updates. Gibbs updates must be specified by the user. - [DiscreteHMCGibbs](https://num.pyro.ai/en/latest/mcmc.html#discretehmcgibbs) combines HMC/NUTS steps with Gibbs updates for discrete latent variables. The corresponding Gibbs updates are computed automatically. - [SA](https://num.pyro.ai/en/latest/mcmc.html#sa) is the only MCMC method in NumPyro that does not leverage gradients. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities. Note that SA generally requires a *very* large number of samples, as mixing tends to be slow. On the plus side individual steps can be fast. -- [NestedSampler](https://num.pyro.ai/en/latest/contrib.html#nested-sampling) offers a wrapper for [jaxns](https://github.com/Joshuaalbert/jaxns). See [here](https://github.com/pyro-ppl/numpyro/blob/master/examples/gaussian_shells.py) for an example. Like HMC/NUTS, all remaining MCMC algorithms support enumeration over discrete latent variables if possible (see [restrictions](https://pyro.ai/examples/enumeration.html#Restriction-1:-conditional-independence)). Enumerated sites need to be marked with `infer={'enumerate': 'parallel'}` like in the [annotation example](https://num.pyro.ai/en/stable/examples/annotation.html). +### Nested Sampling +- [NestedSampler](https://num.pyro.ai/en/latest/contrib.html#nested-sampling) offers a wrapper for [jaxns](https://github.com/Joshuaalbert/jaxns). See [JAXNS's readthedocs](https://jaxns.readthedocs.io/en/latest/) for examples and [Nested Sampling for Gaussian Shells](https://num.pyro.ai/en/stable/examples/gaussian_shells.html) example for how to apply the sampler on numpyro models. Can handle arbitrary models, including ones with discrete RVs, and non-invertible transformations. + ### Stochastic variational inference - Variational objectives - [Trace_ELBO](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.elbo.Trace_ELBO) is our basic ELBO implementation. diff --git a/docs/requirements.txt b/docs/requirements.txt index 8d7ca8279..36b2fb48b 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,7 +4,7 @@ funsor ipython<=8.6.0 # strict the version for https://github.com/ipython/ipython/issues/13845 jax jaxlib -jaxns==1.0.0 +jaxns>=2.0.1 Jinja2<3.1 matplotlib multipledispatch diff --git a/docs/source/mcmc.rst b/docs/source/mcmc.rst index 9d6d1b14a..09f53f159 100644 --- a/docs/source/mcmc.rst +++ b/docs/source/mcmc.rst @@ -10,7 +10,6 @@ We provide a high-level overview of the MCMC algorithms in NumPyro: * `HMCGibbs `_ combines HMC/NUTS steps with custom Gibbs updates. Gibbs updates must be specified by the user. * `DiscreteHMCGibbs `_ combines HMC/NUTS steps with Gibbs updates for discrete latent variables. The corresponding Gibbs updates are computed automatically. * `SA `_ is the only MCMC method in NumPyro that does not leverage gradients. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities. Note that SA generally requires a *very* large number of samples, as mixing tends to be slow. On the plus side individual steps can be fast. -* `NestedSampler `_ offers a wrapper for `jaxns `_. See `here `_ for an example. Like HMC/NUTS, all remaining MCMC algorithms support enumeration over discrete latent variables if possible (see `restrictions `_). Enumerated sites need to be marked with `infer={'enumerate': 'parallel'}` like in the `annotation example `_. @@ -20,6 +19,13 @@ Like HMC/NUTS, all remaining MCMC algorithms support enumeration over discrete l :show-inheritance: :member-order: bysource +Nested Sampling +=============================== + +Nested Sampling is a non-MCMC approach that works for arbitrary probability models, and is particularly well suited to complex posteriors: + +* `NestedSampler `_ offers a wrapper for `jaxns `_. See `JAXNS's readthedocs `_ for examples and `Nested Sampling for Gaussian Shells `_ example for how to apply the sampler on numpyro models. Can handle arbitrary models, including ones with discrete RVs, and non-invertible transformations. + MCMC Kernels ------------ diff --git a/numpyro/contrib/nested_sampling.py b/numpyro/contrib/nested_sampling.py index cc9065b5d..ca7005d65 100644 --- a/numpyro/contrib/nested_sampling.py +++ b/numpyro/contrib/nested_sampling.py @@ -3,24 +3,30 @@ from functools import singledispatch -from jax import nn, random, tree_util +from jax import random import jax.numpy as jnp try: from jaxns import ( - NestedSampler as OrigNestedSampler, + ExactNestedSampler as OrigNestedSampler, + Model, + NestedSamplerResults, + Prior, + PriorModelGen, + TerminationCondition, plot_cornerplot, plot_diagnostics, + resample, summary, ) - from jaxns.prior_transforms import ContinuousPrior, PriorChain - from jaxns.prior_transforms.prior import UniformBase except ImportError as e: raise ImportError( "To use this module, please install `jaxns` package. It can be" - " installed with `pip install jaxns`" + " installed with `pip install jaxns` with python>=3.8" ) from e +import tensorflow_probability.substrates.jax as tfp + import numpyro import numpyro.distributions as dist from numpyro.handlers import reparam, seed, trace @@ -30,14 +36,7 @@ __all__ = ["NestedSampler"] - -class UniformPrior(ContinuousPrior): - def __init__(self, name, shape): - prior_base = UniformBase(shape, jnp.result_type(float)) - super().__init__(name, shape, parents=[], tracked=True, prior_base=prior_base) - - def transform_U(self, U, **kwargs): - return U +tfpd = tfp.distributions @singledispatch @@ -118,8 +117,6 @@ def __call__(self, name, fn, obs): return None, transform(x) -# TODO: Consider deprecating this wrapper. It might be better to only provide some -# utilities to help converting a NumPyro model to a Jaxns loglikelihood function. class NestedSampler: """ (EXPERIMENTAL) A wrapper for `jaxns `_ , @@ -189,7 +186,7 @@ def __init__( ) self._samples = None self._log_weights = None - self._results = None + self._results: NestedSamplerResults | None = None def run(self, rng_key, *args, **kwargs): """ @@ -246,24 +243,58 @@ def run(self, rng_key, *args, **kwargs): loglik_fn = local_dict["loglik_fn"] # use NestedSampler with identity prior chain - prior_chain = PriorChain() - for name in param_names: - prior = UniformPrior(name + "_base", prototype_trace[name]["fn"].shape()) - prior_chain.push(prior) - # XXX: the `marginalised` keyword in jaxns can be used to get expectation of some - # quantity over posterior samples; it can be helpful to expose it in this wrapper - ns = OrigNestedSampler( - loglik_fn, - prior_chain, + def prior_model() -> PriorModelGen: + params = [] + for name in param_names: + shape = prototype_trace[name]["fn"].shape() + param = yield Prior( + tfpd.Uniform(low=jnp.zeros(shape), high=jnp.ones(shape)), + name=name + "_base", + ) + params.append(param) + return tuple(params) + + model = Model(prior_model=prior_model, log_likelihood=loglik_fn) + + default_constructor_kwargs = dict( + num_live_points=model.U_ndims * 25, + num_parallel_samplers=1, + max_samples=1e4, + uncert_improvement_patience=2, + ) + default_termination_kwargs = dict(live_evidence_frac=1e-4) + # Fill-in missing values with defaults. This allows user to inspect what was actually used by inspecting + # these dictionaries + list( + map( + lambda item: self.constructor_kwargs.setdefault(*item), + default_constructor_kwargs.items(), + ) + ) + list( + map( + lambda item: self.termination_kwargs.setdefault(*item), + default_termination_kwargs.items(), + ) + ) + + exact_ns = OrigNestedSampler( + model=model, **self.constructor_kwargs, ) - results = ns(rng_sampling, **self.termination_kwargs) + + termination_reason, state = exact_ns( + rng_sampling, + term_cond=TerminationCondition(**self.termination_kwargs), + ) + results = exact_ns.to_results(state, termination_reason) + # transform base samples back to original domains # Here we only transform the first valid num_samples samples # NB: the number of weighted samples obtained from jaxns is results.num_samples # and only the first num_samples values of results.samples are valid. num_samples = results.total_num_samples - samples = tree_util.tree_map(lambda x: x[:num_samples], results.samples) + samples = results.samples predictive = Predictive( reparam_model, samples, return_sites=param_names + deterministics ) @@ -283,11 +314,10 @@ def get_samples(self, rng_key, num_samples): raise RuntimeError( "NestedSampler.run(...) method should be called first to obtain results." ) - - samples, log_weights = self.get_weighted_samples() - p = nn.softmax(log_weights) - idx = random.choice(rng_key, log_weights.shape[0], (num_samples,), p=p) - return {k: v[idx] for k, v in samples.items()} + weighted_samples, sample_weights = self.get_weighted_samples() + return resample( + rng_key, weighted_samples, sample_weights, S=num_samples, replace=True + ) def get_weighted_samples(self): """ @@ -298,8 +328,7 @@ def get_weighted_samples(self): "NestedSampler.run(...) method should be called first to obtain results." ) - num_samples = self._results.total_num_samples - return self._results.samples, self._results.log_dp_mean[:num_samples] + return self._results.samples, self._results.log_dp_mean def print_summary(self): """ diff --git a/setup.py b/setup.py index 8532e4e83..bba87d335 100644 --- a/setup.py +++ b/setup.py @@ -63,7 +63,7 @@ "flax", "funsor>=0.4.1", "graphviz", - "jaxns==1.0.0", + "jaxns>=2.0.1", "matplotlib", "optax>=0.0.6", "pylab-sdk", # jaxns dependency diff --git a/test/contrib/test_nested_sampling.py b/test/contrib/test_nested_sampling.py index 17cfde6bf..787cf46cb 100644 --- a/test/contrib/test_nested_sampling.py +++ b/test/contrib/test_nested_sampling.py @@ -9,7 +9,11 @@ import jax.numpy as jnp import numpyro -from numpyro.contrib.nested_sampling import NestedSampler, UniformReparam + +try: + from numpyro.contrib.nested_sampling import NestedSampler, UniformReparam +except ImportError: + pytestmark = pytest.mark.skip(reason="jaxns is not installed") import numpyro.distributions as dist from numpyro.distributions.transforms import AffineTransform, ExpTransform