From 1623a8e59144ae8ff2eaaab728ef3234b3978b82 Mon Sep 17 00:00:00 2001 From: Trey Roessig <roessig.trey@gmail.com> Date: Sun, 26 Jul 2020 09:07:54 -0700 Subject: [PATCH] Closes #17 - adding v1.4+ @spawn interpolation --- src/interface.jl | 3 --- src/macros.jl | 49 +++++++++++++++++++++++++++++++++------------ test/runtests.jl | 2 +- test/testspawnat.jl | 31 +++++++++++++++++++++++++++- 4 files changed, 67 insertions(+), 18 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 36c258b..fea5ff1 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -8,10 +8,7 @@ abstract type AbstractThreadPool end @deprecate pmap(fn::Function, pool, itr) tmap(fn::Function, pool, itr) - _detect_type(fn, itr) = Core.Compiler.return_type(fn, Tuple{eltype(itr)}) -#_detect_type(fn, itrs::Tuple) = Compiler.Core.return_type(fn, Tuple{eltype(itr)}) -#_detect_type(fn, itrs::Tuple) = eltype(map(fn, [empty(x) for x in itrs]...)) """ diff --git a/src/macros.jl b/src/macros.jl index 012f0ac..e2a9f47 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -253,21 +253,44 @@ julia> fetch(t) ``` """ macro tspawnat(thrdid, expr) - thunk = esc(:(()->($expr))) - var = esc(Base.sync_varname) - tid = esc(thrdid) - quote - if $tid < 1 || $tid > Threads.nthreads() - throw(AssertionError("@tspawnat thread assignment ($($tid)) must be between 1 and Threads.nthreads() (1:$(Threads.nthreads()))")) + if VERSION >= v"1.4" + letargs = Base._lift_one_interp!(expr) + + thunk = esc(:(()->($expr))) + var = esc(Base.sync_varname) + tid = esc(thrdid) + quote + if $tid < 1 || $tid > Threads.nthreads() + throw(AssertionError("@tspawnat thread assignment ($($tid)) must be between 1 and Threads.nthreads() (1:$(Threads.nthreads()))")) + end + let $(letargs...) + local task = Task($thunk) + task.sticky = false + ccall(:jl_set_task_tid, Cvoid, (Any, Cint), task, $tid-1) + if $(Expr(:islocal, var)) + put!($var, task) + end + schedule(task) + task + end end - local task = Task($thunk) - task.sticky = false - ccall(:jl_set_task_tid, Cvoid, (Any, Cint), task, $tid-1) - if $(Expr(:isdefined, var)) - push!($var, task) + else + thunk = esc(:(()->($expr))) + var = esc(Base.sync_varname) + tid = esc(thrdid) + quote + if $tid < 1 || $tid > Threads.nthreads() + throw(AssertionError("@tspawnat thread assignment ($($tid)) must be between 1 and Threads.nthreads() (1:$(Threads.nthreads()))")) + end + local task = Task($thunk) + task.sticky = false + ccall(:jl_set_task_tid, Cvoid, (Any, Cint), task, $tid-1) + if $(Expr(:isdefined, var)) + push!($var, task) + end + schedule(task) + task end - schedule(task) - task end end diff --git a/test/runtests.jl b/test/runtests.jl index 03ea821..fb506a0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,6 +9,6 @@ include("testforeach.jl") include("testlogfcns.jl") include("stacktests.jl") include("testspawnat.jl") -include("testplots.jl") include("testmultidim.jl") include("testmultiarg.jl") +include("testplots.jl") diff --git a/test/testspawnat.jl b/test/testspawnat.jl index 6dc6a97..2f187a5 100644 --- a/test/testspawnat.jl +++ b/test/testspawnat.jl @@ -5,6 +5,17 @@ using ThreadPools include("util.jl") + +macro ifv1p4(expr) + if VERSION >= v"1.4" + thunk = esc(:(()->($expr))) + quote + $thunk() + end + end +end + + @testset "@tspawnat" begin @testset "@normal operation" begin @@ -19,10 +30,28 @@ include("util.jl") @test obj.data == Threads.nthreads() end + @ifv1p4 begin + @testset "interpolation" begin + function foo(x) + sleep(0.01) + return x + end + + x = 1 + expect_sum = 3 + t1 = @tspawnat max(1, Threads.nthreads()) foo($x) + x += 1 + t2 = @tspawnat max(1, Threads.nthreads()-1) foo($x) + + test_sum = fetch(t1) + fetch(t2) + @test expect_sum == test_sum + end + end + @testset "@out of bounds" begin @test_throws AssertionError task = @tspawnat Threads.nthreads()+1 randn() @test_throws AssertionError task = @tspawnat 0 randn() end -end + end end # module \ No newline at end of file