Skip to content

Commit

Permalink
Use a custom async multiwriter
Browse files Browse the repository at this point in the history
  • Loading branch information
turt2live committed Aug 20, 2023
1 parent c15cc78 commit 063a7f1
Showing 1 changed file with 115 additions and 12 deletions.
127 changes: 115 additions & 12 deletions sfstreams.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sfstreams

import (
"errors"
"fmt"
"io"
"sync"
Expand Down Expand Up @@ -90,7 +91,7 @@ func (g *Group) doWork(key string, fn func() (io.ReadCloser, error)) func() (int
return nil, fmt.Errorf("expected to find singleflight key \"%s\", but didn't", key)
} else {
skipStream := fnRes == nil
writers := make([]io.Writer, 0) // they're actually PipeWriters, but the MultiWriter doesn't like that...
writers := make([]*io.PipeWriter, 0)
for _, ch := range chans {
if skipStream {
// This needs to be async to prevent a deadlock
Expand All @@ -101,7 +102,7 @@ func (g *Group) doWork(key string, fn func() (io.ReadCloser, error)) func() (int
}

r, w := io.Pipe()
writers = append(writers, w) // if `w` becomes a non-PipeWriter, fix `writers` array usage.
writers = append(writers, w)

// This needs to be async to prevent a deadlock
go func(r io.ReadCloser, ch chan<- io.ReadCloser) {
Expand All @@ -120,20 +121,16 @@ func (g *Group) doWork(key string, fn func() (io.ReadCloser, error)) func() (int
}
}

func finishCopy(writers []io.Writer, fnRes io.ReadCloser) {
func finishCopy(writers []*io.PipeWriter, fnRes io.ReadCloser) {
defer func(fnRes io.ReadCloser) {
_ = fnRes.Close()
}(fnRes)
mw := io.MultiWriter(writers...)

// Dev note: Errors are raised through the pipe writers using CloseWithError, which
// should make them available on the pipe readers. We can consume them here.
mw := newAsyncMultiWriter(writers...)
_, copyErr := io.Copy(mw, fnRes)
for _, w := range writers {
cw := w.(*io.PipeWriter) // guaranteed with above code in doWork
if copyErr != nil {
_ = cw.CloseWithError(copyErr)
} else {
_ = cw.Close()
}
}
_ = mw.CloseWithMaybeError(copyErr)
}

// discardCloser discards any remaining data on the underlying reader on close.
Expand All @@ -157,3 +154,109 @@ func (d *discardCloser) Close() error {
}
return d.r.Close()
}

type asyncMultiWriter struct {
io.WriteCloser
writers []*io.PipeWriter
skipFlags []bool
mu *sync.Mutex
}

func newAsyncMultiWriter(writers ...*io.PipeWriter) *asyncMultiWriter {
return &asyncMultiWriter{
writers: writers,
skipFlags: make([]bool, len(writers)),
mu: new(sync.Mutex),
}
}

type writeResponse struct {
i int
err error
n int
}

func (a *asyncMultiWriter) Write(p []byte) (int, error) {
a.mu.Lock()
defer a.mu.Unlock()

ch := make(chan *writeResponse, len(a.writers))
wg := new(sync.WaitGroup)
c := 0
for i, _ := range a.writers {
if a.skipFlags[i] {
continue
}
wg.Add(1)
c += 1
go func(i int, p []byte, a *asyncMultiWriter, ch chan *writeResponse) {
defer wg.Done()
w := a.writers[i]
n, err := w.Write(p)
ch <- &writeResponse{
i: i,
err: err,
n: n,
}
}(i, p, a, ch)
}

wg.Wait()

maxRead := 0
for i := 0; i < c; i++ {
res := <-ch
if res.n > maxRead {
maxRead = res.n
}
if res.err != nil {
w := a.writers[res.i]
_ = w.CloseWithError(res.err)
a.skipFlags[i] = true
}
}

return maxRead, nil
}

func (a *asyncMultiWriter) Close() error {
a.mu.Lock()
defer a.mu.Unlock()

errs := make([]error, 0)
for i, w := range a.writers {
a.skipFlags[i] = true
err := w.Close()
if err != nil {
errs = append(errs, err)
}
}

if len(errs) > 0 {
return errors.Join(errs...)
}
return nil
}

func (a *asyncMultiWriter) CloseWithMaybeError(inError error) error {
if inError == nil {
return a.Close()
}

a.mu.Lock()
defer a.mu.Unlock()

errs := make([]error, 0)
for i, w := range a.writers {
a.skipFlags[i] = true
err := w.CloseWithError(inError)
if err != nil {
errs = append(errs, err)
}
}

if len(errs) > 0 {
return errors.Join(errs...)
}
return nil
}

0 comments on commit 063a7f1

Please sign in to comment.