From 1623a8e59144ae8ff2eaaab728ef3234b3978b82 Mon Sep 17 00:00:00 2001
From: Trey Roessig <roessig.trey@gmail.com>
Date: Sun, 26 Jul 2020 09:07:54 -0700
Subject: [PATCH] Closes #17 - adding v1.4+ @spawn interpolation

---
 src/interface.jl    |  3 ---
 src/macros.jl       | 49 +++++++++++++++++++++++++++++++++------------
 test/runtests.jl    |  2 +-
 test/testspawnat.jl | 31 +++++++++++++++++++++++++++-
 4 files changed, 67 insertions(+), 18 deletions(-)

diff --git a/src/interface.jl b/src/interface.jl
index 36c258b..fea5ff1 100644
--- a/src/interface.jl
+++ b/src/interface.jl
@@ -8,10 +8,7 @@ abstract type AbstractThreadPool end
 @deprecate pmap(fn::Function, pool, itr) tmap(fn::Function, pool, itr)
 
 
-
 _detect_type(fn, itr) = Core.Compiler.return_type(fn, Tuple{eltype(itr)})
-#_detect_type(fn, itrs::Tuple) = Compiler.Core.return_type(fn, Tuple{eltype(itr)})
-#_detect_type(fn, itrs::Tuple) = eltype(map(fn, [empty(x) for x in itrs]...))
 
 
 """
diff --git a/src/macros.jl b/src/macros.jl
index 012f0ac..e2a9f47 100644
--- a/src/macros.jl
+++ b/src/macros.jl
@@ -253,21 +253,44 @@ julia> fetch(t)
 ```
 """
 macro tspawnat(thrdid, expr)
-    thunk = esc(:(()->($expr)))
-    var = esc(Base.sync_varname)
-    tid = esc(thrdid)
-    quote
-        if $tid < 1 || $tid > Threads.nthreads()
-            throw(AssertionError("@tspawnat thread assignment ($($tid)) must be between 1 and Threads.nthreads() (1:$(Threads.nthreads()))"))
+    if VERSION >= v"1.4"
+        letargs = Base._lift_one_interp!(expr)
+    
+        thunk = esc(:(()->($expr)))
+        var = esc(Base.sync_varname)
+        tid = esc(thrdid)
+        quote
+            if $tid < 1 || $tid > Threads.nthreads()
+                throw(AssertionError("@tspawnat thread assignment ($($tid)) must be between 1 and Threads.nthreads() (1:$(Threads.nthreads()))"))
+            end
+            let $(letargs...)
+                local task = Task($thunk)
+                task.sticky = false
+                ccall(:jl_set_task_tid, Cvoid, (Any, Cint), task, $tid-1)
+                if $(Expr(:islocal, var))
+                    put!($var, task)
+                end
+                schedule(task)
+                task
+            end
         end
-        local task = Task($thunk)
-        task.sticky = false
-        ccall(:jl_set_task_tid, Cvoid, (Any, Cint), task, $tid-1)
-        if $(Expr(:isdefined, var))
-            push!($var, task)
+    else
+        thunk = esc(:(()->($expr)))
+        var = esc(Base.sync_varname)
+        tid = esc(thrdid)
+        quote
+            if $tid < 1 || $tid > Threads.nthreads()
+                throw(AssertionError("@tspawnat thread assignment ($($tid)) must be between 1 and Threads.nthreads() (1:$(Threads.nthreads()))"))
+            end
+            local task = Task($thunk)
+            task.sticky = false
+            ccall(:jl_set_task_tid, Cvoid, (Any, Cint), task, $tid-1)
+            if $(Expr(:isdefined, var))
+                push!($var, task)
+            end
+            schedule(task)
+            task
         end
-        schedule(task)
-        task
     end
 end
 
diff --git a/test/runtests.jl b/test/runtests.jl
index 03ea821..fb506a0 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -9,6 +9,6 @@ include("testforeach.jl")
 include("testlogfcns.jl")
 include("stacktests.jl")
 include("testspawnat.jl")
-include("testplots.jl")
 include("testmultidim.jl")
 include("testmultiarg.jl")
+include("testplots.jl")
diff --git a/test/testspawnat.jl b/test/testspawnat.jl
index 6dc6a97..2f187a5 100644
--- a/test/testspawnat.jl
+++ b/test/testspawnat.jl
@@ -5,6 +5,17 @@ using ThreadPools
 
 include("util.jl")
 
+
+macro ifv1p4(expr)
+    if VERSION >= v"1.4"
+        thunk = esc(:(()->($expr)))
+        quote
+            $thunk()
+        end
+    end
+end
+
+
 @testset "@tspawnat" begin
 
     @testset "@normal operation" begin
@@ -19,10 +30,28 @@ include("util.jl")
         @test obj.data == Threads.nthreads()
     end
 
+    @ifv1p4 begin
+        @testset "interpolation" begin
+            function foo(x)
+                sleep(0.01)
+                return x
+            end
+
+            x = 1
+            expect_sum = 3
+            t1 = @tspawnat max(1, Threads.nthreads()) foo($x)
+            x += 1
+            t2 = @tspawnat max(1, Threads.nthreads()-1) foo($x)
+            
+            test_sum = fetch(t1) + fetch(t2)
+            @test expect_sum == test_sum
+        end
+    end
+
     @testset "@out of bounds" begin
         @test_throws AssertionError task = @tspawnat Threads.nthreads()+1 randn()
         @test_throws AssertionError task = @tspawnat 0 randn()
     end
 
-end
+    end
 end # module
\ No newline at end of file