From 0791b1ffc51a61f8f4424997a4902b8aa92b835d Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Fri, 4 Oct 2024 23:22:04 +0500 Subject: [PATCH] Refactor log_prob method in _MixtureBase class to handle negative infinity values in sum_log_probs (#1874) --- numpyro/distributions/mixtures.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/numpyro/distributions/mixtures.py b/numpyro/distributions/mixtures.py index 21b00d4a8..19a821c89 100644 --- a/numpyro/distributions/mixtures.py +++ b/numpyro/distributions/mixtures.py @@ -149,7 +149,10 @@ def sample(self, key, sample_shape=()): def log_prob(self, value, intermediates=None): del intermediates sum_log_probs = self.component_log_probs(value) - return jax.nn.logsumexp(sum_log_probs, axis=-1) + safe_sum_log_probs = jnp.where( + jnp.isneginf(sum_log_probs), -jnp.inf, sum_log_probs + ) + return jax.nn.logsumexp(safe_sum_log_probs, axis=-1) class MixtureSameFamily(_MixtureBase):