Skip to content

Commit

Permalink
Make read notify use sync.Cond instead of chan
Browse files Browse the repository at this point in the history
  • Loading branch information
edaniels committed Jul 22, 2024
1 parent d55a60c commit e358b7f
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 24 deletions.
90 changes: 66 additions & 24 deletions packetio/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ type Buffer struct {
data []byte
head, tail int

notify chan struct{}
waiting bool
closed bool
closed bool

count int
limitCount, limitSize int

readDeadline *deadline.Deadline
readDeadline *deadline.Deadline
nextDeadline chan struct{}
readNotifier *sync.Cond
readChannelWatcherRunning sync.WaitGroup
}

const (
Expand All @@ -56,9 +57,35 @@ const (

// NewBuffer creates a new Buffer.
func NewBuffer() *Buffer {
return &Buffer{
notify: make(chan struct{}, 1),
buffer := &Buffer{
readDeadline: deadline.New(),
nextDeadline: make(chan struct{}, 1),
}
buffer.readNotifier = sync.NewCond(&buffer.mutex)
buffer.readChannelWatcherRunning.Add(1)
go buffer.readDeadlineWatcher()
return buffer
}

func (b *Buffer) readDeadlineWatcher() {
defer b.readChannelWatcherRunning.Done()
for {
select {
case <-b.readDeadline.Done():
b.mutex.Lock()
b.readNotifier.Broadcast()
b.mutex.Unlock()
case _, ok := <-b.nextDeadline:
if ok {
continue
}
return
}

_, ok := <-b.nextDeadline
if !ok {
return
}
}
}

Expand Down Expand Up @@ -173,15 +200,7 @@ func (b *Buffer) Write(packet []byte) (int, error) {
}
b.count++

waiting := b.waiting
b.waiting = false

if waiting {
select {
case b.notify <- struct{}{}:
default:
}
}
b.readNotifier.Signal()
b.mutex.Unlock()

return len(packet), nil
Expand All @@ -199,9 +218,8 @@ func (b *Buffer) Read(packet []byte) (n int, err error) { //nolint:gocognit
default:
}

b.mutex.Lock()
for {
b.mutex.Lock()

if b.head != b.tail {
// decode the packet size
n1 := b.data[b.head]
Expand Down Expand Up @@ -244,7 +262,6 @@ func (b *Buffer) Read(packet []byte) (n int, err error) { //nolint:gocognit
}

b.count--
b.waiting = false
b.mutex.Unlock()

if copied < count {
Expand All @@ -258,32 +275,47 @@ func (b *Buffer) Read(packet []byte) (n int, err error) { //nolint:gocognit
return 0, io.EOF
}

b.waiting = true
b.mutex.Unlock()

b.readNotifier.Wait()
select {
case <-b.readDeadline.Done():
b.mutex.Unlock()
return 0, &netError{ErrTimeout, true, true}
case <-b.notify:
default:
}
}
}

// Close the buffer, unblocking any pending reads.
// Data in the buffer can still be read, Read will return io.EOF only when empty.
func (b *Buffer) Close() (err error) {
return b.close(false)
}

// GracefulClose closes the buffer, unblocking any pending reads.
// Data in the buffer can still be read, Read will return io.EOF only when empty.
// It returns when any goroutines Buffer started have completed. This should not be called
// in any callbacks that may own a buffer unless a goroutine is spawned in that callback
// to call GracefulClose.
func (b *Buffer) GracefulClose() (err error) {
return b.close(true)
}

func (b *Buffer) close(graceful bool) error {
b.mutex.Lock()

if b.closed {
b.mutex.Unlock()
return nil
}

b.waiting = false
b.closed = true
close(b.notify)
close(b.nextDeadline)
b.readNotifier.Broadcast()
b.mutex.Unlock()

if graceful {
b.readChannelWatcherRunning.Wait()
}
return nil
}

Expand Down Expand Up @@ -338,6 +370,16 @@ func (b *Buffer) SetLimitSize(limit int) {
// SetReadDeadline sets the deadline for the Read operation.
// Setting to zero means no deadline.
func (b *Buffer) SetReadDeadline(t time.Time) error {
b.mutex.Lock()
defer b.mutex.Unlock()

b.readDeadline.Set(t)
select {
case b.nextDeadline <- struct{}{}:
default:
// if there is no receiver, then we know that readDeadlineWatcher
// is about to receive the buffered value in the channel. otherwise
// we communicated the next deadline to the receiver directly.
}
return nil
}
61 changes: 61 additions & 0 deletions packetio/buffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"testing"
"time"

"github.com/pion/transport/v3/test"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -584,6 +585,8 @@ func BenchmarkBuffer1400(b *testing.B) {
}

func TestBufferConcurrentRead(t *testing.T) {
defer test.TimeOut(time.Second * 5).Stop()

assert := assert.New(t)

buffer := NewBuffer()
Expand Down Expand Up @@ -626,3 +629,61 @@ func TestBufferConcurrentRead(t *testing.T) {
err = <-errCh
assert.Equal(io.EOF, err)
}

func TestBufferConcurrentReadWrite(t *testing.T) {
defer test.TimeOut(time.Second * 5).Stop()

assert := assert.New(t)

buffer := NewBuffer()
packet := make([]byte, 4)

errCh := make(chan error, 4)
readIntoErr := func() {
_, readErr := buffer.Read(packet)
errCh <- readErr
}
writeIntoErr := func() {
_, writeErr := buffer.Write([]byte{2, 3, 4})
errCh <- writeErr
}
go readIntoErr()
go readIntoErr()
go writeIntoErr()
go writeIntoErr()

// Close
err := buffer.Close()
assert.NoError(err)

// we just care that the reads and writes happen
for i := 0; i < 4; i++ {
<-errCh
}
}

func TestBufferReadDeadlineInSyncCond(t *testing.T) {
defer test.TimeOut(time.Second * 10).Stop()

assert := assert.New(t)

buffer := NewBuffer()

assert.NoError(buffer.SetReadDeadline(time.Now().Add(5 * time.Second))) // Set deadline to avoid test deadlock

// Start up a goroutine to start a blocking read.
readErr := make(chan error)
go func() {
packet := make([]byte, 4)
_, err := buffer.Read(packet)
readErr <- err
}()

err := <-readErr
var e net.Error
if !errors.As(err, &e) || !e.Timeout() {
t.Errorf("Unexpected error: %v", err)
}

assert.NoError(buffer.GracefulClose())
}

0 comments on commit e358b7f

Please sign in to comment.