diff --git a/docs/src/interface.md b/docs/src/interface.md index 74cf0d0..c0b4afd 100644 --- a/docs/src/interface.md +++ b/docs/src/interface.md @@ -13,6 +13,7 @@ enable! disable! next reset! +copy! ``` ## Query a Sampler diff --git a/src/prefixsearch/binarytreeprefixsearch.jl b/src/prefixsearch/binarytreeprefixsearch.jl index c45034c..8fcbd4a 100644 --- a/src/prefixsearch/binarytreeprefixsearch.jl +++ b/src/prefixsearch/binarytreeprefixsearch.jl @@ -44,6 +44,14 @@ function Base.empty!(ps::BinaryTreePrefixSearch) ps.cnt = 0 end +function Base.copy!(dst::BinaryTreePrefixSearch{T}, src::BinaryTreePrefixSearch{T}) where {T} + copy!(dst.array, src.array) + dst.depth = src.depth + dst.offset = src.offset + dst.cnt = src.cnt + dst.initial_allocation = src.initial_allocation +end + time_type(ps::BinaryTreePrefixSearch{T}) where {T} = T time_type(::Type{BinaryTreePrefixSearch{T}}) where {T} = T diff --git a/src/prefixsearch/cumsumprefixsearch.jl b/src/prefixsearch/cumsumprefixsearch.jl index de5ad31..3dc9dbb 100644 --- a/src/prefixsearch/cumsumprefixsearch.jl +++ b/src/prefixsearch/cumsumprefixsearch.jl @@ -28,6 +28,10 @@ function Base.empty!(ps::CumSumPrefixSearch) empty!(ps.cumulant) end +function Base.copy!(dst::CumSumPrefixSearch{T}, src::CumSumPrefixSearch{T}) where {T} + copy!(dst.array, src.array) + copy!(dst.cumulant, src.cumulant) +end Base.length(ps::CumSumPrefixSearch) = length(ps.array) time_type(ps::CumSumPrefixSearch{T}) where {T} = T diff --git a/src/prefixsearch/keyedprefixsearch.jl b/src/prefixsearch/keyedprefixsearch.jl index 691d114..2408b65 100644 --- a/src/prefixsearch/keyedprefixsearch.jl +++ b/src/prefixsearch/keyedprefixsearch.jl @@ -28,6 +28,13 @@ function Base.empty!(kp::KeyedKeepPrefixSearch) empty!(kp.prefix) end +function Base.copy!(dst::KeyedKeepPrefixSearch{T,P}, src::KeyedKeepPrefixSearch{T,P}) where {T,P} + copy!(dst.index, src.index) + copy!(dst.key, src.key) + copy!(dst.prefix, src.prefix) + dst +end + Base.length(kp::KeyedKeepPrefixSearch) = length(kp.index) time_type(kp::KeyedKeepPrefixSearch{T,P}) where {T,P} = time_type(P) @@ -93,6 +100,13 @@ function Base.empty!(kp::KeyedRemovalPrefixSearch) empty!(kp.prefix) end +function Base.copy!(dst::KeyedRemovalPrefixSearch{T,P}, src::KeyedRemovalPrefixSearch{T,P}) where {T,P} + copy!(dst.index, src.index) + copy!(dst.key, src.key) + copy!(dst.free, src.free) + copy!(dst.prefix, src.prefix) + dst +end Base.length(kp::KeyedRemovalPrefixSearch) = length(kp.index) diff --git a/src/sample/combinednr.jl b/src/sample/combinednr.jl index 44d0632..825b91c 100644 --- a/src/sample/combinednr.jl +++ b/src/sample/combinednr.jl @@ -132,7 +132,7 @@ sampling_space(::LinearGamma) = LinearSampling If you want to test a distribution, look at `tests/nrmetric.jl` to see how distributions are timed. """ -struct CombinedNextReaction{K,T} <: SSA{K,T} +mutable struct CombinedNextReaction{K,T} <: SSA{K,T} firing_queue::MutableBinaryMinHeap{OrderedSample{K,T}} transition_entry::Dict{K,NRTransition{T}} end @@ -153,6 +153,12 @@ function reset!(nr::CombinedNextReaction) nothing end +function Base.copy!(dst::CombinedNextReaction{K,T}, src::CombinedNextReaction{K,T}) where {K,T} + dst.firing_queue = deepcopy(src.firing_queue) + copy!(dst.transition_entry, src.transition_entry) +end + + @doc raw""" For the first reaction sampler, you can call next() multiple times and get different, valid, answers. That isn't the case here. When you call next() diff --git a/src/sample/direct.jl b/src/sample/direct.jl index c05de41..2326e4d 100644 --- a/src/sample/direct.jl +++ b/src/sample/direct.jl @@ -47,6 +47,8 @@ end reset!(dc::DirectCall) = (empty!(dc.prefix_tree); nothing) +Base.copy!(dst::DirectCall{K,T,P}, src::DirectCall{K,T,P}) where {K,T,P} = copy!(dst.prefix_tree, src.prefix_tree) + """ enable!(dc::DirectCall, clock::T, distribution::Exponential, when, rng) diff --git a/src/sample/firstreaction.jl b/src/sample/firstreaction.jl index 7713bca..e0eea5a 100644 --- a/src/sample/firstreaction.jl +++ b/src/sample/firstreaction.jl @@ -25,6 +25,7 @@ end reset!(fr::FirstReaction) = reset!(fr.core_matrix) +Base.copy!(dst::FirstReaction{K,T}, src::FirstReaction{K,T}) where {K,T} = (copy!(dst.core_matrix, src.core_matrix); dst) function enable!(fr::FirstReaction{K,T}, clock::K, distribution::UnivariateDistribution, diff --git a/src/sample/firsttofire.jl b/src/sample/firsttofire.jl index 166dc70..a5dd055 100644 --- a/src/sample/firsttofire.jl +++ b/src/sample/firsttofire.jl @@ -11,7 +11,7 @@ fire and saves that time in a sorted heap of future times. Then it works through the heap, one by one. When a clock is disabled, its future firing time is removed from the list. There is no memory of previous firing times. """ -struct FirstToFire{K,T} <: SSA{K,T} +mutable struct FirstToFire{K,T} <: SSA{K,T} firing_queue::MutableBinaryMinHeap{OrderedSample{K,T}} # This maps from transition to entry in the firing queue. transition_entry::Dict{K,Int} @@ -31,6 +31,12 @@ function reset!(propagator::FirstToFire{K,T}) where {K,T} empty!(propagator.transition_entry) end +function Base.copy!(dst::FirstToFire{K,T}, src::FirstToFire{K,T}) where {K,T} + dst.firing_queue = deepcopy(src.firing_queue) + copy!(dst.transition_entry, src.transition_entry) + dst +end + # Finds the next one without removing it from the queue. function next(propagator::FirstToFire{K,T}, when::T, rng::AbstractRNG) where {K,T} diff --git a/src/sample/interface.jl b/src/sample/interface.jl index a319df9..2b480ae 100644 --- a/src/sample/interface.jl +++ b/src/sample/interface.jl @@ -39,6 +39,17 @@ for another sample run. function reset!(sampler::SSA{K,T}) where {K,T} end +""" + copy!(destination_sampler, source_sampler) + +This copies the state of the source sampler to the destination sampler, replacing +the current state of the destination sampler. This is useful for splitting +techniques where you make copies of a simulation and restart it with different +random number generators. +""" +function Base.copy!(sampler::SSA{K,T}) where {K,T} end + + """ disable!(sampler, clock, when) diff --git a/src/sample/multiple_direct.jl b/src/sample/multiple_direct.jl index 53b9e35..9b18458 100644 --- a/src/sample/multiple_direct.jl +++ b/src/sample/multiple_direct.jl @@ -34,6 +34,19 @@ function reset!(md::MultipleDirect) end +function Base.copy!( + dst::MultipleDirect{SamplerKey,K,Time,Chooser}, + src::MultipleDirect{SamplerKey,K,Time,Chooser} + ) where {SamplerKey,K,Time,Chooser} + copy!(dst.scan, src.scan) + copy!(dst.totals, src.totals) + dst.chooser = deepcopy(src.chooser) + copy!(dst.chosen, src.chosen) + copy!(dst.scanmap, src.scanmap) + dst +end + + function Base.setindex!( md::MultipleDirect{SamplerKey,K,Time,Chooser}, keyed_prefix_search, key ) where {SamplerKey,K,Time,Chooser} diff --git a/src/sample/sampler.jl b/src/sample/sampler.jl index a86342e..6652353 100644 --- a/src/sample/sampler.jl +++ b/src/sample/sampler.jl @@ -33,6 +33,11 @@ function SingleSampler(propagator::SSA{Key,Time}) where {Key,Time} SingleSampler{SSA{Key,Time},Time}(propagator, zero(Time)) end +function Base.copy!(dst::SingleSampler{Algorithm,Time}, src::SingleSampler{Algorithm,Time}) where {Algorithm,Time} + copy!(dst.propagator, src.propagator) + dst.when = src.when + dst +end function sample!(sampler::SingleSampler, rng::AbstractRNG) when, transition = next(sampler.propagator, sampler.when, rng) @@ -143,6 +148,17 @@ function reset!(sampler::MultiSampler) end +function Base.copy!( + dst::MultiSampler{SamplerKey,Key,Time,Chooser}, + src::MultiSampler{SamplerKey,Key,Time,Chooser} + ) where {SamplerKey,Key,Time,Chooser} + + copy!(dst.propagator, src.propagator) + dst.when = src.when + dst +end + + function Base.setindex!( sampler::MultiSampler{SamplerKey,Key,Time}, algorithm::SSA{Key,Time}, sampler_key::SamplerKey ) where {SamplerKey,Key,Time} diff --git a/src/sample/track.jl b/src/sample/track.jl index 3f6a06b..0cd4db2 100644 --- a/src/sample/track.jl +++ b/src/sample/track.jl @@ -47,6 +47,10 @@ end reset!(ts::TrackWatcher) = (empty!(ts.enabled); nothing) +function Base.copy!(dst::TrackWatcher{K,T}, src::TrackWatcher{K,T}) where {K,T} + copy!(dst.enabled, src.enabled) +end + function Base.iterate(ts::TrackWatcher) return iterate(values(ts.enabled)) end @@ -99,6 +103,10 @@ end reset!(ts::DebugWatcher) = (empty!(ts.enabled); empty!(ts.disabled); nothing) +function Base.copy!(dst::DebugWatcher{K,T}, src::DebugWatcher{K,T}) where {K,T} + copy!(dst.enabled, src.enabled) + copy!(dst.disabled, src.disabled) +end function enable!(ts::DebugWatcher{K,T}, clock::K, dist::UnivariateDistribution, te, when, rng) where {K,T} push!(ts.enabled, EnablingEntry(clock, dist, te, when)) diff --git a/test/runtests.jl b/test/runtests.jl index 81abeea..4af0f6d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,6 +20,9 @@ end include("test_keyedprefixsearch.jl") end +@testset "test_track.jl" begin + include("test_track.jl") +end @testset "test_combinednr.jl" begin include("test_combinednr.jl") diff --git a/test/test_combinednr.jl b/test/test_combinednr.jl index 89d798c..8c28101 100644 --- a/test/test_combinednr.jl +++ b/test/test_combinednr.jl @@ -1,7 +1,7 @@ using SafeTestsets -@safetestset CombinedNextReactionSmoke = "combinednext reaction does basic things" begin +@safetestset CombinedNextReactionSmoke = "CombinedNextReaction reaction does basic things" begin using Distributions using Random using CompetingClocks: CombinedNextReaction, next, enable!, disable!, reset! @@ -49,3 +49,31 @@ end @test sampler[2] == 12.3 end + + +@safetestset CombinedNextReaction_copy = "CombinedNextReaction copy" begin + using CompetingClocks + using Distributions + using Random: Xoshiro + + src = CombinedNextReaction{Int64,Float64}() + dst = clone(src) + rng = Xoshiro(123) + + enable!(src, 37, Exponential(), 0.0, 0.0, rng) + enable!(src, 38, Exponential(), 0.0, 0.0, rng) + enable!(dst, 29, Exponential(), 0.0, 0.0, rng) + @test length(src) == 2 + @test length(dst) == 1 + copy!(dst, src) + @test length(src) == 2 + @test length(dst) == 2 + # Changing src doesn't change dst. + enable!(src, 48, Exponential(), 0.0, 0.0, rng) + @test length(src) == 3 + @test length(dst) == 2 + # Changing dst doesn't change src. + enable!(dst, 49, Exponential(), 0.0, 0.0, rng) + @test length(src) == 3 + @test length(dst) == 3 +end diff --git a/test/test_direct.jl b/test/test_direct.jl index a3a42a0..4e3951c 100644 --- a/test/test_direct.jl +++ b/test/test_direct.jl @@ -99,3 +99,27 @@ end md = DirectCall{Int,Float64}() test_exponential_binomial(md, rng) end + + +@safetestset direct_call_copy = "DirectCall copy" begin + using CompetingClocks: DirectCall, enable!, next + using Random: MersenneTwister + using Distributions: Exponential + + src = DirectCall{Int,Float64}() + dst = DirectCall{Int,Float64}() + rng = MersenneTwister(90422342) + enable!(src, 1, Exponential(), 0.0, 0.0, rng) + enable!(src, 2, Exponential(), 0.0, 0.0, rng) + enable!(dst, 3, Exponential(), 0.0, 0.0, rng) + @test length(src) == 2 + @test length(dst) == 1 + copy!(dst, src) + @test length(dst) == 2 + enable!(src, 5, Exponential(), 0.0, 0.0, rng) + @test length(src) == 3 + @test length(dst) == 2 + enable!(dst, 6, Exponential(), 0.0, 0.0, rng) + @test length(src) == 3 + @test length(dst) == 3 +end diff --git a/test/test_firstreaction.jl b/test/test_firstreaction.jl index b774a0c..4845347 100644 --- a/test/test_firstreaction.jl +++ b/test/test_firstreaction.jl @@ -154,3 +154,28 @@ end ks2_test = ExactOneSampleKSTest(shifted, dist) @test pvalue(ks2_test; tail = :both) > 0.04 end + + + +@safetestset first_reaction_copy = "FirstReaction copy" begin + using CompetingClocks: FirstReaction, enable!, next + using Random: MersenneTwister + using Distributions: Exponential + + src = FirstReaction{Int,Float64}() + dst = FirstReaction{Int,Float64}() + rng = MersenneTwister(90422342) + enable!(src, 1, Exponential(), 0.0, 0.0, rng) + enable!(src, 2, Exponential(), 0.0, 0.0, rng) + enable!(dst, 3, Exponential(), 0.0, 0.0, rng) + @test length(src) == 2 + @test length(dst) == 1 + copy!(dst, src) + @test length(dst) == 2 + enable!(src, 5, Exponential(), 0.0, 0.0, rng) + @test length(src) == 3 + @test length(dst) == 2 + enable!(dst, 6, Exponential(), 0.0, 0.0, rng) + @test length(src) == 3 + @test length(dst) == 3 +end diff --git a/test/test_firsttofire.jl b/test/test_firsttofire.jl index c4bcc6a..9bae373 100644 --- a/test/test_firsttofire.jl +++ b/test/test_firsttofire.jl @@ -81,3 +81,27 @@ end @test propagator[2] == 12.3 end + + +@safetestset FirstToFire_copy = "FirstToFire copy" begin + using CompetingClocks: FirstToFire, enable!, next + using Random: MersenneTwister + using Distributions: Exponential + + src = FirstToFire{Int,Float64}() + dst = FirstToFire{Int,Float64}() + rng = MersenneTwister(90422342) + enable!(src, 1, Exponential(), 0.0, 0.0, rng) + enable!(src, 2, Exponential(), 0.0, 0.0, rng) + enable!(dst, 3, Exponential(), 0.0, 0.0, rng) + @test length(src) == 2 + @test length(dst) == 1 + copy!(dst, src) + @test length(dst) == 2 + enable!(src, 5, Exponential(), 0.0, 0.0, rng) + @test length(src) == 3 + @test length(dst) == 2 + enable!(dst, 6, Exponential(), 0.0, 0.0, rng) + @test length(src) == 3 + @test length(dst) == 3 +end diff --git a/test/test_track.jl b/test/test_track.jl new file mode 100644 index 0000000..9085960 --- /dev/null +++ b/test/test_track.jl @@ -0,0 +1,45 @@ +using SafeTestsets + + +@safetestset track_trackwatcher_smoke = "TrackWatcher smoke" begin + using Distributions + using CompetingClocks + using Random + rng = Xoshiro(3242234) + tw = TrackWatcher{Int,Float64}() + enable!(tw, 3, Exponential(), 0.0, 0.0, rng) + @test length(tw.enabled) == 1 && 3 ∈ keys(tw.enabled) + enable!(tw, 4, Exponential(), 0.0, 3.0, rng) + @test length(tw.enabled) == 2 && 4 ∈ keys(tw.enabled) + enable!(tw, 7, Exponential(), 5.0, 5.0, rng) + @test length(tw.enabled) == 3 && 7 ∈ keys(tw.enabled) + disable!(tw, 4, 9.0) + @test length(tw.enabled) == 2 && 4 ∉ keys(tw.enabled) + + dst = TrackWatcher{Int,Float64}() + enable!(dst, 11, Exponential(), 5.0, 5.0, rng) + copy!(dst, tw) + @test length(tw.enabled) == 2 && 11 ∉ keys(tw.enabled) +end + + +@safetestset track_debugwatcher_smoke = "DebugWatcher smoke" begin + using Distributions + using CompetingClocks + using Random + rng = Xoshiro(3242234) + dw = DebugWatcher{Int,Float64}() + enable!(dw, 3, Exponential(), 0.0, 0.0, rng) + @test dw.enabled[1].clock == 3 + enable!(dw, 4, Exponential(), 0.0, 3.0, rng) + @test dw.enabled[2].clock == 4 + enable!(dw, 7, Exponential(), 5.0, 5.0, rng) + @test dw.enabled[3].clock == 7 + disable!(dw, 4, 9.0) + @test dw.disabled[1].clock == 4 + + dst = DebugWatcher{Int,Float64}() + enable!(dst, 11, Exponential(), 5.0, 5.0, rng) + copy!(dst, dw) + @test length(dw.enabled) == 3 && length(dw.disabled) == 1 +end