From 29017e76db7ea3be39942f7f7a36876254f0282a Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 28 Sep 2022 12:34:25 +0200 Subject: [PATCH 01/47] Add support for Enzyme --- src/essential/Essential.jl | 1 + src/essential/ad.jl | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/src/essential/Essential.jl b/src/essential/Essential.jl index df0a9b5ac0..331dbeed9f 100644 --- a/src/essential/Essential.jl +++ b/src/essential/Essential.jl @@ -42,6 +42,7 @@ export @model, setadbackend, setadsafe, ForwardDiffAD, + EnzymeAD, TrackerAD, ZygoteAD, ReverseDiffAD, diff --git a/src/essential/ad.jl b/src/essential/ad.jl index b56ce01407..b2530435d1 100644 --- a/src/essential/ad.jl +++ b/src/essential/ad.jl @@ -12,6 +12,9 @@ end function _setadbackend(::Val{:forwarddiff}) ADBACKEND[] = :forwarddiff end +function _setadbackend(::Val{:enzyme}) + ADBACKEND[] = :enzyme +end function _setadbackend(::Val{:tracker}) ADBACKEND[] = :tracker end @@ -47,6 +50,7 @@ getchunksize(::ForwardDiffAD{chunk}) where chunk = chunk standardtag(::ForwardDiffAD{<:Any,true}) = true standardtag(::ForwardDiffAD) = false +struct EnzymeAD <: ADBackend end struct TrackerAD <: ADBackend end struct ZygoteAD <: ADBackend end @@ -64,6 +68,7 @@ ADBackend() = ADBackend(ADBACKEND[]) ADBackend(T::Symbol) = ADBackend(Val(T)) ADBackend(::Val{:forwarddiff}) = ForwardDiffAD{CHUNKSIZE[]} +ADBackend(::Val{:enzyme}) = EnzymeAD ADBackend(::Val{:tracker}) = TrackerAD ADBackend(::Val{:zygote}) = ZygoteAD ADBackend(::Val{:reversediff}) = ReverseDiffAD{getrdcache()} @@ -102,6 +107,10 @@ function LogDensityProblems.ADgradient(ad::ForwardDiffAD, ℓ::Turing.LogDensity return LogDensityProblems.ADgradient(Val(:ForwardDiff), ℓ; gradientconfig=config) end +function LogDensityProblems.ADgradient(::EnzymeAD, ℓ::Turing.LogDensityFunction) + return LogDensityProblems.ADgradient(Val(:Enzyme), ℓ) +end + function LogDensityProblems.ADgradient(::TrackerAD, ℓ::Turing.LogDensityFunction) return LogDensityProblems.ADgradient(Val(:Tracker), ℓ) end From 43ef4c4a10eb5d9deaed10665f79815c716fb931 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 23 Dec 2022 22:47:19 +0100 Subject: [PATCH 02/47] Apply suggestions from code review --- src/essential/ad.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/essential/ad.jl b/src/essential/ad.jl index 61b678edc4..848f10f68f 100644 --- a/src/essential/ad.jl +++ b/src/essential/ad.jl @@ -108,11 +108,11 @@ function LogDensityProblemsAD.ADgradient(ad::ForwardDiffAD, ℓ::Turing.LogDensi end function LogDensityProblemsAD.ADgradient(::EnzymeAD, ℓ::Turing.LogDensityFunction) - return LogDensityProblems.ADgradient(Val(:Enzyme), ℓ) + return LogDensityProblemsAD.ADgradient(Val(:Enzyme), ℓ) end function LogDensityProblemsAD.ADgradient(::TrackerAD, ℓ::Turing.LogDensityFunction) - return LogDensityProblems.ADgradient(Val(:Tracker), ℓ) + return LogDensityProblemsAD.ADgradient(Val(:Tracker), ℓ) end function LogDensityProblemsAD.ADgradient(::ZygoteAD, ℓ::Turing.LogDensityFunction) From 3e5841f30492f70559c5f197f657f203db74467f Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 29 Dec 2022 12:21:12 +0100 Subject: [PATCH 03/47] Add Enzyme to test dependencies --- test/Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/Project.toml b/test/Project.toml index a858a9ec60..02ce8cc1d6 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,6 +8,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -41,6 +42,7 @@ Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" DynamicPPL = "0.21" +Enzyme = "0.10.13" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32" LogDensityProblems = "2" From 66bce4ed46a86cf89b37a15d2b52a421cb1de5e9 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 29 Dec 2022 12:22:56 +0100 Subject: [PATCH 04/47] Test Enzyme --- test/runtests.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 7eee468139..b32605fa4a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,6 +41,7 @@ using Turing: BinomialLogit, ForwardDiffAD, Sampler, SampleFromPrior, NUTS, Trac using Turing.Essential: TuringDenseMvNormal, TuringDiagMvNormal using Turing.Variational: TruncatedADAGrad, DecayedADAGrad, AdvancedVI +import Enzyme import LogDensityProblems import LogDensityProblemsAD @@ -65,7 +66,7 @@ macro timeit_include(path::AbstractString) :(@timeit TIMEROUTPUT $path include($ end Turing.setrdcache(false) - for adbackend in (:forwarddiff, :tracker, :reversediff) + for adbackend in (:forwarddiff, :tracker, :reversediff, :enzyme) @timeit TIMEROUTPUT "inference: $adbackend" begin Turing.setadbackend(adbackend) @info "Testing $(adbackend)" From 789013467cc6dbe0b5a3f9774b936c3a1084d2d3 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 29 Dec 2022 17:40:31 +0100 Subject: [PATCH 05/47] Update ad.jl --- src/essential/ad.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/essential/ad.jl b/src/essential/ad.jl index 848f10f68f..59289c1842 100644 --- a/src/essential/ad.jl +++ b/src/essential/ad.jl @@ -9,6 +9,10 @@ function setadbackend(backend::Val) Bijectors.setadbackend(backend) end +# TODO: Add support to AdvancedVI and Bijectors +# (or better: use common interface package) +setadbackend(backend::Val{:enzyme}) = _setadbackend(backend) + function _setadbackend(::Val{:forwarddiff}) ADBACKEND[] = :forwarddiff end From f4bd1bf0f2c6e4f4f2b429d6a14d1cc8cba6f9a0 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Fri, 3 Feb 2023 22:59:05 +0000 Subject: [PATCH 06/47] Update Project.toml --- Project.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index abae4a539c..567ed47f62 100644 --- a/Project.toml +++ b/Project.toml @@ -40,14 +40,14 @@ AbstractMCMC = "4" AdvancedHMC = "0.3.0, 0.4" AdvancedMH = "0.6.8, 0.7" AdvancedPS = "0.4" -AdvancedVI = "0.1" +AdvancedVI = "0.2" BangBang = "0.3" -Bijectors = "0.8, 0.9, 0.10" +Bijectors = "0.11, 0.12" DataStructures = "0.18" Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" -DynamicPPL = "0.21.5" +DynamicPPL = "0.22" EllipticalSliceSampling = "0.5, 1" ForwardDiff = "0.10.3" Libtask = "0.7, 0.8" From c8e01d0da369cfb55dab045efb8ee27d0c15fbdf Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Fri, 3 Feb 2023 22:59:32 +0000 Subject: [PATCH 07/47] Update advi.jl --- src/variational/advi.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/variational/advi.jl b/src/variational/advi.jl index d91f8a897a..47b35ddb0a 100644 --- a/src/variational/advi.jl +++ b/src/variational/advi.jl @@ -1,5 +1,5 @@ # TODO(torfjelde): Find a better solution. -struct Vec{N, B<:Bijectors.Bijector{N}} <: Bijectors.Bijector{1} +struct Vec{N, B<:Bijectors.Transform} <: Bijectors.Transform b::B size::NTuple{N, Int} end From 946e594d6e6c11decc0b9f49b519de4b837c14f5 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 7 Mar 2023 10:20:12 +0100 Subject: [PATCH 08/47] Do not call `Bijectors.setadbackend` --- Project.toml | 2 +- src/essential/ad.jl | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 12c2ad79ec..b64082ae6d 100644 --- a/Project.toml +++ b/Project.toml @@ -42,7 +42,7 @@ AdvancedMH = "0.6.8, 0.7" AdvancedPS = "0.4" AdvancedVI = "0.2" BangBang = "0.3" -Bijectors = "0.11, 0.12" +Bijectors = "0.12" DataStructures = "0.18" Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" diff --git a/src/essential/ad.jl b/src/essential/ad.jl index d3d73326f4..4df698da3c 100644 --- a/src/essential/ad.jl +++ b/src/essential/ad.jl @@ -6,13 +6,8 @@ setadbackend(backend_sym::Symbol) = setadbackend(Val(backend_sym)) function setadbackend(backend::Val) _setadbackend(backend) AdvancedVI.setadbackend(backend) - Bijectors.setadbackend(backend) end -# TODO: Add support to AdvancedVI and Bijectors -# (or better: use common interface package) -setadbackend(backend::Val{:enzyme}) = _setadbackend(backend) - function _setadbackend(::Val{:forwarddiff}) ADBACKEND[] = :forwarddiff end From e9eedd10cd452962b23e8269c0dce8390ce34461 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 13 Apr 2023 11:04:14 +0200 Subject: [PATCH 09/47] Update Project.toml --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index ca7b133908..cc92ba98bb 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -42,7 +42,7 @@ Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" DynamicPPL = "0.22" -Enzyme = "0.10.13" +Enzyme = "0.11" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" LogDensityProblems = "2" From 8d8d0310c672734b20a760a9faa0c0647193832b Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 27 Jun 2023 00:24:44 +0200 Subject: [PATCH 10/47] Address comments --- test/Project.toml | 2 +- test/runtests.jl | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 1cec0ec197..aab5004faa 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -43,7 +43,7 @@ Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" DynamicPPL = "0.23" -Enzyme = "0.11" +Enzyme = "0.11.2" FillArrays = "=1.0.0" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" diff --git a/test/runtests.jl b/test/runtests.jl index 64d33ee17f..282594576e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -47,6 +47,9 @@ import LogDensityProblemsAD setprogress!(false) +# Disable Enzyme warnings +Enzyme.API.typeWarning!(false) + include(pkgdir(Turing)*"/test/test_utils/AllUtils.jl") # Collect timing and allocations information to show in a clear way. From e5916304d417b031fd2f82e5d1a6363a51600955 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 27 Jun 2023 02:11:04 +0200 Subject: [PATCH 11/47] Update runtests.jl --- test/runtests.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 282594576e..8c7b729bb7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -50,6 +50,9 @@ setprogress!(false) # Disable Enzyme warnings Enzyme.API.typeWarning!(false) +# Enable runtime activity (workaround) +Enzyme.API.runtimeActivity!(true) + include(pkgdir(Turing)*"/test/test_utils/AllUtils.jl") # Collect timing and allocations information to show in a clear way. From 568cdaceb8e1bd58cca981edd4924a430a78e9d3 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 7 Jul 2023 19:48:43 +0200 Subject: [PATCH 12/47] Update Project.toml --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index aab5004faa..5933cf7174 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -43,7 +43,7 @@ Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" DynamicPPL = "0.23" -Enzyme = "0.11.2" +Enzyme = "0.11.3" FillArrays = "=1.0.0" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" From 6f0bf67e079f46ba7be130b89e80c8df5b31d306 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 7 Jul 2023 22:09:56 +0200 Subject: [PATCH 13/47] Update Project.toml --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 5933cf7174..138eea80df 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -43,7 +43,7 @@ Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" DynamicPPL = "0.23" -Enzyme = "0.11.3" +Enzyme = "0.11.4" FillArrays = "=1.0.0" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" From 5ba7ac6bdbf0eac941ef28fe08dd09b1ad717769 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 13 Jul 2023 21:37:15 +0200 Subject: [PATCH 14/47] Update Project.toml --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 138eea80df..1772782c86 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -43,7 +43,7 @@ Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" DynamicPPL = "0.23" -Enzyme = "0.11.4" +Enzyme = "0.11.5" FillArrays = "=1.0.0" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" From 162755be3427a8400d83ddc9e2c32c5beeebbfc7 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 14 Jul 2023 22:03:12 +0200 Subject: [PATCH 15/47] Test against Enzyme#main --- test/runtests.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 6c618d4d76..d6b3c2a7d3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,3 +1,6 @@ +using Pkg +Pkg.add(Pkg.PackageSpec(; url="https://github.com/EnzymeAD/Enzyme.jl.git", rev="main")) + using AbstractMCMC using AdvancedMH using Clustering From e44e7560bc79a45f01c4b74ca1813140ae2f6924 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 24 Jul 2023 10:28:21 +0200 Subject: [PATCH 16/47] Try addr13 branch --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index d6b3c2a7d3..a75c1c7712 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ using Pkg -Pkg.add(Pkg.PackageSpec(; url="https://github.com/EnzymeAD/Enzyme.jl.git", rev="main")) +Pkg.add(Pkg.PackageSpec(; url="https://github.com/EnzymeAD/Enzyme.jl.git", rev="addr13")) using AbstractMCMC using AdvancedMH From 1f1b1140e3ab1d6b41324dcb36d4e79061088f01 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 27 Jul 2023 20:08:37 +0200 Subject: [PATCH 17/47] Update runtests.jl --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index a75c1c7712..d6b3c2a7d3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ using Pkg -Pkg.add(Pkg.PackageSpec(; url="https://github.com/EnzymeAD/Enzyme.jl.git", rev="addr13")) +Pkg.add(Pkg.PackageSpec(; url="https://github.com/EnzymeAD/Enzyme.jl.git", rev="main")) using AbstractMCMC using AdvancedMH From bb795e6e03b6b65e664ec8052784e3ae79571dc8 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Mon, 31 Jul 2023 20:46:32 +0100 Subject: [PATCH 18/47] Disable Gibbs tests temporarily --- test/runtests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index d6b3c2a7d3..1b99986f33 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -81,8 +81,8 @@ macro timeit_include(path::AbstractString) :(@timeit TIMEROUTPUT $path include($ @info "Testing $(adbackend)" @testset "inference: $adbackend" begin @testset "samplers" begin - @timeit_include("inference/gibbs.jl") - @timeit_include("inference/gibbs_conditional.jl") + # @timeit_include("inference/gibbs.jl") + # @timeit_include("inference/gibbs_conditional.jl") @timeit_include("inference/hmc.jl") @timeit_include("inference/Inference.jl") @timeit_include("contrib/inference/dynamichmc.jl") From 1c7f20efd54a74148b33cda574f5efa29f19b5b2 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Mon, 31 Jul 2023 22:13:11 +0100 Subject: [PATCH 19/47] Update test/Project.toml Co-authored-by: David Widmann --- test/Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 6eae1dcfdc..8c777f43f4 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -61,4 +61,3 @@ StatsFuns = "0.9.5, 1" TimerOutputs = "0.5" Tracker = "0.2.11" Zygote = "0.5.4, 0.6" -julia = "1.6" From 012a0cbd3dc26e6e877f0d2a9cc731b1f5a53d3e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 23 Sep 2023 09:54:56 +0100 Subject: [PATCH 20/47] disable tests unrelated to enzyme + limit CI to avoid over-use of resources --- .github/workflows/DynamicHMC.yml | 1 - .github/workflows/Numerical.yml | 1 - .github/workflows/TuringCI.yml | 14 ---------- test/runtests.jl | 45 ++++++++++++++++---------------- 4 files changed, 23 insertions(+), 38 deletions(-) diff --git a/.github/workflows/DynamicHMC.yml b/.github/workflows/DynamicHMC.yml index 099f70fcf8..d66c6988b4 100644 --- a/.github/workflows/DynamicHMC.yml +++ b/.github/workflows/DynamicHMC.yml @@ -12,7 +12,6 @@ jobs: strategy: matrix: version: - - '1.7' - '1' os: - ubuntu-latest diff --git a/.github/workflows/Numerical.yml b/.github/workflows/Numerical.yml index 314241fbef..977fc86f7b 100644 --- a/.github/workflows/Numerical.yml +++ b/.github/workflows/Numerical.yml @@ -12,7 +12,6 @@ jobs: strategy: matrix: version: - - '1.7' - '1' os: - ubuntu-latest diff --git a/.github/workflows/TuringCI.yml b/.github/workflows/TuringCI.yml index 88cc27bcb2..cc8648a7a8 100644 --- a/.github/workflows/TuringCI.yml +++ b/.github/workflows/TuringCI.yml @@ -13,7 +13,6 @@ jobs: strategy: matrix: version: - - '1.7' - '1' os: - ubuntu-latest @@ -22,19 +21,6 @@ jobs: num_threads: - 1 - 2 - include: - - version: '1.7' - os: ubuntu-latest - arch: x86 - num_threads: 2 - - version: '1.7' - os: windows-latest - arch: x64 - num_threads: 2 - - version: '1.7' - os: macOS-latest - arch: x64 - num_threads: 2 steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 diff --git a/test/runtests.jl b/test/runtests.jl index 88b3e8a047..23f5fcc303 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -63,19 +63,20 @@ const TIMEROUTPUT = TimerOutputs.TimerOutput() macro timeit_include(path::AbstractString) :(@timeit TIMEROUTPUT $path include($path)) end @testset "Turing" begin - @testset "essential" begin - @timeit_include("essential/ad.jl") - end - - @testset "samplers (without AD)" begin - @timeit_include("mcmc/particle_mcmc.jl") - @timeit_include("mcmc/emcee.jl") - @timeit_include("mcmc/ess.jl") - @timeit_include("mcmc/is.jl") - end + # NOTE: Doesn't contain Enzyme tests. + # @testset "essential" begin + # @timeit_include("essential/ad.jl") + # end + + # @testset "samplers (without AD)" begin + # @timeit_include("mcmc/particle_mcmc.jl") + # @timeit_include("mcmc/emcee.jl") + # @timeit_include("mcmc/ess.jl") + # @timeit_include("mcmc/is.jl") + # end Turing.setrdcache(false) - for adbackend in (:forwarddiff, :reversediff, :enzyme) + for adbackend in (:enzyme,) @timeit TIMEROUTPUT "inference: $adbackend" begin Turing.setadbackend(adbackend) @info "Testing $(adbackend)" @@ -104,19 +105,19 @@ macro timeit_include(path::AbstractString) :(@timeit TIMEROUTPUT $path include($ end end - @testset "variational optimisers" begin - @timeit_include("variational/optimisers.jl") - end + # @testset "variational optimisers" begin + # @timeit_include("variational/optimisers.jl") + # end - Turing.setadbackend(:forwarddiff) - @testset "stdlib" begin - @timeit_include("stdlib/distributions.jl") - @timeit_include("stdlib/RandomMeasures.jl") - end + # Turing.setadbackend(:forwarddiff) + # @testset "stdlib" begin + # @timeit_include("stdlib/distributions.jl") + # @timeit_include("stdlib/RandomMeasures.jl") + # end - @testset "utilities" begin - @timeit_include("mcmc/utilities.jl") - end + # @testset "utilities" begin + # @timeit_include("mcmc/utilities.jl") + # end end show(TIMEROUTPUT; compact=true, sortby=:firstexec) From 577734477186d8c02307398d809816905a8fb19e Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 12 Dec 2023 08:30:31 +0000 Subject: [PATCH 21/47] import `AutoEnzyme` --- src/essential/Essential.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/essential/Essential.jl b/src/essential/Essential.jl index 5ae03d9c32..ab1460fe6c 100644 --- a/src/essential/Essential.jl +++ b/src/essential/Essential.jl @@ -11,7 +11,7 @@ using Bijectors: PDMatDistribution using AdvancedVI using StatsFuns: logsumexp, softmax @reexport using DynamicPPL -using ADTypes: ADTypes, AutoForwardDiff, AutoTracker, AutoReverseDiff, AutoZygote +using ADTypes: ADTypes, AutoForwardDiff, AutoEnzyme, AutoTracker, AutoReverseDiff, AutoZygote import AdvancedPS import LogDensityProblems From 121df7d5246bc53803ec4f4b67d3686d391805c7 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sat, 16 Dec 2023 10:19:26 +0000 Subject: [PATCH 22/47] Test hmc only --- test/mcmc/Inference.jl | 3 ++- test/mcmc/hmc.jl | 3 ++- test/mcmc/sghmc.jl | 4 +++- test/runtests.jl | 52 +++++++++++++++++++++--------------------- 4 files changed, 33 insertions(+), 29 deletions(-) diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 1f5a148699..d8c05136f4 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -1,4 +1,5 @@ -@testset "Testing inference.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false)) +# @testset "Testing inference.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false)) +@testset "Testing inference.jl with $adbackend" for adbackend in (AutoEnzyme(),) # Only test threading if 1.3+. if VERSION > v"1.2" @testset "threaded sampling" begin diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index fe18fa7733..0d408f01bb 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -1,4 +1,5 @@ -@testset "Testing hmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false)) +# @testset "Testing hmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false)) +@testset "Testing hmc.jl with $adbackend" for adbackend in (AutoEnzyme(),) # Set a seed rng = StableRNG(123) @numerical_testset "constrained bounded" begin diff --git a/test/mcmc/sghmc.jl b/test/mcmc/sghmc.jl index 4405b505ab..9079f94c00 100644 --- a/test/mcmc/sghmc.jl +++ b/test/mcmc/sghmc.jl @@ -1,4 +1,5 @@ @testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false)) +@testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoEnzyme(),) @turing_testset "sghmc constructor" begin alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=adbackend) @test alg isa SGHMC @@ -24,7 +25,8 @@ end end -@testset "Testing sgld.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false)) +# @testset "Testing sgld.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false)) +@testset "Testing sgld.jl with $adbackend" for adbackend in (AutoEnzyme(),) @turing_testset "sgld constructor" begin alg = SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adbackend) @test alg isa SGLD diff --git a/test/runtests.jl b/test/runtests.jl index ab4b8b7b10..71344d62d8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -64,47 +64,47 @@ const TIMEROUTPUT = TimerOutputs.TimerOutput() macro timeit_include(path::AbstractString) :(@timeit TIMEROUTPUT $path include($path)) end @testset "Turing" begin -# @testset "essential" begin -# @timeit_include("essential/ad.jl") -# end - -# @testset "samplers (without AD)" begin -# @timeit_include("mcmc/particle_mcmc.jl") -# @timeit_include("mcmc/emcee.jl") -# @timeit_include("mcmc/ess.jl") -# @timeit_include("mcmc/is.jl") -# end + # @testset "essential" begin + # @timeit_include("essential/ad.jl") + # end + + # @testset "samplers (without AD)" begin + # @timeit_include("mcmc/particle_mcmc.jl") + # @timeit_include("mcmc/emcee.jl") + # @timeit_include("mcmc/ess.jl") + # @timeit_include("mcmc/is.jl") + # end @timeit TIMEROUTPUT "inference" begin @testset "inference with samplers" begin -# @timeit_include("mcmc/gibbs.jl") -# @timeit_include("mcmc/gibbs_conditional.jl") + # @timeit_include("mcmc/gibbs.jl") + # @timeit_include("mcmc/gibbs_conditional.jl") @timeit_include("mcmc/hmc.jl") @timeit_include("mcmc/Inference.jl") @timeit_include("mcmc/sghmc.jl") - @timeit_include("mcmc/abstractmcmc.jl") - @timeit_include("mcmc/mh.jl") - @timeit_include("ext/dynamichmc.jl") + # @timeit_include("mcmc/abstractmcmc.jl") + # @timeit_include("mcmc/mh.jl") + # @timeit_include("ext/dynamichmc.jl") end - @testset "variational algorithms" begin - @timeit_include("variational/advi.jl") - end + # @testset "variational algorithms" begin + # @timeit_include("variational/advi.jl") + # end - @testset "mode estimation" begin - @timeit_include("optimisation/OptimInterface.jl") - @timeit_include("ext/Optimisation.jl") - end + # @testset "mode estimation" begin + # @timeit_include("optimisation/OptimInterface.jl") + # @timeit_include("ext/Optimisation.jl") + # end end # @testset "variational optimisers" begin # @timeit_include("variational/optimisers.jl") # end -# @testset "stdlib" begin -# @timeit_include("stdlib/distributions.jl") -# @timeit_include("stdlib/RandomMeasures.jl") -# end + # @testset "stdlib" begin + # @timeit_include("stdlib/distributions.jl") + # @timeit_include("stdlib/RandomMeasures.jl") + # end # @testset "utilities" begin # @timeit_include("mcmc/utilities.jl") From a164707f10ddf8bacf3ade53a292f6f41229f71b Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 21 Dec 2023 13:50:40 -0600 Subject: [PATCH 23/47] Update sghmc.jl --- test/mcmc/sghmc.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mcmc/sghmc.jl b/test/mcmc/sghmc.jl index 9079f94c00..16b6508ee5 100644 --- a/test/mcmc/sghmc.jl +++ b/test/mcmc/sghmc.jl @@ -1,4 +1,4 @@ -@testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false)) +# @testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false)) @testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoEnzyme(),) @turing_testset "sghmc constructor" begin alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=adbackend) From 97f1fb6bb5b7d4f90bfd63ea86d12003dc687a6a Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 21 Dec 2023 13:51:38 -0600 Subject: [PATCH 24/47] Update runtests.jl --- test/runtests.jl | 64 ++++++++++++++++++++++++------------------------ 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 71344d62d8..b4a6d37e80 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -64,51 +64,51 @@ const TIMEROUTPUT = TimerOutputs.TimerOutput() macro timeit_include(path::AbstractString) :(@timeit TIMEROUTPUT $path include($path)) end @testset "Turing" begin - # @testset "essential" begin - # @timeit_include("essential/ad.jl") - # end - - # @testset "samplers (without AD)" begin - # @timeit_include("mcmc/particle_mcmc.jl") - # @timeit_include("mcmc/emcee.jl") - # @timeit_include("mcmc/ess.jl") - # @timeit_include("mcmc/is.jl") - # end + @testset "essential" begin + @timeit_include("essential/ad.jl") + end + + @testset "samplers (without AD)" begin + @timeit_include("mcmc/particle_mcmc.jl") + @timeit_include("mcmc/emcee.jl") + @timeit_include("mcmc/ess.jl") + @timeit_include("mcmc/is.jl") + end @timeit TIMEROUTPUT "inference" begin @testset "inference with samplers" begin - # @timeit_include("mcmc/gibbs.jl") - # @timeit_include("mcmc/gibbs_conditional.jl") + @timeit_include("mcmc/gibbs.jl") + @timeit_include("mcmc/gibbs_conditional.jl") @timeit_include("mcmc/hmc.jl") @timeit_include("mcmc/Inference.jl") @timeit_include("mcmc/sghmc.jl") - # @timeit_include("mcmc/abstractmcmc.jl") - # @timeit_include("mcmc/mh.jl") - # @timeit_include("ext/dynamichmc.jl") + @timeit_include("mcmc/abstractmcmc.jl") + @timeit_include("mcmc/mh.jl") + @timeit_include("ext/dynamichmc.jl") end - # @testset "variational algorithms" begin - # @timeit_include("variational/advi.jl") - # end + @testset "variational algorithms" begin + @timeit_include("variational/advi.jl") + end - # @testset "mode estimation" begin - # @timeit_include("optimisation/OptimInterface.jl") - # @timeit_include("ext/Optimisation.jl") - # end + @testset "mode estimation" begin + @timeit_include("optimisation/OptimInterface.jl") + @timeit_include("ext/Optimisation.jl") + end end - # @testset "variational optimisers" begin - # @timeit_include("variational/optimisers.jl") - # end + @testset "variational optimisers" begin + @timeit_include("variational/optimisers.jl") + end - # @testset "stdlib" begin - # @timeit_include("stdlib/distributions.jl") - # @timeit_include("stdlib/RandomMeasures.jl") - # end + @testset "stdlib" begin + @timeit_include("stdlib/distributions.jl") + @timeit_include("stdlib/RandomMeasures.jl") + end - # @testset "utilities" begin - # @timeit_include("mcmc/utilities.jl") - # end + @testset "utilities" begin + @timeit_include("mcmc/utilities.jl") + end end show(TIMEROUTPUT; compact=true, sortby=:firstexec) From c7b6cf4c18ae1c730e2b1e5a58b80d1f3dc8360e Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 25 Jan 2024 14:57:43 -0500 Subject: [PATCH 25/47] disable Type unstable getfield --- test/mcmc/Inference.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index d8c05136f4..374c48e1e3 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -337,6 +337,8 @@ alg = Gibbs(HMC(0.2, 3, :m; adtype=adbackend), PG(10, :s)) chn = sample(gdemo_default, alg, 1000) end + # Type unstable getfield of tuple not supported in Enzyme yet + if adbackend != AutoEnzyme() @testset "vectorization @." begin # https://github.com/FluxML/Tracker.jl/issues/119 @model function vdemo1(x) @@ -519,6 +521,7 @@ vdemo3kw(; T) = vdemo3(T) sample(vdemo3kw(; T=Vector{Float64}), alg, 250) end + end @testset "names_values" begin ks, xs = Turing.Inference.names_values([ From efdd8e7537111591043366c28c0c67631e9a69b8 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 25 Jan 2024 15:00:58 -0500 Subject: [PATCH 26/47] use release --- test/Project.toml | 2 +- test/runtests.jl | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 355ded217f..72beb62b88 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -41,7 +41,7 @@ Clustering = "0.14, 0.15" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -Enzyme = "0.11.5" +Enzyme = "0.11.12" DynamicPPL = "0.24" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" diff --git a/test/runtests.jl b/test/runtests.jl index b4a6d37e80..2475f6a383 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,3 @@ -using Pkg -Pkg.add(Pkg.PackageSpec(; url="https://github.com/EnzymeAD/Enzyme.jl.git", rev="main")) - using AbstractMCMC using AdvancedMH using AdvancedPS From 2fdf5464165a1f013f91afd51fb879348662a16a Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 25 Jan 2024 23:35:13 +0100 Subject: [PATCH 27/47] Remove seemingly unnecessary definition --- src/essential/ad.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/essential/ad.jl b/src/essential/ad.jl index 289bf55fd9..da5f827e92 100644 --- a/src/essential/ad.jl +++ b/src/essential/ad.jl @@ -40,10 +40,6 @@ function LogDensityProblemsAD.ADgradient(ad::AutoForwardDiff, ℓ::Turing.LogDen return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk, tag, x = θ) end -function LogDensityProblemsAD.ADgradient(::AutoEnzyme, ℓ::Turing.LogDensityFunction) - return LogDensityProblemsAD.ADgradient(Val(:Enzyme), ℓ) -end - function LogDensityProblemsAD.ADgradient(ad::AutoReverseDiff, ℓ::Turing.LogDensityFunction) return LogDensityProblemsAD.ADgradient(Val(:ReverseDiff), ℓ; compile=Val(ad.compile), x=DynamicPPL.getparams(ℓ)) end From 4d8cd2313cfd3c93665df0dcc463246812e328e3 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 26 Jan 2024 02:02:11 +0100 Subject: [PATCH 28/47] Run tests on Enzyme#main again --- test/runtests.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 2475f6a383..b1f66287d6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,3 +1,6 @@ +import Pkg +Pkg.add(Pkg.PackageSpec(; url="https://github.com/EnzymeAD/Enzyme.jl.git", rev="main")) + using AbstractMCMC using AdvancedMH using AdvancedPS From b8296bed1cd59532ad120e9ea14f3cc0e2b274d2 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 13 Mar 2024 01:18:49 +0100 Subject: [PATCH 29/47] Test with cholesky fixes --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index b1f66287d6..2101bebf33 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ import Pkg -Pkg.add(Pkg.PackageSpec(; url="https://github.com/EnzymeAD/Enzyme.jl.git", rev="main")) +Pkg.add(Pkg.PackageSpec(; url="https://github.com/simsurace/Enzyme.jl.git", rev="fix-cholesky")) using AbstractMCMC using AdvancedMH From 2b54d69a7118e7d129248f3c24080ad550c563e5 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Wed, 29 May 2024 23:31:11 +0100 Subject: [PATCH 30/47] Update Project.toml --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index c70d2009ba..ccaf715cb0 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -42,7 +42,7 @@ Clustering = "0.14, 0.15" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -Enzyme = "0.11.12" +Enzyme = "0.12" DynamicPPL = "0.27" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" From 2823a41efc85986ec3785852f0ff648269b97e93 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Thu, 30 May 2024 00:19:30 +0100 Subject: [PATCH 31/47] Update Turing.jl --- src/Turing.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Turing.jl b/src/Turing.jl index 99e9880d2d..5ef60aca55 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -106,6 +106,7 @@ export @model, # modelling AutoForwardDiff, # ADTypes AutoReverseDiff, AutoZygote, + AutoEnzyme, AutoTracker, AutoTapir, From 0c376a66d5c54459815e81ffdf00c92291bfd17f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 1 Jul 2024 19:57:32 +0100 Subject: [PATCH 32/47] Attempt at fix for `bnn` tests as outlined in #2277 --- test/mcmc/hmc.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 6ac96a551b..855d729e93 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -111,7 +111,7 @@ Enzyme.API.runtimeActivity!(true) alpha = 0.16 # regularizatin term var_prior = sqrt(1.0 / alpha) # variance of the Gaussian prior - @model function bnn(ts) + @model function bnn(ts, var_prior) b1 ~ MvNormal([0. ;0.; 0.], [var_prior 0. 0.; 0. var_prior 0.; 0. 0. var_prior]) w11 ~ MvNormal([0.; 0.], [var_prior 0.; 0. var_prior]) @@ -129,7 +129,7 @@ Enzyme.API.runtimeActivity!(true) end # Sampling - chain = sample(rng, bnn(ts), HMC(0.1, 5; adtype=adbackend), 10) + chain = sample(rng, bnn(ts, var_prior), HMC(0.1, 5; adtype=adbackend), 10) end @testset "hmcda inference" begin From 76b5e48532277879fd4f7b98d1b27778dee8ec7a Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Tue, 9 Jul 2024 17:49:03 +0100 Subject: [PATCH 33/47] Update test/runtests.jl --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 9294b34871..6fa0b20eaf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ import Pkg -Pkg.add(Pkg.PackageSpec(; url="https://github.com/simsurace/Enzyme.jl.git", rev="fix-cholesky")) +Pkg.add(Pkg.PackageSpec(; url="https://github.com/simsurace/Enzyme.jl.git")) include("test_utils/SelectiveTests.jl") using .SelectiveTests: isincluded, parse_args From 784b8cb66e149f9d1829f7aaa86e5213c55e954f Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Tue, 9 Jul 2024 17:50:21 +0100 Subject: [PATCH 34/47] Update runtests.jl --- test/runtests.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 6fa0b20eaf..48a00122d4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,3 @@ -import Pkg -Pkg.add(Pkg.PackageSpec(; url="https://github.com/simsurace/Enzyme.jl.git")) - include("test_utils/SelectiveTests.jl") using .SelectiveTests: isincluded, parse_args using Pkg From 5bfd06dca74c0d1bab2e3627b9a9606aec558ac1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 9 Jul 2024 20:09:38 +0100 Subject: [PATCH 35/47] remove implicit usage of `hvcat` --- test/mcmc/hmc.jl | 94 +++++++++++++++++++++++------------------------- 1 file changed, 45 insertions(+), 49 deletions(-) diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 855d729e93..d8f9e00498 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -28,52 +28,51 @@ Enzyme.API.runtimeActivity!(true) # Set a seed rng = StableRNG(123) @testset "constrained bounded" begin - obs = [0,1,0,1,1,1,1,1,1,1] + obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1] @model function constrained_test(obs) - p ~ Beta(2,2) - for i = 1:length(obs) + p ~ Beta(2, 2) + for i in 1:length(obs) obs[i] ~ Bernoulli(p) end - p + return p end chain = sample( rng, constrained_test(obs), HMC(1.5, 3; adtype=adbackend),# using a large step size (1.5) - 1000) + 1000, + ) - check_numerical(chain, [:p], [10/14], atol=0.1) + check_numerical(chain, [:p], [10 / 14]; atol=0.1) end @testset "constrained simplex" begin - obs12 = [1,2,1,2,2,2,2,2,2,2] + obs12 = [1, 2, 1, 2, 2, 2, 2, 2, 2, 2] @model function constrained_simplex_test(obs12) ps ~ Dirichlet(2, 3) pd ~ Dirichlet(4, 1) - for i = 1:length(obs12) + for i in 1:length(obs12) obs12[i] ~ Categorical(ps) end return ps end chain = sample( - rng, - constrained_simplex_test(obs12), - HMC(0.75, 2; adtype=adbackend), - 1000) + rng, constrained_simplex_test(obs12), HMC(0.75, 2; adtype=adbackend), 1000 + ) - check_numerical(chain, ["ps[1]", "ps[2]"], [5/16, 11/16], atol=0.015) + check_numerical(chain, ["ps[1]", "ps[2]"], [5 / 16, 11 / 16]; atol=0.015) end @testset "hmc reverse diff" begin alg = HMC(0.1, 10; adtype=adbackend) res = sample(rng, gdemo_default, alg, 4000) - check_gdemo(res, rtol=0.1) + check_gdemo(res; rtol=0.1) end @testset "matrix support" begin @model function hmcmatrixsup() - v ~ Wishart(7, [1 0.5; 0.5 1]) + return v ~ Wishart(7, [1 0.5; 0.5 1]) end model_f = hmcmatrixsup() @@ -81,7 +80,7 @@ Enzyme.API.runtimeActivity!(true) vs = map(1:3) do _ chain = sample(rng, model_f, HMC(0.15, 7; adtype=adbackend), n_samples) r = reshape(Array(group(chain, :v)), n_samples, 2, 2) - reshape(mean(r; dims = 1), 2, 2) + reshape(mean(r; dims=1), 2, 2) end @test maximum(abs, mean(vs) - (7 * [1 0.5; 0.5 1])) <= 0.5 @@ -98,10 +97,10 @@ Enzyme.API.runtimeActivity!(true) M = N ÷ 4 x1s = rand(M) * 5 x2s = rand(M) * 5 - xt1s = Array([[x1s[i]; x2s[i]] for i = 1:M]) - append!(xt1s, Array([[x1s[i] - 6; x2s[i] - 6] for i = 1:M])) - xt0s = Array([[x1s[i]; x2s[i] - 6] for i = 1:M]) - append!(xt0s, Array([[x1s[i] - 6; x2s[i]] for i = 1:M])) + xt1s = Array([[x1s[i]; x2s[i]] for i in 1:M]) + append!(xt1s, Array([[x1s[i] - 6; x2s[i] - 6] for i in 1:M])) + xt0s = Array([[x1s[i]; x2s[i] - 6] for i in 1:M]) + append!(xt0s, Array([[x1s[i] - 6; x2s[i]] for i in 1:M])) xs = [xt1s; xt0s] ts = [ones(M); ones(M); zeros(M); zeros(M)] @@ -112,20 +111,18 @@ Enzyme.API.runtimeActivity!(true) var_prior = sqrt(1.0 / alpha) # variance of the Gaussian prior @model function bnn(ts, var_prior) - b1 ~ MvNormal([0. ;0.; 0.], - [var_prior 0. 0.; 0. var_prior 0.; 0. 0. var_prior]) - w11 ~ MvNormal([0.; 0.], [var_prior 0.; 0. var_prior]) - w12 ~ MvNormal([0.; 0.], [var_prior 0.; 0. var_prior]) - w13 ~ MvNormal([0.; 0.], [var_prior 0.; 0. var_prior]) + b1 ~ MvNormal(zeros(3), var_prior * I) + w11 ~ MvNormal(zeros(2), var_prior * I) + w12 ~ MvNormal(zeros(2), var_prior * I) + w13 ~ MvNormal(zeros(2), var_prior * I) bo ~ Normal(0, var_prior) - wo ~ MvNormal([0.; 0; 0], - [var_prior 0. 0.; 0. var_prior 0.; 0. 0. var_prior]) - for i = rand(1:N, 10) + wo ~ MvNormal(zeros(3), var_prior * I) + for i in rand(1:N, 10) y = nn(xs[i], b1, w11, w12, w13, bo, wo) ts[i] ~ Bernoulli(y) end - b1, w11, w12, w13, bo, wo + return b1, w11, w12, w13, bo, wo end # Sampling @@ -153,7 +150,7 @@ Enzyme.API.runtimeActivity!(true) Random.seed!(12345) # particle samplers do not support user-provided `rng` yet alg3 = Gibbs(PG(20, :s), HMCDA(500, 0.8, 0.25, :m; init_ϵ=0.05, adtype=adbackend)) - res3 = sample(rng, gdemo_default, alg3, 3000, discard_initial=1000) + res3 = sample(rng, gdemo_default, alg3, 3000; discard_initial=1000) check_gdemo(res3) end @@ -197,8 +194,8 @@ Enzyme.API.runtimeActivity!(true) @testset "check discard" begin alg = NUTS(100, 0.8; adtype=adbackend) - c1 = sample(rng, gdemo_default, alg, 500, discard_adapt=true) - c2 = sample(rng, gdemo_default, alg, 500, discard_adapt=false) + c1 = sample(rng, gdemo_default, alg, 500; discard_adapt=true) + c2 = sample(rng, gdemo_default, alg, 500; discard_adapt=false) @test size(c1, 1) == 500 @test size(c2, 1) == 500 @@ -216,20 +213,20 @@ Enzyme.API.runtimeActivity!(true) # https://github.com/TuringLang/DynamicPPL.jl/issues/27 @model function mwe1(::Type{T}=Float64) where {T<:Real} m = Matrix{T}(undef, 2, 3) - m .~ MvNormal(zeros(2), I) + return m .~ MvNormal(zeros(2), I) end @test sample(rng, mwe1(), HMC(0.2, 4; adtype=adbackend), 1_000) isa Chains @model function mwe2(::Type{T}=Matrix{Float64}) where {T} m = T(undef, 2, 3) - m .~ MvNormal(zeros(2), I) + return m .~ MvNormal(zeros(2), I) end @test sample(rng, mwe2(), HMC(0.2, 4; adtype=adbackend), 1_000) isa Chains # https://github.com/TuringLang/Turing.jl/issues/1308 @model function mwe3(::Type{T}=Array{Float64}) where {T} m = T(undef, 2, 3) - m .~ MvNormal(zeros(2), I) + return m .~ MvNormal(zeros(2), I) end @test sample(rng, mwe3(), HMC(0.2, 4; adtype=adbackend), 1_000) isa Chains end @@ -247,13 +244,17 @@ Enzyme.API.runtimeActivity!(true) @model function demo_hmc_prior() # NOTE: Used to use `InverseGamma(2, 3)` but this has infinite variance # which means that it's _very_ difficult to find a good tolerance in the test below:) - s ~ truncated(Normal(3, 1), lower=0) - m ~ Normal(0, sqrt(s)) + s ~ truncated(Normal(3, 1); lower=0) + return m ~ Normal(0, sqrt(s)) end alg = NUTS(1000, 0.8; adtype=adbackend) - gdemo_default_prior = DynamicPPL.contextualize(demo_hmc_prior(), DynamicPPL.PriorContext()) + gdemo_default_prior = DynamicPPL.contextualize( + demo_hmc_prior(), DynamicPPL.PriorContext() + ) chain = sample(gdemo_default_prior, alg, 10_000; initial_params=[3.0, 0.0]) - check_numerical(chain, [:s, :m], [mean(truncated(Normal(3, 1); lower=0)), 0], atol=0.2) + check_numerical( + chain, [:s, :m], [mean(truncated(Normal(3, 1); lower=0)), 0]; atol=0.2 + ) end @testset "warning for difficult init params" begin @@ -268,7 +269,7 @@ Enzyme.API.runtimeActivity!(true) @test_logs ( :warn, "failed to find valid initial parameters in 10 tries; consider providing explicit initial parameters using the `initial_params` keyword", - ) (:info,) match_mode=:any begin + ) (:info,) match_mode = :any begin sample(demo_warn_initial_params(), NUTS(; adtype=adbackend), 5) end end @@ -280,7 +281,7 @@ Enzyme.API.runtimeActivity!(true) @model function vector_of_dirichlet(::Type{TV}=Vector{Float64}) where {TV} xs = Vector{TV}(undef, 2) xs[1] ~ Dirichlet(ones(5)) - xs[2] ~ Dirichlet(ones(5)) + return xs[2] ~ Dirichlet(ones(5)) end model = vector_of_dirichlet() chain = sample(model, NUTS(), 1000) @@ -306,15 +307,10 @@ Enzyme.API.runtimeActivity!(true) end end - model = buggy_model(); - num_samples = 1_000; + model = buggy_model() + num_samples = 1_000 - chain = sample( - model, - NUTS(), - num_samples; - initial_params=[0.5, 1.75, 1.0] - ) + chain = sample(model, NUTS(), num_samples; initial_params=[0.5, 1.75, 1.0]) chain_prior = sample(model, Prior(), num_samples) # Extract the `x` like this because running `generated_quantities` was how From ce13e032d6a5901daac026d9f1c1085ddc9a465a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 25 Jul 2024 22:24:43 +0200 Subject: [PATCH 36/47] Re-activate CIs disabled for Enzyme testing --- .github/workflows/Tests.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/Tests.yml b/.github/workflows/Tests.yml index 174b4bd8a7..8de296e5ee 100644 --- a/.github/workflows/Tests.yml +++ b/.github/workflows/Tests.yml @@ -26,15 +26,15 @@ jobs: - "mcmc/ess.jl" - "--skip essential/ad.jl mcmc/gibbs.jl mcmc/hmc.jl mcmc/abstractmcmc.jl mcmc/Inference.jl experimental/gibbs.jl mcmc/ess.jl" version: - #- '1.7' TODO(mhauru): Temporarily disabled for Enzyme + - '1.7' - '1' os: - ubuntu-latest - #- windows-latest TODO(mhauru): Temporarily disabled for Enzyme - #- macOS-latest TODO(mhauru): Temporarily disabled for Enzyme + - windows-latest + - macOS-latest arch: - x64 - #- x86 TODO(mhauru): Temporarily disabled for Enzyme + - x86 num_threads: - 1 - 2 From e2c069345184af87f345cadc7699c5f4149a1bca Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 15 Aug 2024 11:05:16 +0200 Subject: [PATCH 37/47] Re-enable tests with other AD backends --- Project.toml | 2 +- test/mcmc/Inference.jl | 12 +++++++----- test/mcmc/hmc.jl | 11 +++++++---- test/mcmc/sghmc.jl | 6 ++---- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index 33d4be908a..8845ebbe23 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.33.3" +version = "0.33.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 15468669fd..a6d9998ef0 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -20,8 +20,7 @@ Enzyme.API.typeWarning!(false) # Enable runtime activity (workaround) Enzyme.API.runtimeActivity!(true) -# @testset "Testing inference.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false)) -@testset "Testing inference.jl with $adbackend" for adbackend in (AutoEnzyme(),) +@testset "Testing inference.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false), AutoEnzyme()) # Only test threading if 1.3+. if VERSION > v"1.2" @testset "threaded sampling" begin @@ -374,8 +373,6 @@ Enzyme.API.runtimeActivity!(true) alg = Gibbs(HMC(0.2, 3, :m; adtype=adbackend), PG(10, :s)) chn = sample(gdemo_default, alg, 1000) end - # Type unstable getfield of tuple not supported in Enzyme yet - if adbackend != AutoEnzyme() @testset "vectorization @." begin # https://github.com/FluxML/Tracker.jl/issues/119 @model function vdemo1(x) @@ -407,6 +404,8 @@ Enzyme.API.runtimeActivity!(true) alg = HMC(0.01, 5; adtype=adbackend) res = sample(vdemo2(randn(D, 100)), alg, 250) + # Type unstable getfield of tuple not supported in Enzyme yet + if !(adbackend isa AutoEnzyme) # Vector assumptions N = 10 alg = HMC(0.2, 4; adtype=adbackend) @@ -452,6 +451,7 @@ Enzyme.API.runtimeActivity!(true) end sample(vdemo7(), alg, 1000) + end end @testset "vectorization .~" begin @model function vdemo1(x) @@ -474,6 +474,8 @@ Enzyme.API.runtimeActivity!(true) alg = HMC(0.01, 5; adtype=adbackend) res = sample(vdemo2(randn(D, 100)), alg, 250) + # Type unstable getfield of tuple not supported in Enzyme yet + if !(adbackend isa AutoEnzyme) # Vector assumptions N = 10 alg = HMC(0.2, 4; adtype=adbackend) @@ -518,6 +520,7 @@ Enzyme.API.runtimeActivity!(true) end sample(vdemo7(), alg, 1000) + end end @testset "Type parameters" begin N = 10 @@ -558,7 +561,6 @@ Enzyme.API.runtimeActivity!(true) vdemo3kw(; T) = vdemo3(T) sample(vdemo3kw(; T=DynamicPPL.TypeWrap{Vector{Float64}}()), alg, 250) end - end @testset "names_values" begin ks, xs = Turing.Inference.names_values([ diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index d8f9e00498..e0e8da4692 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -23,8 +23,7 @@ Enzyme.API.typeWarning!(false) # Enable runtime activity (workaround) Enzyme.API.runtimeActivity!(true) -# @testset "Testing hmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false)) -@testset "Testing hmc.jl with $adbackend" for adbackend in (AutoEnzyme(),) +@testset "Testing hmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false), AutoEnzyme()) # Set a seed rng = StableRNG(123) @testset "constrained bounded" begin @@ -76,14 +75,18 @@ Enzyme.API.runtimeActivity!(true) end model_f = hmcmatrixsup() - n_samples = 1_000 + n_samples = 5_000 vs = map(1:3) do _ chain = sample(rng, model_f, HMC(0.15, 7; adtype=adbackend), n_samples) r = reshape(Array(group(chain, :v)), n_samples, 2, 2) reshape(mean(r; dims=1), 2, 2) end - @test maximum(abs, mean(vs) - (7 * [1 0.5; 0.5 1])) <= 0.5 + if VERSION > v"1.7" || !(adbackend isa AutoEnzyme) + @test maximum(abs, mean(vs) - (7 * [1 0.5; 0.5 1])) <= 0.5 + else + @test_broken maximum(abs, mean(vs) - (7 * [1 0.5; 0.5 1])) <= 0.5 + end end @testset "multivariate support" begin # Define NN flow diff --git a/test/mcmc/sghmc.jl b/test/mcmc/sghmc.jl index 955995570c..411278f906 100644 --- a/test/mcmc/sghmc.jl +++ b/test/mcmc/sghmc.jl @@ -17,8 +17,7 @@ Enzyme.API.typeWarning!(false) # Enable runtime activity (workaround) Enzyme.API.runtimeActivity!(true) -# @testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false)) -@testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoEnzyme(),) +@testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false), AutoEnzyme()) @testset "sghmc constructor" begin alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=adbackend) @test alg isa SGHMC @@ -44,8 +43,7 @@ Enzyme.API.runtimeActivity!(true) end end -# @testset "Testing sgld.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false)) -@testset "Testing sgld.jl with $adbackend" for adbackend in (AutoEnzyme(),) +@testset "Testing sgld.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false), AutoEnzyme()) @testset "sgld constructor" begin alg = SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adbackend) @test alg isa SGLD From 2115d524ff6f3825debfaba874c4a693243a9337 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 15 Aug 2024 13:03:12 +0200 Subject: [PATCH 38/47] Load `@test_broken` --- test/mcmc/hmc.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 212fced399..0e85fd16bf 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -15,7 +15,7 @@ using LinearAlgebra: I, dot, vec import Random using StableRNGs: StableRNG using StatsFuns: logistic -using Test: @test, @test_logs, @testset +using Test: @test, @test_broken, @test_logs, @testset using Turing # Disable Enzyme warnings From 79d057c37fe2d0202edb5162454480cb89292c4f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Oct 2024 11:11:05 +0100 Subject: [PATCH 39/47] Bump Enzyme to 0.13 in tests --- test/Project.toml | 2 +- test/mcmc/Inference.jl | 4 ++-- test/mcmc/hmc.jl | 3 ++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 9eb5b4fdc7..07d5e01361 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -46,7 +46,7 @@ Clustering = "0.14, 0.15" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -Enzyme = "0.12" +Enzyme = "0.13" DynamicPPL = "0.29" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 02a00766cc..1ad807ef39 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -412,7 +412,7 @@ using Turing alg = HMC(0.01, 5; adtype=adbackend) res = sample(vdemo2(randn(D, 100)), alg, 250) - # Type unstable getfield of tuple not supported in Enzyme yet + # TODO(mhauru) Type unstable getfield of tuple not supported in Enzyme yet if !(adbackend isa AutoEnzyme) # Vector assumptions N = 10 @@ -482,7 +482,7 @@ using Turing alg = HMC(0.01, 5; adtype=adbackend) res = sample(vdemo2(randn(D, 100)), alg, 250) - # Type unstable getfield of tuple not supported in Enzyme yet + # TODO(mhauru) Type unstable getfield of tuple not supported in Enzyme yet if !(adbackend isa AutoEnzyme) # Vector assumptions N = 10 diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 295f2703ba..2a9a897615 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -78,7 +78,8 @@ using Turing reshape(mean(r; dims=1), 2, 2) end - if VERSION > v"1.7" || !(adbackend isa AutoEnzyme) + # TODO(mhauru) The below remains broken for Enzyme. Need to investigate why. + if !(adbackend isa AutoEnzyme) @test maximum(abs, mean(vs) - (7 * [1 0.5; 0.5 1])) <= 0.5 else @test_broken maximum(abs, mean(vs) - (7 * [1 0.5; 0.5 1])) <= 0.5 From c98bbc96af637df114deed81d346544ddda6bc41 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Oct 2024 11:13:25 +0100 Subject: [PATCH 40/47] Run JuliaFormatter on more files, remove trailing whitespace --- .JuliaFormatter.toml | 3 -- .github/workflows/DocsNav.yml | 6 +-- src/mcmc/mh.jl | 2 +- test/mcmc/hmc.jl | 98 +++++++++++++++++------------------ test/mcmc/sghmc.jl | 10 ++-- 5 files changed, 58 insertions(+), 61 deletions(-) diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index 15ecbc5c35..2772de28bf 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -8,7 +8,4 @@ ignore = [ # https://github.com/TuringLang/Turing.jl/pull/2328/files "src/experimental/gibbs.jl", "test/experimental/gibbs.jl", - # https://github.com/TuringLang/Turing.jl/pull/1887 # Enzyme PR - "test/mcmc/hmc.jl", - "test/mcmc/sghmc.jl", ] diff --git a/.github/workflows/DocsNav.yml b/.github/workflows/DocsNav.yml index 14614d1fd9..301ee7393c 100644 --- a/.github/workflows/DocsNav.yml +++ b/.github/workflows/DocsNav.yml @@ -32,13 +32,13 @@ jobs: # Define the URL of the navbar to be used NAVBAR_URL="https://raw.githubusercontent.com/TuringLang/turinglang.github.io/main/assets/scripts/TuringNavbar.html" - + # Update all HTML files in the current directory (gh-pages root) ./insert_navbar.sh . $NAVBAR_URL - + # Remove the insert_navbar.sh file rm insert_navbar.sh - + # Check if there are any changes if [[ -n $(git status -s) ]]; then git add . diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index 433add6b59..bc2519d71e 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -54,7 +54,7 @@ Specifying a single distribution implies the use of static MH: ```julia # Use a static proposal for s² (which happens to be the same -# as the prior) and a static proposal for m (note that this +# as the prior) and a static proposal for m (note that this # isn't a random walk proposal). chain = sample( gdemo(1.5, 2.0), diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 7404dbf43e..27c055394f 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -22,52 +22,51 @@ using Turing # Set a seed rng = StableRNG(123) @testset "constrained bounded" begin - obs = [0,1,0,1,1,1,1,1,1,1] + obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1] @model function constrained_test(obs) - p ~ Beta(2,2) - for i = 1:length(obs) + p ~ Beta(2, 2) + for i in 1:length(obs) obs[i] ~ Bernoulli(p) end - p + return p end chain = sample( rng, constrained_test(obs), HMC(1.5, 3; adtype=adbackend),# using a large step size (1.5) - 1000) + 1000, + ) - check_numerical(chain, [:p], [10/14], atol=0.1) + check_numerical(chain, [:p], [10 / 14]; atol=0.1) end @testset "constrained simplex" begin - obs12 = [1,2,1,2,2,2,2,2,2,2] + obs12 = [1, 2, 1, 2, 2, 2, 2, 2, 2, 2] @model function constrained_simplex_test(obs12) ps ~ Dirichlet(2, 3) pd ~ Dirichlet(4, 1) - for i = 1:length(obs12) + for i in 1:length(obs12) obs12[i] ~ Categorical(ps) end return ps end chain = sample( - rng, - constrained_simplex_test(obs12), - HMC(0.75, 2; adtype=adbackend), - 1000) + rng, constrained_simplex_test(obs12), HMC(0.75, 2; adtype=adbackend), 1000 + ) - check_numerical(chain, ["ps[1]", "ps[2]"], [5/16, 11/16], atol=0.015) + check_numerical(chain, ["ps[1]", "ps[2]"], [5 / 16, 11 / 16]; atol=0.015) end @testset "hmc reverse diff" begin alg = HMC(0.1, 10; adtype=adbackend) res = sample(rng, gdemo_default, alg, 4000) - check_gdemo(res, rtol=0.1) + check_gdemo(res; rtol=0.1) end @testset "matrix support" begin @model function hmcmatrixsup() - v ~ Wishart(7, [1 0.5; 0.5 1]) + return v ~ Wishart(7, [1 0.5; 0.5 1]) end model_f = hmcmatrixsup() @@ -75,7 +74,7 @@ using Turing vs = map(1:3) do _ chain = sample(rng, model_f, HMC(0.15, 7; adtype=adbackend), n_samples) r = reshape(Array(group(chain, :v)), n_samples, 2, 2) - reshape(mean(r; dims = 1), 2, 2) + reshape(mean(r; dims=1), 2, 2) end @test maximum(abs, mean(vs) - (7 * [1 0.5; 0.5 1])) <= 0.5 @@ -92,10 +91,10 @@ using Turing M = N ÷ 4 x1s = rand(M) * 5 x2s = rand(M) * 5 - xt1s = Array([[x1s[i]; x2s[i]] for i = 1:M]) - append!(xt1s, Array([[x1s[i] - 6; x2s[i] - 6] for i = 1:M])) - xt0s = Array([[x1s[i]; x2s[i] - 6] for i = 1:M]) - append!(xt0s, Array([[x1s[i] - 6; x2s[i]] for i = 1:M])) + xt1s = Array([[x1s[i]; x2s[i]] for i in 1:M]) + append!(xt1s, Array([[x1s[i] - 6; x2s[i] - 6] for i in 1:M])) + xt0s = Array([[x1s[i]; x2s[i] - 6] for i in 1:M]) + append!(xt0s, Array([[x1s[i] - 6; x2s[i]] for i in 1:M])) xs = [xt1s; xt0s] ts = [ones(M); ones(M); zeros(M); zeros(M)] @@ -106,20 +105,22 @@ using Turing var_prior = sqrt(1.0 / alpha) # variance of the Gaussian prior @model function bnn(ts) - b1 ~ MvNormal([0. ;0.; 0.], - [var_prior 0. 0.; 0. var_prior 0.; 0. 0. var_prior]) - w11 ~ MvNormal([0.; 0.], [var_prior 0.; 0. var_prior]) - w12 ~ MvNormal([0.; 0.], [var_prior 0.; 0. var_prior]) - w13 ~ MvNormal([0.; 0.], [var_prior 0.; 0. var_prior]) + b1 ~ MvNormal( + [0.0; 0.0; 0.0], [var_prior 0.0 0.0; 0.0 var_prior 0.0; 0.0 0.0 var_prior] + ) + w11 ~ MvNormal([0.0; 0.0], [var_prior 0.0; 0.0 var_prior]) + w12 ~ MvNormal([0.0; 0.0], [var_prior 0.0; 0.0 var_prior]) + w13 ~ MvNormal([0.0; 0.0], [var_prior 0.0; 0.0 var_prior]) bo ~ Normal(0, var_prior) - wo ~ MvNormal([0.; 0; 0], - [var_prior 0. 0.; 0. var_prior 0.; 0. 0. var_prior]) - for i = rand(1:N, 10) + wo ~ MvNormal( + [0.0; 0; 0], [var_prior 0.0 0.0; 0.0 var_prior 0.0; 0.0 0.0 var_prior] + ) + for i in rand(1:N, 10) y = nn(xs[i], b1, w11, w12, w13, bo, wo) ts[i] ~ Bernoulli(y) end - b1, w11, w12, w13, bo, wo + return b1, w11, w12, w13, bo, wo end # Sampling @@ -147,7 +148,7 @@ using Turing Random.seed!(12345) # particle samplers do not support user-provided `rng` yet alg3 = Gibbs(PG(20, :s), HMCDA(500, 0.8, 0.25, :m; init_ϵ=0.05, adtype=adbackend)) - res3 = sample(rng, gdemo_default, alg3, 3000, discard_initial=1000) + res3 = sample(rng, gdemo_default, alg3, 3000; discard_initial=1000) check_gdemo(res3) end @@ -191,8 +192,8 @@ using Turing @testset "check discard" begin alg = NUTS(100, 0.8; adtype=adbackend) - c1 = sample(rng, gdemo_default, alg, 500, discard_adapt=true) - c2 = sample(rng, gdemo_default, alg, 500, discard_adapt=false) + c1 = sample(rng, gdemo_default, alg, 500; discard_adapt=true) + c2 = sample(rng, gdemo_default, alg, 500; discard_adapt=false) @test size(c1, 1) == 500 @test size(c2, 1) == 500 @@ -210,20 +211,20 @@ using Turing # https://github.com/TuringLang/DynamicPPL.jl/issues/27 @model function mwe1(::Type{T}=Float64) where {T<:Real} m = Matrix{T}(undef, 2, 3) - m .~ MvNormal(zeros(2), I) + return m .~ MvNormal(zeros(2), I) end @test sample(rng, mwe1(), HMC(0.2, 4; adtype=adbackend), 1_000) isa Chains @model function mwe2(::Type{T}=Matrix{Float64}) where {T} m = T(undef, 2, 3) - m .~ MvNormal(zeros(2), I) + return m .~ MvNormal(zeros(2), I) end @test sample(rng, mwe2(), HMC(0.2, 4; adtype=adbackend), 1_000) isa Chains # https://github.com/TuringLang/Turing.jl/issues/1308 @model function mwe3(::Type{T}=Array{Float64}) where {T} m = T(undef, 2, 3) - m .~ MvNormal(zeros(2), I) + return m .~ MvNormal(zeros(2), I) end @test sample(rng, mwe3(), HMC(0.2, 4; adtype=adbackend), 1_000) isa Chains end @@ -241,13 +242,17 @@ using Turing @model function demo_hmc_prior() # NOTE: Used to use `InverseGamma(2, 3)` but this has infinite variance # which means that it's _very_ difficult to find a good tolerance in the test below:) - s ~ truncated(Normal(3, 1), lower=0) - m ~ Normal(0, sqrt(s)) + s ~ truncated(Normal(3, 1); lower=0) + return m ~ Normal(0, sqrt(s)) end alg = NUTS(1000, 0.8; adtype=adbackend) - gdemo_default_prior = DynamicPPL.contextualize(demo_hmc_prior(), DynamicPPL.PriorContext()) + gdemo_default_prior = DynamicPPL.contextualize( + demo_hmc_prior(), DynamicPPL.PriorContext() + ) chain = sample(gdemo_default_prior, alg, 10_000; initial_params=[3.0, 0.0]) - check_numerical(chain, [:s, :m], [mean(truncated(Normal(3, 1); lower=0)), 0], atol=0.2) + check_numerical( + chain, [:s, :m], [mean(truncated(Normal(3, 1); lower=0)), 0]; atol=0.2 + ) end @testset "warning for difficult init params" begin @@ -262,7 +267,7 @@ using Turing @test_logs ( :warn, "failed to find valid initial parameters in 10 tries; consider providing explicit initial parameters using the `initial_params` keyword", - ) (:info,) match_mode=:any begin + ) (:info,) match_mode = :any begin sample(demo_warn_initial_params(), NUTS(; adtype=adbackend), 5) end end @@ -271,7 +276,7 @@ using Turing @model function vector_of_dirichlet(::Type{TV}=Vector{Float64}) where {TV} xs = Vector{TV}(undef, 2) xs[1] ~ Dirichlet(ones(5)) - xs[2] ~ Dirichlet(ones(5)) + return xs[2] ~ Dirichlet(ones(5)) end model = vector_of_dirichlet() chain = sample(model, NUTS(), 1000) @@ -296,15 +301,10 @@ using Turing end end - model = buggy_model(); - num_samples = 1_000; + model = buggy_model() + num_samples = 1_000 - chain = sample( - model, - NUTS(), - num_samples; - initial_params=[0.5, 1.75, 1.0] - ) + chain = sample(model, NUTS(), num_samples; initial_params=[0.5, 1.75, 1.0]) chain_prior = sample(model, Prior(), num_samples) # Extract the `x` like this because running `generated_quantities` was how diff --git a/test/mcmc/sghmc.jl b/test/mcmc/sghmc.jl index 1f81795034..c1d07d2ced 100644 --- a/test/mcmc/sghmc.jl +++ b/test/mcmc/sghmc.jl @@ -34,7 +34,7 @@ using Turing alg = SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adbackend) chain = sample(rng, gdemo_default, alg, 10_000) - check_gdemo(chain, atol=0.1) + check_gdemo(chain; atol=0.1) end end @@ -58,15 +58,15 @@ end @testset "sgld inference" begin rng = StableRNG(1) - chain = sample(rng, gdemo_default, SGLD(; stepsize = PolynomialStepsize(0.5)), 20_000) - check_gdemo(chain, atol = 0.2) + chain = sample(rng, gdemo_default, SGLD(; stepsize=PolynomialStepsize(0.5)), 20_000) + check_gdemo(chain; atol=0.2) # Weight samples by step sizes (cf section 4.2 in the paper by Welling and Teh) v = get(chain, [:SGLD_stepsize, :s, :m]) s_weighted = dot(v.SGLD_stepsize, v.s) / sum(v.SGLD_stepsize) m_weighted = dot(v.SGLD_stepsize, v.m) / sum(v.SGLD_stepsize) - @test s_weighted ≈ 49/24 atol=0.2 - @test m_weighted ≈ 7/6 atol=0.2 + @test s_weighted ≈ 49 / 24 atol = 0.2 + @test m_weighted ≈ 7 / 6 atol = 0.2 end end From 9d391938f2b04a8254ad2a23077bb07abfe77a7e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Oct 2024 11:44:48 +0100 Subject: [PATCH 41/47] Restore compat with Enzyme v0.12 --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 07d5e01361..0841eb27e6 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -46,7 +46,7 @@ Clustering = "0.14, 0.15" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -Enzyme = "0.13" +Enzyme = "0.12, 0.13" DynamicPPL = "0.29" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" From 66cd80a1ffeb9a4e233a247d04f596d1ec8ccc2d Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Oct 2024 13:11:34 +0100 Subject: [PATCH 42/47] Import Enzyme in abstractmcmc and gibbs tests --- test/mcmc/abstractmcmc.jl | 1 + test/mcmc/gibbs.jl | 1 + 2 files changed, 2 insertions(+) diff --git a/test/mcmc/abstractmcmc.jl b/test/mcmc/abstractmcmc.jl index 449b43b712..a113f1b7c1 100644 --- a/test/mcmc/abstractmcmc.jl +++ b/test/mcmc/abstractmcmc.jl @@ -5,6 +5,7 @@ using AdvancedMH: AdvancedMH using Distributions: sample using Distributions.FillArrays: Zeros using DynamicPPL: DynamicPPL +import Enzyme using ForwardDiff: ForwardDiff using LinearAlgebra: I using LogDensityProblems: LogDensityProblems diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index cd044910b8..0121687cd3 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -5,6 +5,7 @@ using ..NumericalTests: check_MoGtest_default, check_gdemo, check_numerical import ..ADUtils using Distributions: InverseGamma, Normal using Distributions: sample +import Enzyme using ForwardDiff: ForwardDiff using Random: Random using ReverseDiff: ReverseDiff From ec34e4120b40fc10613d1c650fdaee1b47e98de2 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Oct 2024 13:37:25 +0100 Subject: [PATCH 43/47] Add Enzyme imports to a couple of other tests files --- test/mcmc/gibbs_conditional.jl | 1 + test/optimisation/Optimisation.jl | 1 + 2 files changed, 2 insertions(+) diff --git a/test/mcmc/gibbs_conditional.jl b/test/mcmc/gibbs_conditional.jl index d6d81cbe09..abbdc03c59 100644 --- a/test/mcmc/gibbs_conditional.jl +++ b/test/mcmc/gibbs_conditional.jl @@ -5,6 +5,7 @@ using ..NumericalTests: check_gdemo, check_numerical import ..ADUtils using Clustering: Clustering using Distributions: Categorical, InverseGamma, Normal, sample +import Enzyme using ForwardDiff: ForwardDiff using LinearAlgebra: Diagonal, I using Random: Random diff --git a/test/optimisation/Optimisation.jl b/test/optimisation/Optimisation.jl index d8afd83dbb..8758e946fd 100644 --- a/test/optimisation/Optimisation.jl +++ b/test/optimisation/Optimisation.jl @@ -5,6 +5,7 @@ using ..ADUtils: ADUtils using Distributions using Distributions.FillArrays: Zeros using DynamicPPL: DynamicPPL +using Enzyme: Enzyme using ForwardDiff: ForwardDiff using LinearAlgebra: Diagonal, I using Mooncake: Mooncake From 120230a65727898d1cf9708be47f1d66cfdcc1e6 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 24 Oct 2024 14:21:28 +0100 Subject: [PATCH 44/47] Remove unnecessary version conditions in tests --- test/mcmc/Inference.jl | 95 ++++++++++++++++++------------------------ 1 file changed, 40 insertions(+), 55 deletions(-) diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 1ad807ef39..bf53af36e9 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -18,70 +18,55 @@ using Test: @test, @test_throws, @testset using Turing @testset "Testing inference.jl with $adbackend" for adbackend in ADUtils.adbackends - # Only test threading if 1.3+. - if VERSION > v"1.2" - @testset "threaded sampling" begin - # Test that chains with the same seed will sample identically. - @testset "rng" begin - model = gdemo_default - - # multithreaded sampling with PG causes segfaults on Julia 1.5.4 - # https://github.com/TuringLang/Turing.jl/issues/1571 - samplers = @static if VERSION <= v"1.5.3" || VERSION >= v"1.6.0" - ( - HMC(0.1, 7; adtype=adbackend), - PG(10), - IS(), - MH(), - Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)), - Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)), - ) - else - ( - HMC(0.1, 7; adtype=adbackend), - IS(), - MH(), - Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)), - ) - end - for sampler in samplers - Random.seed!(5) - chain1 = sample(model, sampler, MCMCThreads(), 1000, 4) + @testset "threaded sampling" begin + # Test that chains with the same seed will sample identically. + @testset "rng" begin + model = gdemo_default + + samplers = ( + HMC(0.1, 7; adtype=adbackend), + PG(10), + IS(), + MH(), + Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)), + Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)), + ) + for sampler in samplers + Random.seed!(5) + chain1 = sample(model, sampler, MCMCThreads(), 1000, 4) - Random.seed!(5) - chain2 = sample(model, sampler, MCMCThreads(), 1000, 4) + Random.seed!(5) + chain2 = sample(model, sampler, MCMCThreads(), 1000, 4) - @test chain1.value == chain2.value - end + @test chain1.value == chain2.value + end - # Should also be stable with am explicit RNG - seed = 5 - rng = Random.MersenneTwister(seed) - for sampler in samplers - Random.seed!(rng, seed) - chain1 = sample(rng, model, sampler, MCMCThreads(), 1000, 4) + # Should also be stable with am explicit RNG + seed = 5 + rng = Random.MersenneTwister(seed) + for sampler in samplers + Random.seed!(rng, seed) + chain1 = sample(rng, model, sampler, MCMCThreads(), 1000, 4) - Random.seed!(rng, seed) - chain2 = sample(rng, model, sampler, MCMCThreads(), 1000, 4) + Random.seed!(rng, seed) + chain2 = sample(rng, model, sampler, MCMCThreads(), 1000, 4) - @test chain1.value == chain2.value - end + @test chain1.value == chain2.value end + end - # Smoke test for default sample call. - Random.seed!(100) - chain = sample( - gdemo_default, HMC(0.1, 7; adtype=adbackend), MCMCThreads(), 1000, 4 - ) - check_gdemo(chain) + # Smoke test for default sample call. + Random.seed!(100) + chain = sample(gdemo_default, HMC(0.1, 7; adtype=adbackend), MCMCThreads(), 1000, 4) + check_gdemo(chain) - # run sampler: progress logging should be disabled and - # it should return a Chains object - sampler = Sampler(HMC(0.1, 7; adtype=adbackend), gdemo_default) - chains = sample(gdemo_default, sampler, MCMCThreads(), 1000, 4) - @test chains isa MCMCChains.Chains - end + # run sampler: progress logging should be disabled and + # it should return a Chains object + sampler = Sampler(HMC(0.1, 7; adtype=adbackend), gdemo_default) + chains = sample(gdemo_default, sampler, MCMCThreads(), 1000, 4) + @test chains isa MCMCChains.Chains end + @testset "chain save/resume" begin Random.seed!(1234) From f9165fa80ea746156d4913b0477c0600e26a5bc9 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 29 Nov 2024 09:31:49 +0100 Subject: [PATCH 45/47] Update Project.toml --- test/Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 45676b9cdb..feaf2b49ec 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -48,7 +48,6 @@ DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" Enzyme = "0.13" DynamicPPL = "0.29, 0.30.2" -DynamicPPL = "0.29, 0.30.2" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" HypothesisTests = "0.11" From 5698212b773cac802492281201f8783a6ee4402b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 2 Dec 2024 15:25:34 +0000 Subject: [PATCH 46/47] Dump DPPL to 0.31 --- Project.toml | 2 +- test/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index c6c1db99cb..cea6b3655c 100644 --- a/Project.toml +++ b/Project.toml @@ -63,7 +63,7 @@ Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.29, 0.30.4" +DynamicPPL = "0.29, 0.30.4, 0.31" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3" Libtask = "0.8.8" diff --git a/test/Project.toml b/test/Project.toml index feaf2b49ec..7935495a48 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -47,7 +47,7 @@ Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" Enzyme = "0.13" -DynamicPPL = "0.29, 0.30.2" +DynamicPPL = "0.29, 0.30.2, 0.31" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" HypothesisTests = "0.11" From 20b055e4705ccefc5a904b85613379beabd4a164 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 5 Dec 2024 14:30:45 +0000 Subject: [PATCH 47/47] Fix ADTypeCheck tests for Enzyme, add testing both Reverse and Forward Enzyme --- test/mcmc/hmc.jl | 17 +++++++++++------ test/test_utils/ad_utils.jl | 7 +++++-- test/test_utils/test_utils.jl | 13 ++++++------- 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index cbd896337e..bb894b8a0f 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -20,6 +20,7 @@ using Test: @test, @test_broken, @test_logs, @testset, @test_throws using Turing @testset "Testing hmc.jl with $adbackend" for adbackend in ADUtils.adbackends + @info "Running HMC tests with $adbackend" # Set a seed rng = StableRNG(123) @testset "constrained bounded" begin @@ -332,12 +333,16 @@ using Turing end @testset "Check ADType" begin - alg = HMC(0.1, 10; adtype=adbackend) - m = DynamicPPL.contextualize( - gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context) - ) - # These will error if the adbackend being used is not the one set. - sample(rng, m, alg, 10) + # These tests don't make sense for Enzyme, since it does not use a particular element + # type. + if !(adbackend isa AutoEnzyme) + alg = HMC(0.1, 10; adtype=adbackend) + m = DynamicPPL.contextualize( + gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context) + ) + # These will error if the adbackend being used is not the one set. + @test (sample(rng, m, alg, 10); true) + end end end diff --git a/test/test_utils/ad_utils.jl b/test/test_utils/ad_utils.jl index 3ba6a73d62..231c9ec883 100644 --- a/test/test_utils/ad_utils.jl +++ b/test/test_utils/ad_utils.jl @@ -1,11 +1,11 @@ module ADUtils +using Enzyme: Enzyme using ForwardDiff: ForwardDiff using Pkg: Pkg using Random: Random using ReverseDiff: ReverseDiff using Mooncake: Mooncake -using Test: Test using Turing: Turing using Turing: DynamicPPL using Zygote: Zygote @@ -239,7 +239,10 @@ adbackends = [ Turing.AutoForwardDiff(; chunksize=0), Turing.AutoReverseDiff(; compile=false), Turing.AutoMooncake(; config=nothing), - Turing.AutoEnzyme(), + # TODO(mhauru) Do we want to run both? For now yes, while building up Enzyme + # integration, but in the long term maybe not? + Turing.AutoEnzyme(; mode=Enzyme.Forward), + Turing.AutoEnzyme(; mode=Enzyme.Reverse), ] end diff --git a/test/test_utils/test_utils.jl b/test/test_utils/test_utils.jl index bf9f2b9b8d..b29ae6226b 100644 --- a/test/test_utils/test_utils.jl +++ b/test/test_utils/test_utils.jl @@ -1,7 +1,7 @@ """Module for testing the test utils themselves.""" module TestUtilsTests -using ..ADUtils: ADTypeCheckContext, AbstractWrongADBackendError +using ..ADUtils: ADTypeCheckContext, AbstractWrongADBackendError, adbackends using ForwardDiff: ForwardDiff using ReverseDiff: ReverseDiff using Test: @test, @testset, @test_throws @@ -13,12 +13,11 @@ using Zygote: Zygote @testset "ADTypeCheckContext" begin Turing.@model test_model() = x ~ Turing.Normal(0, 1) tm = test_model() - adtypes = ( - Turing.AutoForwardDiff(), - Turing.AutoReverseDiff(), - Turing.AutoZygote(), - # TODO: Mooncake - # Turing.AutoMooncake(config=nothing), + # These tests don't make sense for Enzyme, since it doesn't have its own element type. + # TODO(mhauru): Make these tests work for more Mooncake. + adtypes = filter( + adtype -> !(adtype isa Turing.AutoMooncake || adtype isa Turing.AutoEnzyme), + adbackends, ) for actual_adtype in adtypes sampler = Turing.HMC(0.1, 5; adtype=actual_adtype)