Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to DifferentiationInterface #29

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 6 additions & 17 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,42 +1,31 @@
name = "LogDensityProblemsAD"
uuid = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
authors = ["Tamás K. Papp <[email protected]>"]
version = "1.9.0"
version = "2.0.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"

[weakdeps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
LogDensityProblemsADADTypesExt = "ADTypes"
LogDensityProblemsADEnzymeExt = "Enzyme"
LogDensityProblemsADFiniteDifferencesExt = "FiniteDifferences"
LogDensityProblemsADForwardDiffBenchmarkToolsExt = ["BenchmarkTools", "ForwardDiff"]
LogDensityProblemsADForwardDiffExt = "ForwardDiff"
LogDensityProblemsADReverseDiffExt = "ReverseDiff"
LogDensityProblemsADTrackerExt = "Tracker"
LogDensityProblemsADZygoteExt = "Zygote"

[compat]
ADTypes = "0.1.7, 0.2, 1"
ADTypes = "1.1"
DifferentiationInterface = "0.3.4"
DocStringExtensions = "0.8, 0.9"
Enzyme = "0.11, 0.12"
FiniteDifferences = "0.12"
LogDensityProblems = "1, 2"
Requires = "0.5, 1"
SimpleUnPack = "1"
Requires = "1.3"
julia = "1.6"

[extras]
Expand Down
28 changes: 0 additions & 28 deletions ext/DiffResults_helpers.jl

This file was deleted.

53 changes: 0 additions & 53 deletions ext/LogDensityProblemsADADTypesExt.jl

This file was deleted.

84 changes: 16 additions & 68 deletions ext/LogDensityProblemsADEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -1,78 +1,26 @@
"""
Gradient AD implementation using Enzyme.
"""
module LogDensityProblemsADEnzymeExt

if isdefined(Base, :get_extension)
using LogDensityProblemsAD: ADGradientWrapper, logdensity
using LogDensityProblemsAD.SimpleUnPack: @unpack

import LogDensityProblemsAD: ADgradient, logdensity_and_gradient
import Enzyme
using ADTypes: AutoEnzyme
using Enzyme: Reverse
using LogDensityProblemsAD: LogDensityProblemsAD, ADgradient, logdensity
else
using ..LogDensityProblemsAD: ADGradientWrapper, logdensity
using ..LogDensityProblemsAD.SimpleUnPack: @unpack

import ..LogDensityProblemsAD: ADgradient, logdensity_and_gradient
import ..Enzyme
end

struct EnzymeGradientLogDensity{L,M<:Union{Enzyme.ForwardMode,Enzyme.ReverseMode},S} <: ADGradientWrapper
ℓ::L
mode::M
shadow::S # only used in forward mode
using ..ADTypes: AutoEnzyme
using ..Enzyme: Reverse
using ..LogDensityProblemsAD: LogDensityProblemsAD, ADgradient, logdensity
end

"""
ADgradient(:Enzyme, ℓ; kwargs...)
ADgradient(Val(:Enzyme), ℓ; kwargs...)

Gradient using algorithmic/automatic differentiation via Enzyme.

# Keyword arguments

- `mode::Enzyme.Mode`: Differentiation mode (default: `Enzyme.Reverse`).
Currently only `Enzyme.Reverse` and `Enzyme.Forward` are supported.

- `shadow`: Collection of one-hot vectors for each entry of the inputs `x` to the log density
`ℓ`, or `nothing` (default: `nothing`). This keyword argument is only used in forward
mode. By default, it will be recomputed in every call of `logdensity_and_gradient(ℓ, x)`.
For performance reasons it is recommended to compute it only once when calling `ADgradient`.
The one-hot vectors can be constructed, e.g., with `Enzyme.onehot(x)`.
"""
function ADgradient(::Val{:Enzyme}, ℓ; mode::Enzyme.Mode = Enzyme.Reverse, shadow = nothing)
mode isa Union{Enzyme.ForwardMode,Enzyme.ReverseMode} ||
throw(ArgumentError("currently automatic differentiation via Enzyme only supports " *
"`Enzyme.Forward` and `Enzyme.Reverse` modes"))
if mode isa Enzyme.ReverseMode && shadow !== nothing
@info "keyword argument `shadow` is ignored in reverse mode"
shadow = nothing
function LogDensityProblemsAD.ADgradient(
::Val{:Enzyme},
ℓ;
mode=Reverse,
shadow=nothing,
)
if !isnothing(shadow)
@warn "keyword argument `shadow` is now ignored"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the PR will lead to worse performance in downstream packages that work with pre-allocated/cached shadows?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is that DifferentiationInterface has its own mechanism for preparing a gradient, which is triggered when you supply the constructor with x.
As you can see here, it does the exact same shadow construction:
https://github.com/gdalle/DifferentiationInterface.jl/blob/16d93ef4111e1c196912a3dd53ffb31e04445324/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L40-L48
Again, the idea is to have a single source code for all of this boilerplate, so that if there is something to improve, it can be improved in DI

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assumed there exists such an internal caching/preparation - but this will still break use cases, or at least lead to worse performance, in cases where currently a shadow is pre-allocated and passed around to multiple ADgradient calls. So from the perspective of ADgradient, ideally it would still be possible to forward a shadow to the construction in DI.

It seems there are more major issues with the Enzyme backend though: #29 (review)

end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we put this in a breaking release as suggested above, IMO we should just remove the keyword argument. Alternatively, in a non-breaking release I'd prefer to make it a deprecation warning (possibly one that is forced to be displayed).

return EnzymeGradientLogDensity(ℓ, mode, shadow)
end

function Base.show(io::IO, ∇ℓ::EnzymeGradientLogDensity)
print(io, "Enzyme AD wrapper for ", ∇ℓ.ℓ, " with ",
∇ℓ.mode isa Enzyme.ForwardMode ? "forward" : "reverse", " mode")
end

function logdensity_and_gradient(∇ℓ::EnzymeGradientLogDensity{<:Any,<:Enzyme.ForwardMode},
x::AbstractVector)
@unpack ℓ, mode, shadow = ∇ℓ
_shadow = shadow === nothing ? Enzyme.onehot(x) : shadow
y, ∂ℓ_∂x = Enzyme.autodiff(mode, logdensity, Enzyme.BatchDuplicated,
Enzyme.Const(ℓ),
Enzyme.BatchDuplicated(x, _shadow))
return y, collect(∂ℓ_∂x)
end

function logdensity_and_gradient(∇ℓ::EnzymeGradientLogDensity{<:Any,<:Enzyme.ReverseMode},
x::AbstractVector)
@unpack ℓ = ∇ℓ
∂ℓ_∂x = zero(x)
_, y = Enzyme.autodiff(Enzyme.ReverseWithPrimal, logdensity, Enzyme.Active,
Enzyme.Const(ℓ), Enzyme.Duplicated(x, ∂ℓ_∂x))
y, ∂ℓ_∂x
backend = AutoEnzyme(; mode)
return ADgradient(backend, ℓ)
end

end # module
52 changes: 10 additions & 42 deletions ext/LogDensityProblemsADFiniteDifferencesExt.jl
Original file line number Diff line number Diff line change
@@ -1,51 +1,19 @@
"""
Gradient implementation using FiniteDifferences.
"""
module LogDensityProblemsADFiniteDifferencesExt

if isdefined(Base, :get_extension)
using LogDensityProblemsAD: ADGradientWrapper, logdensity
using LogDensityProblemsAD.SimpleUnPack: @unpack

import LogDensityProblemsAD: ADgradient, logdensity_and_gradient
import FiniteDifferences
using ADTypes: AutoFiniteDifferences
import FiniteDifferences: central_fdm
using LogDensityProblemsAD: LogDensityProblemsAD, ADgradient
else
using ..LogDensityProblemsAD: ADGradientWrapper, logdensity
using ..LogDensityProblemsAD.SimpleUnPack: @unpack

import ..LogDensityProblemsAD: ADgradient, logdensity_and_gradient
import ..FiniteDifferences
end

struct FiniteDifferencesGradientLogDensity{L,M} <: ADGradientWrapper
ℓ::L
"finite difference method"
fdm::M
end

"""
ADgradient(:FiniteDifferences, ℓ; fdm = central_fdm(5, 1))
ADgradient(Val(:FiniteDifferences), ℓ; fdm = central_fdm(5, 1))

Gradient using FiniteDifferences, mainly intended for checking results from other algorithms.

# Keyword arguments

- `fdm`: the finite difference method. Defaults to `central_fdm(5, 1)`.
"""
function ADgradient(::Val{:FiniteDifferences}, ℓ; fdm = FiniteDifferences.central_fdm(5, 1))
FiniteDifferencesGradientLogDensity(ℓ, fdm)
end

function Base.show(io::IO, ∇ℓ::FiniteDifferencesGradientLogDensity)
print(io, "FiniteDifferences AD wrapper for ", ∇ℓ.ℓ, " with ", ∇ℓ.fdm)
using ..ADTypes: AutoFiniteDifferences
import ..FiniteDifferences: central_fdm
using ..LogDensityProblemsAD: LogDensityProblemsAD, ADgradient
end

function logdensity_and_gradient(∇ℓ::FiniteDifferencesGradientLogDensity, x::AbstractVector)
@unpack ℓ, fdm = ∇ℓ
y = logdensity(ℓ, x)
∇y = only(FiniteDifferences.grad(fdm, Base.Fix1(logdensity, ℓ), x))
y, ∇y
function LogDensityProblemsAD.ADgradient(::Val{:FiniteDifferences}, ℓ)
fdm = central_fdm(5, 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please re-add the fdm keyword argument.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh right I had missed that one

backend = AutoFiniteDifferences(; fdm)
ADgradient(backend, ℓ)
end

end # module
65 changes: 0 additions & 65 deletions ext/LogDensityProblemsADForwardDiffBenchmarkToolsExt.jl

This file was deleted.

Loading
Loading