From 16e5ba39c3b0f7e299a4f0a13c693cfa40904c60 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Wed, 12 Aug 2020 18:44:38 +0200 Subject: [PATCH] zstd: Fix ReadFrom with small blocks Two 'last' blocks was added on small payloads when using ReadFrom. Fixes #277 --- zstd/blockenc.go | 8 ++++---- zstd/encoder.go | 1 + zstd/encoder_test.go | 28 ++++++++++++++++++++++++++++ 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/zstd/blockenc.go b/zstd/blockenc.go index c584f6aabc..be718afd43 100644 --- a/zstd/blockenc.go +++ b/zstd/blockenc.go @@ -295,7 +295,7 @@ func (b *blockEnc) encodeRaw(a []byte) { b.output = bh.appendTo(b.output[:0]) b.output = append(b.output, a...) if debug { - println("Adding RAW block, length", len(a)) + println("Adding RAW block, length", len(a), "last:", b.last) } } @@ -308,7 +308,7 @@ func (b *blockEnc) encodeRawTo(dst, src []byte) []byte { dst = bh.appendTo(dst) dst = append(dst, src...) if debug { - println("Adding RAW block, length", len(src)) + println("Adding RAW block, length", len(src), "last:", b.last) } return dst } @@ -322,7 +322,7 @@ func (b *blockEnc) encodeLits(raw bool) error { // Don't compress extremely small blocks if len(b.literals) < 32 || raw { if debug { - println("Adding RAW block, length", len(b.literals)) + println("Adding RAW block, length", len(b.literals), "last:", b.last) } bh.setType(blockTypeRaw) b.output = bh.appendTo(b.output) @@ -349,7 +349,7 @@ func (b *blockEnc) encodeLits(raw bool) error { switch err { case huff0.ErrIncompressible: if debug { - println("Adding RAW block, length", len(b.literals)) + println("Adding RAW block, length", len(b.literals), "last:", b.last) } bh.setType(blockTypeRaw) b.output = bh.appendTo(b.output) diff --git a/zstd/encoder.go b/zstd/encoder.go index c56d2241f7..95ebc3d84e 100644 --- a/zstd/encoder.go +++ b/zstd/encoder.go @@ -190,6 +190,7 @@ func (e *Encoder) nextBlock(final bool) error { s.filling = s.filling[:0] s.headerWritten = true s.fullFrameWritten = true + s.eofWritten = true return nil } diff --git a/zstd/encoder_test.go b/zstd/encoder_test.go index 7240608eed..7cbaca3ec2 100644 --- a/zstd/encoder_test.go +++ b/zstd/encoder_test.go @@ -667,6 +667,34 @@ func TestEncoder_EncodeAllSilesia(t *testing.T) { t.Log("Encoded content matched") } +func TestEncoderReadFrom(t *testing.T) { + buffer := bytes.NewBuffer(nil) + encoder, err := NewWriter(buffer) + if err != nil { + t.Fatal(err) + } + if _, err := encoder.ReadFrom(strings.NewReader("0")); err != nil { + t.Fatal(err) + } + if err := encoder.Close(); err != nil { + t.Fatal(err) + } + + dec, _ := NewReader(nil) + toDec := buffer.Bytes() + toDec = append(toDec, toDec...) + decoded, err := dec.DecodeAll(toDec, nil) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal([]byte("00"), decoded) { + t.Logf("encoded: % x\n", buffer.Bytes()) + t.Fatalf("output mismatch, got %s", string(decoded)) + } + dec.Close() +} + func TestEncoder_EncodeAllEmpty(t *testing.T) { if testing.Short() { t.SkipNow()