-
Notifications
You must be signed in to change notification settings - Fork 223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sampling with HMC can hang if AD is bugged #2389
Comments
Minor comment (of course doesn't apply to the general problem here): The problem in this specific example is the line V ~ truncated(Normal(0, σ), 0, Inf) One should always use V ~ truncated(Normal(0, σ); lower=0) or V ~ truncated(Normal(0, σ), 0, nothing) Since the latter is less descriptive, I'd only use it if keyword arguments are problematic (e.g. in broadcasting). |
@devmotion Indeed, I found that out too when experimenting 😄 While you're here: I hadn't gotten round to reporting the actual NaN gradients: using DynamicPPL: @model, LogDensityFunction
using Distributions
using LogDensityProblems: logdensity_and_gradient
using LogDensityProblemsAD: ADgradient
@model function model1()
σ ~ InverseGamma(2, 3)
V ~ truncated(Normal(0, σ), 0, Inf)
end
import ForwardDiff
ℓ = ADgradient(:ForwardDiff, LogDensityFunction(model1()))
logdensity_and_gradient(ℓ, [1.0, 2.0])
# --> (-3.0285667753085077, [NaN, NaN]) Would you consider this a bug in the AD backends (i.e. we can attempt to minimise the issue and report upstream), or improper usage of Distributions (i.e. maybe the user should be told that they shouldn't do that)? |
Mainly improper usage of Distributions, I would say. Introducing In the ForwardDiff case, sometimes you can actually get around this problem by switching to |
Minimal working example
Description
Hangs.
This isn't because of a bug in Turing; it's actually a bug in ForwardDiff, which returns NaN's when calculating the gradient. (This is true of all other currently supported AD backends too, see https://discourse.julialang.org/t/help-cant-get-turing-to-work-on-a-simple-model/122107/2)
Because the gradient is always returned with
NaN
's,isfinite()
on it returns false, and this block goes into an infinite loop:Turing.jl/src/mcmc/hmc.jl
Lines 176 to 194 in 397d1a7
It would probably make sense to just error after a sufficiently large number of attempts (not sure about the exact number, but 1000 seems reasonable perhaps?). Alternatively, or additionally, we could also check for NaN's and just directly error if logp or its gradient contains NaN's.
Julia version info
versioninfo()
Manifest
The relevant parts are:
I did paste the whole thing here because I'm the one who made this issue template and I had better abide by it 😄
]st --manifest
The text was updated successfully, but these errors were encountered: