From 3b8e44c3f6c634c27ff935d7975af7c517a1493e Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Tue, 10 Oct 2023 07:23:35 -0500 Subject: [PATCH] Add a ForwardDiff extension (#6) --- .github/workflows/CI.yml | 2 +- Project.toml | 11 +++++-- .../IntervalArith/src/IntervalArith.jl | 19 ++++++++++-- ext/ThickNumbersForwardDiffExt.jl | 30 +++++++++++++++++++ test/extensions/forwarddiff.jl | 23 ++++++++++++++ test/extensions/runtests.jl | 1 + test/runtests.jl | 14 +++------ test/setpath.jl | 15 ++++++++++ 8 files changed, 100 insertions(+), 15 deletions(-) create mode 100644 ext/ThickNumbersForwardDiffExt.jl create mode 100644 test/extensions/forwarddiff.jl create mode 100644 test/extensions/runtests.jl create mode 100644 test/setpath.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 1e4ecc4..ee19124 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -20,7 +20,7 @@ jobs: matrix: version: - '1' - - '1.6' + - '1.9' os: - ubuntu-latest arch: diff --git a/Project.toml b/Project.toml index f17e29b..e2ca394 100644 --- a/Project.toml +++ b/Project.toml @@ -6,11 +6,18 @@ version = "1.0.0-DEV" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +[weakdeps] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + +[extensions] +ThickNumbersForwardDiffExt = "ForwardDiff" + [compat] -julia = "1.6" +julia = "1.9" [extras] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["ForwardDiff", "Test"] diff --git a/ThickNumbersInterfaceTests/test/testpackages/IntervalArith/src/IntervalArith.jl b/ThickNumbersInterfaceTests/test/testpackages/IntervalArith/src/IntervalArith.jl index a2b6b6d..e4deeb1 100644 --- a/ThickNumbersInterfaceTests/test/testpackages/IntervalArith/src/IntervalArith.jl +++ b/ThickNumbersInterfaceTests/test/testpackages/IntervalArith/src/IntervalArith.jl @@ -4,12 +4,17 @@ using ThickNumbers export Interval -struct Interval{T} <: ThickNumber{T} +# To avoid ambiguity with the ForwardDiff extension, it's easiest to be specific about the promotion of +# other Numbers against `Interval` +const BaseReals = Union{AbstractFloat, Integer, AbstractIrrational, Rational} + +struct Interval{T<:Number} <: ThickNumber{T} lo::T hi::T end Interval(lo, hi) = Interval(promote(lo, hi)...) Interval{T}(iv::Interval) where T = Interval{T}(iv.lo, iv.hi) +Interval{T}(x::Number) where T = Interval{T}(x, x) ThickNumbers.loval(x::Interval) = x.lo ThickNumbers.hival(x::Interval) = x.hi @@ -20,10 +25,20 @@ ThickNumbers.midrad(::Type{Interval}, mid::T, rad::T) where T = midrad(Interval{ ThickNumbers.midrad(::Type{Interval}, mid, rad) = midrad(Interval, promote(mid, rad)...) # Promotion of `valuetype` -Base.promote_rule(::Type{Interval{T}}, ::Type{Interval{S}}) where {T, S} = Interval{promote_type(T, S)} +Base.promote_rule(::Type{Interval{S}}, ::Type{Interval{T}}) where {S<:Number, T<:Number} = Interval{promote_type(T, S)} +Base.promote_rule(::Type{Interval{S}}, ::Type{T}) where {S<:Number, T<:BaseReals} = Interval{promote_type(T, S)} # Very basic arithmetic needed for `norm` (this would be fleshed out in real applications) Base.:+(x::Interval, y::Interval) = Interval(x.lo + y.lo, x.hi + y.hi) +Base.:/(x::Interval, y::Real) = Interval(x.lo / y, x.hi / y) +function Base.:*(x::Interval, y::Interval) + T = typeof(zero(valuetype(x))*zero(valuetype(y))) + (isempty(x) || isempty(y)) && return emptyset(Interval{T}) + v1, v2, v3, v4 = x.lo*y.lo, x.hi*y.lo, x.lo*y.hi, x.hi*y.hi + v1, v2 = v1 > v2 ? (v2, v1) : (v1, v2) + v3, v4 = v3 > v4 ? (v4, v3) : (v3, v4) + return Interval(min(v1, v3), max(v2, v4)) +end Base.abs2(x::Interval) = Interval(mig(x)^2, mag(x)^2) Base.sqrt(x::Interval) = Interval(sqrt(loval(x)), sqrt(hival(x))) diff --git a/ext/ThickNumbersForwardDiffExt.jl b/ext/ThickNumbersForwardDiffExt.jl new file mode 100644 index 0000000..469d9d3 --- /dev/null +++ b/ext/ThickNumbersForwardDiffExt.jl @@ -0,0 +1,30 @@ +module ThickNumbersForwardDiffExt + +using ThickNumbers +using ForwardDiff: ForwardDiff, Dual, Partials, Tag + +function ForwardDiff.derivative(f::F, x::TN) where {F,TN<:ThickNumber} + T = typeof(Tag(f, TN)) + return ForwardDiff.extract_derivative(T, f(Dual{T}(x, one(x)))) +end + +ForwardDiff.can_dual(::Type{<:ThickNumber}) = true + +function ForwardDiff.dual_definition_retval(::Val{T}, val::ThickNumber, deriv::ThickNumber, partial::Partials) where {T} + return Dual{T}(val, deriv * partial) +end +function ForwardDiff.dual_definition_retval(::Val{T}, val::ThickNumber, deriv1::ThickNumber, partial1::Partials, deriv2::ThickNumber, partial2::Partials) where {T} + return Dual{T}(val, ForwardDiff._mul_partials(partial1, partial2, deriv1, deriv2)) +end + +Base.:*(x::ThickNumber, partials::Partials) = partials * x +function Base.:*(partials::Partials, x::ThickNumber) + return Partials(ForwardDiff.scale_tuple(partials.values, x)) +end + +Base.promote_rule(::Type{TN}, ::Type{Dual{T,V,N}}) where {TN<:ThickNumber,T,V<:Number,N} = Dual{T, promote_dual(TN, V),N} + +promote_dual(::Type{TN}, ::Type{V}) where {TN<:ThickNumber,V} = promote_type(TN, V) +promote_dual(::Type{TN}, ::Type{Dual{T,V,N}}) where {TN<:ThickNumber,T,V,N} = Dual{T, promote_dual(TN, V), N} + +end diff --git a/test/extensions/forwarddiff.jl b/test/extensions/forwarddiff.jl new file mode 100644 index 0000000..e84074e --- /dev/null +++ b/test/extensions/forwarddiff.jl @@ -0,0 +1,23 @@ +using ThickNumbers +using ForwardDiff +using Test + +include(joinpath(dirname(@__DIR__), "setpath.jl")) + +using IntervalArith + +@testset "ForwardDiff extension" begin + @test isempty(detect_ambiguities(ThickNumbers)) + @test isempty(detect_ambiguities(IntervalArith)) + + a, b = Interval(1, 2), Interval(0, 0.1) + f1(t) = a + t*b + f2(x) = a + abs2(x)/2 + + df1(t) = ForwardDiff.derivative(f1, t) + df2(x) = ForwardDiff.derivative(f2, x) + @test df1(0.5) ≐ b + @test df2(b) ⩪ b + ddf2(x) = ForwardDiff.derivative(df2, x) + @test ddf2(b) ≐ 1 +end diff --git a/test/extensions/runtests.jl b/test/extensions/runtests.jl new file mode 100644 index 0000000..07629ba --- /dev/null +++ b/test/extensions/runtests.jl @@ -0,0 +1 @@ +include("forwarddiff.jl") diff --git a/test/runtests.jl b/test/runtests.jl index 01859e9..23d7fbf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,12 +1,7 @@ using ThickNumbers using Test -const interfacetestsdir = abspath(joinpath(dirname(@__DIR__), "ThickNumbersInterfaceTests")) -const testpackagesdir = joinpath(interfacetestsdir, "test", "testpackages") - -if testpackagesdir ∉ LOAD_PATH - push!(LOAD_PATH, testpackagesdir) -end +include("setpath.jl") using IntervalArith @@ -118,7 +113,6 @@ using IntervalArith @test [Interval(1, 2), Interval(0, 1+eps())] ⩪ [Interval(1, 2*(1+eps())), Interval(0, 1)] end -filter!(LOAD_PATH) do path - path != testpackagesdir && path != interfacetestsdir -end -nothing +include(joinpath("extensions", "runtests.jl")) + +cleanup() diff --git a/test/setpath.jl b/test/setpath.jl new file mode 100644 index 0000000..fde3ca7 --- /dev/null +++ b/test/setpath.jl @@ -0,0 +1,15 @@ +if !isdefined(@__MODULE__, :testpackagesdir) + const interfacetestsdir = abspath(joinpath(dirname(@__DIR__), "ThickNumbersInterfaceTests")) + const testpackagesdir = joinpath(interfacetestsdir, "test", "testpackages") + + if testpackagesdir ∉ LOAD_PATH + push!(LOAD_PATH, testpackagesdir) + end + + function cleanup() + filter!(LOAD_PATH) do path + path != testpackagesdir && path != interfacetestsdir + end + return nothing + end +end