Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid nan acceptance probability in SA #740

Merged
merged 3 commits into from
Sep 22, 2020

Conversation

fehiepsi
Copy link
Member

Currently, when log_weights_[-1] = inf,

log_weights_[-1] - logsumexp(log_weights_)

will lead to nan acceptance probability. This PR clips log_weights by finfo.max value to prevent that issue.

cc @awakhloo

@fehiepsi
Copy link
Member Author

@martinjankowiak Is it better to clip the nonsense value inf by finfo.max or -inf? The former means that we put the main weight at that inf sample, while the latter means that we put 0 weight to that value. In the last commit, I used finfo.max but using -inf makes more sense to me.

@fehiepsi
Copy link
Member Author

Anyway, after thinking more about this, I think it is safer to use -inf log weights for invalid values: nan, +inf.

@@ -172,7 +172,8 @@ def sample_kernel(sa_state, model_args=(), model_kwargs=None):
log_weights_ = dist.MultivariateNormal(locs_, scale_tril=scales_).log_prob(zs_) + pes_
else:
log_weights_ = dist.Normal(locs_, scales_).log_prob(zs_).sum(-1) + pes_
log_weights_ = jnp.where(jnp.isnan(log_weights_), -jnp.inf, log_weights_)
# mask invalid values (nan, +inf) by -inf
log_weights_ = jnp.where(jnp.isfinite(log_weights_), log_weights_, -jnp.inf)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but when would this ever be +inf?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It comes from a diffusion drift model in our forum (where the model has lax.scan). I guess something like 1/0 is involved or precision loss during the loop... (the model is pretty complicated and involves sequential matrix multiplication....).

@neerajprad neerajprad merged commit 67c7cd7 into pyro-ppl:master Sep 22, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants