Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bump jaxns to >=2.0.1 #1546

Merged
merged 16 commits into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,12 @@ As discussed above, model [reparameterization](https://num.pyro.ai/en/latest/rep
- [HMCGibbs](https://num.pyro.ai/en/latest/mcmc.html#hmcgibbs) combines HMC/NUTS steps with custom Gibbs updates. Gibbs updates must be specified by the user.
- [DiscreteHMCGibbs](https://num.pyro.ai/en/latest/mcmc.html#discretehmcgibbs) combines HMC/NUTS steps with Gibbs updates for discrete latent variables. The corresponding Gibbs updates are computed automatically.
- [SA](https://num.pyro.ai/en/latest/mcmc.html#sa) is the only MCMC method in NumPyro that does not leverage gradients. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities. Note that SA generally requires a *very* large number of samples, as mixing tends to be slow. On the plus side individual steps can be fast.
- [NestedSampler](https://num.pyro.ai/en/latest/contrib.html#nested-sampling) offers a wrapper for [jaxns](https://github.com/Joshuaalbert/jaxns). See [here](https://github.com/pyro-ppl/numpyro/blob/master/examples/gaussian_shells.py) for an example.

Like HMC/NUTS, all remaining MCMC algorithms support enumeration over discrete latent variables if possible (see [restrictions](https://pyro.ai/examples/enumeration.html#Restriction-1:-conditional-independence)). Enumerated sites need to be marked with `infer={'enumerate': 'parallel'}` like in the [annotation example](https://num.pyro.ai/en/stable/examples/annotation.html).

### Nested Sampling
- [NestedSampler](https://num.pyro.ai/en/latest/contrib.html#nested-sampling) offers a wrapper for [jaxns](https://github.com/Joshuaalbert/jaxns). See [JAXNS's readthedocs](https://jaxns.readthedocs.io/en/latest/) for examples and [Nested Sampling for Gaussian Shells](https://num.pyro.ai/en/stable/examples/gaussian_shells.html) example for how to apply the sampler on numpyro models. Can handle arbitrary models, including ones with discrete RVs, and non-invertible transformations.

### Stochastic variational inference
- Variational objectives
- [Trace_ELBO](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.elbo.Trace_ELBO) is our basic ELBO implementation.
Expand Down
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ funsor
ipython<=8.6.0 # strict the version for https://github.com/ipython/ipython/issues/13845
jax
jaxlib
jaxns==1.0.0
jaxns>=2.0.1
Jinja2<3.1
matplotlib
multipledispatch
Expand Down
8 changes: 7 additions & 1 deletion docs/source/mcmc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ We provide a high-level overview of the MCMC algorithms in NumPyro:
* `HMCGibbs <https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.hmc_gibbs.HMCGibbs>`_ combines HMC/NUTS steps with custom Gibbs updates. Gibbs updates must be specified by the user.
* `DiscreteHMCGibbs <https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.hmc_gibbs.DiscreteHMCGibbs>`_ combines HMC/NUTS steps with Gibbs updates for discrete latent variables. The corresponding Gibbs updates are computed automatically.
* `SA <https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.sa.SA>`_ is the only MCMC method in NumPyro that does not leverage gradients. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities. Note that SA generally requires a *very* large number of samples, as mixing tends to be slow. On the plus side individual steps can be fast.
* `NestedSampler <https://num.pyro.ai/en/latest/contrib.html#nested-sampling>`_ offers a wrapper for `jaxns <https://github.com/Joshuaalbert/jaxns>`_. See `here <https://github.com/pyro-ppl/numpyro/blob/master/examples/gaussian_shells.py>`_ for an example.

Like HMC/NUTS, all remaining MCMC algorithms support enumeration over discrete latent variables if possible (see `restrictions <https://pyro.ai/examples/enumeration.html#Restriction-1:-conditional-independence>`_). Enumerated sites need to be marked with `infer={'enumerate': 'parallel'}` like in the `annotation example <https://num.pyro.ai/en/stable/examples/annotation.html>`_.

Expand All @@ -20,6 +19,13 @@ Like HMC/NUTS, all remaining MCMC algorithms support enumeration over discrete l
:show-inheritance:
:member-order: bysource

Nested Sampling
===============================

Nested Sampling is a non-MCMC approach that works for arbitrary probability models, and is particularly well suited to complex posteriors:

* `NestedSampler <https://num.pyro.ai/en/latest/contrib.html#nested-sampling>`_ offers a wrapper for `jaxns <https://github.com/Joshuaalbert/jaxns>`_. See `JAXNS's readthedocs <https://jaxns.readthedocs.io/en/latest/>`_ for examples and `Nested Sampling for Gaussian Shells <https://num.pyro.ai/en/stable/examples/gaussian_shells.html>`_ example for how to apply the sampler on numpyro models. Can handle arbitrary models, including ones with discrete RVs, and non-invertible transformations.

MCMC Kernels
------------

Expand Down
97 changes: 63 additions & 34 deletions numpyro/contrib/nested_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,30 @@

from functools import singledispatch

from jax import nn, random, tree_util
from jax import random
import jax.numpy as jnp

try:
from jaxns import (
NestedSampler as OrigNestedSampler,
ExactNestedSampler as OrigNestedSampler,
Model,
NestedSamplerResults,
Prior,
PriorModelGen,
TerminationCondition,
plot_cornerplot,
plot_diagnostics,
resample,
summary,
)
from jaxns.prior_transforms import ContinuousPrior, PriorChain
from jaxns.prior_transforms.prior import UniformBase
except ImportError as e:
raise ImportError(
"To use this module, please install `jaxns` package. It can be"
" installed with `pip install jaxns`"
" installed with `pip install jaxns` with python>=3.8"
) from e

import tensorflow_probability.substrates.jax as tfp

import numpyro
import numpyro.distributions as dist
from numpyro.handlers import reparam, seed, trace
Expand All @@ -30,14 +36,7 @@

__all__ = ["NestedSampler"]


class UniformPrior(ContinuousPrior):
def __init__(self, name, shape):
prior_base = UniformBase(shape, jnp.result_type(float))
super().__init__(name, shape, parents=[], tracked=True, prior_base=prior_base)

def transform_U(self, U, **kwargs):
return U
tfpd = tfp.distributions


@singledispatch
Expand Down Expand Up @@ -118,8 +117,6 @@ def __call__(self, name, fn, obs):
return None, transform(x)


# TODO: Consider deprecating this wrapper. It might be better to only provide some
# utilities to help converting a NumPyro model to a Jaxns loglikelihood function.
class NestedSampler:
"""
(EXPERIMENTAL) A wrapper for `jaxns <https://github.com/Joshuaalbert/jaxns>`_ ,
Expand Down Expand Up @@ -189,7 +186,7 @@ def __init__(
)
self._samples = None
self._log_weights = None
self._results = None
self._results: NestedSamplerResults | None = None

def run(self, rng_key, *args, **kwargs):
"""
Expand Down Expand Up @@ -246,24 +243,58 @@ def run(self, rng_key, *args, **kwargs):
loglik_fn = local_dict["loglik_fn"]

# use NestedSampler with identity prior chain
prior_chain = PriorChain()
for name in param_names:
prior = UniformPrior(name + "_base", prototype_trace[name]["fn"].shape())
prior_chain.push(prior)
# XXX: the `marginalised` keyword in jaxns can be used to get expectation of some
# quantity over posterior samples; it can be helpful to expose it in this wrapper
ns = OrigNestedSampler(
loglik_fn,
prior_chain,
def prior_model() -> PriorModelGen:
params = []
for name in param_names:
shape = prototype_trace[name]["fn"].shape()
param = yield Prior(
tfpd.Uniform(low=jnp.zeros(shape), high=jnp.ones(shape)),
name=name + "_base",
)
params.append(param)
return tuple(params)

model = Model(prior_model=prior_model, log_likelihood=loglik_fn)

default_constructor_kwargs = dict(
num_live_points=model.U_ndims * 25,
num_parallel_samplers=1,
max_samples=1e4,
uncert_improvement_patience=2,
)
default_termination_kwargs = dict(live_evidence_frac=1e-4)
# Fill-in missing values with defaults. This allows user to inspect what was actually used by inspecting
# these dictionaries
list(
map(
lambda item: self.constructor_kwargs.setdefault(*item),
default_constructor_kwargs.items(),
)
)
list(
map(
lambda item: self.termination_kwargs.setdefault(*item),
default_termination_kwargs.items(),
)
)

exact_ns = OrigNestedSampler(
model=model,
**self.constructor_kwargs,
)
results = ns(rng_sampling, **self.termination_kwargs)

termination_reason, state = exact_ns(
rng_sampling,
term_cond=TerminationCondition(**self.termination_kwargs),
)
results = exact_ns.to_results(state, termination_reason)

# transform base samples back to original domains
# Here we only transform the first valid num_samples samples
# NB: the number of weighted samples obtained from jaxns is results.num_samples
# and only the first num_samples values of results.samples are valid.
num_samples = results.total_num_samples
samples = tree_util.tree_map(lambda x: x[:num_samples], results.samples)
samples = results.samples
Joshuaalbert marked this conversation as resolved.
Show resolved Hide resolved
predictive = Predictive(
reparam_model, samples, return_sites=param_names + deterministics
)
Expand All @@ -283,11 +314,10 @@ def get_samples(self, rng_key, num_samples):
raise RuntimeError(
"NestedSampler.run(...) method should be called first to obtain results."
)

samples, log_weights = self.get_weighted_samples()
p = nn.softmax(log_weights)
idx = random.choice(rng_key, log_weights.shape[0], (num_samples,), p=p)
return {k: v[idx] for k, v in samples.items()}
weighted_samples, sample_weights = self.get_weighted_samples()
return resample(
rng_key, weighted_samples, sample_weights, S=num_samples, replace=True
)

def get_weighted_samples(self):
"""
Expand All @@ -298,8 +328,7 @@ def get_weighted_samples(self):
"NestedSampler.run(...) method should be called first to obtain results."
)

num_samples = self._results.total_num_samples
return self._results.samples, self._results.log_dp_mean[:num_samples]
return self._results.samples, self._results.log_dp_mean
Joshuaalbert marked this conversation as resolved.
Show resolved Hide resolved

def print_summary(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
"flax",
"funsor>=0.4.1",
"graphviz",
"jaxns==1.0.0",
"jaxns>=2.0.1",
"matplotlib",
"optax>=0.0.6",
"pylab-sdk", # jaxns dependency
Expand Down
6 changes: 5 additions & 1 deletion test/contrib/test_nested_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
import jax.numpy as jnp

import numpyro
from numpyro.contrib.nested_sampling import NestedSampler, UniformReparam

try:
from numpyro.contrib.nested_sampling import NestedSampler, UniformReparam
except ImportError:
pytestmark = pytest.mark.skip(reason="jaxns is not installed")
import numpyro.distributions as dist
from numpyro.distributions.transforms import AffineTransform, ExpTransform

Expand Down