diff --git a/go.mod b/go.mod index b6d447e..5895b4a 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/bodgit/plumbing v1.3.0 github.com/bodgit/windows v1.0.1 github.com/hashicorp/go-multierror v1.1.1 + github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/klauspost/compress v1.17.7 github.com/pierrec/lz4/v4 v4.1.21 github.com/stretchr/testify v1.9.0 diff --git a/go.sum b/go.sum index 07f3f9d..13955a0 100644 --- a/go.sum +++ b/go.sum @@ -66,6 +66,8 @@ github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+l github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= diff --git a/internal/aes7z/key.go b/internal/aes7z/key.go index 2ff37e7..cdcf151 100644 --- a/internal/aes7z/key.go +++ b/internal/aes7z/key.go @@ -4,12 +4,47 @@ import ( "bytes" "crypto/sha256" "encoding/binary" + "encoding/hex" + lru "github.com/hashicorp/golang-lru/v2" + "go4.org/syncutil" "golang.org/x/text/encoding/unicode" "golang.org/x/text/transform" ) -func calculateKey(password string, cycles int, salt []byte) []byte { +type cacheKey struct { + password string + cycles int + salt string // []byte isn't comparable +} + +const cacheSize = 10 + +//nolint:gochecknoglobals +var ( + once syncutil.Once + cache *lru.Cache[cacheKey, []byte] +) + +func calculateKey(password string, cycles int, salt []byte) ([]byte, error) { + if err := once.Do(func() (err error) { + cache, err = lru.New[cacheKey, []byte](cacheSize) + + return + }); err != nil { + return nil, err + } + + ck := cacheKey{ + password: password, + cycles: cycles, + salt: hex.EncodeToString(salt), + } + + if key, ok := cache.Get(ck); ok { + return key, nil + } + b := bytes.NewBuffer(salt) // Convert password to UTF-16LE @@ -30,5 +65,7 @@ func calculateKey(password string, cycles int, salt []byte) []byte { copy(key, h.Sum(nil)) } - return key + _ = cache.Add(ck, key) + + return key, nil } diff --git a/internal/aes7z/reader.go b/internal/aes7z/reader.go index 4d0d30c..760db29 100644 --- a/internal/aes7z/reader.go +++ b/internal/aes7z/reader.go @@ -29,7 +29,12 @@ func (rc *readCloser) Close() error { } func (rc *readCloser) Password(p string) error { - block, err := aes.NewCipher(calculateKey(p, rc.cycles, rc.salt)) + key, err := calculateKey(p, rc.cycles, rc.salt) + if err != nil { + return err + } + + block, err := aes.NewCipher(key) if err != nil { return err } diff --git a/internal/zstd/reader.go b/internal/zstd/reader.go index 89f3993..0d68a3c 100644 --- a/internal/zstd/reader.go +++ b/internal/zstd/reader.go @@ -52,6 +52,7 @@ func NewReader(_ []byte, _ uint64, readers []io.ReadCloser) (io.ReadCloser, erro if r, err = zstd.NewReader(readers[0]); err != nil { return nil, err } + runtime.SetFinalizer(r, (*zstd.Decoder).Close) } diff --git a/reader.go b/reader.go index 8f586dd..6a2fdf8 100644 --- a/reader.go +++ b/reader.go @@ -544,7 +544,7 @@ func toValidName(name string) string { return p } -//nolint:cyclop +//nolint:cyclop,funlen func (z *Reader) initFileList() { z.fileListOnce.Do(func() { files := make(map[string]int) @@ -583,12 +583,14 @@ func (z *Reader) initFileList() { isDir: isDir, } z.fileList = append(z.fileList, entry) + if isDir { knownDirs[name] = idx } else { files[name] = idx } } + for dir := range dirs { if _, ok := knownDirs[dir]; !ok { if idx, ok := files[dir]; ok { diff --git a/reader_test.go b/reader_test.go index 8ba9436..5b259d4 100644 --- a/reader_test.go +++ b/reader_test.go @@ -188,6 +188,7 @@ func TestOpenReader(t *testing.T) { t.Run(table.name, func(t *testing.T) { t.Parallel() + r, err := sevenzip.OpenReader(filepath.Join("testdata", table.file)) if err != nil { assert.ErrorIs(t, err, table.err) @@ -243,6 +244,7 @@ func TestOpenReaderWithPassword(t *testing.T) { t.Run(table.name, func(t *testing.T) { t.Parallel() + r, err := sevenzip.OpenReaderWithPassword(filepath.Join("testdata", table.file), table.password) if err != nil { t.Fatal(err) @@ -362,13 +364,13 @@ func benchmarkArchiveNaiveParallel(b *testing.B, file string, workers int) { } } -func benchmarkArchive(b *testing.B, file string, optimised bool) { +func benchmarkArchive(b *testing.B, file, password string, optimised bool) { b.Helper() h := crc32.NewIEEE() for n := 0; n < b.N; n++ { - r, err := sevenzip.OpenReader(filepath.Join("testdata", file)) + r, err := sevenzip.OpenReaderWithPassword(filepath.Join("testdata", file), password) if err != nil { b.Fatal(err) } @@ -382,56 +384,60 @@ func benchmarkArchive(b *testing.B, file string, optimised bool) { } } +func BenchmarkAES7z(b *testing.B) { + benchmarkArchive(b, "aes7z.7z", "password", true) +} + func BenchmarkBzip2(b *testing.B) { - benchmarkArchive(b, "bzip2.7z", true) + benchmarkArchive(b, "bzip2.7z", "", true) } func BenchmarkCopy(b *testing.B) { - benchmarkArchive(b, "copy.7z", true) + benchmarkArchive(b, "copy.7z", "", true) } func BenchmarkDeflate(b *testing.B) { - benchmarkArchive(b, "deflate.7z", true) + benchmarkArchive(b, "deflate.7z", "", true) } func BenchmarkDelta(b *testing.B) { - benchmarkArchive(b, "delta.7z", true) + benchmarkArchive(b, "delta.7z", "", true) } func BenchmarkLZMA(b *testing.B) { - benchmarkArchive(b, "lzma.7z", true) + benchmarkArchive(b, "lzma.7z", "", true) } func BenchmarkLZMA2(b *testing.B) { - benchmarkArchive(b, "lzma2.7z", true) + benchmarkArchive(b, "lzma2.7z", "", true) } func BenchmarkBCJ2(b *testing.B) { - benchmarkArchive(b, "bcj2.7z", true) + benchmarkArchive(b, "bcj2.7z", "", true) } func BenchmarkComplex(b *testing.B) { - benchmarkArchive(b, "lzma1900.7z", true) + benchmarkArchive(b, "lzma1900.7z", "", true) } func BenchmarkLZ4(b *testing.B) { - benchmarkArchive(b, "lz4.7z", true) + benchmarkArchive(b, "lz4.7z", "", true) } func BenchmarkBrotli(b *testing.B) { - benchmarkArchive(b, "brotli.7z", true) + benchmarkArchive(b, "brotli.7z", "", true) } func BenchmarkZstandard(b *testing.B) { - benchmarkArchive(b, "zstd.7z", true) + benchmarkArchive(b, "zstd.7z", "", true) } func BenchmarkNaiveReader(b *testing.B) { - benchmarkArchive(b, "lzma1900.7z", false) + benchmarkArchive(b, "lzma1900.7z", "", false) } func BenchmarkOptimisedReader(b *testing.B) { - benchmarkArchive(b, "lzma1900.7z", true) + benchmarkArchive(b, "lzma1900.7z", "", true) } func BenchmarkNaiveParallelReader(b *testing.B) { @@ -447,17 +453,17 @@ func BenchmarkParallelReader(b *testing.B) { } func BenchmarkBCJ(b *testing.B) { - benchmarkArchive(b, "bcj.7z", true) + benchmarkArchive(b, "bcj.7z", "", true) } func BenchmarkPPC(b *testing.B) { - benchmarkArchive(b, "ppc.7z", true) + benchmarkArchive(b, "ppc.7z", "", true) } func BenchmarkARM(b *testing.B) { - benchmarkArchive(b, "arm.7z", true) + benchmarkArchive(b, "arm.7z", "", true) } func BenchmarkSPARC(b *testing.B) { - benchmarkArchive(b, "sparc.7z", true) + benchmarkArchive(b, "sparc.7z", "", true) } diff --git a/testdata/aes7z.7z b/testdata/aes7z.7z new file mode 100644 index 0000000..3f4b3e9 Binary files /dev/null and b/testdata/aes7z.7z differ