Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow users to specify total_count_max in Multinomial #1557

Merged
merged 6 commits into from
Mar 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 42 additions & 7 deletions numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,14 +505,17 @@ class MultinomialProbs(Distribution):
"total_count": constraints.nonnegative_integer,
}

def __init__(self, probs, total_count=1, *, validate_args=None):
def __init__(
self, probs, total_count=1, *, total_count_max=None, validate_args=None
):
if jnp.ndim(probs) < 1:
raise ValueError("`probs` parameter must be at least one-dimensional.")
batch_shape, event_shape = self.infer_shapes(
jnp.shape(probs), jnp.shape(total_count)
)
self.probs = promote_shapes(probs, shape=batch_shape + jnp.shape(probs)[-1:])[0]
self.total_count = promote_shapes(total_count, shape=batch_shape)[0]
self.total_count_max = total_count_max
super(MultinomialProbs, self).__init__(
batch_shape=batch_shape,
event_shape=event_shape,
Expand All @@ -522,7 +525,11 @@ def __init__(self, probs, total_count=1, *, validate_args=None):
def sample(self, key, sample_shape=()):
assert is_prng_key(key)
return multinomial(
key, self.probs, self.total_count, shape=sample_shape + self.batch_shape
key,
self.probs,
self.total_count,
shape=sample_shape + self.batch_shape,
total_count_max=self.total_count_max,
)

@validate_sample
Expand Down Expand Up @@ -562,7 +569,9 @@ class MultinomialLogits(Distribution):
"total_count": constraints.nonnegative_integer,
}

def __init__(self, logits, total_count=1, *, validate_args=None):
def __init__(
self, logits, total_count=1, *, total_count_max=None, validate_args=None
):
if jnp.ndim(logits) < 1:
raise ValueError("`logits` parameter must be at least one-dimensional.")
batch_shape, event_shape = self.infer_shapes(
Expand All @@ -572,6 +581,7 @@ def __init__(self, logits, total_count=1, *, validate_args=None):
logits, shape=batch_shape + jnp.shape(logits)[-1:]
)[0]
self.total_count = promote_shapes(total_count, shape=batch_shape)[0]
self.total_count_max = total_count_max
super(MultinomialLogits, self).__init__(
batch_shape=batch_shape,
event_shape=event_shape,
Expand All @@ -581,7 +591,11 @@ def __init__(self, logits, total_count=1, *, validate_args=None):
def sample(self, key, sample_shape=()):
assert is_prng_key(key)
return multinomial(
key, self.probs, self.total_count, shape=sample_shape + self.batch_shape
key,
self.probs,
self.total_count,
shape=sample_shape + self.batch_shape,
total_count_max=self.total_count_max,
)

@validate_sample
Expand Down Expand Up @@ -618,11 +632,32 @@ def infer_shapes(logits, total_count):
return batch_shape, event_shape


def Multinomial(total_count=1, probs=None, logits=None, *, validate_args=None):
def Multinomial(
total_count=1, probs=None, logits=None, *, total_count_max=None, validate_args=None
):
"""Multinomial distribution.

:param total_count: number of trials. If this is a JAX array,
it is required to specify `total_count_max`.
:param probs: event probabilities
:param logits: event log probabilities
:param int total_count_max: the maximum number of trials,
i.e. `max(total_count)`
"""
if probs is not None:
return MultinomialProbs(probs, total_count, validate_args=validate_args)
return MultinomialProbs(
probs,
total_count,
total_count_max=total_count_max,
validate_args=validate_args,
)
elif logits is not None:
return MultinomialLogits(logits, total_count, validate_args=validate_args)
return MultinomialLogits(
logits,
total_count,
total_count_max=total_count_max,
validate_args=validate_args,
)
else:
raise ValueError("One of `probs` or `logits` must be specified.")

Expand Down
14 changes: 9 additions & 5 deletions numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,15 @@ def _multinomial(key, p, n, n_max, shape=()):
return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess


def multinomial(key, p, n, shape=()):
assert not isinstance(
n, jax.core.Tracer
), "The total count parameter `n` should not be a jax abstract array."
n_max = int(np.max(jax.device_get(n)))
def multinomial(key, p, n, shape=(), total_count_max=None):
if total_count_max is None:
if isinstance(n, jax.core.Tracer):
raise ValueError(
"Please specify total_count_max in Multinomial distribution."
)
n_max = int(np.max(jax.device_get(n)))
else:
n_max = total_count_max
return _multinomial(key, p, n, n_max, shape)


Expand Down
4 changes: 4 additions & 0 deletions test/infer/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ def model(data):

@pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH])
def test_beta_bernoulli_x64(kernel_cls):
if kernel_cls is SA and "CI" in os.environ and "JAX_ENABLE_X64" in os.environ:
pytest.skip("The test is flaky on CI x64.")
num_warmup, num_samples = (100000, 100000) if kernel_cls is SA else (500, 20000)

def model(data):
Expand Down Expand Up @@ -318,6 +320,8 @@ def model(data):

@pytest.mark.parametrize("with_logits", ["True", "False"])
def test_binomial_stable_x64(with_logits):
if "CI" in os.environ and "JAX_ENABLE_X64" in os.environ:
pytest.skip("The test is flaky on CI x64.")
# Ref: https://github.com/pyro-ppl/pyro/issues/1706
num_warmup, num_samples = 200, 200

Expand Down
15 changes: 15 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2743,3 +2743,18 @@ def sample_binomial_withp0(key):
return dist.Binomial(total_count=n, probs=0).sample(key)

jax.vmap(sample_binomial_withp0)(random.split(random.PRNGKey(0), 1))


def test_multinomial_abstract_total_count():
probs = jnp.array([0.2, 0.5, 0.3])
key = random.PRNGKey(0)

def f(x):
total_count = x.sum(-1)
return dist.Multinomial(total_count, probs=probs, total_count_max=10).sample(
key
)

x = dist.Multinomial(10, probs).sample(key)
y = jax.jit(f)(x)
assert_allclose(x, y, rtol=1e-6)