Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: always build initialization problem #3347

Open
wants to merge 36 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
a16368a
fix: handle case when parameter object is `nothing` in `GetUpdatedMTK…
AayushSabharwal Jan 21, 2025
7c3cda3
feat: update `u0map` and `pmap` in `build_operating_point`
AayushSabharwal Jan 21, 2025
e0288b6
feat: add dummy initialization parameters instead of hard constraints
AayushSabharwal Jan 21, 2025
da867ec
feat: always build initialization problem, handle dummy parameters
AayushSabharwal Jan 21, 2025
3156e81
refactor: allow disabling `gensym` in `similar_variable`
AayushSabharwal Jan 21, 2025
fd376e4
feat: count shift equations as differential equations
AayushSabharwal Jan 21, 2025
41a2400
test: provide missing parameters to `NonlinearProblem` test
AayushSabharwal Jan 21, 2025
88d85a7
feat: add `force_time_independent` keyword to `InitializationProblem`
AayushSabharwal Jan 21, 2025
17e9fe0
feat: expose `force_time_independent` through `process_SciMLProblem`
AayushSabharwal Jan 21, 2025
4dfd9d3
fix: remove delayed terms in initialization system of DDEs
AayushSabharwal Jan 21, 2025
c443665
fix: only add dummy parameters for unknowns of time-dependent systems
AayushSabharwal Jan 21, 2025
fb845d0
fix: change differential operator when building initialization for `D…
AayushSabharwal Jan 21, 2025
0b130e9
fix: handle recursive default edge case in initializesystem
AayushSabharwal Jan 21, 2025
496410b
fix: remove occurrences of dummy parameters when remaking initializeprob
AayushSabharwal Jan 21, 2025
aa922b2
fix: only set solvable parameter values to zero in limited cases
AayushSabharwal Jan 21, 2025
c9a1f01
fix: build initialization for `DiscreteProblem(::DiscreteSystem)`
AayushSabharwal Jan 21, 2025
4244e30
fix: ensure initialization systems for `SteadyStateProblem` do not in…
AayushSabharwal Jan 21, 2025
55cfe18
test: disambiguate `observed` in initializesystem tests
AayushSabharwal Jan 21, 2025
5b73c33
test: remove unnecessary test
AayushSabharwal Jan 21, 2025
e72fd2b
test: refactor `initialization_data === nothing` tests
AayushSabharwal Jan 21, 2025
cdde9aa
test: update initializeprob type promotion test
AayushSabharwal Jan 21, 2025
52c3a53
fix: don't build initialization for `JumpProblem`
AayushSabharwal Jan 23, 2025
83a8e1c
fix: don't use observed equations for initialization of `DiscreteSystem`
AayushSabharwal Jan 23, 2025
25339a6
build: bump SciMLBase version
AayushSabharwal Jan 23, 2025
7ebeb78
fix: populate observed into `u0map` for time-independent systems
AayushSabharwal Jan 23, 2025
2844170
fix: remove incorrect fix to dummy derivatives in `u0map` added in #3337
AayushSabharwal Jan 24, 2025
cf67f77
fix: fix `_eq_unordered` for multiple identical values in both `a` an…
AayushSabharwal Jan 24, 2025
456c056
fix: do not build initialization for `DiscreteProblem`
AayushSabharwal Jan 24, 2025
779864c
test: make `NonlinearSystem` initialization tests less temperamental
AayushSabharwal Jan 24, 2025
e1e53f5
test: mark initialization of `JumpSystem` as broken
AayushSabharwal Jan 24, 2025
80d1fa0
fix: fix `ReconstructInitializeprob` overriding guesses
AayushSabharwal Jan 24, 2025
fc9b773
feat: propagate `algebraic_only` kwarg in `InitializationProblem`
AayushSabharwal Jan 20, 2025
8a16286
fix: filter `nothing` values in `op`/`u0map`/`pmap` in `process_SciML…
AayushSabharwal Jan 20, 2025
f9704e6
feat: propagate `algebraic_only`, `allow_incomplete` kwargs to initia…
AayushSabharwal Jan 20, 2025
10784ba
fix: use `SciMLBase.get_initial_values` in `linearization_function`
AayushSabharwal Jan 20, 2025
93efab1
fix: do not forward keywords to `linearize` in `AnalysisPoint` linear…
AayushSabharwal Jan 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ RecursiveArrayTools = "3.26"
Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.5.9"
SCCNonlinearSolve = "1.0.0"
SciMLBase = "2.71.1"
SciMLBase = "2.72"
SciMLStructures = "1.0"
Serialization = "1"
Setfield = "0.7, 0.8, 1"
Expand Down
2 changes: 1 addition & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ using Compat
using AbstractTrees
using DiffEqBase, SciMLBase, ForwardDiff
using SciMLBase: StandardODEProblem, StandardNonlinearProblem, handle_varmap, TimeDomain,
PeriodicClock, Clock, SolverStepClock, Continuous
PeriodicClock, Clock, SolverStepClock, Continuous, OverrideInit, NoInit
using Distributed
import JuliaFormatter
using MLStyle
Expand Down
141 changes: 60 additions & 81 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2377,6 +2377,9 @@ See also [`linearize`](@ref) which provides a higher-level interface.
function linearization_function(sys::AbstractSystem, inputs,
outputs; simplify = false,
initialize = true,
initializealg = nothing,
initialization_abstol = 1e-6,
initialization_reltol = 1e-3,
op = Dict(),
p = DiffEqBase.NullParameters(),
zero_dummy_der = false,
Expand All @@ -2403,88 +2406,29 @@ function linearization_function(sys::AbstractSystem, inputs,
op = merge(defs, op)
end
sys = ssys
u0map = Dict(k => v for (k, v) in op if is_variable(ssys, k))
initsys = structural_simplify(
generate_initializesystem(
sys, u0map = u0map, guesses = guesses(sys), algebraic_only = true),
fully_determined = false)

# HACK: some unknowns may not be involved in any initialization equations, and are
# thus removed from the system during `structural_simplify`.
# This causes `getu(initsys, unknowns(sys))` to fail, so we add them back as parameters
# for now.
missing_unknowns = setdiff(unknowns(sys), all_symbols(initsys))
if !isempty(missing_unknowns)
if warn_initialize_determined
@warn "Initialization system is underdetermined. No equations for $(missing_unknowns). Initialization will default to using least squares. To suppress this warning pass warn_initialize_determined = false."
end
new_parameters = [parameters(initsys); missing_unknowns]
@set! initsys.ps = new_parameters
initsys = complete(initsys)
end

if p isa SciMLBase.NullParameters
p = Dict()
else
p = todict(p)
end
x0 = merge(defaults_and_guesses(sys), op)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
sys_ps = MTKParameters(sys, p, x0)
else
sys_ps = varmap_to_vars(p, parameters(sys); defaults = x0)
end
p[get_iv(sys)] = NaN
if has_index_cache(initsys) && get_index_cache(initsys) !== nothing
oldps = MTKParameters(initsys, p, merge(guesses(sys), defaults(sys), op))
initsys_ps = parameters(initsys)
p_getter = build_explicit_observed_function(
sys, initsys_ps; eval_expression, eval_module)

u_getter = isempty(unknowns(initsys)) ? (_...) -> nothing :
build_explicit_observed_function(
sys, unknowns(initsys); eval_expression, eval_module)
get_initprob_u_p = let p_getter = p_getter,
p_setter! = setp(initsys, initsys_ps),
u_getter = u_getter

function (u, p, t)
p_setter!(oldps, p_getter(u, p, t))
newu = u_getter(u, p, t)
return newu, oldps
end
end
else
get_initprob_u_p = let p_getter = getu(sys, parameters(initsys)),
u_getter = build_explicit_observed_function(
sys, unknowns(initsys); eval_expression, eval_module)

function (u, p, t)
state = ProblemState(; u, p, t)
return u_getter(
state_values(state), parameter_values(state), current_time(state)),
p_getter(state)
end
end

if initializealg === nothing
initializealg = initialize ? OverrideInit() : NoInit()
end
initfn = NonlinearFunction(initsys; eval_expression, eval_module)
initprobmap = build_explicit_observed_function(
initsys, unknowns(sys); eval_expression, eval_module)

fun, u0, p = process_SciMLProblem(ODEFunction{true, SciMLBase.FullSpecialize}, sys, op, p; t = 0.0, build_initializeprob = initializealg isa OverrideInit, allow_incomplete = true, algebraic_only = true)
prob = ODEProblem(fun, u0, (nothing, nothing), p)

ps = parameters(sys)
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
lin_fun = let diff_idxs = diff_idxs,
alge_idxs = alge_idxs,
input_idxs = input_idxs,
sts = unknowns(sys),
get_initprob_u_p = get_initprob_u_p,
fun = ODEFunction{true, SciMLBase.FullSpecialize}(
sys, unknowns(sys), ps; eval_expression, eval_module),
initfn = initfn,
initprobmap = initprobmap,
fun = fun,
prob = prob,
sys_ps = p,
h = h,
integ_cache = (similar(u0)),
chunk = ForwardDiff.Chunk(input_idxs),
sys_ps = sys_ps,
initialize = initialize,
initializealg = initializealg,
initialization_abstol = initialization_abstol,
initialization_reltol = initialization_reltol,
initialization_solver_alg = initialization_solver_alg,
sys = sys

Expand All @@ -2504,14 +2448,11 @@ function linearization_function(sys::AbstractSystem, inputs,
if u !== nothing # Handle systems without unknowns
length(sts) == length(u) ||
error("Number of unknown variables ($(length(sts))) does not match the number of input unknowns ($(length(u)))")
if initialize && !isempty(alge_idxs) # This is expensive and can be omitted if the user knows that the system is already initialized
residual = fun(u, p, t)
if norm(residual[alge_idxs]) > √(eps(eltype(residual)))
initu0, initp = get_initprob_u_p(u, p, t)
initprob = NonlinearLeastSquaresProblem(initfn, initu0, initp)
nlsol = solve(initprob, initialization_solver_alg)
u = initprobmap(state_values(nlsol), parameter_values(nlsol))
end

integ = MockIntegrator{true}(u, p, t, integ_cache)
u, p, success = SciMLBase.get_initial_values(prob, integ, fun, initializealg, Val(true); abstol = initialization_abstol, reltol = initialization_reltol, nlsolve_alg = initialization_solver_alg)
if !success
error("Initialization algorithm $(initializealg) failed with `u = $u` and `p = $p`.")
end
uf = SciMLBase.UJacobianWrapper(fun, t, p)
fg_xz = ForwardDiff.jacobian(uf, u)
Expand Down Expand Up @@ -2546,6 +2487,44 @@ function linearization_function(sys::AbstractSystem, inputs,
return lin_fun, sys
end

"""
$(TYPEDEF)

Mock `DEIntegrator` to allow using `CheckInit` without having to create a new integrator
(and consequently depend on `OrdinaryDiffEq`).

# Fields

$(TYPEDFIELDS)
"""
struct MockIntegrator{iip, U, P, T, C} <: SciMLBase.DEIntegrator{Nothing, iip, U, T}
"""
The state vector.
"""
u::U
"""
The parameter object.
"""
p::P
"""
The current time.
"""
t::T
"""
The integrator cache.
"""
cache::C
end

function MockIntegrator{iip}(u::U, p::P, t::T, cache::C) where {iip, U, P, T, C}
return MockIntegrator{iip, U, P, T, C}(u, p, t, cache)
end

SymbolicIndexingInterface.state_values(integ::MockIntegrator) = integ.u
SymbolicIndexingInterface.parameter_values(integ::MockIntegrator) = integ.p
SymbolicIndexingInterface.current_time(integ::MockIntegrator) = integ.t
SciMLBase.get_tmp_cache(integ::MockIntegrator) = integ.cache

"""
(; A, B, C, D), simplified_sys = linearize_symbolic(sys::AbstractSystem, inputs, outputs; simplify = false, allow_input_derivatives = false, kwargs...)

Expand Down
2 changes: 1 addition & 1 deletion src/systems/analysis_points.jl
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@
ap_idx = analysis_point_index(ap_sys, tf.ap)
ap_idx === nothing &&
error("Analysis point $(nameof(tf.ap)) not found in system $(nameof(sys)).")
# get the anlysis point

Check warning on line 513 in src/systems/analysis_points.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"anlysis" should be "analysis".
ap_sys_eqs = copy(get_eqs(ap_sys))
ap = ap_sys_eqs[ap_idx].rhs

Expand Down Expand Up @@ -564,7 +564,7 @@
ap_idx = analysis_point_index(ap_sys, tf.ap)
ap_idx === nothing &&
error("Analysis point $(nameof(tf.ap)) not found in system $(nameof(sys)).")
# modified quations

Check warning on line 567 in src/systems/analysis_points.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"quations" should be "equations".
ap_sys_eqs = copy(get_eqs(ap_sys))
@set! ap_sys.eqs = ap_sys_eqs
ap = ap_sys_eqs[ap_idx].rhs
Expand Down Expand Up @@ -863,7 +863,7 @@
sys, ap, args...; loop_openings = [], system_modifier = identity, kwargs...)
lin_fun, ssys = $(utility_fun)(
sys, ap, args...; loop_openings, system_modifier, kwargs...)
ModelingToolkit.linearize(ssys, lin_fun; kwargs...), ssys
ModelingToolkit.linearize(ssys, lin_fun), ssys
end
end

Expand All @@ -876,7 +876,7 @@
# Keyword Arguments

- `system_modifier`: a function which takes the modified system and returns a new system
with any required further modifications peformed.

Check warning on line 879 in src/systems/analysis_points.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"peformed" should be "performed".
"""
function open_loop(sys, ap::Union{Symbol, AnalysisPoint}; system_modifier = identity)
ap = only(canonicalize_ap(sys, ap))
Expand Down
44 changes: 22 additions & 22 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1120,7 +1120,7 @@ function DiffEqBase.SteadyStateProblem{iip}(sys::AbstractODESystem, u0map,
end
f, u0, p = process_SciMLProblem(ODEFunction{iip}, sys, u0map, parammap;
steady_state = true,
check_length, kwargs...)
check_length, force_initialization_time_independent = true, kwargs...)
kwargs = filter_kwargs(kwargs)
SteadyStateProblem{iip}(f, u0, p; kwargs...)
end
Expand Down Expand Up @@ -1295,27 +1295,41 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
check_units = true,
use_scc = true,
allow_incomplete = false,
force_time_independent = false,
algebraic_only = false,
kwargs...) where {iip, specialize}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEProblem`")
end
if isempty(u0map) && get_initializesystem(sys) !== nothing
isys = get_initializesystem(sys; initialization_eqs, check_units)
simplify_system = false
elseif isempty(u0map) && get_initializesystem(sys) === nothing
isys = structural_simplify(
generate_initializesystem(
isys = generate_initializesystem(
sys; initialization_eqs, check_units, pmap = parammap,
guesses, extra_metadata = (; use_scc)); fully_determined)
guesses, extra_metadata = (; use_scc), algebraic_only)
simplify_system = true
else
isys = structural_simplify(
generate_initializesystem(
isys = generate_initializesystem(
sys; u0map, initialization_eqs, check_units,
pmap = parammap, guesses, extra_metadata = (; use_scc)); fully_determined)
pmap = parammap, guesses, extra_metadata = (; use_scc), algebraic_only)
simplify_system = true
end

