-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Switch to DifferentiationInterface #29
Changes from all commits
aa4aa38
a3c49f1
96b25ea
910ee4e
02caeb4
707803b
33df6f7
2557e19
48e231d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,42 +1,31 @@ | ||
name = "LogDensityProblemsAD" | ||
uuid = "996a588d-648d-4e1f-a8f0-a84b347e47b1" | ||
authors = ["Tamás K. Papp <[email protected]>"] | ||
version = "1.9.0" | ||
version = "2.0.0" | ||
|
||
[deps] | ||
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | ||
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" | ||
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" | ||
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" | ||
Requires = "ae029012-a4dd-5104-9daa-d747884805df" | ||
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" | ||
|
||
[weakdeps] | ||
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | ||
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" | ||
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" | ||
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" | ||
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" | ||
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" | ||
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" | ||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
|
||
[extensions] | ||
LogDensityProblemsADADTypesExt = "ADTypes" | ||
LogDensityProblemsADEnzymeExt = "Enzyme" | ||
LogDensityProblemsADFiniteDifferencesExt = "FiniteDifferences" | ||
LogDensityProblemsADForwardDiffBenchmarkToolsExt = ["BenchmarkTools", "ForwardDiff"] | ||
LogDensityProblemsADForwardDiffExt = "ForwardDiff" | ||
LogDensityProblemsADReverseDiffExt = "ReverseDiff" | ||
LogDensityProblemsADTrackerExt = "Tracker" | ||
LogDensityProblemsADZygoteExt = "Zygote" | ||
|
||
[compat] | ||
ADTypes = "0.1.7, 0.2, 1" | ||
ADTypes = "1.1" | ||
DifferentiationInterface = "0.3.4" | ||
DocStringExtensions = "0.8, 0.9" | ||
Enzyme = "0.11, 0.12" | ||
FiniteDifferences = "0.12" | ||
LogDensityProblems = "1, 2" | ||
Requires = "0.5, 1" | ||
SimpleUnPack = "1" | ||
Requires = "1.3" | ||
julia = "1.6" | ||
|
||
[extras] | ||
|
This file was deleted.
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,78 +1,26 @@ | ||
""" | ||
Gradient AD implementation using Enzyme. | ||
""" | ||
module LogDensityProblemsADEnzymeExt | ||
|
||
if isdefined(Base, :get_extension) | ||
using LogDensityProblemsAD: ADGradientWrapper, logdensity | ||
using LogDensityProblemsAD.SimpleUnPack: @unpack | ||
|
||
import LogDensityProblemsAD: ADgradient, logdensity_and_gradient | ||
import Enzyme | ||
using ADTypes: AutoEnzyme | ||
using Enzyme: Reverse | ||
using LogDensityProblemsAD: LogDensityProblemsAD, ADgradient, logdensity | ||
else | ||
using ..LogDensityProblemsAD: ADGradientWrapper, logdensity | ||
using ..LogDensityProblemsAD.SimpleUnPack: @unpack | ||
|
||
import ..LogDensityProblemsAD: ADgradient, logdensity_and_gradient | ||
import ..Enzyme | ||
end | ||
|
||
struct EnzymeGradientLogDensity{L,M<:Union{Enzyme.ForwardMode,Enzyme.ReverseMode},S} <: ADGradientWrapper | ||
ℓ::L | ||
mode::M | ||
shadow::S # only used in forward mode | ||
using ..ADTypes: AutoEnzyme | ||
using ..Enzyme: Reverse | ||
using ..LogDensityProblemsAD: LogDensityProblemsAD, ADgradient, logdensity | ||
end | ||
|
||
""" | ||
ADgradient(:Enzyme, ℓ; kwargs...) | ||
ADgradient(Val(:Enzyme), ℓ; kwargs...) | ||
|
||
Gradient using algorithmic/automatic differentiation via Enzyme. | ||
|
||
# Keyword arguments | ||
|
||
- `mode::Enzyme.Mode`: Differentiation mode (default: `Enzyme.Reverse`). | ||
Currently only `Enzyme.Reverse` and `Enzyme.Forward` are supported. | ||
|
||
- `shadow`: Collection of one-hot vectors for each entry of the inputs `x` to the log density | ||
`ℓ`, or `nothing` (default: `nothing`). This keyword argument is only used in forward | ||
mode. By default, it will be recomputed in every call of `logdensity_and_gradient(ℓ, x)`. | ||
For performance reasons it is recommended to compute it only once when calling `ADgradient`. | ||
The one-hot vectors can be constructed, e.g., with `Enzyme.onehot(x)`. | ||
""" | ||
function ADgradient(::Val{:Enzyme}, ℓ; mode::Enzyme.Mode = Enzyme.Reverse, shadow = nothing) | ||
mode isa Union{Enzyme.ForwardMode,Enzyme.ReverseMode} || | ||
throw(ArgumentError("currently automatic differentiation via Enzyme only supports " * | ||
"`Enzyme.Forward` and `Enzyme.Reverse` modes")) | ||
if mode isa Enzyme.ReverseMode && shadow !== nothing | ||
@info "keyword argument `shadow` is ignored in reverse mode" | ||
shadow = nothing | ||
function LogDensityProblemsAD.ADgradient( | ||
::Val{:Enzyme}, | ||
ℓ; | ||
mode=Reverse, | ||
shadow=nothing, | ||
) | ||
if !isnothing(shadow) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the PR will lead to worse performance in downstream packages that work with pre-allocated/cached shadows? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The idea is that DifferentiationInterface has its own mechanism for preparing a gradient, which is triggered when you supply the constructor with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assumed there exists such an internal caching/preparation - but this will still break use cases, or at least lead to worse performance, in cases where currently a It seems there are more major issues with the Enzyme backend though: #29 (review) |
||
@warn "keyword argument `shadow` is now ignored" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we put this in a breaking release as suggested above, IMO we should just remove the keyword argument. Alternatively, in a non-breaking release I'd prefer to make it a deprecation warning (possibly one that is forced to be displayed). |
||
end | ||
return EnzymeGradientLogDensity(ℓ, mode, shadow) | ||
end | ||
|
||
function Base.show(io::IO, ∇ℓ::EnzymeGradientLogDensity) | ||
print(io, "Enzyme AD wrapper for ", ∇ℓ.ℓ, " with ", | ||
∇ℓ.mode isa Enzyme.ForwardMode ? "forward" : "reverse", " mode") | ||
end | ||
|
||
function logdensity_and_gradient(∇ℓ::EnzymeGradientLogDensity{<:Any,<:Enzyme.ForwardMode}, | ||
x::AbstractVector) | ||
@unpack ℓ, mode, shadow = ∇ℓ | ||
_shadow = shadow === nothing ? Enzyme.onehot(x) : shadow | ||
y, ∂ℓ_∂x = Enzyme.autodiff(mode, logdensity, Enzyme.BatchDuplicated, | ||
Enzyme.Const(ℓ), | ||
Enzyme.BatchDuplicated(x, _shadow)) | ||
return y, collect(∂ℓ_∂x) | ||
end | ||
|
||
function logdensity_and_gradient(∇ℓ::EnzymeGradientLogDensity{<:Any,<:Enzyme.ReverseMode}, | ||
x::AbstractVector) | ||
@unpack ℓ = ∇ℓ | ||
∂ℓ_∂x = zero(x) | ||
_, y = Enzyme.autodiff(Enzyme.ReverseWithPrimal, logdensity, Enzyme.Active, | ||
Enzyme.Const(ℓ), Enzyme.Duplicated(x, ∂ℓ_∂x)) | ||
y, ∂ℓ_∂x | ||
backend = AutoEnzyme(; mode) | ||
return ADgradient(backend, ℓ) | ||
end | ||
|
||
end # module |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,51 +1,19 @@ | ||
""" | ||
Gradient implementation using FiniteDifferences. | ||
""" | ||
Comment on lines
-1
to
-3
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Revert? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above, the information content is rather low |
||
module LogDensityProblemsADFiniteDifferencesExt | ||
|
||
if isdefined(Base, :get_extension) | ||
using LogDensityProblemsAD: ADGradientWrapper, logdensity | ||
using LogDensityProblemsAD.SimpleUnPack: @unpack | ||
|
||
import LogDensityProblemsAD: ADgradient, logdensity_and_gradient | ||
import FiniteDifferences | ||
using ADTypes: AutoFiniteDifferences | ||
import FiniteDifferences: central_fdm | ||
using LogDensityProblemsAD: LogDensityProblemsAD, ADgradient | ||
else | ||
using ..LogDensityProblemsAD: ADGradientWrapper, logdensity | ||
using ..LogDensityProblemsAD.SimpleUnPack: @unpack | ||
|
||
import ..LogDensityProblemsAD: ADgradient, logdensity_and_gradient | ||
import ..FiniteDifferences | ||
end | ||
|
||
struct FiniteDifferencesGradientLogDensity{L,M} <: ADGradientWrapper | ||
ℓ::L | ||
"finite difference method" | ||
fdm::M | ||
end | ||
|
||
""" | ||
ADgradient(:FiniteDifferences, ℓ; fdm = central_fdm(5, 1)) | ||
ADgradient(Val(:FiniteDifferences), ℓ; fdm = central_fdm(5, 1)) | ||
|
||
Gradient using FiniteDifferences, mainly intended for checking results from other algorithms. | ||
|
||
# Keyword arguments | ||
|
||
- `fdm`: the finite difference method. Defaults to `central_fdm(5, 1)`. | ||
""" | ||
function ADgradient(::Val{:FiniteDifferences}, ℓ; fdm = FiniteDifferences.central_fdm(5, 1)) | ||
FiniteDifferencesGradientLogDensity(ℓ, fdm) | ||
end | ||
|
||
function Base.show(io::IO, ∇ℓ::FiniteDifferencesGradientLogDensity) | ||
print(io, "FiniteDifferences AD wrapper for ", ∇ℓ.ℓ, " with ", ∇ℓ.fdm) | ||
using ..ADTypes: AutoFiniteDifferences | ||
import ..FiniteDifferences: central_fdm | ||
using ..LogDensityProblemsAD: LogDensityProblemsAD, ADgradient | ||
end | ||
|
||
function logdensity_and_gradient(∇ℓ::FiniteDifferencesGradientLogDensity, x::AbstractVector) | ||
@unpack ℓ, fdm = ∇ℓ | ||
y = logdensity(ℓ, x) | ||
∇y = only(FiniteDifferences.grad(fdm, Base.Fix1(logdensity, ℓ), x)) | ||
y, ∇y | ||
function LogDensityProblemsAD.ADgradient(::Val{:FiniteDifferences}, ℓ) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please re-add the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh right I had missed that one |
||
fdm = central_fdm(5, 1) | ||
backend = AutoFiniteDifferences(; fdm) | ||
ADgradient(backend, ℓ) | ||
end | ||
|
||
end # module |
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be kept?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's very uninformative, no one will ever see it, and this file no longer contains the actual gradient implementation