diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 0000000..738c75c --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,17 @@ +name: Go package +on: [push] +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Setup Go + uses: actions/setup-go@v4 + with: + go-version: '1.19' + - name: Install dependencies + run: go get . + - name: Build + run: go build + - name: Test + run: go test \ No newline at end of file diff --git a/.gitignore b/.gitignore index 3b735ec..7b6fd0f 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,6 @@ # Go workspace file go.work + +# Custom +/.idea \ No newline at end of file diff --git a/README.md b/README.md index 1f8c3e6..612d90d 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,25 @@ # go-singleflight-streams -Go library to return a dedicated reader to each singleflight consumer. Useful for reading a source once, but sharing the result with many consumers. +Go library to return a dedicated reader to each singleflight consumer. Useful for reading a source once, +but sharing the result with many consumers. + +Example usage: + +```go +package main + +import ( + "io" + "github.com/t2bot/go-singleflight-streams" +) + +g := new(sfstreams.Group) + +workFn := func() (io.ReadCloser, error) { + // do your file download, thumbnailing, whatever here + return src, nil +} + +// in your various goroutines... +r, err, shared := g.Do("string key", workFn) +// do something with r (it'll be a unique instance) +``` diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..e294ea9 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module github.com/t2bot/go-singleflight-streams + +go 1.19 + +require golang.org/x/sync v0.2.0 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..4051f4f --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI= +golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/sfstreams.go b/sfstreams.go new file mode 100644 index 0000000..8fe5de2 --- /dev/null +++ b/sfstreams.go @@ -0,0 +1,157 @@ +package sfstreams + +import ( + "errors" + "fmt" + "io" + "sync" + + "golang.org/x/sync/singleflight" +) + +// ReaderResult carries the return values of Group.Do over the Group.DoChan channel. +type ReaderResult struct { + Err error + Reader io.ReadCloser + Shared bool +} + +// Group represents a singleflight stream group. This behaves just like a normal singleflight.Group, +// but guarantees a usable (distinct) io.ReadCloser to be returned for each call. +type Group struct { + sf singleflight.Group + mu sync.Mutex + calls map[string][]chan<- io.ReadCloser +} + +// Do behaves just like singleflight.Group, with the added guarantee that the returned io.ReadCloser +// is unique to the caller. Note that this uses an io.MultiWriter and io.Pipe instances, meaning that +// if one reader fails then all readers generated by the call will fail. The returned readers will +// discard any unread data upon being closed, preventing a single stream being closed ultimately closing +// all streams. +// +// The io.ReadCloser generated by fn is closed internally. +func (g *Group) Do(key string, fn func() (io.ReadCloser, error)) (reader io.ReadCloser, err error, shared bool) { + g.mu.Lock() + if g.calls == nil { + g.calls = make(map[string][]chan<- io.ReadCloser) + } + if _, ok := g.calls[key]; !ok { + g.calls[key] = make([]chan<- io.ReadCloser, 0) + } + resCh := make(chan io.ReadCloser) + defer close(resCh) + g.calls[key] = append(g.calls[key], resCh) + + valCh := g.sf.DoChan(key, g.doWork(key, fn)) + g.mu.Unlock() + + res := <-valCh + if res.Err != nil { + return nil, res.Err, res.Shared + } + return <-resCh, nil, res.Shared +} + +// DoChan runs Group.Do, but returns a channel that will receive the results/stream when ready. +// +// The returned channel is not closed. +func (g *Group) DoChan(key string, fn func() (io.ReadCloser, error)) <-chan ReaderResult { + ch := make(chan ReaderResult) + go func(ch chan ReaderResult, g *Group) { + r, err, shared := g.Do(key, fn) + ch <- ReaderResult{ + Err: err, + Reader: r, + Shared: shared, + } + }(ch, g) + return ch +} + +// Forget acts just like singleflight.Group. +func (g *Group) Forget(key string) { + g.mu.Lock() + if chans, ok := g.calls[key]; ok { + for _, ch := range chans { + close(ch) + } + } + delete(g.calls, key) + g.sf.Forget(key) + g.mu.Unlock() +} + +func (g *Group) doWork(key string, fn func() (io.ReadCloser, error)) func() (interface{}, error) { + return func() (interface{}, error) { + fnRes, fnErr := fn() + + if fnErr != nil { + g.mu.Lock() + delete(g.calls, key) + g.mu.Unlock() + return nil, fnErr + } + + g.mu.Lock() + defer g.mu.Unlock() + g.sf.Forget(key) // we won't be processing future calls, so wrap it up + if chans, ok := g.calls[key]; !ok { + return nil, errors.New(fmt.Sprintf("expected to find singleflight key \"%s\", but didn't", key)) + } else { + writers := make([]io.Writer, 0) // they're actually PipeWriters, but the MultiWriter doesn't like that... + for _, ch := range chans { + r, w := io.Pipe() + writers = append(writers, w) // if `w` becomes a non-PipeWriter, fix `writers` array usage. + + // This needs to be async to prevent a deadlock + go func(r io.ReadCloser, ch chan<- io.ReadCloser) { + ch <- NewDiscardCloser(r) + }(r, ch) + } + delete(g.calls, key) // we've done all we can for this call: clear it before we unlock + + // Do the io copy async to prevent holding up other singleflight calls + go finishCopy(writers, fnRes) + + return nil, nil // we discard the return value + } + } +} + +func finishCopy(writers []io.Writer, fnRes io.ReadCloser) { + //goland:noinspection GoUnhandledErrorResult + defer fnRes.Close() + mw := io.MultiWriter(writers...) + _, copyErr := io.Copy(mw, fnRes) + for _, w := range writers { + cw := w.(*io.PipeWriter) // guaranteed with above code in doWork + if copyErr != nil { + _ = cw.CloseWithError(copyErr) + } else { + _ = cw.Close() + } + } +} + +// DiscardCloser discards any remaining data on the underlying reader on close. +type DiscardCloser struct { + io.ReadCloser + r io.ReadCloser +} + +// NewDiscardCloser creates a new DiscardCloser from an input io.ReadCloser +func NewDiscardCloser(r io.ReadCloser) *DiscardCloser { + return &DiscardCloser{r: r} +} + +func (d *DiscardCloser) Read(p []byte) (int, error) { + return d.r.Read(p) +} + +func (d *DiscardCloser) Close() error { + if _, err := io.Copy(io.Discard, d.r); err != nil { + return err + } + return d.r.Close() +} diff --git a/sfstreams_test.go b/sfstreams_test.go new file mode 100644 index 0000000..2cd2448 --- /dev/null +++ b/sfstreams_test.go @@ -0,0 +1,184 @@ +package sfstreams + +import ( + "bytes" + "crypto/rand" + "errors" + "io" + "sync" + "testing" + "time" +) + +func makeStream() (key string, expectedBytes int64, src io.ReadCloser) { + key = "fake file" + + b := make([]byte, 16*1024) // 16kb + _, _ = rand.Read(b) + src = io.NopCloser(bytes.NewBuffer(b)) + + expectedBytes = int64(len(b)) + return +} + +func TestDo(t *testing.T) { + key, expectedBytes, src := makeStream() + + callCount := 0 + workFn := func() (io.ReadCloser, error) { + callCount++ + return src, nil + } + + g := new(Group) + r, err, shared := g.Do(key, workFn) + if err != nil { + t.Fatal(err) + } + if shared { + t.Error("Expected a non-shared result") + } + if r == src { + t.Error("Reader and source are the same") + } + + //goland:noinspection GoUnhandledErrorResult + defer r.Close() + c, _ := io.Copy(io.Discard, r) + if c != expectedBytes { + t.Errorf("Read %d bytes but expected %d", c, expectedBytes) + } + + if callCount != 1 { + t.Errorf("Expected 1 call, got %d", callCount) + } +} + +func TestDoError(t *testing.T) { + expectedErr := errors.New("this is expected") + callCount := 0 + workFn := func() (io.ReadCloser, error) { + callCount++ + return nil, expectedErr + } + + g := new(Group) + r, err, shared := g.Do("test", workFn) + if err != nil && err != expectedErr { + t.Fatal(err) + } + if shared { + t.Error("Expected a non-shared result") + } + if err == nil || r != nil { + t.Error("Expected an error; Expected no reader") + } +} + +func TestDoDuplicates(t *testing.T) { + key, expectedBytes, src := makeStream() + + workWg1 := new(sync.WaitGroup) + workWg2 := new(sync.WaitGroup) + workCh := make(chan int, 1) + callCount := 0 + workFn := func() (io.ReadCloser, error) { + callCount++ + if callCount == 1 { + workWg1.Done() + } + v := <-workCh + workCh <- v + time.Sleep(10 * time.Millisecond) + return src, nil + } + + g := &Group{} + readFn := func() { + defer workWg2.Done() + workWg1.Done() + r, err, _ := g.Do(key, workFn) + if err != nil { + t.Error(err) + return + } + c, err := io.Copy(io.Discard, r) + if err != nil { + t.Error(err) + return + } + if c != expectedBytes { + t.Errorf("Read %d bytes instead of %d", c, expectedBytes) + } + } + + const max = 10 + workWg1.Add(1) + for i := 0; i < max; i++ { + workWg1.Add(1) + workWg2.Add(1) + go readFn() + } + workWg1.Wait() + workCh <- 1 + workWg2.Wait() + if callCount <= 0 || callCount >= max { + t.Errorf("Expected between 1 and %d calls, got %d", max-1, callCount) + } +} + +func TestDoChan(t *testing.T) { + key, expectedBytes, src := makeStream() + + callCount := 0 + workFn := func() (io.ReadCloser, error) { + callCount++ + return src, nil + } + + g := new(Group) + ch := g.DoChan(key, workFn) + res := <-ch + if res.Err != nil { + t.Fatal(res.Err) + } + if res.Shared { + t.Error("Expected a non-shared result") + } + if res.Reader == src { + t.Error("Reader and source are the same") + } + + //goland:noinspection GoUnhandledErrorResult + defer res.Reader.Close() + c, _ := io.Copy(io.Discard, res.Reader) + if c != expectedBytes { + t.Errorf("Read %d bytes but expected %d", c, expectedBytes) + } + + if callCount != 1 { + t.Errorf("Expected 1 call, got %d", callCount) + } +} + +func TestDoChanError(t *testing.T) { + expectedErr := errors.New("this is expected") + callCount := 0 + workFn := func() (io.ReadCloser, error) { + callCount++ + return nil, expectedErr + } + + g := new(Group) + ch := g.DoChan("key", workFn) + res := <-ch + if res.Err != nil && res.Err != expectedErr { + t.Fatal(res.Err) + } + if res.Shared { + t.Error("Expected a non-shared result") + } + if res.Err == nil || res.Reader != nil { + t.Error("Expected an error; Expected no reader") + } +}