-
-
Notifications
You must be signed in to change notification settings - Fork 117
/
Copy pathcallbacks.jl
104 lines (90 loc) · 3.61 KB
/
callbacks.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
using DiffEqBase, Test
condition = function (u, t, integrator) # Event when event_f(u,t,k) == 0
t - 2.95
end
affect! = function (integrator)
integrator.u = integrator.u + 2
end
rootfind = true
save_positions = (true, true)
callback = ContinuousCallback(condition, affect!; save_positions = save_positions)
cbs = CallbackSet(nothing)
@test typeof(cbs.discrete_callbacks) <: Tuple
@test typeof(cbs.continuous_callbacks) <: Tuple
cbs = CallbackSet(callback, nothing)
@test typeof(cbs.discrete_callbacks) <: Tuple
@test typeof(cbs.continuous_callbacks) <: Tuple
cbs = CallbackSet(callback, CallbackSet())
@test typeof(cbs.discrete_callbacks) <: Tuple
@test typeof(cbs.continuous_callbacks) <: Tuple
condition = function (integrator)
true
end
affect! = function (integrator) end
save_positions = (true, false)
saving_callback = DiscreteCallback(condition, affect!; save_positions = save_positions)
cbs1 = CallbackSet(callback, saving_callback)
@test length(cbs1.discrete_callbacks) == 1
@test length(cbs1.continuous_callbacks) == 1
cbs2 = CallbackSet(callback)
@test length(cbs2.continuous_callbacks) == 1
@test length(cbs2.discrete_callbacks) == 0
cbs3 = CallbackSet(saving_callback)
@test length(cbs3.discrete_callbacks) == 1
@test length(cbs3.continuous_callbacks) == 0
cbs4 = CallbackSet()
@test length(cbs4.discrete_callbacks) == 0
@test length(cbs4.continuous_callbacks) == 0
cbs5 = CallbackSet(cbs1, cbs2)
@test length(cbs5.discrete_callbacks) == 1
@test length(cbs5.continuous_callbacks) == 2
# For the purposes of this test, create a empty integrator type and
# override find_callback_time, since we don't actually care about testing
# the find callback time aspect, just the inference failure
struct EmptyIntegrator
u::Vector{Float64}
end
function DiffEqBase.find_callback_time(integrator::EmptyIntegrator,
callback::ContinuousCallback, counter)
1.0 + counter, 0.9 + counter, true, counter
end
function DiffEqBase.find_callback_time(integrator::EmptyIntegrator,
callback::VectorContinuousCallback, counter)
1.0 + counter, 0.9 + counter, true, counter
end
find_first_integrator = EmptyIntegrator([1.0, 2.0])
vector_affect! = function (integrator, idx)
integrator.u = integrator.u + idx
end
cond_1(u, t, integrator) = t - 1.0
cond_2(u, t, integrator) = t - 1.1
cond_3(u, t, integrator) = t - 1.2
cond_4(u, t, integrator) = t - 1.3
cond_5(u, t, integrator) = t - 1.4
cond_6(u, t, integrator) = t - 1.5
cond_7(u, t, integrator) = t - 1.6
cond_8(u, t, integrator) = t - 1.7
cond_9(u, t, integrator) = t - 1.8
cond_10(u, t, integrator) = t - 1.9
# Setup a lot of callbacks so the recursive inference failure happens
callbacks = (ContinuousCallback(cond_1, affect!),
ContinuousCallback(cond_2, affect!),
ContinuousCallback(cond_3, affect!),
ContinuousCallback(cond_4, affect!),
ContinuousCallback(cond_5, affect!),
ContinuousCallback(cond_6, affect!),
ContinuousCallback(cond_7, affect!),
ContinuousCallback(cond_8, affect!),
ContinuousCallback(cond_9, affect!),
ContinuousCallback(cond_10, affect!),
VectorContinuousCallback(cond_1, vector_affect!, 2),
VectorContinuousCallback(cond_2, vector_affect!, 2),
VectorContinuousCallback(cond_3, vector_affect!, 2),
VectorContinuousCallback(cond_4, vector_affect!, 2),
VectorContinuousCallback(cond_5, vector_affect!, 2),
VectorContinuousCallback(cond_6, vector_affect!, 2));
function test_find_first_callback(callbacks, int)
@timed(DiffEqBase.find_first_continuous_callback(int, callbacks...))
end
test_find_first_callback(callbacks, find_first_integrator);
@test test_find_first_callback(callbacks, find_first_integrator).bytes == 0