Skip to content

Commit

Permalink
Merge pull request #42 from tpapp/dw/error_hints
Browse files Browse the repository at this point in the history
Add error hints
  • Loading branch information
tpapp authored Nov 18, 2024
2 parents e3401f2 + ef18c48 commit ee6f9fe
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LogDensityProblemsAD"
uuid = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
authors = ["Tamás K. Papp <[email protected]>"]
version = "1.12.0"
version = "1.13.0"

[deps]
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand Down
11 changes: 11 additions & 0 deletions ext/LogDensityProblemsADADTypesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,15 @@ function LogDensityProblemsAD.ADgradient(::ADTypes.AutoZygote, ℓ; x::Union{Not
return LogDensityProblemsAD.ADgradient(Val(:Zygote), ℓ)
end

# Better error message if users forget to load DifferentiationInterface
if isdefined(Base.Experimental, :register_error_hint)
function __init__()
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _
if exc.f === LogDensityProblemsAD.ADgradient && length(argtypes) == 2 && first(argtypes) <: ADTypes.AbstractADType
print(io, "\nDon't know how to AD with $(nameof(first(argtypes))). Did you forget to load DifferentiationInterface?")
end
end
end
end

end # module
14 changes: 11 additions & 3 deletions src/LogDensityProblemsAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,17 @@ The function `parent` can be used to retrieve the original argument.
"""
ADgradient(kind::Symbol, P; kwargs...) = ADgradient(Val{kind}(), P; kwargs...)

function ADgradient(v::Val{kind}, P; kwargs...) where kind
@info "Don't know how to AD with $(kind), consider `import $(kind)` if there is such a package."
throw(MethodError(ADgradient, (v, P)))
# Better error message if users forget to load the AD package
if isdefined(Base.Experimental, :register_error_hint)
_unval(::Type{Val{T}}) where {T} = T
function __init__()
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _
if exc.f === ADgradient && length(argtypes) == 2 && first(argtypes) <: Val
kind = _unval(first(argtypes))
print(io, "\nDon't know how to AD with $(kind), consider `import $(kind)` if there is such a package.")
end
end
end
end

#####
Expand Down
15 changes: 13 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import FiniteDifferences, ForwardDiff, Enzyme, Tracker, Zygote, ReverseDiff # ba
import ADTypes # load support for AD types with options
import BenchmarkTools # load the heuristic chunks code
using ComponentArrays: ComponentVector # test with other vector types
import DifferentiationInterface

struct EnzymeTestMode <: Enzyme.Mode{Enzyme.DefaultABI, false, false} end

Expand Down Expand Up @@ -71,6 +70,15 @@ struct TestTag end
# Allow tag type in gradient etc. calls of the log density function
ForwardDiff.checktag(::Type{ForwardDiff.Tag{TestTag, V}}, ::Base.Fix1{typeof(logdensity),typeof(TestLogDensity())}, ::AbstractArray{V}) where {V} = true

@testset "Missing DI for unsupported ADType" begin
msg = "Don't know how to AD with AutoFiniteDifferences. Did you forget to load DifferentiationInterface?"
adtype = ADTypes.AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(5, 1))
@test_throws msg ADgradient(adtype, TestLogDensity2())
@test_throws msg ADgradient(adtype, TestLogDensity2(); x=zeros(20))
end

import DifferentiationInterface

@testset "AD via ReverseDiff" begin
= TestLogDensity()

Expand Down Expand Up @@ -296,7 +304,10 @@ end

@testset "ADgradient missing method" begin
msg = "Don't know how to AD with Foo, consider `import Foo` if there is such a package."
@test_logs((:info, msg), @test_throws(MethodError, ADgradient(:Foo, TestLogDensity2())))
@test_throws msg ADgradient(:Foo, TestLogDensity2())
@test_throws msg ADgradient(:Foo, TestLogDensity2(); x=zeros(20))
@test_throws msg ADgradient(Val(:Foo), TestLogDensity2())
@test_throws msg ADgradient(Val(:Foo), TestLogDensity2(); x=zeros(20))
end

@testset "benchmark ForwardDiff chunk size" begin
Expand Down

2 comments on commit ee6f9fe

@tpapp
Copy link
Owner Author

@tpapp tpapp commented on ee6f9fe Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

Improve error message when the backend is not found.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/119677

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.13.0 -m "<description of version>" ee6f9fef6ee550cb35a3bd3abdbd860e342e7b60
git push origin v1.13.0

Please sign in to comment.