Skip to content

Commit

Permalink
Generate Gaussian noise with the same structure as the input PyTree (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
junpenglao committed Mar 12, 2024
1 parent 9c3af23 commit 573a4f0
Show file tree
Hide file tree
Showing 11 changed files with 115 additions and 65 deletions.
4 changes: 2 additions & 2 deletions blackjax/adaptation/window_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
dual_averaging_adaptation,
)
from blackjax.types import Array, PyTree
from blackjax.util import pytree_size

__all__ = ["base", "schedule"]

Expand Down Expand Up @@ -97,8 +98,7 @@ def init(position: PyTree, initial_step_size: float) -> Tuple:
We may reconsider this choice in the future.
"""
flat_position, _ = jax.flatten_util.ravel_pytree(position)
num_dimensions = flat_position.shape[-1]
num_dimensions = pytree_size(position)
imm_state = mm_init(num_dimensions)

ss_state = da_init(initial_step_size)
Expand Down
9 changes: 2 additions & 7 deletions blackjax/mcmc/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import jax
import jax.numpy as jnp

from blackjax.types import PRNGKey, PyTree
from blackjax.types import PyTree
from blackjax.util import generate_gaussian_noise

__all__ = ["overdamped_langevin"]

Expand All @@ -15,12 +16,6 @@ class DiffusionState(NamedTuple):
logprob_grad: PyTree


def generate_gaussian_noise(rng_key: PRNGKey, position):
position_flat, unravel_fn = jax.flatten_util.ravel_pytree(position)
noise_flat = jax.random.normal(rng_key, shape=jnp.shape(position_flat))
return unravel_fn(noise_flat)


def overdamped_langevin(logprob_and_grad_fn):
"""Euler solver for overdamped Langevin diffusion."""

Expand Down
11 changes: 3 additions & 8 deletions blackjax/mcmc/elliptical_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax.numpy as jnp

from blackjax.types import Array, PRNGKey, PyTree
from blackjax.util import generate_gaussian_noise

__all__ = ["EllipSliceState", "EllipSliceInfo", "init", "kernel"]

Expand Down Expand Up @@ -76,11 +77,9 @@ def kernel(cov_matrix: Array, mean: Array):

if ndim == 1: # diagonal covariance matrix
cov_matrix_sqrt = jnp.sqrt(cov_matrix)
dot = jnp.multiply

elif ndim == 2:
cov_matrix_sqrt = jax.lax.linalg.cholesky(cov_matrix)
dot = jnp.dot

else:
raise ValueError(
Expand All @@ -89,11 +88,7 @@ def kernel(cov_matrix: Array, mean: Array):
)

def momentum_generator(rng_key, position):
p, unravel_fn = jax.flatten_util.ravel_pytree(position)
momentum = mean + dot(
cov_matrix_sqrt, jax.random.normal(rng_key, shape=p.shape)
)
return unravel_fn(momentum)
return generate_gaussian_noise(rng_key, position, mean, cov_matrix_sqrt)

def one_step(
rng_key: PRNGKey,
Expand Down Expand Up @@ -141,7 +136,7 @@ def generate(
key_momentum, key_uniform, key_theta = jax.random.split(rng_key, 3)
# step 1: sample momentum
momentum = momentum_generator(key_momentum, position)
# #step 2: get slice (y)
# step 2: get slice (y)
logy = loglikelihood + jnp.log(jax.random.uniform(key_uniform))
# step 3: get theta (ellipsis move), set inital interval
theta = 2 * jnp.pi * jax.random.uniform(key_theta)
Expand Down
11 changes: 4 additions & 7 deletions blackjax/mcmc/ghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree

import blackjax.mcmc.hmc as hmc
import blackjax.mcmc.integrators as integrators
import blackjax.mcmc.metrics as metrics
import blackjax.mcmc.proposal as proposal
from blackjax.types import PRNGKey, PyTree
from blackjax.util import generate_gaussian_noise, pytree_size

__all__ = ["GHMCState", "init", "kernel"]

Expand Down Expand Up @@ -44,9 +44,8 @@ def potential_fn(x):

potential_energy, potential_energy_grad = jax.value_and_grad(potential_fn)(position)

p, unravel_fn = ravel_pytree(position)
key_mometum, key_slice = jax.random.split(rng_key)
momentum = unravel_fn(jax.random.normal(key_mometum, p.shape))
momentum = generate_gaussian_noise(key_mometum, position)
slice = jax.random.uniform(key_slice, minval=-1.0, maxval=1.0)

return GHMCState(position, momentum, potential_energy, potential_energy_grad, slice)
Expand Down Expand Up @@ -220,10 +219,8 @@ def update_momentum(rng_key, state, alpha):

position, momentum, *_ = state

m, _ = ravel_pytree(momentum)
momentum_generator, *_ = metrics.gaussian_euclidean(
1 / alpha * jnp.ones(jnp.shape(m))
)
m_size = pytree_size(momentum)
momentum_generator, *_ = metrics.gaussian_euclidean(1 / alpha * jnp.ones((m_size,)))
momentum = jax.tree_map(
lambda prev_momentum, shifted_momentum: prev_momentum * jnp.sqrt(1.0 - alpha)
+ shifted_momentum,
Expand Down
8 changes: 5 additions & 3 deletions blackjax/mcmc/mala.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Public API for Metropolis Adjusted Langevin kernels."""
import operator
from typing import Callable, NamedTuple, Tuple

