Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test Enzyme and reexport ADTypes.AutoEnzyme #1887

Draft
wants to merge 88 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
29017e7
Add support for Enzyme
devmotion Sep 28, 2022
fdf0d43
Merge branch 'master' into dw/enzyme
yebai Nov 7, 2022
d7ef23e
Merge branch 'master' into dw/enzyme
yebai Nov 12, 2022
874b9e7
Merge branch 'master' into dw/enzyme
devmotion Dec 23, 2022
43ef4c4
Apply suggestions from code review
devmotion Dec 23, 2022
3e5841f
Add Enzyme to test dependencies
devmotion Dec 29, 2022
66bce4e
Test Enzyme
devmotion Dec 29, 2022
7890134
Update ad.jl
devmotion Dec 29, 2022
120e9f5
Merge branch 'master' into dw/enzyme
yebai Feb 3, 2023
f4bd1bf
Update Project.toml
yebai Feb 3, 2023
c8e01d0
Update advi.jl
yebai Feb 3, 2023
d8d7729
Merge branch 'master' into dw/enzyme
yebai Feb 16, 2023
1d1dba0
Merge branch 'master' into dw/enzyme
yebai Mar 1, 2023
946e594
Do not call `Bijectors.setadbackend`
devmotion Mar 7, 2023
edd19a4
Merge branch 'master' into dw/enzyme
yebai Apr 12, 2023
e9eedd1
Update Project.toml
devmotion Apr 13, 2023
9cfb589
Merge branch 'master' into dw/enzyme
yebai May 25, 2023
cf03624
Merge branch 'master' into dw/enzyme
yebai Jun 14, 2023
0a5c42b
Merge branch 'master' into dw/enzyme
devmotion Jun 24, 2023
8d8d031
Address comments
devmotion Jun 26, 2023
e591630
Update runtests.jl
devmotion Jun 27, 2023
568cdac
Update Project.toml
devmotion Jul 7, 2023
d00f297
Merge branch 'master' into dw/enzyme
devmotion Jul 7, 2023
6f0bf67
Update Project.toml
devmotion Jul 7, 2023
5ba7ac6
Update Project.toml
devmotion Jul 13, 2023
162755b
Test against Enzyme#main
devmotion Jul 14, 2023
ce26c3c
Merge branch 'master' into dw/enzyme
devmotion Jul 14, 2023
b35ab28
Merge branch 'master' into dw/enzyme
yebai Jul 21, 2023
e44e756
Try addr13 branch
devmotion Jul 24, 2023
1f1b114
Update runtests.jl
devmotion Jul 27, 2023
1b3fa60
Merge branch 'master' into dw/enzyme
devmotion Jul 27, 2023
aad8a1a
Merge branch 'master' into dw/enzyme
yebai Jul 30, 2023
bb795e6
Disable Gibbs tests temporarily
yebai Jul 31, 2023
1c7f20e
Update test/Project.toml
yebai Jul 31, 2023
2a40639
Merge branch 'master' into dw/enzyme
yebai Aug 8, 2023
1b87d2e
Merge branch 'master' into dw/enzyme
yebai Aug 14, 2023
dad6b97
Merge branch 'master' into dw/enzyme
yebai Sep 4, 2023
012a0cb
disable tests unrelated to enzyme + limit CI to avoid over-use of res…
torfjelde Sep 23, 2023
552b01f
Merge branch 'master' into dw/enzyme
sunxd3 Dec 12, 2023
5777344
import `AutoEnzyme`
sunxd3 Dec 12, 2023
1ffdbca
Merge branch 'master' into dw/enzyme
sunxd3 Dec 16, 2023
121df7d
Test hmc only
sunxd3 Dec 16, 2023
a164707
Update sghmc.jl
wsmoses Dec 21, 2023
97f1fb6
Update runtests.jl
wsmoses Dec 21, 2023
c7b6cf4
disable Type unstable getfield
wsmoses Jan 25, 2024
efdd8e7
use release
wsmoses Jan 25, 2024
2fdf546
Remove seemingly unnecessary definition
devmotion Jan 25, 2024
4d8cd23
Run tests on Enzyme#main again
devmotion Jan 26, 2024
47292a7
Merge branch 'master' into dw/enzyme
wsmoses Jan 27, 2024
4b00f0d
Merge branch 'master' into dw/enzyme
wsmoses Feb 10, 2024
889275e
Merge branch 'master' into dw/enzyme
yebai Feb 27, 2024
578967b
Merge branch 'master' into dw/enzyme
yebai Mar 4, 2024
b8296be
Test with cholesky fixes
devmotion Mar 13, 2024
0385250
Merge branch 'master' into dw/enzyme
yebai Apr 8, 2024
24cc3a9
Merge branch 'master' into dw/enzyme
yebai May 29, 2024
2b54d69
Update Project.toml
yebai May 29, 2024
2823a41
Update Turing.jl
yebai May 29, 2024
f4c72bd
Merge remote-tracking branch 'origin/master' into dw/enzyme
mhauru Jun 21, 2024
6b7159c
Merge branch 'master' into dw/enzyme
devmotion Jul 1, 2024
0c376a6
Attempt at fix for `bnn` tests as outlined in #2277
torfjelde Jul 1, 2024
76b5e48
Update test/runtests.jl
yebai Jul 9, 2024
784b8cb
Update runtests.jl
yebai Jul 9, 2024
5bfd06d
remove implicit usage of `hvcat`
torfjelde Jul 9, 2024
836e29b
Merge branch 'master' into dw/enzyme
yebai Jul 9, 2024
e299042
Merge branch 'master' into dw/enzyme
devmotion Jul 10, 2024
d4d55d6
Merge branch 'master' into dw/enzyme
wsmoses Jul 21, 2024
ce13e03
Re-activate CIs disabled for Enzyme testing
torfjelde Jul 25, 2024
19a3332
Merge branch 'master' into dw/enzyme
yebai Jul 31, 2024
e2c0693
Re-enable tests with other AD backends
devmotion Aug 15, 2024
387018d
Merge branch 'master' into dw/enzyme
devmotion Aug 15, 2024
2115d52
Load `@test_broken`
devmotion Aug 15, 2024
b7ad9db
Merge branch 'master' into dw/enzyme
yebai Sep 3, 2024
1c7319b
Merge remote-tracking branch 'origin/master' into dw/enzyme
mhauru Oct 24, 2024
79d057c
Bump Enzyme to 0.13 in tests
mhauru Oct 24, 2024
c98bbc9
Run JuliaFormatter on more files, remove trailing whitespace
mhauru Oct 24, 2024
aa9abd6
Merge branch 'mhauru/more-formatting' into dw/enzyme
mhauru Oct 24, 2024
9d39193
Restore compat with Enzyme v0.12
mhauru Oct 24, 2024
66cd80a
Import Enzyme in abstractmcmc and gibbs tests
mhauru Oct 24, 2024
ec34e41
Add Enzyme imports to a couple of other tests files
mhauru Oct 24, 2024
120230a
Remove unnecessary version conditions in tests
mhauru Oct 24, 2024
98f2c2b
Merge branch 'master' into dw/enzyme
yebai Oct 25, 2024
2c88d59
Merge branch 'master' into dw/enzyme
devmotion Nov 29, 2024
f9165fa
Update Project.toml
devmotion Nov 29, 2024
e77aa54
Merge branch 'master' into dw/enzyme
mhauru Nov 29, 2024
5698212
Dump DPPL to 0.31
mhauru Dec 2, 2024
8bff239
Merge remote-tracking branch 'origin/master' into dw/enzyme
mhauru Dec 2, 2024
20b055e
Fix ADTypeCheck tests for Enzyme, add testing both Reverse and Forwar…
mhauru Dec 5, 2024
6b2b65f
Merge remote-tracking branch 'origin/master' into dw/enzyme
mhauru Jan 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ export @model, # modelling
AutoForwardDiff, # ADTypes
AutoReverseDiff,
AutoZygote,
AutoEnzyme,
AutoMooncake,
setprogress!, # debugging
Flat,
Expand Down
4 changes: 3 additions & 1 deletion src/essential/Essential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@ using Bijectors: PDMatDistribution
using AdvancedVI
using StatsFuns: logsumexp, softmax
@reexport using DynamicPPL
using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoZygote, AutoMooncake
using ADTypes:
ADTypes, AutoForwardDiff, AutoEnzyme, AutoReverseDiff, AutoZygote, AutoMooncake

