Can JAX handle the derivatives of expectation in statistics? If Yes, how does it work? #4800
-
For example, I have a function F(x), and x is under some kind of distribution p(x, theta), which has some parameters theta to determine the distribution, like mean or covariance in Normal distribution. Then I compute the expectation of F(x) under p(x, theta) (by using samples generated by random function, then computing their mean), resulting a formula which only has unknown paramters theta. Can I just use grad() to obtain the derivative of theta? Does it make sense? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
In general, no, and you will need REINFORCE aka the score-function estimator, see e.g. http://blog.shakirm.com/2015/11/machine-learning-trick-of-the-day-5-log-derivative-trick/. For location-scale distributions (which includes the normal distribution) you can use a pretty straightforward reparameterization where you sample from standard normal and then scale and shift by the desired stdev and mean. See http://blog.shakirm.com/2015/10/machine-learning-trick-of-the-day-4-reparameterisation-tricks/. |
Beta Was this translation helpful? Give feedback.
In general, no, and you will need REINFORCE aka the score-function estimator, see e.g. http://blog.shakirm.com/2015/11/machine-learning-trick-of-the-day-5-log-derivative-trick/. For location-scale distributions (which includes the normal distribution) you can use a pretty straightforward reparameterization where you sample from standard normal and then scale and shift by the desired stdev and mean. See http://blog.shakirm.com/2015/10/machine-learning-trick-of-the-day-4-reparameterisation-tricks/.