diff --git a/bz2.go b/bz2.go index b62192d4..57a278f4 100644 --- a/bz2.go +++ b/bz2.go @@ -28,8 +28,8 @@ func (bz Bz2) Match(filename string, stream io.Reader) (MatchResult, error) { } // match file header - buf := make([]byte, len(bzip2Header)) - if _, err := io.ReadFull(stream, buf); err != nil { + buf, err := readAtMost(stream, len(bzip2Header)) + if err != nil { return mr, err } mr.ByStream = bytes.Equal(buf, bzip2Header) diff --git a/formats.go b/formats.go index c80b31e4..5140e218 100644 --- a/formats.go +++ b/formats.go @@ -2,6 +2,7 @@ package archiver import ( "context" + "errors" "fmt" "io" "strings" @@ -119,6 +120,31 @@ func identifyOne(format Format, filename string, stream io.ReadSeeker, comp Comp return format.Match(filename, stream) } +// readAtMost reads at most n bytes from the stream. A nil, empty, or short +// stream is not an error. The returned slice of bytes may have length < n +// without an error. +func readAtMost(stream io.Reader, n int) ([]byte, error) { + if stream == nil || n <= 0 { + return []byte{}, nil + } + + buf := make([]byte, n) + nr, err := io.ReadFull(stream, buf) + + // Return the bytes read if there was no error OR if the + // error was EOF (stream was empty) or UnexpectedEOF (stream + // had less than n). We ignore those errors because we aren't + // required to read the full n bytes; so an empty or short + // stream is not actually an error. + if err == nil || + errors.Is(err, io.EOF) || + errors.Is(err, io.ErrUnexpectedEOF) { + return buf[:nr], nil + } + + return nil, err +} + // CompressedArchive combines a compression format on top of an archive // format (e.g. "tar.gz") and provides both functionalities in a single // type. It ensures that archive functions are wrapped by compressors and diff --git a/formats_test.go b/formats_test.go new file mode 100644 index 00000000..f3021644 --- /dev/null +++ b/formats_test.go @@ -0,0 +1,289 @@ +package archiver + +import ( + "bytes" + "context" + "io" + "io/fs" + "os" + "strings" + "testing" +) + +func TestIdentifyDoesNotMatchContentFromTrimmedKnownHeaderHaving0Suffix(t *testing.T) { + // Using the outcome of `n, err := io.ReadFull(stream, buf)` without minding n + // may lead to a mis-characterization for cases with known header ending with 0x0 + // because the default byte value in a declared array is 0. + // This test guards against those cases. + tests := []struct { + name string + header []byte + }{ + { + name: "rar_v5.0", + header: rarHeaderV5_0, + }, + { + name: "rar_v1.5", + header: rarHeaderV1_5, + }, + { + name: "xz", + header: xzHeader, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + headerLen := len(tt.header) + if headerLen == 0 || tt.header[headerLen-1] != 0 { + t.Errorf("header expected to end with 0: header=%v", tt.header) + return + } + headerTrimmed := tt.header[:headerLen-1] + stream := bytes.NewReader(headerTrimmed) + got, err := Identify("", stream) + if got != nil { + t.Errorf("no Format expected for trimmed know %s header: found Format= %v", tt.name, got.Name()) + return + } + if ErrNoMatch != err { + t.Fatalf("ErrNoMatch expected for for trimmed know %s header: err :=%#v", tt.name, err) + return + } + + }) + } +} + +func TestIdentifyCanAssessSmallOrNoContent(t *testing.T) { + type args struct { + stream io.ReadSeeker + } + tests := []struct { + name string + args args + }{ + { + name: "should return nomatch for an empty stream", + args: args{ + stream: bytes.NewReader([]byte{}), + }, + }, + { + name: "should return nomatch for a stream with content size less than known header", + args: args{ + stream: bytes.NewReader([]byte{'a'}), + }, + }, + { + name: "should return nomatch for a stream with content size greater then known header size and not supported format", + args: args{ + stream: bytes.NewReader([]byte(strings.Repeat("this is a txt content", 2))), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Identify("", tt.args.stream) + if got != nil { + t.Errorf("no Format expected for non archive and not compressed stream: found Format= %v", got.Name()) + return + } + if ErrNoMatch != err { + t.Fatalf("ErrNoMatch expected for non archive and not compressed stream: err :=%#v", err) + return + } + + }) + } +} + +func compress( + t *testing.T, compName string, content []byte, + openwriter func(w io.Writer) (io.WriteCloser, error), +) []byte { + buf := bytes.NewBuffer(make([]byte, 0, 128)) + cwriter, err := openwriter(buf) + if err != nil { + t.Fatalf("fail to open compression writer: compression-name=%s, err=%#v", compName, err) + return nil + } + _, err = cwriter.Write(content) + if err != nil { + cerr := cwriter.Close() + t.Fatalf( + "fail to write using compression writer: compression-name=%s, err=%#v, close-err=%#v", + compName, err, cerr) + return nil + } + err = cwriter.Close() + if err != nil { + t.Fatalf("fail to close compression writer: compression-name=%s, err=%#v", compName, err) + return nil + } + return buf.Bytes() +} + +func archive(t *testing.T, arch Archiver, fname string, fileInfo fs.FileInfo) []byte { + files := []File{ + {FileInfo: fileInfo, NameInArchive: "tmp.txt", + Open: func() (io.ReadCloser, error) { + return os.Open(fname) + }}, + } + buf := bytes.NewBuffer(make([]byte, 0, 128)) + err := arch.Archive(context.TODO(), buf, files) + if err != nil { + t.Fatalf("fail to create archive: err=%#v", err) + return nil + } + return buf.Bytes() + +} + +type writeNopCloser struct{ io.Writer } + +func (wnc writeNopCloser) Close() error { return nil } + +func newWriteNopCloser(w io.Writer) (io.WriteCloser, error) { + return writeNopCloser{w}, nil +} + +func newTmpTextFile(t *testing.T, content string) (string, fs.FileInfo) { + + tmpTxtFile, err := os.CreateTemp("", "TestIdentifyFindFormatByStreamContent-tmp-*.txt") + if err != nil { + t.Fatalf("fail to create tmp test file for archive tests: err=%v", err) + return "", nil + } + fname := tmpTxtFile.Name() + + if _, err = tmpTxtFile.Write([]byte(content)); err != nil { + tmpTxtFile.Close() + os.Remove(fname) + t.Fatalf("fail to write content to tmp-txt-file: err=%#v", err) + return "", nil + } + if err = tmpTxtFile.Close(); err != nil { + os.Remove(fname) + t.Fatalf("fail to close tmp-txt-file: err=%#v", err) + return "", nil + } + fi, err := os.Stat(fname) + if err != nil { + os.Remove(fname) + t.Fatalf("fail to get tmp-txt-file stats: err=%v", err) + return "", nil + } + + return fname, fi +} + +func TestIdentifyFindFormatByStreamContent(t *testing.T) { + tmpTxtFileName, tmpTxtFileInfo := newTmpTextFile(t, "this is text") + t.Cleanup(func() { + os.Remove(tmpTxtFileName) + }) + + tests := []struct { + name string + content []byte + openCompressionWriter func(w io.Writer) (io.WriteCloser, error) + compressorName string + wantFormatName string + }{ + //TODO add test case for brotli when Brotli.Match() by stream content is implemented + { + name: "should recognize bz2", + openCompressionWriter: Bz2{}.OpenWriter, + content: []byte("this is text"), + compressorName: ".bz2", + wantFormatName: ".bz2", + }, + { + name: "should recognize gz", + openCompressionWriter: Gz{}.OpenWriter, + content: []byte("this is text"), + compressorName: ".gz", + wantFormatName: ".gz", + }, + { + name: "should recognize lz4", + openCompressionWriter: Lz4{}.OpenWriter, + content: []byte("this is text"), + compressorName: ".lz4", + wantFormatName: ".lz4", + }, + { + name: "should recognize sz", + openCompressionWriter: Sz{}.OpenWriter, + content: []byte("this is text"), + compressorName: ".sz", + wantFormatName: ".sz", + }, + { + name: "should recognize xz", + openCompressionWriter: Xz{}.OpenWriter, + content: []byte("this is text"), + compressorName: ".xz", + wantFormatName: ".xz", + }, + { + name: "should recognize zst", + openCompressionWriter: Zstd{}.OpenWriter, + content: []byte("this is text"), + compressorName: ".zst", + wantFormatName: ".zst", + }, + { + name: "should recognize tar", + openCompressionWriter: newWriteNopCloser, + content: archive(t, Tar{}, tmpTxtFileName, tmpTxtFileInfo), + compressorName: "", + wantFormatName: ".tar", + }, + { + name: "should recognize tar.gz", + openCompressionWriter: Gz{}.OpenWriter, + content: archive(t, Tar{}, tmpTxtFileName, tmpTxtFileInfo), + compressorName: ".gz", + wantFormatName: ".tar.gz", + }, + { + name: "should recognize zip", + openCompressionWriter: newWriteNopCloser, + content: archive(t, Zip{}, tmpTxtFileName, tmpTxtFileInfo), + compressorName: "", + wantFormatName: ".zip", + }, + { + name: "should recognize rar by v5.0 header", + openCompressionWriter: newWriteNopCloser, + content: rarHeaderV5_0[:], + compressorName: "", + wantFormatName: ".rar", + }, + { + name: "should recognize rar by v1.5 header", + openCompressionWriter: newWriteNopCloser, + content: rarHeaderV1_5[:], + compressorName: "", + wantFormatName: ".rar", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stream := bytes.NewReader(compress(t, tt.compressorName, tt.content, tt.openCompressionWriter)) + got, err := Identify("", stream) + if err != nil { + t.Fatalf("should have found a corresponding Format: err :=%+v", err) + return + } + if tt.wantFormatName != got.Name() { + t.Errorf("unexpected format found: expected=%s actual:%s", tt.wantFormatName, got.Name()) + return + } + + }) + } +} diff --git a/gz.go b/gz.go index ae6b4ef5..e747d030 100644 --- a/gz.go +++ b/gz.go @@ -36,8 +36,8 @@ func (gz Gz) Match(filename string, stream io.Reader) (MatchResult, error) { } // match file header - buf := make([]byte, len(gzHeader)) - if _, err := io.ReadFull(stream, buf); err != nil { + buf, err := readAtMost(stream, len(gzHeader)) + if err != nil { return mr, err } mr.ByStream = bytes.Equal(buf, gzHeader) diff --git a/lz4.go b/lz4.go index 659c3975..aaa22a54 100644 --- a/lz4.go +++ b/lz4.go @@ -28,8 +28,8 @@ func (lz Lz4) Match(filename string, stream io.Reader) (MatchResult, error) { } // match file header - buf := make([]byte, len(lz4Header)) - if _, err := io.ReadFull(stream, buf); err != nil { + buf, err := readAtMost(stream, len(lz4Header)) + if err != nil { return mr, err } mr.ByStream = bytes.Equal(buf, lz4Header) diff --git a/rar.go b/rar.go index e41a192e..61e55e57 100644 --- a/rar.go +++ b/rar.go @@ -41,11 +41,17 @@ func (r Rar) Match(filename string, stream io.Reader) (MatchResult, error) { } // match file header (there are two versions; allocate buffer for larger one) - buf := make([]byte, len(rarHeaderV5_0)) - if _, err := io.ReadFull(stream, buf); err != nil { + buf, err := readAtMost(stream, len(rarHeaderV5_0)) + if err != nil { return mr, err } - mr.ByStream = bytes.Equal(buf[:len(rarHeaderV1_5)], rarHeaderV1_5) || bytes.Equal(buf, rarHeaderV5_0) + + matchedV1_5 := len(buf) >= len(rarHeaderV1_5) && + bytes.Equal(rarHeaderV1_5, buf[:len(rarHeaderV1_5)]) + matchedV5_0 := len(buf) >= len(rarHeaderV5_0) && + bytes.Equal(rarHeaderV5_0, buf[:len(rarHeaderV5_0)]) + + mr.ByStream = matchedV1_5 || matchedV5_0 return mr, nil } diff --git a/sz.go b/sz.go index 577e331a..9d10604a 100644 --- a/sz.go +++ b/sz.go @@ -26,8 +26,8 @@ func (sz Sz) Match(filename string, stream io.Reader) (MatchResult, error) { } // match file header - buf := make([]byte, len(snappyHeader)) - if _, err := io.ReadFull(stream, buf); err != nil { + buf, err := readAtMost(stream, len(snappyHeader)) + if err != nil { return mr, err } mr.ByStream = bytes.Equal(buf, snappyHeader) diff --git a/xz.go b/xz.go index eaf43471..4e1b6b41 100644 --- a/xz.go +++ b/xz.go @@ -27,8 +27,8 @@ func (x Xz) Match(filename string, stream io.Reader) (MatchResult, error) { } // match file header - buf := make([]byte, len(xzHeader)) - if _, err := io.ReadFull(stream, buf); err != nil { + buf, err := readAtMost(stream, len(xzHeader)) + if err != nil { return mr, err } mr.ByStream = bytes.Equal(buf, xzHeader) diff --git a/zip.go b/zip.go index 20fa4250..d334538a 100644 --- a/zip.go +++ b/zip.go @@ -92,8 +92,8 @@ func (z Zip) Match(filename string, stream io.Reader) (MatchResult, error) { } // match file header - buf := make([]byte, len(zipHeader)) - if _, err := io.ReadFull(stream, buf); err != nil { + buf, err := readAtMost(stream, len(zipHeader)) + if err != nil { return mr, err } mr.ByStream = bytes.Equal(buf, zipHeader) diff --git a/zstd.go b/zstd.go index cd310a31..fe07b76f 100644 --- a/zstd.go +++ b/zstd.go @@ -29,8 +29,8 @@ func (zs Zstd) Match(filename string, stream io.Reader) (MatchResult, error) { } // match file header - buf := make([]byte, len(zstdHeader)) - if _, err := io.ReadFull(stream, buf); err != nil { + buf, err := readAtMost(stream, len(zstdHeader)) + if err != nil { return mr, err } mr.ByStream = bytes.Equal(buf, zstdHeader)