diff --git a/sfstreams.go b/sfstreams.go index 4a560b9..bdfe9da 100644 --- a/sfstreams.go +++ b/sfstreams.go @@ -125,23 +125,24 @@ func (g *Group) doWork(key string, fn func() (io.ReadCloser, error)) func() (int ch <- newSyncSeeker(parent) }(ch) } + return nil, fnErr // we intentionally discard the return value } - } else { - writers := make([]*io.PipeWriter, len(chans)) - for i, ch := range chans { - r, w := io.Pipe() - writers[i] = w + } - // This needs to be async to prevent a deadlock - go func(r io.ReadCloser, ch chan<- io.ReadCloser) { - ch <- newDiscardCloser(r) - }(r, ch) - } + writers := make([]*io.PipeWriter, len(chans)) + for i, ch := range chans { + r, w := io.Pipe() + writers[i] = w - // Do the io copy async to prevent holding up other singleflight calls - go finishCopy(writers, fnRes) + // This needs to be async to prevent a deadlock + go func(r io.ReadCloser, ch chan<- io.ReadCloser) { + ch <- newDiscardCloser(r) + }(r, ch) } + // Do the io copy async to prevent holding up other singleflight calls + go finishCopy(writers, fnRes) + return nil, fnErr // we intentionally discard the return value } } diff --git a/sfstreams_test.go b/sfstreams_test.go index 30e7fd8..407069b 100644 --- a/sfstreams_test.go +++ b/sfstreams_test.go @@ -486,6 +486,44 @@ func TestReturnsNoSeekerDefault(t *testing.T) { } } +func TestReturnsNoSeekerIfNotGiven(t *testing.T) { + key, expectedBytes, src := makeStream() + src = io.NopCloser(src) // lose the Seek interface + + callCount := 0 + workFn := func() (io.ReadCloser, error) { + callCount++ + return src, nil + } + + g := new(Group) + g.UseSeekers = true + r, err, shared := g.Do(key, workFn) + if err != nil { + t.Fatal(err) + } + if shared { + t.Error("Expected a non-shared result") + } + if r == src { + t.Error("Reader and source are the same") + } + if _, ok := r.(io.ReadSeekCloser); ok { + t.Error("Expected reader to *not* be a ReadSeekCloser") + } + + //goland:noinspection GoUnhandledErrorResult + defer r.Close() + c, _ := io.Copy(io.Discard, r) + if c != expectedBytes { + t.Errorf("Read %d bytes but expected %d", c, expectedBytes) + } + + if callCount != 1 { + t.Errorf("Expected 1 call, got %d", callCount) + } +} + func TestReturnsSeekerWhenEnabled(t *testing.T) { key, expectedBytes, src := makeStream()