From 0dabbdd07f53769e068fdc181b607c1952c4fe86 Mon Sep 17 00:00:00 2001 From: David Ziegler <25408738+InfinityMod@users.noreply.github.com> Date: Mon, 9 Sep 2024 18:05:52 -0400 Subject: [PATCH 1/5] Comments added why the binominal_dispatch function will run into an infinite loop --- numpyro/distributions/util.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index 7b0e26325..26ea2ce87 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -136,6 +136,7 @@ def _binom_inv_body_fn(val): i, key, geom_acc = val key, key_u = random.split(key) u = random.uniform(key_u) + #FIXME: we run here against -inf as log1_p equals zero log(0+1) for p = 0 geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1 geom_acc = geom_acc + geom return i + 1, key, geom_acc @@ -146,6 +147,7 @@ def _binom_inv_cond_fn(val): # this cond_exclude_large_mu is unnecessary for correctness but will # still improve performance. cond_exclude_large_mu = p * n < _binomial_mu_thresh + #FIXME: for p equals 0 the while loop will never end, as -inf is not catched return cond_exclude_large_mu & (geom_acc <= n) log1_p = jnp.log1p(-p) @@ -171,6 +173,7 @@ def dispatch(key, p, n): cond0 = jnp.isfinite(p) & (n > 0) & (p > 0) return lax.cond( cond0 & (p < 1), + # FIXME: at this point it does allow also zero values for p to be executed inside the dispatch function, even if their results are filtered out later., (key, p, n), lambda x: dispatch(*x), (), From baf4d35108914732a907d2922224a5c3f69b42a2 Mon Sep 17 00:00:00 2001 From: David Ziegler <25408738+InfinityMod@users.noreply.github.com> Date: Mon, 9 Sep 2024 18:07:41 -0400 Subject: [PATCH 2/5] Fix for the infinite loop problem of binominal_dispatch. --- numpyro/distributions/util.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index 26ea2ce87..760a880f0 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -136,7 +136,6 @@ def _binom_inv_body_fn(val): i, key, geom_acc = val key, key_u = random.split(key) u = random.uniform(key_u) - #FIXME: we run here against -inf as log1_p equals zero log(0+1) for p = 0 geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1 geom_acc = geom_acc + geom return i + 1, key, geom_acc @@ -147,7 +146,6 @@ def _binom_inv_cond_fn(val): # this cond_exclude_large_mu is unnecessary for correctness but will # still improve performance. cond_exclude_large_mu = p * n < _binomial_mu_thresh - #FIXME: for p equals 0 the while loop will never end, as -inf is not catched return cond_exclude_large_mu & (geom_acc <= n) log1_p = jnp.log1p(-p) @@ -158,6 +156,8 @@ def _binom_inv_cond_fn(val): def _binomial_dispatch(key, p, n): def dispatch(key, p, n): is_le_mid = p <= 0.5 + #Make sure p=0 is never taken into account as a fix for possible zeros in p. + p = jnp.sum(jnp.stack((p, jnp.ones(p.shape)*0.001)), where = (p == 0)) pq = jnp.where(is_le_mid, p, 1 - p) mu = n * pq k = lax.cond( @@ -173,7 +173,6 @@ def dispatch(key, p, n): cond0 = jnp.isfinite(p) & (n > 0) & (p > 0) return lax.cond( cond0 & (p < 1), - # FIXME: at this point it does allow also zero values for p to be executed inside the dispatch function, even if their results are filtered out later., (key, p, n), lambda x: dispatch(*x), (), From 1ce6995ab65e438a23c8aa3daa068cc4152e654c Mon Sep 17 00:00:00 2001 From: David Ziegler <25408738+InfinityMod@users.noreply.github.com> Date: Tue, 10 Sep 2024 01:07:17 -0400 Subject: [PATCH 3/5] Update of the fix to a more concise version. --- numpyro/distributions/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index 760a880f0..39bcbedfd 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -157,7 +157,7 @@ def _binomial_dispatch(key, p, n): def dispatch(key, p, n): is_le_mid = p <= 0.5 #Make sure p=0 is never taken into account as a fix for possible zeros in p. - p = jnp.sum(jnp.stack((p, jnp.ones(p.shape)*0.001)), where = (p == 0)) + p = jnp.clip(p, jnp.finfo(jnp.float32).tiny) pq = jnp.where(is_le_mid, p, 1 - p) mu = n * pq k = lax.cond( From ddc8a7c3ba062de9f726e9827a05ea3bd7870a8b Mon Sep 17 00:00:00 2001 From: David Ziegler <25408738+InfinityMod@users.noreply.github.com> Date: Tue, 10 Sep 2024 11:03:50 -0400 Subject: [PATCH 4/5] Linting fix --- numpyro/distributions/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index 39bcbedfd..aee340910 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -156,7 +156,7 @@ def _binom_inv_cond_fn(val): def _binomial_dispatch(key, p, n): def dispatch(key, p, n): is_le_mid = p <= 0.5 - #Make sure p=0 is never taken into account as a fix for possible zeros in p. + # Make sure p=0 is never taken into account as a fix for possible zeros in p. p = jnp.clip(p, jnp.finfo(jnp.float32).tiny) pq = jnp.where(is_le_mid, p, 1 - p) mu = n * pq From ccb4c80f7e90fb01025b9d1d9dfb6252a2fbd516 Mon Sep 17 00:00:00 2001 From: David Ziegler <25408738+InfinityMod@users.noreply.github.com> Date: Tue, 10 Sep 2024 11:32:59 -0400 Subject: [PATCH 5/5] Changed to the proposed solution, to correct log1_p value correctly --- numpyro/distributions/util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index aee340910..d32d2da25 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -149,6 +149,8 @@ def _binom_inv_cond_fn(val): return cond_exclude_large_mu & (geom_acc <= n) log1_p = jnp.log1p(-p) + # Make sure p=0 is never taken into account as a fix for possible zeros in p. + log1_p = jnp.where(log1_p == 0, -jnp.finfo(log1_p.dtype).tiny, log1_p) ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn, (-1, key, 0.0)) return ret[0] @@ -156,8 +158,6 @@ def _binom_inv_cond_fn(val): def _binomial_dispatch(key, p, n): def dispatch(key, p, n): is_le_mid = p <= 0.5 - # Make sure p=0 is never taken into account as a fix for possible zeros in p. - p = jnp.clip(p, jnp.finfo(jnp.float32).tiny) pq = jnp.where(is_le_mid, p, 1 - p) mu = n * pq k = lax.cond(