diff --git a/README.md b/README.md index e323d08..bd5f2fc 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ func main() { } else { // This shouldn't happen in this example fmt.Println("WARN: Response was not shared!") - } + } // We discard here, but a more realistic handling might be to stream // the response to a user. diff --git a/seeker.go b/seeker.go index 3c164e3..46ccc25 100644 --- a/seeker.go +++ b/seeker.go @@ -1,6 +1,7 @@ package sfstreams import ( + "errors" "io" "sync" ) @@ -42,28 +43,47 @@ type downstreamSeeker struct { io.ReadSeekCloser parent *parentSeeker pos int64 + eof bool + eofPos int64 + closed bool } func newSyncSeeker(parent *parentSeeker) *downstreamSeeker { return &downstreamSeeker{ parent: parent, pos: 0, + eof: false, + eofPos: 0, + closed: false, } } func (s *downstreamSeeker) Read(b []byte) (int, error) { + if s.closed { + return 0, io.ErrClosedPipe + } s.parent.mutex.Lock() defer s.parent.mutex.Unlock() + if s.eof && s.pos == s.eofPos { + return 0, io.EOF + } offset, err := s.parent.Seek(s.pos, io.SeekStart) if err != nil { return 0, err } i, err := s.parent.Read(b) s.pos = offset + int64(i) + if err != nil && errors.Is(err, io.EOF) { + s.eof = true + s.eofPos = s.pos + } return i, err } func (s *downstreamSeeker) Seek(offset int64, whence int) (int64, error) { + if s.closed { + return 0, io.ErrClosedPipe + } s.parent.mutex.Lock() defer s.parent.mutex.Unlock() offset, err := s.parent.Seek(offset, whence) @@ -76,5 +96,6 @@ func (s *downstreamSeeker) Seek(offset int64, whence int) (int64, error) { func (s *downstreamSeeker) Close() error { s.parent.closeWg.Done() + s.closed = true return nil } diff --git a/seeker_test.go b/seeker_test.go new file mode 100644 index 0000000..17faa35 --- /dev/null +++ b/seeker_test.go @@ -0,0 +1,166 @@ +package sfstreams + +import ( + "bytes" + "crypto/rand" + "errors" + "io" + "testing" +) + +type bufCloser struct { + io.ReadSeeker + io.Closer +} + +func (b *bufCloser) Close() error { + return nil // no-op +} + +func createSource(length int64, t *testing.T) (io.ReadSeekCloser, []byte) { + buf := make([]byte, length) + i, err := rand.Read(buf) + if err != nil { + panic(err) + } + if int64(i) != length { + t.Fatal("did not read enough random bytes") + } + return &bufCloser{ReadSeeker: bytes.NewReader(buf)}, buf +} + +func TestDuplicateReads(t *testing.T) { + rsc, b := createSource(1024, t) + ps := newParentSeeker(rsc, 2) + s1 := newSyncSeeker(ps) + s2 := newSyncSeeker(ps) + + // Seek to different places + _, err := s1.Seek(512, io.SeekStart) + if err != nil { + t.Fatal(err) + } + _, err = s2.Seek(128, io.SeekStart) + if err != nil { + t.Fatal(err) + } + + // Read a segment of bytes from each + b1 := make([]byte, 128) + b2 := make([]byte, 128) + _, err = s1.Read(b1) + if err != nil { + t.Fatal(err) + } + _, err = s2.Read(b2) + if err != nil { + t.Fatal(err) + } + + // Compare each segment to ensure the correct thing was read + for i := 512; i < 640; i++ { + if b1[i-512] != b[i] { + t.Fatalf("byte %d in segment 1 is incorrect", i) + } + } + for i := 128; i < 256; i++ { + if b2[i-128] != b[i] { + t.Fatalf("byte %d in segment 2 is incorrect", i) + } + } +} + +func TestOverRead(t *testing.T) { + rsc, _ := createSource(1024, t) + ps := newParentSeeker(rsc, 1) + s1 := newSyncSeeker(ps) + + // Discard the whole stream + _, err := io.Copy(io.Discard, s1) + if err != nil { + t.Fatal(err) + } + + // Read from it again + b := make([]byte, 128) + i, err := s1.Read(b) + if !errors.Is(err, io.EOF) { + t.Fatal(err) + } + if i > 0 { + t.Fatalf("expected to read zero bytes, got %d", i) + } +} + +type badStream struct { + io.ReadSeekCloser + pos int64 + source *bytes.Reader +} + +func (b *badStream) Read(buf []byte) (int, error) { + if b.pos >= int64(b.source.Len()) { + return 0, errors.New("the requested range cannot be satisfied") + } + i, err := b.source.Read(buf) + if !errors.Is(err, io.EOF) && (int64(i)+b.pos) >= int64(b.source.Len()) { + return i, io.EOF + } + return i, err +} + +func (b *badStream) Seek(offset int64, whence int) (int64, error) { + p, err := b.source.Seek(offset, whence) + b.pos = p + return p, err +} + +func (b *badStream) Close() error { + return nil // no-op +} + +func TestImproperSourceOverRead(t *testing.T) { + _, b := createSource(1024, t) + bs := &badStream{source: bytes.NewReader(b)} + ps := newParentSeeker(bs, 1) + s1 := newSyncSeeker(ps) + + // Discard the whole stream + _, err := io.Copy(io.Discard, s1) + if err != nil { + t.Fatal(err) + } + + // Read from it again + rb := make([]byte, 128) + i, err := s1.Read(rb) + if !errors.Is(err, io.EOF) { + t.Fatal(err) + } + if i > 0 { + t.Fatalf("expected to read zero bytes, got %d", i) + } +} + +func TestUseAfterClose(t *testing.T) { + rsc, _ := createSource(1024, t) + ps := newParentSeeker(rsc, 1) + s1 := newSyncSeeker(ps) + + // Close the whole thing + err := s1.Close() + if err != nil { + t.Fatal(err) + } + + // Now try to read/seek from it + b := make([]byte, 128) + _, err = s1.Read(b) + if !errors.Is(err, io.ErrClosedPipe) { + t.Fatal(err) + } + _, err = s1.Seek(12, io.SeekStart) + if !errors.Is(err, io.ErrClosedPipe) { + t.Fatal(err) + } +}