using AdvancedPS: AdvancedPS

include("container.jl")

export @model,
@varname,
AutoEnzyme,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's export this as Turing.Experimental.AutoEnzyme until Enzyme becomes more stable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the threshold for being considered stable here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhauru and @penelopeysm probably have a lot more experience on this.

My heuristic threshold:

  • Enzyme passes all Distributions.jl and Turing.jl tests
  • No known segfaults for Enzyme

for a continuous period of 8 weeks.

AutoForwardDiff,
AutoZygote,
AutoReverseDiff,
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
Expand Down Expand Up @@ -52,6 +53,7 @@ Distributions = "0.25"
DistributionsAD = "0.6.3"
DynamicHMC = "2.1.6, 3.0"
DynamicPPL = "0.32.2"
Enzyme = "0.13"
FiniteDifferences = "0.10.8, 0.11, 0.12"
ForwardDiff = "0.10.12 - 0.10.32, 0.10"
HypothesisTests = "0.11"
Expand Down
107 changes: 59 additions & 48 deletions test/experimental/gibbs.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
module ExperimentalGibbsTests

using ..Models: MoGtest_default, MoGtest_default_z_vector, gdemo
using ..NumericalTests: check_MoGtest_default, check_MoGtest_default_z_vector, check_gdemo,
check_numerical, two_sample_test
using ..NumericalTests:
check_MoGtest_default,
check_MoGtest_default_z_vector,
check_gdemo,
check_numerical,
two_sample_test
using DynamicPPL
using Random
using Test
Expand All @@ -11,10 +15,7 @@ using Turing.Inference: AdvancedHMC, AdvancedMH
using ForwardDiff: ForwardDiff
using ReverseDiff: ReverseDiff

