From 43697859d2d6b5c316b019a320f13e55e8e13f72 Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Tue, 6 Jun 2023 22:01:18 -0600 Subject: [PATCH] Support a nil return as a stream option --- sfstreams.go | 15 ++++++++++++-- sfstreams_test.go | 53 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/sfstreams.go b/sfstreams.go index 8fe5de2..8f39039 100644 --- a/sfstreams.go +++ b/sfstreams.go @@ -99,8 +99,17 @@ func (g *Group) doWork(key string, fn func() (io.ReadCloser, error)) func() (int if chans, ok := g.calls[key]; !ok { return nil, errors.New(fmt.Sprintf("expected to find singleflight key \"%s\", but didn't", key)) } else { + skipStream := fnRes == nil writers := make([]io.Writer, 0) // they're actually PipeWriters, but the MultiWriter doesn't like that... for _, ch := range chans { + if skipStream { + // This needs to be async to prevent a deadlock + go func(ch chan<- io.ReadCloser) { + ch <- nil + }(ch) + continue + } + r, w := io.Pipe() writers = append(writers, w) // if `w` becomes a non-PipeWriter, fix `writers` array usage. @@ -111,8 +120,10 @@ func (g *Group) doWork(key string, fn func() (io.ReadCloser, error)) func() (int } delete(g.calls, key) // we've done all we can for this call: clear it before we unlock - // Do the io copy async to prevent holding up other singleflight calls - go finishCopy(writers, fnRes) + if !skipStream { + // Do the io copy async to prevent holding up other singleflight calls + go finishCopy(writers, fnRes) + } return nil, nil // we discard the return value } diff --git a/sfstreams_test.go b/sfstreams_test.go index 2cd2448..dce0346 100644 --- a/sfstreams_test.go +++ b/sfstreams_test.go @@ -127,6 +127,32 @@ func TestDoDuplicates(t *testing.T) { } } +func TestDoNilReturn(t *testing.T) { + key, _, _ := makeStream() + + callCount := 0 + workFn := func() (io.ReadCloser, error) { + callCount++ + return nil, nil + } + + g := new(Group) + r, err, shared := g.Do(key, workFn) + if err != nil { + t.Fatal(err) + } + if shared { + t.Error("Expected a non-shared result") + } + if r != nil { + t.Error("Expected a nil result") + } + + if callCount != 1 { + t.Errorf("Expected 1 call, got %d", callCount) + } +} + func TestDoChan(t *testing.T) { key, expectedBytes, src := makeStream() @@ -182,3 +208,30 @@ func TestDoChanError(t *testing.T) { t.Error("Expected an error; Expected no reader") } } + +func TestDoChanNilReturn(t *testing.T) { + key, _, _ := makeStream() + + callCount := 0 + workFn := func() (io.ReadCloser, error) { + callCount++ + return nil, nil + } + + g := new(Group) + ch := g.DoChan(key, workFn) + res := <-ch + if res.Err != nil { + t.Fatal(res.Err) + } + if res.Shared { + t.Error("Expected a non-shared result") + } + if res.Reader != nil { + t.Error("Expected a nil result") + } + + if callCount != 1 { + t.Errorf("Expected 1 call, got %d", callCount) + } +}