Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch find_first_continuous_callback to use a generated implementation. #920

Merged
merged 1 commit into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 34 additions & 37 deletions src/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,45 +116,42 @@ function get_condition(integrator::DEIntegrator, callback, abst)
end
end

# Use Recursion to find the first callback for type-stability

# Base Case: Only one callback
function find_first_continuous_callback(integrator, callback::AbstractContinuousCallback)
(find_callback_time(integrator, callback, 1)..., 1, 1)
end

# Starting Case: Compute on the first callback
function find_first_continuous_callback(integrator, callback::AbstractContinuousCallback,
args...)
find_first_continuous_callback(integrator,
find_callback_time(integrator, callback, 1)..., 1, 1,
args...)
# Use a generated function for type stability even when many callbacks are given
@inline function find_first_continuous_callback(integrator,
callbacks::Vararg{
AbstractContinuousCallback,
N}) where {N}
find_first_continuous_callback(integrator, tuple(callbacks...))
end

function find_first_continuous_callback(integrator, tmin::Number, upcrossing::Number,
event_occurred::Bool, event_idx::Int, idx::Int,
counter::Int,
callback2)
counter += 1 # counter is idx for callback2.
tmin2, upcrossing2, event_occurred2, event_idx2 = find_callback_time(integrator,
callback2, counter)

if event_occurred2 && (tmin2 < tmin || !event_occurred)
return tmin2, upcrossing2, true, event_idx2, counter, counter
else
return tmin, upcrossing, event_occurred, event_idx, idx, counter
@generated function find_first_continuous_callback(integrator,
callbacks::NTuple{N,
AbstractContinuousCallback
}) where {N}
ex = quote
tmin, upcrossing, event_occurred, event_idx = find_callback_time(integrator,
callbacks[1], 1)
identified_idx = 1
end
end

function find_first_continuous_callback(integrator, tmin::Number, upcrossing::Number,
event_occurred::Bool, event_idx::Int, idx::Int,
counter::Int, callback2, args...)
find_first_continuous_callback(integrator,
find_first_continuous_callback(integrator, tmin,
upcrossing,
event_occurred,
event_idx, idx, counter,
callback2)..., args...)
for i in 2:N
ex = quote
$ex
tmin2, upcrossing2, event_occurred2, event_idx2 = find_callback_time(integrator,
callbacks[$i],
$i)
if event_occurred2 && (tmin2 < tmin || !event_occurred)
tmin = tmin2
upcrossing = upcrossing2
event_occurred = true
event_idx = event_idx2
identified_idx = $i
end
end
end
ex = quote
$ex
return tmin, upcrossing, event_occurred, event_idx, identified_idx, $N
end
ex
end

@inline function determine_event_occurance(integrator, callback::VectorContinuousCallback,
Expand Down
52 changes: 52 additions & 0 deletions test/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,55 @@ cbs5 = CallbackSet(cbs1, cbs2)

@test length(cbs5.discrete_callbacks) == 1
@test length(cbs5.continuous_callbacks) == 2

# For the purposes of this test, create a empty integrator type and
# override find_callback_time, since we don't actually care about testing
# the find callback time aspect, just the inference failure
struct EmptyIntegrator
u::Vector{Float64}
end
function DiffEqBase.find_callback_time(integrator::EmptyIntegrator,
callback::ContinuousCallback, counter)
1.0 + counter, 0.9 + counter, true, counter
end
function DiffEqBase.find_callback_time(integrator::EmptyIntegrator,
callback::VectorContinuousCallback, counter)
1.0 + counter, 0.9 + counter, true, counter
end
find_first_integrator = EmptyIntegrator([1.0, 2.0])
vector_affect! = function (integrator, idx)
integrator.u = integrator.u + idx
end

cond_1(u, t, integrator) = t - 1.0
cond_2(u, t, integrator) = t - 1.1
cond_3(u, t, integrator) = t - 1.2
cond_4(u, t, integrator) = t - 1.3
cond_5(u, t, integrator) = t - 1.4
cond_6(u, t, integrator) = t - 1.5
cond_7(u, t, integrator) = t - 1.6
cond_8(u, t, integrator) = t - 1.7
cond_9(u, t, integrator) = t - 1.8
cond_10(u, t, integrator) = t - 1.9
# Setup a lot of callbacks so the recursive inference failure happens
callbacks = (ContinuousCallback(cond_1, affect!),
ContinuousCallback(cond_2, affect!),
ContinuousCallback(cond_3, affect!),
ContinuousCallback(cond_4, affect!),
ContinuousCallback(cond_5, affect!),
ContinuousCallback(cond_6, affect!),
ContinuousCallback(cond_7, affect!),
ContinuousCallback(cond_8, affect!),
ContinuousCallback(cond_9, affect!),
ContinuousCallback(cond_10, affect!),
VectorContinuousCallback(cond_1, vector_affect!, 2),
VectorContinuousCallback(cond_2, vector_affect!, 2),
VectorContinuousCallback(cond_3, vector_affect!, 2),
VectorContinuousCallback(cond_4, vector_affect!, 2),
VectorContinuousCallback(cond_5, vector_affect!, 2),
VectorContinuousCallback(cond_6, vector_affect!, 2));
function test_find_first_callback(callbacks, int)
@timed(DiffEqBase.find_first_continuous_callback(int, callbacks...))
end
test_find_first_callback(callbacks, find_first_integrator);
@test test_find_first_callback(callbacks, find_first_integrator).bytes == 0