Skip to content

Commit

Permalink
Support a nil return as a stream option
Browse files Browse the repository at this point in the history
  • Loading branch information
turt2live committed Jun 7, 2023
1 parent 1bb8bb3 commit 4369785
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 2 deletions.
15 changes: 13 additions & 2 deletions sfstreams.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,17 @@ func (g *Group) doWork(key string, fn func() (io.ReadCloser, error)) func() (int
if chans, ok := g.calls[key]; !ok {
return nil, errors.New(fmt.Sprintf("expected to find singleflight key \"%s\", but didn't", key))
} else {
skipStream := fnRes == nil
writers := make([]io.Writer, 0) // they're actually PipeWriters, but the MultiWriter doesn't like that...
for _, ch := range chans {
if skipStream {
// This needs to be async to prevent a deadlock
go func(ch chan<- io.ReadCloser) {
ch <- nil
}(ch)
continue
}

r, w := io.Pipe()
writers = append(writers, w) // if `w` becomes a non-PipeWriter, fix `writers` array usage.

Expand All @@ -111,8 +120,10 @@ func (g *Group) doWork(key string, fn func() (io.ReadCloser, error)) func() (int
}
delete(g.calls, key) // we've done all we can for this call: clear it before we unlock

// Do the io copy async to prevent holding up other singleflight calls
go finishCopy(writers, fnRes)
if !skipStream {
// Do the io copy async to prevent holding up other singleflight calls
go finishCopy(writers, fnRes)
}

return nil, nil // we discard the return value
}
Expand Down
53 changes: 53 additions & 0 deletions sfstreams_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,32 @@ func TestDoDuplicates(t *testing.T) {
}
}

func TestDoNilReturn(t *testing.T) {
key, _, _ := makeStream()

callCount := 0
workFn := func() (io.ReadCloser, error) {
callCount++
return nil, nil
}

g := new(Group)
r, err, shared := g.Do(key, workFn)
if err != nil {
t.Fatal(err)
}
if shared {
t.Error("Expected a non-shared result")
}
if r != nil {
t.Error("Expected a nil result")
}

if callCount != 1 {
t.Errorf("Expected 1 call, got %d", callCount)
}
}

func TestDoChan(t *testing.T) {
key, expectedBytes, src := makeStream()

Expand Down Expand Up @@ -182,3 +208,30 @@ func TestDoChanError(t *testing.T) {
t.Error("Expected an error; Expected no reader")
}
}

func TestDoChanNilReturn(t *testing.T) {
key, _, _ := makeStream()

callCount := 0
workFn := func() (io.ReadCloser, error) {
callCount++
return nil, nil
}

g := new(Group)
ch := g.DoChan(key, workFn)
res := <-ch
if res.Err != nil {
t.Fatal(res.Err)
}
if res.Shared {
t.Error("Expected a non-shared result")
}
if res.Reader != nil {
t.Error("Expected a nil result")
}

if callCount != 1 {
t.Errorf("Expected 1 call, got %d", callCount)
}
}

0 comments on commit 4369785

Please sign in to comment.