Skip to content

Commit

Permalink
Merge branch 'main' into smc_compatibility_Test
Browse files Browse the repository at this point in the history
  • Loading branch information
albcab authored Apr 20, 2023
2 parents 8cb7d45 + a650f9b commit 0583dad
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 61 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v1
with:
python-version: 3.7
python-version: 3.8
- name: Give PyPI some time to update the index
run: sleep 240
- name: Attempt install from PyPI
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- style
strategy:
matrix:
python-version: [ '3.7', '3.10']
python-version: [ '3.8', '3.10']
steps:
- uses: actions/checkout@v1
- name: Set up Python ${{ matrix.python-version }}
Expand Down
21 changes: 12 additions & 9 deletions blackjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,6 @@ class sgld:
.. code::
sgld = blackjax.sgld(grad_fn)
state = sgld.init(position)
Assuming we have an iterator `batches` that yields batches of data we can
perform one step:
Expand All @@ -544,14 +543,14 @@ class sgld:
step_size = 1e-3
minibatch = next(batches)
new_state = sgld.step(rng_key, state, minibatch, step_size)
new_position = sgld.step(rng_key, position, minibatch, step_size)
Kernels are not jit-compiled by default so you will need to do it manually:
.. code::
step = jax.jit(sgld.step)
new_state, info = step(rng_key, state, minibatch, step_size)
new_position, info = step(rng_key, position, minibatch, step_size)
Parameters
----------
Expand Down Expand Up @@ -611,7 +610,6 @@ class sghmc:
.. code::
sghmc = blackjax.sghmc(grad_estimator, num_integration_steps)
state = sghmc.init(position)
Assuming we have an iterator `batches` that yields batches of data we can
perform one step:
Expand All @@ -620,14 +618,14 @@ class sghmc:
step_size = 1e-3
minibatch = next(batches)
new_state = sghmc.step(rng_key, state, minibatch, step_size)
new_position = sghmc.step(rng_key, position, minibatch, step_size)
Kernels are not jit-compiled by default so you will need to do it manually:
.. code::
step = jax.jit(sghmc.step)
new_state, info = step(rng_key, state, minibatch, step_size)
new_position, info = step(rng_key, position, minibatch, step_size)
Parameters
----------
Expand Down Expand Up @@ -668,9 +666,12 @@ class csgld:
Parameters
----------
logdensity_estimator_fn
logdensity_estimator
A function that returns an estimation of the model's logdensity given
a position and a batch of data.
gradient_estimator
A function that takes a position, a batch of data and returns an estimation
of the gradient of the log-density at this position.
zeta
Hyperparameter that controls the geometric property of the flattened
density. If `zeta=0` the function reduces to the SGLD step function.
Expand Down Expand Up @@ -700,7 +701,8 @@ class csgld:

