Sampling with HMC can hang if AD is bugged #2389

penelopeysm opened this issue Nov 4, 2024 · 3 comments · Fixed by #2392

penelopeysm opened this issue Nov 4, 2024 · 3 comments · Fixed by #2392


penelopeysm commented Nov 4, 2024

Minimal working example

using Turing

@model function model1()
    σ ~ InverseGamma(2, 3)
    V ~ truncated(Normal(0, σ), 0, Inf)

sample(model1(), NUTS(), 100)



This isn't because of a bug in Turing; it's actually a bug in ForwardDiff, which returns NaN's when calculating the gradient. (This is true of all other currently supported AD backends too, see

Because the gradient is always returned with NaN's, isfinite() on it returns false, and this block goes into an infinite loop:


Lines 176 to 194 in 397d1a7

# If no initial parameters are provided, resample until the log probability
# and its gradient are finite.
if initial_params === nothing
init_attempt_count = 1
while !isfinite(z)
if init_attempt_count == 10
@warn "failed to find valid initial parameters in $(init_attempt_count) tries; consider providing explicit initial parameters using the `initial_params` keyword"
# NOTE: This will sample in the unconstrained space.
vi = last(DynamicPPL.evaluate!!(model, rng, vi, SampleFromUniform()))
theta = vi[spl]
hamiltonian = AHMC.Hamiltonian(metric, logπ, ∂logπ∂θ)
z = AHMC.phasepoint(rng, theta, hamiltonian)
init_attempt_count += 1

It would probably make sense to just error after a sufficiently large number of attempts (not sure about the exact number, but 1000 seems reasonable perhaps?). Alternatively, or additionally, we could also check for NaN's and just directly error if logp or its gradient contains NaN's.

bug
Minor comment (of course doesn't apply to the general problem here): The problem in this specific example is the line

    V ~ truncated(Normal(0, σ), 0, Inf)

One should always use

    V ~ truncated(Normal(0, σ); lower=0)


    V ~ truncated(Normal(0, σ), 0, nothing)

Since the latter is less descriptive, I'd only use it if keyword arguments are problematic (e.g. in broadcasting).

Member Author

@devmotion Indeed, I found that out too when experimenting 😄 While you're here: I hadn't gotten round to reporting the actual NaN gradients:

using DynamicPPL: @model, LogDensityFunction
using Distributions
using LogDensityProblems: logdensity_and_gradient
using LogDensityProblemsAD: ADgradient

@model function model1()
    σ ~ InverseGamma(2, 3)
    V ~ truncated(Normal(0, σ), 0, Inf)

import ForwardDiff
ℓ = ADgradient(:ForwardDiff, LogDensityFunction(model1()))
logdensity_and_gradient(ℓ, [1.0, 2.0])
# --> (-3.0285667753085077, [NaN, NaN])

Would you consider this a bug in the AD backends (i.e. we can attempt to minimise the issue and report upstream), or improper usage of Distributions (i.e. maybe the user should be told that they shouldn't do that)?

Copy link

Mainly improper usage of Distributions, I would say. Introducing Inf in calculations makes it very likely that you'll get NaN derivatives, regardless of the AD system.

In the ForwardDiff case, sometimes you can actually get around this problem by switching to NaN-safe mode ( But IMO this is only the second-best alternative and doesn't address all Inf/NaN issues.

