diff --git a/Project.toml b/Project.toml index 9edcf0afc..5a56f294f 100644 --- a/Project.toml +++ b/Project.toml @@ -41,6 +41,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb" +GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" @@ -55,6 +56,7 @@ DiffEqBaseChainRulesCoreExt = "ChainRulesCore" DiffEqBaseDistributionsExt = "Distributions" DiffEqBaseEnzymeExt = ["ChainRulesCore", "Enzyme"] DiffEqBaseGeneralizedGeneratedExt = "GeneralizedGenerated" +DiffEqBaseGTPSAExt = "GTPSA" DiffEqBaseMPIExt = "MPI" DiffEqBaseMeasurementsExt = "Measurements" DiffEqBaseMonteCarloMeasurementsExt = "MonteCarloMeasurements" @@ -81,6 +83,7 @@ ForwardDiff = "0.10" FunctionWrappers = "1.0" FunctionWrappersWrappers = "0.1" GeneralizedGenerated = "0.3" +GTPSA = "1.3" LinearAlgebra = "1.9" Logging = "1.9" MPI = "0.20" diff --git a/ext/DiffEqBaseGTPSAExt.jl b/ext/DiffEqBaseGTPSAExt.jl new file mode 100644 index 000000000..655b5002e --- /dev/null +++ b/ext/DiffEqBaseGTPSAExt.jl @@ -0,0 +1,17 @@ +module DiffEqBaseGTPSAExt + +if isdefined(Base, :get_extension) + using DiffEqBase + import DiffEqBase: value + using GTPSA +else + using ..DiffEqBase + import ..DiffEqBase: value + using ..GTPSA +end + +value(x::TPS) = scalar(x); +value(::Type{TPS{T}}) where {T} = T + + +end \ No newline at end of file diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 49c49e10d..c4e010943 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -5,6 +5,7 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" MultiScaleArrays = "f9640e96-87f6-5992-9c3b-0743c6a49ffa" diff --git a/test/downstream/gtpsa.jl b/test/downstream/gtpsa.jl new file mode 100644 index 000000000..90f5b06c3 --- /dev/null +++ b/test/downstream/gtpsa.jl @@ -0,0 +1,39 @@ +using OrdinaryDiffEq, ForwardDiff, GTPSA, Test + +f!(du, u, p, t) = du .= p .* u + +# Initial variables and parameters +x = [1.0, 2.0, 3.0] +p = [4.0, 5.0, 6.0] + +prob = ODEProblem(f!, x, (0.0, 1.0), p) +sol = solve(prob, Tsit5(), reltol=1e-16, abstol=1e-16) + +# Parametric GTPSA map +desc = Descriptor(3, 2, 3, 2) # 3 variables 3 parameters, both to 2nd order +dx = vars(desc) +dp = params(desc) +prob_GTPSA = ODEProblem(f!, x .+ dx, (0.0, 1.0), p .+ dp) +sol_GTPSA = solve(prob_GTPSA, Tsit5(), reltol=1e-16, abstol=1e-16) + +@test sol.u[end] ≈ scalar.(sol_GTPSA.u[end]) # scalar gets 0th order part + +# Compare Jacobian against ForwardDiff +J_FD = ForwardDiff.jacobian([x..., p...]) do t + prob = ODEProblem(f!, t[1:3], (0.0, 1.0), t[4:6]) + sol = solve(prob, Tsit5(), reltol=1e-16, abstol=1e-16) + sol.u[end] +end + +@test J_FD ≈ GTPSA.jacobian(sol_GTPSA.u[end], include_params=true) + +# Compare Hessians against ForwardDiff +for i in 1:3 + Hi_FD = ForwardDiff.hessian([x..., p...]) do t + prob = ODEProblem(f!, t[1:3], (0.0, 1.0), t[4:6]) + sol = solve(prob, Tsit5(), reltol=1e-16, abstol=1e-16) + sol.u[end][i] + end + @test Hi_FD ≈ GTPSA.hessian(sol_GTPSA.u[end][i], include_params=true) +end + diff --git a/test/runtests.jl b/test/runtests.jl index c33fa8170..4cfa3d9a1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -55,6 +55,7 @@ end @time @safetestset "Default linsolve with structure" include("downstream/default_linsolve_structure.jl") @time @safetestset "Callback Merging Tests" include("downstream/callback_merging.jl") @time @safetestset "LabelledArrays Tests" include("downstream/labelledarrays.jl") + @time @safetestset "GTPSA Tests" include("downstream/gtpsa.jl") end if !is_APPVEYOR && GROUP == "Downstream2"