Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make subsample faster in CPU #865

Merged
merged 6 commits into from
Jan 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import functools
import warnings

from jax import lax, random
from jax import lax, ops, random
from jax.lib import xla_bridge
import jax.numpy as jnp

import numpyro
Expand Down Expand Up @@ -235,7 +236,21 @@ def module(name, nn, input_shape=None):

def _subsample_fn(size, subsample_size, rng_key=None):
assert rng_key is not None, "Missing random key to generate subsample indices."
return random.permutation(rng_key, size)[:subsample_size]
if xla_bridge.get_backend().platform == 'cpu':
# ref: https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm
rng_keys = random.split(rng_key, subsample_size)

def body_fn(val, idx):
i_p1 = size - idx
i = i_p1 - 1
j = random.randint(rng_keys[idx], (), 0, i_p1)
val = ops.index_update(val, ops.index[[i, j], ], val[ops.index[[j, i], ]])
return val, None

val, _ = lax.scan(body_fn, jnp.arange(size), jnp.arange(subsample_size))
return val[-subsample_size:]
else:
return random.choice(rng_key, size, (subsample_size,), replace=False)


class plate(Messenger):
Expand Down
19 changes: 19 additions & 0 deletions test/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,3 +612,22 @@ def model(x=None):
z0 = handlers.seed(model, 0)()
assert (z[1:] != z0).all()
assert (z[0] == z0).all()


def test_subsample_fn():
size = 20
subsample_size = 11
num_samples = 1000000

@jit
def subsample_fn(rng_key):
return numpyro.primitives._subsample_fn(size, subsample_size, rng_key)

rng_keys = random.split(random.PRNGKey(0), num_samples)
subsamples = vmap(subsample_fn)(rng_keys)
for k in range(1, 11):
i = random.randint(random.PRNGKey(k), (), 0, size)
assert_allclose(jnp.mean(subsamples == i, axis=0), jnp.full(subsample_size, 1 / size), atol=1e-3)

# test that values are not duplicated
assert len(set(subsamples[k])) == subsample_size