-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathad.jl
76 lines (70 loc) · 3.33 KB
/
ad.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
@testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
f = DynamicPPL.LogDensityFunction(m)
rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m)
vns = DynamicPPL.TestUtils.varnames(m)
varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns)
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
f = DynamicPPL.LogDensityFunction(m, varinfo)
# use ForwardDiff result as reference
ad_forwarddiff_f = LogDensityProblemsAD.ADgradient(
ADTypes.AutoForwardDiff(; chunksize=0), f
)
# convert to `Vector{Float64}` to avoid `ReverseDiff` initializing the gradients to Integer 0
# reference: https://github.com/TuringLang/DynamicPPL.jl/pull/571#issuecomment-1924304489
θ = convert(Vector{Float64}, varinfo[:])
logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ)
@testset "$adtype" for adtype in [
ADTypes.AutoReverseDiff(; compile=false),
ADTypes.AutoReverseDiff(; compile=true),
ADTypes.AutoMooncake(; config=nothing),
]
# Mooncake can't currently handle something that is going on in
# SimpleVarInfo{<:VarNamedVector}. Disable all SimpleVarInfo tests for now.
if adtype isa ADTypes.AutoMooncake && varinfo isa DynamicPPL.SimpleVarInfo
@test_broken 1 == 0
else
ad_f = LogDensityProblemsAD.ADgradient(adtype, f)
_, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ)
@test grad ≈ ref_grad
end
end
end
end
@testset "Turing#2151: ReverseDiff compilation & eltype(vi, spl)" begin
# Failing model
t = 1:0.05:8
σ = 0.3
y = @. rand(sin(t) + Normal(0, σ))
@model function state_space(y, TT, ::Type{T}=Float64) where {T}
# Priors
α ~ Normal(y[1], 0.001)
τ ~ Exponential(1)
η ~ filldist(Normal(0, 1), TT - 1)
σ ~ Exponential(1)
# create latent variable
x = Vector{T}(undef, TT)
x[1] = α
for t in 2:TT
x[t] = x[t - 1] + η[t - 1] * τ
end
# measurement model
y ~ MvNormal(x, σ^2 * I)
return x
end
model = state_space(y, length(t))
# Dummy sampling algorithm for testing. The test case can only be replicated
# with a custom sampler, it doesn't work with SampleFromPrior(). We need to
# overload assume so that model evaluation doesn't fail due to a lack
# of implementation
struct MyEmptyAlg end
DynamicPPL.getspace(::DynamicPPL.Sampler{MyEmptyAlg}) = ()
DynamicPPL.assume(rng, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi) =
DynamicPPL.assume(dist, vn, vi)
# Compiling the ReverseDiff tape used to fail here
spl = Sampler(MyEmptyAlg())
vi = VarInfo(model)
ldf = DynamicPPL.LogDensityFunction(vi, model, SamplingContext(spl))
@test LogDensityProblemsAD.ADgradient(AutoReverseDiff(; compile=true), ldf) isa Any
end
end