Skip to content

Commit

Permalink
Remove LogDensityProblemsAD
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Feb 10, 2025
1 parent 7613dbb commit 5aafaf0
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 142 deletions.
10 changes: 3 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,13 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
# TODO(penelopeysm,mhauru) KernelAbstractions is only a dependency so that we can pin its version, see
# https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down Expand Up @@ -56,17 +54,15 @@ Bijectors = "0.13.18, 0.14, 0.15"
ChainRulesCore = "1"
Compat = "4"
ConstructionBase = "1.5.4"
DifferentiationInterface = "0.6.39"
Distributions = "0.25"
DocStringExtensions = "0.9"
# TODO(penelopeysm,mhauru) See https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767
# for why KernelAbstractions is pinned like this.
KernelAbstractions = "< 0.9.32"
EnzymeCore = "0.6 - 0.8"
ForwardDiff = "0.10"
JET = "0.9"
KernelAbstractions = "< 0.9.32"
LinearAlgebra = "1.6"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.7.0"
MCMCChains = "6"
MacroTools = "0.5.6"
Mooncake = "0.4.59"
Expand Down
31 changes: 2 additions & 29 deletions ext/DynamicPPLForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,40 +1,13 @@
module DynamicPPLForwardDiffExt

if isdefined(Base, :get_extension)
using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
using ForwardDiff
else
using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
using ..ForwardDiff
end
using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems
using ForwardDiff

getchunksize(::ADTypes.AutoForwardDiff{chunk}) where {chunk} = chunk

standardtag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true
standardtag(::ADTypes.AutoForwardDiff) = false

function LogDensityProblemsAD.ADgradient(
ad::ADTypes.AutoForwardDiff, ℓ::DynamicPPL.LogDensityFunction
)
θ = DynamicPPL.getparams(ℓ)
f = Base.Fix1(LogDensityProblems.logdensity, ℓ)

# Define configuration for ForwardDiff.
tag = if standardtag(ad)
ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(θ))
else
ForwardDiff.Tag(f, eltype(θ))
end
chunk_size = getchunksize(ad)
chunk = if chunk_size == 0 || chunk_size === nothing
ForwardDiff.Chunk(θ)
else
ForwardDiff.Chunk(length(θ), chunk_size)
end

return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk, tag, x=θ)
end

# Allow Turing tag in gradient etc. calls of the log density function
function ForwardDiff.checktag(
::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}},
Expand Down
1 change: 0 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ using MacroTools: MacroTools
using ConstructionBase: ConstructionBase
using Accessors: Accessors
using LogDensityProblems: LogDensityProblems
using LogDensityProblemsAD: LogDensityProblemsAD

using LinearAlgebra: LinearAlgebra, Cholesky

Expand Down
1 change: 1 addition & 0 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ at which point it will return the sampler of that context.
getsampler(context::SamplingContext) = context.sampler
getsampler(context::AbstractContext) = getsampler(NodeTrait(context), context)
getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(context))
getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context")

Check warning on line 187 in src/contexts.jl

View check run for this annotation

Codecov / codecov/patch

src/contexts.jl#L187

Added line #L187 was not covered by tests

"""
struct DefaultContext <: AbstractContext end
Expand Down
56 changes: 20 additions & 36 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import DifferentiationInterface as DI

"""
LogDensityFunction
Expand Down Expand Up @@ -81,37 +83,13 @@ end
Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
"""
getmodel(f::LogDensityProblemsAD.ADGradientWrapper) =
getmodel(LogDensityProblemsAD.parent(f))
getmodel(f::DynamicPPL.LogDensityFunction) = f.model

"""
setmodel(f, model[, adtype])
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
!!! warning
Note that if `f` is a `LogDensityProblemsAD.ADGradientWrapper` wrapping a
`DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f`
might require recompilation of the gradient tape, depending on the AD backend.
"""
function setmodel(
f::LogDensityProblemsAD.ADGradientWrapper,
model::DynamicPPL.Model,
adtype::ADTypes.AbstractADType,
)
# TODO: Should we handle `SciMLBase.NoAD`?
# For an `ADGradientWrapper` we do the following:
# 1. Update the `Model` in the underlying `LogDensityFunction`.
# 2. Re-construct the `ADGradientWrapper` using `ADgradient` using the provided `adtype`
# to ensure that the recompilation of gradient tapes, etc. also occur. For example,
# ReverseDiff.jl in compiled mode will cache the compiled tape, which means that just
# replacing the corresponding field with the new model won't be sufficient to obtain
# the correct gradients.
return LogDensityProblemsAD.ADgradient(
adtype, setmodel(LogDensityProblemsAD.parent(f), model)
)
end
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
return Accessors.@set f.model = model
end
Expand Down Expand Up @@ -140,18 +118,24 @@ end
# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)?
LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f))

# This is important for performance -- one needs to provide `ADGradient` with a vector of
# parameters, or DifferentiationInterface will not have sufficient information to e.g.
# compile a rule for Mooncake (because it won't know the type of the input), or pre-allocate
# a tape when using ReverseDiff.jl.
function _make_ad_gradient(ad::ADTypes.AbstractADType, ℓ::LogDensityFunction)
x = map(identity, getparams(ℓ)) # ensure we concretise the elements of the params
return LogDensityProblemsAD.ADgradient(ad, ℓ; x)
end
_flipped_logdensity(θ, f) = LogDensityProblems.logdensity(f, θ)

Check warning on line 121 in src/logdensityfunction.jl

View check run for this annotation

Codecov / codecov/patch

src/logdensityfunction.jl#L121

Added line #L121 was not covered by tests

function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoMooncake, f::LogDensityFunction)
return _make_ad_gradient(ad, f)
# By default, the AD backend to use is inferred from the context, which would
# typically be a SamplingContext which contains a sampler.
function LogDensityProblems.logdensity_and_gradient(

Check warning on line 125 in src/logdensityfunction.jl

View check run for this annotation

Codecov / codecov/patch

src/logdensityfunction.jl#L125

Added line #L125 was not covered by tests
f::LogDensityFunction, θ::AbstractVector
)
adtype = getadtype(getsampler(getcontext(f)))
return LogDensityProblems.logdensity_and_gradient(f, θ, adtype)

Check warning on line 129 in src/logdensityfunction.jl

View check run for this annotation

Codecov / codecov/patch

src/logdensityfunction.jl#L128-L129

Added lines #L128 - L129 were not covered by tests
end
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoReverseDiff, f::LogDensityFunction)
return _make_ad_gradient(ad, f)

# Extra method allowing one to manually specify the AD backend to use, thus
# overriding the default AD backend inferred from the sampler.
function LogDensityProblems.logdensity_and_gradient(

Check warning on line 134 in src/logdensityfunction.jl

View check run for this annotation

Codecov / codecov/patch

src/logdensityfunction.jl#L134

Added line #L134 was not covered by tests
f::LogDensityFunction, θ::AbstractVector, adtype::ADTypes.AbstractADType
)
# Ensure we concretise the elements of the params.
θ = map(identity, getparams(f))
prep = DI.prepare_gradient(_flipped_logdensity, adtype, params, DI.Constant(f))
return DI.value_and_gradient(_flipped_logdensity, prep, adtype, params, DI.Constant(f))

Check warning on line 140 in src/logdensityfunction.jl

View check run for this annotation

Codecov / codecov/patch

src/logdensityfunction.jl#L138-L140

Added lines #L138 - L140 were not covered by tests
end
3 changes: 3 additions & 0 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ Sampler(alg) = Sampler(alg, Selector())
Sampler(alg, model::Model) = Sampler(alg, model, Selector())
Sampler(alg, model::Model, s::Selector) = Sampler(alg, s)

# Extract the AD type from the underlying algorithm
getadtype(s::Sampler) = getadtype(s.alg)

Check warning on line 58 in src/sampler.jl

View check run for this annotation

Codecov / codecov/patch

src/sampler.jl#L58

Added line #L58 was not covered by tests

# AbstractMCMC interface for SampleFromUniform and SampleFromPrior
function AbstractMCMC.step(
rng::Random.AbstractRNG,
Expand Down
2 changes: 0 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Expand Down Expand Up @@ -46,7 +45,6 @@ EnzymeCore = "0.6 - 0.8"
ForwardDiff = "0.10.12"
JET = "0.9"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.7.0"
MCMCChains = "6.0.4"
MacroTools = "0.5.6"
Mooncake = "0.4.59"
Expand Down
17 changes: 7 additions & 10 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,12 @@

@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
f = DynamicPPL.LogDensityFunction(m, varinfo)

# use ForwardDiff result as reference
ad_forwarddiff_f = LogDensityProblemsAD.ADgradient(
ADTypes.AutoForwardDiff(; chunksize=0), f
)
# convert to `Vector{Float64}` to avoid `ReverseDiff` initializing the gradients to Integer 0
# reference: https://github.com/TuringLang/DynamicPPL.jl/pull/571#issuecomment-1924304489
θ = convert(Vector{Float64}, varinfo[:])
logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ)
# Calculate reference logp + gradient of logp using ForwardDiff
default_adtype = ADTypes.AutoForwardDiff(; chunksize=0)
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(f, θ, default_adtype)

@testset "$adtype" for adtype in [
ADTypes.AutoReverseDiff(; compile=false),
Expand All @@ -27,9 +24,9 @@
if adtype isa ADTypes.AutoMooncake && varinfo isa DynamicPPL.SimpleVarInfo
@test_broken 1 == 0
else
ad_f = LogDensityProblemsAD.ADgradient(adtype, f)
_, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ)
logp, grad = LogDensityProblems.logdensity_and_gradient(f, θ, adtype)
@test grad ref_grad
@test logp ref_logp
end
end
end
Expand All @@ -50,7 +47,7 @@
x = Vector{T}(undef, TT)
x[1] = α
for t in 2:TT
x[t] = x[t - 1] + η[t - 1] * τ
x[t] = x[t-1] + η[t-1] * τ
end
# measurement model
y ~ MvNormal(x, σ^2 * I)
Expand All @@ -71,6 +68,6 @@
spl = Sampler(MyEmptyAlg())
vi = VarInfo(model)
ldf = DynamicPPL.LogDensityFunction(vi, model, SamplingContext(spl))
@test LogDensityProblemsAD.ADgradient(AutoReverseDiff(; compile=true), ldf) isa Any
@test LogDensityProblems.logdensity_and_gradient(ldf, vi[:], AutoReverseDiff(; compile=true)) isa Any
end
end
13 changes: 1 addition & 12 deletions test/logdensityfunction.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,11 @@
using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, ReverseDiff
using Test, DynamicPPL, ADTypes, LogDensityProblems, ReverseDiff

@testset "`getmodel` and `setmodel`" begin
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
model = DynamicPPL.TestUtils.DEMO_MODELS[1]
= DynamicPPL.LogDensityFunction(model)
@test DynamicPPL.getmodel(ℓ) == model
@test DynamicPPL.setmodel(ℓ, model).model == model

# ReverseDiff related
∇ℓ = LogDensityProblemsAD.ADgradient(:ReverseDiff, ℓ; compile=Val(false))
@test DynamicPPL.getmodel(∇ℓ) == model
@test DynamicPPL.getmodel(DynamicPPL.setmodel(∇ℓ, model, AutoReverseDiff())) ==
model
∇ℓ = LogDensityProblemsAD.ADgradient(:ReverseDiff, ℓ; compile=Val(true))
new_∇ℓ = DynamicPPL.setmodel(∇ℓ, model, AutoReverseDiff())
@test DynamicPPL.getmodel(new_∇ℓ) == model
# HACK(sunxd): rely on internal implementation detail, i.e., naming of `compiledtape`
@test new_∇ℓ.compiledtape != ∇ℓ.compiledtape
end
end

Expand Down
90 changes: 45 additions & 45 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using Distributions
using DistributionsAD
using Documenter
using ForwardDiff
using LogDensityProblems, LogDensityProblemsAD
using LogDensityProblems
using MacroTools
using MCMCChains
using Mooncake: Mooncake
Expand Down Expand Up @@ -45,57 +45,57 @@ include("test_util.jl")
# groups are chosen to make both groups take roughly the same amount of
# time, but beyond that there is no particular reason for the split.
if GROUP == "All" || GROUP == "Group1"
include("utils.jl")
include("compiler.jl")
include("varnamedvector.jl")
include("varinfo.jl")
include("simple_varinfo.jl")
include("model.jl")
include("sampler.jl")
include("independence.jl")
include("distribution_wrappers.jl")
include("logdensityfunction.jl")
include("linking.jl")
include("serialization.jl")
include("pointwise_logdensities.jl")
include("lkj.jl")
include("deprecated.jl")
# include("utils.jl")
# include("compiler.jl")
# include("varnamedvector.jl")
# include("varinfo.jl")
# include("simple_varinfo.jl")
# include("model.jl")
# include("sampler.jl")
# include("independence.jl")
# include("distribution_wrappers.jl")
# include("logdensityfunction.jl")
# include("linking.jl")
# include("serialization.jl")
# include("pointwise_logdensities.jl")
# include("lkj.jl")
# include("deprecated.jl")
end

if GROUP == "All" || GROUP == "Group2"
include("contexts.jl")
include("context_implementations.jl")
include("threadsafe.jl")
include("debug_utils.jl")
@testset "compat" begin
include(joinpath("compat", "ad.jl"))
end
@testset "extensions" begin
include("ext/DynamicPPLMCMCChainsExt.jl")
include("ext/DynamicPPLJETExt.jl")
end
# include("contexts.jl")
# include("context_implementations.jl")
# include("threadsafe.jl")
# include("debug_utils.jl")
# @testset "compat" begin
# include(joinpath("compat", "ad.jl"))
# end
# @testset "extensions" begin
# include("ext/DynamicPPLMCMCChainsExt.jl")
# include("ext/DynamicPPLJETExt.jl")
# end
@testset "ad" begin
include("ext/DynamicPPLForwardDiffExt.jl")
include("ext/DynamicPPLMooncakeExt.jl")
include("ad.jl")
end
@testset "prob and logprob macro" begin
@test_throws ErrorException prob"..."
@test_throws ErrorException logprob"..."
end
@testset "doctests" begin
DocMeta.setdocmeta!(
DynamicPPL,
:DocTestSetup,
:(using DynamicPPL, Distributions);
recursive=true,
)
doctestfilters = [
# Ignore the source of a warning in the doctest output, since this is dependent on host.
# This is a line that starts with "└ @ " and ends with the line number.
r"└ @ .+:[0-9]+",
]
doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters)
end
# @testset "prob and logprob macro" begin
# @test_throws ErrorException prob"..."
# @test_throws ErrorException logprob"..."
# end
# @testset "doctests" begin
# DocMeta.setdocmeta!(
# DynamicPPL,
# :DocTestSetup,
# :(using DynamicPPL, Distributions);
# recursive=true,
# )
# doctestfilters = [
# # Ignore the source of a warning in the doctest output, since this is dependent on host.
# # This is a line that starts with "└ @ " and ends with the line number.
# r"└ @ .+:[0-9]+",
# ]
# doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters)
# end
end
end

0 comments on commit 5aafaf0

Please sign in to comment.