diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 74d7bd38e..473920f31 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -11,6 +11,7 @@ irmh, mala, meads, + meanfield_vi, mgrad_gaussian, nuts, orbital_hmc, @@ -45,7 +46,8 @@ "pathfinder_adaptation", "adaptive_tempered_smc", # smc "tempered_smc", - "pathfinder", # variational inference + "meanfield_vi", # variational inference + "pathfinder", "ess", # diagnostics "rhat", ] diff --git a/blackjax/base.py b/blackjax/base.py index 59c9a7c35..3a2ffccd9 100644 --- a/blackjax/base.py +++ b/blackjax/base.py @@ -1,5 +1,4 @@ # Copyright 2020- The Blackjax Authors. -# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -136,7 +135,8 @@ class VIAlgorithm(NamedTuple): """ - approximate: Callable + init: Callable + step: Callable sample: Callable diff --git a/blackjax/kernels.py b/blackjax/kernels.py index 9b13ae0d3..ddf62f29e 100644 --- a/blackjax/kernels.py +++ b/blackjax/kernels.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """Blackjax high-level interface with sampling algorithms.""" -from typing import Callable, Dict, NamedTuple, Optional, Union +from typing import Callable, Dict, NamedTuple, Optional, Tuple, Union import jax import jax.numpy as jnp +from optax import GradientTransformation import blackjax.adaptation as adaptation import blackjax.mcmc as mcmc @@ -1251,6 +1252,11 @@ def step_fn(rng_key: PRNGKey, state): # ----------------------------------------------------------------------------- +class PathFinderAlgorithm(NamedTuple): + approximate: Callable + sample: Callable + + class pathfinder: """Implements the (basic) user interface for the pathfinder kernel. @@ -1273,7 +1279,7 @@ class pathfinder: approximate = staticmethod(vi.pathfinder.approximate) sample = staticmethod(vi.pathfinder.sample) - def __new__(cls, logdensity_fn: Callable) -> VIAlgorithm: # type: ignore[misc] + def __new__(cls, logdensity_fn: Callable) -> PathFinderAlgorithm: # type: ignore[misc] def approximate_fn( rng_key: PRNGKey, position: PyTree, @@ -1289,7 +1295,7 @@ def sample_fn( ): return cls.sample(rng_key, state, num_samples) - return VIAlgorithm(approximate_fn, sample_fn) + return PathFinderAlgorithm(approximate_fn, sample_fn) def pathfinder_adaptation( @@ -1385,3 +1391,30 @@ def kernel(rng_key, state): return AdaptationResults(last_chain_state, kernel, parameters) return AdaptationAlgorithm(run) + + +class meanfield_vi: + init = staticmethod(vi.meanfield_vi.init) + step = staticmethod(vi.meanfield_vi.step) + sample = staticmethod(vi.meanfield_vi.sample) + + def __new__( + cls, + logdensity_fn: Callable, + optimizer: GradientTransformation, + num_samples: int = 100, + ): # type: ignore[misc] + def init_fn(position: PyTree): + return cls.init(position, optimizer) + + def step_fn( + rng_key: PRNGKey, state: vi.meanfield_vi.MFVIState + ) -> Tuple[vi.meanfield_vi.MFVIState, vi.meanfield_vi.MFVIInfo]: + return cls.step(rng_key, state, logdensity_fn, optimizer, num_samples) + + def sample_fn( + rng_key: PRNGKey, state: vi.meanfield_vi.MFVIState, num_samples: int + ): + return cls.sample(rng_key, state, num_samples) + + return VIAlgorithm(init_fn, step_fn, sample_fn) diff --git a/blackjax/vi/__init__.py b/blackjax/vi/__init__.py index 2da96c7bb..131fd6c3a 100644 --- a/blackjax/vi/__init__.py +++ b/blackjax/vi/__init__.py @@ -1,3 +1,3 @@ -from . import pathfinder +from . import meanfield_vi, pathfinder -__all__ = ["pathfinder"] +__all__ = ["pathfinder", "meanfield_vi"] diff --git a/blackjax/vi/meanfield_vi.py b/blackjax/vi/meanfield_vi.py new file mode 100644 index 000000000..565507e08 --- /dev/null +++ b/blackjax/vi/meanfield_vi.py @@ -0,0 +1,131 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, NamedTuple, Tuple + +import jax +import jax.numpy as jnp +import jax.scipy as jsp +from optax import GradientTransformation, OptState + +from blackjax.types import PRNGKey, PyTree + +__all__ = ["MFVIState", "MFVIInfo", "sample", "generate_meanfield_logdensity", "step"] + + +class MFVIState(NamedTuple): + mu: PyTree + rho: PyTree + opt_state: OptState + + +class MFVIInfo(NamedTuple): + elbo: float + + +def init( + position: PyTree, + optimizer: GradientTransformation, + *optimizer_args, + **optimizer_kwargs +) -> MFVIState: + """Initialize the mean-field VI state.""" + mu = jax.tree_map(jnp.zeros_like, position) + rho = jax.tree_map(lambda x: -2.0 * jnp.ones_like(x), position) + opt_state = optimizer.init((mu, rho)) + return MFVIState(mu, rho, opt_state) + + +def step( + rng_key: PRNGKey, + state: MFVIState, + logdensity_fn: Callable, + optimizer: GradientTransformation, + num_samples: int = 5, + stl_estimator: bool = True, +) -> Tuple[MFVIState, MFVIInfo]: + """Approximate the target density using the mean-field approximation. + + Parameters + ---------- + rng_key + Key for JAX's pseudo-random number generator. + init_state + Initial state of the mean-field approximation. + logdensity_fn + Function that represents the target log-density to approximate. + optimizer + Optax `GradientTransformation` to be used for optimization. + num_samples + The number of samples that are taken from the approximation + at each step to compute the Kullback-Leibler divergence between + the approximation and the target log-density. + stl_estimator + Whether to use stick-the-landing (STL) gradient estimator [1] for gradient estimation. + The STL estimator has lower gradient variance by removing the score function term + from the gradient. It is suggested by [2] to always keep it in order for better results. + + References + ---------- + .. [1]: Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). + Sticking the landing: Simple, lower-variance gradient estimators for variational inference. + Advances in Neural Information Processing Systems, 30. + .. [2]: Agrawal, A., Sheldon, D. R., & Domke, J. (2020). + Advances in black-box VI: Normalizing flows, importance weighting, and optimization. + Advances in Neural Information Processing Systems, 33. + """ + + parameters = (state.mu, state.rho) + + def kl_divergence_fn(parameters): + mu, rho = parameters + z = _sample(rng_key, mu, rho, num_samples) + if stl_estimator: + mu = jax.lax.stop_gradient(mu) + rho = jax.lax.stop_gradient(rho) + logq = jax.vmap(generate_meanfield_logdensity(mu, rho))(z) + logp = jax.vmap(logdensity_fn)(z) + return (logq - logp).mean() + + elbo, elbo_grad = jax.value_and_grad(kl_divergence_fn)(parameters) + updates, new_opt_state = optimizer.update(elbo_grad, state.opt_state, parameters) + new_parameters = jax.tree_map(lambda p, u: p + u, parameters, updates) + new_state = MFVIState(new_parameters[0], new_parameters[1], new_opt_state) + return new_state, MFVIInfo(elbo) + + +def sample(rng_key: PRNGKey, state: MFVIState, num_samples: int = 1): + """Sample from the mean-field approximation.""" + return _sample(rng_key, state.mu, state.rho, num_samples) + + +def _sample(rng_key, mu, rho, num_samples): + sigma = jax.tree_map(jnp.exp, rho) + mu_flatten, unravel_fn = jax.flatten_util.ravel_pytree(mu) + sigma_flat, _ = jax.flatten_util.ravel_pytree(sigma) + flatten_sample = ( + jax.random.normal(rng_key, (num_samples,) + mu_flatten.shape) * sigma_flat + + mu_flatten + ) + return jax.vmap(unravel_fn)(flatten_sample) + + +def generate_meanfield_logdensity(mu, rho): + sigma_param = jax.tree_map(jnp.exp, rho) + + def meanfield_logdensity(position): + logq_pytree = jax.tree_map(jsp.stats.norm.logpdf, position, mu, sigma_param) + logq = jax.tree_map(jnp.sum, logq_pytree) + return jax.tree_util.tree_reduce(jnp.add, logq) + + return meanfield_logdensity diff --git a/docs/vi.rst b/docs/vi.rst index 55fe97f54..48a9284e1 100644 --- a/docs/vi.rst +++ b/docs/vi.rst @@ -7,9 +7,15 @@ Variational Inference :nosignatures: pathfinder + meanfield_vi Pathfinder ~~~~~~~~~~ .. autoclass:: blackjax.pathfinder + +Mean-field VI +~~~~~~~~~~~~~ + +.. autoclass:: blackjax.meanfield_vi diff --git a/pyproject.toml b/pyproject.toml index ab6aabb7d..30637decb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "jax>=0.3.13", "jaxlib>=0.3.10", "jaxopt>=0.5.5", + "optax", "typing-extensions>=4.4.0", ] dynamic = ["version"] diff --git a/pytest.ini b/pytest.ini index 3f40f905c..46acb2787 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,5 +1,4 @@ [pytest] -addopts = -n auto testpaths= "tests" filterwarnings = error diff --git a/tests/test_meanfield_vi.py b/tests/test_meanfield_vi.py new file mode 100644 index 000000000..3553b5c49 --- /dev/null +++ b/tests/test_meanfield_vi.py @@ -0,0 +1,53 @@ +import chex +import jax +import jax.numpy as jnp +import jax.scipy.stats as stats +import optax +from absl.testing import absltest + +import blackjax + + +class MFVITest(chex.TestCase): + def setUp(self): + super().setUp() + self.key = jax.random.PRNGKey(42) + + def test_recover_posterior(self): + ground_truth = [ + # loc, scale + (2, 4), + (3, 5), + ] + + def logdensity_fn(x): + logpdf = stats.norm.logpdf(x["x_1"], *ground_truth[0]) + stats.norm.logpdf( + x["x_2"], *ground_truth[1] + ) + return jnp.sum(logpdf) + + initial_position = {"x_1": 0.0, "x_2": 0.0} + + num_steps = 50_000 + num_samples = 500 + + optimizer = optax.sgd(1e-2) + mfvi = blackjax.meanfield_vi(logdensity_fn, optimizer, num_samples) + state = mfvi.init(initial_position) + + rng_key = self.key + for _ in range(num_steps): + rng_key, _ = jax.random.split(rng_key) + state, _ = jax.jit(mfvi.step)(self.key, state) + + loc_1, loc_2 = state.mu["x_1"], state.mu["x_2"] + scale = jax.tree_map(jnp.exp, state.rho) + scale_1, scale_2 = scale["x_1"], scale["x_2"] + self.assertAlmostEqual(loc_1, ground_truth[0][0], delta=0.01) + self.assertAlmostEqual(scale_1, ground_truth[0][1], delta=0.01) + self.assertAlmostEqual(loc_2, ground_truth[1][0], delta=0.01) + self.assertAlmostEqual(scale_2, ground_truth[1][1], delta=0.01) + + +if __name__ == "__main__": + absltest.main()