Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the function for generating Gaussian noise #377

Merged
merged 5 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions blackjax/adaptation/window_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
dual_averaging_adaptation,
)
from blackjax.types import Array, PyTree
from blackjax.util import pytree_size

__all__ = ["base", "schedule"]

Expand Down Expand Up @@ -97,8 +98,7 @@ def init(position: PyTree, initial_step_size: float) -> Tuple:
We may reconsider this choice in the future.

"""
flat_position, _ = jax.flatten_util.ravel_pytree(position)
num_dimensions = flat_position.shape[-1]
num_dimensions = pytree_size(position)
imm_state = mm_init(num_dimensions)

ss_state = da_init(initial_step_size)
Expand Down
9 changes: 2 additions & 7 deletions blackjax/mcmc/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import jax
import jax.numpy as jnp

from blackjax.types import PRNGKey, PyTree
from blackjax.types import PyTree
from blackjax.util import generate_gaussian_noise

__all__ = ["overdamped_langevin"]

Expand All @@ -15,12 +16,6 @@ class DiffusionState(NamedTuple):
logprob_grad: PyTree


def generate_gaussian_noise(rng_key: PRNGKey, position):
position_flat, unravel_fn = jax.flatten_util.ravel_pytree(position)
noise_flat = jax.random.normal(rng_key, shape=jnp.shape(position_flat))
return unravel_fn(noise_flat)


def overdamped_langevin(logprob_and_grad_fn):
"""Euler solver for overdamped Langevin diffusion."""

Expand Down
11 changes: 3 additions & 8 deletions blackjax/mcmc/elliptical_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax.numpy as jnp

from blackjax.types import Array, PRNGKey, PyTree
from blackjax.util import generate_gaussian_noise

__all__ = ["EllipSliceState", "EllipSliceInfo", "init", "kernel"]

Expand Down Expand Up @@ -76,11 +77,9 @@ def kernel(cov_matrix: Array, mean: Array):

if ndim == 1: # diagonal covariance matrix
cov_matrix_sqrt = jnp.sqrt(cov_matrix)
dot = jnp.multiply

elif ndim == 2:
cov_matrix_sqrt = jax.lax.linalg.cholesky(cov_matrix)
dot = jnp.dot

else:
raise ValueError(
Expand All @@ -89,11 +88,7 @@ def kernel(cov_matrix: Array, mean: Array):
)

def momentum_generator(rng_key, position):
p, unravel_fn = jax.flatten_util.ravel_pytree(position)
momentum = mean + dot(
cov_matrix_sqrt, jax.random.normal(rng_key, shape=p.shape)
)
return unravel_fn(momentum)
return generate_gaussian_noise(rng_key, position, mean, cov_matrix_sqrt)

def one_step(
rng_key: PRNGKey,
Expand Down Expand Up @@ -141,7 +136,7 @@ def generate(
key_momentum, key_uniform, key_theta = jax.random.split(rng_key, 3)
# step 1: sample momentum
momentum = momentum_generator(key_momentum, position)
# #step 2: get slice (y)
# step 2: get slice (y)
logy = loglikelihood + jnp.log(jax.random.uniform(key_uniform))
# step 3: get theta (ellipsis move), set inital interval
theta = 2 * jnp.pi * jax.random.uniform(key_theta)
Expand Down
11 changes: 4 additions & 7 deletions blackjax/mcmc/ghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree

import blackjax.mcmc.hmc as hmc
import blackjax.mcmc.integrators as integrators
import blackjax.mcmc.metrics as metrics
import blackjax.mcmc.proposal as proposal
from blackjax.types import PRNGKey, PyTree
from blackjax.util import generate_gaussian_noise, pytree_size

__all__ = ["GHMCState", "init", "kernel"]

Expand Down Expand Up @@ -44,9 +44,8 @@ def potential_fn(x):

potential_energy, potential_energy_grad = jax.value_and_grad(potential_fn)(position)

p, unravel_fn = ravel_pytree(position)
key_mometum, key_slice = jax.random.split(rng_key)
momentum = unravel_fn(jax.random.normal(key_mometum, p.shape))
momentum = generate_gaussian_noise(key_mometum, position)
slice = jax.random.uniform(key_slice, minval=-1.0, maxval=1.0)

return GHMCState(position, momentum, potential_energy, potential_energy_grad, slice)
Expand Down Expand Up @@ -220,10 +219,8 @@ def update_momentum(rng_key, state, alpha):

position, momentum, *_ = state

m, _ = ravel_pytree(momentum)
momentum_generator, *_ = metrics.gaussian_euclidean(
1 / alpha * jnp.ones(jnp.shape(m))
)
m_size = pytree_size(momentum)
momentum_generator, *_ = metrics.gaussian_euclidean(1 / alpha * jnp.ones((m_size,)))
momentum = jax.tree_map(
lambda prev_momentum, shifted_momentum: prev_momentum * jnp.sqrt(1.0 - alpha)
+ shifted_momentum,
Expand Down
8 changes: 5 additions & 3 deletions blackjax/mcmc/mala.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Public API for Metropolis Adjusted Langevin kernels."""
import operator
from typing import Callable, NamedTuple, Tuple

import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree

from blackjax.mcmc.diffusion import overdamped_langevin
from blackjax.types import PRNGKey, PyTree
Expand Down Expand Up @@ -69,8 +69,10 @@ def transition_probability(state, new_state, step_size):
state.position,
state.logprob_grad,
)
theta_ravel, _ = ravel_pytree(theta)
return -0.25 * (1.0 / step_size) * jnp.dot(theta_ravel, theta_ravel)
theta_dot = jax.tree_util.tree_reduce(
operator.add, jax.tree_util.tree_map(lambda x: jnp.sum(x * x), theta)
)
return -0.25 * (1.0 / step_size) * theta_dot

def one_step(
rng_key: PRNGKey, state: MALAState, logprob_fn: Callable, step_size: float
Expand Down
12 changes: 4 additions & 8 deletions blackjax/mcmc/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
"""
from typing import Callable, Tuple

import jax
import jax.numpy as jnp
import jax.scipy as jscipy
from jax.flatten_util import ravel_pytree

from blackjax.types import Array, PRNGKey, PyTree
from blackjax.util import generate_gaussian_noise

__all__ = ["gaussian_euclidean"]

Expand Down Expand Up @@ -79,15 +79,15 @@ def gaussian_euclidean(

if ndim == 1: # diagonal mass matrix
mass_matrix_sqrt = jnp.sqrt(jnp.reciprocal(inverse_mass_matrix))
dot, matmul = jnp.multiply, jnp.multiply
matmul = jnp.multiply

elif ndim == 2:
tril_inv = jscipy.linalg.cholesky(inverse_mass_matrix)
identity = jnp.identity(shape[0])
mass_matrix_sqrt = jscipy.linalg.solve_triangular(
tril_inv, identity, lower=True
)
dot, matmul = jnp.dot, jnp.matmul
matmul = jnp.matmul

else:
raise ValueError(
Expand All @@ -96,11 +96,7 @@ def gaussian_euclidean(
)

def momentum_generator(rng_key: PRNGKey, position: PyTree) -> PyTree:
_, unravel_fn = ravel_pytree(position)
standard_normal_sample = jax.random.normal(rng_key, shape)
momentum = dot(mass_matrix_sqrt, standard_normal_sample)
momentum_unravel = unravel_fn(momentum)
return momentum_unravel
return generate_gaussian_noise(rng_key, position, sigma=mass_matrix_sqrt)

def kinetic_energy(momentum: PyTree) -> float:
momentum, _ = ravel_pytree(momentum)
Expand Down
18 changes: 4 additions & 14 deletions blackjax/mcmc/rmh.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from jax import numpy as jnp

from blackjax.types import Array, PRNGKey, PyTree
from blackjax.util import generate_gaussian_noise

__all__ = ["RMHState", "RMHInfo", "init", "kernel"]

Expand Down Expand Up @@ -190,21 +191,10 @@ def normal(sigma: Array) -> Callable:
normal distribution from which we draw the move proposals.

"""
ndim = jnp.ndim(sigma) # type: ignore[arg-type]
shape = jnp.shape(jnp.atleast_1d(sigma))[:1]

if ndim == 1:
dot = jnp.multiply
elif ndim == 2:
dot = jnp.dot
else:
raise ValueError
if jnp.ndim(sigma) > 2:
raise ValueError("sigma must be a vector or a matrix.")

def propose(rng_key: PRNGKey, position: PyTree) -> PyTree:
_, unravel_fn = jax.flatten_util.ravel_pytree(position)
sample = jax.random.normal(rng_key, shape)
move_sample = dot(sigma, sample)
move_unravel = unravel_fn(move_sample)
return move_unravel
return generate_gaussian_noise(rng_key, position, sigma=sigma)

return propose
7 changes: 1 addition & 6 deletions blackjax/sgmcmc/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax.numpy as jnp

from blackjax.types import PRNGKey, PyTree
from blackjax.util import generate_gaussian_noise

__all__ = ["overdamped_langevin"]

Expand All @@ -14,12 +15,6 @@ class DiffusionState(NamedTuple):
logprob_grad: PyTree


def generate_gaussian_noise(rng_key: PRNGKey, position: PyTree):
position_flat, unravel_fn = jax.flatten_util.ravel_pytree(position)
noise_flat = jax.random.normal(rng_key, shape=jnp.shape(position_flat))
return unravel_fn(noise_flat)


def overdamped_langevin(logprob_grad_fn):
"""Euler solver for overdamped Langevin diffusion."""

Expand Down
11 changes: 2 additions & 9 deletions blackjax/sgmcmc/sghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,11 @@
from blackjax.sgmcmc.diffusion import SGHMCState, sghmc
from blackjax.sgmcmc.sgld import SGLDState
from blackjax.types import PRNGKey, PyTree
from blackjax.util import generate_gaussian_noise

__all__ = ["kernel"]


def sample_momentum(rng_key: PRNGKey, position: PyTree, step_size: float):
position_flat, unravel_fn = jax.flatten_util.ravel_pytree(position)
noise_flat = jnp.sqrt(step_size) * jax.random.normal(
rng_key, shape=jnp.shape(position_flat)
)
return unravel_fn(noise_flat)


def kernel(
grad_estimator_fn: Callable, alpha: float = 0.01, beta: float = 0
) -> Callable:
Expand All @@ -29,7 +22,7 @@ def one_step(
) -> SGLDState:

step, position, logprob_grad = state
momentum = sample_momentum(rng_key, position, step_size)
momentum = generate_gaussian_noise(rng_key, position, jnp.sqrt(step_size))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to somehow keep the sample_momentum mention here; noise is introduced somewhere else in the SgHMC algorithm.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think adding a code comment is sufficient.

diffusion_state = SGHMCState(position, momentum, logprob_grad)

def body_fn(state, rng_key):
Expand Down
86 changes: 86 additions & 0 deletions blackjax/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Utility functions for BlackJax."""
from functools import partial
from typing import Union

import jax.numpy as jnp
from jax import jit, lax
from jax._src.numpy.util import _promote_dtypes
from jax.flatten_util import ravel_pytree
from jax.random import normal
from jax.tree_util import tree_leaves

from blackjax.types import Array, PRNGKey, PyTree


@partial(jit, static_argnames=("precision",), inline=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we rename this file pytrees.py which I find more informative that util so you would call pytrees.random_normal for instance ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about pytree_util.py?

def linear_map(diag_or_dense_a, b, *, precision="highest"):
"""Perform a linear map of the form y = Ax.

Dispatch matrix multiplication to either jnp.dot or jnp.multiply.

Unlike jax.numpy.dot, this function output an Array that match the dtype
and shape of the 2nd input:
- diag_or_dense_a is a scalar or 1d vector, `diag_or_dense_a * b` is returned
- diag_or_dense_a is a 2d matrix, `diag_or_dense_a @ b` is returned

Note that unlike jax.numpy.dot, here we defaults to full (highest)
precision. This is more useful for numerical algorithms and will be the
default for jax.numpy in the future:
https://github.com/google/jax/pull/7859

Parameters
----------
diag_or_dense_a:
A diagonal (1d vector) or dense matrix (2d square matrix).
b:
A vector.
precision:
The precision of the computation. See jax.lax.dot_general for
more details.

Returns
-------
The result vector of the matrix multiplication.
"""
diag_or_dense_a, b = _promote_dtypes(diag_or_dense_a, b)
ndim = jnp.ndim(diag_or_dense_a)

if ndim <= 1:
return lax.mul(diag_or_dense_a, b)
else:
return lax.dot(diag_or_dense_a, b, precision=precision)


# TODO(https://github.com/blackjax-devs/blackjax/issues/376)
# Refactor this function to not use ravel_pytree might be more performant.
def generate_gaussian_noise(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can give a better name than gaussian_noise (which I assume you took from the SgMCMC algorithms)? Like random_normal?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to random_normal

rng_key: PRNGKey,
position: PyTree,
mu: Union[float, Array] = 0.0,
sigma: Union[float, Array] = 1.0,
) -> PyTree:
"""Generate N(mu, sigma) noise with output structure that match a given PyTree.

Parameters
----------
rng_key:
The pseudo-random number generator key used to generate random numbers.
position:
PyTree that the structure the output should to match.
mu:
The mean of the Gaussian distribution.
sigma:
The standard deviation of the Gaussian distribution.

Returns
-------
Gaussian noise following N(mu, sigma) that match the structure of position.
"""
p, unravel_fn = ravel_pytree(position)
sample = normal(rng_key, shape=p.shape, dtype=p.dtype)
return unravel_fn(mu + linear_map(sigma, sample))


def pytree_size(pytree: PyTree) -> int:
"""Return the dimension of the flatten PyTree."""
return sum(jnp.size(value) for value in tree_leaves(pytree))
3 changes: 2 additions & 1 deletion tests/test_proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def test_normal_univariate(self):
proposal(key, jnp.array([10.0])) for key in jax.random.split(self.key, 100)
]
samples_from_another_position = [
proposal(key, jnp.array([15000])) for key in jax.random.split(self.key, 100)
proposal(key, jnp.array([15000.0]))
for key in jax.random.split(self.key, 100)
]

for samples in [samples_from_initial_position, samples_from_another_position]:
Expand Down