Skip to content

Commit

Permalink
Merge pull request #1107 from mattsignorelli/addgtpsa
Browse files Browse the repository at this point in the history
Add GTPSA extension
  • Loading branch information
ChrisRackauckas authored Jan 3, 2025
2 parents ceec4c0 + c2d5509 commit a9f73bb
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 0 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
Expand All @@ -55,6 +56,7 @@ DiffEqBaseChainRulesCoreExt = "ChainRulesCore"
DiffEqBaseDistributionsExt = "Distributions"
DiffEqBaseEnzymeExt = ["ChainRulesCore", "Enzyme"]
DiffEqBaseGeneralizedGeneratedExt = "GeneralizedGenerated"
DiffEqBaseGTPSAExt = "GTPSA"
DiffEqBaseMPIExt = "MPI"
DiffEqBaseMeasurementsExt = "Measurements"
DiffEqBaseMonteCarloMeasurementsExt = "MonteCarloMeasurements"
Expand All @@ -81,6 +83,7 @@ ForwardDiff = "0.10"
FunctionWrappers = "1.0"
FunctionWrappersWrappers = "0.1"
GeneralizedGenerated = "0.3"
GTPSA = "1.3"
LinearAlgebra = "1.9"
Logging = "1.9"
MPI = "0.20"
Expand Down
17 changes: 17 additions & 0 deletions ext/DiffEqBaseGTPSAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module DiffEqBaseGTPSAExt

if isdefined(Base, :get_extension)
using DiffEqBase
import DiffEqBase: value
using GTPSA
else
using ..DiffEqBase
import ..DiffEqBase: value
using ..GTPSA
end

value(x::TPS) = scalar(x);
value(::Type{TPS{T}}) where {T} = T


end
1 change: 1 addition & 0 deletions test/downstream/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
MultiScaleArrays = "f9640e96-87f6-5992-9c3b-0743c6a49ffa"
Expand Down
39 changes: 39 additions & 0 deletions test/downstream/gtpsa.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using OrdinaryDiffEq, ForwardDiff, GTPSA, Test

f!(du, u, p, t) = du .= p .* u

# Initial variables and parameters
x = [1.0, 2.0, 3.0]
p = [4.0, 5.0, 6.0]

prob = ODEProblem(f!, x, (0.0, 1.0), p)
sol = solve(prob, Tsit5(), reltol=1e-16, abstol=1e-16)

# Parametric GTPSA map
desc = Descriptor(3, 2, 3, 2) # 3 variables 3 parameters, both to 2nd order
dx = vars(desc)
dp = params(desc)
prob_GTPSA = ODEProblem(f!, x .+ dx, (0.0, 1.0), p .+ dp)
sol_GTPSA = solve(prob_GTPSA, Tsit5(), reltol=1e-16, abstol=1e-16)

@test sol.u[end] scalar.(sol_GTPSA.u[end]) # scalar gets 0th order part

# Compare Jacobian against ForwardDiff
J_FD = ForwardDiff.jacobian([x..., p...]) do t
prob = ODEProblem(f!, t[1:3], (0.0, 1.0), t[4:6])
sol = solve(prob, Tsit5(), reltol=1e-16, abstol=1e-16)
sol.u[end]
end

@test J_FD GTPSA.jacobian(sol_GTPSA.u[end], include_params=true)

# Compare Hessians against ForwardDiff
for i in 1:3
Hi_FD = ForwardDiff.hessian([x..., p...]) do t
prob = ODEProblem(f!, t[1:3], (0.0, 1.0), t[4:6])
sol = solve(prob, Tsit5(), reltol=1e-16, abstol=1e-16)
sol.u[end][i]
end
@test Hi_FD GTPSA.hessian(sol_GTPSA.u[end][i], include_params=true)
end

1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ end
@time @safetestset "Default linsolve with structure" include("downstream/default_linsolve_structure.jl")
@time @safetestset "Callback Merging Tests" include("downstream/callback_merging.jl")
@time @safetestset "LabelledArrays Tests" include("downstream/labelledarrays.jl")
@time @safetestset "GTPSA Tests" include("downstream/gtpsa.jl")
end

if !is_APPVEYOR && GROUP == "Downstream2"
Expand Down

0 comments on commit a9f73bb

Please sign in to comment.