Skip to content

Commit

Permalink
Support derivs of abs; add eps and clamp (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy authored Oct 19, 2023
1 parent 45cf56e commit c432354
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 1 deletion.
29 changes: 28 additions & 1 deletion ext/ThickNumbersForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
module ThickNumbersForwardDiffExt

using ThickNumbers
using ForwardDiff: ForwardDiff, Dual, Partials, Tag
using ForwardDiff: ForwardDiff, Dual, Partials, Tag, value, partials, npartials, tagtype
using ForwardDiff.DiffRules: DiffRules, @define_diffrule
# Supports up to third-order derivatives
const ThickDual = Union{Dual{T,<:ThickNumber} where T,
Dual{T1,<:Dual{T2,<:ThickNumber}} where {T1,T2},
Dual{T1,<:Dual{T2,<:Dual{T3,<:ThickNumber}}} where {T1,T2,T3},
}
const ThickLike = Union{ThickNumber, ThickDual}

function ForwardDiff.derivative(f::F, x::TN) where {F,TN<:ThickNumber}
T = typeof(Tag(f, TN))
Expand All @@ -27,4 +34,24 @@ Base.promote_rule(::Type{TN}, ::Type{Dual{T,V,N}}) where {TN<:ThickNumber,T,V<:N
promote_dual(::Type{TN}, ::Type{V}) where {TN<:ThickNumber,V} = promote_type(TN, V)
promote_dual(::Type{TN}, ::Type{Dual{T,V,N}}) where {TN<:ThickNumber,T,V,N} = Dual{T, promote_dual(TN, V), N}

### Special functions

## First and higher-order derivatives of `abs`
function DiffRules._abs_deriv(x::ThickNumber)
sb = signbit(x)
lv, hv = loval(x), hival(x)
return lohi(basetype(typeof(x)), true sb ? -one(lv) : one(lv), false sb ? one(hv) : -one(hv))
end
# Second derivative of abs spans from 0 to either 0 or Inf (if 0 is included in the range)
_abs_deriv2(x::ThickNumber) = iszero(mig(x)) ? lohi(typeof(x), 0, typemax(valuetype(x))) : zero(x)
@define_diffrule DiffRules._abs_deriv(x) = :($(_abs_deriv2)($x))
eval(ForwardDiff.unary_dual_definition(:DiffRules, :_abs_deriv))
# Third and higher derivatives of abs span from -Inf to Inf, or is zero if 0 is not included in the range
_abs_deriv3(x::ThickNumber{T}) where T = iszero(mig(x)) ? lohi(typeof(x), typemin(T), typemax(T)) : zero(x)
@define_diffrule ThickNumbersForwardDiffExt._abs_deriv2(x) = :($(_abs_deriv3)($x))
@define_diffrule ThickNumbersForwardDiffExt._abs_deriv3(x) = :($(_abs_deriv3)($x))
eval(ForwardDiff.unary_dual_definition(:ThickNumbersForwardDiffExt, :_abs_deriv2))
eval(ForwardDiff.unary_dual_definition(:ThickNumbersForwardDiffExt, :_abs_deriv3))


end
12 changes: 12 additions & 0 deletions src/ThickNumbers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ Interval{$Int}(1, 2)
```
"""
basetype(::Type{TN}) where TN<:ThickNumber = error("basetype not defined for $TN")
basetype(x::ThickNumber) = basetype(typeof(x))

# Optional specializations

Expand Down Expand Up @@ -526,5 +527,16 @@ Base.:(-)(a::TN) where TN<:ThickNumber{T} where T<:Integer = lohi(TN,
Base.Checked.checked_neg(loval(a))
)

# Functions

Base.eps(::Type{TN}) where TN<:ThickNumber{T} where T = (e = eps(T); lohi(TN, e, e))
Base.eps(x::ThickNumber) = lohi(typeof(x), eps(mig(x)), eps(mag(x)))

Base.signbit(x::ThickNumber) = lohi(basetype(typeof(x)), signbit(hival(x)), signbit(loval(x)))

Base.abs(x::ThickNumber) = lohi(typeof(x), mig(x), mag(x))

Base.clamp(x::ThickNumber, lo::Real, hi::Real) = lohi(basetype(x), max(lo, loval(x)), min(hi, hival(x)))


end # module
12 changes: 12 additions & 0 deletions test/extensions/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,16 @@ using IntervalArith
@test df2(b) b
ddf2(x) = ForwardDiff.derivative(df2, x)
@test ddf2(b) 1

# abs
dabs(x) = ForwardDiff.derivative(abs, x)
ddabs(x) = ForwardDiff.derivative(dabs, x)
dddabs(x) = ForwardDiff.derivative(ddabs, x)
@test dabs(Interval(1.0, 2.0)) === Interval(1.0, 1.0)
@test ddabs(Interval(1.0, 2.0)) === Interval(0.0, 0.0)
@test dddabs(Interval(1.0, 2.0)) === Interval(0.0, 0.0)
@test dabs(Interval(-1.0, 2.0)) === Interval(-1.0, 1.0)
@test ddabs(Interval(-1.0, 2.0)) === Interval(0.0, Inf)
abs3 = dddabs(Interval(-1.0, 2.0))
@test abs3 === Interval(-Inf, Inf) || isnan_tn(abs3)
end
7 changes: 7 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ using IntervalArith
@test [Interval(1, 2), Interval(0, 1+eps())] [Interval(1, 2*(1+eps())), Interval(0, 1)]
end

@testset "Specializations" begin
@test loval(signbit(Interval(-1, 1))) === false
@test hival(signbit(Interval(-1, 1))) === true
@test eps(Interval(1.0, 1000.0)) === Interval(eps(1.0), eps(1000.0))
@test clamp(Interval(-1, 4), 2, 3) === Interval(2, 3)
end

include(joinpath("extensions", "runtests.jl"))

cleanup()

0 comments on commit c432354

Please sign in to comment.