function check_transition_varnames(
transition::Turing.Inference.Transition,
parent_varnames
)
function check_transition_varnames(transition::Turing.Inference.Transition, parent_varnames)
transition_varnames = mapreduce(vcat, transition.θ) do vn_and_val
[first(vn_and_val)]
end
Expand Down Expand Up @@ -49,40 +50,32 @@ has_dot_assume(::Model) = true
end

samplers = [
Turing.Experimental.Gibbs(
vns_s => NUTS(),
vns_m => NUTS(),
),
Turing.Experimental.Gibbs(
vns_s => NUTS(),
vns_m => HMC(0.01, 4),
)
Turing.Experimental.Gibbs(vns_s => NUTS(), vns_m => NUTS()),
Turing.Experimental.Gibbs(vns_s => NUTS(), vns_m => HMC(0.01, 4)),
]

if !has_dot_assume(model)
# Add in some MH samplers, which are not compatible with `.~`.
append!(
samplers,
[
Turing.Experimental.Gibbs(
vns_s => HMC(0.01, 4),
vns_m => MH(),
),
Turing.Experimental.Gibbs(
vns_s => MH(),
vns_m => HMC(0.01, 4),
)
]
Turing.Experimental.Gibbs(vns_s => HMC(0.01, 4), vns_m => MH()),
Turing.Experimental.Gibbs(vns_s => MH(), vns_m => HMC(0.01, 4)),
],
)
end

@testset "$sampler" for sampler in samplers
# Check that taking steps performs as expected.
rng = Random.default_rng()
transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler))
transition, state = AbstractMCMC.step(
rng, model, DynamicPPL.Sampler(sampler)
)
check_transition_varnames(transition, vns)
for _ = 1:5
transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler), state)
for _ in 1:5
transition, state = AbstractMCMC.step(
rng, model, DynamicPPL.Sampler(sampler), state
)
check_transition_varnames(transition, vns)
end
end
Expand All @@ -104,8 +97,7 @@ has_dot_assume(::Model) = true
# Sampler to use for Gibbs components.
sampler_inner = HMC(0.1, 32)
sampler = Turing.Experimental.Gibbs(
vns_s => sampler_inner,
vns_m => sampler_inner,
vns_s => sampler_inner, vns_m => sampler_inner
)
Random.seed!(42)
chain = sample(
Expand All @@ -117,7 +109,7 @@ has_dot_assume(::Model) = true
progress=false,
initial_params=initial_params,
discard_initial=1_000,
thinning=thinning
thinning=thinning,
)

