-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathLogDensityProblemsADADTypesExt.jl
82 lines (68 loc) · 3.46 KB
/
LogDensityProblemsADADTypesExt.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
module LogDensityProblemsADADTypesExt
if isdefined(Base, :get_extension)
import LogDensityProblemsAD
import ADTypes
else
import ..LogDensityProblemsAD
import ..ADTypes
end
"""
ADgradient(ad::ADTypes.AbstractADType, ℓ; x::Union{Nothing,AbstractVector}=nothing)
Wrap log density `ℓ` using automatic differentiation (AD) of type `ad` to obtain a gradient.
Currently,
- `ad::ADTypes.AutoEnzyme`
- `ad::ADTypes.AutoForwardDiff`
- `ad::ADTypes.AutoReverseDiff`
- `ad::ADTypes.AutoTracker`
- `ad::ADTypes.AutoZygote`
are supported with custom implementations.
The AD configuration specified by `ad` is forwarded to the corresponding calls of `ADgradient(Val(...), ℓ)`.
Passing `x` as a keyword argument means that the gradient operator will be "prepared" for the specific type and size of the array `x`. This can speed up further evaluations on similar inputs, but will likely cause errors if the new inputs have a different type or size. With `AutoReverseDiff`, it can also yield incorrect results if the logdensity contains value-dependent control flow.
If you want to use another backend from [ADTypes.jl](https://github.com/SciML/ADTypes.jl) which is not in the list above, you need to load [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl) first.
"""
LogDensityProblemsAD.ADgradient(::ADTypes.AbstractADType, ℓ)
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoEnzyme, ℓ; x::Union{Nothing,AbstractVector}=nothing)
if x !== nothing
@warn "`ADgradient`: Keyword argument `x` is ignored"
end
if ad.mode === nothing
# Use default mode (Enzyme.Reverse)
return LogDensityProblemsAD.ADgradient(Val(:Enzyme), ℓ)
else
return LogDensityProblemsAD.ADgradient(Val(:Enzyme), ℓ; mode=ad.mode)
end
end
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoForwardDiff{C}, ℓ; x::Union{Nothing,AbstractVector}=nothing) where {C}
if C === nothing
# Use default chunk size
return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; tag = ad.tag, x=x)
else
return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk = C, tag = ad.tag, x=x)
end
end
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoReverseDiff{T}, ℓ; x::Union{Nothing,AbstractVector}=nothing) where {T}
return LogDensityProblemsAD.ADgradient(Val(:ReverseDiff), ℓ; compile = Val(T), x=x)
end
function LogDensityProblemsAD.ADgradient(::ADTypes.AutoTracker, ℓ; x::Union{Nothing,AbstractVector}=nothing)
if x !== nothing
@warn "`ADgradient`: Keyword argument `x` is ignored"
end
return LogDensityProblemsAD.ADgradient(Val(:Tracker), ℓ)
end
function LogDensityProblemsAD.ADgradient(::ADTypes.AutoZygote, ℓ; x::Union{Nothing,AbstractVector}=nothing)
if x !== nothing
@warn "`ADgradient`: Keyword argument `x` is ignored"
end
return LogDensityProblemsAD.ADgradient(Val(:Zygote), ℓ)
end
# Better error message if users forget to load DifferentiationInterface
if isdefined(Base.Experimental, :register_error_hint)
function __init__()
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _
if exc.f === LogDensityProblemsAD.ADgradient && length(argtypes) == 2 && first(argtypes) <: ADTypes.AbstractADType
print(io, "\nDon't know how to AD with $(nameof(first(argtypes))). Did you forget to load DifferentiationInterface?")
end
end
end
end
end # module