Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SGNHT #515

Merged
merged 8 commits into from
May 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .sgmcmc.csgld import csgld
from .sgmcmc.sghmc import sghmc
from .sgmcmc.sgld import sgld
from .sgmcmc.sgnht import sgnht
from .smc.adaptive_tempered import adaptive_tempered_smc
from .smc.tempered import tempered_smc
from .vi.meanfield_vi import meanfield_vi
Expand All @@ -38,6 +39,7 @@
"ghmc",
"sgld", # stochastic gradient mcmc
"sghmc",
"sgnht",
"csgld",
"window_adaptation", # mcmc adaptation
"meads_adaptation",
Expand Down
4 changes: 2 additions & 2 deletions blackjax/sgmcmc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import csgld, sghmc, sgld
from . import csgld, sghmc, sgld, sgnht
from .gradients import grad_estimator, logdensity_estimator

__all__ = ["grad_estimator", "logdensity_estimator", "csgld", "sgld", "sghmc"]
__all__ = ["grad_estimator", "logdensity_estimator", "csgld", "sgld", "sghmc", "sgnht"]
55 changes: 49 additions & 6 deletions blackjax/sgmcmc/diffusions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Solvers for Langevin diffusions."""
import operator

import jax
import jax.numpy as jnp

from blackjax.types import PRNGKey, PyTree
from blackjax.util import generate_gaussian_noise
from blackjax.util import generate_gaussian_noise, pytree_size

__all__ = ["overdamped_langevin", "sghmc"]
__all__ = ["overdamped_langevin", "sghmc", "sgnht"]


def overdamped_langevin():
Expand Down Expand Up @@ -51,7 +53,8 @@ def one_step(


def sghmc(alpha: float = 0.01, beta: float = 0):
"""Solver for the diffusion equation of the SGHMC algorithm :cite:p:`chen2014stochastic`.
"""Euler solver for the diffusion equation of the SGHMC algorithm :cite:p:`chen2014stochastic`,
with parameters alpha and beta scaled according to :cite:p:`ma2015complete`.

This algorithm was ported from :cite:p:`coullon2022sgmcmcjax`.

Expand All @@ -66,11 +69,13 @@ def one_step(
temperature: float = 1.0,
):
noise = generate_gaussian_noise(rng_key, position)
position = jax.tree_util.tree_map(lambda x, p: x + p, position, momentum)
position = jax.tree_util.tree_map(
lambda x, p: x + step_size * p, position, momentum
)
momentum = jax.tree_util.tree_map(
lambda p, g, n: (1.0 - alpha) * p
lambda p, g, n: (1.0 - alpha * step_size) * p
+ step_size * g
+ jnp.sqrt(2 * step_size * (alpha - beta) * temperature) * n,
+ jnp.sqrt(step_size * (2 * alpha - step_size * beta) * temperature) * n,
momentum,
logdensity_grad,
noise,
Expand All @@ -79,3 +84,41 @@ def one_step(
return position, momentum

return one_step


def sgnht(alpha: float = 0.01, beta: float = 0):
"""Euler solver for the diffusion equation of the SGNHT algorithm :cite:p:`ding2014bayesian`.

This algorithm was ported from :cite:p:`coullon2022sgmcmcjax`.

"""

def one_step(
rng_key: PRNGKey,
position: PyTree,
momentum: PyTree,
xi: float,
logdensity_grad: PyTree,
step_size: float,
temperature: float = 1.0,
):
noise = generate_gaussian_noise(rng_key, position)
position = jax.tree_util.tree_map(
lambda x, p: x + step_size * p, position, momentum
)
momentum = jax.tree_util.tree_map(
lambda p, g, n: (1.0 - xi * step_size) * p
+ step_size * g
+ jnp.sqrt(step_size * (2 * alpha - step_size * beta) * temperature) * n,
momentum,
logdensity_grad,
noise,
)
momentum_dot = jax.tree_util.tree_reduce(
operator.add, jax.tree_util.tree_map(lambda x: jnp.sum(x * x), momentum)
)
d = pytree_size(momentum)
xi = xi + step_size * (momentum_dot / d - 1)
return position, momentum, xi

return one_step
7 changes: 3 additions & 4 deletions blackjax/sgmcmc/sghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import Callable

import jax
import jax.numpy as jnp

import blackjax.sgmcmc.diffusions as diffusions
from blackjax.types import PRNGKey, PyTree
Expand Down Expand Up @@ -45,7 +44,7 @@ def body_fn(state, rng_key):
)
return ((position, momentum), position)

momentum = generate_gaussian_noise(rng_key, position, 0, jnp.sqrt(step_size))
momentum = generate_gaussian_noise(rng_key, position)
keys = jax.random.split(rng_key, num_integration_steps)
(position, momentum), _ = jax.lax.scan(body_fn, (position, momentum), keys)

Expand All @@ -57,7 +56,7 @@ def body_fn(state, rng_key):
class sghmc:
"""Implements the (basic) user interface for the SGHMC kernel.

