From 32630dfa2d0edc7875788e7203cd7362bb48e54f Mon Sep 17 00:00:00 2001 From: Antony Dovgal Date: Mon, 22 Mar 2021 22:55:57 +0300 Subject: [PATCH] POC: optimize small values encoding ... using virtual table 'shards' and avoiding costly copy() --- zstd/dict_test.go | 43 +++++++++++++++------- zstd/enc_dfast.go | 90 +++++++++++++++++++++++++++++++++++++++++------ zstd/enc_fast.go | 58 ++++++++++++++++++++++++++---- 3 files changed, 161 insertions(+), 30 deletions(-) diff --git a/zstd/dict_test.go b/zstd/dict_test.go index 433349c5d4..1d63c93886 100644 --- a/zstd/dict_test.go +++ b/zstd/dict_test.go @@ -218,7 +218,7 @@ func TestEncoder_SmallDict(t *testing.T) { } } -func BenchmarkEncodeAllDict(b *testing.B) { +func benchmarkEncodeAllLimitedBySize(b *testing.B, lowerLimit int, upperLimit int) { fn := "testdata/dict-tests-small.zip" data, err := ioutil.ReadFile(fn) t := testing.TB(b) @@ -232,7 +232,6 @@ func BenchmarkEncodeAllDict(b *testing.B) { } var dicts [][]byte var encs []*Encoder - var noDictEncs []*Encoder var encNames []string for _, tt := range zr.File { @@ -257,12 +256,6 @@ func BenchmarkEncodeAllDict(b *testing.B) { } encs = append(encs, enc) encNames = append(encNames, fmt.Sprint("level-", level.String(), "-dict-", len(dicts))) - - enc, err = NewWriter(nil, WithEncoderConcurrency(1), WithEncoderLevel(level), WithWindowSize(1<<17)) - if err != nil { - t.Fatal(err) - } - noDictEncs = append(noDictEncs, enc) } }() } @@ -273,10 +266,7 @@ func BenchmarkEncodeAllDict(b *testing.B) { } defer dec.Close() - for i, tt := range zr.File { - if i == 5 { - break - } + for _, tt := range zr.File { if !strings.HasSuffix(tt.Name, ".zst") { continue } @@ -293,6 +283,15 @@ func BenchmarkEncodeAllDict(b *testing.B) { if err != nil { t.Fatal(err) } + + if len(decoded) < lowerLimit { + continue + } + + if upperLimit > 0 && len(decoded) > upperLimit { + continue + } + for i := range encs { // Only do 1 dict (3 encoders) for now. if i == 3 { @@ -313,6 +312,26 @@ func BenchmarkEncodeAllDict(b *testing.B) { } } +func BenchmarkEncodeAllDict0_1024(b *testing.B) { + benchmarkEncodeAllLimitedBySize(b, 0, 1024) +} + +func BenchmarkEncodeAllDict1024_8192(b *testing.B) { + benchmarkEncodeAllLimitedBySize(b, 1024, 8192) +} + +func BenchmarkEncodeAllDict8192_16384(b *testing.B) { + benchmarkEncodeAllLimitedBySize(b, 8192, 16384) +} + +func BenchmarkEncodeAllDict16384_65536(b *testing.B) { + benchmarkEncodeAllLimitedBySize(b, 16384, 65536) +} + +func BenchmarkEncodeAllDict65536_0(b *testing.B) { + benchmarkEncodeAllLimitedBySize(b, 65536, 0) +} + func TestDecoder_MoreDicts(t *testing.T) { // All files have CRC // https://files.klauspost.com/compress/zstd-dict-tests.zip diff --git a/zstd/enc_dfast.go b/zstd/enc_dfast.go index 19eebf66e5..2e7d6a1238 100644 --- a/zstd/enc_dfast.go +++ b/zstd/enc_dfast.go @@ -11,6 +11,9 @@ const ( dFastLongTableSize = 1 << dFastLongTableBits // Size of the table dFastLongTableMask = dFastLongTableSize - 1 // Mask for table indices. Redundant, but can eliminate bounds checks. + dLongTableShardCnt = 1 << 9 // Number of shards in the table + dLongTableShardSize = dFastLongTableSize / tableShardCnt // Size of an individual shard + dFastShortTableBits = tableBits // Bits used in the short match table dFastShortTableSize = 1 << dFastShortTableBits // Size of the table dFastShortTableMask = dFastShortTableSize - 1 // Mask for table indices. Redundant, but can eliminate bounds checks. @@ -18,8 +21,9 @@ const ( type doubleFastEncoder struct { fastEncoder - longTable [dFastLongTableSize]tableEntry - dictLongTable []tableEntry + longTable [dFastLongTableSize]tableEntry + dictLongTable []tableEntry + longTableShardClean [dLongTableShardCnt]bool } // Encode mimmics functionality in zstd_dfast.c @@ -40,6 +44,7 @@ func (e *doubleFastEncoder) Encode(blk *blockEnc, src []byte) { for i := range e.longTable[:] { e.longTable[i] = tableEntry{} } + e.markAllLongShardsDirty() e.cur = e.maxMatchOff break } @@ -63,6 +68,7 @@ func (e *doubleFastEncoder) Encode(blk *blockEnc, src []byte) { } e.longTable[i].offset = v } + e.markAllLongShardsDirty() e.cur = e.maxMatchOff break } @@ -124,7 +130,9 @@ encodeLoop: repIndex := s - offset1 + repOff entry := tableEntry{offset: s + e.cur, val: uint32(cv)} e.longTable[nextHashL] = entry + e.markLongShardDirty(int(nextHashL)) e.table[nextHashS] = entry + e.markShardDirty(int(nextHashS)) if canRepeat { if repIndex >= 0 && load3232(src, repIndex) == uint32(cv>>(repOff*8)) { @@ -205,6 +213,7 @@ encodeLoop: // We can store it, since we have at least a 4 byte match. e.longTable[nextHashL] = tableEntry{offset: s + checkAt + e.cur, val: uint32(cv)} + e.markLongShardDirty(int(nextHashL)) if coffsetL < e.maxMatchOff && uint32(cv) == candidateL.val { // Found a long match, likely at least 8 bytes. // Reference encoder checks all 8 bytes, we only check 4, @@ -295,16 +304,24 @@ encodeLoop: cv1 := load6432(src, index1) te0 := tableEntry{offset: index0 + e.cur, val: uint32(cv0)} te1 := tableEntry{offset: index1 + e.cur, val: uint32(cv1)} - e.longTable[hash8(cv0, dFastLongTableBits)] = te0 - e.longTable[hash8(cv1, dFastLongTableBits)] = te1 + longHash1 := hash8(cv0, dFastLongTableBits) + longHash2 := hash8(cv0, dFastLongTableBits) + e.longTable[longHash1] = te0 + e.longTable[longHash2] = te1 + e.markLongShardDirty(int(longHash1)) + e.markLongShardDirty(int(longHash2)) cv0 >>= 8 cv1 >>= 8 te0.offset++ te1.offset++ te0.val = uint32(cv0) te1.val = uint32(cv1) - e.table[hash5(cv0, dFastShortTableBits)] = te0 - e.table[hash5(cv1, dFastShortTableBits)] = te1 + hashVal1 := hash5(cv0, dFastShortTableBits) + hashVal2 := hash5(cv1, dFastShortTableBits) + e.table[hashVal1] = te0 + e.markShardDirty(int(hashVal1)) + e.table[hashVal2] = te1 + e.markShardDirty(int(hashVal2)) cv = load6432(src, s) @@ -330,7 +347,9 @@ encodeLoop: entry := tableEntry{offset: s + e.cur, val: uint32(cv)} e.longTable[nextHashL] = entry + e.markLongShardDirty(int(nextHashL)) e.table[nextHashS] = entry + e.markShardDirty(int(nextHashS)) seq.matchLen = uint32(l) - zstdMinMatch seq.litLen = 0 @@ -383,6 +402,7 @@ func (e *doubleFastEncoder) EncodeNoHist(blk *blockEnc, src []byte) { for i := range e.longTable[:] { e.longTable[i] = tableEntry{} } + e.markAllLongShardsDirty() e.cur = e.maxMatchOff } @@ -436,7 +456,9 @@ encodeLoop: repIndex := s - offset1 + repOff entry := tableEntry{offset: s + e.cur, val: uint32(cv)} e.longTable[nextHashL] = entry + e.markLongShardDirty(int(nextHashL)) e.table[nextHashS] = entry + e.markShardDirty(int(nextHashS)) if len(blk.sequences) > 2 { if load3232(src, repIndex) == uint32(cv>>(repOff*8)) { @@ -518,6 +540,7 @@ encodeLoop: // We can store it, since we have at least a 4 byte match. e.longTable[nextHashL] = tableEntry{offset: s + checkAt + e.cur, val: uint32(cv)} + e.markLongShardDirty(int(nextHashL)) if coffsetL < e.maxMatchOff && uint32(cv) == candidateL.val { // Found a long match, likely at least 8 bytes. // Reference encoder checks all 8 bytes, we only check 4, @@ -605,16 +628,24 @@ encodeLoop: cv1 := load6432(src, index1) te0 := tableEntry{offset: index0 + e.cur, val: uint32(cv0)} te1 := tableEntry{offset: index1 + e.cur, val: uint32(cv1)} - e.longTable[hash8(cv0, dFastLongTableBits)] = te0 - e.longTable[hash8(cv1, dFastLongTableBits)] = te1 + longHash1 := hash8(cv0, dFastLongTableBits) + longHash2 := hash8(cv1, dFastLongTableBits) + e.longTable[longHash1] = te0 + e.longTable[longHash2] = te1 + e.markLongShardDirty(int(longHash1)) + e.markLongShardDirty(int(longHash2)) cv0 >>= 8 cv1 >>= 8 te0.offset++ te1.offset++ te0.val = uint32(cv0) te1.val = uint32(cv1) - e.table[hash5(cv0, dFastShortTableBits)] = te0 - e.table[hash5(cv1, dFastShortTableBits)] = te1 + hashVal1 := hash5(cv0, dFastShortTableBits) + hashVal2 := hash5(cv1, dFastShortTableBits) + e.table[hashVal1] = te0 + e.markShardDirty(int(hashVal1)) + e.table[hashVal2] = te1 + e.markShardDirty(int(hashVal2)) cv = load6432(src, s) @@ -641,7 +672,9 @@ encodeLoop: entry := tableEntry{offset: s + e.cur, val: uint32(cv)} e.longTable[nextHashL] = entry + e.markLongShardDirty(int(nextHashL)) e.table[nextHashS] = entry + e.markShardDirty(int(nextHashS)) seq.matchLen = uint32(l) - zstdMinMatch seq.litLen = 0 @@ -709,5 +742,40 @@ func (e *doubleFastEncoder) Reset(d *dict, singleBlock bool) { } // Reset table to initial state e.cur = e.maxMatchOff - copy(e.longTable[:], e.dictLongTable) + + dirtyShardCnt := 0 + + for i := range e.longTableShardClean { + if !e.longTableShardClean[i] { + dirtyShardCnt++ + } + } + + minDirty := float64(dLongTableShardCnt) / 2 + if dirtyShardCnt > int(minDirty) { + copy(e.longTable[:], e.dictLongTable) + for i := range e.longTableShardClean { + e.longTableShardClean[i] = true + } + } else { + for i := range e.longTableShardClean { + if e.longTableShardClean[i] { + continue + } + + copy(e.longTable[i*dLongTableShardSize:(i+1)*dLongTableShardSize], e.dictLongTable[i*dLongTableShardSize:(i+1)*dLongTableShardSize]) + e.longTableShardClean[i] = true + } + } +} + +func (e *doubleFastEncoder) markAllLongShardsDirty() { + for i := range e.longTableShardClean[:] { + e.longTableShardClean[i] = false + } + e.markAllShardsDirty() +} + +func (e *doubleFastEncoder) markLongShardDirty(entryNum int) { + e.longTableShardClean[entryNum/dLongTableShardSize] = false } diff --git a/zstd/enc_fast.go b/zstd/enc_fast.go index 0045016d94..b1dd0638f8 100644 --- a/zstd/enc_fast.go +++ b/zstd/enc_fast.go @@ -11,9 +11,11 @@ import ( ) const ( - tableBits = 15 // Bits used in the table - tableSize = 1 << tableBits // Size of the table - tableMask = tableSize - 1 // Mask for table indices. Redundant, but can eliminate bounds checks. + tableBits = 15 // Bits used in the table + tableSize = 1 << tableBits // Size of the table + tableShardCnt = 1 << 9 // Number of shards in the table + tableShardSize = tableSize / tableShardCnt // Size of an individual shard + tableMask = tableSize - 1 // Mask for table indices. Redundant, but can eliminate bounds checks. maxMatchLength = 131074 ) @@ -24,8 +26,9 @@ type tableEntry struct { type fastEncoder struct { fastBase - table [tableSize]tableEntry - dictTable []tableEntry + table [tableSize]tableEntry + dictTable []tableEntry + tableShardClean [tableShardCnt]bool } // Encode mimmics functionality in zstd_fast.c @@ -41,6 +44,7 @@ func (e *fastEncoder) Encode(blk *blockEnc, src []byte) { for i := range e.table[:] { e.table[i] = tableEntry{} } + e.markAllShardsDirty() e.cur = e.maxMatchOff break } @@ -55,6 +59,7 @@ func (e *fastEncoder) Encode(blk *blockEnc, src []byte) { } e.table[i].offset = v } + e.markAllShardsDirty() e.cur = e.maxMatchOff break } @@ -121,7 +126,9 @@ encodeLoop: repIndex := s - offset1 + 2 e.table[nextHash] = tableEntry{offset: s + e.cur, val: uint32(cv)} + e.markShardDirty(int(nextHash)) e.table[nextHash2] = tableEntry{offset: s + e.cur + 1, val: uint32(cv >> 8)} + e.markShardDirty(int(nextHash2)) if canRepeat && repIndex >= 0 && load3232(src, repIndex) == uint32(cv>>16) { // Consider history as well. @@ -295,6 +302,7 @@ encodeLoop: // Store this, since we have it. nextHash := hash6(cv, hashLog) e.table[nextHash] = tableEntry{offset: s + e.cur, val: uint32(cv)} + e.markShardDirty(int(nextHash)) seq.matchLen = uint32(l) - zstdMinMatch seq.litLen = 0 // Since litlen is always 0, this is offset 1. @@ -346,6 +354,7 @@ func (e *fastEncoder) EncodeNoHist(blk *blockEnc, src []byte) { for i := range e.table[:] { e.table[i] = tableEntry{} } + e.markAllShardsDirty() e.cur = e.maxMatchOff } @@ -404,7 +413,9 @@ encodeLoop: repIndex := s - offset1 + 2 e.table[nextHash] = tableEntry{offset: s + e.cur, val: uint32(cv)} + e.markShardDirty(int(nextHash)) e.table[nextHash2] = tableEntry{offset: s + e.cur + 1, val: uint32(cv >> 8)} + e.markShardDirty(int(nextHash2)) if len(blk.sequences) > 2 && load3232(src, repIndex) == uint32(cv>>16) { // Consider history as well. @@ -583,6 +594,7 @@ encodeLoop: // Store this, since we have it. nextHash := hash6(cv, hashLog) e.table[nextHash] = tableEntry{offset: s + e.cur, val: uint32(cv)} + e.markShardDirty(int(nextHash)) seq.matchLen = uint32(l) - zstdMinMatch seq.litLen = 0 // Since litlen is always 0, this is offset 1. @@ -656,6 +668,38 @@ func (e *fastEncoder) Reset(d *dict, singleBlock bool) { } e.cur = e.maxMatchOff - // Reset table to initial state - copy(e.table[:], e.dictTable) + + dirtyShardCnt := 0 + for i := range e.tableShardClean { + if !e.tableShardClean[i] { + dirtyShardCnt++ + } + } + + minDirty := float64(tableShardCnt) * 4 / 5 + if dirtyShardCnt > int(minDirty) { + copy(e.table[:], e.dictTable) + for i := range e.tableShardClean { + e.tableShardClean[i] = true + } + } else { + for i := range e.tableShardClean { + if e.tableShardClean[i] { + continue + } + + copy(e.table[i*tableShardSize:(i+1)*tableShardSize], e.dictTable[i*tableShardSize:(i+1)*tableShardSize]) + e.tableShardClean[i] = true + } + } +} + +func (e *fastEncoder) markAllShardsDirty() { + for i := range e.tableShardClean { + e.tableShardClean[i] = false + } +} + +func (e *fastEncoder) markShardDirty(entryNum int) { + e.tableShardClean[entryNum/tableShardSize] = false }