Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom prng key #1642

Merged
merged 4 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ jobs:
XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_mcmc.py -k "chain or pmap or vmap"
XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/test_tfp.py -k "chain"
XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_hmc_gibbs.py -k "chain"
- name: Test custom prng
run: |
JAX_ENABLE_CUSTOM_PRNG=1 pytest -vs test/infer/test_mcmc.py


examples:
Expand Down
5 changes: 2 additions & 3 deletions numpyro/contrib/einstein/steinvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import operator

from jax import grad, jacfwd, numpy as jnp, random, vmap
from jax.random import KeyArray
from jax.tree_util import tree_map

from numpyro import handlers
Expand Down Expand Up @@ -370,10 +369,10 @@ def _update_force(attr_force, rep_force, jac):
)
return jnp.linalg.norm(particle_grads), res_grads

def init(self, rng_key: KeyArray, *args, **kwargs):
def init(self, rng_key, *args, **kwargs):
"""Register random variable transformations, constraints and determine initialize positions of the particles.

:param KeyArray rng_key: Random number generator seed.
:param rng_key: Random number generator seed.
:param args: Arguments to the model / guide.
:param kwargs: Keyword arguments to the model / guide.
:return: initial :data:`SteinVIState`
Expand Down
6 changes: 3 additions & 3 deletions numpyro/contrib/tfp/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from numpyro.infer import init_to_uniform
from numpyro.infer.mcmc import MCMCKernel
from numpyro.infer.util import initialize_model
from numpyro.util import identity
from numpyro.util import identity, is_prng_key

TFPKernelState = namedtuple("TFPKernelState", ["z", "kernel_results", "rng_key"])

Expand Down Expand Up @@ -174,7 +174,7 @@ def init(
self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}
):
# non-vectorized
if rng_key.ndim == 1:
if is_prng_key(rng_key):
rng_key, rng_key_init_model = random.split(rng_key)
# vectorized
else:
Expand All @@ -190,7 +190,7 @@ def init(
" `target_log_prob_fn`."
)

if rng_key.ndim == 1:
if is_prng_key(rng_key):
init_state = self._init_fn(init_params, rng_key)
else:
# XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
Expand Down
3 changes: 2 additions & 1 deletion numpyro/distributions/conjugate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
ZeroInflatedDistribution,
)
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import is_prng_key, promote_shapes, validate_sample
from numpyro.distributions.util import promote_shapes, validate_sample
from numpyro.util import is_prng_key


def _log_beta_1(alpha, value):
Expand Down
2 changes: 1 addition & 1 deletion numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@
betaincinv,
cholesky_of_inverse,
gammaincinv,
is_prng_key,
lazy_property,
matrix_to_tril_vec,
promote_shapes,
signed_stick_breaking_tril,
validate_sample,
vec_to_tril_matrix,
)
from numpyro.util import is_prng_key


class AsymmetricLaplace(Distribution):
Expand Down
8 changes: 2 additions & 6 deletions numpyro/distributions/copula.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,8 @@
import numpyro.distributions.constraints as constraints
from numpyro.distributions.continuous import Beta, MultivariateNormal, Normal
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import (
clamp_probs,
is_prng_key,
lazy_property,
validate_sample,
)
from numpyro.distributions.util import clamp_probs, lazy_property, validate_sample
from numpyro.util import is_prng_key


class GaussianCopula(Distribution):
Expand Down
3 changes: 1 addition & 2 deletions numpyro/distributions/directional.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@
from numpyro.distributions import constraints
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import (
is_prng_key,
lazy_property,
promote_shapes,
safe_normalize,
validate_sample,
von_mises_centered,
)
from numpyro.util import while_loop
from numpyro.util import is_prng_key, while_loop


def _numel(shape):
Expand Down
3 changes: 1 addition & 2 deletions numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,12 @@
binomial,
categorical,
clamp_probs,
is_prng_key,
lazy_property,
multinomial,
promote_shapes,
validate_sample,
)
from numpyro.util import not_jax_tracer
from numpyro.util import is_prng_key, not_jax_tracer


def _to_probs_bernoulli(logits):
Expand Down
3 changes: 2 additions & 1 deletion numpyro/distributions/mixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from numpyro.distributions import Distribution, constraints
from numpyro.distributions.discrete import CategoricalLogits, CategoricalProbs
from numpyro.distributions.util import is_prng_key, validate_sample
from numpyro.distributions.util import validate_sample
from numpyro.util import is_prng_key


def Mixture(mixing_distribution, component_distributions, *, validate_args=None):
Expand Down
2 changes: 1 addition & 1 deletion numpyro/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import (
clamp_probs,
is_prng_key,
lazy_property,
promote_shapes,
validate_sample,
)
from numpyro.util import is_prng_key


class LeftTruncatedDistribution(Distribution):
Expand Down
7 changes: 4 additions & 3 deletions numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import namedtuple
from functools import partial, update_wrapper
import math
import warnings

import numpy as np

Expand Down Expand Up @@ -612,11 +613,11 @@ def safe_normalize(x, *, p=2):
return x


# src: https://github.com/google/jax/blob/5a41779fbe12ba7213cd3aa1169d3b0ffb02a094/jax/_src/random.py#L95
def is_prng_key(key):
if isinstance(key, jax.random.PRNGKeyArray):
return key.shape == ()
warnings.warn("Please use numpyro.util.is_prng_key.", DeprecationWarning)
try:
if jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key):
return key.shape == ()
return key.shape == (2,) and key.dtype == np.uint32
except AttributeError:
return False
Expand Down
16 changes: 8 additions & 8 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
apply_stack,
plate,
)
from numpyro.util import find_stack_level, not_jax_tracer
from numpyro.util import find_stack_level, is_prng_key, not_jax_tracer

__all__ = [
"block",
Expand Down Expand Up @@ -705,15 +705,15 @@ class seed(Messenger):
"""

