From 39fc4eec54fc97efda268c9975348121d0d225f5 Mon Sep 17 00:00:00 2001 From: Jarrett Revels Date: Thu, 28 May 2020 12:58:41 -0400 Subject: [PATCH] add `Threads.foreach` for convenient multithreaded Channel consumption (#34543) Co-authored-by: Takafumi Arakaki Co-authored-by: Alex Arslan Co-authored-by: Valentin Churavy --- NEWS.md | 4 ++- base/Base.jl | 1 + base/threadingconstructs.jl | 8 ++++++ base/threads_overloads.jl | 51 +++++++++++++++++++++++++++++++++++++ test/threads_exec.jl | 35 +++++++++++++++++++++++++ 5 files changed, 98 insertions(+), 1 deletion(-) create mode 100644 base/threads_overloads.jl diff --git a/NEWS.md b/NEWS.md index f95768f41a5a0..9966811f8e1c7 100644 --- a/NEWS.md +++ b/NEWS.md @@ -14,7 +14,6 @@ Language changes Compiler/Runtime improvements ----------------------------- - * All platforms can now use `@executable_path` within `jl_load_dynamic_library()`. This allows executable-relative paths to be embedded within executables on all platforms, not just MacOS, which the syntax is borrowed from. ([#35627]) @@ -33,7 +32,9 @@ Build system changes New library functions --------------------- + * New function `Base.kron!` and corresponding overloads for various matrix types for performing Kronecker product in-place. ([#31069]). +* New function `Base.Threads.foreach(f, channel::Channel)` for multithreaded `Channel` consumption. ([#34543]). New library features -------------------- @@ -41,6 +42,7 @@ New library features Standard library changes ------------------------ + * The `nextprod` function now accepts tuples and other array types for its first argument ([#35791]). * The function `isapprox(x,y)` now accepts the `norm` keyword argument also for numeric (i.e., non-array) arguments `x` and `y` ([#35883]). * `view`, `@view`, and `@views` now work on `AbstractString`s, returning a `SubString` when appropriate ([#35879]). diff --git a/base/Base.jl b/base/Base.jl index 9c1cbe735e4fd..777e07bd30715 100644 --- a/base/Base.jl +++ b/base/Base.jl @@ -223,6 +223,7 @@ include("threads.jl") include("lock.jl") include("channels.jl") include("task.jl") +include("threads_overloads.jl") include("weakkeydict.jl") # Logging diff --git a/base/threadingconstructs.jl b/base/threadingconstructs.jl index aa122044a7881..56c4cbb13db5c 100644 --- a/base/threadingconstructs.jl +++ b/base/threadingconstructs.jl @@ -180,3 +180,11 @@ macro spawn(expr) end end end + +# This is a stub that can be overloaded for downstream structures like `Channel` +function foreach end + +# Scheduling traits that can be employed for downstream overloads +abstract type AbstractSchedule end +struct StaticSchedule <: AbstractSchedule end +struct FairSchedule <: AbstractSchedule end diff --git a/base/threads_overloads.jl b/base/threads_overloads.jl new file mode 100644 index 0000000000000..3e6ad06760747 --- /dev/null +++ b/base/threads_overloads.jl @@ -0,0 +1,51 @@ +# This file is a part of Julia. License is MIT: https://julialang.org/license + +""" + Threads.foreach(f, channel::Channel; + schedule::Threads.AbstractSchedule=Threads.FairSchedule(), + ntasks=Threads.nthreads()) + +Similar to `foreach(f, channel)`, but iteration over `channel` and calls to +`f` are split across `ntasks` tasks spawned by `Threads.@spawn`. This function +will wait for all internally spawned tasks to complete before returning. + +If `schedule isa FairSchedule`, `Threads.foreach` will attempt to spawn tasks in a +manner that enables Julia's scheduler to more freely load-balance work items across +threads. This approach generally has higher per-item overhead, but may perform +better than `StaticSchedule` in concurrence with other multithreaded workloads. + +If `schedule isa StaticSchedule`, `Threads.foreach` will spawn tasks in a manner +that incurs lower per-item overhead than `FairSchedule`, but is less amenable +to load-balancing. This approach thus may be more suitable for fine-grained, +uniform workloads, but may perform worse than `FairSchedule` in concurrence +with other multithreaded workloads. + +!!! compat "Julia 1.6" + This function requires Julia 1.6 or later. +""" +function Threads.foreach(f, channel::Channel; + schedule::Threads.AbstractSchedule=Threads.FairSchedule(), + ntasks=Threads.nthreads()) + apply = _apply_for_schedule(schedule) + stop = Threads.Atomic{Bool}(false) + @sync for _ in 1:ntasks + Threads.@spawn try + for item in channel + $apply(f, item) + # do `stop[] && break` after `f(item)` to avoid losing `item`. + # this isn't super comprehensive since a task could still get + # stuck on `take!` at `for item in channel`. We should think + # about a more robust mechanism to avoid dropping items. See also: + # https://github.com/JuliaLang/julia/pull/34543#discussion_r422695217 + stop[] && break + end + catch + stop[] = true + rethrow() + end + end + return nothing +end + +_apply_for_schedule(::Threads.StaticSchedule) = (f, x) -> f(x) +_apply_for_schedule(::Threads.FairSchedule) = (f, x) -> wait(Threads.@spawn f(x)) diff --git a/test/threads_exec.jl b/test/threads_exec.jl index 9022ce9f05ba0..691fca2fb2afa 100644 --- a/test/threads_exec.jl +++ b/test/threads_exec.jl @@ -845,3 +845,38 @@ fib34666(x) = f(x) end @test fib34666(25) == 75025 + +function jitter_channel(f, k, delay, ntasks, schedule) + x = Channel(ch -> foreach(i -> put!(ch, i), 1:k), 1) + y = Channel(k) do ch + g = i -> begin + iseven(i) && sleep(delay) + put!(ch, f(i)) + end + Threads.foreach(g, x; schedule=schedule, ntasks=ntasks) + end + return y +end + +@testset "Threads.foreach(f, ::Channel)" begin + k = 50 + delay = 0.01 + expected = sin.(1:k) + ordered_fair = collect(jitter_channel(sin, k, delay, 1, Threads.FairSchedule())) + ordered_static = collect(jitter_channel(sin, k, delay, 1, Threads.StaticSchedule())) + @test expected == ordered_fair + @test expected == ordered_static + + unordered_fair = collect(jitter_channel(sin, k, delay, 10, Threads.FairSchedule())) + unordered_static = collect(jitter_channel(sin, k, delay, 10, Threads.StaticSchedule())) + @test expected != unordered_fair + @test expected != unordered_static + @test Set(expected) == Set(unordered_fair) + @test Set(expected) == Set(unordered_static) + + ys = Channel() do ys + inner = Channel(xs -> foreach(i -> put!(xs, i), 1:3)) + Threads.foreach(x -> put!(ys, x), inner) + end + @test sort!(collect(ys)) == 1:3 +end