Skip to content

Commit

Permalink
Fix edge case of Categorical due to simplex numerical issues (#1419)
Browse files Browse the repository at this point in the history
* Fix edge case due to simplex numerical issues

* Fix indexing issues

* fix wrong normalization of probs_to_logits
  • Loading branch information
fehiepsi authored Jun 13, 2022
1 parent b213022 commit 1b72dc2
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 1 deletion.
3 changes: 2 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ multipledispatch
nbsphinx==0.8.6
numpy
optax>=0.0.6
pyyaml
readthedocs-sphinx-search==0.1.0
sphinx==4.0.3
sphinx-gallery
sphinx_rtd_theme==0.5.2
tensorflow_probability>=0.15.0
tqdm
tqdm
2 changes: 2 additions & 0 deletions numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ def _categorical(key, p, shape):
# Ref: https://stackoverflow.com/a/34190035
shape = shape or p.shape[:-1]
s = jnp.cumsum(p, axis=-1)
# Normalize s to deal with numerical issues.
s = s[..., :-1] / s[..., -1:]
r = random.uniform(key, shape=shape + (1,))
# FIXME: replace this computation by using binary search as suggested in the above
# reference. A while_loop + vmap for a reshaped 2D array would be enough.
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"graphviz",
"jaxns==1.0.0",
"optax>=0.0.6",
"pyyaml", # flax dependency
"tensorflow_probability>=0.15.0",
],
"examples": [
Expand Down

0 comments on commit 1b72dc2

Please sign in to comment.