Skip to content

Commit

Permalink
Closes #17 - adding v1.4+ @Spawn interpolation
Browse files Browse the repository at this point in the history
  • Loading branch information
tro3 committed Jul 26, 2020
1 parent 69a6b09 commit 1623a8e
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 18 deletions.
3 changes: 0 additions & 3 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@ abstract type AbstractThreadPool end
@deprecate pmap(fn::Function, pool, itr) tmap(fn::Function, pool, itr)



_detect_type(fn, itr) = Core.Compiler.return_type(fn, Tuple{eltype(itr)})
#_detect_type(fn, itrs::Tuple) = Compiler.Core.return_type(fn, Tuple{eltype(itr)})
#_detect_type(fn, itrs::Tuple) = eltype(map(fn, [empty(x) for x in itrs]...))


"""
Expand Down
49 changes: 36 additions & 13 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,21 +253,44 @@ julia> fetch(t)
```
"""
macro tspawnat(thrdid, expr)
thunk = esc(:(()->($expr)))
var = esc(Base.sync_varname)
tid = esc(thrdid)
quote
if $tid < 1 || $tid > Threads.nthreads()
throw(AssertionError("@tspawnat thread assignment ($($tid)) must be between 1 and Threads.nthreads() (1:$(Threads.nthreads()))"))
if VERSION >= v"1.4"
letargs = Base._lift_one_interp!(expr)

thunk = esc(:(()->($expr)))
var = esc(Base.sync_varname)
tid = esc(thrdid)
quote
if $tid < 1 || $tid > Threads.nthreads()
throw(AssertionError("@tspawnat thread assignment ($($tid)) must be between 1 and Threads.nthreads() (1:$(Threads.nthreads()))"))
end
let $(letargs...)
local task = Task($thunk)
task.sticky = false
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), task, $tid-1)
if $(Expr(:islocal, var))
put!($var, task)
end
schedule(task)
task
end
end
local task = Task($thunk)
task.sticky = false
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), task, $tid-1)
if $(Expr(:isdefined, var))
push!($var, task)
else
thunk = esc(:(()->($expr)))
var = esc(Base.sync_varname)
tid = esc(thrdid)
quote
if $tid < 1 || $tid > Threads.nthreads()
throw(AssertionError("@tspawnat thread assignment ($($tid)) must be between 1 and Threads.nthreads() (1:$(Threads.nthreads()))"))
end
local task = Task($thunk)
task.sticky = false
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), task, $tid-1)
if $(Expr(:isdefined, var))
push!($var, task)
end
schedule(task)
task
end
schedule(task)
task
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ include("testforeach.jl")
include("testlogfcns.jl")
include("stacktests.jl")
include("testspawnat.jl")
include("testplots.jl")
include("testmultidim.jl")
include("testmultiarg.jl")
include("testplots.jl")
31 changes: 30 additions & 1 deletion test/testspawnat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@ using ThreadPools

include("util.jl")


macro ifv1p4(expr)
if VERSION >= v"1.4"
thunk = esc(:(()->($expr)))
quote
$thunk()
end
end
end


@testset "@tspawnat" begin

@testset "@normal operation" begin
Expand All @@ -19,10 +30,28 @@ include("util.jl")
@test obj.data == Threads.nthreads()
end

@ifv1p4 begin
@testset "interpolation" begin
function foo(x)
sleep(0.01)
return x
end

x = 1
expect_sum = 3
t1 = @tspawnat max(1, Threads.nthreads()) foo($x)
x += 1
t2 = @tspawnat max(1, Threads.nthreads()-1) foo($x)

test_sum = fetch(t1) + fetch(t2)
@test expect_sum == test_sum
end
end

@testset "@out of bounds" begin
@test_throws AssertionError task = @tspawnat Threads.nthreads()+1 randn()
@test_throws AssertionError task = @tspawnat 0 randn()
end

end
end
end # module

0 comments on commit 1623a8e

Please sign in to comment.