# "Ground truth" samples.
Expand All @@ -137,7 +129,7 @@ has_dot_assume(::Model) = true
# Perform KS test to ensure that the chains are similar.
xs = Array(chain)
xs_true = Array(chain_true)
for i = 1:size(xs, 2)
for i in 1:size(xs, 2)
@test two_sample_test(xs[:, i], xs_true[:, i]; warn_on_fail=true)
# Let's make sure that the significance level is not too low by
# checking that the KS test fails for some simple transformations.
Expand Down Expand Up @@ -200,42 +192,58 @@ has_dot_assume(::Model) = true
@varname(mu1) => ESS(),
@varname(mu2) => ESS(),
)
vns = (@varname(z1), @varname(z2), @varname(z3), @varname(z4), @varname(mu1), @varname(mu2))
vns = (
@varname(z1),
@varname(z2),
@varname(z3),
@varname(z4),
@varname(mu1),
@varname(mu2)
)
# `step`
transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg))
check_transition_varnames(transition, vns)
for _ = 1:5
transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state)
for _ in 1:5
transition, state = AbstractMCMC.step(
rng, model, DynamicPPL.Sampler(alg), state
)
check_transition_varnames(transition, vns)
end

# Sample!
Random.seed!(42)
chain = sample(MoGtest_default, alg, 1000; progress=false)
check_MoGtest_default(chain, atol = 0.2)
check_MoGtest_default(chain; atol=0.2)
end

@testset "CSMC + ESS (usage of implicit varname)" begin
rng = Random.default_rng()
model = MoGtest_default_z_vector
alg = Turing.Experimental.Gibbs(
@varname(z) => CSMC(15),
@varname(mu1) => ESS(),
@varname(mu2) => ESS(),
@varname(z) => CSMC(15), @varname(mu1) => ESS(), @varname(mu2) => ESS()
)
vns = (
@varname(z[1]),
@varname(z[2]),
@varname(z[3]),
@varname(z[4]),
@varname(mu1),
@varname(mu2)
)
vns = (@varname(z[1]), @varname(z[2]), @varname(z[3]), @varname(z[4]), @varname(mu1), @varname(mu2))
# `step`
transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg))
check_transition_varnames(transition, vns)
for _ = 1:5
transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state)
for _ in 1:5
transition, state = AbstractMCMC.step(
rng, model, DynamicPPL.Sampler(alg), state
)
check_transition_varnames(transition, vns)
end

# Sample!
Random.seed!(42)
chain = sample(model, alg, 1000; progress=false)
check_MoGtest_default_z_vector(chain, atol = 0.2)
check_MoGtest_default_z_vector(chain; atol=0.2)
end

@testset "externsalsampler" begin
Expand All @@ -252,18 +260,21 @@ has_dot_assume(::Model) = true
model = demo_gibbs_external()
samplers_inner = [
externalsampler(AdvancedMH.RWMH(1)),
externalsampler(AdvancedHMC.HMC(1e-1, 32), adtype=AutoForwardDiff()),
externalsampler(AdvancedHMC.HMC(1e-1, 32), adtype=AutoReverseDiff()),
externalsampler(AdvancedHMC.HMC(1e-1, 32), adtype=AutoReverseDiff(compile=true)),
externalsampler(AdvancedHMC.HMC(1e-1, 32); adtype=AutoForwardDiff()),
externalsampler(AdvancedHMC.HMC(1e-1, 32); adtype=AutoReverseDiff()),
externalsampler(
AdvancedHMC.HMC(1e-1, 32); adtype=AutoReverseDiff(; compile=true)
),
]
@testset "$(sampler_inner)" for sampler_inner in samplers_inner
sampler = Turing.Experimental.Gibbs(
@varname(m1) => sampler_inner,
@varname(m2) => sampler_inner,
@varname(m1) => sampler_inner, @varname(m2) => sampler_inner
)
Random.seed!(42)
chain = sample(model, sampler, 1000; discard_initial=1000, thinning=10, n_adapts=0)
check_numerical(chain, [:m1, :m2], [-0.2, 0.6], atol=0.1)
chain = sample(
model, sampler, 1000; discard_initial=1000, thinning=10, n_adapts=0
)
check_numerical(chain, [:m1, :m2], [-0.2, 0.6]; atol=0.1)
end
end
end
Expand Down
104 changes: 90 additions & 14 deletions test/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using Distributions: Bernoulli, Beta, InverseGamma, Normal
using Distributions: sample
import DynamicPPL
using DynamicPPL: Sampler, getlogp
import Enzyme
import ForwardDiff
using LinearAlgebra: I
import MCMCChains
Expand Down Expand Up @@ -449,15 +450,53 @@ using Turing
alg = HMC(0.01, 5; adtype=adbackend)
res = sample(StableRNG(seed), vdemo2(randn(D, 100)), alg, 10)

