From 1b72dc26c0cecd9d0cc9d1c3595d80f5063bfd34 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Mon, 13 Jun 2022 08:41:54 +0700 Subject: [PATCH] Fix edge case of Categorical due to simplex numerical issues (#1419) * Fix edge case due to simplex numerical issues * Fix indexing issues * fix wrong normalization of probs_to_logits --- docs/requirements.txt | 3 ++- numpyro/distributions/util.py | 2 ++ setup.py | 1 + 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 1cf616275..c82228fba 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -9,9 +9,10 @@ multipledispatch nbsphinx==0.8.6 numpy optax>=0.0.6 +pyyaml readthedocs-sphinx-search==0.1.0 sphinx==4.0.3 sphinx-gallery sphinx_rtd_theme==0.5.2 tensorflow_probability>=0.15.0 -tqdm \ No newline at end of file +tqdm diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index e418e9ac6..08e537301 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -185,6 +185,8 @@ def _categorical(key, p, shape): # Ref: https://stackoverflow.com/a/34190035 shape = shape or p.shape[:-1] s = jnp.cumsum(p, axis=-1) + # Normalize s to deal with numerical issues. + s = s[..., :-1] / s[..., -1:] r = random.uniform(key, shape=shape + (1,)) # FIXME: replace this computation by using binary search as suggested in the above # reference. A while_loop + vmap for a reshaped 2D array would be enough. diff --git a/setup.py b/setup.py index ad92e8bc1..f9ccf0901 100644 --- a/setup.py +++ b/setup.py @@ -63,6 +63,7 @@ "graphviz", "jaxns==1.0.0", "optax>=0.0.6", + "pyyaml", # flax dependency "tensorflow_probability>=0.15.0", ], "examples": [