Skip to content

Commit

Permalink
Merge pull request #1078 from AayushSabharwal/as/reversediff
Browse files Browse the repository at this point in the history
fix: fix usage of ReverseDiff in parameters
  • Loading branch information
ChrisRackauckas authored Sep 4, 2024
2 parents 13ac2da + 4e60a87 commit 2627f2b
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 1 deletion.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Expand Down Expand Up @@ -95,6 +96,7 @@ Reexport = "1.0"
ReverseDiff = "1"
SciMLBase = "2.28.0"
SciMLOperators = "0.3"
SciMLStructures = "1.5"
Setfield = "1"
SparseArrays = "1.9"
Static = "1"
Expand Down Expand Up @@ -131,4 +133,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["Distributed", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "SymbolicIndexingInterface", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Plots", "Pkg", "Random", "StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions", "Aqua", "BenchmarkTools"]
test = ["Distributed", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "SymbolicIndexingInterface", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Plots", "Pkg", "Random", "ReverseDiff", "StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions", "Aqua", "BenchmarkTools"]
6 changes: 6 additions & 0 deletions ext/DiffEqBaseReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ using DiffEqBase
import DiffEqBase: value
import ReverseDiff
import DiffEqBase.ArrayInterface
import DiffEqBase.ForwardDiff

function DiffEqBase.anyeltypedual(::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {V, D, N, VA, DA, T <: ReverseDiff.TrackedArray{V, D, N, VA, DA}}
DiffEqBase.anyeltypedual(V, Val{counter})
end

DiffEqBase.value(x::Type{ReverseDiff.TrackedReal{V, D, O}}) where {V, D, O} = V
function DiffEqBase.value(x::Type{
Expand Down Expand Up @@ -33,6 +38,7 @@ function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal},
u0
end
DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ReverseDiff.track(u0)
DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray{T}, t0) where {T <: ForwardDiff.Dual} = ReverseDiff.track(T.(u0))
DiffEqBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = eltype(p).(u0)

# Support adaptive with non-tracked time
Expand Down
2 changes: 2 additions & 0 deletions src/DiffEqBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ import SciMLBase: solve, init, step!, solve!, __init, __solve, update_coefficien

import SciMLBase: AbstractDiffEqLinearOperator # deprecation path

import SciMLStructures

import Tricks

using Reexport
Expand Down
12 changes: 12 additions & 0 deletions src/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,12 @@ DiffEqBase.anyeltypedual(f::SciMLBase.AbstractSciMLFunction, ::Type{Val{counter}
@inline promote_u0(::Nothing, p, t0) = nothing

@inline function promote_u0(u0, p, t0)
if SciMLStructures.isscimlstructure(p)
_p = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]
if _p != p
return promote_u0(u0, _p, t0)
end
end
Tu = eltype(u0)
if Tu <: ForwardDiff.Dual
return u0
Expand All @@ -373,6 +379,12 @@ DiffEqBase.anyeltypedual(f::SciMLBase.AbstractSciMLFunction, ::Type{Val{counter}
end

@inline function promote_u0(u0::AbstractArray{<:Complex}, p, t0)
if SciMLStructures.isscimlstructure(p)
_p = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]
if _p != p
return promote_u0(u0, _p, t0)
end
end
Tu = real(eltype(u0))
if Tu <: ForwardDiff.Dual
return u0
Expand Down
18 changes: 18 additions & 0 deletions test/forwarddiff_dual_detection.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using DiffEqBase, ForwardDiff, Test, InteractiveUtils
using ReverseDiff, SciMLStructures
using Plots

u0 = 2.0
Expand Down Expand Up @@ -348,3 +349,20 @@ foo = SciMLBase.build_solution(
prob, DiffEqBase.InternalEuler.FwdEulerAlg(), [u0, u0], [0.0, 1.0])
DiffEqBase.anyeltypedual((; x = foo))
DiffEqBase.anyeltypedual((; x = foo, y = prob.f))

@test DiffEqBase.anyeltypedual(ReverseDiff.track(ones(3))) == Any
@test DiffEqBase.anyeltypedual(typeof(ReverseDiff.track(ones(3)))) == Any
@test DiffEqBase.anyeltypedual(ReverseDiff.track(ones(ForwardDiff.Dual, 3))) == eltype(ones(ForwardDiff.Dual, 3))
@test DiffEqBase.anyeltypedual(typeof(ReverseDiff.track(ones(ForwardDiff.Dual, 3)))) == eltype(ones(ForwardDiff.Dual, 3))

struct FakeParameterObject{T}
tunables::T
end

SciMLStructures.isscimlstructure(::FakeParameterObject) = true
SciMLStructures.canonicalize(::SciMLStructures.Tunable, f::FakeParameterObject) = f.tunables, x -> FakeParameterObject(x), true

@test DiffEqBase.promote_u0(ones(3), FakeParameterObject(ReverseDiff.track(ones(3))), 0.0) isa ReverseDiff.TrackedArray
@test DiffEqBase.promote_u0(1.0, FakeParameterObject(ReverseDiff.track(ones(3))), 0.0) isa ReverseDiff.TrackedReal
@test DiffEqBase.promote_u0(ones(3), FakeParameterObject(ReverseDiff.track(ones(ForwardDiff.Dual, 3))), 0.0) isa ReverseDiff.TrackedArray{<:ForwardDiff.Dual}
@test DiffEqBase.promote_u0(1.0, FakeParameterObject(ReverseDiff.track(ones(ForwardDiff.Dual, 3))), 0.0) isa ReverseDiff.TrackedReal{<:ForwardDiff.Dual}

0 comments on commit 2627f2b

Please sign in to comment.