Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sampling from posterior predictive with Categorical and AutoDiagonalNormal occasionally samples outside support #1402

Closed
lumip opened this issue May 2, 2022 · 5 comments · Fixed by #1419
Labels
bug Something isn't working

Comments

@lumip
Copy link
Contributor

lumip commented May 2, 2022

We encountered a subtle problem when sampling from the posterior predictive distribution for a model containing Categorical and using the AutoDiagonalNormal: In some rare cases, the sample drawn from the Categorical would be one large than the maximum category index. E.g., a model like

def model(num_categories):
    sample("ps", Dirichlet(jnp.ones(num_categories)))
    sample("x", Categorical(probs=ps))

may sample a value of num_categories even though the support of the categorical distribution is (0, ..., num_categories-1).

This seems to arise from the combination of

We observed samples from the AutoDiagonalNormal guide for which the probability vector ps summed up to something just shy of one. In rare cases, the uniform (0,1) sample used in numpyro.distributions.util._categorical falls above the total sum, resulting in the function to return the largest category index plus one as a value.

Our current quick-and-dirty workaround for this is to simply clamp the sampled values to the support but surely there are better ways of handling this, but I haven't given it much thought. Naively it would seem that explicit renormalisation of the probability vector at some point would be the way to go, but I'm not sure that will actually be robust towards the numerical issues that seem to lead to this in the first place..

@fehiepsi
Copy link
Member

fehiepsi commented May 2, 2022

the uniform (0,1) sample used in numpyro.distributions.util._categorical falls above the total sum, resulting in the function to return the largest category index plus one as a value.

Good point! I think one way is to remove the last value 1 there: s = s[..., :-1]. Do you want to submit the fix? :)

@fehiepsi fehiepsi added the bug Something isn't working label May 2, 2022
@lumip
Copy link
Contributor Author

lumip commented May 2, 2022

I'll give that a go soon.

Edit: Thinking about it a bit more, the solution you suggest would introduce as slight bias towards the last category, which would completely subsume the error/difference of probabilities summing up to one. Given that, I would prefer renormalising s, i.e.
s = s / s[..., -1] which would keep the relative proportions of the probabilities intact - that seems to be the solution with less surprise factor for me (and contrary to my earlier rambling about potential issues would be perfectly fine to mitigate the issue). Thoughts?

@ordabayevy
Copy link
Member

This is also the reason for tests failing in funsor pyro-ppl/funsor#594. Funsor assumes that Categorical distribution probs are renormalized by the backend distribution which is done by the Pytorch's Categorical distribution but apparently not in Numpyro.

@fehiepsi
Copy link
Member

fehiepsi commented May 6, 2022

prefer renormalising s, i.e. s = s / s[..., -1]

Yeah, please go with it. Maybe we can also do s = s[..., :-1] / s[..., -1] if the division is not perfect (I don't know)

@fehiepsi
Copy link
Member

fehiepsi commented May 6, 2022

@ordabayevy Thanks for the pointer! We typically don't want to modify the inputs of the distributions in numpyro so I hope that @lumip solution will resolve the issue. Maybe we can also fix the issue here for completeness:

def _to_logits_multinom(probs):
    minval = jnp.finfo(jnp.result_type(probs)).min
    return jnp.clip(jnp.log(probs) - logsumexp(probs, axis=-1, keepdims=True), a_min=minval)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants