diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index e6fc06200..524095855 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -13,6 +13,7 @@ dual_averaging_adaptation, ) from blackjax.types import Array, PyTree +from blackjax.util import pytree_size __all__ = ["base", "schedule"] @@ -97,8 +98,7 @@ def init(position: PyTree, initial_step_size: float) -> Tuple: We may reconsider this choice in the future. """ - flat_position, _ = jax.flatten_util.ravel_pytree(position) - num_dimensions = flat_position.shape[-1] + num_dimensions = pytree_size(position) imm_state = mm_init(num_dimensions) ss_state = da_init(initial_step_size) diff --git a/blackjax/mcmc/diffusion.py b/blackjax/mcmc/diffusion.py index 48ea90e89..e138b6b72 100644 --- a/blackjax/mcmc/diffusion.py +++ b/blackjax/mcmc/diffusion.py @@ -4,7 +4,8 @@ import jax import jax.numpy as jnp -from blackjax.types import PRNGKey, PyTree +from blackjax.types import PyTree +from blackjax.util import generate_gaussian_noise __all__ = ["overdamped_langevin"] @@ -15,12 +16,6 @@ class DiffusionState(NamedTuple): logprob_grad: PyTree -def generate_gaussian_noise(rng_key: PRNGKey, position): - position_flat, unravel_fn = jax.flatten_util.ravel_pytree(position) - noise_flat = jax.random.normal(rng_key, shape=jnp.shape(position_flat)) - return unravel_fn(noise_flat) - - def overdamped_langevin(logprob_and_grad_fn): """Euler solver for overdamped Langevin diffusion.""" diff --git a/blackjax/mcmc/elliptical_slice.py b/blackjax/mcmc/elliptical_slice.py index 3b8d3c9e1..a2721c03f 100644 --- a/blackjax/mcmc/elliptical_slice.py +++ b/blackjax/mcmc/elliptical_slice.py @@ -5,6 +5,7 @@ import jax.numpy as jnp from blackjax.types import Array, PRNGKey, PyTree +from blackjax.util import generate_gaussian_noise __all__ = ["EllipSliceState", "EllipSliceInfo", "init", "kernel"] @@ -76,11 +77,9 @@ def kernel(cov_matrix: Array, mean: Array): if ndim == 1: # diagonal covariance matrix cov_matrix_sqrt = jnp.sqrt(cov_matrix) - dot = jnp.multiply elif ndim == 2: cov_matrix_sqrt = jax.lax.linalg.cholesky(cov_matrix) - dot = jnp.dot else: raise ValueError( @@ -89,11 +88,7 @@ def kernel(cov_matrix: Array, mean: Array): ) def momentum_generator(rng_key, position): - p, unravel_fn = jax.flatten_util.ravel_pytree(position) - momentum = mean + dot( - cov_matrix_sqrt, jax.random.normal(rng_key, shape=p.shape) - ) - return unravel_fn(momentum) + return generate_gaussian_noise(rng_key, position, mean, cov_matrix_sqrt) def one_step( rng_key: PRNGKey, @@ -141,7 +136,7 @@ def generate( key_momentum, key_uniform, key_theta = jax.random.split(rng_key, 3) # step 1: sample momentum momentum = momentum_generator(key_momentum, position) - # #step 2: get slice (y) + # step 2: get slice (y) logy = loglikelihood + jnp.log(jax.random.uniform(key_uniform)) # step 3: get theta (ellipsis move), set inital interval theta = 2 * jnp.pi * jax.random.uniform(key_theta) diff --git a/blackjax/mcmc/ghmc.py b/blackjax/mcmc/ghmc.py index 1d68c45ef..821ccba79 100644 --- a/blackjax/mcmc/ghmc.py +++ b/blackjax/mcmc/ghmc.py @@ -3,13 +3,13 @@ import jax import jax.numpy as jnp -from jax.flatten_util import ravel_pytree import blackjax.mcmc.hmc as hmc import blackjax.mcmc.integrators as integrators import blackjax.mcmc.metrics as metrics import blackjax.mcmc.proposal as proposal from blackjax.types import PRNGKey, PyTree +from blackjax.util import generate_gaussian_noise, pytree_size __all__ = ["GHMCState", "init", "kernel"] @@ -44,9 +44,8 @@ def potential_fn(x): potential_energy, potential_energy_grad = jax.value_and_grad(potential_fn)(position) - p, unravel_fn = ravel_pytree(position) key_mometum, key_slice = jax.random.split(rng_key) - momentum = unravel_fn(jax.random.normal(key_mometum, p.shape)) + momentum = generate_gaussian_noise(key_mometum, position) slice = jax.random.uniform(key_slice, minval=-1.0, maxval=1.0) return GHMCState(position, momentum, potential_energy, potential_energy_grad, slice) @@ -220,10 +219,8 @@ def update_momentum(rng_key, state, alpha): position, momentum, *_ = state - m, _ = ravel_pytree(momentum) - momentum_generator, *_ = metrics.gaussian_euclidean( - 1 / alpha * jnp.ones(jnp.shape(m)) - ) + m_size = pytree_size(momentum) + momentum_generator, *_ = metrics.gaussian_euclidean(1 / alpha * jnp.ones((m_size,))) momentum = jax.tree_map( lambda prev_momentum, shifted_momentum: prev_momentum * jnp.sqrt(1.0 - alpha) + shifted_momentum, diff --git a/blackjax/mcmc/mala.py b/blackjax/mcmc/mala.py index f809c2d4e..8623e9314 100644 --- a/blackjax/mcmc/mala.py +++ b/blackjax/mcmc/mala.py @@ -1,9 +1,9 @@ """Public API for Metropolis Adjusted Langevin kernels.""" +import operator from typing import Callable, NamedTuple, Tuple import jax import jax.numpy as jnp -from jax.flatten_util import ravel_pytree from blackjax.mcmc.diffusion import overdamped_langevin from blackjax.types import PRNGKey, PyTree @@ -69,8 +69,10 @@ def transition_probability(state, new_state, step_size): state.position, state.logprob_grad, ) - theta_ravel, _ = ravel_pytree(theta) - return -0.25 * (1.0 / step_size) * jnp.dot(theta_ravel, theta_ravel) + theta_dot = jax.tree_util.tree_reduce( + operator.add, jax.tree_util.tree_map(lambda x: jnp.sum(x * x), theta) + ) + return -0.25 * (1.0 / step_size) * theta_dot def one_step( rng_key: PRNGKey, state: MALAState, logprob_fn: Callable, step_size: float diff --git a/blackjax/mcmc/metrics.py b/blackjax/mcmc/metrics.py index f7d953fd1..ba80c801b 100644 --- a/blackjax/mcmc/metrics.py +++ b/blackjax/mcmc/metrics.py @@ -22,12 +22,12 @@ """ from typing import Callable, Tuple -import jax import jax.numpy as jnp import jax.scipy as jscipy from jax.flatten_util import ravel_pytree from blackjax.types import Array, PRNGKey, PyTree +from blackjax.util import generate_gaussian_noise __all__ = ["gaussian_euclidean"] @@ -79,7 +79,7 @@ def gaussian_euclidean( if ndim == 1: # diagonal mass matrix mass_matrix_sqrt = jnp.sqrt(jnp.reciprocal(inverse_mass_matrix)) - dot, matmul = jnp.multiply, jnp.multiply + matmul = jnp.multiply elif ndim == 2: tril_inv = jscipy.linalg.cholesky(inverse_mass_matrix) @@ -87,7 +87,7 @@ def gaussian_euclidean( mass_matrix_sqrt = jscipy.linalg.solve_triangular( tril_inv, identity, lower=True ) - dot, matmul = jnp.dot, jnp.matmul + matmul = jnp.matmul else: raise ValueError( @@ -96,11 +96,7 @@ def gaussian_euclidean( ) def momentum_generator(rng_key: PRNGKey, position: PyTree) -> PyTree: - _, unravel_fn = ravel_pytree(position) - standard_normal_sample = jax.random.normal(rng_key, shape) - momentum = dot(mass_matrix_sqrt, standard_normal_sample) - momentum_unravel = unravel_fn(momentum) - return momentum_unravel + return generate_gaussian_noise(rng_key, position, sigma=mass_matrix_sqrt) def kinetic_energy(momentum: PyTree) -> float: momentum, _ = ravel_pytree(momentum) diff --git a/blackjax/mcmc/rmh.py b/blackjax/mcmc/rmh.py index 15d453804..0fe583305 100644 --- a/blackjax/mcmc/rmh.py +++ b/blackjax/mcmc/rmh.py @@ -5,6 +5,7 @@ from jax import numpy as jnp from blackjax.types import Array, PRNGKey, PyTree +from blackjax.util import generate_gaussian_noise __all__ = ["RMHState", "RMHInfo", "init", "kernel"] @@ -190,21 +191,10 @@ def normal(sigma: Array) -> Callable: normal distribution from which we draw the move proposals. """ - ndim = jnp.ndim(sigma) # type: ignore[arg-type] - shape = jnp.shape(jnp.atleast_1d(sigma))[:1] - - if ndim == 1: - dot = jnp.multiply - elif ndim == 2: - dot = jnp.dot - else: - raise ValueError + if jnp.ndim(sigma) > 2: + raise ValueError("sigma must be a vector or a matrix.") def propose(rng_key: PRNGKey, position: PyTree) -> PyTree: - _, unravel_fn = jax.flatten_util.ravel_pytree(position) - sample = jax.random.normal(rng_key, shape) - move_sample = dot(sigma, sample) - move_unravel = unravel_fn(move_sample) - return move_unravel + return generate_gaussian_noise(rng_key, position, sigma=sigma) return propose diff --git a/blackjax/sgmcmc/diffusion.py b/blackjax/sgmcmc/diffusion.py index da04cce48..86cefd09d 100644 --- a/blackjax/sgmcmc/diffusion.py +++ b/blackjax/sgmcmc/diffusion.py @@ -5,6 +5,7 @@ import jax.numpy as jnp from blackjax.types import PRNGKey, PyTree +from blackjax.util import generate_gaussian_noise __all__ = ["overdamped_langevin"] @@ -14,12 +15,6 @@ class DiffusionState(NamedTuple): logprob_grad: PyTree -def generate_gaussian_noise(rng_key: PRNGKey, position: PyTree): - position_flat, unravel_fn = jax.flatten_util.ravel_pytree(position) - noise_flat = jax.random.normal(rng_key, shape=jnp.shape(position_flat)) - return unravel_fn(noise_flat) - - def overdamped_langevin(logprob_grad_fn): """Euler solver for overdamped Langevin diffusion.""" diff --git a/blackjax/sgmcmc/sghmc.py b/blackjax/sgmcmc/sghmc.py index a1840fab7..5bd455fa0 100644 --- a/blackjax/sgmcmc/sghmc.py +++ b/blackjax/sgmcmc/sghmc.py @@ -7,18 +7,11 @@ from blackjax.sgmcmc.diffusion import SGHMCState, sghmc from blackjax.sgmcmc.sgld import SGLDState from blackjax.types import PRNGKey, PyTree +from blackjax.util import generate_gaussian_noise __all__ = ["kernel"] -def sample_momentum(rng_key: PRNGKey, position: PyTree, step_size: float): - position_flat, unravel_fn = jax.flatten_util.ravel_pytree(position) - noise_flat = jnp.sqrt(step_size) * jax.random.normal( - rng_key, shape=jnp.shape(position_flat) - ) - return unravel_fn(noise_flat) - - def kernel( grad_estimator_fn: Callable, alpha: float = 0.01, beta: float = 0 ) -> Callable: @@ -29,7 +22,7 @@ def one_step( ) -> SGLDState: step, position, logprob_grad = state - momentum = sample_momentum(rng_key, position, step_size) + momentum = generate_gaussian_noise(rng_key, position, jnp.sqrt(step_size)) diffusion_state = SGHMCState(position, momentum, logprob_grad) def body_fn(state, rng_key): diff --git a/blackjax/util.py b/blackjax/util.py new file mode 100644 index 000000000..08786be97 --- /dev/null +++ b/blackjax/util.py @@ -0,0 +1,86 @@ +"""Utility functions for BlackJax.""" +from functools import partial +from typing import Union + +import jax.numpy as jnp +from jax import jit, lax +from jax._src.numpy.util import _promote_dtypes +from jax.flatten_util import ravel_pytree +from jax.random import normal +from jax.tree_util import tree_leaves + +from blackjax.types import Array, PRNGKey, PyTree + + +@partial(jit, static_argnames=("precision",), inline=True) +def linear_map(diag_or_dense_a, b, *, precision="highest"): + """Perform a linear map of the form y = Ax. + + Dispatch matrix multiplication to either jnp.dot or jnp.multiply. + + Unlike jax.numpy.dot, this function output an Array that match the dtype + and shape of the 2nd input: + - diag_or_dense_a is a scalar or 1d vector, `diag_or_dense_a * b` is returned + - diag_or_dense_a is a 2d matrix, `diag_or_dense_a @ b` is returned + + Note that unlike jax.numpy.dot, here we defaults to full (highest) + precision. This is more useful for numerical algorithms and will be the + default for jax.numpy in the future: + https://github.com/google/jax/pull/7859 + + Parameters + ---------- + diag_or_dense_a: + A diagonal (1d vector) or dense matrix (2d square matrix). + b: + A vector. + precision: + The precision of the computation. See jax.lax.dot_general for + more details. + + Returns + ------- + The result vector of the matrix multiplication. + """ + diag_or_dense_a, b = _promote_dtypes(diag_or_dense_a, b) + ndim = jnp.ndim(diag_or_dense_a) + + if ndim <= 1: + return lax.mul(diag_or_dense_a, b) + else: + return lax.dot(diag_or_dense_a, b, precision=precision) + + +# TODO(https://github.com/blackjax-devs/blackjax/issues/376) +# Refactor this function to not use ravel_pytree might be more performant. +def generate_gaussian_noise( + rng_key: PRNGKey, + position: PyTree, + mu: Union[float, Array] = 0.0, + sigma: Union[float, Array] = 1.0, +) -> PyTree: + """Generate N(mu, sigma) noise with output structure that match a given PyTree. + + Parameters + ---------- + rng_key: + The pseudo-random number generator key used to generate random numbers. + position: + PyTree that the structure the output should to match. + mu: + The mean of the Gaussian distribution. + sigma: + The standard deviation of the Gaussian distribution. + + Returns + ------- + Gaussian noise following N(mu, sigma) that match the structure of position. + """ + p, unravel_fn = ravel_pytree(position) + sample = normal(rng_key, shape=p.shape, dtype=p.dtype) + return unravel_fn(mu + linear_map(sigma, sample)) + + +def pytree_size(pytree: PyTree) -> int: + """Return the dimension of the flatten PyTree.""" + return sum(jnp.size(value) for value in tree_leaves(pytree)) diff --git a/tests/test_proposal.py b/tests/test_proposal.py index 078e2d7de..5498b1ccf 100644 --- a/tests/test_proposal.py +++ b/tests/test_proposal.py @@ -22,7 +22,8 @@ def test_normal_univariate(self): proposal(key, jnp.array([10.0])) for key in jax.random.split(self.key, 100) ] samples_from_another_position = [ - proposal(key, jnp.array([15000])) for key in jax.random.split(self.key, 100) + proposal(key, jnp.array([15000.0])) + for key in jax.random.split(self.key, 100) ] for samples in [samples_from_initial_position, samples_from_another_position]: