From 1afef855a533807deefc8d50ea2a489b0dfdbcaf Mon Sep 17 00:00:00 2001 From: Joel Howse Date: Sat, 19 Feb 2022 15:00:06 +1300 Subject: [PATCH 1/7] [feat] don't require io.Seeker for identify --- formats.go | 110 +++++++++++++++++++++++------------------- formats_test.go | 76 +++++++++++++++++++++++++++++ fs.go | 2 +- header_reader.go | 96 ++++++++++++++++++++++++++++++++++++ header_reader_test.go | 40 +++++++++++++++ 5 files changed, 274 insertions(+), 50 deletions(-) create mode 100644 formats_test.go create mode 100644 header_reader.go create mode 100644 header_reader_test.go diff --git a/formats.go b/formats.go index c80b31e4..922e85d2 100644 --- a/formats.go +++ b/formats.go @@ -2,6 +2,7 @@ package archiver import ( "context" + "errors" "fmt" "io" "strings" @@ -24,10 +25,15 @@ func RegisterFormat(format Format) { // value can be type-asserted to ascertain its capabilities. // // If no matching formats were found, special error ErrNoMatch is returned. -func Identify(filename string, stream io.ReadSeeker) (Format, error) { +// +// The returned io.Reader will always be non-nil, and will read from the same point +// as the reader which was passed in. +func Identify(filename string, stream io.Reader) (Format, io.Reader, error) { var compression Compression var archival Archival + headerReader := newHeaderReader(stream) + // try compression format first, since that's the outer "layer" for name, format := range formats { cf, isCompression := format.(Compression) @@ -35,9 +41,9 @@ func Identify(filename string, stream io.ReadSeeker) (Format, error) { continue } - matchResult, err := identifyOne(format, filename, stream, nil) + matchResult, err := identifyOne(format, filename, headerReader, nil) if err != nil { - return nil, fmt.Errorf("matching %s: %w", name, err) + return nil, headerReader.Reader(), fmt.Errorf("matching %s: %w", name, err) } // if matched, wrap input stream with decompression @@ -49,52 +55,39 @@ func Identify(filename string, stream io.ReadSeeker) (Format, error) { } // try archive format next - for name, format := range formats { - af, isArchive := format.(Archival) - if !isArchive { - continue - } + // for name, format := range formats { + // af, isArchive := format.(Archival) + // if !isArchive { + // continue + // } - matchResult, err := identifyOne(format, filename, stream, compression) - if err != nil { - return nil, fmt.Errorf("matching %s: %w", name, err) - } + // matchResult, err := identifyOne(format, filename, headerReader, compression) + // if err != nil { + // return nil, headerReader.Reader(), fmt.Errorf("matching %s: %w", name, err) + // } - if matchResult.Matched() { - archival = af - break - } - } + // if matchResult.Matched() { + // archival = af + // break + // } + // } + // the stream should be rewound by identifyOne + streamOut := headerReader.Reader() switch { case compression != nil && archival == nil: - return compression, nil + return compression, streamOut, nil case compression == nil && archival != nil: - return archival, nil + return archival, streamOut, nil case compression != nil && archival != nil: - return CompressedArchive{compression, archival}, nil + return CompressedArchive{compression, archival}, streamOut, nil default: - return nil, ErrNoMatch + return nil, streamOut, ErrNoMatch } } -func identifyOne(format Format, filename string, stream io.ReadSeeker, comp Compression) (MatchResult, error) { - if stream == nil { - // shimming an empty stream is easier than hoping every format's - // implementation of Match() expects and handles a nil stream - stream = strings.NewReader("") - } - - // reset stream position to beginning, then restore current position when done - previousOffset, err := stream.Seek(0, io.SeekCurrent) - if err != nil { - return MatchResult{}, err - } - _, err = stream.Seek(0, io.SeekStart) - if err != nil { - return MatchResult{}, err - } - defer stream.Seek(previousOffset, io.SeekStart) +func identifyOne(format Format, filename string, stream *headerReader, comp Compression) (mr MatchResult, err error) { + defer stream.Rewind() // if looking within a compressed format, wrap the stream in a // reader that can decompress it so we can match the "inner" format @@ -102,21 +95,22 @@ func identifyOne(format Format, filename string, stream io.ReadSeeker, comp Comp // because we reset/seek the stream each time and that can mess up // the compression reader's state if we don't discard it also) if comp != nil { - decompressedStream, err := comp.OpenReader(stream) - if err != nil { - return MatchResult{}, err + decompressedStream, openErr := comp.OpenReader(stream) + if openErr != nil { + return MatchResult{}, openErr } defer decompressedStream.Close() - stream = struct { - io.Reader - io.Seeker - }{ - Reader: decompressedStream, - Seeker: stream, - } + mr, err = format.Match(filename, decompressedStream) + } else { + mr, err = format.Match(filename, stream) } - return format.Match(filename, stream) + // if the error is EOF, we can just ignore it. + // Just means we have a small input file. + if errors.Is(err, io.EOF) { + err = nil + } + return mr, err } // CompressedArchive combines a compression format on top of an archive @@ -230,6 +224,24 @@ type MatchResult struct { // Matched returns true if a match was made by either name or stream. func (mr MatchResult) Matched() bool { return mr.ByName || mr.ByStream } +// Compare match results returning 0 if the values are the same, 1 if this match +// ir stronger than the other, and -1 if the other match is stronger. +func (mr MatchResult) compare(other MatchResult) int { + if mr.ByStream && !other.ByStream { + return 1 + } + if other.ByStream && !mr.ByStream { + return -1 + } + if mr.ByName && !other.ByName { + return 1 + } + if other.ByName && !mr.ByName { + return -1 + } + return 0 +} + // ErrNoMatch is returned if there are no matching formats. var ErrNoMatch = fmt.Errorf("no formats matched") diff --git a/formats_test.go b/formats_test.go new file mode 100644 index 00000000..de02a378 --- /dev/null +++ b/formats_test.go @@ -0,0 +1,76 @@ +package archiver + +import ( + "bytes" + "io" + "math/rand" + "testing" + "time" +) + +func checkErr(t *testing.T, err error, msgFmt string, args ...interface{}) { + t.Helper() + if err == nil { + return + } + args = append(args, err) + t.Fatalf(msgFmt+": %s", args...) +} + +func TestCompression(t *testing.T) { + seed := time.Now().UnixNano() + t.Logf("seed: %d", seed) + r := rand.New(rand.NewSource(seed)) + + contents := make([]byte, 1024) + r.Read(contents) + + compressed := new(bytes.Buffer) + + testOK := func(t *testing.T, comp Compression, testFilename string) { + // compress into buffer + compressed.Reset() + wc, err := comp.OpenWriter(compressed) + checkErr(t, err, "opening writer") + _, err = wc.Write(contents) + checkErr(t, err, "writing contents") + checkErr(t, wc.Close(), "closing writer") + + format, stream, err := Identify(testFilename, compressed) + checkErr(t, err, "identifying") + + if format.Name() != comp.Name() { + t.Fatalf("expected format %s but got %s", comp.Name(), format.Name()) + } + + decompReader, err := format.(Decompressor).OpenReader(stream) + checkErr(t, err, "opening with decompressor '%s'", format.Name()) + + data, err := io.ReadAll(decompReader) + checkErr(t, err, "reading decompressed data") + checkErr(t, decompReader.Close(), "closing decompressor") + + if !bytes.Equal(data, contents) { + t.Fatalf("not equal to original") + } + } + + var cannotIdentifyFromStream = map[string]bool{Brotli{}.Name(): true} + + for _, f := range formats { + // only test compressors + comp, ok := f.(Compression) + if !ok { + continue + } + + t.Run(f.Name()+"_with_extension", func(t *testing.T) { + testOK(t, comp, "file"+f.Name()) + }) + if !cannotIdentifyFromStream[f.Name()] { + t.Run(f.Name()+"_without_extension", func(t *testing.T) { + testOK(t, comp, "") + }) + } + } +} diff --git a/fs.go b/fs.go index 274595af..eb819d75 100644 --- a/fs.go +++ b/fs.go @@ -46,7 +46,7 @@ func FileSystem(root string) (fs.FS, error) { return nil, err } defer file.Close() - format, err := Identify(filepath.Base(root), file) + format, _, err := Identify(filepath.Base(root), file) if err != nil && !errors.Is(err, ErrNoMatch) { return nil, err } diff --git a/header_reader.go b/header_reader.go new file mode 100644 index 00000000..23260937 --- /dev/null +++ b/header_reader.go @@ -0,0 +1,96 @@ +package archiver + +import ( + "bytes" + "errors" + "io" + "strings" +) + +var ( + errReaderFrozen = errors.New("Reader() has been called and reads are now frozen") +) + +// headerReader will read from an underlying reader but buffer all the calls +// to Read(). You are then able to reset the reader by calling Rewind() which is equivalent +// to Seek(0,0). This reader does not implement the io.Seeker interface because any other calls +// to Seek would be inefficient and would not be supported by this reader. +// +// Once the header has been read and rewound as much as you would like, call Reader() to +// get a reader that will no longer buffer calls to read. The internal buffer would +// be drained then calls would be redirected back to the underlying reader. +// When calling Reader(), the returned reader will read from the current cursor position. +// Call Rewind() first to reset the cursor to the start of the stream. +type headerReader struct { + pos int + buf []byte + + // sticky error + err error + + r io.Reader +} + +func newHeaderReader(r io.Reader) *headerReader { + const initialBufferSize = 128 + + // make sure the underlying reader is non-nil + if r == nil { + r = strings.NewReader("") + } + + return &headerReader{ + buf: make([]byte, 0, initialBufferSize), + r: r, + } +} + +func (s *headerReader) Read(data []byte) (n int, err error) { + if s.err != nil && s.err != io.EOF { + return 0, s.err + } + + // if this read is asking for more data than we have buffered + // then load more data from the underlying reader into the buffer + if s.pos+len(data) > len(s.buf) { + s.readUptoNMore(s.pos + len(data) - len(s.buf)) + } + + // copy whats in the buffer into the data slice + n = copy(data, s.buf[s.pos:]) + s.pos += n + + return n, s.err +} + +// Rewind sets the pointer back to the start of the stream. +// Any following calls to Read will come from the start of the stream again +func (s *headerReader) Rewind() { s.pos = 0 } + +// Reader returns a reader which will read from the current position in +// the buffer onwards. Use Rewind() first to reset to the start of the +// stream. +// +// Once this function has been called, any subsequent reads to the stream +// header reader will result in ErrReaderFrozen being returned. +func (s *headerReader) Reader() io.Reader { + s.err = errReaderFrozen + return io.MultiReader(bytes.NewReader(s.buf[s.pos:]), s.r) +} + +// readUptoNMore will read at most n more bytes from the underlying +// reader, storing them into the buffer. The position will not be +// updated but the buffer will be grown. +func (s *headerReader) readUptoNMore(n int) { + // grow the buffer by the amount of additional data we need + l := len(s.buf) + s.buf = append(s.buf, make([]byte, n)...) + + // We could call io.ReadFull here, but instead just let the + // behaviour of the underlying reader determine how the reads + // are handled. + n, s.err = s.r.Read(s.buf[l:]) + + // if we read less, make sure the buffer is trimmed + s.buf = s.buf[:l+n] +} diff --git a/header_reader_test.go b/header_reader_test.go new file mode 100644 index 00000000..4441ba4b --- /dev/null +++ b/header_reader_test.go @@ -0,0 +1,40 @@ +package archiver + +import ( + "bytes" + "io" + "testing" +) + +func TestStreamHeaderReader(t *testing.T) { + data := []byte("the header\nthe body\n") + + r := newHeaderReader(bytes.NewReader(data)) + + buf := make([]byte, 10) // enough for 'the header' + + // test rewinding reads + for i := 0; i < 10; i++ { + r.Rewind() + _, err := r.Read(buf) + if err != nil { + t.Fatalf("Read failed: %s", err) + } + if string(buf) != "the header" { + t.Fatalf("expected 'the header' but got '%s'", string(buf)) + } + } + + // get the reader from header reader and make sure we can read all of the data out + r.Rewind() + finalReader := r.Reader() + buf = make([]byte, len(data)) + _, err := io.ReadFull(finalReader, buf) + if err != nil { + t.Fatalf("ReadFull failed: %s", err) + } + + if string(buf) != string(data) { + t.Fatalf("expected '%s' but got '%s'", string(data), string(buf)) + } +} From b8c94ac953850151d90b5462e76ba3ef4fdd32f9 Mon Sep 17 00:00:00 2001 From: Joel Howse Date: Sat, 19 Feb 2022 15:04:28 +1300 Subject: [PATCH 2/7] tidy up --- formats.go | 46 ++++++++++++++-------------------------------- formats_test.go | 5 ++--- 2 files changed, 16 insertions(+), 35 deletions(-) diff --git a/formats.go b/formats.go index 922e85d2..233313bb 100644 --- a/formats.go +++ b/formats.go @@ -55,22 +55,22 @@ func Identify(filename string, stream io.Reader) (Format, io.Reader, error) { } // try archive format next - // for name, format := range formats { - // af, isArchive := format.(Archival) - // if !isArchive { - // continue - // } + for name, format := range formats { + af, isArchive := format.(Archival) + if !isArchive { + continue + } - // matchResult, err := identifyOne(format, filename, headerReader, compression) - // if err != nil { - // return nil, headerReader.Reader(), fmt.Errorf("matching %s: %w", name, err) - // } + matchResult, err := identifyOne(format, filename, headerReader, compression) + if err != nil { + return nil, headerReader.Reader(), fmt.Errorf("matching %s: %w", name, err) + } - // if matchResult.Matched() { - // archival = af - // break - // } - // } + if matchResult.Matched() { + archival = af + break + } + } // the stream should be rewound by identifyOne streamOut := headerReader.Reader() @@ -224,24 +224,6 @@ type MatchResult struct { // Matched returns true if a match was made by either name or stream. func (mr MatchResult) Matched() bool { return mr.ByName || mr.ByStream } -// Compare match results returning 0 if the values are the same, 1 if this match -// ir stronger than the other, and -1 if the other match is stronger. -func (mr MatchResult) compare(other MatchResult) int { - if mr.ByStream && !other.ByStream { - return 1 - } - if other.ByStream && !mr.ByStream { - return -1 - } - if mr.ByName && !other.ByName { - return 1 - } - if other.ByName && !mr.ByName { - return -1 - } - return 0 -} - // ErrNoMatch is returned if there are no matching formats. var ErrNoMatch = fmt.Errorf("no formats matched") diff --git a/formats_test.go b/formats_test.go index de02a378..8c0cb3e9 100644 --- a/formats_test.go +++ b/formats_test.go @@ -36,20 +36,19 @@ func TestCompression(t *testing.T) { checkErr(t, err, "writing contents") checkErr(t, wc.Close(), "closing writer") + // make sure Identify correctly chooses this compression method format, stream, err := Identify(testFilename, compressed) checkErr(t, err, "identifying") - if format.Name() != comp.Name() { t.Fatalf("expected format %s but got %s", comp.Name(), format.Name()) } + // read the contents back out and compare decompReader, err := format.(Decompressor).OpenReader(stream) checkErr(t, err, "opening with decompressor '%s'", format.Name()) - data, err := io.ReadAll(decompReader) checkErr(t, err, "reading decompressed data") checkErr(t, decompReader.Close(), "closing decompressor") - if !bytes.Equal(data, contents) { t.Fatalf("not equal to original") } From 1ed903c82f68fb6a0b58d5e2e2fc7af20aea88c7 Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Thu, 17 Mar 2022 15:53:33 -0600 Subject: [PATCH 3/7] Refactor and simplify with some bug fixes --- formats.go | 95 ++++++++++++++++++++++++++++++++++++------ formats_test.go | 45 +++++++++++++++++--- header_reader.go | 96 ------------------------------------------- header_reader_test.go | 40 ------------------ 4 files changed, 122 insertions(+), 154 deletions(-) delete mode 100644 header_reader.go delete mode 100644 header_reader_test.go diff --git a/formats.go b/formats.go index 233313bb..0e4d893c 100644 --- a/formats.go +++ b/formats.go @@ -1,6 +1,7 @@ package archiver import ( + "bytes" "context" "errors" "fmt" @@ -32,7 +33,7 @@ func Identify(filename string, stream io.Reader) (Format, io.Reader, error) { var compression Compression var archival Archival - headerReader := newHeaderReader(stream) + rewindableStream := newRewindReader(stream) // try compression format first, since that's the outer "layer" for name, format := range formats { @@ -41,9 +42,9 @@ func Identify(filename string, stream io.Reader) (Format, io.Reader, error) { continue } - matchResult, err := identifyOne(format, filename, headerReader, nil) + matchResult, err := identifyOne(format, filename, rewindableStream, nil) if err != nil { - return nil, headerReader.Reader(), fmt.Errorf("matching %s: %w", name, err) + return nil, rewindableStream.reader(), fmt.Errorf("matching %s: %w", name, err) } // if matched, wrap input stream with decompression @@ -61,9 +62,9 @@ func Identify(filename string, stream io.Reader) (Format, io.Reader, error) { continue } - matchResult, err := identifyOne(format, filename, headerReader, compression) + matchResult, err := identifyOne(format, filename, rewindableStream, compression) if err != nil { - return nil, headerReader.Reader(), fmt.Errorf("matching %s: %w", name, err) + return nil, rewindableStream.reader(), fmt.Errorf("matching %s: %w", name, err) } if matchResult.Matched() { @@ -73,21 +74,21 @@ func Identify(filename string, stream io.Reader) (Format, io.Reader, error) { } // the stream should be rewound by identifyOne - streamOut := headerReader.Reader() + bufferedStream := rewindableStream.reader() switch { case compression != nil && archival == nil: - return compression, streamOut, nil + return compression, bufferedStream, nil case compression == nil && archival != nil: - return archival, streamOut, nil + return archival, bufferedStream, nil case compression != nil && archival != nil: - return CompressedArchive{compression, archival}, streamOut, nil + return CompressedArchive{compression, archival}, bufferedStream, nil default: - return nil, streamOut, ErrNoMatch + return nil, bufferedStream, ErrNoMatch } } -func identifyOne(format Format, filename string, stream *headerReader, comp Compression) (mr MatchResult, err error) { - defer stream.Rewind() +func identifyOne(format Format, filename string, stream *rewindReader, comp Compression) (mr MatchResult, err error) { + defer stream.rewind() // if looking within a compressed format, wrap the stream in a // reader that can decompress it so we can match the "inner" format @@ -224,6 +225,76 @@ type MatchResult struct { // Matched returns true if a match was made by either name or stream. func (mr MatchResult) Matched() bool { return mr.ByName || mr.ByStream } +// rewindReader is a Reader that can be rewound (reset) to re-read what +// was already read and then continue to read more from the underlying +// stream. When no more rewinding is necessary, call reader() to get a +// new reader that first reads the buffered bytes, then continues to +// read from the stream. This is useful for "peeking" a stream an +// arbitrary number of bytes. Loosely based on the Connection type +// from https://github.com/mholt/caddy-l4. +type rewindReader struct { + io.Reader + buf *bytes.Buffer + bufReader io.Reader +} + +func newRewindReader(r io.Reader) *rewindReader { + return &rewindReader{ + Reader: r, + buf: new(bytes.Buffer), + } +} + +func (rr *rewindReader) Read(p []byte) (n int, err error) { + // if there is a buffer we should read from, start + // with that; we only read from the underlying stream + // after the buffer has been "depleted" + if rr.bufReader != nil { + n, err = rr.bufReader.Read(p) + if err == io.EOF { + rr.bufReader = nil + err = nil + } + if n == len(p) { + return + } + } + + // buffer has been "depleted" so read from + // underlying connection + nr, err := rr.Reader.Read(p[n:]) + + // anything that was read needs to be written to + // the buffer, even if there was an error + if nr > 0 { + if nw, errw := rr.buf.Write(p[n : n+nr]); errw != nil { + return nw, errw + } + } + + // up to now, n was how many bytes were read from + // the buffer, and nr was how many bytes were read + // from the stream; add them to return total count + n += nr + + return +} + +// rewind resets the stream to the beginning by causing +// Read() to start reading from the beginning of the +// buffered bytes. +func (rr *rewindReader) rewind() { + rr.bufReader = bytes.NewReader(rr.buf.Bytes()) +} + +// reader returns a reader that reads first from the buffered +// bytes, then from the underlying stream. After calling this, +// no more rewinding is allowed since reads from the stream are +// not recorded, so rewinding properly is impossible. +func (rr *rewindReader) reader() io.Reader { + return io.MultiReader(bytes.NewReader(rr.buf.Bytes()), rr.Reader) +} + // ErrNoMatch is returned if there are no matching formats. var ErrNoMatch = fmt.Errorf("no formats matched") diff --git a/formats_test.go b/formats_test.go index 8c0cb3e9..1b1c3b2d 100644 --- a/formats_test.go +++ b/formats_test.go @@ -4,17 +4,41 @@ import ( "bytes" "io" "math/rand" + "strings" "testing" "time" ) -func checkErr(t *testing.T, err error, msgFmt string, args ...interface{}) { - t.Helper() - if err == nil { - return +func TestRewindReader(t *testing.T) { + data := "the header\nthe body\n" + + r := newRewindReader(strings.NewReader(data)) + + buf := make([]byte, 10) // enough for 'the header' + + // test rewinding reads + for i := 0; i < 10; i++ { + r.rewind() + n, err := r.Read(buf) + if err != nil { + t.Fatalf("Read failed: %s", err) + } + if string(buf[:n]) != "the header" { + t.Fatalf("iteration %d: expected 'the header' but got '%s' (n=%d)", i, string(buf[:n]), n) + } + } + + // get the reader from header reader and make sure we can read all of the data out + r.rewind() + finalReader := r.reader() + buf = make([]byte, len(data)) + n, err := io.ReadFull(finalReader, buf) + if err != nil { + t.Fatalf("ReadFull failed: %s (n=%d)", err, n) + } + if string(buf) != data { + t.Fatalf("expected '%s' but got '%s'", string(data), string(buf)) } - args = append(args, err) - t.Fatalf(msgFmt+": %s", args...) } func TestCompression(t *testing.T) { @@ -73,3 +97,12 @@ func TestCompression(t *testing.T) { } } } + +func checkErr(t *testing.T, err error, msgFmt string, args ...interface{}) { + t.Helper() + if err == nil { + return + } + args = append(args, err) + t.Fatalf(msgFmt+": %s", args...) +} diff --git a/header_reader.go b/header_reader.go deleted file mode 100644 index 23260937..00000000 --- a/header_reader.go +++ /dev/null @@ -1,96 +0,0 @@ -package archiver - -import ( - "bytes" - "errors" - "io" - "strings" -) - -var ( - errReaderFrozen = errors.New("Reader() has been called and reads are now frozen") -) - -// headerReader will read from an underlying reader but buffer all the calls -// to Read(). You are then able to reset the reader by calling Rewind() which is equivalent -// to Seek(0,0). This reader does not implement the io.Seeker interface because any other calls -// to Seek would be inefficient and would not be supported by this reader. -// -// Once the header has been read and rewound as much as you would like, call Reader() to -// get a reader that will no longer buffer calls to read. The internal buffer would -// be drained then calls would be redirected back to the underlying reader. -// When calling Reader(), the returned reader will read from the current cursor position. -// Call Rewind() first to reset the cursor to the start of the stream. -type headerReader struct { - pos int - buf []byte - - // sticky error - err error - - r io.Reader -} - -func newHeaderReader(r io.Reader) *headerReader { - const initialBufferSize = 128 - - // make sure the underlying reader is non-nil - if r == nil { - r = strings.NewReader("") - } - - return &headerReader{ - buf: make([]byte, 0, initialBufferSize), - r: r, - } -} - -func (s *headerReader) Read(data []byte) (n int, err error) { - if s.err != nil && s.err != io.EOF { - return 0, s.err - } - - // if this read is asking for more data than we have buffered - // then load more data from the underlying reader into the buffer - if s.pos+len(data) > len(s.buf) { - s.readUptoNMore(s.pos + len(data) - len(s.buf)) - } - - // copy whats in the buffer into the data slice - n = copy(data, s.buf[s.pos:]) - s.pos += n - - return n, s.err -} - -// Rewind sets the pointer back to the start of the stream. -// Any following calls to Read will come from the start of the stream again -func (s *headerReader) Rewind() { s.pos = 0 } - -// Reader returns a reader which will read from the current position in -// the buffer onwards. Use Rewind() first to reset to the start of the -// stream. -// -// Once this function has been called, any subsequent reads to the stream -// header reader will result in ErrReaderFrozen being returned. -func (s *headerReader) Reader() io.Reader { - s.err = errReaderFrozen - return io.MultiReader(bytes.NewReader(s.buf[s.pos:]), s.r) -} - -// readUptoNMore will read at most n more bytes from the underlying -// reader, storing them into the buffer. The position will not be -// updated but the buffer will be grown. -func (s *headerReader) readUptoNMore(n int) { - // grow the buffer by the amount of additional data we need - l := len(s.buf) - s.buf = append(s.buf, make([]byte, n)...) - - // We could call io.ReadFull here, but instead just let the - // behaviour of the underlying reader determine how the reads - // are handled. - n, s.err = s.r.Read(s.buf[l:]) - - // if we read less, make sure the buffer is trimmed - s.buf = s.buf[:l+n] -} diff --git a/header_reader_test.go b/header_reader_test.go deleted file mode 100644 index 4441ba4b..00000000 --- a/header_reader_test.go +++ /dev/null @@ -1,40 +0,0 @@ -package archiver - -import ( - "bytes" - "io" - "testing" -) - -func TestStreamHeaderReader(t *testing.T) { - data := []byte("the header\nthe body\n") - - r := newHeaderReader(bytes.NewReader(data)) - - buf := make([]byte, 10) // enough for 'the header' - - // test rewinding reads - for i := 0; i < 10; i++ { - r.Rewind() - _, err := r.Read(buf) - if err != nil { - t.Fatalf("Read failed: %s", err) - } - if string(buf) != "the header" { - t.Fatalf("expected 'the header' but got '%s'", string(buf)) - } - } - - // get the reader from header reader and make sure we can read all of the data out - r.Rewind() - finalReader := r.Reader() - buf = make([]byte, len(data)) - _, err := io.ReadFull(finalReader, buf) - if err != nil { - t.Fatalf("ReadFull failed: %s", err) - } - - if string(buf) != string(data) { - t.Fatalf("expected '%s' but got '%s'", string(data), string(buf)) - } -} From fdc991234b27c509e7113730c888dba2a3ab0645 Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Thu, 17 Mar 2022 16:03:28 -0600 Subject: [PATCH 4/7] Clarify returned Reader in godoc comment --- formats.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/formats.go b/formats.go index 2c9ff18b..99ce196f 100644 --- a/formats.go +++ b/formats.go @@ -27,8 +27,11 @@ func RegisterFormat(format Format) { // // If no matching formats were found, special error ErrNoMatch is returned. // -// The returned io.Reader will always be non-nil, and will read from the same point -// as the reader which was passed in. +// The returned io.Reader will always be non-nil and will read from the +// same point as the reader which was passed in; it should be used in place +// of the input stream after calling Identify() because it preserves and +// re-reads the bytes that were already read during the identification +// process. func Identify(filename string, stream io.Reader) (Format, io.Reader, error) { var compression Compression var archival Archival From e93d465c1c49be3355cd80e64c1196616300d725 Mon Sep 17 00:00:00 2001 From: Joel Howse Date: Mon, 21 Mar 2022 23:46:28 +1300 Subject: [PATCH 5/7] if underlying reader supports seek use that --- formats.go | 5 +++++ formats_test.go | 23 +++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/formats.go b/formats.go index 99ce196f..fd2e77d3 100644 --- a/formats.go +++ b/formats.go @@ -320,6 +320,11 @@ func (rr *rewindReader) rewind() { // no more rewinding is allowed since reads from the stream are // not recorded, so rewinding properly is impossible. func (rr *rewindReader) reader() io.Reader { + if ras, ok := rr.Reader.(seekReaderAt); ok { + if _, err := ras.Seek(-int64(rr.buf.Len()), io.SeekCurrent); err == nil { + return rr.Reader + } + } return io.MultiReader(bytes.NewReader(rr.buf.Bytes()), rr.Reader) } diff --git a/formats_test.go b/formats_test.go index a1d0eb74..2531ceb1 100644 --- a/formats_test.go +++ b/formats_test.go @@ -387,3 +387,26 @@ func TestIdentifyFindFormatByStreamContent(t *testing.T) { }) } } + +func TestIdentifyAndOpenZip(t *testing.T) { + f, err := os.Open("testdata/test.zip") + checkErr(t, err, "opening zip") + defer f.Close() + + format, reader, err := Identify("test.zip", f) + checkErr(t, err, "identifying zip") + if format.Name() != ".zip" { + t.Fatalf("unexpected format found: expected=.zip actual:%s", format.Name()) + } + + err = format.(Extractor).Extract(context.Background(), reader, nil, func(ctx context.Context, f File) error { + rc, err := f.Open() + if err != nil { + return err + } + defer rc.Close() + _, err = io.ReadAll(rc) + return err + }) + checkErr(t, err, "extracting zip") +} From 7708ac2c64f3682290fd84f88c95f28b350c7995 Mon Sep 17 00:00:00 2001 From: Joel Howse Date: Tue, 22 Mar 2022 07:18:48 +1300 Subject: [PATCH 6/7] update comment --- formats.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/formats.go b/formats.go index fd2e77d3..9b44f40a 100644 --- a/formats.go +++ b/formats.go @@ -319,8 +319,10 @@ func (rr *rewindReader) rewind() { // bytes, then from the underlying stream. After calling this, // no more rewinding is allowed since reads from the stream are // not recorded, so rewinding properly is impossible. +// If the underlying reader implements io.Seeker, then the +// underlying reader will be used directly. func (rr *rewindReader) reader() io.Reader { - if ras, ok := rr.Reader.(seekReaderAt); ok { + if ras, ok := rr.Reader.(io.Seeker); ok { if _, err := ras.Seek(-int64(rr.buf.Len()), io.SeekCurrent); err == nil { return rr.Reader } From fd3004592aad6bf638e3730287729b0589f461a1 Mon Sep 17 00:00:00 2001 From: jhwz <52683873+jhwz@users.noreply.github.com> Date: Tue, 22 Mar 2022 07:59:02 +1300 Subject: [PATCH 7/7] Update formats.go Co-authored-by: Matt Holt --- formats.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/formats.go b/formats.go index 9b44f40a..a63a9691 100644 --- a/formats.go +++ b/formats.go @@ -323,7 +323,7 @@ func (rr *rewindReader) rewind() { // underlying reader will be used directly. func (rr *rewindReader) reader() io.Reader { if ras, ok := rr.Reader.(io.Seeker); ok { - if _, err := ras.Seek(-int64(rr.buf.Len()), io.SeekCurrent); err == nil { + if _, err := ras.Seek(0, io.SeekStart); err == nil { return rr.Reader } }