diff --git a/numpyro/distributions/mixtures.py b/numpyro/distributions/mixtures.py index 21b00d4a8..19a821c89 100644 --- a/numpyro/distributions/mixtures.py +++ b/numpyro/distributions/mixtures.py @@ -149,7 +149,10 @@ def sample(self, key, sample_shape=()): def log_prob(self, value, intermediates=None): del intermediates sum_log_probs = self.component_log_probs(value) - return jax.nn.logsumexp(sum_log_probs, axis=-1) + safe_sum_log_probs = jnp.where( + jnp.isneginf(sum_log_probs), -jnp.inf, sum_log_probs + ) + return jax.nn.logsumexp(safe_sum_log_probs, axis=-1) class MixtureSameFamily(_MixtureBase):