Skip to content

Commit

Permalink
Resolve numerical instability in entropy of GeometricLogits. (pyro-…
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann authored Aug 26, 2024
1 parent b19a83d commit e0d450b
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,8 +936,11 @@ def variance(self):
return (1.0 / self.probs - 1.0) / self.probs

def entropy(self):
nexp = jnp.exp(-self.logits)
return nexp * self.logits + jnp.log1p(nexp) * (1 + nexp)
logq = -jax.nn.softplus(self.logits)
logp = -jax.nn.softplus(-self.logits)
p = jax.scipy.special.expit(self.logits)
p_clip = jnp.clip(p, min=jnp.finfo(p).tiny)
return -(1 - p) * logq / p_clip - logp


def Geometric(probs=None, logits=None, *, validate_args=None):
Expand Down

0 comments on commit e0d450b

Please sign in to comment.