Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add spawn_background #29

Merged
merged 2 commits into from
Sep 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/ThreadPools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export tmap, bmap, qmap, qbmap
export logtmap, logbmap, logqmap, logqbmap
export tforeach, bforeach, qforeach, qbforeach
export logtforeach, logbforeach, logqforeach, logqbforeach
export spawn_background, checked_fetch

include("interface.jl")
include("macros.jl")
Expand All @@ -17,6 +18,7 @@ include("qpool.jl")
include("logstaticpool.jl")
include("logqpool.jl")
include("simplefuncs.jl")
include("spawn_background.jl")


export @pthreads, pwith
Expand Down
61 changes: 61 additions & 0 deletions src/spawn_background.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
const AVAILABLE_THREADS = Base.RefValue{Channel{Int}}()

# Somehow, fetch doesn't do a very good job at preserving
# stacktraces. So, we catch any error in spawn_background
# And return it as a CapturedException, and then use checked_fetch to
# rethrow any exception in that case
function checked_fetch(future)
value = fetch(future)
value isa Exception && throw(value)
return value
end

"""
spawnbg(f)

Spawn work on any available background thread.
Captures any exception thrown in the thread, to give better stacktraces.

You can use `checked_fetch(spawnbg(f))` to rethrow any exception.

** Warning ** this doesn't compose with other ways of scheduling threads
So, one should use `spawn_background` exclusively in each Julia process.
"""
function spawnbg(f)
# -1, because we don't spawn on foreground thread 1
nbackground = Threads.nthreads() - 1
if nbackground == 0
# we don't run in threaded mode, so we just run things async
# to not block forever
@warn("No threads available, running in foreground thread")
return @async try
return f()
catch e
# If we don't do this, we get pretty bad stack traces... not sure why!?
return CapturedException(e, Base.catch_backtrace())
end
end
# Initialize dynamically, could also do this in __init__ but it's nice to keep things in one place
if !isassigned(AVAILABLE_THREADS)
# Allocate a Channel with n background threads
c = Channel{Int}(nbackground)
AVAILABLE_THREADS[] = c
# fill queue with available threads
foreach(i -> put!(c, i + 1), 1:nbackground)
end
# take the next free thread... Will block/wait until a thread becomes free
thread_id = take!(AVAILABLE_THREADS[])

return ThreadPools.@tspawnat thread_id begin
try
return f()
catch e
# If we don't do this, we get pretty bad stack traces...
# not sure why something so basic just doesn't work well \_(ツ)_/¯
return CapturedException(e, Base.catch_backtrace())
finally
# Make thread available again after work is done!
put!(AVAILABLE_THREADS[], thread_id)
end
end
end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
GR = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ThreadPools = "b189fb0b-2eb5-4ed4-bc0c-d34c51242431"
3 changes: 2 additions & 1 deletion test/runtests_exec.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

using Test, ThreadPools
include("teststatic.jl")
include("testlogstatic.jl")
include("testq.jl")
Expand All @@ -14,3 +14,4 @@ include("testmultiarg.jl")
include("testmisc.jl")
include("testplots.jl")
include("errorhandling.jl")
include("spawn_background.jl")
97 changes: 97 additions & 0 deletions test/spawn_background.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
module TestSpawnBackground

using Test
using ThreadPools
using Statistics

# Put in the function to test this in compiled form
# to make sure there is no yield etc introduced from running interpreted
function uses_all_threads()
bg_nthreads = Threads.nthreads() - 1
bg_threads = zeros(bg_nthreads)
futures = map(1:bg_nthreads) do i
return spawnbg() do
id = Threads.threadid()
bg_threads[id - 1] = id
return id
end
end
foreach(wait, futures)
return sum(bg_threads) == sum(2:(bg_nthreads + 1))
end

function busy_wait(time_s)
t = time()
while time() - t < time_s
end
return
end

function count_occurence(list)
occurences = Dict{Int,Int}()
for elem in list
i = get!(occurences, elem, 0)
occurences[elem] = i + 1
end
return occurences
end

function spam_threads(f, spam_factor)
bg_nthreads = Threads.nthreads() - 1
n_executions = bg_nthreads * spam_factor
thread_ids = []
time_spent = @elapsed begin
futures = map(1:n_executions) do i
return spawnbg() do
f()
return Threads.threadid()
end
end
thread_ids = map(fetch, futures)
end
return time_spent, thread_ids
end

function spam_threads_busy(time_waiting, spam_factor)
return spam_threads(spam_factor) do
return busy_wait(time_waiting)
end
end

@testset "threading" begin
nthreads = Threads.nthreads()
bg_nthreads = nthreads - 1
if bg_nthreads == 0
@test fetch(spawnbg(()-> Threads.threadid())) == 1
else
@testset "scheduling" begin
# When we quickly schedule nthreads work items, the implementation should use all threads
@test uses_all_threads()

spam_factor = 5
time_spent, thread_ids = spam_threads(() -> nothing, spam_factor)
occurences = count_occurence(thread_ids)
# We should spread out work to all threads when spamming lots of tasks
@test all(x -> x in keys(occurences), 2:bg_nthreads)
# a few threads may get more work items, but the mean should be equal to the spamfactor
@test spam_factor == mean(values(occurences))

time_spent, thread_ids = spam_threads_busy(0.5, spam_factor)
occurences = count_occurence(thread_ids)
@test all(x -> x in keys(occurences), 2:bg_nthreads)
@test spam_factor == mean(values(occurences))
# I'm not sure how stable this will be on the CI, we may need to tweak the atol
@test time_spent ≈ 0.5 * spam_factor atol = 0.1
end
@testset "Queue contains all threads, after work is done" begin
@test length(unique(ThreadPools.AVAILABLE_THREADS[].data)) == bg_nthreads
end
end

@testset "error handling" begin
@test_throws CapturedException checked_fetch(spawnbg(() -> error("hey")))
end

end

end