Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support derivs of abs; add eps and clamp #12

Merged
merged 3 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion ext/ThickNumbersForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
module ThickNumbersForwardDiffExt

using ThickNumbers
using ForwardDiff: ForwardDiff, Dual, Partials, Tag
using ForwardDiff: ForwardDiff, Dual, Partials, Tag, value, partials, npartials, tagtype
using ForwardDiff.DiffRules: DiffRules, @define_diffrule
# Supports up to third-order derivatives
const ThickDual = Union{Dual{T,<:ThickNumber} where T,
Dual{T1,<:Dual{T2,<:ThickNumber}} where {T1,T2},
Dual{T1,<:Dual{T2,<:Dual{T3,<:ThickNumber}}} where {T1,T2,T3},
}
const ThickLike = Union{ThickNumber, ThickDual}

function ForwardDiff.derivative(f::F, x::TN) where {F,TN<:ThickNumber}
T = typeof(Tag(f, TN))
Expand All @@ -27,4 +34,24 @@ Base.promote_rule(::Type{TN}, ::Type{Dual{T,V,N}}) where {TN<:ThickNumber,T,V<:N
promote_dual(::Type{TN}, ::Type{V}) where {TN<:ThickNumber,V} = promote_type(TN, V)
promote_dual(::Type{TN}, ::Type{Dual{T,V,N}}) where {TN<:ThickNumber,T,V,N} = Dual{T, promote_dual(TN, V), N}

### Special functions

## First and higher-order derivatives of `abs`
function DiffRules._abs_deriv(x::ThickNumber)
sb = signbit(x)
lv, hv = loval(x), hival(x)
return lohi(basetype(typeof(x)), true ∈ sb ? -one(lv) : one(lv), false ∈ sb ? one(hv) : -one(hv))
end
# Second derivative of abs spans from 0 to either 0 or Inf (if 0 is included in the range)
_abs_deriv2(x::ThickNumber) = iszero(mig(x)) ? lohi(typeof(x), 0, typemax(valuetype(x))) : zero(x)
@define_diffrule DiffRules._abs_deriv(x) = :($(_abs_deriv2)($x))
eval(ForwardDiff.unary_dual_definition(:DiffRules, :_abs_deriv))
# Third and higher derivatives of abs span from -Inf to Inf, or is zero if 0 is not included in the range
_abs_deriv3(x::ThickNumber{T}) where T = iszero(mig(x)) ? lohi(typeof(x), typemin(T), typemax(T)) : zero(x)
@define_diffrule ThickNumbersForwardDiffExt._abs_deriv2(x) = :($(_abs_deriv3)($x))
@define_diffrule ThickNumbersForwardDiffExt._abs_deriv3(x) = :($(_abs_deriv3)($x))
eval(ForwardDiff.unary_dual_definition(:ThickNumbersForwardDiffExt, :_abs_deriv2))
eval(ForwardDiff.unary_dual_definition(:ThickNumbersForwardDiffExt, :_abs_deriv3))


end
12 changes: 12 additions & 0 deletions src/ThickNumbers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ Interval{$Int}(1, 2)
```
"""
basetype(::Type{TN}) where TN<:ThickNumber = error("basetype not defined for $TN")
basetype(x::ThickNumber) = basetype(typeof(x))

# Optional specializations

Expand Down Expand Up @@ -526,5 +527,16 @@ Base.:(-)(a::TN) where TN<:ThickNumber{T} where T<:Integer = lohi(TN,
Base.Checked.checked_neg(loval(a))
)

# Functions

Base.eps(::Type{TN}) where TN<:ThickNumber{T} where T = (e = eps(T); lohi(TN, e, e))
Base.eps(x::ThickNumber) = lohi(typeof(x), eps(mig(x)), eps(mag(x)))

Base.signbit(x::ThickNumber) = lohi(basetype(typeof(x)), signbit(hival(x)), signbit(loval(x)))

Base.abs(x::ThickNumber) = lohi(typeof(x), mig(x), mag(x))

Base.clamp(x::ThickNumber, lo::Real, hi::Real) = lohi(basetype(x), max(lo, loval(x)), min(hi, hival(x)))


end # module
12 changes: 12 additions & 0 deletions test/extensions/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,16 @@ using IntervalArith
@test df2(b) ⩪ b
ddf2(x) = ForwardDiff.derivative(df2, x)
@test ddf2(b) ≐ 1

# abs
dabs(x) = ForwardDiff.derivative(abs, x)
ddabs(x) = ForwardDiff.derivative(dabs, x)
dddabs(x) = ForwardDiff.derivative(ddabs, x)
@test dabs(Interval(1.0, 2.0)) === Interval(1.0, 1.0)
@test ddabs(Interval(1.0, 2.0)) === Interval(0.0, 0.0)
@test dddabs(Interval(1.0, 2.0)) === Interval(0.0, 0.0)
@test dabs(Interval(-1.0, 2.0)) === Interval(-1.0, 1.0)
@test ddabs(Interval(-1.0, 2.0)) === Interval(0.0, Inf)
abs3 = dddabs(Interval(-1.0, 2.0))
@test abs3 === Interval(-Inf, Inf) || isnan_tn(abs3)
end
7 changes: 7 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ using IntervalArith
@test [Interval(1, 2), Interval(0, 1+eps())] ⩪ [Interval(1, 2*(1+eps())), Interval(0, 1)]
end

@testset "Specializations" begin
@test loval(signbit(Interval(-1, 1))) === false
@test hival(signbit(Interval(-1, 1))) === true
@test eps(Interval(1.0, 1000.0)) === Interval(eps(1.0), eps(1000.0))
@test clamp(Interval(-1, 4), 2, 3) === Interval(2, 3)
end

include(joinpath("extensions", "runtests.jl"))

cleanup()