Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Different supports in component distributions for mixture models #1791

Merged
merged 7 commits into from
May 12, 2024
63 changes: 52 additions & 11 deletions numpyro/distributions/mixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ class MixtureGeneral(_MixtureBase):
``mixture_size``.
:param component_distributions: A list of ``mixture_size``
:class:`~numpyro.distributions.Distribution` objects.
:param support: A :class:`~numpyro.distributions.constraints.Constraint`
object specifying the support of the mixture distribution. If not
provided, the support will be inferred from the component distributions.

**Example**

Expand All @@ -288,13 +291,36 @@ class MixtureGeneral(_MixtureBase):
>>> mixture = dist.MixtureGeneral(mixing_dist, component_dists)
>>> mixture.sample(jax.random.PRNGKey(42)).shape
()

.. doctest::

>>> import jax
>>> import jax.numpy as jnp
>>> import numpyro.distributions as dist
>>> mixing_dist = dist.Categorical(probs=jnp.ones(2) / 2.)
>>> component_dists = [
... dist.Normal(loc=0.0, scale=1.0),
... dist.HalfNormal(scale=0.3),
... ]
>>> mixture = dist.MixtureGeneral(mixing_dist, component_dists, support=dist.constraints.real)
>>> mixture.sample(jax.random.PRNGKey(42)).shape
()
"""

pytree_data_fields = ("_mixing_distribution", "_component_distributions")
pytree_data_fields = (
"_mixing_distribution",
"_component_distributions",
"_support",
)
pytree_aux_fields = ("_mixture_size",)

def __init__(
self, mixing_distribution, component_distributions, *, validate_args=None
self,
mixing_distribution,
component_distributions,
*,
support=None,
Qazalbash marked this conversation as resolved.
Show resolved Hide resolved
validate_args=None,
):
_check_mixing_distribution(mixing_distribution)

Expand All @@ -308,7 +334,7 @@ def __init__(
for d in component_distributions:
if not isinstance(d, Distribution):
raise ValueError(
"All elements of 'component_distributions' must be instaces of "
"All elements of 'component_distributions' must be instances of "
"numpyro.distributions.Distribution subclasses"
)
if len(component_distributions) != self.mixture_size:
Expand All @@ -320,11 +346,19 @@ def __init__(
# TODO: It would be good to check that the support of all the component
# distributions match, but for now we just check the type, since __eq__
# isn't consistently implemented for all support types.
support_type = type(component_distributions[0].support)
if any(
type(d.support) is not support_type for d in component_distributions[1:]
):
raise ValueError("All component distributions must have the same support.")
self._support = support
Qazalbash marked this conversation as resolved.
Show resolved Hide resolved
if support is None:
support_type = type(component_distributions[0].support)
if any(
type(d.support) is not support_type for d in component_distributions[1:]
):
raise ValueError(
"All component distributions must have the same support."
)
else:
assert isinstance(
support, constraints.Constraint
), "support must be a Constraint object"

self._mixing_distribution = mixing_distribution
self._component_distributions = component_distributions
Expand Down Expand Up @@ -357,6 +391,8 @@ def component_distributions(self):

@constraints.dependent_property
def support(self):
if self._support is not None:
return self._support
return self.component_distributions[0].support

@property
Expand Down Expand Up @@ -389,9 +425,14 @@ def component_sample(self, key, sample_shape=()):
return jnp.stack(samples, axis=self.mixture_dim)

def component_log_probs(self, value):
component_log_probs = jnp.stack(
[d.log_prob(value) for d in self.component_distributions], axis=-1
)
component_log_probs = []
for d in self.component_distributions:
log_prob = d.log_prob(value)
if (self._support is not None) and (not d._validate_args):
mask = d.support(value)
log_prob = jnp.where(mask, log_prob, -jnp.inf)
component_log_probs.append(log_prob)
component_log_probs = jnp.stack(component_log_probs, axis=-1)
return jax.nn.log_softmax(self.mixing_distribution.logits) + component_log_probs


Expand Down
45 changes: 45 additions & 0 deletions test/test_distributions_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ def get_normal(batch_shape):
return normal


def get_half_normal(batch_shape):
"""Get parameterized HalfNormal with given batch shape."""
scale = jnp.ones(batch_shape)
half_normal = dist.HalfNormal(scale=scale)
return half_normal


def get_mvn(batch_shape):
"""Get parameterized MultivariateNormal with given batch shape."""
dimensions = 2
Expand Down Expand Up @@ -78,6 +85,44 @@ def test_mixture_broadcast_batch_shape(
_test_mixture(mixing_distribution, component_distribution)


@pytest.mark.parametrize("batch_shape", [(), (1,), (7,), (2, 5)])
@pytest.mark.filterwarnings(
"ignore:Out-of-support values provided to log prob method."
" The value argument should be within the support.:UserWarning"
)
def test_mixture_with_different_support(batch_shape):
mixing_probabilities = jnp.ones(2) / 2
mixing_distribution = dist.Categorical(probs=mixing_probabilities)
component_distribution = [
get_normal(batch_shape),
get_half_normal(batch_shape),
]
mixture = dist.MixtureGeneral(
mixing_distribution=mixing_distribution,
component_distributions=component_distribution,
support=dist.constraints.real,
)
assert mixture.batch_shape == batch_shape
sample_shape = (11,)
component_distribution[0]._validate_args = True
component_distribution[1]._validate_args = True
xx = component_distribution[0].sample(rng_key, sample_shape)
log_prob_0 = component_distribution[0].log_prob(xx)
log_prob_1 = component_distribution[1].log_prob(xx)
Qazalbash marked this conversation as resolved.
Show resolved Hide resolved
expected_log_prob = jax.scipy.special.logsumexp(
jnp.stack(
[
log_prob_0 + jnp.log(mixing_probabilities[0]),
log_prob_1 + jnp.log(mixing_probabilities[1]),
],
axis=-1,
),
axis=-1,
)
result = mixture.log_prob(xx)
assert jnp.allclose(result, expected_log_prob)


def _test_mixture(mixing_distribution, component_distribution):
# Create mixture
mixture = dist.Mixture(
Expand Down
Loading