# useful for `SteadyStateProblem` since `f` has to be autonomous and the
# initialization should be too
if force_time_independent
idx = findfirst(isequal(get_iv(sys)), get_ps(isys))
idx === nothing || deleteat!(get_ps(isys), idx)
end

if simplify_system
isys = structural_simplify(isys; fully_determined)
end

meta = get_metadata(isys)
if meta isa InitializationSystemMetadata
@set! isys.metadata.oop_reconstruct_u0_p = ReconstructInitializeprob(sys, isys)
@set! isys.metadata.oop_reconstruct_u0_p = ReconstructInitializeprob(sys, isys; remap = meta.new_params)
end

ts = get_tearing_state(isys)
Expand Down Expand Up @@ -1374,20 +1388,6 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,

u0map = merge(ModelingToolkit.guesses(sys), todict(guesses), todict(u0map))

# Replace dummy derivatives in u0map: D(x) -> x_t etc.
if has_schedule(sys)
schedule = get_schedule(sys)
if !isnothing(schedule)
for (var, val) in u0map
dvar = get(schedule.dummy_sub, var, var) # with dummy derivatives
if dvar !== var # then replace it
delete!(u0map, var)
push!(u0map, dvar => val)
end
end
end
end

fullmap = merge(u0map, parammap)
u0T = Union{}
for sym in unknowns(isys)
Expand Down
6 changes: 5 additions & 1 deletion src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -687,13 +687,17 @@ function populate_delays(delays::Set, obsexprs, histfn, sys, sym)
end

