Skip to content

Can JAX handle the derivatives of expectation in statistics? If Yes, how does it work? #4800

Answered by cooijmanstim
shixinxing asked this question in Q&A
Discussion options

You must be logged in to vote

In general, no, and you will need REINFORCE aka the score-function estimator, see e.g. http://blog.shakirm.com/2015/11/machine-learning-trick-of-the-day-5-log-derivative-trick/. For location-scale distributions (which includes the normal distribution) you can use a pretty straightforward reparameterization where you sample from standard normal and then scale and shift by the desired stdev and mean. See http://blog.shakirm.com/2015/10/machine-learning-trick-of-the-day-4-reparameterisation-tricks/.

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@jeremiecoullon
Comment options

@cooijmanstim
Comment options

@shixinxing
Comment options

@cooijmanstim
Comment options

Answer selected by shixinxing
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants