From 3bd7653ccd5c50c7916a13d6159801ddae774e86 Mon Sep 17 00:00:00 2001 From: Jeromy Date: Sat, 2 May 2020 15:37:42 -0700 Subject: [PATCH 1/5] attempt to allocate less by using shared buffers --- gen.go | 5 ++- testing/cbor_gen.go | 16 +++++-- utils.go | 100 ++++++++++++++++++++++++++++++++++---------- 3 files changed, 94 insertions(+), 27 deletions(-) diff --git a/gen.go b/gen.go index 1acd51f..ff3b51a 100644 --- a/gen.go +++ b/gen.go @@ -427,12 +427,13 @@ func emitCborMarshalSliceField(w io.Writer, f Field) error { } func emitCborMarshalStructTuple(w io.Writer, gti *GenTypeInfo) error { - err := doTemplate(w, gti, `func (t *{{ .Name }}) MarshalCBOR(w io.Writer) error { + err := doTemplate(w, gti, `var lengthBuf{{ .Name }} = {{ .TupleHeaderAsByteString }} +func (t *{{ .Name }}) MarshalCBOR(w io.Writer) error { if t == nil { _, err := w.Write(cbg.CborNull) return err } - if _, err := w.Write({{ .TupleHeaderAsByteString }}); err != nil { + if _, err := w.Write(lengthBuf{{ .Name }}); err != nil { return err } `) diff --git a/testing/cbor_gen.go b/testing/cbor_gen.go index 19bfb2d..349414f 100644 --- a/testing/cbor_gen.go +++ b/testing/cbor_gen.go @@ -12,12 +12,14 @@ import ( var _ = xerrors.Errorf +var lengthBufSignedArray = []byte{129} + func (t *SignedArray) MarshalCBOR(w io.Writer) error { if t == nil { _, err := w.Write(cbg.CborNull) return err } - if _, err := w.Write([]byte{129}); err != nil { + if _, err := w.Write(lengthBufSignedArray); err != nil { return err } @@ -88,12 +90,14 @@ func (t *SignedArray) UnmarshalCBOR(r io.Reader) error { return nil } +var lengthBufSimpleTypeOne = []byte{132} + func (t *SimpleTypeOne) MarshalCBOR(w io.Writer) error { if t == nil { _, err := w.Write(cbg.CborNull) return err } - if _, err := w.Write([]byte{132}); err != nil { + if _, err := w.Write(lengthBufSimpleTypeOne); err != nil { return err } @@ -225,12 +229,14 @@ func (t *SimpleTypeOne) UnmarshalCBOR(r io.Reader) error { return nil } +var lengthBufSimpleTypeTwo = []byte{137} + func (t *SimpleTypeTwo) MarshalCBOR(w io.Writer) error { if t == nil { _, err := w.Write(cbg.CborNull) return err } - if _, err := w.Write([]byte{137}); err != nil { + if _, err := w.Write(lengthBufSimpleTypeTwo); err != nil { return err } @@ -648,12 +654,14 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { return nil } +var lengthBufDeferredContainer = []byte{131} + func (t *DeferredContainer) MarshalCBOR(w io.Writer) error { if t == nil { _, err := w.Write(cbg.CborNull) return err } - if _, err := w.Write([]byte{131}); err != nil { + if _, err := w.Write(lengthBufDeferredContainer); err != nil { return err } diff --git a/utils.go b/utils.go index 0932b04..a6ff081 100644 --- a/utils.go +++ b/utils.go @@ -219,9 +219,9 @@ func readByte(r io.Reader) (byte, error) { if br, ok := r.(io.ByteReader); ok { return br.ReadByte() } - var b [1]byte - _, err := io.ReadFull(r, b[:]) - return b[0], err + var buf [1]byte + _, err := io.ReadFull(r, buf[:1]) + return buf[0], err } func CborReadHeader(br io.Reader) (byte, uint64, error) { @@ -246,31 +246,31 @@ func CborReadHeader(br io.Reader) (byte, uint64, error) { } return maj, uint64(next), nil case low == 25: - buf := make([]byte, 2) - if _, err := io.ReadFull(br, buf); err != nil { + scratch := make([]byte, 2) + if _, err := io.ReadFull(br, scratch[:2]); err != nil { return 0, 0, err } - val := uint64(binary.BigEndian.Uint16(buf)) + val := uint64(binary.BigEndian.Uint16(scratch[:2])) if val <= math.MaxUint8 { return 0, 0, fmt.Errorf("cbor input was not canonical (lval 25 with value <= MaxUint8)") } return maj, val, nil case low == 26: - buf := make([]byte, 4) - if _, err := io.ReadFull(br, buf); err != nil { + scratch := make([]byte, 4) + if _, err := io.ReadFull(br, scratch[:4]); err != nil { return 0, 0, err } - val := uint64(binary.BigEndian.Uint32(buf)) + val := uint64(binary.BigEndian.Uint32(scratch[:4])) if val <= math.MaxUint16 { return 0, 0, fmt.Errorf("cbor input was not canonical (lval 26 with value <= MaxUint16)") } return maj, val, nil case low == 27: - buf := make([]byte, 8) - if _, err := io.ReadFull(br, buf); err != nil { + scratch := make([]byte, 8) + if _, err := io.ReadFull(br, scratch); err != nil { return 0, 0, err } - val := binary.BigEndian.Uint64(buf) + val := binary.BigEndian.Uint64(scratch) if val <= math.MaxUint32 { return 0, 0, fmt.Errorf("cbor input was not canonical (lval 27 with value <= MaxUint32)") } @@ -280,11 +280,12 @@ func CborReadHeader(br io.Reader) (byte, uint64, error) { } } -func CborWriteHeader(w io.Writer, t byte, val uint64) error { - return WriteMajorTypeHeader(w, t, val) +func CborWriteHeader(w io.Writer, t byte, l uint64) error { + return WriteMajorTypeHeader(w, t, l) } // TODO: No matter what I do, this function *still* allocates. Its super frustrating. +// See issue: https://github.com/golang/go/issues/33160 func WriteMajorTypeHeader(w io.Writer, t byte, l uint64) error { switch { case l < 24: @@ -314,6 +315,36 @@ func WriteMajorTypeHeader(w io.Writer, t byte, l uint64) error { } } +// Same as the above, but uses a passed in buffer to avoid allocations +func WriteMajorTypeHeaderBuf(buf []byte, w io.Writer, t byte, l uint64) error { + switch { + case l < 24: + buf[0] = (t << 5) | byte(l) + _, err := w.Write(buf[:1]) + return err + case l < (1 << 8): + buf[0] = (t << 5) | 24 + buf[1] = byte(l) + _, err := w.Write(buf[:2]) + return err + case l < (1 << 16): + buf[0] = (t << 5) | 25 + binary.BigEndian.PutUint16(buf[1:3], uint16(l)) + _, err := w.Write(buf[:3]) + return err + case l < (1 << 32): + buf[0] = (t << 5) | 26 + binary.BigEndian.PutUint32(buf[1:5], uint32(l)) + _, err := w.Write(buf[:5]) + return err + default: + buf[0] = (t << 5) | 27 + binary.BigEndian.PutUint64(buf[1:9], uint64(l)) + _, err := w.Write(buf[:9]) + return err + } +} + func CborEncodeMajorType(t byte, l uint64) []byte { switch { case l < 24: @@ -390,9 +421,9 @@ var ( func EncodeBool(b bool) []byte { if b { - return []byte{0xf5} + return CborBoolTrue } - return []byte{0xf4} + return CborBoolFalse } func WriteBool(w io.Writer, b bool) error { @@ -449,8 +480,35 @@ func bufToCid(buf []byte) (cid.Cid, error) { return cid.Cast(buf[1:]) } +var byteArrZero = []byte{0} + func WriteCid(w io.Writer, c cid.Cid) error { - if err := CborWriteHeader(w, MajTag, 42); err != nil { + if err := WriteMajorTypeHeader(w, MajTag, 42); err != nil { + return err + } + if c == cid.Undef { + return fmt.Errorf("undefined cid") + //return CborWriteHeader(w, MajByteString, 0) + } + + if err := WriteMajorTypeHeader(w, MajByteString, uint64(c.ByteLen()+1)); err != nil { + return err + } + + // that binary multibase prefix... + if _, err := w.Write(byteArrZero); err != nil { + return err + } + + if _, err := c.WriteBytes(w); err != nil { + return err + } + + return nil +} + +func WriteCidBuf(buf []byte, w io.Writer, c cid.Cid) error { + if err := WriteMajorTypeHeaderBuf(buf, w, MajTag, 42); err != nil { return err } if c == cid.Undef { @@ -458,12 +516,12 @@ func WriteCid(w io.Writer, c cid.Cid) error { //return CborWriteHeader(w, MajByteString, 0) } - if err := CborWriteHeader(w, MajByteString, uint64(c.ByteLen()+1)); err != nil { + if err := WriteMajorTypeHeaderBuf(buf, w, MajByteString, uint64(c.ByteLen()+1)); err != nil { return err } // that binary multibase prefix... - if _, err := w.Write([]byte{0}); err != nil { + if _, err := w.Write(byteArrZero); err != nil { return err } @@ -506,11 +564,11 @@ type CborInt int64 func (ci *CborInt) MarshalCBOR(w io.Writer) error { v := int64(*ci) if v >= 0 { - if _, err := w.Write(CborEncodeMajorType(MajUnsignedInt, uint64(v))); err != nil { + if err := WriteMajorTypeHeader(w, MajUnsignedInt, uint64(v)); err != nil { return err } } else { - if _, err := w.Write(CborEncodeMajorType(MajNegativeInt, uint64(-v)-1)); err != nil { + if err := WriteMajorTypeHeader(w, MajNegativeInt, uint64(-v)-1); err != nil { return err } } From 2d9e79ab4308937a91b9d23182253d0015689f6b Mon Sep 17 00:00:00 2001 From: Jeromy Date: Sun, 3 May 2020 12:15:34 -0700 Subject: [PATCH 2/5] use scratch buffers to avoid even more allocations --- gen.go | 44 ++++++++++++++++++++++++++++++++----- testing/cbor_gen.go | 48 ++++++++++++++++++++++++----------------- testing/cbor_map_gen.go | 46 ++++++++++++++++++++------------------- 3 files changed, 91 insertions(+), 47 deletions(-) diff --git a/gen.go b/gen.go index ff3b51a..90ec89b 100644 --- a/gen.go +++ b/gen.go @@ -16,7 +16,7 @@ func doTemplate(w io.Writer, info interface{}, templ string) error { t := template.Must(template.New(""). Funcs(template.FuncMap{ "MajorType": func(wname string, tname string, val string) string { - return fmt.Sprintf(`if err := cbg.WriteMajorTypeHeader(%s, %s, uint64(%s)); err != nil { + return fmt.Sprintf(`if err := cbg.WriteMajorTypeHeaderBuf(scratch, %s, %s, uint64(%s)); err != nil { return err }`, wname, tname, val) }, @@ -94,6 +94,34 @@ type GenTypeInfo struct { Fields []Field } +func (gti *GenTypeInfo) NeedsScratch() bool { + for _, f := range gti.Fields { + switch f.Type.Kind() { + case reflect.String, + reflect.Uint64, + reflect.Int64, + reflect.Uint8, + reflect.Array, + reflect.Slice, + reflect.Map: + return true + + case reflect.Struct: + fname := f.Type.PkgPath() + "." + f.Type.Name() + switch fname { + case "math/big.Int": + return true + case "github.com/ipfs/go-cid.Cid": + return true + } + // nope + case reflect.Bool: + // nope + } + } + return false +} + func nameIsExported(name string) bool { return strings.ToUpper(name[0:1]) == name[0:1] } @@ -168,7 +196,7 @@ func emitCborMarshalStringField(w io.Writer, f Field) error { } {{ MajorType "w" "cbg.MajTextString" (print "len(" .Name ")") }} - if _, err := w.Write([]byte({{ .Name }})); err != nil { + if _, err := io.WriteString(w, {{ .Name }}); err != nil { return err } `) @@ -204,12 +232,12 @@ func emitCborMarshalStructField(w io.Writer, f Field) error { return err } } else { - if err := cbg.WriteCid(w, *{{ .Name }}); err != nil { + if err := cbg.WriteCidBuf(scratch, w, *{{ .Name }}); err != nil { return xerrors.Errorf("failed to write cid field {{ .Name }}: %w", err) } } {{ else }} - if err := cbg.WriteCid(w, {{ .Name }}); err != nil { + if err := cbg.WriteCidBuf(scratch, w, {{ .Name }}); err != nil { return xerrors.Errorf("failed to write cid field {{ .Name }}: %w", err) } {{ end }} @@ -372,7 +400,7 @@ func emitCborMarshalSliceField(w io.Writer, f Field) error { switch fname { case "github.com/ipfs/go-cid.Cid": err := doTemplate(w, f, ` - if err := cbg.WriteCid(w, v); err != nil { + if err := cbg.WriteCidBuf(scratch, w, v); err != nil { return xerrors.Errorf("failed writing cid field {{ .Name }}: %w", err) } `) @@ -436,6 +464,9 @@ func (t *{{ .Name }}) MarshalCBOR(w io.Writer) error { if _, err := w.Write(lengthBuf{{ .Name }}); err != nil { return err } +{{ if .NeedsScratch }} + scratch := make([]byte, 9) +{{ end }} `) if err != nil { return err @@ -1039,6 +1070,9 @@ func emitCborMarshalStructMap(w io.Writer, gti *GenTypeInfo) error { if _, err := w.Write({{ .MapHeaderAsByteString }}); err != nil { return err } +{{ if .NeedsScratch }} + scratch := make([]byte, 9) +{{ end }} `) if err != nil { return err diff --git a/testing/cbor_gen.go b/testing/cbor_gen.go index 349414f..611405b 100644 --- a/testing/cbor_gen.go +++ b/testing/cbor_gen.go @@ -23,12 +23,14 @@ func (t *SignedArray) MarshalCBOR(w io.Writer) error { return err } + scratch := make([]byte, 9) + // t.Signed ([]uint64) (slice) if len(t.Signed) > cbg.MaxLength { return xerrors.Errorf("Slice value in field t.Signed was too long") } - if err := cbg.WriteMajorTypeHeader(w, cbg.MajArray, uint64(len(t.Signed))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajArray, uint64(len(t.Signed))); err != nil { return err } for _, v := range t.Signed { @@ -101,21 +103,23 @@ func (t *SimpleTypeOne) MarshalCBOR(w io.Writer) error { return err } + scratch := make([]byte, 9) + // t.Foo (string) (string) if len(t.Foo) > cbg.MaxLength { return xerrors.Errorf("Value in field t.Foo was too long") } - if err := cbg.WriteMajorTypeHeader(w, cbg.MajTextString, uint64(len(t.Foo))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.Foo))); err != nil { return err } - if _, err := w.Write([]byte(t.Foo)); err != nil { + if _, err := io.WriteString(w, t.Foo); err != nil { return err } // t.Value (uint64) (uint64) - if err := cbg.WriteMajorTypeHeader(w, cbg.MajUnsignedInt, uint64(t.Value)); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.Value)); err != nil { return err } @@ -124,7 +128,7 @@ func (t *SimpleTypeOne) MarshalCBOR(w io.Writer) error { return xerrors.Errorf("Byte array in field t.Binary was too long") } - if err := cbg.WriteMajorTypeHeader(w, cbg.MajByteString, uint64(len(t.Binary))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajByteString, uint64(len(t.Binary))); err != nil { return err } @@ -134,11 +138,11 @@ func (t *SimpleTypeOne) MarshalCBOR(w io.Writer) error { // t.Signed (int64) (int64) if t.Signed >= 0 { - if err := cbg.WriteMajorTypeHeader(w, cbg.MajUnsignedInt, uint64(t.Signed)); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.Signed)); err != nil { return err } } else { - if err := cbg.WriteMajorTypeHeader(w, cbg.MajNegativeInt, uint64(-t.Signed-1)); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajNegativeInt, uint64(-t.Signed-1)); err != nil { return err } } @@ -240,6 +244,8 @@ func (t *SimpleTypeTwo) MarshalCBOR(w io.Writer) error { return err } + scratch := make([]byte, 9) + // t.Stuff (testing.SimpleTypeTwo) (struct) if err := t.Stuff.MarshalCBOR(w); err != nil { return err @@ -250,7 +256,7 @@ func (t *SimpleTypeTwo) MarshalCBOR(w io.Writer) error { return xerrors.Errorf("Slice value in field t.Others was too long") } - if err := cbg.WriteMajorTypeHeader(w, cbg.MajArray, uint64(len(t.Others))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajArray, uint64(len(t.Others))); err != nil { return err } for _, v := range t.Others { @@ -264,16 +270,16 @@ func (t *SimpleTypeTwo) MarshalCBOR(w io.Writer) error { return xerrors.Errorf("Slice value in field t.SignedOthers was too long") } - if err := cbg.WriteMajorTypeHeader(w, cbg.MajArray, uint64(len(t.SignedOthers))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajArray, uint64(len(t.SignedOthers))); err != nil { return err } for _, v := range t.SignedOthers { if v >= 0 { - if err := cbg.WriteMajorTypeHeader(w, cbg.MajUnsignedInt, uint64(v)); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(v)); err != nil { return err } } else { - if err := cbg.WriteMajorTypeHeader(w, cbg.MajNegativeInt, uint64(-v-1)); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajNegativeInt, uint64(-v-1)); err != nil { return err } } @@ -284,7 +290,7 @@ func (t *SimpleTypeTwo) MarshalCBOR(w io.Writer) error { return xerrors.Errorf("Slice value in field t.Test was too long") } - if err := cbg.WriteMajorTypeHeader(w, cbg.MajArray, uint64(len(t.Test))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajArray, uint64(len(t.Test))); err != nil { return err } for _, v := range t.Test { @@ -292,7 +298,7 @@ func (t *SimpleTypeTwo) MarshalCBOR(w io.Writer) error { return xerrors.Errorf("Byte array in field v was too long") } - if err := cbg.WriteMajorTypeHeader(w, cbg.MajByteString, uint64(len(v))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajByteString, uint64(len(v))); err != nil { return err } @@ -306,10 +312,10 @@ func (t *SimpleTypeTwo) MarshalCBOR(w io.Writer) error { return xerrors.Errorf("Value in field t.Dog was too long") } - if err := cbg.WriteMajorTypeHeader(w, cbg.MajTextString, uint64(len(t.Dog))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.Dog))); err != nil { return err } - if _, err := w.Write([]byte(t.Dog)); err != nil { + if _, err := io.WriteString(w, t.Dog); err != nil { return err } @@ -318,7 +324,7 @@ func (t *SimpleTypeTwo) MarshalCBOR(w io.Writer) error { return xerrors.Errorf("Slice value in field t.Numbers was too long") } - if err := cbg.WriteMajorTypeHeader(w, cbg.MajArray, uint64(len(t.Numbers))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajArray, uint64(len(t.Numbers))); err != nil { return err } for _, v := range t.Numbers { @@ -334,7 +340,7 @@ func (t *SimpleTypeTwo) MarshalCBOR(w io.Writer) error { return err } } else { - if err := cbg.WriteMajorTypeHeader(w, cbg.MajUnsignedInt, uint64(*t.Pizza)); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(*t.Pizza)); err != nil { return err } } @@ -346,7 +352,7 @@ func (t *SimpleTypeTwo) MarshalCBOR(w io.Writer) error { return err } } else { - if err := cbg.WriteMajorTypeHeader(w, cbg.MajUnsignedInt, uint64(*t.PointyPizza)); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(*t.PointyPizza)); err != nil { return err } } @@ -356,7 +362,7 @@ func (t *SimpleTypeTwo) MarshalCBOR(w io.Writer) error { return xerrors.Errorf("Slice value in field t.Arrrrrghay was too long") } - if err := cbg.WriteMajorTypeHeader(w, cbg.MajArray, uint64(len(t.Arrrrrghay))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajArray, uint64(len(t.Arrrrrghay))); err != nil { return err } for _, v := range t.Arrrrrghay { @@ -665,6 +671,8 @@ func (t *DeferredContainer) MarshalCBOR(w io.Writer) error { return err } + scratch := make([]byte, 9) + // t.Stuff (testing.SimpleTypeOne) (struct) if err := t.Stuff.MarshalCBOR(w); err != nil { return err @@ -677,7 +685,7 @@ func (t *DeferredContainer) MarshalCBOR(w io.Writer) error { // t.Value (uint64) (uint64) - if err := cbg.WriteMajorTypeHeader(w, cbg.MajUnsignedInt, uint64(t.Value)); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.Value)); err != nil { return err } diff --git a/testing/cbor_map_gen.go b/testing/cbor_map_gen.go index 36312ae..a3f66e7 100644 --- a/testing/cbor_map_gen.go +++ b/testing/cbor_map_gen.go @@ -21,15 +21,17 @@ func (t *SimpleTypeTree) MarshalCBOR(w io.Writer) error { return err } + scratch := make([]byte, 9) + // t.Stuff (testing.SimpleTypeTree) (struct) if len("Stuff") > cbg.MaxLength { return xerrors.Errorf("Value in field \"Stuff\" was too long") } - if err := cbg.WriteMajorTypeHeader(w, cbg.MajTextString, uint64(len("Stuff"))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Stuff"))); err != nil { return err } - if _, err := w.Write([]byte("Stuff")); err != nil { + if _, err := io.WriteString(w, "Stuff"); err != nil { return err } @@ -42,10 +44,10 @@ func (t *SimpleTypeTree) MarshalCBOR(w io.Writer) error { return xerrors.Errorf("Value in field \"Stufff\" was too long") } - if err := cbg.WriteMajorTypeHeader(w, cbg.MajTextString, uint64(len("Stufff"))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Stufff"))); err != nil { return err } - if _, err := w.Write([]byte("Stufff")); err != nil { + if _, err := io.WriteString(w, "Stufff"); err != nil { return err } @@ -58,10 +60,10 @@ func (t *SimpleTypeTree) MarshalCBOR(w io.Writer) error { return xerrors.Errorf("Value in field \"Others\" was too long") } - if err := cbg.WriteMajorTypeHeader(w, cbg.MajTextString, uint64(len("Others"))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Others"))); err != nil { return err } - if _, err := w.Write([]byte("Others")); err != nil { + if _, err := io.WriteString(w, "Others"); err != nil { return err } @@ -69,7 +71,7 @@ func (t *SimpleTypeTree) MarshalCBOR(w io.Writer) error { return xerrors.Errorf("Slice value in field t.Others was too long") } - if err := cbg.WriteMajorTypeHeader(w, cbg.MajArray, uint64(len(t.Others))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajArray, uint64(len(t.Others))); err != nil { return err } for _, v := range t.Others { @@ -83,10 +85,10 @@ func (t *SimpleTypeTree) MarshalCBOR(w io.Writer) error { return xerrors.Errorf("Value in field \"Test\" was too long") } - if err := cbg.WriteMajorTypeHeader(w, cbg.MajTextString, uint64(len("Test"))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Test"))); err != nil { return err } - if _, err := w.Write([]byte("Test")); err != nil { + if _, err := io.WriteString(w, "Test"); err != nil { return err } @@ -94,7 +96,7 @@ func (t *SimpleTypeTree) MarshalCBOR(w io.Writer) error { return xerrors.Errorf("Slice value in field t.Test was too long") } - if err := cbg.WriteMajorTypeHeader(w, cbg.MajArray, uint64(len(t.Test))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajArray, uint64(len(t.Test))); err != nil { return err } for _, v := range t.Test { @@ -102,7 +104,7 @@ func (t *SimpleTypeTree) MarshalCBOR(w io.Writer) error { return xerrors.Errorf("Byte array in field v was too long") } - if err := cbg.WriteMajorTypeHeader(w, cbg.MajByteString, uint64(len(v))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajByteString, uint64(len(v))); err != nil { return err } @@ -116,10 +118,10 @@ func (t *SimpleTypeTree) MarshalCBOR(w io.Writer) error { return xerrors.Errorf("Value in field \"Dog\" was too long") } - if err := cbg.WriteMajorTypeHeader(w, cbg.MajTextString, uint64(len("Dog"))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Dog"))); err != nil { return err } - if _, err := w.Write([]byte("Dog")); err != nil { + if _, err := io.WriteString(w, "Dog"); err != nil { return err } @@ -127,10 +129,10 @@ func (t *SimpleTypeTree) MarshalCBOR(w io.Writer) error { return xerrors.Errorf("Value in field t.Dog was too long") } - if err := cbg.WriteMajorTypeHeader(w, cbg.MajTextString, uint64(len(t.Dog))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.Dog))); err != nil { return err } - if _, err := w.Write([]byte(t.Dog)); err != nil { + if _, err := io.WriteString(w, t.Dog); err != nil { return err } @@ -139,19 +141,19 @@ func (t *SimpleTypeTree) MarshalCBOR(w io.Writer) error { return xerrors.Errorf("Value in field \"SixtyThreeBitIntegerWithASignBit\" was too long") } - if err := cbg.WriteMajorTypeHeader(w, cbg.MajTextString, uint64(len("SixtyThreeBitIntegerWithASignBit"))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("SixtyThreeBitIntegerWithASignBit"))); err != nil { return err } - if _, err := w.Write([]byte("SixtyThreeBitIntegerWithASignBit")); err != nil { + if _, err := io.WriteString(w, "SixtyThreeBitIntegerWithASignBit"); err != nil { return err } if t.SixtyThreeBitIntegerWithASignBit >= 0 { - if err := cbg.WriteMajorTypeHeader(w, cbg.MajUnsignedInt, uint64(t.SixtyThreeBitIntegerWithASignBit)); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.SixtyThreeBitIntegerWithASignBit)); err != nil { return err } } else { - if err := cbg.WriteMajorTypeHeader(w, cbg.MajNegativeInt, uint64(-t.SixtyThreeBitIntegerWithASignBit-1)); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajNegativeInt, uint64(-t.SixtyThreeBitIntegerWithASignBit-1)); err != nil { return err } } @@ -161,10 +163,10 @@ func (t *SimpleTypeTree) MarshalCBOR(w io.Writer) error { return xerrors.Errorf("Value in field \"NotPizza\" was too long") } - if err := cbg.WriteMajorTypeHeader(w, cbg.MajTextString, uint64(len("NotPizza"))); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("NotPizza"))); err != nil { return err } - if _, err := w.Write([]byte("NotPizza")); err != nil { + if _, err := io.WriteString(w, "NotPizza"); err != nil { return err } @@ -173,7 +175,7 @@ func (t *SimpleTypeTree) MarshalCBOR(w io.Writer) error { return err } } else { - if err := cbg.WriteMajorTypeHeader(w, cbg.MajUnsignedInt, uint64(*t.NotPizza)); err != nil { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(*t.NotPizza)); err != nil { return err } } From fe75e5c9712941d64d46a05adfe555dc3f93f51e Mon Sep 17 00:00:00 2001 From: Jeromy Date: Mon, 4 May 2020 09:40:55 -0700 Subject: [PATCH 3/5] add benchmark for unmarshaling --- testing/bench_test.go | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/testing/bench_test.go b/testing/bench_test.go index b237c94..f6a193a 100644 --- a/testing/bench_test.go +++ b/testing/bench_test.go @@ -1,6 +1,9 @@ package testing import ( + "bytes" + "encoding/hex" + "io" "io/ioutil" "math/rand" "reflect" @@ -26,3 +29,25 @@ func BenchmarkMarshaling(b *testing.B) { } } } + +func BenchmarkUnmarshaling(b *testing.B) { + hx := "8989f68080807859f099a586f0908093f1af9fb6f3a0ad82e8aaa0efbfbdf1b88688f29d8aaeecacabf0a4be94f19295a0f19b9081f0b6bf8ff3ad83a6f09e9ca7f2be8a8bf187a8a0f280a8b3f4899a9bf181afb1f0bca2b0f1b5ab9bf2b0a3ac80f6f68384787af2b1a082f3b99d98f1adb9b6f3b9868df29fbbb0f3858791f3b5b39df2b68e92f2a9bb9af282b4b2f18fba9cf294b8b2f3a0a1a9f0aab1aaf28cb994f09796aef195bc90f488be81e0a59af2928183f1a0a4abf393bbbae39d8df28fb287f0bf8fa6f0b79a89f188babcf395b0b1f29ebab7f2b0a091f29db48f1b527ef13ee4f5321a403b2130370299eeb937847848f287b8b9f1ad9e90f1b1b9bbf18ebc91f3908583f0be9ab3f2aca8abf0a8acadf380a7abf293a8aaf1b2b6a6f3b89587f3809fadf3a39f97f3a8b48cf3b299bff19cab9df28399a01b374797708d2015d3401b6bfb7066c509754c8478a1f0ab8f96f287988df297aea7f3afa699f3859788f2a2b2b8f2b681a6f29a95a4f382978cf396b183e2acbdf39cbdb5f0b99b94f1a2baaaf1ba89b0f3a8a7bbf397bdabf3af8c83f1b38ebef0beb1a0f3939f83f0b9ad90f1acb597f0b49eb0f29ab3a3f480808ef39b878ae5989ff0a7b789f48981b6f281aba6f2a9ad88f09fb395f0aa95adf0a1a59ff38a8d97f397b7b0eebfa9f2a5ab87f2afa7b8f0b992b81bab566703ac0b139c401b463b0320db277de1841b73dd7cd1861ff4561bc0256739761d28dd1b39c9019ac37c08721b2f08fbf368bf7f94813b075c40eb7f66e0488078b4f288898bf486bc90e9b8bdf180ab8ff39db1a7f0b1afb8f3ab9fa6f1b19182f189bdaff3bf80a5f1a18fb4f39c99a9f0ba839af2adb88fe4bd9df39bba8bf28ebf9ef2b3a783f0b6b395f197be84f3a1998af1b0898bf3b0b08ef1b49b94f094b59df19dbfa6f2aa8494f48ba0b2e28181f1a08999f2b3ac81eaa689f1bb80bbf2ae918bf0a19397f1a19d9cf3b095b5f1b4baa2f0b7ad92f3ab8c8ef38fab92f489b499f18d9899f0b5bcb5f2a3b6a5f2a1acb1831be5bdbd1384238b4b1b8a95991fbf9ca8d11baf61be2ac6477c7d1ba1d9dac0cecd182d1b4175138c8c7fbb4e8384785ef0a7adaef39b8a8ff1b79bacc693f1948ebef0938383f48aa6abf09ab684f1ba8c89f188b091f18ab2b5f1ac8484f2b7b089f18b97bdf1838aacf397ad98f0b9a8aff394a2a3f39eb6bff09ab8bef39189bef18f89aaf3aca982f29381901b7f0bf8763d569f3b403b75c40e5c6163108084786af1879fa8f2b4af9ef3a8b3b6f3b0be93f0aba9a8f0a1a698f3b6a7a7e6adbef1a8849bf28087a3f3b89f82f38caab6f0b7b09ff1bf938ff0a0b1aff2b79691f0a5b29bf4858896f484a5abf393bbbbf3a2b8bdf29393a6eba180f1a1b3b0f29da098f1b09ca7f3bda2901b6f088c64a0854512401b564b5898ca46ac958467f29c9188eea5941bb7b58825b1edf1ee403b6dcbd95c52f6ca33" + + d, err := hex.DecodeString(hx) + if err != nil { + b.Fatal(err) + } + + buf := bytes.NewReader(d) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + buf.Seek(0, io.SeekStart) + var tt SimpleTypeTwo + if err := tt.UnmarshalCBOR(buf); err != nil { + b.Fatal(err) + } + } + +} From 7a77837bec3632bd567a51cf88cb3f41b3b0bc33 Mon Sep 17 00:00:00 2001 From: Jeromy Date: Mon, 4 May 2020 09:41:30 -0700 Subject: [PATCH 4/5] improve allocation counts in unmarshalers by using shared buffers --- gen.go | 31 +++++++------ testing/cbor_gen.go | 50 +++++++++++---------- testing/cbor_map_gen.go | 19 ++++---- utils.go | 98 ++++++++++++++++++++++++++++++++++++++--- 4 files changed, 148 insertions(+), 50 deletions(-) diff --git a/gen.go b/gen.go index 90ec89b..e504a8e 100644 --- a/gen.go +++ b/gen.go @@ -20,6 +20,9 @@ func doTemplate(w io.Writer, info interface{}, templ string) error { return err }`, wname, tname, val) }, + "ReadHeader": func(rdr string) string { + return fmt.Sprintf(`cbg.CborReadHeaderBuf(%s, scratch)`, rdr) + }, }).Parse(templ)) return t.Execute(w, info) @@ -529,7 +532,7 @@ func emitCborUnmarshalStringField(w io.Writer, f Field) error { } return doTemplate(w, f, ` { - sval, err := cbg.ReadString(br) + sval, err := cbg.ReadStringBuf(br, scratch) if err != nil { return err } @@ -545,7 +548,7 @@ func emitCborUnmarshalStructField(w io.Writer, f Field) error { switch fname { case "math/big.Int": return doTemplate(w, f, ` - maj, extra, err = cbg.CborReadHeader(br) + maj, extra, err = {{ ReadHeader "br" }} if err != nil { return err } @@ -554,7 +557,7 @@ func emitCborUnmarshalStructField(w io.Writer, f Field) error { return fmt.Errorf("big ints should be cbor bignums") } - maj, extra, err = cbg.CborReadHeader(br) + maj, extra, err = {{ ReadHeader "br" }} if err != nil { return err } @@ -647,7 +650,7 @@ func emitCborUnmarshalStructField(w io.Writer, f Field) error { func emitCborUnmarshalInt64Field(w io.Writer, f Field) error { return doTemplate(w, f, `{ - maj, extra, err := cbg.CborReadHeader(br) + maj, extra, err := {{ ReadHeader "br" }} var extraI int64 if err != nil { return err @@ -687,7 +690,7 @@ func emitCborUnmarshalUint64Field(w io.Writer, f Field) error { return err } } else { - maj, extra, err = cbg.CborReadHeader(br) + maj, extra, err = {{ ReadHeader "br" }} if err != nil { return err } @@ -698,7 +701,7 @@ func emitCborUnmarshalUint64Field(w io.Writer, f Field) error { {{ .Name }} = &typed } {{ else }} - maj, extra, err = cbg.CborReadHeader(br) + maj, extra, err = {{ ReadHeader "br" }} if err != nil { return err } @@ -713,7 +716,7 @@ func emitCborUnmarshalUint64Field(w io.Writer, f Field) error { func emitCborUnmarshalUint8Field(w io.Writer, f Field) error { return doTemplate(w, f, ` - maj, extra, err = cbg.CborReadHeader(br) + maj, extra, err = {{ ReadHeader "br" }} if err != nil { return err } @@ -729,7 +732,7 @@ func emitCborUnmarshalUint8Field(w io.Writer, f Field) error { func emitCborUnmarshalBoolField(w io.Writer, f Field) error { return doTemplate(w, f, ` - maj, extra, err = cbg.CborReadHeader(br) + maj, extra, err = {{ ReadHeader "br" }} if err != nil { return err } @@ -749,7 +752,7 @@ func emitCborUnmarshalBoolField(w io.Writer, f Field) error { func emitCborUnmarshalMapField(w io.Writer, f Field) error { err := doTemplate(w, f, ` - maj, extra, err = cbg.CborReadHeader(br) + maj, extra, err = {{ ReadHeader "br" }} if err != nil { return err } @@ -834,7 +837,7 @@ func emitCborUnmarshalSliceField(w io.Writer, f Field) error { } err := doTemplate(w, f, ` - maj, extra, err = cbg.CborReadHeader(br) + maj, extra, err = {{ ReadHeader "br" }} if err != nil { return err } @@ -924,7 +927,7 @@ func emitCborUnmarshalSliceField(w io.Writer, f Field) error { } case reflect.Uint64: err := doTemplate(w, f, ` - maj, val, err := cbg.CborReadHeader(br) + maj, val, err := {{ ReadHeader "br" }} if err != nil { return xerrors.Errorf("failed to read uint64 for {{ .Name }} slice: %w", err) } @@ -976,8 +979,9 @@ func emitCborUnmarshalStructTuple(w io.Writer, gti *GenTypeInfo) error { err := doTemplate(w, gti, ` func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error { br := cbg.GetPeeker(r) + scratch := make([]byte, 8) - maj, extra, err := cbg.CborReadHeader(br) + maj, extra, err := {{ ReadHeader "br" }} if err != nil { return err } @@ -1137,8 +1141,9 @@ func emitCborUnmarshalStructMap(w io.Writer, gti *GenTypeInfo) error { err := doTemplate(w, gti, ` func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error { br := cbg.GetPeeker(r) + scratch := make([]byte, 8) - maj, extra, err := cbg.CborReadHeader(br) + maj, extra, err := {{ ReadHeader "br" }} if err != nil { return err } diff --git a/testing/cbor_gen.go b/testing/cbor_gen.go index 611405b..8b88fcd 100644 --- a/testing/cbor_gen.go +++ b/testing/cbor_gen.go @@ -43,8 +43,9 @@ func (t *SignedArray) MarshalCBOR(w io.Writer) error { func (t *SignedArray) UnmarshalCBOR(r io.Reader) error { br := cbg.GetPeeker(r) + scratch := make([]byte, 8) - maj, extra, err := cbg.CborReadHeader(br) + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } @@ -58,7 +59,7 @@ func (t *SignedArray) UnmarshalCBOR(r io.Reader) error { // t.Signed ([]uint64) (slice) - maj, extra, err = cbg.CborReadHeader(br) + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } @@ -77,7 +78,7 @@ func (t *SignedArray) UnmarshalCBOR(r io.Reader) error { for i := 0; i < int(extra); i++ { - maj, val, err := cbg.CborReadHeader(br) + maj, val, err := cbg.CborReadHeaderBuf(br, scratch) if err != nil { return xerrors.Errorf("failed to read uint64 for t.Signed slice: %w", err) } @@ -151,8 +152,9 @@ func (t *SimpleTypeOne) MarshalCBOR(w io.Writer) error { func (t *SimpleTypeOne) UnmarshalCBOR(r io.Reader) error { br := cbg.GetPeeker(r) + scratch := make([]byte, 8) - maj, extra, err := cbg.CborReadHeader(br) + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } @@ -167,7 +169,7 @@ func (t *SimpleTypeOne) UnmarshalCBOR(r io.Reader) error { // t.Foo (string) (string) { - sval, err := cbg.ReadString(br) + sval, err := cbg.ReadStringBuf(br, scratch) if err != nil { return err } @@ -178,7 +180,7 @@ func (t *SimpleTypeOne) UnmarshalCBOR(r io.Reader) error { { - maj, extra, err = cbg.CborReadHeader(br) + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } @@ -190,7 +192,7 @@ func (t *SimpleTypeOne) UnmarshalCBOR(r io.Reader) error { } // t.Binary ([]uint8) (slice) - maj, extra, err = cbg.CborReadHeader(br) + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } @@ -207,7 +209,7 @@ func (t *SimpleTypeOne) UnmarshalCBOR(r io.Reader) error { } // t.Signed (int64) (int64) { - maj, extra, err := cbg.CborReadHeader(br) + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) var extraI int64 if err != nil { return err @@ -375,8 +377,9 @@ func (t *SimpleTypeTwo) MarshalCBOR(w io.Writer) error { func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { br := cbg.GetPeeker(r) + scratch := make([]byte, 8) - maj, extra, err := cbg.CborReadHeader(br) + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } @@ -411,7 +414,7 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { } // t.Others ([]uint64) (slice) - maj, extra, err = cbg.CborReadHeader(br) + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } @@ -430,7 +433,7 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { for i := 0; i < int(extra); i++ { - maj, val, err := cbg.CborReadHeader(br) + maj, val, err := cbg.CborReadHeaderBuf(br, scratch) if err != nil { return xerrors.Errorf("failed to read uint64 for t.Others slice: %w", err) } @@ -444,7 +447,7 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { // t.SignedOthers ([]int64) (slice) - maj, extra, err = cbg.CborReadHeader(br) + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } @@ -463,7 +466,7 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { for i := 0; i < int(extra); i++ { { - maj, extra, err := cbg.CborReadHeader(br) + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) var extraI int64 if err != nil { return err @@ -490,7 +493,7 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { // t.Test ([][]uint8) (slice) - maj, extra, err = cbg.CborReadHeader(br) + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } @@ -513,7 +516,7 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { var extra uint64 var err error - maj, extra, err = cbg.CborReadHeader(br) + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } @@ -534,7 +537,7 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { // t.Dog (string) (string) { - sval, err := cbg.ReadString(br) + sval, err := cbg.ReadStringBuf(br, scratch) if err != nil { return err } @@ -543,7 +546,7 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { } // t.Numbers ([]testing.NaturalNumber) (slice) - maj, extra, err = cbg.CborReadHeader(br) + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } @@ -562,7 +565,7 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { for i := 0; i < int(extra); i++ { - maj, val, err := cbg.CborReadHeader(br) + maj, val, err := cbg.CborReadHeaderBuf(br, scratch) if err != nil { return xerrors.Errorf("failed to read uint64 for t.Numbers slice: %w", err) } @@ -588,7 +591,7 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { return err } } else { - maj, extra, err = cbg.CborReadHeader(br) + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } @@ -614,7 +617,7 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { return err } } else { - maj, extra, err = cbg.CborReadHeader(br) + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } @@ -628,7 +631,7 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { } // t.Arrrrrghay ([3]testing.SimpleTypeOne) (array) - maj, extra, err = cbg.CborReadHeader(br) + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } @@ -694,8 +697,9 @@ func (t *DeferredContainer) MarshalCBOR(w io.Writer) error { func (t *DeferredContainer) UnmarshalCBOR(r io.Reader) error { br := cbg.GetPeeker(r) + scratch := make([]byte, 8) - maj, extra, err := cbg.CborReadHeader(br) + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } @@ -742,7 +746,7 @@ func (t *DeferredContainer) UnmarshalCBOR(r io.Reader) error { { - maj, extra, err = cbg.CborReadHeader(br) + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } diff --git a/testing/cbor_map_gen.go b/testing/cbor_map_gen.go index a3f66e7..a1cea2b 100644 --- a/testing/cbor_map_gen.go +++ b/testing/cbor_map_gen.go @@ -185,8 +185,9 @@ func (t *SimpleTypeTree) MarshalCBOR(w io.Writer) error { func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error { br := cbg.GetPeeker(r) + scratch := make([]byte, 8) - maj, extra, err := cbg.CborReadHeader(br) + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } @@ -204,7 +205,7 @@ func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error { for i := uint64(0); i < n; i++ { { - sval, err := cbg.ReadString(br) + sval, err := cbg.ReadStringBuf(br, scratch) if err != nil { return err } @@ -260,7 +261,7 @@ func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error { // t.Others ([]uint64) (slice) case "Others": - maj, extra, err = cbg.CborReadHeader(br) + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } @@ -279,7 +280,7 @@ func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error { for i := 0; i < int(extra); i++ { - maj, val, err := cbg.CborReadHeader(br) + maj, val, err := cbg.CborReadHeaderBuf(br, scratch) if err != nil { return xerrors.Errorf("failed to read uint64 for t.Others slice: %w", err) } @@ -294,7 +295,7 @@ func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error { // t.Test ([][]uint8) (slice) case "Test": - maj, extra, err = cbg.CborReadHeader(br) + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } @@ -317,7 +318,7 @@ func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error { var extra uint64 var err error - maj, extra, err = cbg.CborReadHeader(br) + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } @@ -339,7 +340,7 @@ func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error { case "Dog": { - sval, err := cbg.ReadString(br) + sval, err := cbg.ReadStringBuf(br, scratch) if err != nil { return err } @@ -349,7 +350,7 @@ func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error { // t.SixtyThreeBitIntegerWithASignBit (int64) (int64) case "SixtyThreeBitIntegerWithASignBit": { - maj, extra, err := cbg.CborReadHeader(br) + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) var extraI int64 if err != nil { return err @@ -387,7 +388,7 @@ func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error { return err } } else { - maj, extra, err = cbg.CborReadHeader(br) + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } diff --git a/utils.go b/utils.go index a6ff081..5b55666 100644 --- a/utils.go +++ b/utils.go @@ -247,7 +247,7 @@ func CborReadHeader(br io.Reader) (byte, uint64, error) { return maj, uint64(next), nil case low == 25: scratch := make([]byte, 2) - if _, err := io.ReadFull(br, scratch[:2]); err != nil { + if _, err := io.ReadAtLeast(br, scratch[:2], 2); err != nil { return 0, 0, err } val := uint64(binary.BigEndian.Uint16(scratch[:2])) @@ -257,7 +257,7 @@ func CborReadHeader(br io.Reader) (byte, uint64, error) { return maj, val, nil case low == 26: scratch := make([]byte, 4) - if _, err := io.ReadFull(br, scratch[:4]); err != nil { + if _, err := io.ReadAtLeast(br, scratch[:4], 4); err != nil { return 0, 0, err } val := uint64(binary.BigEndian.Uint32(scratch[:4])) @@ -267,7 +267,7 @@ func CborReadHeader(br io.Reader) (byte, uint64, error) { return maj, val, nil case low == 27: scratch := make([]byte, 8) - if _, err := io.ReadFull(br, scratch); err != nil { + if _, err := io.ReadAtLeast(br, scratch, 8); err != nil { return 0, 0, err } val := binary.BigEndian.Uint64(scratch) @@ -280,6 +280,71 @@ func CborReadHeader(br io.Reader) (byte, uint64, error) { } } +func readByteBuf(r io.Reader, scratch []byte) (byte, error) { + n, err := r.Read(scratch[:1]) + if err != nil { + return 0, err + } + if n != 1 { + return 0, fmt.Errorf("failed to read a byte") + } + return scratch[0], err +} + +// same as the above, just tries to allocate less by using a passed in scratch buffer +func CborReadHeaderBuf(br io.Reader, scratch []byte) (byte, uint64, error) { + first, err := readByteBuf(br, scratch) + if err != nil { + return 0, 0, err + } + + maj := (first & 0xe0) >> 5 + low := first & 0x1f + + switch { + case low < 24: + return maj, uint64(low), nil + case low == 24: + next, err := readByteBuf(br, scratch) + if err != nil { + return 0, 0, err + } + if next < 24 { + return 0, 0, fmt.Errorf("cbor input was not canonical (lval 24 with value < 24)") + } + return maj, uint64(next), nil + case low == 25: + if _, err := io.ReadAtLeast(br, scratch[:2], 2); err != nil { + return 0, 0, err + } + val := uint64(binary.BigEndian.Uint16(scratch[:2])) + if val <= math.MaxUint8 { + return 0, 0, fmt.Errorf("cbor input was not canonical (lval 25 with value <= MaxUint8)") + } + return maj, val, nil + case low == 26: + if _, err := io.ReadAtLeast(br, scratch[:4], 4); err != nil { + return 0, 0, err + } + val := uint64(binary.BigEndian.Uint32(scratch[:4])) + if val <= math.MaxUint16 { + return 0, 0, fmt.Errorf("cbor input was not canonical (lval 26 with value <= MaxUint16)") + } + return maj, val, nil + case low == 27: + if _, err := io.ReadAtLeast(br, scratch[:8], 8); err != nil { + return 0, 0, err + } + val := binary.BigEndian.Uint64(scratch[:8]) + if val <= math.MaxUint32 { + return 0, 0, fmt.Errorf("cbor input was not canonical (lval 27 with value <= MaxUint32)") + } + return maj, val, nil + default: + return 0, 0, fmt.Errorf("invalid header: (%x)", first) + } +} + func CborWriteHeader(w io.Writer, t byte, l uint64) error { return WriteMajorTypeHeader(w, t, l) } @@ -406,7 +471,7 @@ func ReadByteArray(br io.Reader, maxlen uint64) ([]byte, error) { } buf := make([]byte, extra) - if _, err := io.ReadFull(br, buf); err != nil { + if _, err := io.ReadAtLeast(br, buf, int(extra)); err != nil { return nil, err } @@ -446,7 +511,30 @@ func ReadString(r io.Reader) (string, error) { } buf := make([]byte, l) - _, err = io.ReadFull(r, buf) + _, err = io.ReadAtLeast(r, buf, int(l)) + if err != nil { + return "", err + } + + return string(buf), nil +} + +func ReadStringBuf(r io.Reader, scratch []byte) (string, error) { + maj, l, err := CborReadHeaderBuf(r, scratch) + if err != nil { + return "", err + } + + if maj != MajTextString { + return "", fmt.Errorf("got tag %d while reading string value (l = %d)", maj, l) + } + + if l > MaxLength { + return "", fmt.Errorf("string in input was too long") + } + + buf := make([]byte, l) + _, err = io.ReadAtLeast(r, buf, int(l)) if err != nil { return "", err } From 4837c7b2a928e012152f3b0794028ca73a9228bf Mon Sep 17 00:00:00 2001 From: Jeromy Date: Mon, 4 May 2020 13:41:59 -0700 Subject: [PATCH 5/5] add comment on 9 byte buffer --- gen.go | 1 + 1 file changed, 1 insertion(+) diff --git a/gen.go b/gen.go index e504a8e..743f58e 100644 --- a/gen.go +++ b/gen.go @@ -458,6 +458,7 @@ func emitCborMarshalSliceField(w io.Writer, f Field) error { } func emitCborMarshalStructTuple(w io.Writer, gti *GenTypeInfo) error { + // 9 byte buffer to accomodate for the maximum header length (cbor varints are maximum 9 bytes_ err := doTemplate(w, gti, `var lengthBuf{{ .Name }} = {{ .TupleHeaderAsByteString }} func (t *{{ .Name }}) MarshalCBOR(w io.Writer) error { if t == nil {