Skip to content

Commit

Permalink
Add a ForwardDiff extension (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy authored Oct 10, 2023
1 parent f303001 commit 3b8e44c
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
matrix:
version:
- '1'
- '1.6'
- '1.9'
os:
- ubuntu-latest
arch:
Expand Down
11 changes: 9 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,18 @@ version = "1.0.0-DEV"
[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

[extensions]
ThickNumbersForwardDiffExt = "ForwardDiff"

[compat]
julia = "1.6"
julia = "1.9"

[extras]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["ForwardDiff", "Test"]
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@ using ThickNumbers

export Interval

struct Interval{T} <: ThickNumber{T}
# To avoid ambiguity with the ForwardDiff extension, it's easiest to be specific about the promotion of
# other Numbers against `Interval`
const BaseReals = Union{AbstractFloat, Integer, AbstractIrrational, Rational}

struct Interval{T<:Number} <: ThickNumber{T}
lo::T
hi::T
end
Interval(lo, hi) = Interval(promote(lo, hi)...)
Interval{T}(iv::Interval) where T = Interval{T}(iv.lo, iv.hi)
Interval{T}(x::Number) where T = Interval{T}(x, x)

ThickNumbers.loval(x::Interval) = x.lo
ThickNumbers.hival(x::Interval) = x.hi
Expand All @@ -20,10 +25,20 @@ ThickNumbers.midrad(::Type{Interval}, mid::T, rad::T) where T = midrad(Interval{
ThickNumbers.midrad(::Type{Interval}, mid, rad) = midrad(Interval, promote(mid, rad)...)

# Promotion of `valuetype`
Base.promote_rule(::Type{Interval{T}}, ::Type{Interval{S}}) where {T, S} = Interval{promote_type(T, S)}
Base.promote_rule(::Type{Interval{S}}, ::Type{Interval{T}}) where {S<:Number, T<:Number} = Interval{promote_type(T, S)}
Base.promote_rule(::Type{Interval{S}}, ::Type{T}) where {S<:Number, T<:BaseReals} = Interval{promote_type(T, S)}

# Very basic arithmetic needed for `norm` (this would be fleshed out in real applications)
Base.:+(x::Interval, y::Interval) = Interval(x.lo + y.lo, x.hi + y.hi)
Base.:/(x::Interval, y::Real) = Interval(x.lo / y, x.hi / y)
function Base.:*(x::Interval, y::Interval)
T = typeof(zero(valuetype(x))*zero(valuetype(y)))
(isempty(x) || isempty(y)) && return emptyset(Interval{T})
v1, v2, v3, v4 = x.lo*y.lo, x.hi*y.lo, x.lo*y.hi, x.hi*y.hi
v1, v2 = v1 > v2 ? (v2, v1) : (v1, v2)
v3, v4 = v3 > v4 ? (v4, v3) : (v3, v4)
return Interval(min(v1, v3), max(v2, v4))
end
Base.abs2(x::Interval) = Interval(mig(x)^2, mag(x)^2)
Base.sqrt(x::Interval) = Interval(sqrt(loval(x)), sqrt(hival(x)))

Expand Down
30 changes: 30 additions & 0 deletions ext/ThickNumbersForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
module ThickNumbersForwardDiffExt

using ThickNumbers
using ForwardDiff: ForwardDiff, Dual, Partials, Tag

function ForwardDiff.derivative(f::F, x::TN) where {F,TN<:ThickNumber}
T = typeof(Tag(f, TN))
return ForwardDiff.extract_derivative(T, f(Dual{T}(x, one(x))))
end

ForwardDiff.can_dual(::Type{<:ThickNumber}) = true

function ForwardDiff.dual_definition_retval(::Val{T}, val::ThickNumber, deriv::ThickNumber, partial::Partials) where {T}
return Dual{T}(val, deriv * partial)
end
function ForwardDiff.dual_definition_retval(::Val{T}, val::ThickNumber, deriv1::ThickNumber, partial1::Partials, deriv2::ThickNumber, partial2::Partials) where {T}
return Dual{T}(val, ForwardDiff._mul_partials(partial1, partial2, deriv1, deriv2))
end

Base.:*(x::ThickNumber, partials::Partials) = partials * x
function Base.:*(partials::Partials, x::ThickNumber)
return Partials(ForwardDiff.scale_tuple(partials.values, x))
end

Base.promote_rule(::Type{TN}, ::Type{Dual{T,V,N}}) where {TN<:ThickNumber,T,V<:Number,N} = Dual{T, promote_dual(TN, V),N}

promote_dual(::Type{TN}, ::Type{V}) where {TN<:ThickNumber,V} = promote_type(TN, V)
promote_dual(::Type{TN}, ::Type{Dual{T,V,N}}) where {TN<:ThickNumber,T,V,N} = Dual{T, promote_dual(TN, V), N}

end
23 changes: 23 additions & 0 deletions test/extensions/forwarddiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using ThickNumbers
using ForwardDiff
using Test

include(joinpath(dirname(@__DIR__), "setpath.jl"))

using IntervalArith

@testset "ForwardDiff extension" begin
@test isempty(detect_ambiguities(ThickNumbers))
@test isempty(detect_ambiguities(IntervalArith))

a, b = Interval(1, 2), Interval(0, 0.1)
f1(t) = a + t*b
f2(x) = a + abs2(x)/2

df1(t) = ForwardDiff.derivative(f1, t)
df2(x) = ForwardDiff.derivative(f2, x)
@test df1(0.5) b
@test df2(b) b
ddf2(x) = ForwardDiff.derivative(df2, x)
@test ddf2(b) 1
end
1 change: 1 addition & 0 deletions test/extensions/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include("forwarddiff.jl")
14 changes: 4 additions & 10 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
using ThickNumbers
using Test

const interfacetestsdir = abspath(joinpath(dirname(@__DIR__), "ThickNumbersInterfaceTests"))
const testpackagesdir = joinpath(interfacetestsdir, "test", "testpackages")

if testpackagesdir LOAD_PATH
push!(LOAD_PATH, testpackagesdir)
end
include("setpath.jl")

using IntervalArith

Expand Down Expand Up @@ -118,7 +113,6 @@ using IntervalArith
@test [Interval(1, 2), Interval(0, 1+eps())] [Interval(1, 2*(1+eps())), Interval(0, 1)]
end

filter!(LOAD_PATH) do path
path != testpackagesdir && path != interfacetestsdir
end
nothing
include(joinpath("extensions", "runtests.jl"))

cleanup()
15 changes: 15 additions & 0 deletions test/setpath.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
if !isdefined(@__MODULE__, :testpackagesdir)
const interfacetestsdir = abspath(joinpath(dirname(@__DIR__), "ThickNumbersInterfaceTests"))
const testpackagesdir = joinpath(interfacetestsdir, "test", "testpackages")

if testpackagesdir LOAD_PATH
push!(LOAD_PATH, testpackagesdir)
end

function cleanup()
filter!(LOAD_PATH) do path
path != testpackagesdir && path != interfacetestsdir
end
return nothing
end
end

0 comments on commit 3b8e44c

Please sign in to comment.