From 677636b5daf2c615531c99232594d030f76fac5d Mon Sep 17 00:00:00 2001 From: Christopher Johnstone Date: Tue, 22 Aug 2023 13:16:51 -0400 Subject: [PATCH] Switch find_first_continuous_callback to use a generated implementation. As mentioned in SciML/DifferentialEquations.jl#971, the current recursive method for identifying the first continuous callback can cause the compiler to give up on type inference, especially when there are many callbacks. The fallback then allocates. This switches this function to using a generated function (along with an inline function that takes splatted tuples). Because this generated function explicitly unrolls the tuple, there are no type inference problems. I added a test that allocates using the old implementation (about 19kb allocations!) but does not with the new system. --- src/callbacks.jl | 71 +++++++++++++++++++++++------------------------ test/callbacks.jl | 52 ++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 37 deletions(-) diff --git a/src/callbacks.jl b/src/callbacks.jl index e5a09045c..a4333d2cd 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -116,45 +116,42 @@ function get_condition(integrator::DEIntegrator, callback, abst) end end -# Use Recursion to find the first callback for type-stability - -# Base Case: Only one callback -function find_first_continuous_callback(integrator, callback::AbstractContinuousCallback) - (find_callback_time(integrator, callback, 1)..., 1, 1) -end - -# Starting Case: Compute on the first callback -function find_first_continuous_callback(integrator, callback::AbstractContinuousCallback, - args...) - find_first_continuous_callback(integrator, - find_callback_time(integrator, callback, 1)..., 1, 1, - args...) +# Use a generated function for type stability even when many callbacks are given +@inline function find_first_continuous_callback(integrator, + callbacks::Vararg{ + AbstractContinuousCallback, + N}) where {N} + find_first_continuous_callback(integrator, tuple(callbacks...)) end - -function find_first_continuous_callback(integrator, tmin::Number, upcrossing::Number, - event_occurred::Bool, event_idx::Int, idx::Int, - counter::Int, - callback2) - counter += 1 # counter is idx for callback2. - tmin2, upcrossing2, event_occurred2, event_idx2 = find_callback_time(integrator, - callback2, counter) - - if event_occurred2 && (tmin2 < tmin || !event_occurred) - return tmin2, upcrossing2, true, event_idx2, counter, counter - else - return tmin, upcrossing, event_occurred, event_idx, idx, counter +@generated function find_first_continuous_callback(integrator, + callbacks::NTuple{N, + AbstractContinuousCallback + }) where {N} + ex = quote + tmin, upcrossing, event_occurred, event_idx = find_callback_time(integrator, + callbacks[1], 1) + identified_idx = 1 end -end - -function find_first_continuous_callback(integrator, tmin::Number, upcrossing::Number, - event_occurred::Bool, event_idx::Int, idx::Int, - counter::Int, callback2, args...) - find_first_continuous_callback(integrator, - find_first_continuous_callback(integrator, tmin, - upcrossing, - event_occurred, - event_idx, idx, counter, - callback2)..., args...) + for i in 2:N + ex = quote + $ex + tmin2, upcrossing2, event_occurred2, event_idx2 = find_callback_time(integrator, + callbacks[$i], + $i) + if event_occurred2 && (tmin2 < tmin || !event_occurred) + tmin = tmin2 + upcrossing = upcrossing2 + event_occurred = true + event_idx = event_idx2 + identified_idx = $i + end + end + end + ex = quote + $ex + return tmin, upcrossing, event_occurred, event_idx, identified_idx, $N + end + ex end @inline function determine_event_occurance(integrator, callback::VectorContinuousCallback, diff --git a/test/callbacks.jl b/test/callbacks.jl index 9260e980e..252af55bd 100644 --- a/test/callbacks.jl +++ b/test/callbacks.jl @@ -50,3 +50,55 @@ cbs5 = CallbackSet(cbs1, cbs2) @test length(cbs5.discrete_callbacks) == 1 @test length(cbs5.continuous_callbacks) == 2 + +# For the purposes of this test, create a empty integrator type and +# override find_callback_time, since we don't actually care about testing +# the find callback time aspect, just the inference failure +struct EmptyIntegrator + u::Vector{Float64} +end +function DiffEqBase.find_callback_time(integrator::EmptyIntegrator, + callback::ContinuousCallback, counter) + 1.0 + counter, 0.9 + counter, true, counter +end +function DiffEqBase.find_callback_time(integrator::EmptyIntegrator, + callback::VectorContinuousCallback, counter) + 1.0 + counter, 0.9 + counter, true, counter +end +find_first_integrator = EmptyIntegrator([1.0, 2.0]) +vector_affect! = function (integrator, idx) + integrator.u = integrator.u + idx +end + +cond_1(u, t, integrator) = t - 1.0 +cond_2(u, t, integrator) = t - 1.1 +cond_3(u, t, integrator) = t - 1.2 +cond_4(u, t, integrator) = t - 1.3 +cond_5(u, t, integrator) = t - 1.4 +cond_6(u, t, integrator) = t - 1.5 +cond_7(u, t, integrator) = t - 1.6 +cond_8(u, t, integrator) = t - 1.7 +cond_9(u, t, integrator) = t - 1.8 +cond_10(u, t, integrator) = t - 1.9 +# Setup a lot of callbacks so the recursive inference failure happens +callbacks = (ContinuousCallback(cond_1, affect!), + ContinuousCallback(cond_2, affect!), + ContinuousCallback(cond_3, affect!), + ContinuousCallback(cond_4, affect!), + ContinuousCallback(cond_5, affect!), + ContinuousCallback(cond_6, affect!), + ContinuousCallback(cond_7, affect!), + ContinuousCallback(cond_8, affect!), + ContinuousCallback(cond_9, affect!), + ContinuousCallback(cond_10, affect!), + VectorContinuousCallback(cond_1, vector_affect!, 2), + VectorContinuousCallback(cond_2, vector_affect!, 2), + VectorContinuousCallback(cond_3, vector_affect!, 2), + VectorContinuousCallback(cond_4, vector_affect!, 2), + VectorContinuousCallback(cond_5, vector_affect!, 2), + VectorContinuousCallback(cond_6, vector_affect!, 2)); +function test_find_first_callback(callbacks, int) + @timed(DiffEqBase.find_first_continuous_callback(int, callbacks...)) +end +test_find_first_callback(callbacks, find_first_integrator); +@test test_find_first_callback(callbacks, find_first_integrator).bytes == 0