Skip to content

Commit

Permalink
Add SGNHT (#515)
Browse files Browse the repository at this point in the history
* add sgnht

* reformat

* Restructure kernels

* Reformat

* Clean

* Rename step to kernel
  • Loading branch information
SamDuffield authored May 25, 2023
1 parent 661874d commit c6149e3
Show file tree
Hide file tree
Showing 8 changed files with 259 additions and 14 deletions.
2 changes: 2 additions & 0 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .sgmcmc.csgld import csgld
from .sgmcmc.sghmc import sghmc
from .sgmcmc.sgld import sgld
from .sgmcmc.sgnht import sgnht
from .smc.adaptive_tempered import adaptive_tempered_smc
from .smc.tempered import tempered_smc
from .vi.meanfield_vi import meanfield_vi
Expand All @@ -38,6 +39,7 @@
"ghmc",
"sgld", # stochastic gradient mcmc
"sghmc",
"sgnht",
"csgld",
"window_adaptation", # mcmc adaptation
"meads_adaptation",
Expand Down
4 changes: 2 additions & 2 deletions blackjax/sgmcmc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import csgld, sghmc, sgld
from . import csgld, sghmc, sgld, sgnht
from .gradients import grad_estimator, logdensity_estimator

__all__ = ["grad_estimator", "logdensity_estimator", "csgld", "sgld", "sghmc"]
__all__ = ["grad_estimator", "logdensity_estimator", "csgld", "sgld", "sghmc", "sgnht"]
55 changes: 49 additions & 6 deletions blackjax/sgmcmc/diffusions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Solvers for Langevin diffusions."""
import operator

import jax
import jax.numpy as jnp

from blackjax.types import PRNGKey, PyTree
from blackjax.util import generate_gaussian_noise
from blackjax.util import generate_gaussian_noise, pytree_size

__all__ = ["overdamped_langevin", "sghmc"]
__all__ = ["overdamped_langevin", "sghmc", "sgnht"]


def overdamped_langevin():
Expand Down Expand Up @@ -51,7 +53,8 @@ def one_step(


def sghmc(alpha: float = 0.01, beta: float = 0):
"""Solver for the diffusion equation of the SGHMC algorithm :cite:p:`chen2014stochastic`.
"""Euler solver for the diffusion equation of the SGHMC algorithm :cite:p:`chen2014stochastic`,
with parameters alpha and beta scaled according to :cite:p:`ma2015complete`.
This algorithm was ported from :cite:p:`coullon2022sgmcmcjax`.
Expand All @@ -66,11 +69,13 @@ def one_step(
temperature: float = 1.0,
):
noise = generate_gaussian_noise(rng_key, position)
position = jax.tree_util.tree_map(lambda x, p: x + p, position, momentum)
position = jax.tree_util.tree_map(
lambda x, p: x + step_size * p, position, momentum
)
momentum = jax.tree_util.tree_map(
lambda p, g, n: (1.0 - alpha) * p
lambda p, g, n: (1.0 - alpha * step_size) * p
+ step_size * g
+ jnp.sqrt(2 * step_size * (alpha - beta) * temperature) * n,
+ jnp.sqrt(step_size * (2 * alpha - step_size * beta) * temperature) * n,
momentum,
logdensity_grad,
noise,
Expand All @@ -79,3 +84,41 @@ def one_step(
return position, momentum

return one_step


def sgnht(alpha: float = 0.01, beta: float = 0):
"""Euler solver for the diffusion equation of the SGNHT algorithm :cite:p:`ding2014bayesian`.
This algorithm was ported from :cite:p:`coullon2022sgmcmcjax`.
"""

def one_step(
rng_key: PRNGKey,
position: PyTree,
momentum: PyTree,
xi: float,
logdensity_grad: PyTree,
step_size: float,
temperature: float = 1.0,
):
noise = generate_gaussian_noise(rng_key, position)
position = jax.tree_util.tree_map(
lambda x, p: x + step_size * p, position, momentum
)
momentum = jax.tree_util.tree_map(
lambda p, g, n: (1.0 - xi * step_size) * p
+ step_size * g
+ jnp.sqrt(step_size * (2 * alpha - step_size * beta) * temperature) * n,
momentum,
logdensity_grad,
noise,
)
momentum_dot = jax.tree_util.tree_reduce(
operator.add, jax.tree_util.tree_map(lambda x: jnp.sum(x * x), momentum)
)
d = pytree_size(momentum)
xi = xi + step_size * (momentum_dot / d - 1)
return position, momentum, xi

return one_step
7 changes: 3 additions & 4 deletions blackjax/sgmcmc/sghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import Callable

import jax
import jax.numpy as jnp

import blackjax.sgmcmc.diffusions as diffusions
from blackjax.types import PRNGKey, PyTree
Expand Down Expand Up @@ -45,7 +44,7 @@ def body_fn(state, rng_key):
)
return ((position, momentum), position)

momentum = generate_gaussian_noise(rng_key, position, 0, jnp.sqrt(step_size))
momentum = generate_gaussian_noise(rng_key, position)
keys = jax.random.split(rng_key, num_integration_steps)
(position, momentum), _ = jax.lax.scan(body_fn, (position, momentum), keys)

Expand All @@ -57,7 +56,7 @@ def body_fn(state, rng_key):
class sghmc:
"""Implements the (basic) user interface for the SGHMC kernel.
The general sghmc kernel builder (:meth:`blackjax.mcmc.sghmc.build_kernel`, alias
The general sghmc kernel builder (:meth:`blackjax.sgmcmc.sghmc.build_kernel`, alias
`blackjax.sghmc.build_kernel`) can be cumbersome to manipulate. Since most users
only need to specify the kernel parameters at initialization time, we
provide a helper function that specializes the general kernel.
Expand Down Expand Up @@ -103,7 +102,7 @@ class sghmc:
Returns
-------
A ``MCMCSamplingAlgorithm``.
A step function.
"""

Expand Down
4 changes: 2 additions & 2 deletions blackjax/sgmcmc/sgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def kernel(
class sgld:
"""Implements the (basic) user interface for the SGLD kernel.
The general sgld kernel builder (:meth:`blackjax.mcmc.sgld.build_kernel`, alias
The general sgld kernel builder (:meth:`blackjax.sgmcmc.sgld.build_kernel`, alias
`blackjax.sgld.build_kernel`) can be cumbersome to manipulate. Since most users
only need to specify the kernel parameters at initialization time, we
provide a helper function that specializes the general kernel.
Expand Down Expand Up @@ -91,7 +91,7 @@ class sgld:
Returns
-------
A ``MCMCSamplingAlgorithm``.
A step function.
"""

Expand Down
145 changes: 145 additions & 0 deletions blackjax/sgmcmc/sgnht.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Public API for the Stochastic gradient Nosé-Hoover Thermostat kernel."""
from typing import Callable, NamedTuple

import blackjax.sgmcmc.diffusions as diffusions
from blackjax.base import MCMCSamplingAlgorithm
from blackjax.types import PRNGKey, PyTree
from blackjax.util import generate_gaussian_noise

__all__ = ["SGNHTState", "init", "build_kernel", "sgnht"]


class SGNHTState(NamedTuple):
r"""State of the SGNHT algorithm.
Parameters
----------
position
Current position in the sample space.
momentum
Current momentum in the sample space.
xi
Scalar thermostat controlling kinetic energy.
"""
position: PyTree
momentum: PyTree
xi: float


def init(rng_key: PRNGKey, position: PyTree, alpha: float = 0.01):
momentum = generate_gaussian_noise(rng_key, position)
return SGNHTState(position, momentum, alpha)


def build_kernel(alpha: float = 0.01, beta: float = 0) -> Callable:
"""Stochastic gradient Nosé-Hoover Thermostat (SGNHT) algorithm."""
integrator = diffusions.sgnht(alpha, beta)

def kernel(
rng_key: PRNGKey,
state: SGNHTState,
grad_estimator: Callable,
minibatch: PyTree,
step_size: float,
temperature: float = 1.0,
) -> PyTree:
position, momentum, xi = state
logdensity_grad = grad_estimator(position, minibatch)
position, momentum, xi = integrator(
rng_key, position, momentum, xi, logdensity_grad, step_size, temperature
)
return SGNHTState(position, momentum, xi)

return kernel


class sgnht:
"""Implements the (basic) user interface for the SGNHT kernel.
The general sgnht kernel (:meth:`blackjax.sgmcmc.sgnht.build_kernel`, alias
`blackjax.sgnht.build_kernel`) can be cumbersome to manipulate. Since most users
only need to specify the kernel parameters at initialization time, we
provide a helper function that specializes the general kernel.
Example
-------
To initialize a SGNHT kernel one needs to specify a schedule function, which
returns a step size at each sampling step, and a gradient estimator
function. Here for a constant step size, and `data_size` data samples:
.. code::
grad_estimator = blackjax.sgmcmc.gradients.grad_estimator(logprior_fn, loglikelihood_fn, data_size)
We can now initialize the sgnht kernel and the state.
.. code::
sgnht = blackjax.sgnht(grad_estimator)
state = sgnht.init(rng_key, position)
Assuming we have an iterator `batches` that yields batches of data we can
perform one step:
.. code::
step_size = 1e-3
minibatch = next(batches)
new_state = sgnht.step(rng_key, state, minibatch, step_size)
Kernels are not jit-compiled by default so you will need to do it manually:
.. code::
step = jax.jit(sgnht.step)
new_state = step(rng_key, state, minibatch, step_size)
Parameters
----------
grad_estimator
A function that takes a position, a batch of data and returns an estimation
of the gradient of the log-density at this position.
Returns
-------
A ``MCMCSamplingAlgorithm``.
"""

init = staticmethod(init)
build_kernel = staticmethod(build_kernel)

def __new__( # type: ignore[misc]
cls,
grad_estimator: Callable,
) -> MCMCSamplingAlgorithm:
kernel = cls.build_kernel()

def init_fn(position: PyTree, rng_key: PRNGKey):
return cls.init(rng_key, position)

def step_fn(rng_key: PRNGKey, state, minibatch: PyTree, step_size: float):
return kernel(
rng_key,
state,
grad_estimator,
minibatch,
step_size,
)

return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type]
16 changes: 16 additions & 0 deletions docs/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -327,3 +327,19 @@ @book{gelman2014bayesian
year={2014},
publisher={Chapman and Hall/CRC}
}

@article{ding2014bayesian,
title={Bayesian sampling using stochastic gradient thermostats},
author={Ding, Nan and Fang, Youhan and Babbush, Ryan and Chen, Changyou and Skeel, Robert D and Neven, Hartmut},
journal={Advances in neural information processing systems},
volume={27},
year={2014}
}

@article{ma2015complete,
title={A complete recipe for stochastic gradient MCMC},
author={Ma, Yi-An and Chen, Tianqi and Fox, Emily},
journal={Advances in neural information processing systems},
volume={28},
year={2015}
}
40 changes: 40 additions & 0 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,46 @@ def test_linear_regression_sghmc_cv(self):
data_batch = X_data[:100, :]
_ = sghmc(rng_key, init_position, data_batch, 1e-3)

def test_linear_regression_sgnht(self):
rng_key, data_key = jax.random.split(self.key, 2)

data_size = 1000
X_data = jax.random.normal(data_key, shape=(data_size, 5))

grad_fn = blackjax.sgmcmc.grad_estimator(
self.logprior_fn, self.loglikelihood_fn, data_size
)
sgnht = blackjax.sgnht(grad_fn)

_, rng_key = jax.random.split(rng_key)
data_batch = X_data[100:200, :]
init_position = 1.0
data_batch = X_data[:100, :]
init_state = sgnht.init(init_position, self.key)
_ = sgnht.step(rng_key, init_state, data_batch, 1e-3)

def test_linear_regression_sgnhtc_cv(self):
rng_key, data_key = jax.random.split(self.key, 2)

data_size = 1000
X_data = jax.random.normal(data_key, shape=(data_size, 5))

centering_position = 1.0
grad_fn = blackjax.sgmcmc.grad_estimator(
self.logprior_fn, self.loglikelihood_fn, data_size
)
cv_grad_fn = blackjax.sgmcmc.gradients.control_variates(
grad_fn, centering_position, X_data
)

sgnht = blackjax.sgnht(cv_grad_fn)

_, rng_key = jax.random.split(rng_key)
init_position = 1.0
data_batch = X_data[:100, :]
init_state = sgnht.init(init_position, self.key)
_ = sgnht.step(rng_key, init_state, data_batch, 1e-3)


class LatentGaussianTest(chex.TestCase):
"""Test sampling of a linear regression model."""
Expand Down

0 comments on commit c6149e3

Please sign in to comment.