diff --git a/seeker.go b/seeker.go new file mode 100644 index 0000000..3c164e3 --- /dev/null +++ b/seeker.go @@ -0,0 +1,80 @@ +package sfstreams + +import ( + "io" + "sync" +) + +type parentSeeker struct { + io.ReadSeekCloser + underlying io.ReadSeekCloser + mutex *sync.Mutex + closeWg *sync.WaitGroup +} + +func newParentSeeker(src io.ReadSeekCloser, downstreamReaders int) *parentSeeker { + wg := new(sync.WaitGroup) + wg.Add(downstreamReaders) + go func() { + wg.Wait() + _ = src.Close() + }() + return &parentSeeker{ + underlying: src, + mutex: new(sync.Mutex), + closeWg: wg, + } +} + +func (p *parentSeeker) Read(b []byte) (int, error) { + return p.underlying.Read(b) +} + +func (p *parentSeeker) Seek(offset int64, whence int) (int64, error) { + return p.underlying.Seek(offset, whence) +} + +func (p *parentSeeker) Close() error { + return p.underlying.Close() +} + +type downstreamSeeker struct { + io.ReadSeekCloser + parent *parentSeeker + pos int64 +} + +func newSyncSeeker(parent *parentSeeker) *downstreamSeeker { + return &downstreamSeeker{ + parent: parent, + pos: 0, + } +} + +func (s *downstreamSeeker) Read(b []byte) (int, error) { + s.parent.mutex.Lock() + defer s.parent.mutex.Unlock() + offset, err := s.parent.Seek(s.pos, io.SeekStart) + if err != nil { + return 0, err + } + i, err := s.parent.Read(b) + s.pos = offset + int64(i) + return i, err +} + +func (s *downstreamSeeker) Seek(offset int64, whence int) (int64, error) { + s.parent.mutex.Lock() + defer s.parent.mutex.Unlock() + offset, err := s.parent.Seek(offset, whence) + if err != nil { + return s.pos, err + } + s.pos = offset + return s.pos, nil +} + +func (s *downstreamSeeker) Close() error { + s.parent.closeWg.Done() + return nil +} diff --git a/sfstreams.go b/sfstreams.go index df4d083..4a560b9 100644 --- a/sfstreams.go +++ b/sfstreams.go @@ -22,6 +22,17 @@ type Group struct { sf singleflight.Group mu sync.Mutex calls map[string][]chan<- io.ReadCloser + + // Normally, the Group will copy the work function's returned reader, but in some cases it is + // desirable to maintain the io.Seeker interface. When this option is set to true, the Group + // no longer copies but instead returns proxy io.ReadSeekCloser readers that track their own + // read state. Whenever one of those readers seeks/reads, it synchronously does so on the work + // function's returned io.ReadSeekCloser. This can lead to performance bottlenecks if several + // call sites are attempting to read/seek at the same time. + // + // If this is set to true, but the work function doesn't return an io.ReadSeekCloser, the copy + // behaviour is used. When false (the default), the copy behaviour is always used. + UseSeekers bool } // Do behaves just like singleflight.Group, with the added guarantee that the returned io.ReadCloser @@ -92,32 +103,46 @@ func (g *Group) doWork(key string, fn func() (io.ReadCloser, error)) func() (int } else { var zero io.ReadCloser canStream := fnRes != nil && fnRes != zero - writers := make([]*io.PipeWriter, 0) - for _, ch := range chans { - if !canStream { + + delete(g.calls, key) // we've done all we can for this call: clear it before we unlock + + if !canStream { + for _, ch := range chans { // This needs to be async to prevent a deadlock go func(ch chan<- io.ReadCloser) { ch <- nil }(ch) - continue } + return nil, fnErr // we intentionally discard the return value + } - r, w := io.Pipe() - writers = append(writers, w) + if g.UseSeekers { + if rsc, ok := fnRes.(io.ReadSeekCloser); ok { + parent := newParentSeeker(rsc, len(chans)) + for _, ch := range chans { + // This needs to be async to prevent a deadlock + go func(ch chan<- io.ReadCloser) { + ch <- newSyncSeeker(parent) + }(ch) + } + } + } else { + writers := make([]*io.PipeWriter, len(chans)) + for i, ch := range chans { + r, w := io.Pipe() + writers[i] = w - // This needs to be async to prevent a deadlock - go func(r io.ReadCloser, ch chan<- io.ReadCloser) { - ch <- newDiscardCloser(r) - }(r, ch) - } - delete(g.calls, key) // we've done all we can for this call: clear it before we unlock + // This needs to be async to prevent a deadlock + go func(r io.ReadCloser, ch chan<- io.ReadCloser) { + ch <- newDiscardCloser(r) + }(r, ch) + } - if canStream { // Do the io copy async to prevent holding up other singleflight calls go finishCopy(writers, fnRes) } - return nil, fnErr // we discard the return value + return nil, fnErr // we intentionally discard the return value } } } diff --git a/sfstreams_test.go b/sfstreams_test.go index a2a63a1..30e7fd8 100644 --- a/sfstreams_test.go +++ b/sfstreams_test.go @@ -10,12 +10,24 @@ import ( "time" ) +type nopReadSeekCloser struct { + io.ReadSeeker +} + +func (nopReadSeekCloser) Close() error { + return nil +} + +func nopSeekCloser(rs io.ReadSeeker) io.ReadSeekCloser { + return nopReadSeekCloser{ReadSeeker: rs} +} + func makeStream() (key string, expectedBytes int64, src io.ReadCloser) { key = "fake file" b := make([]byte, 16*1024) // 16kb _, _ = rand.Read(b) - src = io.NopCloser(bytes.NewBuffer(b)) + src = nopSeekCloser(bytes.NewReader(b)) expectedBytes = int64(len(b)) return @@ -93,7 +105,7 @@ func TestDoDuplicates(t *testing.T) { return src, nil } - g := &Group{} + g := new(Group) readFn := func() { defer workWg2.Done() workWg1.Done() @@ -323,7 +335,7 @@ func TestStallOnRead(t *testing.T) { return src, nil } - g := &Group{} + g := new(Group) readFn := func(i int) { defer workWg2.Done() workWg1.Done() @@ -386,7 +398,7 @@ func TestFasterRead(t *testing.T) { return src, nil } - g := &Group{} + g := new(Group) readFn := func(i int) { defer workWg2.Done() workWg1.Done() @@ -437,3 +449,142 @@ func TestFasterRead(t *testing.T) { t.Errorf("Expected between 1 and %d calls, got %d", max-1, callCount) } } + +func TestReturnsNoSeekerDefault(t *testing.T) { + key, expectedBytes, src := makeStream() + + callCount := 0 + workFn := func() (io.ReadCloser, error) { + callCount++ + return src, nil + } + + g := new(Group) + r, err, shared := g.Do(key, workFn) + if err != nil { + t.Fatal(err) + } + if shared { + t.Error("Expected a non-shared result") + } + if r == src { + t.Error("Reader and source are the same") + } + if _, ok := r.(io.ReadSeekCloser); ok { + t.Error("Expected reader to *not* be a ReadSeekCloser") + } + + //goland:noinspection GoUnhandledErrorResult + defer r.Close() + c, _ := io.Copy(io.Discard, r) + if c != expectedBytes { + t.Errorf("Read %d bytes but expected %d", c, expectedBytes) + } + + if callCount != 1 { + t.Errorf("Expected 1 call, got %d", callCount) + } +} + +func TestReturnsSeekerWhenEnabled(t *testing.T) { + key, expectedBytes, src := makeStream() + + callCount := 0 + workFn := func() (io.ReadCloser, error) { + callCount++ + return src, nil + } + + g := new(Group) + g.UseSeekers = true + r, err, shared := g.Do(key, workFn) + if err != nil { + t.Fatal(err) + } + if shared { + t.Error("Expected a non-shared result") + } + if r == src { + t.Error("Reader and source are the same") + } + if _, ok := r.(io.ReadSeekCloser); !ok { + t.Error("Expected reader to be a ReadSeekCloser") + } + + //goland:noinspection GoUnhandledErrorResult + defer r.Close() + c, _ := io.Copy(io.Discard, r) + if c != expectedBytes { + t.Errorf("Read %d bytes but expected %d", c, expectedBytes) + } + + if callCount != 1 { + t.Errorf("Expected 1 call, got %d", callCount) + } +} + +func TestSeekerUsesParent(t *testing.T) { + key, expectedBytes, src := makeStream() + + skip := int64(10) + + workWg1 := new(sync.WaitGroup) + workWg2 := new(sync.WaitGroup) + workCh := make(chan int, 1) + callCount := 0 + workFn := func() (io.ReadCloser, error) { + callCount++ + if callCount == 1 { + workWg1.Done() + } + v := <-workCh + workCh <- v + time.Sleep(10 * time.Millisecond) + return src, nil + } + + g := new(Group) + g.UseSeekers = true + readFn := func(i int) { + localSkip := skip + int64(i) + defer workWg2.Done() + workWg1.Done() + r, err, _ := g.Do(key, workFn) + if err != nil { + t.Error(err) + return + } + rsc := r.(io.ReadSeekCloser) + offset, err := rsc.Seek(localSkip, io.SeekStart) + if err != nil { + t.Error(err) + return + } + if offset != localSkip { + t.Errorf("Expected seek to %d instead of %d", localSkip, offset) + } + c, err := io.Copy(io.Discard, rsc) + if err != nil { + t.Error(err) + return + } + if c != (expectedBytes - localSkip) { + t.Errorf("Read %d bytes instead of %d", c, expectedBytes) + } + } + + const max = 10 + workWg1.Add(1) + for i := 0; i < max; i++ { + workWg1.Add(1) + workWg2.Add(1) + go readFn(i) + } + workWg1.Wait() + workCh <- 1 + workWg2.Wait() + if callCount <= 0 || callCount >= max { + t.Errorf("Expected between 1 and %d calls, got %d", max-1, callCount) + } + +}