From c4323540a65cec47ac6b16915301a817dcf120b4 Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Thu, 19 Oct 2023 05:48:27 -0500 Subject: [PATCH] Support derivs of `abs`; add `eps` and `clamp` (#12) --- ext/ThickNumbersForwardDiffExt.jl | 29 ++++++++++++++++++++++++++++- src/ThickNumbers.jl | 12 ++++++++++++ test/extensions/forwarddiff.jl | 12 ++++++++++++ test/runtests.jl | 7 +++++++ 4 files changed, 59 insertions(+), 1 deletion(-) diff --git a/ext/ThickNumbersForwardDiffExt.jl b/ext/ThickNumbersForwardDiffExt.jl index 469d9d3..09b0920 100644 --- a/ext/ThickNumbersForwardDiffExt.jl +++ b/ext/ThickNumbersForwardDiffExt.jl @@ -1,7 +1,14 @@ module ThickNumbersForwardDiffExt using ThickNumbers -using ForwardDiff: ForwardDiff, Dual, Partials, Tag +using ForwardDiff: ForwardDiff, Dual, Partials, Tag, value, partials, npartials, tagtype +using ForwardDiff.DiffRules: DiffRules, @define_diffrule +# Supports up to third-order derivatives +const ThickDual = Union{Dual{T,<:ThickNumber} where T, + Dual{T1,<:Dual{T2,<:ThickNumber}} where {T1,T2}, + Dual{T1,<:Dual{T2,<:Dual{T3,<:ThickNumber}}} where {T1,T2,T3}, +} +const ThickLike = Union{ThickNumber, ThickDual} function ForwardDiff.derivative(f::F, x::TN) where {F,TN<:ThickNumber} T = typeof(Tag(f, TN)) @@ -27,4 +34,24 @@ Base.promote_rule(::Type{TN}, ::Type{Dual{T,V,N}}) where {TN<:ThickNumber,T,V<:N promote_dual(::Type{TN}, ::Type{V}) where {TN<:ThickNumber,V} = promote_type(TN, V) promote_dual(::Type{TN}, ::Type{Dual{T,V,N}}) where {TN<:ThickNumber,T,V,N} = Dual{T, promote_dual(TN, V), N} +### Special functions + +## First and higher-order derivatives of `abs` +function DiffRules._abs_deriv(x::ThickNumber) + sb = signbit(x) + lv, hv = loval(x), hival(x) + return lohi(basetype(typeof(x)), true ∈ sb ? -one(lv) : one(lv), false ∈ sb ? one(hv) : -one(hv)) +end +# Second derivative of abs spans from 0 to either 0 or Inf (if 0 is included in the range) +_abs_deriv2(x::ThickNumber) = iszero(mig(x)) ? lohi(typeof(x), 0, typemax(valuetype(x))) : zero(x) +@define_diffrule DiffRules._abs_deriv(x) = :($(_abs_deriv2)($x)) +eval(ForwardDiff.unary_dual_definition(:DiffRules, :_abs_deriv)) +# Third and higher derivatives of abs span from -Inf to Inf, or is zero if 0 is not included in the range +_abs_deriv3(x::ThickNumber{T}) where T = iszero(mig(x)) ? lohi(typeof(x), typemin(T), typemax(T)) : zero(x) +@define_diffrule ThickNumbersForwardDiffExt._abs_deriv2(x) = :($(_abs_deriv3)($x)) +@define_diffrule ThickNumbersForwardDiffExt._abs_deriv3(x) = :($(_abs_deriv3)($x)) +eval(ForwardDiff.unary_dual_definition(:ThickNumbersForwardDiffExt, :_abs_deriv2)) +eval(ForwardDiff.unary_dual_definition(:ThickNumbersForwardDiffExt, :_abs_deriv3)) + + end diff --git a/src/ThickNumbers.jl b/src/ThickNumbers.jl index 5aa19af..08f638e 100644 --- a/src/ThickNumbers.jl +++ b/src/ThickNumbers.jl @@ -157,6 +157,7 @@ Interval{$Int}(1, 2) ``` """ basetype(::Type{TN}) where TN<:ThickNumber = error("basetype not defined for $TN") +basetype(x::ThickNumber) = basetype(typeof(x)) # Optional specializations @@ -526,5 +527,16 @@ Base.:(-)(a::TN) where TN<:ThickNumber{T} where T<:Integer = lohi(TN, Base.Checked.checked_neg(loval(a)) ) +# Functions + +Base.eps(::Type{TN}) where TN<:ThickNumber{T} where T = (e = eps(T); lohi(TN, e, e)) +Base.eps(x::ThickNumber) = lohi(typeof(x), eps(mig(x)), eps(mag(x))) + +Base.signbit(x::ThickNumber) = lohi(basetype(typeof(x)), signbit(hival(x)), signbit(loval(x))) + +Base.abs(x::ThickNumber) = lohi(typeof(x), mig(x), mag(x)) + +Base.clamp(x::ThickNumber, lo::Real, hi::Real) = lohi(basetype(x), max(lo, loval(x)), min(hi, hival(x))) + end # module diff --git a/test/extensions/forwarddiff.jl b/test/extensions/forwarddiff.jl index e84074e..6419109 100644 --- a/test/extensions/forwarddiff.jl +++ b/test/extensions/forwarddiff.jl @@ -20,4 +20,16 @@ using IntervalArith @test df2(b) ⩪ b ddf2(x) = ForwardDiff.derivative(df2, x) @test ddf2(b) ≐ 1 + + # abs + dabs(x) = ForwardDiff.derivative(abs, x) + ddabs(x) = ForwardDiff.derivative(dabs, x) + dddabs(x) = ForwardDiff.derivative(ddabs, x) + @test dabs(Interval(1.0, 2.0)) === Interval(1.0, 1.0) + @test ddabs(Interval(1.0, 2.0)) === Interval(0.0, 0.0) + @test dddabs(Interval(1.0, 2.0)) === Interval(0.0, 0.0) + @test dabs(Interval(-1.0, 2.0)) === Interval(-1.0, 1.0) + @test ddabs(Interval(-1.0, 2.0)) === Interval(0.0, Inf) + abs3 = dddabs(Interval(-1.0, 2.0)) + @test abs3 === Interval(-Inf, Inf) || isnan_tn(abs3) end diff --git a/test/runtests.jl b/test/runtests.jl index bd6fad7..ddab14c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -121,6 +121,13 @@ using IntervalArith @test [Interval(1, 2), Interval(0, 1+eps())] ⩪ [Interval(1, 2*(1+eps())), Interval(0, 1)] end +@testset "Specializations" begin + @test loval(signbit(Interval(-1, 1))) === false + @test hival(signbit(Interval(-1, 1))) === true + @test eps(Interval(1.0, 1000.0)) === Interval(eps(1.0), eps(1000.0)) + @test clamp(Interval(-1, 4), 2, 3) === Interval(2, 3) +end + include(joinpath("extensions", "runtests.jl")) cleanup()