def __new__( # type: ignore[misc]
cls,
logdensity_estimator_fn: Callable,
logdensity_estimator: Callable,
gradient_estimator: Callable,
zeta: float = 1,
temperature: float = 0.01,
num_partitions: int = 512,
Expand All @@ -722,7 +724,8 @@ def step_fn(
return step(
rng_key,
state,
logdensity_estimator_fn,
logdensity_estimator,
gradient_estimator,
minibatch,
step_size_diff,
step_size_stoch,
Expand Down
34 changes: 15 additions & 19 deletions blackjax/mcmc/mala.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import jax.numpy as jnp

import blackjax.mcmc.diffusions as diffusions
import blackjax.mcmc.proposal as proposal
from blackjax.types import PRNGKey, PyTree

__all__ = ["MALAState", "MALAInfo", "init", "kernel"]
Expand Down Expand Up @@ -74,8 +75,8 @@ def kernel():
"""

def transition_probability(state, new_state, step_size):
"""Transition probability to go from `state` to `new_state`"""
def transition_energy(state, new_state, step_size):
"""Transition energy to go from `state` to `new_state`"""
theta = jax.tree_util.tree_map(
lambda new_x, x, g: new_x - x - step_size * g,
new_state.position,
Expand All @@ -85,7 +86,12 @@ def transition_probability(state, new_state, step_size):
theta_dot = jax.tree_util.tree_reduce(
operator.add, jax.tree_util.tree_map(lambda x: jnp.sum(x * x), theta)
)
return -0.25 * (1.0 / step_size) * theta_dot
return -state.logdensity + 0.25 * (1.0 / step_size) * theta_dot

init_proposal, generate_proposal = proposal.asymmetric_proposal_generator(
transition_energy, divergence_threshold=jnp.inf
)
sample_proposal = proposal.static_binomial_sampling

def one_step(
rng_key: PRNGKey, state: MALAState, logdensity_fn: Callable, step_size: float
Expand All @@ -97,26 +103,16 @@ def one_step(
key_integrator, key_rmh = jax.random.split(rng_key)

new_state = integrator(key_integrator, state, step_size)
new_state = MALAState(*new_state)

delta = (
new_state.logdensity
- state.logdensity
+ transition_probability(new_state, state, step_size)
- transition_probability(state, new_state, step_size)
proposal = init_proposal(state)
new_proposal, _ = generate_proposal(state, new_state, step_size=step_size)
sampled_proposal, do_accept, p_accept = sample_proposal(
key_rmh, proposal, new_proposal
)
delta = jnp.where(jnp.isnan(delta), -jnp.inf, delta)
p_accept = jnp.clip(jnp.exp(delta), a_max=1)

do_accept = jax.random.bernoulli(key_rmh, p_accept)

new_state = MALAState(*new_state)
info = MALAInfo(p_accept, do_accept)

return jax.lax.cond(
do_accept,
lambda _: (new_state, info),
lambda _: (state, info),
operand=None,
)
return sampled_proposal.state, info

return one_step
44 changes: 22 additions & 22 deletions blackjax/mcmc/proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import jax
import jax.numpy as jnp
import numpy as np

TrajectoryState = NamedTuple

Expand Down Expand Up @@ -49,18 +48,18 @@ def proposal_generator(
Parameters
----------
energy
A callable that computes the energy associated to a given state
A function that computes the energy associated to a given state
divergence_threshold
max value allowed for the difference in energies not to be considered a divergence
max value allowed for the difference in energies not to be considered a divergence
Returns
-------
Two callables, to generate an initial proposal when no step has been taken,
and to generate proposals after each step.
Two functions, one to generate an initial proposal when no step has been taken,
another to generate proposals after each step.
"""

def new(state: TrajectoryState) -> Proposal:
return Proposal(state, energy(state), 0.0, -np.inf)
return Proposal(state, energy(state), 0.0, -jnp.inf)

def update(initial_energy: float, state: TrajectoryState) -> Tuple[Proposal, bool]:
"""Generate a new proposal from a trajectory state.
Expand Down Expand Up @@ -103,13 +102,13 @@ def proposal_from_energy_diff(
Parameters
----------
initial_energy
the energy from the previous state
the energy from the initial state
new_energy
the energy at the new state
the energy at the proposed state
divergence_threshold
max value allowed for the difference in energies not to be considered a divergence
max value allowed for the difference in energies not to be considered a divergence
state
the state to propose
the proposed state
Returns
-------
Expand Down Expand Up @@ -139,36 +138,37 @@ def proposal_from_energy_diff(
def asymmetric_proposal_generator(
transition_energy_fn: Callable,
divergence_threshold: float,
proposal_factory=proposal_from_energy_diff,
proposal_factory: Callable = proposal_from_energy_diff,
) -> Tuple[Callable, Callable]:
"""A proposal generator that takes into account the transition between
two states to compute a new proposal. In particular, both states are
used to compute the energies to consider in weighting the proposal,
to account for asymmetries.
----------
transition_energy_fn
A Callable that computes the energy of a associated with a transition
from one state to another
A function that computes the energy of a transition from an initial state
to a new state, given some optional keyword arguments.
divergence_threshold
A max number to will be used by the proposal_factory to flag a Proposal
as a divergence.
The maximum value allowed for the difference in energies not to be considered a divergence.
proposal_factory
A callable that builds a proposal from the transitions energies
A function that builds a proposal from the transition energies.
Returns
-------
Two callables, to generate an initial proposal when no step has been taken,
and to generate proposals after each step.
Two functions, one to generate an initial proposal when no step has been taken,
another to generate proposals after each step.
"""

def new(state: TrajectoryState) -> Proposal:
return Proposal(state, 0.0, 0.0, -np.inf)
return Proposal(state, 0.0, 0.0, -jnp.inf)

def update(
initial_state: TrajectoryState, state: TrajectoryState
initial_state: TrajectoryState,
state: TrajectoryState,
**energy_params,
) -> Tuple[Proposal, bool]:
new_energy = transition_energy_fn(initial_state, state)
prev_energy = transition_energy_fn(state, initial_state)
new_energy = transition_energy_fn(initial_state, state, **energy_params)
prev_energy = transition_energy_fn(state, initial_state, **energy_params)
return proposal_factory(prev_energy, new_energy, divergence_threshold, state)

return new, update
Expand Down
12 changes: 8 additions & 4 deletions blackjax/sgmcmc/csgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def kernel(num_partitions=512, energy_gap=10, min_energy=0) -> Callable:
def one_step(
rng_key: PRNGKey,
state: ContourSGLDState,
logdensity_estimator_fn: Callable,
logdensity_estimator: Callable,
gradient_estimator: Callable,
minibatch: PyTree,
step_size_diff: float, # step size for Langevin diffusion
step_size_stoch: float = 1e-3, # step size for stochastic approximation
Expand Down Expand Up @@ -95,9 +96,12 @@ def one_step(
State of the pseudo-random number generator.
state
Current state of the CSGLD sampler
logdensity_estimator_fn
logdensity_estimator
Function that returns an estimation of the value of the density
function at the current position.
gradient_estimator
A function that takes a position, a batch of data and returns an estimation
of the gradient of the log-density at this position.
minibatch
Minibatch of data.
step_size_diff
Expand All @@ -123,7 +127,7 @@ def one_step(
/ energy_gap
)

logprob_grad = jax.grad(logdensity_estimator_fn)(position, minibatch)
logprob_grad = gradient_estimator(position, minibatch)
position = integrator(
rng_key,
position,
Expand All @@ -133,7 +137,7 @@ def one_step(
)

# Update the stochastic approximation to the energy histogram
neg_logprob = -logdensity_estimator_fn(position, minibatch)
neg_logprob = -logdensity_estimator(position, minibatch)
idx = jax.lax.min(
jax.lax.max(
jax.lax.floor((neg_logprob - min_energy) / energy_gap + 1).astype(
Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
name = "blackjax"
authors= [{name = "The Blackjax team", email = "[email protected]"}]
description = "Flexible and fast sampling in Python"
requires-python = ">=3.7"
requires-python = ">=3.8"
keywords=[
"probability",
"machine learning",
Expand All @@ -22,7 +22,6 @@ classifiers = [
"License :: OSI Approved :: Apache Software License",
"Operating System :: MacOS",
"Operating System :: POSIX",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
Expand All @@ -36,7 +35,7 @@ dependencies = [
"jax>=0.3.13",
"jaxlib>=0.3.10",
"jaxopt>=0.5.5",
"optax",
"optax@git+https://github.com/deepmind/optax.git",
"typing-extensions>=4.4.0",
]
dynamic = ["version"]
Expand Down
2 changes: 1 addition & 1 deletion requirements-doc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jupytext
myst_nb
numba
numpyro
optax
optax@git+https://github.com/deepmind/optax.git
oryx
pymc
scikit-learn
Expand Down
5 changes: 4 additions & 1 deletion tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,10 @@ def test_linear_regression_contour_sgld(self):
logdensity_fn = blackjax.sgmcmc.logdensity_estimator(
self.logprior_fn, self.loglikelihood_fn, data_size
)
csgld = blackjax.csgld(logdensity_fn)
grad_fn = blackjax.sgmcmc.grad_estimator(
self.logprior_fn, self.loglikelihood_fn, data_size
)
csgld = blackjax.csgld(logdensity_fn, grad_fn)

_, rng_key = jax.random.split(rng_key)
data_batch = X_data[:100, :]
Expand Down

0 comments on commit 0583dad

Please sign in to comment.