Skip to content

Commit

Permalink
Apply callbacks with a type-stable generated function.
Browse files Browse the repository at this point in the history
Currently, as referenced in SciML/DifferentialEquations.jl#971, the old
implementation of `handle_callbacks!` directly calls. `apply_callback!`
on `continuous_callbacks[idx]`, which is inherently type-unstable
because `apply_callback!` is specialized on the callback type.

This commit adds a generated function `apply_ith_callback!` which
generates type-stable code to do the same thing, where for each callback
tuple type, the generated function unrolls the tuple by checking the
callback index against static indicies. As a nice bonus, this generated
function seems to often be converted into a switch statement at the LLVM
level:
```
   switch i64 %4, label %L46 [
    i64 9, label %L3
    i64 8, label %L8
    i64 7, label %L13
    i64 6, label %L18
    i64 5, label %L23
    i64 4, label %L28
    i64 3, label %L33
    i64 2, label %L38
    i64 1, label %L43
  ]
```

For testing, I added an allocation test which sets up a simple ODE
problem, steps the integrator manually to before the first callback,
then manipulates integrator state past the first callback point. This
way, we can directly call `handle_callbacks!` and write a test on the
allocation count. I confirm that (at least testing against commit
SciML/DiffEqBase.jl@1799fc3, the current master branch tip in
DiffEqBase.jl), the new method does not allocate, whereas the old one
allocates. This may not be the case until a new release is cut of
DiffEqBase.jl, because the old version of
`find_first_continuous_callback` might allocate.
  • Loading branch information
meson800 committed Aug 23, 2023
1 parent 8ef8ce9 commit 9c350f5
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 6 deletions.
39 changes: 33 additions & 6 deletions src/integrators/integrator_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,31 @@ function _loopfooter!(integrator)
nothing
end

# Use a generated function to call apply_callback! in a type-stable way
@generated function apply_ith_callback!(integrator,
time, upcrossing, event_idx, cb_idx,
callbacks::NTuple{N,
Union{ContinuousCallback,
VectorContinuousCallback}}) where {N}
ex = quote
throw(BoundsError(callbacks, cb_idx))
end
for i in 1:N
# N.B: doing this as an explicit if (return) else (rest of expression)
# means that LLVM compiles this into a switch.
# This seemingly isn't the case with just if (return) end (rest of expression)
ex = quote
if (cb_idx == $i)
return DiffEqBase.apply_callback!(integrator, callbacks[$i], time,
upcrossing, event_idx)
else
$ex
end
end
end
ex
end

function handle_callbacks!(integrator)
discrete_callbacks = integrator.opts.callback.discrete_callbacks
continuous_callbacks = integrator.opts.callback.continuous_callbacks
Expand All @@ -295,22 +320,23 @@ function handle_callbacks!(integrator)
saved_in_cb = false
if !(typeof(continuous_callbacks) <: Tuple{})
time, upcrossing, event_occurred, event_idx, idx, counter = DiffEqBase.find_first_continuous_callback(integrator,
continuous_callbacks...)
continuous_callbacks...)
if event_occurred
integrator.event_last_time = idx
integrator.vector_event_last_time = event_idx
continuous_modified, saved_in_cb = DiffEqBase.apply_callback!(integrator,
continuous_callbacks[idx],
time, upcrossing,
event_idx)
continuous_modified, saved_in_cb = apply_ith_callback!(integrator,
time, upcrossing,
event_idx,
idx,
continuous_callbacks)
else
integrator.event_last_time = 0
integrator.vector_event_last_time = 1
end
end
if !integrator.force_stepfail && !(typeof(discrete_callbacks) <: Tuple{})
discrete_modified, saved_in_cb = DiffEqBase.apply_discrete_callback!(integrator,
discrete_callbacks...)
discrete_callbacks...)
end
if !saved_in_cb
savevalues!(integrator)
Expand All @@ -321,6 +347,7 @@ function handle_callbacks!(integrator)
integrator.do_error_check = false
handle_callback_modifiers!(integrator)
end
nothing
end

function handle_callback_modifiers!(integrator::ODEIntegrator)
Expand Down
45 changes: 45 additions & 0 deletions test/integrators/callback_allocation_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
using OrdinaryDiffEq, Test

# Setup a simple ODE problem with several callbacks (to test LLVM code gen)
# We will manually trigger the first callback and check its allocations.
function f!(du, u, p, t)
du .= .-u
end

cond_1(u, t, integrator) = u[1] - 0.5
cond_2(u, t, integrator) = u[2] + 0.5
cond_3(u, t, integrator) = u[2] + 0.6
cond_4(u, t, integrator) = u[2] + 0.7
cond_5(u, t, integrator) = u[2] + 0.8
cond_6(u, t, integrator) = u[2] + 0.9
cond_7(u, t, integrator) = u[2] + 0.1
cond_8(u, t, integrator) = u[2] + 0.11
cond_9(u, t, integrator) = u[2] + 0.12

function cb_affect!(integrator)
integrator.p[1] += 1
end

cbs = CallbackSet(ContinuousCallback(cond_1, cb_affect!),
ContinuousCallback(cond_2, cb_affect!),
ContinuousCallback(cond_3, cb_affect!),
ContinuousCallback(cond_4, cb_affect!),
ContinuousCallback(cond_5, cb_affect!),
ContinuousCallback(cond_6, cb_affect!),
ContinuousCallback(cond_7, cb_affect!),
ContinuousCallback(cond_8, cb_affect!),
ContinuousCallback(cond_9, cb_affect!))

integrator = init(ODEProblem(f!, [0.8, 1.0], (0.0, 100.0), [0, 0]), Tsit5(), callback = cbs,
save_on = false);
# Force a callback event to occur so we can call handle_callbacks! directly.
# Step to a point where u[1] is still > 0.5, so we can force it below 0.5 and
# call handle callbacks
step!(integrator, 0.1, true)

function handle_allocs(integrator)
integrator.u[1] = 0.4
@allocations OrdinaryDiffEq.handle_callbacks!(integrator)
end
handle_allocs(integrator);
@test handle_allocs(integrator) == 0
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ end
@time @safetestset "Events Tests" include("integrators/ode_event_tests.jl")
@time @safetestset "Alg Events Tests" include("integrators/alg_events_tests.jl")
@time @safetestset "Discrete Callback Dual Tests" include("integrators/discrete_callback_dual_test.jl")
@time @safetestset "Callback Allocation Tests" include("integrators/callback_allocation_tests.jl")
@time @safetestset "Iterator Tests" include("integrators/iterator_tests.jl")
@time @safetestset "Integrator Interface Tests" include("integrators/integrator_interface_tests.jl")
@time @safetestset "Error Check Tests" include("integrators/check_error.jl")
Expand Down

0 comments on commit 9c350f5

Please sign in to comment.