import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree

from blackjax.mcmc.diffusion import overdamped_langevin
from blackjax.types import PRNGKey, PyTree
Expand Down Expand Up @@ -69,8 +69,10 @@ def transition_probability(state, new_state, step_size):
state.position,
state.logprob_grad,
)
theta_ravel, _ = ravel_pytree(theta)
return -0.25 * (1.0 / step_size) * jnp.dot(theta_ravel, theta_ravel)
theta_dot = jax.tree_util.tree_reduce(
operator.add, jax.tree_util.tree_map(lambda x: jnp.sum(x * x), theta)
)
return -0.25 * (1.0 / step_size) * theta_dot

def one_step(
rng_key: PRNGKey, state: MALAState, logprob_fn: Callable, step_size: float
Expand Down
12 changes: 4 additions & 8 deletions blackjax/mcmc/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
"""
from typing import Callable, Tuple

import jax
import jax.numpy as jnp
import jax.scipy as jscipy
from jax.flatten_util import ravel_pytree

from blackjax.types import Array, PRNGKey, PyTree
from blackjax.util import generate_gaussian_noise

__all__ = ["gaussian_euclidean"]

Expand Down Expand Up @@ -79,15 +79,15 @@ def gaussian_euclidean(

if ndim == 1: # diagonal mass matrix
mass_matrix_sqrt = jnp.sqrt(jnp.reciprocal(inverse_mass_matrix))
dot, matmul = jnp.multiply, jnp.multiply
matmul = jnp.multiply

elif ndim == 2:
tril_inv = jscipy.linalg.cholesky(inverse_mass_matrix)
identity = jnp.identity(shape[0])
mass_matrix_sqrt = jscipy.linalg.solve_triangular(
tril_inv, identity, lower=True
)
dot, matmul = jnp.dot, jnp.matmul
matmul = jnp.matmul

else:
raise ValueError(
Expand All @@ -96,11 +96,7 @@ def gaussian_euclidean(
)

def momentum_generator(rng_key: PRNGKey, position: PyTree) -> PyTree:
_, unravel_fn = ravel_pytree(position)
standard_normal_sample = jax.random.normal(rng_key, shape)
momentum = dot(mass_matrix_sqrt, standard_normal_sample)
momentum_unravel = unravel_fn(momentum)
return momentum_unravel
return generate_gaussian_noise(rng_key, position, sigma=mass_matrix_sqrt)

def kinetic_energy(momentum: PyTree) -> float:
momentum, _ = ravel_pytree(momentum)
Expand Down
18 changes: 4 additions & 14 deletions blackjax/mcmc/rmh.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from jax import numpy as jnp

from blackjax.types import Array, PRNGKey, PyTree
from blackjax.util import generate_gaussian_noise

__all__ = ["RMHState", "RMHInfo", "init", "kernel"]

Expand Down Expand Up @@ -190,21 +191,10 @@ def normal(sigma: Array) -> Callable:
normal distribution from which we draw the move proposals.
"""
ndim = jnp.ndim(sigma) # type: ignore[arg-type]
shape = jnp.shape(jnp.atleast_1d(sigma))[:1]

if ndim == 1:
dot = jnp.multiply
elif ndim == 2:
dot = jnp.dot
else:
raise ValueError
if jnp.ndim(sigma) > 2:
raise ValueError("sigma must be a vector or a matrix.")

def propose(rng_key: PRNGKey, position: PyTree) -> PyTree:
_, unravel_fn = jax.flatten_util.ravel_pytree(position)
sample = jax.random.normal(rng_key, shape)
move_sample = dot(sigma, sample)
move_unravel = unravel_fn(move_sample)
return move_unravel
return generate_gaussian_noise(rng_key, position, sigma=sigma)

return propose
7 changes: 1 addition & 6 deletions blackjax/sgmcmc/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax.numpy as jnp

from blackjax.types import PRNGKey, PyTree
from blackjax.util import generate_gaussian_noise

__all__ = ["overdamped_langevin"]

Expand All @@ -14,12 +15,6 @@ class DiffusionState(NamedTuple):
logprob_grad: PyTree


def generate_gaussian_noise(rng_key: PRNGKey, position: PyTree):
position_flat, unravel_fn = jax.flatten_util.ravel_pytree(position)
noise_flat = jax.random.normal(rng_key, shape=jnp.shape(position_flat))
return unravel_fn(noise_flat)


def overdamped_langevin(logprob_grad_fn):
"""Euler solver for overdamped Langevin diffusion."""

Expand Down
11 changes: 2 additions & 9 deletions blackjax/sgmcmc/sghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,11 @@
from blackjax.sgmcmc.diffusion import SGHMCState, sghmc
from blackjax.sgmcmc.sgld import SGLDState
from blackjax.types import PRNGKey, PyTree
from blackjax.util import generate_gaussian_noise

__all__ = ["kernel"]


def sample_momentum(rng_key: PRNGKey, position: PyTree, step_size: float):
position_flat, unravel_fn = jax.flatten_util.ravel_pytree(position)
noise_flat = jnp.sqrt(step_size) * jax.random.normal(
rng_key, shape=jnp.shape(position_flat)
)
return unravel_fn(noise_flat)


def kernel(
grad_estimator_fn: Callable, alpha: float = 0.01, beta: float = 0
) -> Callable:
Expand All @@ -29,7 +22,7 @@ def one_step(
) -> SGLDState:

step, position, logprob_grad = state
momentum = sample_momentum(rng_key, position, step_size)
momentum = generate_gaussian_noise(rng_key, position, jnp.sqrt(step_size))
diffusion_state = SGHMCState(position, momentum, logprob_grad)

def body_fn(state, rng_key):
Expand Down
86 changes: 86 additions & 0 deletions blackjax/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Utility functions for BlackJax."""
from functools import partial
from typing import Union

import jax.numpy as jnp
from jax import jit, lax
from jax._src.numpy.util import _promote_dtypes
from jax.flatten_util import ravel_pytree
from jax.random import normal
from jax.tree_util import tree_leaves

from blackjax.types import Array, PRNGKey, PyTree


@partial(jit, static_argnames=("precision",), inline=True)
def linear_map(diag_or_dense_a, b, *, precision="highest"):
"""Perform a linear map of the form y = Ax.
Dispatch matrix multiplication to either jnp.dot or jnp.multiply.
Unlike jax.numpy.dot, this function output an Array that match the dtype
and shape of the 2nd input:
- diag_or_dense_a is a scalar or 1d vector, `diag_or_dense_a * b` is returned
- diag_or_dense_a is a 2d matrix, `diag_or_dense_a @ b` is returned
Note that unlike jax.numpy.dot, here we defaults to full (highest)
precision. This is more useful for numerical algorithms and will be the
default for jax.numpy in the future:
https://github.com/google/jax/pull/7859
Parameters
----------
diag_or_dense_a:
A diagonal (1d vector) or dense matrix (2d square matrix).
b:
A vector.
precision:
The precision of the computation. See jax.lax.dot_general for
more details.
Returns
-------
The result vector of the matrix multiplication.
"""
diag_or_dense_a, b = _promote_dtypes(diag_or_dense_a, b)
ndim = jnp.ndim(diag_or_dense_a)

if ndim <= 1:
return lax.mul(diag_or_dense_a, b)
else:
return lax.dot(diag_or_dense_a, b, precision=precision)


# TODO(https://github.com/blackjax-devs/blackjax/issues/376)
# Refactor this function to not use ravel_pytree might be more performant.
def generate_gaussian_noise(
rng_key: PRNGKey,
position: PyTree,
mu: Union[float, Array] = 0.0,
sigma: Union[float, Array] = 1.0,
) -> PyTree:
"""Generate N(mu, sigma) noise with output structure that match a given PyTree.
Parameters
----------
rng_key:
The pseudo-random number generator key used to generate random numbers.
position:
PyTree that the structure the output should to match.
mu:
The mean of the Gaussian distribution.
sigma:
The standard deviation of the Gaussian distribution.
Returns
-------
Gaussian noise following N(mu, sigma) that match the structure of position.
"""
p, unravel_fn = ravel_pytree(position)
sample = normal(rng_key, shape=p.shape, dtype=p.dtype)
return unravel_fn(mu + linear_map(sigma, sample))


def pytree_size(pytree: PyTree) -> int:
"""Return the dimension of the flatten PyTree."""
return sum(jnp.size(value) for value in tree_leaves(pytree))
3 changes: 2 additions & 1 deletion tests/test_proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def test_normal_univariate(self):
proposal(key, jnp.array([10.0])) for key in jax.random.split(self.key, 100)
]
samples_from_another_position = [
proposal(key, jnp.array([15000])) for key in jax.random.split(self.key, 100)
proposal(key, jnp.array([15000.0]))
for key in jax.random.split(self.key, 100)
]

for samples in [samples_from_initial_position, samples_from_another_position]:
Expand Down

0 comments on commit 573a4f0

Please sign in to comment.