function _eq_unordered(a, b)
a = vec(a)
b = vec(b)
length(a) === length(b) || return false
n = length(a)
idxs = Set(1:n)
for x in a
idx = findfirst(isequal(x), b)
while idx !== nothing && !(idx in idxs)
idx = findnext(isequal(x), b, idx + 1)
end
idx === nothing && return false
idx ∈ idxs || return false
delete!(idxs, idx)
end
return true
Expand Down
6 changes: 4 additions & 2 deletions src/systems/discrete_system/discrete_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ function SciMLBase.DiscreteProblem(
u0map = to_varmap(u0map, dvs)
u0map = shift_u0map_forward(sys, u0map, defaults(sys))
f, u0, p = process_SciMLProblem(
DiscreteFunction, sys, u0map, parammap; eval_expression, eval_module)
DiscreteFunction, sys, u0map, parammap; eval_expression, eval_module, build_initializeprob = false)
u0 = f(u0, p, tspan[1])
DiscreteProblem(f, u0, tspan, p; kwargs...)
end
Expand All @@ -336,6 +336,7 @@ function SciMLBase.DiscreteFunction{iip, specialize}(
eval_expression = false,
eval_module = @__MODULE__,
analytic = nothing,
initialization_data = nothing,
kwargs...) where {iip, specialize}
if !iscomplete(sys)
error("A completed `DiscreteSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`")
Expand All @@ -359,7 +360,8 @@ function SciMLBase.DiscreteFunction{iip, specialize}(
DiscreteFunction{iip, specialize}(f;
sys = sys,
observed = observedfun,
analytic = analytic)
analytic = analytic,
initialization_data = initialization_data)
end

"""
Expand Down
4 changes: 2 additions & 2 deletions src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
end

_f, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false, build_initializeprob = false)
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT

observedfun = ObservedFunctionCache(
Expand Down Expand Up @@ -523,7 +523,7 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi
parameter_dependencies = parameter_dependencies(sys),
metadata = get_metadata(sys), gui_metadata = get_gui_metadata(sys))
osys = complete(osys)
return ODEProblem(osys, u0map, tspan, parammap; check_length = false, kwargs...)
return ODEProblem(osys, u0map, tspan, parammap; check_length = false, build_initializeprob = false, kwargs...)
else
_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false,
Expand Down
Loading
Loading