# Vector assumptions
N = 10
alg = HMC(0.2, 4; adtype=adbackend)
# TODO(mhauru) Type unstable getfield of tuple not supported in Enzyme yet
if !(adbackend isa AutoEnzyme)
# Vector assumptions
N = 10
alg = HMC(0.2, 4; adtype=adbackend)

@model function vdemo3()
x = Vector{Real}(undef, N)
for i in 1:N
x[i] ~ Normal(0, sqrt(4))
@model function vdemo3()
x = Vector{Real}(undef, N)
for i in 1:N
x[i] ~ Normal(0, sqrt(4))
end
end

t_loop = @elapsed res = sample(vdemo3(), alg, 1000)

# Test for vectorize UnivariateDistribution
@model function vdemo4()
x = Vector{Real}(undef, N)
@. x ~ Normal(0, 2)
end

t_vec = @elapsed res = sample(vdemo4(), alg, 1000)

@model vdemo5() = x ~ MvNormal(zeros(N), 4 * I)

t_mv = @elapsed res = sample(vdemo5(), alg, 1000)

println("Time for")
println(" Loop : ", t_loop)
println(" Vec : ", t_vec)
println(" Mv : ", t_mv)

# Transformed test
@model function vdemo6()
x = Vector{Real}(undef, N)
@. x ~ InverseGamma(2, 3)
end

sample(vdemo6(), alg, 1000)

N = 3
@model function vdemo7()
x = Array{Real}(undef, N, N)
@. x ~ [InverseGamma(2, 3) for i in 1:N]
end

sample(vdemo7(), alg, 1000)
end

# TODO(mhauru) What is the point of the below @elapsed stuff? It prints out some
Expand Down Expand Up @@ -519,15 +558,52 @@ using Turing
alg = HMC(0.01, 5; adtype=adbackend)
res = sample(StableRNG(seed), vdemo2(randn(D, 100)), alg, 10)

# Vector assumptions
N = 10
alg = HMC(0.2, 4; adtype=adbackend)
# TODO(mhauru) Type unstable getfield of tuple not supported in Enzyme yet
if !(adbackend isa AutoEnzyme)
# Vector assumptions
N = 10
alg = HMC(0.2, 4; adtype=adbackend)

@model function vdemo3()
x = Vector{Real}(undef, N)
for i in 1:N
x[i] ~ Normal(0, sqrt(4))
@model function vdemo3()
x = Vector{Real}(undef, N)
for i in 1:N
x[i] ~ Normal(0, sqrt(4))
end
end

t_loop = @elapsed res = sample(vdemo3(), alg, 1000)

# Test for vectorize UnivariateDistribution
@model function vdemo4()
x = Vector{Real}(undef, N)
return x .~ Normal(0, 2)
end

t_vec = @elapsed res = sample(vdemo4(), alg, 1000)

@model vdemo5() = x ~ MvNormal(zeros(N), 4 * I)

t_mv = @elapsed res = sample(vdemo5(), alg, 1000)

println("Time for")
println(" Loop : ", t_loop)
println(" Vec : ", t_vec)
println(" Mv : ", t_mv)

# Transformed test
@model function vdemo6()
x = Vector{Real}(undef, N)
return x .~ InverseGamma(2, 3)
end

sample(vdemo6(), alg, 1000)

@model function vdemo7()
x = Array{Real}(undef, N, N)
return x .~ [InverseGamma(2, 3) for i in 1:N]
end

sample(vdemo7(), alg, 1000)
end

# TODO(mhauru) Same question as above about @elapsed.
Expand Down
1 change: 1 addition & 0 deletions test/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using AdvancedMH: AdvancedMH
using Distributions: sample
using Distributions.FillArrays: Zeros
using DynamicPPL: DynamicPPL
import Enzyme
using ForwardDiff: ForwardDiff
using LinearAlgebra: I
using LogDensityProblems: LogDensityProblems
Expand Down
1 change: 1 addition & 0 deletions test/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import Combinatorics
using Distributions: InverseGamma, Normal
using Distributions: sample
using DynamicPPL: DynamicPPL
using Enzyme: Enzyme
using ForwardDiff: ForwardDiff
using Random: Random
using ReverseDiff: ReverseDiff
Expand Down
Loading
Loading