Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Different supports in component distributions for mixture models #1791

Merged
merged 7 commits into from
May 12, 2024
31 changes: 20 additions & 11 deletions numpyro/distributions/mixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from numpyro.distributions import Distribution, constraints
from numpyro.distributions.discrete import CategoricalLogits, CategoricalProbs
from numpyro.distributions.util import validate_sample
from numpyro.util import is_prng_key


Expand Down Expand Up @@ -145,7 +144,7 @@ def sample_with_intermediates(self, key, sample_shape=()):
def sample(self, key, sample_shape=()):
return self.sample_with_intermediates(key=key, sample_shape=sample_shape)[0]

@validate_sample
# @validate_sample
Qazalbash marked this conversation as resolved.
Show resolved Hide resolved
Qazalbash marked this conversation as resolved.
Show resolved Hide resolved
def log_prob(self, value, intermediates=None):
del intermediates
sum_log_probs = self.component_log_probs(value)
Expand Down Expand Up @@ -308,7 +307,7 @@ def __init__(
for d in component_distributions:
if not isinstance(d, Distribution):
raise ValueError(
"All elements of 'component_distributions' must be instaces of "
"All elements of 'component_distributions' must be instances of "
"numpyro.distributions.Distribution subclasses"
)
if len(component_distributions) != self.mixture_size:
Expand All @@ -320,11 +319,9 @@ def __init__(
# TODO: It would be good to check that the support of all the component
# distributions match, but for now we just check the type, since __eq__
# isn't consistently implemented for all support types.
support_type = type(component_distributions[0].support)
if any(
type(d.support) is not support_type for d in component_distributions[1:]
):
raise ValueError("All component distributions must have the same support.")
# support_type = type(component_distributions[0].support)
Qazalbash marked this conversation as resolved.
Show resolved Hide resolved
# if any(type(d.support) is not support_type for d in component_distributions[1:]):
# raise ValueError("All component distributions must have the same support.")

self._mixing_distribution = mixing_distribution
self._component_distributions = component_distributions
Expand Down Expand Up @@ -389,9 +386,21 @@ def component_sample(self, key, sample_shape=()):
return jnp.stack(samples, axis=self.mixture_dim)

def component_log_probs(self, value):
component_log_probs = jnp.stack(
[d.log_prob(value) for d in self.component_distributions], axis=-1
)
if self._validate_args:
Qazalbash marked this conversation as resolved.
Show resolved Hide resolved
Qazalbash marked this conversation as resolved.
Show resolved Hide resolved
mask = jnp.stack(
[d.support(value) for d in self.component_distributions], axis=-1
)
component_log_probs = jnp.where(
mask,
jnp.stack(
[d.log_prob(value) for d in self.component_distributions], axis=-1
Qazalbash marked this conversation as resolved.
Show resolved Hide resolved
),
-jnp.inf,
)
else:
component_log_probs = jnp.stack(
[d.log_prob(value) for d in self.component_distributions], axis=-1
)
return jax.nn.log_softmax(self.mixing_distribution.logits) + component_log_probs


Expand Down
Loading