-
Notifications
You must be signed in to change notification settings - Fork 107
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add sgnht * reformat * Restructure kernels * Reformat * Clean * Rename step to kernel
- Loading branch information
1 parent
661874d
commit c6149e3
Showing
8 changed files
with
259 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from . import csgld, sghmc, sgld | ||
from . import csgld, sghmc, sgld, sgnht | ||
from .gradients import grad_estimator, logdensity_estimator | ||
|
||
__all__ = ["grad_estimator", "logdensity_estimator", "csgld", "sgld", "sghmc"] | ||
__all__ = ["grad_estimator", "logdensity_estimator", "csgld", "sgld", "sghmc", "sgnht"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# Copyright 2020- The Blackjax Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Public API for the Stochastic gradient Nosé-Hoover Thermostat kernel.""" | ||
from typing import Callable, NamedTuple | ||
|
||
import blackjax.sgmcmc.diffusions as diffusions | ||
from blackjax.base import MCMCSamplingAlgorithm | ||
from blackjax.types import PRNGKey, PyTree | ||
from blackjax.util import generate_gaussian_noise | ||
|
||
__all__ = ["SGNHTState", "init", "build_kernel", "sgnht"] | ||
|
||
|
||
class SGNHTState(NamedTuple): | ||
r"""State of the SGNHT algorithm. | ||
Parameters | ||
---------- | ||
position | ||
Current position in the sample space. | ||
momentum | ||
Current momentum in the sample space. | ||
xi | ||
Scalar thermostat controlling kinetic energy. | ||
""" | ||
position: PyTree | ||
momentum: PyTree | ||
xi: float | ||
|
||
|
||
def init(rng_key: PRNGKey, position: PyTree, alpha: float = 0.01): | ||
momentum = generate_gaussian_noise(rng_key, position) | ||
return SGNHTState(position, momentum, alpha) | ||
|
||
|
||
def build_kernel(alpha: float = 0.01, beta: float = 0) -> Callable: | ||
"""Stochastic gradient Nosé-Hoover Thermostat (SGNHT) algorithm.""" | ||
integrator = diffusions.sgnht(alpha, beta) | ||
|
||
def kernel( | ||
rng_key: PRNGKey, | ||
state: SGNHTState, | ||
grad_estimator: Callable, | ||
minibatch: PyTree, | ||
step_size: float, | ||
temperature: float = 1.0, | ||
) -> PyTree: | ||
position, momentum, xi = state | ||
logdensity_grad = grad_estimator(position, minibatch) | ||
position, momentum, xi = integrator( | ||
rng_key, position, momentum, xi, logdensity_grad, step_size, temperature | ||
) | ||
return SGNHTState(position, momentum, xi) | ||
|
||
return kernel | ||
|
||
|
||
class sgnht: | ||
"""Implements the (basic) user interface for the SGNHT kernel. | ||
The general sgnht kernel (:meth:`blackjax.sgmcmc.sgnht.build_kernel`, alias | ||
`blackjax.sgnht.build_kernel`) can be cumbersome to manipulate. Since most users | ||
only need to specify the kernel parameters at initialization time, we | ||
provide a helper function that specializes the general kernel. | ||
Example | ||
------- | ||
To initialize a SGNHT kernel one needs to specify a schedule function, which | ||
returns a step size at each sampling step, and a gradient estimator | ||
function. Here for a constant step size, and `data_size` data samples: | ||
.. code:: | ||
grad_estimator = blackjax.sgmcmc.gradients.grad_estimator(logprior_fn, loglikelihood_fn, data_size) | ||
We can now initialize the sgnht kernel and the state. | ||
.. code:: | ||
sgnht = blackjax.sgnht(grad_estimator) | ||
state = sgnht.init(rng_key, position) | ||
Assuming we have an iterator `batches` that yields batches of data we can | ||
perform one step: | ||
.. code:: | ||
step_size = 1e-3 | ||
minibatch = next(batches) | ||
new_state = sgnht.step(rng_key, state, minibatch, step_size) | ||
Kernels are not jit-compiled by default so you will need to do it manually: | ||
.. code:: | ||
step = jax.jit(sgnht.step) | ||
new_state = step(rng_key, state, minibatch, step_size) | ||
Parameters | ||
---------- | ||
grad_estimator | ||
A function that takes a position, a batch of data and returns an estimation | ||
of the gradient of the log-density at this position. | ||
Returns | ||
------- | ||
A ``MCMCSamplingAlgorithm``. | ||
""" | ||
|
||
init = staticmethod(init) | ||
build_kernel = staticmethod(build_kernel) | ||
|
||
def __new__( # type: ignore[misc] | ||
cls, | ||
grad_estimator: Callable, | ||
) -> MCMCSamplingAlgorithm: | ||
kernel = cls.build_kernel() | ||
|
||
def init_fn(position: PyTree, rng_key: PRNGKey): | ||
return cls.init(rng_key, position) | ||
|
||
def step_fn(rng_key: PRNGKey, state, minibatch: PyTree, step_size: float): | ||
return kernel( | ||
rng_key, | ||
state, | ||
grad_estimator, | ||
minibatch, | ||
step_size, | ||
) | ||
|
||
return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters