Skip to content

Commit

Permalink
Support maintaining io.ReadSeekClosers through singleflight
Browse files Browse the repository at this point in the history
  • Loading branch information
turt2live committed Sep 7, 2023
1 parent 3269fb2 commit 025e46b
Show file tree
Hide file tree
Showing 3 changed files with 274 additions and 18 deletions.
80 changes: 80 additions & 0 deletions seeker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package sfstreams

import (
"io"
"sync"
)

type parentSeeker struct {
io.ReadSeekCloser
underlying io.ReadSeekCloser
mutex *sync.Mutex
closeWg *sync.WaitGroup
}

func newParentSeeker(src io.ReadSeekCloser, downstreamReaders int) *parentSeeker {
wg := new(sync.WaitGroup)
wg.Add(downstreamReaders)
go func() {
wg.Wait()
_ = src.Close()
}()
return &parentSeeker{
underlying: src,
mutex: new(sync.Mutex),
closeWg: wg,
}
}

func (p *parentSeeker) Read(b []byte) (int, error) {
return p.underlying.Read(b)
}

func (p *parentSeeker) Seek(offset int64, whence int) (int64, error) {
return p.underlying.Seek(offset, whence)
}

func (p *parentSeeker) Close() error {
return p.underlying.Close()
}

type downstreamSeeker struct {
io.ReadSeekCloser
parent *parentSeeker
pos int64
}

func newSyncSeeker(parent *parentSeeker) *downstreamSeeker {
return &downstreamSeeker{
parent: parent,
pos: 0,
}
}

func (s *downstreamSeeker) Read(b []byte) (int, error) {
s.parent.mutex.Lock()
defer s.parent.mutex.Unlock()
offset, err := s.parent.Seek(s.pos, io.SeekStart)
if err != nil {
return 0, err
}
i, err := s.parent.Read(b)
s.pos = offset + int64(i)
return i, err
}

func (s *downstreamSeeker) Seek(offset int64, whence int) (int64, error) {
s.parent.mutex.Lock()
defer s.parent.mutex.Unlock()
offset, err := s.parent.Seek(offset, whence)
if err != nil {
return s.pos, err
}
s.pos = offset
return s.pos, nil
}

func (s *downstreamSeeker) Close() error {
s.parent.closeWg.Done()
return nil
}
53 changes: 39 additions & 14 deletions sfstreams.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@ type Group struct {
sf singleflight.Group
mu sync.Mutex
calls map[string][]chan<- io.ReadCloser

// Normally, the Group will copy the work function's returned reader, but in some cases it is
// desirable to maintain the io.Seeker interface. When this option is set to true, the Group
// no longer copies but instead returns proxy io.ReadSeekCloser readers that track their own
// read state. Whenever one of those readers seeks/reads, it synchronously does so on the work
// function's returned io.ReadSeekCloser. This can lead to performance bottlenecks if several
// call sites are attempting to read/seek at the same time.
//
// If this is set to true, but the work function doesn't return an io.ReadSeekCloser, the copy
// behaviour is used. When false (the default), the copy behaviour is always used.
UseSeekers bool
}

// Do behaves just like singleflight.Group, with the added guarantee that the returned io.ReadCloser
Expand Down Expand Up @@ -92,32 +103,46 @@ func (g *Group) doWork(key string, fn func() (io.ReadCloser, error)) func() (int
} else {
var zero io.ReadCloser
canStream := fnRes != nil && fnRes != zero
writers := make([]*io.PipeWriter, 0)
for _, ch := range chans {
if !canStream {

delete(g.calls, key) // we've done all we can for this call: clear it before we unlock

if !canStream {
for _, ch := range chans {
// This needs to be async to prevent a deadlock
go func(ch chan<- io.ReadCloser) {
ch <- nil
}(ch)
continue
}
return nil, fnErr // we intentionally discard the return value
}

r, w := io.Pipe()
writers = append(writers, w)
if g.UseSeekers {
if rsc, ok := fnRes.(io.ReadSeekCloser); ok {
parent := newParentSeeker(rsc, len(chans))
for _, ch := range chans {
// This needs to be async to prevent a deadlock
go func(ch chan<- io.ReadCloser) {
ch <- newSyncSeeker(parent)
}(ch)
}
}
} else {
writers := make([]*io.PipeWriter, len(chans))
for i, ch := range chans {
r, w := io.Pipe()
writers[i] = w

// This needs to be async to prevent a deadlock
go func(r io.ReadCloser, ch chan<- io.ReadCloser) {
ch <- newDiscardCloser(r)
}(r, ch)
}
delete(g.calls, key) // we've done all we can for this call: clear it before we unlock
// This needs to be async to prevent a deadlock
go func(r io.ReadCloser, ch chan<- io.ReadCloser) {
ch <- newDiscardCloser(r)
}(r, ch)
}

if canStream {
// Do the io copy async to prevent holding up other singleflight calls
go finishCopy(writers, fnRes)
}

return nil, fnErr // we discard the return value
return nil, fnErr // we intentionally discard the return value
}
}
}
Expand Down
159 changes: 155 additions & 4 deletions sfstreams_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,24 @@ import (
"time"
)

type nopReadSeekCloser struct {
io.ReadSeeker
}

func (nopReadSeekCloser) Close() error {
return nil
}

func nopSeekCloser(rs io.ReadSeeker) io.ReadSeekCloser {
return nopReadSeekCloser{ReadSeeker: rs}
}

func makeStream() (key string, expectedBytes int64, src io.ReadCloser) {
key = "fake file"

b := make([]byte, 16*1024) // 16kb
_, _ = rand.Read(b)
src = io.NopCloser(bytes.NewBuffer(b))
src = nopSeekCloser(bytes.NewReader(b))

expectedBytes = int64(len(b))
return
Expand Down Expand Up @@ -93,7 +105,7 @@ func TestDoDuplicates(t *testing.T) {
return src, nil
}

g := &Group{}
g := new(Group)
readFn := func() {
defer workWg2.Done()
workWg1.Done()
Expand Down Expand Up @@ -323,7 +335,7 @@ func TestStallOnRead(t *testing.T) {
return src, nil
}

g := &Group{}
g := new(Group)
readFn := func(i int) {
defer workWg2.Done()
workWg1.Done()
Expand Down Expand Up @@ -386,7 +398,7 @@ func TestFasterRead(t *testing.T) {
return src, nil
}

g := &Group{}
g := new(Group)
readFn := func(i int) {
defer workWg2.Done()
workWg1.Done()
Expand Down Expand Up @@ -437,3 +449,142 @@ func TestFasterRead(t *testing.T) {
t.Errorf("Expected between 1 and %d calls, got %d", max-1, callCount)
}
}

func TestReturnsNoSeekerDefault(t *testing.T) {
key, expectedBytes, src := makeStream()

callCount := 0
workFn := func() (io.ReadCloser, error) {
callCount++
return src, nil
}

g := new(Group)
r, err, shared := g.Do(key, workFn)
if err != nil {
t.Fatal(err)
}
if shared {
t.Error("Expected a non-shared result")
}
if r == src {
t.Error("Reader and source are the same")
}
if _, ok := r.(io.ReadSeekCloser); ok {
t.Error("Expected reader to *not* be a ReadSeekCloser")
}

//goland:noinspection GoUnhandledErrorResult
defer r.Close()
c, _ := io.Copy(io.Discard, r)
if c != expectedBytes {
t.Errorf("Read %d bytes but expected %d", c, expectedBytes)
}

if callCount != 1 {
t.Errorf("Expected 1 call, got %d", callCount)
}
}

func TestReturnsSeekerWhenEnabled(t *testing.T) {
key, expectedBytes, src := makeStream()

callCount := 0
workFn := func() (io.ReadCloser, error) {
callCount++
return src, nil
}

g := new(Group)
g.UseSeekers = true
r, err, shared := g.Do(key, workFn)
if err != nil {
t.Fatal(err)
}
if shared {
t.Error("Expected a non-shared result")
}
if r == src {
t.Error("Reader and source are the same")
}
if _, ok := r.(io.ReadSeekCloser); !ok {
t.Error("Expected reader to be a ReadSeekCloser")
}

//goland:noinspection GoUnhandledErrorResult
defer r.Close()
c, _ := io.Copy(io.Discard, r)
if c != expectedBytes {
t.Errorf("Read %d bytes but expected %d", c, expectedBytes)
}

if callCount != 1 {
t.Errorf("Expected 1 call, got %d", callCount)
}
}

func TestSeekerUsesParent(t *testing.T) {
key, expectedBytes, src := makeStream()

skip := int64(10)

workWg1 := new(sync.WaitGroup)
workWg2 := new(sync.WaitGroup)
workCh := make(chan int, 1)
callCount := 0
workFn := func() (io.ReadCloser, error) {
callCount++
if callCount == 1 {
workWg1.Done()
}
v := <-workCh
workCh <- v
time.Sleep(10 * time.Millisecond)
return src, nil
}

g := new(Group)
g.UseSeekers = true
readFn := func(i int) {
localSkip := skip + int64(i)
defer workWg2.Done()
workWg1.Done()
r, err, _ := g.Do(key, workFn)
if err != nil {
t.Error(err)
return
}
rsc := r.(io.ReadSeekCloser)
offset, err := rsc.Seek(localSkip, io.SeekStart)
if err != nil {
t.Error(err)
return
}
if offset != localSkip {
t.Errorf("Expected seek to %d instead of %d", localSkip, offset)
}
c, err := io.Copy(io.Discard, rsc)
if err != nil {
t.Error(err)
return
}
if c != (expectedBytes - localSkip) {
t.Errorf("Read %d bytes instead of %d", c, expectedBytes)
}
}

const max = 10
workWg1.Add(1)
for i := 0; i < max; i++ {
workWg1.Add(1)
workWg2.Add(1)
go readFn(i)
}
workWg1.Wait()
workCh <- 1
workWg2.Wait()
if callCount <= 0 || callCount >= max {
t.Errorf("Expected between 1 and %d calls, got %d", max-1, callCount)
}

}

0 comments on commit 025e46b

Please sign in to comment.