Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix initialization of DiscreteProblem(::JumpSystem) #3329

Merged
merged 5 commits into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ NonlinearSolve = "4.3"
OffsetArrays = "1"
OrderedCollections = "1"
OrdinaryDiffEq = "6.82.0"
OrdinaryDiffEqCore = "1.13.0"
OrdinaryDiffEqCore = "1.15.0"
OrdinaryDiffEqDefault = "1.2"
OrdinaryDiffEqNonlinearSolve = "1.3.0"
PrecompileTools = "1"
Expand All @@ -132,7 +132,7 @@ RecursiveArrayTools = "3.26"
Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.5.9"
SCCNonlinearSolve = "1.0.0"
SciMLBase = "2.68.1"
SciMLBase = "2.71"
SciMLStructures = "1.0"
Serialization = "1"
Setfield = "0.7, 0.8, 1"
Expand Down
5 changes: 3 additions & 2 deletions src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -425,14 +425,15 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
error("The passed in JumpSystem contains `Equation`s or continuous events, please use a problem type that supports these features, such as ODEProblem.")
end

_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
_f, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT

observedfun = ObservedFunctionCache(
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

df = DiscreteFunction{true, true}(f; sys = sys, observed = observedfun)
df = DiscreteFunction{true, true}(f; sys = sys, observed = observedfun,
initialization_data = get(_f.kwargs, :initialization_data, nothing))
DiscreteProblem(df, u0, tspan, p; kwargs...)
end

Expand Down
4 changes: 3 additions & 1 deletion src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ function generate_initializesystem(sys::AbstractSystem;
check_units = true, check_defguess = false,
name = nameof(sys), extra_metadata = (;), kwargs...)
eqs = equations(sys)
eqs = filter(x -> x isa Equation, eqs)
if !(eqs isa Vector{Equation})
eqs = Equation[x for x in eqs if x isa Equation]
end
trueobs, eqs = unhack_observed(observed(sys), eqs)
vars = unique([unknowns(sys); getfield.(trueobs, :lhs)])
vars_set = Set(vars) # for efficient in-lookup
Expand Down
28 changes: 27 additions & 1 deletion test/initializationsystem.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using ModelingToolkit, OrdinaryDiffEq, NonlinearSolve, Test
using StochasticDiffEq, DelayDiffEq, StochasticDelayDiffEq
using StochasticDiffEq, DelayDiffEq, StochasticDelayDiffEq, JumpProcesses
using ForwardDiff
using SymbolicIndexingInterface, SciMLStructures
using SciMLStructures: Tunable
Expand Down Expand Up @@ -1306,3 +1306,29 @@ end
@test integ[X] ≈ 4.0
@test integ[Y] ≈ 7.0
end

@testset "Issue#3297: `generate_initializesystem(::JumpSystem)`" begin
@parameters β γ S0
@variables S(t)=S0 I(t) R(t)
rate₁ = β * S * I
affect₁ = [S ~ S - 1, I ~ I + 1]
rate₂ = γ * I
affect₂ = [I ~ I - 1, R ~ R + 1]
j₁ = ConstantRateJump(rate₁, affect₁)
j₂ = ConstantRateJump(rate₂, affect₂)
j₃ = MassActionJump(2 * β + γ, [R => 1], [S => 1, R => -1])
@mtkbuild js = JumpSystem([j₁, j₂, j₃], t, [S, I, R], [β, γ, S0])

u0s = [I => 1, R => 0]
ps = [S0 => 999, β => 0.01, γ => 0.001]
dprob = DiscreteProblem(js, u0s, (0.0, 10.0), ps)
@test dprob.f.initialization_data !== nothing
sol = solve(dprob, FunctionMap())
isaacsas marked this conversation as resolved.
Show resolved Hide resolved
@test sol[S, 1] ≈ 999
@test SciMLBase.successful_retcode(sol)

jprob = JumpProblem(js, dprob)
sol = solve(jprob, SSAStepper())
@test sol[S, 1] ≈ 999
@test SciMLBase.successful_retcode(sol)
end
Loading