Skip to content

Commit

Permalink
Add GTPSA integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mattsignorelli committed Jan 2, 2025
1 parent a26018c commit c2d5509
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 0 deletions.
1 change: 1 addition & 0 deletions test/downstream/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
MultiScaleArrays = "f9640e96-87f6-5992-9c3b-0743c6a49ffa"
Expand Down
39 changes: 39 additions & 0 deletions test/downstream/gtpsa.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using OrdinaryDiffEq, ForwardDiff, GTPSA, Test

f!(du, u, p, t) = du .= p .* u

# Initial variables and parameters
x = [1.0, 2.0, 3.0]
p = [4.0, 5.0, 6.0]

prob = ODEProblem(f!, x, (0.0, 1.0), p)
sol = solve(prob, Tsit5(), reltol=1e-16, abstol=1e-16)

# Parametric GTPSA map
desc = Descriptor(3, 2, 3, 2) # 3 variables 3 parameters, both to 2nd order
dx = vars(desc)
dp = params(desc)
prob_GTPSA = ODEProblem(f!, x .+ dx, (0.0, 1.0), p .+ dp)
sol_GTPSA = solve(prob_GTPSA, Tsit5(), reltol=1e-16, abstol=1e-16)

@test sol.u[end] scalar.(sol_GTPSA.u[end]) # scalar gets 0th order part

# Compare Jacobian against ForwardDiff
J_FD = ForwardDiff.jacobian([x..., p...]) do t
prob = ODEProblem(f!, t[1:3], (0.0, 1.0), t[4:6])
sol = solve(prob, Tsit5(), reltol=1e-16, abstol=1e-16)
sol.u[end]
end

@test J_FD GTPSA.jacobian(sol_GTPSA.u[end], include_params=true)

# Compare Hessians against ForwardDiff
for i in 1:3
Hi_FD = ForwardDiff.hessian([x..., p...]) do t
prob = ODEProblem(f!, t[1:3], (0.0, 1.0), t[4:6])
sol = solve(prob, Tsit5(), reltol=1e-16, abstol=1e-16)
sol.u[end][i]
end
@test Hi_FD GTPSA.hessian(sol_GTPSA.u[end][i], include_params=true)
end

1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ end
@time @safetestset "Default linsolve with structure" include("downstream/default_linsolve_structure.jl")
@time @safetestset "Callback Merging Tests" include("downstream/callback_merging.jl")
@time @safetestset "LabelledArrays Tests" include("downstream/labelledarrays.jl")
@time @safetestset "GTPSA Tests" include("downstream/gtpsa.jl")
end

if !is_APPVEYOR && GROUP == "Downstream2"
Expand Down

0 comments on commit c2d5509

Please sign in to comment.