def __init__(self, fn=None, rng_seed=None, hide_types=None):
if isinstance(rng_seed, int) or (
isinstance(rng_seed, (np.ndarray, jnp.ndarray)) and not jnp.shape(rng_seed)
if not is_prng_key(rng_seed) and (
isinstance(rng_seed, int)
or (
isinstance(rng_seed, (np.ndarray, jnp.ndarray))
and not jnp.shape(rng_seed)
)
):
rng_seed = random.PRNGKey(rng_seed)
if not (
isinstance(rng_seed, (np.ndarray, jnp.ndarray))
and rng_seed.dtype == jnp.uint32
and rng_seed.shape == (2,)
):
if not is_prng_key(rng_seed):
raise TypeError("Incorrect type for rng_seed: {}".format(type(rng_seed)))
self.rng_key = rng_seed
self.hide_types = [] if hide_types is None else hide_types
Expand Down
4 changes: 2 additions & 2 deletions numpyro/infer/barker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from numpyro.infer.initialization import init_to_uniform
from numpyro.infer.mcmc import MCMCKernel
from numpyro.infer.util import initialize_model
from numpyro.util import identity
from numpyro.util import identity, is_prng_key

BarkerMHState = namedtuple(
"BarkerMHState",
Expand Down Expand Up @@ -170,7 +170,7 @@ def _init_state(self, rng_key, model_args, model_kwargs, init_params):
def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
self._num_warmup = num_warmup
# TODO (low-priority): support chain_method="vectorized", i.e. rng_key is a batch of keys
assert rng_key.shape == (2,), (
assert is_prng_key(rng_key), (
"BarkerMH only supports chain_method='parallel' or chain_method='sequential'."
" Please put in a feature request if you think it would be useful to be able "
"to use BarkerMH in vectorized mode."
Expand Down
6 changes: 3 additions & 3 deletions numpyro/infer/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from numpyro.infer.mcmc import MCMCKernel
from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model
from numpyro.util import cond, fori_loop, identity
from numpyro.util import cond, fori_loop, identity, is_prng_key

HMCState = namedtuple(
"HMCState",
Expand Down Expand Up @@ -703,7 +703,7 @@ def init(
self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}
):
# non-vectorized
if rng_key.ndim == 1:
if is_prng_key(rng_key):
rng_key, rng_key_init_model = random.split(rng_key)
# vectorized
else:
Expand Down Expand Up @@ -749,7 +749,7 @@ def init(
model_kwargs=model_kwargs,
rng_key=rng_key,
)
if rng_key.ndim == 1:
if is_prng_key(rng_key):
init_state = hmc_init_fn(init_params, rng_key)
else:
# XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
Expand Down
12 changes: 9 additions & 3 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
from jax.tree_util import tree_flatten, tree_map

from numpyro.diagnostics import print_summary
from numpyro.util import cached_by, find_stack_level, fori_collect, identity
from numpyro.util import (
cached_by,
find_stack_level,
fori_collect,
identity,
is_prng_key,
)

__all__ = [
"MCMCKernel",
Expand Down Expand Up @@ -418,7 +424,7 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields):
sample_fn, postprocess_fn = self._get_cached_fns()
diagnostics = (
lambda x: self.sampler.get_diagnostics_str(x[0])
if rng_key.ndim == 1
if is_prng_key(rng_key)
else ""
) # noqa: E731
init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,)
Expand Down Expand Up @@ -595,7 +601,7 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
self._args = args
self._kwargs = kwargs
init_state = self._get_cached_init_state(rng_key, args, kwargs)
if self.num_chains > 1 and rng_key.ndim == 1:
if self.num_chains > 1 and is_prng_key(rng_key):
rng_key = random.split(rng_key, self.num_chains)

if self._warmup_state is not None:
Expand Down
6 changes: 3 additions & 3 deletions numpyro/infer/sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from numpyro.distributions.util import cholesky_update
from numpyro.infer.mcmc import MCMCKernel
from numpyro.infer.util import init_to_uniform, initialize_model
from numpyro.util import identity
from numpyro.util import identity, is_prng_key


def _get_proposal_loc_and_scale(samples, loc, scale, new_sample):
Expand Down Expand Up @@ -331,7 +331,7 @@ def init(
self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}
):
# non-vectorized
if rng_key.ndim == 1:
if is_prng_key(rng_key):
rng_key, rng_key_init_model = random.split(rng_key)
# vectorized
else:
Expand All @@ -358,7 +358,7 @@ def init(
model_args=model_args,
model_kwargs=model_kwargs,
)
if rng_key.ndim == 1:
if is_prng_key(rng_key):
init_state = sa_init_fn(init_params, rng_key)
else:
init_state = vmap(sa_init_fn)(init_params, rng_key)
Expand Down
8 changes: 5 additions & 3 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from numpyro.util import (
_validate_model,
find_stack_level,
is_prng_key,
not_jax_tracer,
soft_vmap,
while_loop,
Expand Down Expand Up @@ -435,7 +436,7 @@ def _find_valid_params(rng_key, exit_early=False):
return (init_params, pe, z_grad), is_valid

# Handle possible vectorization
if rng_key.ndim == 1:
if is_prng_key(rng_key):
(init_params, pe, z_grad), is_valid = _find_valid_params(
rng_key, exit_early=True
)
Expand Down Expand Up @@ -644,7 +645,7 @@ def initialize_model(
"""
model_kwargs = {} if model_kwargs is None else model_kwargs
substituted_model = substitute(
seed(model, rng_key if jnp.ndim(rng_key) == 1 else rng_key[0]),
seed(model, rng_key if is_prng_key(rng_key) else rng_key[0]),
substitute_fn=init_strategy,
)
(
Expand Down Expand Up @@ -816,9 +817,10 @@ def single_prediction(val):
return {name: value for name, value in pred_samples.items() if name in sites}

num_samples = int(np.prod(batch_shape))
key_shape = rng_key.shape
if num_samples > 1:
rng_key = random.split(rng_key, num_samples)
rng_key = rng_key.reshape((*batch_shape, 2))
rng_key = rng_key.reshape(batch_shape + key_shape)
chunk_size = num_samples if parallel else 1
return soft_vmap(
single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size
Expand Down
2 changes: 1 addition & 1 deletion numpyro/ops/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jax.experimental.pjit import pjit_p
from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic
from jax.interpreters.pxla import xla_pmap_p
import jax.linear_util as lu
import jax.extend.linear_util as lu
import jax.numpy as jnp


Expand Down
9 changes: 9 additions & 0 deletions numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,15 @@ def fori_loop(lower, upper, body_fun, init_val):
return lax.fori_loop(lower, upper, body_fun, init_val)


def is_prng_key(key):
try:
if jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key):
return key.shape == ()
return key.shape == (2,) and key.dtype == np.uint32
except AttributeError:
return False


def not_jax_tracer(x):
"""
Checks if `x` is not an array generated inside `jit`, `pmap`, `vmap`, or `lax_control_flow`.
Expand Down
4 changes: 2 additions & 2 deletions test/contrib/einstein/test_steinvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def model(obs):
expected_shape = (num_particles, *np.shape(inner_param["value"]))
assert init_value.shape == expected_shape
if "auto_loc" in name or name == "b":
assert np.alltrue(init_value != np.zeros(expected_shape))
assert np.all(init_value != np.zeros(expected_shape))
assert np.unique(init_value).shape == init_value.reshape(-1).shape
elif "scale" in name:
assert_allclose(init_value[init_value != 0.0], 0.1, rtol=1e-6)
Expand Down Expand Up @@ -311,7 +311,7 @@ def model(obs):
expected_shape = (num_particles, latent_dim)

assert expected_shape == init_value.shape
assert np.alltrue(init_value != np.zeros(expected_shape))
assert np.all(init_value != np.zeros(expected_shape))
assert np.unique(init_value).shape == init_value.reshape(-1).shape


Expand Down
Loading