Skip to content

Commit

Permalink
Allow users to specify total_count_max in Multinomial (#1557)
Browse files Browse the repository at this point in the history
* allow users to specify total_count_max in Multinomial

* black

* add docstring for Multinomial

* revise docstring of Multinomial

* skip flaky tests on CI
  • Loading branch information
fehiepsi authored Mar 26, 2023
1 parent f98bc4d commit fa5ce64
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 12 deletions.
49 changes: 42 additions & 7 deletions numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,14 +505,17 @@ class MultinomialProbs(Distribution):
"total_count": constraints.nonnegative_integer,
}

def __init__(self, probs, total_count=1, *, validate_args=None):
def __init__(
self, probs, total_count=1, *, total_count_max=None, validate_args=None
):
if jnp.ndim(probs) < 1:
raise ValueError("`probs` parameter must be at least one-dimensional.")
batch_shape, event_shape = self.infer_shapes(
jnp.shape(probs), jnp.shape(total_count)
)
self.probs = promote_shapes(probs, shape=batch_shape + jnp.shape(probs)[-1:])[0]
self.total_count = promote_shapes(total_count, shape=batch_shape)[0]
self.total_count_max = total_count_max
super(MultinomialProbs, self).__init__(
batch_shape=batch_shape,
event_shape=event_shape,
Expand All @@ -522,7 +525,11 @@ def __init__(self, probs, total_count=1, *, validate_args=None):
def sample(self, key, sample_shape=()):
assert is_prng_key(key)
return multinomial(
key, self.probs, self.total_count, shape=sample_shape + self.batch_shape
key,
self.probs,
self.total_count,
shape=sample_shape + self.batch_shape,
total_count_max=self.total_count_max,
)

@validate_sample
Expand Down Expand Up @@ -562,7 +569,9 @@ class MultinomialLogits(Distribution):
"total_count": constraints.nonnegative_integer,
}

def __init__(self, logits, total_count=1, *, validate_args=None):
def __init__(
self, logits, total_count=1, *, total_count_max=None, validate_args=None
):
if jnp.ndim(logits) < 1:
raise ValueError("`logits` parameter must be at least one-dimensional.")
batch_shape, event_shape = self.infer_shapes(
Expand All @@ -572,6 +581,7 @@ def __init__(self, logits, total_count=1, *, validate_args=None):
logits, shape=batch_shape + jnp.shape(logits)[-1:]
)[0]
self.total_count = promote_shapes(total_count, shape=batch_shape)[0]
self.total_count_max = total_count_max
super(MultinomialLogits, self).__init__(
batch_shape=batch_shape,
event_shape=event_shape,
Expand All @@ -581,7 +591,11 @@ def __init__(self, logits, total_count=1, *, validate_args=None):
def sample(self, key, sample_shape=()):
assert is_prng_key(key)
return multinomial(
key, self.probs, self.total_count, shape=sample_shape + self.batch_shape
key,
self.probs,
self.total_count,
shape=sample_shape + self.batch_shape,
total_count_max=self.total_count_max,
)

@validate_sample
Expand Down Expand Up @@ -618,11 +632,32 @@ def infer_shapes(logits, total_count):
return batch_shape, event_shape


def Multinomial(total_count=1, probs=None, logits=None, *, validate_args=None):
def Multinomial(
total_count=1, probs=None, logits=None, *, total_count_max=None, validate_args=None
):
"""Multinomial distribution.
:param total_count: number of trials. If this is a JAX array,
it is required to specify `total_count_max`.
:param probs: event probabilities
:param logits: event log probabilities
:param int total_count_max: the maximum number of trials,
i.e. `max(total_count)`
"""
if probs is not None:
return MultinomialProbs(probs, total_count, validate_args=validate_args)
return MultinomialProbs(
probs,
total_count,
total_count_max=total_count_max,
validate_args=validate_args,
)
elif logits is not None:
return MultinomialLogits(logits, total_count, validate_args=validate_args)
return MultinomialLogits(
logits,
total_count,
total_count_max=total_count_max,
validate_args=validate_args,
)
else:
raise ValueError("One of `probs` or `logits` must be specified.")

Expand Down
14 changes: 9 additions & 5 deletions numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,15 @@ def _multinomial(key, p, n, n_max, shape=()):
return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess


def multinomial(key, p, n, shape=()):
assert not isinstance(
n, jax.core.Tracer
), "The total count parameter `n` should not be a jax abstract array."
n_max = int(np.max(jax.device_get(n)))
def multinomial(key, p, n, shape=(), total_count_max=None):
if total_count_max is None:
if isinstance(n, jax.core.Tracer):
raise ValueError(
"Please specify total_count_max in Multinomial distribution."
)
n_max = int(np.max(jax.device_get(n)))
else:
n_max = total_count_max
return _multinomial(key, p, n, n_max, shape)


Expand Down
4 changes: 4 additions & 0 deletions test/infer/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ def model(data):

@pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH])
def test_beta_bernoulli_x64(kernel_cls):
if kernel_cls is SA and "CI" in os.environ and "JAX_ENABLE_X64" in os.environ:
pytest.skip("The test is flaky on CI x64.")
num_warmup, num_samples = (100000, 100000) if kernel_cls is SA else (500, 20000)

def model(data):
Expand Down Expand Up @@ -318,6 +320,8 @@ def model(data):

@pytest.mark.parametrize("with_logits", ["True", "False"])
def test_binomial_stable_x64(with_logits):
if "CI" in os.environ and "JAX_ENABLE_X64" in os.environ:
pytest.skip("The test is flaky on CI x64.")
# Ref: https://github.com/pyro-ppl/pyro/issues/1706
num_warmup, num_samples = 200, 200

Expand Down
15 changes: 15 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2743,3 +2743,18 @@ def sample_binomial_withp0(key):
return dist.Binomial(total_count=n, probs=0).sample(key)

jax.vmap(sample_binomial_withp0)(random.split(random.PRNGKey(0), 1))


def test_multinomial_abstract_total_count():
probs = jnp.array([0.2, 0.5, 0.3])
key = random.PRNGKey(0)

def f(x):
total_count = x.sum(-1)
return dist.Multinomial(total_count, probs=probs, total_count_max=10).sample(
key
)

x = dist.Multinomial(10, probs).sample(key)
y = jax.jit(f)(x)
assert_allclose(x, y, rtol=1e-6)

0 comments on commit fa5ce64

Please sign in to comment.