diff --git a/modules/zstd/zstd.go b/modules/zstd/zstd.go index 4b7cdea2b1dd3..be5c01be63f58 100644 --- a/modules/zstd/zstd.go +++ b/modules/zstd/zstd.go @@ -10,73 +10,49 @@ import ( "github.com/klauspost/compress/zstd" ) -type Writer zstd.Encoder +type Writer struct { + enc *zstd.Encoder -var _ io.WriteCloser = (*Writer)(nil) - -func NewWriter(w io.Writer, opts ...WriterOption) (*Writer, error) { - zstdW, err := zstd.NewWriter(w, opts...) - if err != nil { - return nil, err - } - return (*Writer)(zstdW), nil -} - -func (w *Writer) Write(p []byte) (int, error) { - return (*zstd.Encoder)(w).Write(p) -} - -func (w *Writer) Close() error { - return (*zstd.Encoder)(w).Close() + skw seekable.Writer + buf []byte + n int } -type Reader zstd.Decoder - -var _ io.ReadCloser = (*Reader)(nil) +var _ io.WriteCloser = (*Writer)(nil) -func NewReader(r io.Reader, opts ...ReaderOption) (*Reader, error) { - zstdR, err := zstd.NewReader(r, opts...) +func NewWriter(w io.Writer, opts ...WriterOption) (*Writer, error) { + enc, err := zstd.NewWriter(w, opts...) if err != nil { return nil, err } - return (*Reader)(zstdR), nil -} - -func (r *Reader) Read(p []byte) (int, error) { - return (*zstd.Decoder)(r).Read(p) -} - -func (r *Reader) Close() error { - (*zstd.Decoder)(r).Close() // no error returned - return nil -} - -type SeekableWriter struct { - buf []byte - n int - w seekable.Writer + return &Writer{ + enc: enc, + }, nil } -var _ io.WriteCloser = (*SeekableWriter)(nil) - -func NewSeekableWriter(w io.Writer, blockSize int, opts ...WriterOption) (*SeekableWriter, error) { - zstdW, err := zstd.NewWriter(nil, opts...) +func NewSeekableWriter(w io.Writer, blockSize int, opts ...WriterOption) (*Writer, error) { + enc, err := zstd.NewWriter(nil, opts...) if err != nil { return nil, err } - seekableW, err := seekable.NewWriter(w, zstdW) + skw, err := seekable.NewWriter(w, enc) if err != nil { return nil, err } - return &SeekableWriter{ + return &Writer{ + enc: enc, + skw: skw, buf: make([]byte, blockSize), - w: seekableW, }, nil } -func (w *SeekableWriter) Write(p []byte) (int, error) { +func (w *Writer) Write(p []byte) (int, error) { + if w.skw != nil { + return w.enc.Write(p) + } + written := 0 for len(p) > 0 { n := copy(w.buf[w.n:], p) @@ -85,7 +61,7 @@ func (w *SeekableWriter) Write(p []byte) (int, error) { p = p[n:] if w.n == len(w.buf) { - if _, err := w.w.Write(w.buf); err != nil { + if _, err := w.skw.Write(w.buf); err != nil { return written, err } w.n = 0 @@ -94,13 +70,48 @@ func (w *SeekableWriter) Write(p []byte) (int, error) { return written, nil } -func (w *SeekableWriter) Close() error { - if w.n > 0 { - if _, err := w.w.Write(w.buf[:w.n]); err != nil { +func (w *Writer) Close() error { + if w.skw != nil { + if w.n > 0 { + if _, err := w.skw.Write(w.buf[:w.n]); err != nil { + return err + } + } + if err := w.skw.Close(); err != nil { return err } } - return w.w.Close() + return w.enc.Close() +} + +type Reader struct { + dec *zstd.Decoder + skr seekable.Reader +} + +var _ io.ReadCloser = (*Reader)(nil) + +func NewReader(r io.Reader, opts ...ReaderOption) (*Reader, error) { + dec, err := zstd.NewReader(r, opts...) + if err != nil { + return nil, err + } + return &Reader{ + dec: dec, + }, nil +} + +func (r *Reader) Read(p []byte) (int, error) { + return r.dec.Read(p) +} + +func (r *Reader) Close() error { + r.dec.Close() // no error returned + return nil +} + +func (r *Reader) SeekReader() (seekable.Reader, error) { + return r.skr } type SeekableReader seekable.Reader