From 9b98d5b70ce8b3fa4199b5f9fb67b3456ae8e224 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Wed, 6 Jan 2021 16:14:58 -0600 Subject: [PATCH 1/4] improve the speed of subsample --- numpyro/primitives.py | 18 ++++++++++++++++-- test/test_handlers.py | 20 ++++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/numpyro/primitives.py b/numpyro/primitives.py index acc94fd39..3ad6022fa 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -5,7 +5,8 @@ from contextlib import ExitStack, contextmanager import functools -from jax import lax, random +from jax import lax, ops, random +from jax.lib import xla_bridge import jax.numpy as jnp import numpyro @@ -234,7 +235,20 @@ def module(name, nn, input_shape=None): def _subsample_fn(size, subsample_size, rng_key=None): assert rng_key is not None, "Missing random key to generate subsample indices." - return random.permutation(rng_key, size)[:subsample_size] + if xla_bridge.get_backend().platform == 'cpu': + u = random.uniform(rng_key, (subsample_size,)) + + def body_fn(idx, val): + i_p1 = size - idx + i = i_p1 - 1 + j = (u[idx] * i_p1).astype(i.dtype) + val = ops.index_update(val, ops.index[[i, j], ], val[ops.index[[j, i], ]]) + return val + + val = lax.fori_loop(0, subsample_size, body_fn, jnp.arange(size)) + return val[-subsample_size:] + else: + return random.choice(rng_key, size, (subsample_size,), replace=False) class plate(Messenger): diff --git a/test/test_handlers.py b/test/test_handlers.py index c4a232fb0..215716aa7 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -612,3 +612,23 @@ def model(x=None): z0 = handlers.seed(model, 0)() assert (z[1:] != z0).all() assert (z[0] == z0).all() + + +def test_subsample_fn(): + size = 20 + subsample_size = 11 + num_samples = 1000000 + + @jit + def subsample_fn(rng_key): + return numpyro.primitives._subsample_fn(size, subsample_size, rng_key) + + # test that keys are not duplicated + for i in range(10): + x = subsample_fn(random.PRNGKey(i)) + assert len(set(x)) == subsample_size + + rng_keys = random.split(random.PRNGKey(0), num_samples) + subsamples = vmap(subsample_fn)(rng_keys) + i = random.randint(random.PRNGKey(1), (), 0, size) + assert_allclose(jnp.mean(subsamples == i, axis=0), jnp.full(subsample_size, 1 / size), atol=1e-3) From 6e20dd92a914253754410dfaf1ad4e4c281e0367 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Wed, 6 Jan 2021 16:17:35 -0600 Subject: [PATCH 2/4] add reference --- numpyro/primitives.py | 1 + 1 file changed, 1 insertion(+) diff --git a/numpyro/primitives.py b/numpyro/primitives.py index 3ad6022fa..5627a4398 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -236,6 +236,7 @@ def module(name, nn, input_shape=None): def _subsample_fn(size, subsample_size, rng_key=None): assert rng_key is not None, "Missing random key to generate subsample indices." if xla_bridge.get_backend().platform == 'cpu': + # ref: https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm u = random.uniform(rng_key, (subsample_size,)) def body_fn(idx, val): From 3156cc88ad62e0f9665cb74a3ab72eba59529182 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Wed, 6 Jan 2021 16:27:26 -0600 Subject: [PATCH 3/4] also use subsample in gibbs_fn --- numpyro/infer/hmc_gibbs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 0cbab8817..5da8c2329 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -14,6 +14,7 @@ from numpyro.infer.hmc import HMC from numpyro.infer.mcmc import MCMCKernel from numpyro.infer.util import log_likelihood, potential_energy, _guess_max_plate_nesting +from numpyro.primitives import _subsample_fn from numpyro.util import cond, fori_loop, identity, ravel_pytree @@ -424,7 +425,7 @@ def gibbs_fn(rng_key, gibbs_sites, hmc_sites): for name in gibbs_sites: size, subsample_size = plate_sizes[name] rng_key, subkey = random.split(rng_key) - u_new[name] = random.choice(subkey, size, (subsample_size,), replace=False) + u_new[name] = _subsample_fn(size, subsample_size, rng_key=subkey) u_loglik = log_likelihood(_wrap_model(model), hmc_sites, *model_args, batch_ndims=0, **model_kwargs, _gibbs_sites=gibbs_sites) From fbecfc877f00d60306868a30c27f2355688acada Mon Sep 17 00:00:00 2001 From: Du Phan Date: Wed, 6 Jan 2021 22:58:30 -0600 Subject: [PATCH 4/4] make sure that subsample set does not suffer by single precision issue --- numpyro/primitives.py | 10 +++++----- test/test_handlers.py | 13 ++++++------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/numpyro/primitives.py b/numpyro/primitives.py index 5627a4398..8f38b9519 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -237,16 +237,16 @@ def _subsample_fn(size, subsample_size, rng_key=None): assert rng_key is not None, "Missing random key to generate subsample indices." if xla_bridge.get_backend().platform == 'cpu': # ref: https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm - u = random.uniform(rng_key, (subsample_size,)) + rng_keys = random.split(rng_key, subsample_size) - def body_fn(idx, val): + def body_fn(val, idx): i_p1 = size - idx i = i_p1 - 1 - j = (u[idx] * i_p1).astype(i.dtype) + j = random.randint(rng_keys[idx], (), 0, i_p1) val = ops.index_update(val, ops.index[[i, j], ], val[ops.index[[j, i], ]]) - return val + return val, None - val = lax.fori_loop(0, subsample_size, body_fn, jnp.arange(size)) + val, _ = lax.scan(body_fn, jnp.arange(size), jnp.arange(subsample_size)) return val[-subsample_size:] else: return random.choice(rng_key, size, (subsample_size,), replace=False) diff --git a/test/test_handlers.py b/test/test_handlers.py index 215716aa7..0423b8bd8 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -623,12 +623,11 @@ def test_subsample_fn(): def subsample_fn(rng_key): return numpyro.primitives._subsample_fn(size, subsample_size, rng_key) - # test that keys are not duplicated - for i in range(10): - x = subsample_fn(random.PRNGKey(i)) - assert len(set(x)) == subsample_size - rng_keys = random.split(random.PRNGKey(0), num_samples) subsamples = vmap(subsample_fn)(rng_keys) - i = random.randint(random.PRNGKey(1), (), 0, size) - assert_allclose(jnp.mean(subsamples == i, axis=0), jnp.full(subsample_size, 1 / size), atol=1e-3) + for k in range(1, 11): + i = random.randint(random.PRNGKey(k), (), 0, size) + assert_allclose(jnp.mean(subsamples == i, axis=0), jnp.full(subsample_size, 1 / size), atol=1e-3) + + # test that values are not duplicated + assert len(set(subsamples[k])) == subsample_size