The general sghmc kernel builder (:meth:`blackjax.mcmc.sghmc.build_kernel`, alias
The general sghmc kernel builder (:meth:`blackjax.sgmcmc.sghmc.build_kernel`, alias
`blackjax.sghmc.build_kernel`) can be cumbersome to manipulate. Since most users
only need to specify the kernel parameters at initialization time, we
provide a helper function that specializes the general kernel.
Expand Down Expand Up @@ -103,7 +102,7 @@ class sghmc:

Returns
-------
A ``MCMCSamplingAlgorithm``.
A step function.

"""

Expand Down
4 changes: 2 additions & 2 deletions blackjax/sgmcmc/sgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def kernel(
class sgld:
"""Implements the (basic) user interface for the SGLD kernel.

The general sgld kernel builder (:meth:`blackjax.mcmc.sgld.build_kernel`, alias
The general sgld kernel builder (:meth:`blackjax.sgmcmc.sgld.build_kernel`, alias
`blackjax.sgld.build_kernel`) can be cumbersome to manipulate. Since most users
only need to specify the kernel parameters at initialization time, we
provide a helper function that specializes the general kernel.
Expand Down Expand Up @@ -91,7 +91,7 @@ class sgld:

Returns
-------
A ``MCMCSamplingAlgorithm``.
A step function.

"""

Expand Down
145 changes: 145 additions & 0 deletions blackjax/sgmcmc/sgnht.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Public API for the Stochastic gradient Nosé-Hoover Thermostat kernel."""
from typing import Callable, NamedTuple

import blackjax.sgmcmc.diffusions as diffusions
from blackjax.base import MCMCSamplingAlgorithm
from blackjax.types import PRNGKey, PyTree
from blackjax.util import generate_gaussian_noise

__all__ = ["SGNHTState", "init", "build_kernel", "sgnht"]


class SGNHTState(NamedTuple):
r"""State of the SGNHT algorithm.

Parameters
----------
position
Current position in the sample space.
momentum
Current momentum in the sample space.
xi
Scalar thermostat controlling kinetic energy.

"""
position: PyTree
momentum: PyTree
xi: float


def init(rng_key: PRNGKey, position: PyTree, alpha: float = 0.01):
momentum = generate_gaussian_noise(rng_key, position)
return SGNHTState(position, momentum, alpha)


def build_kernel(alpha: float = 0.01, beta: float = 0) -> Callable:
"""Stochastic gradient Nosé-Hoover Thermostat (SGNHT) algorithm."""
integrator = diffusions.sgnht(alpha, beta)

def kernel(
rng_key: PRNGKey,
state: SGNHTState,
grad_estimator: Callable,
minibatch: PyTree,
step_size: float,
temperature: float = 1.0,
) -> PyTree:
position, momentum, xi = state
logdensity_grad = grad_estimator(position, minibatch)
position, momentum, xi = integrator(
rng_key, position, momentum, xi, logdensity_grad, step_size, temperature
)
return SGNHTState(position, momentum, xi)

return kernel


class sgnht:
"""Implements the (basic) user interface for the SGNHT kernel.

The general sgnht kernel (:meth:`blackjax.sgmcmc.sgnht.build_kernel`, alias
`blackjax.sgnht.build_kernel`) can be cumbersome to manipulate. Since most users
only need to specify the kernel parameters at initialization time, we
provide a helper function that specializes the general kernel.

Example
-------

To initialize a SGNHT kernel one needs to specify a schedule function, which
returns a step size at each sampling step, and a gradient estimator
function. Here for a constant step size, and `data_size` data samples:

.. code::

grad_estimator = blackjax.sgmcmc.gradients.grad_estimator(logprior_fn, loglikelihood_fn, data_size)

We can now initialize the sgnht kernel and the state.

.. code::

sgnht = blackjax.sgnht(grad_estimator)
state = sgnht.init(rng_key, position)

Assuming we have an iterator `batches` that yields batches of data we can
perform one step:

.. code::

step_size = 1e-3
minibatch = next(batches)
new_state = sgnht.step(rng_key, state, minibatch, step_size)

Kernels are not jit-compiled by default so you will need to do it manually:

.. code::

step = jax.jit(sgnht.step)
new_state = step(rng_key, state, minibatch, step_size)

Parameters
----------
grad_estimator
A function that takes a position, a batch of data and returns an estimation
of the gradient of the log-density at this position.

Returns
-------
A ``MCMCSamplingAlgorithm``.

"""

init = staticmethod(init)
build_kernel = staticmethod(build_kernel)

def __new__( # type: ignore[misc]
cls,
grad_estimator: Callable,
) -> MCMCSamplingAlgorithm:
kernel = cls.build_kernel()

def init_fn(position: PyTree, rng_key: PRNGKey):
return cls.init(rng_key, position)

def step_fn(rng_key: PRNGKey, state, minibatch: PyTree, step_size: float):
return kernel(
rng_key,
state,
grad_estimator,
minibatch,
step_size,
)

return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type]
16 changes: 16 additions & 0 deletions docs/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -327,3 +327,19 @@ @book{gelman2014bayesian
year={2014},
publisher={Chapman and Hall/CRC}
}

@article{ding2014bayesian,
title={Bayesian sampling using stochastic gradient thermostats},
author={Ding, Nan and Fang, Youhan and Babbush, Ryan and Chen, Changyou and Skeel, Robert D and Neven, Hartmut},
journal={Advances in neural information processing systems},
volume={27},
year={2014}
}

@article{ma2015complete,
title={A complete recipe for stochastic gradient MCMC},
author={Ma, Yi-An and Chen, Tianqi and Fox, Emily},
journal={Advances in neural information processing systems},
volume={28},
year={2015}
}
40 changes: 40 additions & 0 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,46 @@ def test_linear_regression_sghmc_cv(self):
data_batch = X_data[:100, :]
_ = sghmc(rng_key, init_position, data_batch, 1e-3)

def test_linear_regression_sgnht(self):
rng_key, data_key = jax.random.split(self.key, 2)

data_size = 1000
X_data = jax.random.normal(data_key, shape=(data_size, 5))

grad_fn = blackjax.sgmcmc.grad_estimator(
self.logprior_fn, self.loglikelihood_fn, data_size
)
sgnht = blackjax.sgnht(grad_fn)

_, rng_key = jax.random.split(rng_key)
data_batch = X_data[100:200, :]
init_position = 1.0
data_batch = X_data[:100, :]
init_state = sgnht.init(init_position, self.key)
_ = sgnht.step(rng_key, init_state, data_batch, 1e-3)

def test_linear_regression_sgnhtc_cv(self):
rng_key, data_key = jax.random.split(self.key, 2)

data_size = 1000
X_data = jax.random.normal(data_key, shape=(data_size, 5))

centering_position = 1.0
grad_fn = blackjax.sgmcmc.grad_estimator(
self.logprior_fn, self.loglikelihood_fn, data_size
)
cv_grad_fn = blackjax.sgmcmc.gradients.control_variates(
grad_fn, centering_position, X_data
)

sgnht = blackjax.sgnht(cv_grad_fn)

_, rng_key = jax.random.split(rng_key)
init_position = 1.0
data_batch = X_data[:100, :]
init_state = sgnht.init(init_position, self.key)
_ = sgnht.step(rng_key, init_state, data_batch, 1e-3)


class LatentGaussianTest(chex.TestCase):
"""Test sampling of a linear regression model."""
Expand Down