diff --git a/HISTORY.md b/HISTORY.md index 6b7247c8d..23001d7d4 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -2,7 +2,7 @@ ## 0.35.0 -**Breaking** +**Breaking changes** ### Remove indexing by samplers @@ -49,6 +49,51 @@ This release removes the feature of `VarInfo` where it kept track of which varia This change also affects sampling in Turing.jl. +### `LogDensityFunction` argument order + + - The method `LogDensityFunction(varinfo, model, context)` has been removed. + The only accepted order is `LogDensityFunction(model, varinfo, context; adtype)`. + (For an explanation of `adtype`, see below.) + The varinfo and context arguments are both still optional. + +**Other changes** + +### `LogDensityProblems` interface + +LogDensityProblemsAD is now removed as a dependency. +Instead of constructing a `LogDensityProblemAD.ADgradient` object, we now directly use `DifferentiationInterface` to calculate the gradient of the log density with respect to model parameters. + +Note that if you wish, you can still construct an `ADgradient` out of a `LogDensityFunction` object (there is nothing preventing this). + +However, in this version, `LogDensityFunction` now takes an extra AD type argument. +If this argument is not provided, the behaviour is exactly the same as before, i.e. you can calculate `logdensity` but not its gradient. +However, if you do pass an AD type, that will allow you to calculate the gradient as well. +You may thus find that it is easier to instead do this: + +```julia +@model f() = ... + +ldf = LogDensityFunction(f(); adtype=AutoForwardDiff()) +``` + +This will return an object which satisfies the `LogDensityProblems` interface to first-order, i.e. you can now directly call both + +``` +LogDensityProblems.logdensity(ldf, params) +LogDensityProblems.logdensity_and_gradient(ldf, params) +``` + +without having to construct a separate `ADgradient` object. + +If you prefer, you can also use `setadtype` to tack on the AD type afterwards: + +```julia +@model f() = ... + +ldf = LogDensityFunction(f()) # by default, no adtype set +ldf_with_ad = setadtype(ldf, AutoForwardDiff()) +``` + ## 0.34.2 - Fixed bugs in ValuesAsInModelContext as well as DebugContext where underlying PrefixContexts were not being applied. diff --git a/Project.toml b/Project.toml index be4586246..81ba8b418 100644 --- a/Project.toml +++ b/Project.toml @@ -12,13 +12,13 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" -LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -29,7 +29,6 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" @@ -38,7 +37,6 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [extensions] DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] DynamicPPLEnzymeCoreExt = ["EnzymeCore"] -DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLJETExt = ["JET"] DynamicPPLMCMCChainsExt = ["MCMCChains"] DynamicPPLMooncakeExt = ["Mooncake"] @@ -54,15 +52,14 @@ Bijectors = "0.13.18, 0.14, 0.15" ChainRulesCore = "1" Compat = "4" ConstructionBase = "1.5.4" +DifferentiationInterface = "0.6.39" Distributions = "0.25" DocStringExtensions = "0.9" -KernelAbstractions = "0.9.33" EnzymeCore = "0.6 - 0.8" -ForwardDiff = "0.10" JET = "0.9" +KernelAbstractions = "0.9.33" LinearAlgebra = "1.6" LogDensityProblems = "2" -LogDensityProblemsAD = "1.7.0" MCMCChains = "6" MacroTools = "0.5.6" Mooncake = "0.4.59" diff --git a/docs/src/api.md b/docs/src/api.md index 6c58264fe..f949453a3 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -54,7 +54,7 @@ logjoint ### LogDensityProblems.jl interface -The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by simply wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction`: +The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction`. ```@docs DynamicPPL.LogDensityFunction diff --git a/ext/DynamicPPLForwardDiffExt.jl b/ext/DynamicPPLForwardDiffExt.jl deleted file mode 100644 index 4bc33e217..000000000 --- a/ext/DynamicPPLForwardDiffExt.jl +++ /dev/null @@ -1,54 +0,0 @@ -module DynamicPPLForwardDiffExt - -if isdefined(Base, :get_extension) - using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD - using ForwardDiff -else - using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD - using ..ForwardDiff -end - -getchunksize(::ADTypes.AutoForwardDiff{chunk}) where {chunk} = chunk - -standardtag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true -standardtag(::ADTypes.AutoForwardDiff) = false - -function LogDensityProblemsAD.ADgradient( - ad::ADTypes.AutoForwardDiff, ℓ::DynamicPPL.LogDensityFunction -) - θ = DynamicPPL.getparams(ℓ) - f = Base.Fix1(LogDensityProblems.logdensity, ℓ) - - # Define configuration for ForwardDiff. - tag = if standardtag(ad) - ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(θ)) - else - ForwardDiff.Tag(f, eltype(θ)) - end - chunk_size = getchunksize(ad) - chunk = if chunk_size == 0 || chunk_size === nothing - ForwardDiff.Chunk(θ) - else - ForwardDiff.Chunk(length(θ), chunk_size) - end - - return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk, tag, x=θ) -end - -# Allow Turing tag in gradient etc. calls of the log density function -function ForwardDiff.checktag( - ::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}}, - ::DynamicPPL.LogDensityFunction, - ::AbstractArray{W}, -) where {V,W} - return true -end -function ForwardDiff.checktag( - ::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}}, - ::Base.Fix1{typeof(LogDensityProblems.logdensity),<:DynamicPPL.LogDensityFunction}, - ::AbstractArray{W}, -) where {V,W} - return true -end - -end # module diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 55e1f7e88..c844060d5 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -14,7 +14,6 @@ using MacroTools: MacroTools using ConstructionBase: ConstructionBase using Accessors: Accessors using LogDensityProblems: LogDensityProblems -using LogDensityProblemsAD: LogDensityProblemsAD using LinearAlgebra: LinearAlgebra, Cholesky diff --git a/src/contexts.jl b/src/contexts.jl index 0b4633283..87ad8df0b 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -184,6 +184,7 @@ at which point it will return the sampler of that context. getsampler(context::SamplingContext) = context.sampler getsampler(context::AbstractContext) = getsampler(NodeTrait(context), context) getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(context)) +getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context") """ struct DefaultContext <: AbstractContext end diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 29f591cc3..77cc21475 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -1,12 +1,53 @@ +import DifferentiationInterface as DI + """ - LogDensityFunction + is_supported(adtype::AbstractADType) + +Check if the given AD type is formally supported by DynamicPPL. + +AD backends that are not formally supported can still be used for gradient +calculation; it is just that the DynamicPPL developers do not commit to +maintaining compatibility with them. +""" +is_supported(::ADTypes.AbstractADType) = false +is_supported(::ADTypes.AutoForwardDiff) = true +is_supported(::ADTypes.AutoMooncake) = true +is_supported(::ADTypes.AutoReverseDiff) = true + +""" + LogDensityFunction( + model::Model, + varinfo::AbstractVarInfo=VarInfo(model), + context::AbstractContext=DefaultContext(); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing + ) + +A struct which contains a model, along with all the information necessary to: + + - calculate its log density at a given point; + - and if `adtype` is provided, calculate the gradient of the log density at + that point. -A callable representing a log density function of a `model`. +At its most basic level, a LogDensityFunction wraps the model together with its +the type of varinfo to be used, as well as the evaluation context. These must +be known in order to calculate the log density (using +[`DynamicPPL.evaluate!!`](@ref)). + +If the `adtype` keyword argument is provided, then this struct will also store +the adtype along with other information for efficient calculation of the +gradient of the log density. Note that preparing a `LogDensityFunction` with an +AD type `AutoBackend()` requires the AD backend itself to have been loaded +(e.g. with `import Backend`). + +`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface. +If `adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a +concrete AD backend type, then `logdensity_and_gradient` is also implemented. # Fields $(FIELDS) # Examples + ```jldoctest julia> using Distributions @@ -42,116 +83,202 @@ julia> # This also respects the context in `model`. julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0) true + +julia> # If we also need to calculate the gradient, we can specify an AD backend. + import ForwardDiff, ADTypes + +julia> f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff()); + +julia> LogDensityProblems.logdensity_and_gradient(f, [0.0]) +(-2.3378770664093453, [1.0]) ``` """ -struct LogDensityFunction{V,M,C} - "varinfo used for evaluation" - varinfo::V +struct LogDensityFunction{ + M<:Model,V<:AbstractVarInfo,C<:AbstractContext,AD<:Union{Nothing,ADTypes.AbstractADType} +} "model used for evaluation" model::M + "varinfo used for evaluation" + varinfo::V "context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable" context::C + "AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated" + adtype::AD + "(internal use only) gradient preparation object for the model" + prep::Union{Nothing,DI.GradientPrep} + "(internal use only) whether a closure was used for the gradient preparation" + with_closure::Bool + + function LogDensityFunction( + model::Model, + varinfo::AbstractVarInfo=VarInfo(model), + context::AbstractContext=leafcontext(model.context); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, + ) + if adtype === nothing + prep = nothing + with_closure = false + else + # Check support + is_supported(adtype) || + @warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed." + # Get a set of dummy params to use for prep + x = map(identity, varinfo[:]) + with_closure = use_closure(adtype) + if with_closure + prep = DI.prepare_gradient( + x -> logdensity_at(x, model, varinfo, context), adtype, x + ) + else + prep = DI.prepare_gradient( + logdensity_at, + adtype, + x, + DI.Constant(model), + DI.Constant(varinfo), + DI.Constant(context), + ) + end + with_closure = with_closure + end + return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}( + model, varinfo, context, adtype, prep, with_closure + ) + end end -# TODO: Deprecate. -function LogDensityFunction( - varinfo::AbstractVarInfo, - model::Model, - sampler::AbstractSampler, - context::AbstractContext, -) - return LogDensityFunction(varinfo, model, SamplingContext(sampler, context)) +""" + setadtype(f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType}) + +Set the AD type used for evaluation of log density gradient in the given LogDensityFunction. +This function also performs preparation of the gradient, and sets the `prep` +and `with_closure` fields of the LogDensityFunction. + +If `adtype` is `nothing`, the `prep` field will be set to `nothing` as well. + +This function returns a new LogDensityFunction with the updated AD type, i.e. it does +not mutate the input LogDensityFunction. +""" +function setadtype(f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType}) + return if adtype === f.adtype + f # Avoid recomputing prep if not needed + else + LogDensityFunction(f.model, f.varinfo, f.context; adtype=adtype) + end end -function LogDensityFunction( - model::Model, - varinfo::AbstractVarInfo=VarInfo(model), - context::Union{Nothing,AbstractContext}=nothing, +""" + logdensity_at( + x::AbstractVector, + model::Model, + varinfo::AbstractVarInfo, + context::AbstractContext + ) + +Evaluate the log density of the given `model` at the given parameter values `x`, +using the given `varinfo` and `context`. Note that the `varinfo` argument is provided +only for its structure, in the sense that the parameters from the vector `x` are inserted into +it, and its own parameters are discarded. +""" +function logdensity_at( + x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext ) - return LogDensityFunction(varinfo, model, context) + varinfo_new = unflatten(varinfo, x) + return getlogp(last(evaluate!!(model, varinfo_new, context))) end -# If a `context` has been specified, we use that. Otherwise we just use the leaf context of `model`. -function getcontext(f::LogDensityFunction) - return f.context === nothing ? leafcontext(f.model.context) : f.context +### LogDensityProblems interface + +function LogDensityProblems.capabilities( + ::Type{<:LogDensityFunction{M,V,C,Nothing}} +) where {M,V,C} + return LogDensityProblems.LogDensityOrder{0}() end +function LogDensityProblems.capabilities( + ::Type{<:LogDensityFunction{M,V,C,AD}} +) where {M,V,C,AD<:ADTypes.AbstractADType} + return LogDensityProblems.LogDensityOrder{1}() +end +function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector) + return logdensity_at(x, f.model, f.varinfo, f.context) +end +function LogDensityProblems.logdensity_and_gradient( + f::LogDensityFunction{M,V,C,AD}, x::AbstractVector +) where {M,V,C,AD<:ADTypes.AbstractADType} + f.prep === nothing && + error("Gradient preparation not available; this should not happen") + x = map(identity, x) # Concretise type + return if f.with_closure + DI.value_and_gradient( + x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.adtype, x + ) + else + DI.value_and_gradient( + logdensity_at, + f.prep, + f.adtype, + x, + DI.Constant(f.model), + DI.Constant(f.varinfo), + DI.Constant(f.context), + ) + end +end + +# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)? +LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) + +### Utils + +""" + use_closure(adtype::ADTypes.AbstractADType) + +In LogDensityProblems, we want to calculate the derivative of logdensity(f, x) +with respect to x, where f is the model (in our case LogDensityFunction) and is +a constant. However, DifferentiationInterface generally expects a +single-argument function g(x) to differentiate. + +There are two ways of dealing with this: + +1. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f) + +2. Use a constant context. This lets us pass a two-argument function to + DifferentiationInterface, as long as we also give it the 'inactive argument' + (i.e. the model) wrapped in `DI.Constant`. + +The relative performance of the two approaches, however, depends on the AD +backend used. Some benchmarks are provided here: +https://github.com/TuringLang/DynamicPPL.jl/pull/806#issuecomment-2658061480 + +This function is used to determine whether a given AD backend should use a +closure or a constant. If `use_closure(adtype)` returns `true`, then the +closure approach will be used. By default, this function returns `false`, i.e. +the constant approach will be used. +""" +use_closure(::ADTypes.AbstractADType) = false +use_closure(::ADTypes.AutoForwardDiff) = false +use_closure(::ADTypes.AutoMooncake) = false +use_closure(::ADTypes.AutoReverseDiff) = true """ getmodel(f) Return the `DynamicPPL.Model` wrapped in the given log-density function `f`. """ -getmodel(f::LogDensityProblemsAD.ADGradientWrapper) = - getmodel(LogDensityProblemsAD.parent(f)) getmodel(f::DynamicPPL.LogDensityFunction) = f.model """ setmodel(f, model[, adtype]) Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. - -!!! warning - Note that if `f` is a `LogDensityProblemsAD.ADGradientWrapper` wrapping a - `DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f` - might require recompilation of the gradient tape, depending on the AD backend. """ -function setmodel( - f::LogDensityProblemsAD.ADGradientWrapper, - model::DynamicPPL.Model, - adtype::ADTypes.AbstractADType, -) - # TODO: Should we handle `SciMLBase.NoAD`? - # For an `ADGradientWrapper` we do the following: - # 1. Update the `Model` in the underlying `LogDensityFunction`. - # 2. Re-construct the `ADGradientWrapper` using `ADgradient` using the provided `adtype` - # to ensure that the recompilation of gradient tapes, etc. also occur. For example, - # ReverseDiff.jl in compiled mode will cache the compiled tape, which means that just - # replacing the corresponding field with the new model won't be sufficient to obtain - # the correct gradients. - return LogDensityProblemsAD.ADgradient( - adtype, setmodel(LogDensityProblemsAD.parent(f), model) - ) -end function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) - return Accessors.@set f.model = model + return LogDensityFunction(model, f.varinfo, f.context; adtype=f.adtype) end -# HACK: heavy usage of `AbstractSampler` for, well, _everything_, is being phased out. In the mean time -# we need to define these annoying methods to ensure that we stay compatible with everything. -getsampler(f::LogDensityFunction) = getsampler(getcontext(f)) -hassampler(f::LogDensityFunction) = hassampler(getcontext(f)) - """ getparams(f::LogDensityFunction) Return the parameters of the wrapped varinfo as a vector. """ getparams(f::LogDensityFunction) = f.varinfo[:] - -# LogDensityProblems interface -function LogDensityProblems.logdensity(f::LogDensityFunction, θ::AbstractVector) - context = getcontext(f) - vi_new = unflatten(f.varinfo, θ) - return getlogp(last(evaluate!!(f.model, vi_new, context))) -end -function LogDensityProblems.capabilities(::Type{<:LogDensityFunction}) - return LogDensityProblems.LogDensityOrder{0}() -end -# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)? -LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) - -# This is important for performance -- one needs to provide `ADGradient` with a vector of -# parameters, or DifferentiationInterface will not have sufficient information to e.g. -# compile a rule for Mooncake (because it won't know the type of the input), or pre-allocate -# a tape when using ReverseDiff.jl. -function _make_ad_gradient(ad::ADTypes.AbstractADType, ℓ::LogDensityFunction) - x = map(identity, getparams(ℓ)) # ensure we concretise the elements of the params - return LogDensityProblemsAD.ADgradient(ad, ℓ; x) -end - -function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoMooncake, f::LogDensityFunction) - return _make_ad_gradient(ad, f) -end -function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoReverseDiff, f::LogDensityFunction) - return _make_ad_gradient(ad, f) -end diff --git a/test/Project.toml b/test/Project.toml index c7583c672..420edba94 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -16,7 +16,6 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" -LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" @@ -46,7 +45,6 @@ EnzymeCore = "0.6 - 0.8" ForwardDiff = "0.10.12" JET = "0.9" LogDensityProblems = "2" -LogDensityProblemsAD = "1.7.0" MCMCChains = "6.0.4" MacroTools = "0.5.6" Mooncake = "0.4.59" diff --git a/test/ad.jl b/test/ad.jl index 17981cf2a..4bc0ef765 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,35 +1,59 @@ -@testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - f = DynamicPPL.LogDensityFunction(m) - rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) - vns = DynamicPPL.TestUtils.varnames(m) - varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) +using DynamicPPL: LogDensityFunction - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - f = DynamicPPL.LogDensityFunction(m, varinfo) +@testset "Automatic differentiation" begin + @testset "Unsupported backends" begin + @model demo() = x ~ Normal() + @test_logs (:warn, r"not officially supported") LogDensityFunction( + demo(); adtype=AutoZygote() + ) + end + + @testset "Correctness: ForwardDiff, ReverseDiff, and Mooncake" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) + vns = DynamicPPL.TestUtils.varnames(m) + varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) + + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + f = LogDensityFunction(m, varinfo) + x = DynamicPPL.getparams(f) + # Calculate reference logp + gradient of logp using ForwardDiff + ref_adtype = ADTypes.AutoForwardDiff() + ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype) + ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x) + + @testset "$adtype" for adtype in [ + AutoReverseDiff(; compile=false), + AutoReverseDiff(; compile=true), + AutoMooncake(; config=nothing), + ] + @info "Testing AD on: $(m.f) - $(short_varinfo_name(varinfo)) - $adtype" - # use ForwardDiff result as reference - ad_forwarddiff_f = LogDensityProblemsAD.ADgradient( - ADTypes.AutoForwardDiff(; chunksize=0), f - ) - # convert to `Vector{Float64}` to avoid `ReverseDiff` initializing the gradients to Integer 0 - # reference: https://github.com/TuringLang/DynamicPPL.jl/pull/571#issuecomment-1924304489 - θ = convert(Vector{Float64}, varinfo[:]) - logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ) + # Put predicates here to avoid long lines + is_mooncake = adtype isa AutoMooncake + is_1_10 = v"1.10" <= VERSION < v"1.11" + is_1_11 = v"1.11" <= VERSION < v"1.12" + is_svi_vnv = varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} + is_svi_od = varinfo isa SimpleVarInfo{<:OrderedDict} - @testset "$adtype" for adtype in [ - ADTypes.AutoReverseDiff(; compile=false), - ADTypes.AutoReverseDiff(; compile=true), - ADTypes.AutoMooncake(; config=nothing), - ] - # Mooncake can't currently handle something that is going on in - # SimpleVarInfo{<:VarNamedVector}. Disable all SimpleVarInfo tests for now. - if adtype isa ADTypes.AutoMooncake && varinfo isa DynamicPPL.SimpleVarInfo - @test_broken 1 == 0 - else - ad_f = LogDensityProblemsAD.ADgradient(adtype, f) - _, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ) - @test grad ≈ ref_grad + # Mooncake doesn't work with several combinations of SimpleVarInfo. + if is_mooncake && is_1_11 && is_svi_vnv + # https://github.com/compintell/Mooncake.jl/issues/470 + @test_throws ArgumentError DynamicPPL.setadtype(ref_ldf, adtype) + elseif is_mooncake && is_1_10 && is_svi_vnv + # TODO: report upstream + @test_throws UndefRefError DynamicPPL.setadtype(ref_ldf, adtype) + elseif is_mooncake && is_1_10 && is_svi_od + # TODO: report upstream + @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.setadtype( + ref_ldf, adtype + ) + else + ldf = DynamicPPL.setadtype(ref_ldf, adtype) + logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x) + @test grad ≈ ref_grad + @test logp ≈ ref_logp + end end end end @@ -64,13 +88,16 @@ # of implementation struct MyEmptyAlg end DynamicPPL.getspace(::DynamicPPL.Sampler{MyEmptyAlg}) = () - DynamicPPL.assume(rng, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi) = - DynamicPPL.assume(dist, vn, vi) + DynamicPPL.assume( + ::Random.AbstractRNG, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi + ) = DynamicPPL.assume(dist, vn, vi) # Compiling the ReverseDiff tape used to fail here spl = Sampler(MyEmptyAlg()) vi = VarInfo(model) - ldf = DynamicPPL.LogDensityFunction(vi, model, SamplingContext(spl)) - @test LogDensityProblemsAD.ADgradient(AutoReverseDiff(; compile=true), ldf) isa Any + ldf = LogDensityFunction( + model, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true) + ) + @test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any end end diff --git a/test/ext/DynamicPPLForwardDiffExt.jl b/test/ext/DynamicPPLForwardDiffExt.jl deleted file mode 100644 index 8de28046b..000000000 --- a/test/ext/DynamicPPLForwardDiffExt.jl +++ /dev/null @@ -1,14 +0,0 @@ -@testset "tag" begin - for chunksize in (nothing, 0, 1, 10) - ad = ADTypes.AutoForwardDiff(; chunksize=chunksize) - standardtag = if !isdefined(Base, :get_extension) - DynamicPPL.DynamicPPLForwardDiffExt.standardtag - else - Base.get_extension(DynamicPPL, :DynamicPPLForwardDiffExt).standardtag - end - @test standardtag(ad) - for tag in (false, 0, 1) - @test !standardtag(AutoForwardDiff(; chunksize=chunksize, tag=tag)) - end - end -end diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index beda767e6..319371609 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -1,4 +1,4 @@ -using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, ReverseDiff +using Test, DynamicPPL, ADTypes, LogDensityProblems, ReverseDiff @testset "`getmodel` and `setmodel`" begin @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS @@ -6,17 +6,6 @@ using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, Rever ℓ = DynamicPPL.LogDensityFunction(model) @test DynamicPPL.getmodel(ℓ) == model @test DynamicPPL.setmodel(ℓ, model).model == model - - # ReverseDiff related - ∇ℓ = LogDensityProblemsAD.ADgradient(:ReverseDiff, ℓ; compile=Val(false)) - @test DynamicPPL.getmodel(∇ℓ) == model - @test DynamicPPL.getmodel(DynamicPPL.setmodel(∇ℓ, model, AutoReverseDiff())) == - model - ∇ℓ = LogDensityProblemsAD.ADgradient(:ReverseDiff, ℓ; compile=Val(true)) - new_∇ℓ = DynamicPPL.setmodel(∇ℓ, model, AutoReverseDiff()) - @test DynamicPPL.getmodel(new_∇ℓ) == model - # HACK(sunxd): rely on internal implementation detail, i.e., naming of `compiledtape` - @test new_∇ℓ.compiledtape != ∇ℓ.compiledtape end end diff --git a/test/runtests.jl b/test/runtests.jl index 29a148789..2fff8adb6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,7 +9,7 @@ using Distributions using DistributionsAD using Documenter using ForwardDiff -using LogDensityProblems, LogDensityProblemsAD +using LogDensityProblems using MacroTools using MCMCChains using Mooncake: Mooncake @@ -75,7 +75,6 @@ include("test_util.jl") include("ext/DynamicPPLJETExt.jl") end @testset "ad" begin - include("ext/DynamicPPLForwardDiffExt.jl") include("ext/DynamicPPLMooncakeExt.jl") include("ad.jl") end diff --git a/test/test_util.jl b/test/test_util.jl index 27a68456c..d831a5ea6 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -56,6 +56,15 @@ function short_varinfo_name(vi::TypedVarInfo) end short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo" short_varinfo_name(::DynamicPPL.VectorVarInfo) = "VectorVarInfo" +function short_varinfo_name(::SimpleVarInfo{<:NamedTuple,<:Ref}) + return "SimpleVarInfo{<:NamedTuple,<:Ref}" +end +function short_varinfo_name(::SimpleVarInfo{<:OrderedDict,<:Ref}) + return "SimpleVarInfo{<:OrderedDict,<:Ref}" +end +function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector,<:Ref}) + return "SimpleVarInfo{<:VarNamedVector,<:Ref}" +end short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}" function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector})