Grads w.r.t. weights of MixtureGeneral
Distribution are giving nan
s
#1870
Labels
enhancement
New feature or request
MixtureGeneral
Distribution are giving nan
s
#1870
Hi,
We have created some models where we estimate the weights of the
MixtureGeneral
distribution. However, when computing the gradient of this argument, we are encounteringnan
values. We enabledjax.config.update("debug_nan", True)
to diagnose the issue, and it pointed to the following line:numpyro/numpyro/distributions/mixtures.py
Line 152 in 8e9313f
I suspect that after the implementation of #1791, extra care is needed to handle
inf
andnan
values, possibly by using a doublewhere
for a safelogsumexp
.Important
This is an urgent issue, so a prompt response would be greatly appreciated.
The text was updated successfully, but these errors were encountered: