Skip to content

Commit

Permalink
Ensure streams are copied even if an error is returned
Browse files Browse the repository at this point in the history
  • Loading branch information
turt2live committed Jun 10, 2023
1 parent fb08c6a commit e70afe4
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 12 deletions.
14 changes: 2 additions & 12 deletions sfstreams.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@ func (g *Group) Do(key string, fn func() (io.ReadCloser, error)) (reader io.Read
g.mu.Unlock()

res := <-valCh
if res.Err != nil {
return nil, res.Err, res.Shared
}
return <-resCh, nil, res.Shared
return <-resCh, res.Err, res.Shared
}

// DoChan runs Group.Do, but returns a channel that will receive the results/stream when ready.
Expand Down Expand Up @@ -86,13 +83,6 @@ func (g *Group) doWork(key string, fn func() (io.ReadCloser, error)) func() (int
return func() (interface{}, error) {
fnRes, fnErr := fn()

if fnErr != nil {
g.mu.Lock()
delete(g.calls, key)
g.mu.Unlock()
return nil, fnErr
}

g.mu.Lock()
defer g.mu.Unlock()
g.sf.Forget(key) // we won't be processing future calls, so wrap it up
Expand Down Expand Up @@ -125,7 +115,7 @@ func (g *Group) doWork(key string, fn func() (io.ReadCloser, error)) func() (int
go finishCopy(writers, fnRes)
}

return nil, nil // we discard the return value
return nil, fnErr // we discard the return value
}
}
}
Expand Down
69 changes: 69 additions & 0 deletions sfstreams_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,40 @@ func TestDoNilReturn(t *testing.T) {
}
}

func TestDoErrorAndStream(t *testing.T) {
key, expectedBytes, src := makeStream()
expectedErr := errors.New("this is an error")

callCount := 0
workFn := func() (io.ReadCloser, error) {
callCount++
return src, expectedErr
}

g := new(Group)
r, err, shared := g.Do(key, workFn)
if err != expectedErr {
t.Error("Expected a different error")
}
if shared {
t.Error("Expected a non-shared result")
}
if r == src {
t.Error("Reader and source are the same")
}

//goland:noinspection GoUnhandledErrorResult
defer r.Close()
c, _ := io.Copy(io.Discard, r)
if c != expectedBytes {
t.Errorf("Read %d bytes but expected %d", c, expectedBytes)
}

if callCount != 1 {
t.Errorf("Expected 1 call, got %d", callCount)
}
}

func TestDoChan(t *testing.T) {
key, expectedBytes, src := makeStream()

Expand Down Expand Up @@ -235,3 +269,38 @@ func TestDoChanNilReturn(t *testing.T) {
t.Errorf("Expected 1 call, got %d", callCount)
}
}

func TestDoChanErrorAndStream(t *testing.T) {
key, expectedBytes, src := makeStream()
expectedError := errors.New("this is an error")

callCount := 0
workFn := func() (io.ReadCloser, error) {
callCount++
return src, expectedError
}

g := new(Group)
ch := g.DoChan(key, workFn)
res := <-ch
if res.Err != expectedError {
t.Error("Expected a different error")
}
if res.Shared {
t.Error("Expected a non-shared result")
}
if res.Reader == src {
t.Error("Reader and source are the same")
}

//goland:noinspection GoUnhandledErrorResult
defer res.Reader.Close()
c, _ := io.Copy(io.Discard, res.Reader)
if c != expectedBytes {
t.Errorf("Read %d bytes but expected %d", c, expectedBytes)
}

if callCount != 1 {
t.Errorf("Expected 1 call, got %d", callCount)
}
}

0 comments on commit e70afe4

Please sign in to comment.