-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Distributions Entropy Method #1696
Comments
Yeah, it would be great if we have the entropy method. So you can do |
Changed topic title since all distributions (or most of them) do not have an entropy method. |
@stergiosba |
I am actively working on it yes. Let's colab if you want @yayami3 |
@stergiosba |
I am working on discrete ones now. I added entropy as a method and not a property so it matches other python modules like Distrax and TFP. I have done Categorical and Bernoulli. I double check with Distrax and TFP to get the same results as they do. @pytest.mark.parametrize(
"jax_dist, sp_dist, params",
[
T(dist.BernoulliProbs, 0.2),
T(dist.BernoulliProbs, np.array([0.2, 0.7])),
T(dist.BernoulliLogits, np.array([-1.0, 3.0])),
],
) Make sure you cover edge cases like exploding logits. For the Bernoulli distribution you used def entropy(self):
return -xlogy(self.probs, self.probs) - xlog1py(1 - self.probs, -self.probs) But I wanted to make the explicit check and did this for example: def entropy(self):
"""Calculates the entropy of the Bernoulli distribution with probability p.
H(p,q)=-qlog(q)-plog(p) where q=1-p.
With extra care for p=0 and p=1.
Returns:
entropy: The entropy of the Bernoulli distribution.
"""
probs0 = _to_probs_bernoulli(-1.0 * self.logits)
probs1 = self.probs
log_probs0 = -jax.nn.softplus(self.logits)
log_probs1 = -jax.nn.softplus(-1.0 * self.logits)
# Could make this into a function if we need it elsewhere.
qlogq = jnp.where(probs0 == 0.0, 0.0, probs0 * log_probs0)
plogp = jnp.where(probs1 == 0.0, 0.0, probs1 * log_probs1)
return -qlogq - plogp I compared the performance of both solutions and is the same. Also for some reason I don't know which style is better. Maybe @fehiepsi can give his take on this. |
I am also adding a mode property for the distributions. |
@stergiosba |
I think you can clip y and use xlogy. I remember than grad needs to be computed correctly at the extreme points. I don't have strong opinion on the style though. |
Great catch there @fehiepsi There is an issue with the gradients when using the I was able to fix the nan by adding def entropy(self):
"""Calculates the entropy of the Bernoulli distribution with probability p.
H(p,q)=-qlog(q)-plog(p) where q=1-p.
With extra care for p=0 and p=1.
Returns:
entropy: The entropy of the Bernoulli distribution.
"""
probs = lax.stop_gradient(self.probs)
return -xlogy(probs, probs) - xlog1py(1 - probs, -probs) Same with clipping and using xlogy: def entropy(self):
"""Calculates the entropy of the Bernoulli distribution with probability p.
H(p,q)=-qlog(q)-plog(p) where q=1-p.
With extra care for p=0 and p=1.
Returns:
entropy: The entropy of the Bernoulli distribution.
"""
probs = lax.stop_gradient(self.probs)
# Could make this into a function if we need it elsewhere.
qlogq = jnp.where(probs == 0.0, 0.0, xlog1py(1.0-probs, -probs))
plogp = jnp.where(probs == 1.0, 0.0, xlogy(probs, probs))
return -qlogq - plogp Just for the record the first entropy calculation I provided was based on Distrax's code and it had no problems with gradients "out of the box". But we can go forward with the |
I think it is better to do: |
Yeah I was blind, I see the issue. This is the clipping: def entropy(self, eps=1e-9):
"""Calculates the entropy of the Bernoulli distribution with probability p.
H(p,q)=-qlog(q)-plog(p) where q=1-p.
With extra care for p=0 and p=1.
Returns:
entropy: The entropy of the Bernoulli distribution.
"""
probs = jnp.clip(self.probs, eps, 1.0 - eps)
return -xlogy(probs, probs) - xlog1py(1.0 - probs, -probs) Clipping works for the gradients but inherently has errors. For example we fail to pass the testcase with big negative logit. The following, although not the most beautiful, works for everything so I vote to go with it. def entropy(self):
"""Calculates the entropy of the Bernoulli distribution with probability p.
H(p,q)=-qlog(q)-plog(p) where q=1-p.
With extra care for p=0 and p=1.
Returns:
entropy: The entropy of the Bernoulli distribution.
"""
q = _to_probs_bernoulli(-1.0 * self.logits)
p = self.probs
logq = -jax.nn.softplus(self.logits)
logp = -jax.nn.softplus(-1.0 * self.logits)
# Could make this into a function if we need it elsewhere.
qlogq = jnp.where(q == 0.0, 0.0, q * logq)
plogp = jnp.where(p == 0.0, 0.0, p * logp)
return -qlogq - plogp |
Ideas about what to return when logits are very negative in a Geometric distribution. As you can see from the code below we need to divide with p and when logits are very negative p=sigmoid(logit)=0. TFP and PyTorch return nan in this case and DIstrax does not have a Geometric distribution. def entropy(self):
"""Calculates the entropy of the Geometric distribution with probability p.
H(p,q)=[-qlog(q)-plog(p)]*1/p where q=1-p.
With extra care for p=0 and p=1.
Returns:
entropy: The entropy of the Geometric distribution.
"""
q = _to_probs_bernoulli(-1.0 * self.logits)
p = self.probs
logq = -jax.nn.softplus(self.logits)
logp = -jax.nn.softplus(-1.0 * self.logits)
# Could make this into a function if we need it elsewhere.
qlogq = jnp.where(q == 0.0, 0.0, q * logq)
plogp = jnp.where(p == 0.0, 0.0, p * logp)
return (-qlogq - plogp) * 1.0 / p |
You can divide implicitly (rather than directly). e.g. I think you can use ( I have not checked yet)
Edit: ignore me, exp(-logits) can be very large |
Ok I will add some tests for the |
I think #1787 also covers most of the discrete distributions. |
@tillahoffmann Do you want to address the rest? There is a subtle numerical issue in Geometric |
@fehiepsi, do you recall what the numerical issue in Based on which of the
|
The formula is |
Makes sense. See #1852 for a fix. Shall we use the list of items above to track which distributions users have requested an implementation for? |
Sorry for the late response! Thanks for addressing the numerical issue. I guess we can close this issue and let users make new feature requests for new distributions. |
Hello guys, I come from the Tensorflow Distributions world and was looking for a lightweight alternative and was pleasantly surprised to see that Pyro is available for Jax via your amazing work.
I have implemented the PPO algorithm for some of my DRL problems and inside the loss function the entropy of a Categorical distribution is needed. I saw that the
CategoricalLogits
class does not have anentropy
method contrary to those found in TFP and Distrax (from DeepMind). Is there a different, and possibly, more streamlined way to get it in numpyro without an external function that has the following form:Is this a design choice? I have implemented an entropy method on the local numpyro I am using for my projects but possible others want this little feature added.
Anyways let me know what you think.
Cheers!
The text was updated successfully, but these errors were encountered: