Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grads w.r.t. weights of MixtureGeneral Distribution are giving nans #1870

Closed
Qazalbash opened this issue Sep 27, 2024 · 3 comments · Fixed by #1874
Closed

Grads w.r.t. weights of MixtureGeneral Distribution are giving nans #1870

Qazalbash opened this issue Sep 27, 2024 · 3 comments · Fixed by #1874
Labels
enhancement New feature or request

Comments

@Qazalbash
Copy link
Contributor

Hi,

We have created some models where we estimate the weights of the MixtureGeneral distribution. However, when computing the gradient of this argument, we are encountering nan values. We enabled jax.config.update("debug_nan", True) to diagnose the issue, and it pointed to the following line:

return jax.nn.logsumexp(sum_log_probs, axis=-1)

I suspect that after the implementation of #1791, extra care is needed to handle inf and nan values, possibly by using a double where for a safe logsumexp.

Important

This is an urgent issue, so a prompt response would be greatly appreciated.

@fehiepsi fehiepsi added the enhancement New feature or request label Sep 28, 2024
@fehiepsi
Copy link
Member

You can add jax.debug.print(...) to inspect the component log probs. If all of the component log probs are -inf, nan will happen.

@Qazalbash

This comment was marked as outdated.

@Qazalbash
Copy link
Contributor Author

I am able to get gradients of weights even with -jnp.inf, by modifying,

@validate_sample
def log_prob(self, value, intermediates=None):
del intermediates
sum_log_probs = self.component_log_probs(value)
return jax.nn.logsumexp(sum_log_probs, axis=-1)

to

@validate_sample
def log_prob(self, value, intermediates=None):
    del intermediates
    sum_log_probs = self.component_log_probs(value)
    safe_sum_log_probs = jnp.where(
        jnp.isneginf(sum_log_probs), -jnp.inf, sum_log_probs
    )
    return jax.nn.logsumexp(safe_sum_log_probs, axis=-1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants