Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove LogDensityProblemsAD #806

Draft
wants to merge 12 commits into
base: release-0.35
Choose a base branch
from
9 changes: 9 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ This release removes the feature of `VarInfo` where it kept track of which varia

This change also affects sampling in Turing.jl.

**Other changes**

LogDensityProblemsAD is now removed as a dependency.
Instead of constructing a `LogDensityProblemAD.ADgradient` object, we now directly use `DifferentiationInterface` to calculate the gradient of the log density with respect to model parameters.

In practice, this means that if you want to calculate the gradient for a model, you can do:

TODO(penelopeysm): Finish this

## 0.34.2

- Fixed bugs in ValuesAsInModelContext as well as DebugContext where underlying PrefixContexts were not being applied.
Expand Down
13 changes: 3 additions & 10 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,13 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
# TODO(penelopeysm,mhauru) KernelAbstractions is only a dependency so that we can pin its version, see
# https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -31,7 +29,6 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Expand All @@ -40,7 +37,6 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[extensions]
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLJETExt = ["JET"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLMooncakeExt = ["Mooncake"]
Expand All @@ -56,17 +52,14 @@ Bijectors = "0.13.18, 0.14, 0.15"
ChainRulesCore = "1"
Compat = "4"
ConstructionBase = "1.5.4"
DifferentiationInterface = "0.6.39"
Distributions = "0.25"
DocStringExtensions = "0.9"
# TODO(penelopeysm,mhauru) See https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767
# for why KernelAbstractions is pinned like this.
KernelAbstractions = "< 0.9.32"
EnzymeCore = "0.6 - 0.8"
ForwardDiff = "0.10"
JET = "0.9"
KernelAbstractions = "< 0.9.32"
LinearAlgebra = "1.6"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.7.0"
MCMCChains = "6"
MacroTools = "0.5.6"
Mooncake = "0.4.59"
Expand Down
3 changes: 2 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ logjoint

### LogDensityProblems.jl interface

The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by simply wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction`:
The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction` or `DynamicPPL.LogDensityFunctionWithGrad`.

```@docs
DynamicPPL.LogDensityFunction
DynamicPPL.LogDensityFunctionWithGrad
```

## Condition and decondition
Expand Down
54 changes: 0 additions & 54 deletions ext/DynamicPPLForwardDiffExt.jl

This file was deleted.

1 change: 0 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ using MacroTools: MacroTools
using ConstructionBase: ConstructionBase
using Accessors: Accessors
using LogDensityProblems: LogDensityProblems
using LogDensityProblemsAD: LogDensityProblemsAD

using LinearAlgebra: LinearAlgebra, Cholesky

Expand Down
1 change: 1 addition & 0 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@
getsampler(context::SamplingContext) = context.sampler
getsampler(context::AbstractContext) = getsampler(NodeTrait(context), context)
getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(context))
getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context")

Check warning on line 187 in src/contexts.jl

View check run for this annotation

Codecov / codecov/patch

src/contexts.jl#L187

Added line #L187 was not covered by tests

"""
struct DefaultContext <: AbstractContext end
Expand Down
162 changes: 109 additions & 53 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import DifferentiationInterface as DI

"""
LogDensityFunction

A callable representing a log density function of a `model`.
`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface,
but only to 0th-order, i.e. it is only possible to calculate the log density,
and not its gradient. If you need to calculate the gradient as well, you have
to construct a [`DynamicPPL.LogDensityFunctionWithGrad`](@ref) object.

# Fields
$(FIELDS)
Expand Down Expand Up @@ -53,16 +59,6 @@
context::C
end

# TODO: Deprecate.
function LogDensityFunction(
varinfo::AbstractVarInfo,
model::Model,
sampler::AbstractSampler,
context::AbstractContext,
)
return LogDensityFunction(varinfo, model, SamplingContext(sampler, context))
end

function LogDensityFunction(
model::Model,
varinfo::AbstractVarInfo=VarInfo(model),
Expand All @@ -81,57 +77,28 @@

Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
"""
getmodel(f::LogDensityProblemsAD.ADGradientWrapper) =
getmodel(LogDensityProblemsAD.parent(f))
getmodel(f::DynamicPPL.LogDensityFunction) = f.model

"""
setmodel(f, model[, adtype])

Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.

!!! warning
Note that if `f` is a `LogDensityProblemsAD.ADGradientWrapper` wrapping a
`DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f`
might require recompilation of the gradient tape, depending on the AD backend.
"""
function setmodel(
f::LogDensityProblemsAD.ADGradientWrapper,
model::DynamicPPL.Model,
adtype::ADTypes.AbstractADType,
)
# TODO: Should we handle `SciMLBase.NoAD`?
# For an `ADGradientWrapper` we do the following:
# 1. Update the `Model` in the underlying `LogDensityFunction`.
# 2. Re-construct the `ADGradientWrapper` using `ADgradient` using the provided `adtype`
# to ensure that the recompilation of gradient tapes, etc. also occur. For example,
# ReverseDiff.jl in compiled mode will cache the compiled tape, which means that just
# replacing the corresponding field with the new model won't be sufficient to obtain
# the correct gradients.
return LogDensityProblemsAD.ADgradient(
adtype, setmodel(LogDensityProblemsAD.parent(f), model)
)
end
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
return Accessors.@set f.model = model
end

# HACK: heavy usage of `AbstractSampler` for, well, _everything_, is being phased out. In the mean time
# we need to define these annoying methods to ensure that we stay compatible with everything.
getsampler(f::LogDensityFunction) = getsampler(getcontext(f))
hassampler(f::LogDensityFunction) = hassampler(getcontext(f))

"""
getparams(f::LogDensityFunction)

Return the parameters of the wrapped varinfo as a vector.
"""
getparams(f::LogDensityFunction) = f.varinfo[:]

# LogDensityProblems interface
function LogDensityProblems.logdensity(f::LogDensityFunction, θ::AbstractVector)
# LogDensityProblems interface: logp (0th order)
function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector)

Check warning on line 99 in src/logdensityfunction.jl

View check run for this annotation

Codecov / codecov/patch

src/logdensityfunction.jl#L99

Added line #L99 was not covered by tests
context = getcontext(f)
vi_new = unflatten(f.varinfo, θ)
vi_new = unflatten(f.varinfo, x)

Check warning on line 101 in src/logdensityfunction.jl

View check run for this annotation

Codecov / codecov/patch

src/logdensityfunction.jl#L101

Added line #L101 was not covered by tests
return getlogp(last(evaluate!!(f.model, vi_new, context)))
end
function LogDensityProblems.capabilities(::Type{<:LogDensityFunction})
Expand All @@ -140,18 +107,107 @@
# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)?
LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f))

# This is important for performance -- one needs to provide `ADGradient` with a vector of
# parameters, or DifferentiationInterface will not have sufficient information to e.g.
# compile a rule for Mooncake (because it won't know the type of the input), or pre-allocate
# a tape when using ReverseDiff.jl.
function _make_ad_gradient(ad::ADTypes.AbstractADType, ℓ::LogDensityFunction)
x = map(identity, getparams(ℓ)) # ensure we concretise the elements of the params
return LogDensityProblemsAD.ADgradient(ad, ℓ; x)
# LogDensityProblems interface: gradient (1st order)
"""
use_closure(adtype::ADTypes.AbstractADType)

In LogDensityProblems, we want to calculate the derivative of logdensity(f, x)
with respect to x, where f is the model (in our case LogDensityFunction) and is
a constant. However, DifferentiationInterface generally expects a
single-argument function g(x) to differentiate.

There are two ways of dealing with this:

1. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f)

2. Use a constant context. This lets us pass a two-argument function to
DifferentiationInterface, as long as we also give it the 'inactive argument'
(i.e. the model) wrapped in `DI.Constant`.

The relative performance of the two approaches, however, depends on the AD
backend used. Some benchmarks are provided here:
https://github.com/TuringLang/DynamicPPL.jl/pull/806#issuecomment-2658061480

This function is used to determine whether a given AD backend should use a
closure or a constant. If `use_closure(adtype)` returns `true`, then the
closure approach will be used. By default, this function returns `false`, i.e.
the constant approach will be used.
"""
use_closure(::ADTypes.AbstractADType) = false
use_closure(::ADTypes.AutoForwardDiff) = false
use_closure(::ADTypes.AutoMooncake) = false
use_closure(::ADTypes.AutoReverseDiff) = true

Check warning on line 139 in src/logdensityfunction.jl

View check run for this annotation

Codecov / codecov/patch

src/logdensityfunction.jl#L136-L139

Added lines #L136 - L139 were not covered by tests

"""
_flipped_logdensity(f::LogDensityFunction, x::AbstractVector)

This function is the same as `LogDensityProblems.logdensity(f, x)` but with the
arguments flipped. It is used in the 'constant' approach to DifferentiationInterface
(see `use_closure` for more information).
"""
function _flipped_logdensity(x::AbstractVector, f::LogDensityFunction)
return LogDensityProblems.logdensity(f, x)

Check warning on line 149 in src/logdensityfunction.jl

View check run for this annotation

Codecov / codecov/patch

src/logdensityfunction.jl#L148-L149

Added lines #L148 - L149 were not covered by tests
end

function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoMooncake, f::LogDensityFunction)
return _make_ad_gradient(ad, f)
"""
LogDensityFunctionWithGrad(ldf::DynamicPPL.LogDensityFunction, adtype::ADTypes.AbstractADType)

A callable representing a log density function of a `model`.
`DynamicPPL.LogDensityFunctionWithGrad` implements the LogDensityProblems.jl
interface to 1st-order, meaning that you can both calculate the log density
using

LogDensityProblems.logdensity(f, x)

and its gradient using

LogDensityProblems.logdensity_and_gradient(f, x)

where `f` is a `LogDensityFunctionWithGrad` object and `x` is a vector of parameters.

# Fields
$(FIELDS)
"""
struct LogDensityFunctionWithGrad{V,M,C,TAD<:ADTypes.AbstractADType}
ldf::LogDensityFunction{V,M,C}
adtype::TAD
prep::DI.GradientPrep
with_closure::Bool

function LogDensityFunctionWithGrad(

Check warning on line 177 in src/logdensityfunction.jl

View check run for this annotation

Codecov / codecov/patch

src/logdensityfunction.jl#L177

Added line #L177 was not covered by tests
ldf::LogDensityFunction{V,M,C}, adtype::TAD
) where {V,M,C,TAD}
# Get a set of dummy params to use for prep
x = map(identity, getparams(ldf))
with_closure = use_closure(adtype)
if with_closure
prep = DI.prepare_gradient(

Check warning on line 184 in src/logdensityfunction.jl

View check run for this annotation

Codecov / codecov/patch

src/logdensityfunction.jl#L181-L184

Added lines #L181 - L184 were not covered by tests
Base.Fix1(LogDensityProblems.logdensity, ldf), adtype, x
)
else
prep = DI.prepare_gradient(_flipped_logdensity, adtype, x, DI.Constant(ldf))

Check warning on line 188 in src/logdensityfunction.jl

View check run for this annotation

Codecov / codecov/patch

src/logdensityfunction.jl#L188

Added line #L188 was not covered by tests
end
# Store the prep with the struct. We also store whether a closure was used because
# we need to know this when calling `DI.value_and_gradient`. In practice we could
# recalculate it, but this runs the risk of introducing inconsistencies.
return new{V,M,C,TAD}(ldf, adtype, prep, with_closure)

Check warning on line 193 in src/logdensityfunction.jl

View check run for this annotation

Codecov / codecov/patch

src/logdensityfunction.jl#L193

Added line #L193 was not covered by tests
end
end
function LogDensityProblems.logdensity(f::LogDensityFunctionWithGrad)
return LogDensityProblems.logdensity(f.ldf)

Check warning on line 197 in src/logdensityfunction.jl

View check run for this annotation

Codecov / codecov/patch

src/logdensityfunction.jl#L196-L197

Added lines #L196 - L197 were not covered by tests
end
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoReverseDiff, f::LogDensityFunction)
return _make_ad_gradient(ad, f)
function LogDensityProblems.capabilities(::Type{<:LogDensityFunctionWithGrad})
return LogDensityProblems.LogDensityOrder{1}()

Check warning on line 200 in src/logdensityfunction.jl

View check run for this annotation

Codecov / codecov/patch

src/logdensityfunction.jl#L199-L200

Added lines #L199 - L200 were not covered by tests
end
function LogDensityProblems.logdensity_and_gradient(

Check warning on line 202 in src/logdensityfunction.jl

View check run for this annotation

Codecov / codecov/patch

src/logdensityfunction.jl#L202

Added line #L202 was not covered by tests
f::LogDensityFunctionWithGrad, x::AbstractVector
)
x = map(identity, x) # Concretise type
return if f.with_closure
DI.value_and_gradient(

Check warning on line 207 in src/logdensityfunction.jl

View check run for this annotation

Codecov / codecov/patch

src/logdensityfunction.jl#L205-L207

Added lines #L205 - L207 were not covered by tests
Base.Fix1(LogDensityProblems.logdensity, f.ldf), f.prep, f.adtype, x
)
else
DI.value_and_gradient(_flipped_logdensity, f.prep, f.adtype, x, DI.Constant(f.ldf))

Check warning on line 211 in src/logdensityfunction.jl

View check run for this annotation

Codecov / codecov/patch

src/logdensityfunction.jl#L211

Added line #L211 was not covered by tests
end
end
2 changes: 0 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Expand Down Expand Up @@ -46,7 +45,6 @@ EnzymeCore = "0.6 - 0.8"
ForwardDiff = "0.10.12"
JET = "0.9"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.7.0"
MCMCChains = "6.0.4"
MacroTools = "0.5.6"
Mooncake = "0.4.59"
Expand Down
Loading
Loading