Skip to content

Commit

Permalink
Exposing RMH and Random Walk as two different algorithms, generalizin…
Browse files Browse the repository at this point in the history
…g RW to non-gaussian jumps (blackjax-devs#484)

* Exposing RMH and Random Walk as two different sampling algorithms, generalizing Random Walk to non-gassian jumps

* Making rmh depend on proposal and trajectory, removing code duplication. Adding protocol to apply Interface Segregation Principle on proposal.proposal_generator

* assymetric proposal generator

* Moving hmc energy to trajectory.py

* Adding transition aware proposal generator, refactoring so that special cases of RMH are all in the same file, adding docs and tests

* Adding test for proporsal_from_energy_diff

* Adding test for RMH energy

* Renames to be compliant with new convetion, moving kernel parameters from factories to kernel methods
  • Loading branch information
ciguaran authored Mar 20, 2023
1 parent 1908453 commit 7100bca
Show file tree
Hide file tree
Showing 14 changed files with 784 additions and 320 deletions.
2 changes: 2 additions & 0 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .diagnostics import potential_scale_reduction as rhat
from .kernels import (
adaptive_tempered_smc,
additive_step_random_walk,
csgld,
elliptical_slice,
ghmc,
Expand Down Expand Up @@ -34,6 +35,7 @@
"mgrad_gaussian",
"nuts",
"orbital_hmc",
"additive_step_random_walk",
"rmh",
"irmh",
"elliptical_slice",
Expand Down
110 changes: 93 additions & 17 deletions blackjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import blackjax.adaptation as adaptation
import blackjax.mcmc as mcmc
import blackjax.mcmc.random_walk
import blackjax.sgmcmc as sgmcmc
import blackjax.smc as smc
import blackjax.vi as vi
Expand All @@ -34,6 +35,7 @@
"nuts",
"ghmc",
"orbital_hmc",
"additive_step_random_walk",
"rmh",
"sgld",
"sghmc",
Expand Down Expand Up @@ -957,17 +959,86 @@ def run(rng_key: PRNGKey, positions: PyTree, num_steps: int = 1000):
return AdaptationAlgorithm(run) # type: ignore[arg-type]


class additive_step_random_walk:
"""Implements the user interface for the Additive Step RMH
Examples
--------
A new kernel can be initialized and used with the following code:
.. code::
rw = blackjax.additive_step_random_walk(logdensity_fn, random_step)
state = rw.init(position)
new_state, info = rw.step(rng_key, state)
The specific case of a Gaussian `random_step` is already implemented, either with independent components
when `covariance_matrix` is a one dimensional array or with dependent components if a two dimensional array:
.. code::
rw_gaussian = blackjax.additive_step_random_walk.normal_random_walk(logdensity_fn, covariance_matrix)
state = rw_gaussian.init(position)
new_state, info = rw_gaussian.step(rng_key, state)
Parameters
----------
logdensity_fn
The log density probability density function from which we wish to sample.
random_step
A Callable that takes a random number generator and the current state and produces a step,
which will be added to the current position to obtain a new position. Must be symmetric
to maintain detailed balance. This means that P(step|position) = P(-step | position+step)
Returns
-------
A ``MCMCSamplingAlgorithm``.
"""

init = staticmethod(blackjax.mcmc.random_walk.init)
build_kernel = staticmethod(blackjax.mcmc.random_walk.build_additive_step)

@classmethod
def normal_random_walk(cls, logdensity_fn: Callable, sigma):
"""
Parameters
----------
logdensity_fn
The log density probability density function from which we wish to sample.
sigma
The value of the covariance matrix of the gaussian proposal distribution.
Returns
-------
A ``MCMCSamplingAlgorithm``.
"""
return cls(logdensity_fn, blackjax.mcmc.random_walk.normal(sigma))

def __new__( # type: ignore[misc]
cls, logdensity_fn: Callable, random_step: Callable
) -> MCMCSamplingAlgorithm:
kernel = cls.build_kernel()

def init_fn(position: PyTree):
return cls.init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
return kernel(rng_key, random_step, state, logdensity_fn)

return MCMCSamplingAlgorithm(init_fn, step_fn)


class rmh:
"""Implements the (basic) user interface for the gaussian random walk kernel
"""Implements the user interface for the RMH.
Examples
--------
A new Gaussian Random Walk kernel can be initialized and used with the following code:
A new kernel can be initialized and used with the following code:
.. code::
rmh = blackjax.rmh(logdensity_fn, sigma)
rmh = blackjax.rmh(logdensity_fn, proposal_generator)
state = rmh.init(position)
new_state, info = rmh.step(rng_key, state)
Expand All @@ -982,34 +1053,39 @@ class rmh:
----------
logdensity_fn
The log density probability density function from which we wish to sample.
sigma
The value of the covariance matrix of the gaussian proposal distribution.
proposal_generator
A Callable that takes a random number generator and the current state and produces a new proposal.
proposal_logdensity_fn
The logdensity function associated to the proposal_generator. If the generator is non-symmetric,
P(x_t|x_t-1) is not equal to P(x_t-1|x_t), then this parameter must be not None in order to apply
the Metropolis-Hastings correction for detailed balance.
Returns
-------
A ``MCMCSamplingAlgorithm``.
"""

init = staticmethod(mcmc.rmh.init)
kernel = staticmethod(mcmc.rmh.kernel)
init = staticmethod(blackjax.mcmc.random_walk.init)
build_kernel = staticmethod(blackjax.mcmc.random_walk.build_rmh)

def __new__( # type: ignore[misc]
cls,
logdensity_fn: Callable,
sigma: Array,
proposal_generator: Callable[[PRNGKey, PyTree], PyTree],
proposal_logdensity_fn: Optional[Callable[[PyTree], PyTree]] = None,
) -> MCMCSamplingAlgorithm:
step = cls.kernel()
kernel = cls.build_kernel()

def init_fn(position: PyTree):
return cls.init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
return step(
return kernel(
rng_key,
state,
logdensity_fn,
sigma,
proposal_generator,
state,
proposal_logdensity_fn,
)

return MCMCSamplingAlgorithm(init_fn, step_fn)
Expand Down Expand Up @@ -1050,21 +1126,21 @@ class irmh:
"""

init = staticmethod(mcmc.rmh.init)
kernel = staticmethod(mcmc.irmh.kernel)
init = staticmethod(blackjax.mcmc.random_walk.init)
build_kernel = staticmethod(blackjax.mcmc.random_walk.build_irmh)

def __new__( # type: ignore[misc]
cls,
logdensity_fn: Callable,
proposal_distribution: Callable,
) -> MCMCSamplingAlgorithm:
step = cls.kernel(proposal_distribution)
kernel = cls.build_kernel()

def init_fn(position: PyTree):
return cls.init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
return step(rng_key, state, logdensity_fn)
return kernel(rng_key, proposal_distribution, state, logdensity_fn)

return MCMCSamplingAlgorithm(init_fn, step_fn)

Expand Down
6 changes: 2 additions & 4 deletions blackjax/mcmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
elliptical_slice,
ghmc,
hmc,
irmh,
mala,
marginal_latent_gaussian,
nuts,
periodic_orbital,
rmh,
random_walk,
)

__all__ = [
Expand All @@ -17,7 +16,6 @@
"mala",
"nuts",
"periodic_orbital",
"rmh",
"marginal_latent_gaussian",
"irmh",
"random_walk",
]
3 changes: 2 additions & 1 deletion blackjax/mcmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import blackjax.mcmc.metrics as metrics
import blackjax.mcmc.proposal as proposal
import blackjax.mcmc.trajectory as trajectory
from blackjax.mcmc.trajectory import hmc_energy
from blackjax.types import Array, PRNGKey, PyTree

__all__ = ["HMCState", "HMCInfo", "init", "kernel"]
Expand Down Expand Up @@ -180,7 +181,7 @@ def hmc_proposal(
"""
build_trajectory = trajectory.static_integration(integrator)
init_proposal, generate_proposal = proposal.proposal_generator(
kinetic_energy, divergence_threshold
hmc_energy(kinetic_energy), divergence_threshold
)

def generate(
Expand Down
51 changes: 0 additions & 51 deletions blackjax/mcmc/irmh.py

This file was deleted.

Loading

0 comments on commit 7100bca

Please sign in to comment.