Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WaitGroup synchronization primitive #14167

Merged
merged 17 commits into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions spec/std/wait_group_spec.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
require "spec"
require "wait_group"

private def block_until_pending_waiter(wg)
while [email protected]?
Fiber.yield
end
end

private def forge_counter(wg, value)
[email protected](value)
end

describe WaitGroup do
describe "#add" do
it "can't decrement to a negative counter" do
wg = WaitGroup.new
wg.add(5)
wg.add(-3)
expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.add(-5) }
end

it "resumes waiters when reaching negative counter" do
wg = WaitGroup.new(1)
spawn do
block_until_pending_waiter(wg)
wg.add(-2)
rescue RuntimeError
end
expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.wait }
end

it "can't increment after reaching negative counter" do
wg = WaitGroup.new
forge_counter(wg, -1)

# check twice, to make sure the waitgroup counter wasn't incremented back
# to a positive value!
expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.add(5) }
expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.add(3) }
end
end

describe "#done" do
it "can't decrement to a negative counter" do
wg = WaitGroup.new
wg.add(1)
wg.done
expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.done }
end

it "resumes waiters when reaching negative counter" do
wg = WaitGroup.new(1)
spawn do
block_until_pending_waiter(wg)
forge_counter(wg, 0)
wg.done
ysbaddaden marked this conversation as resolved.
Show resolved Hide resolved
rescue RuntimeError
end
expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.wait }
end
end

describe "#wait" do
it "immediately returns when counter is zero" do
channel = Channel(Nil).new(1)

spawn do
wg = WaitGroup.new(0)
wg.wait
channel.send(nil)
end

select
when channel.receive
# success
when timeout(1.second)
fail "expected #wait to not block the fiber"
end
end

it "immediately raises when counter is negative" do
wg = WaitGroup.new(0)
expect_raises(RuntimeError) { wg.done }
expect_raises(RuntimeError, "Negative WaitGroup counter") { wg.wait }
end

it "raises when counter is positive after wake up" do
wg = WaitGroup.new(1)
waiter = Fiber.current

spawn do
block_until_pending_waiter(wg)
waiter.enqueue
end

expect_raises(RuntimeError, "Positive WaitGroup counter (early wake up?)") { wg.wait }
end
end

it "waits until concurrent executions are finished" do
wg1 = WaitGroup.new
wg2 = WaitGroup.new

8.times do
wg1.add(16)
wg2.add(16)
exited = Channel(Bool).new(16)

16.times do
spawn do
wg1.done
wg2.wait
exited.send(true)
end
end

wg1.wait

16.times do
select
when exited.receive
fail "WaitGroup released group too soon"
else
end
wg2.done
end
ysbaddaden marked this conversation as resolved.
Show resolved Hide resolved

16.times do
select
when x = exited.receive
x.should eq(true)
when timeout(1.millisecond)
fail "Expected channel to receive value"
end
end
end
end

it "increments the counter from executing fibers" do
wg = WaitGroup.new(16)
extra = Atomic(Int32).new(0)

16.times do
spawn do
wg.add(2)

2.times do
spawn do
extra.add(1)
wg.done
end
end

wg.done
end
end

wg.wait
extra.get.should eq(32)
end

# the test takes far too much time for the interpreter to complete
{% unless flag?(:interpreted) %}
it "stress add/done/wait" do
wg = WaitGroup.new

1000.times do
counter = Atomic(Int32).new(0)

2.times do
wg.add(1)

spawn do
counter.add(1)
wg.done
end
end

wg.wait
counter.get.should eq(2)
end
end
{% end %}
end
6 changes: 6 additions & 0 deletions src/crystal/pointer_linked_list.cr
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,10 @@ struct Crystal::PointerLinkedList(T)
node = _next
end
end

# Iterates the list before clearing it.
def consume_each(&) : Nil
each { |node| yield node }
@head = Pointer(T).null
end
end
120 changes: 120 additions & 0 deletions src/wait_group.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
require "fiber"
require "crystal/spin_lock"
require "crystal/pointer_linked_list"

# Suspend execution until a collection of fibers are finished.
#
# The wait group is a declarative counter of how many concurrent fibers have
# been started. Each such fiber is expected to call `#done` to report that they
# are finished doing their work. Whenever the counter reaches zero the waiters
# will be resumed.
#
# This is a simpler and more efficient alternative to using a `Channel(Nil)`
# then looping a number of times until we received N messages to resume
# execution.
#
# Basic example:
#
# ```
# require "wait_group"
# wg = WaitGroup.new(5)
#
# 5.times do
# spawn do
# do_something
# ensure
# wg.done # the fiber has finished
# end
# end
#
# # suspend the current fiber until the 5 fibers are done
# wg.wait
# ```
class WaitGroup
private struct Waiting
include Crystal::PointerLinkedList::Node

def initialize(@fiber : Fiber)
end

def enqueue : Nil
@fiber.enqueue
end
end

def initialize(n : Int32 = 0)
@waiting = Crystal::PointerLinkedList(Waiting).new
@lock = Crystal::SpinLock.new
@counter = Atomic(Int32).new(n)
end

# Increments the counter by how many fibers we want to wait for.
#
# A negative value decrements the counter. When the counter reaches zero,
ysbaddaden marked this conversation as resolved.
Show resolved Hide resolved
# all waiting fibers will be resumed.
# Raises `RuntimeError` if the counter reaches a negative value.
#
# Can be called at any time, allowing concurrent fibers to add more fibers to
# wait for, but they must always do so before calling `#done` that would
# decrement the counter, to make sure that the counter may never inadvertently
# reach zero before all fibers are done.
def add(n : Int32 = 1) : Nil
ysbaddaden marked this conversation as resolved.
Show resolved Hide resolved
counter = @counter.get(:acquire)

loop do
raise RuntimeError.new("Negative WaitGroup counter") if counter < 0

counter, success = @counter.compare_and_set(counter, counter + n, :acquire_release, :acquire)
break if success
end

new_counter = counter + n
return if new_counter > 0

@lock.sync do
@waiting.consume_each do |node|
node.value.enqueue
end
straight-shoota marked this conversation as resolved.
Show resolved Hide resolved
end

raise RuntimeError.new("Negative WaitGroup counter") if new_counter < 0
end

# Decrements the counter by one. Must be called by concurrent fibers once they
# have finished processing. When the counter reaches zero, all waiting fibers
# will be resumed.
ysbaddaden marked this conversation as resolved.
Show resolved Hide resolved
def done : Nil
add(-1)
end

# Suspends the current fiber until the counter reaches zero, at which point
# the fiber will be resumed.
#
# Can be called from different fibers.
def wait : Nil
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this donation very much! It will be very useful in many cases.

One proposal that probably can be done later on as a separate improvement - is to make the wait method compatible with select to support the following snippet:

select
  when wg.wait
    puts "All fibers done"
  when timeout(X.seconds)
    puts "Some fiber stuck"
end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe just have wg.wait(timeout: Time::Span | Nil) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bararchy We're missing a generic mechanism for timeouts... but we could abstract how it's implemented for select so that could be doable.

That doesn't mean we can't also integrate with select: we could wait on channel(s) + waitgroup(s) + timeout. Now, I'm not sure how to do that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ysbaddaden I think that @alexkutsan's idea is better, because then we don't need to handle a Timeout Exception in case that the Timeout happen, and instead handle it in select context which seems cleaner, like how channel works when calling "receive" etc..

So I think my idea is less clean tbh 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a commit to support WaitGroup in select expressions.

The integration wasn't complex after I understood how SelectAction and SelectContext are working, but the current implementation is very isolated to Channel (on purpose). Maybe the integration is not a good idea, but if proves to be a good idea, we might want to extract the select logic from Channel to the Crystal namespace.

I'll open a pull request after this one is merged, so we can have a proper discussion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ysbaddaden now that it's in and merged, are you planning to make the followup PR? 👁️

return if done?

waiting = Waiting.new(Fiber.current)

@lock.sync do
# must check again to avoid a race condition where #done may have
# decremented the counter to zero between the above check and #wait
# acquiring the lock; we'd push the current fiber to the wait list that
# would never be resumed (oops)
return if done?

@waiting.push(pointerof(waiting))
end

Crystal::Scheduler.reschedule

return if done?
raise RuntimeError.new("Positive WaitGroup counter (early wake up?)")
end

private def done?
counter = @counter.get(:acquire)
raise RuntimeError.new("Negative WaitGroup counter") if counter < 0
counter == 0
end
end
Loading