Skip to content

Commit

Permalink
RLP encoding and decoding generator for simple structures (#13341)
Browse files Browse the repository at this point in the history
  • Loading branch information
racytech authored Jan 23, 2025
1 parent 683381a commit 6c2b596
Show file tree
Hide file tree
Showing 9 changed files with 2,002 additions and 6 deletions.
734 changes: 734 additions & 0 deletions cmd/rlpgen/handlers.go

Large diffs are not rendered by default.

286 changes: 286 additions & 0 deletions cmd/rlpgen/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
// Copyright 2025 The Erigon Authors
// This file is part of Erigon.
//
// Erigon is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Erigon is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with Erigon. If not, see <http://www.gnu.org/licenses/>.

// NOTE: This generator works only on structures, if the type is slice of types (e.g []MyType) it will fail.
// And not all the field types currently supported, see `matcher.go`
// This will be fixed in the future.

package main

import (
"bytes"
"errors"
"flag"
"fmt"
"go/types"
"os"
"strings"

"golang.org/x/tools/go/packages"
)

const (
rlpPackagePath = "github.com/erigontech/erigon-lib/rlp"
)

const headerMsg = "// Code generated by rlpgen. DO NOT EDIT.\n\n"

var (
_imports = map[string]bool{}
pkgSrc *types.Package
)

func main() {
var (
pkgdir = flag.String("dir", ".", "input package")
typename = flag.String("type", "", "type to generate methods for")
writefile = flag.Bool("wfile", true, "set to false if no need to write to the file")
)
flag.Parse()

pcfg := &packages.Config{
Mode: packages.NeedName | packages.NeedTypes,
Dir: *pkgdir,
}
ps, err := packages.Load(pcfg, rlpPackagePath, ".")
if err != nil {
_exit(fmt.Sprint("error loading package: ", err))
}
if len(ps) != 2 {
_exit(fmt.Sprintf("expected to load 2 packages: 1) %v, 2) %v\n \tgot %v", rlpPackagePath, *pkgdir, len(ps)))
}

if err := checkPackageErrors(ps[0]); err != nil {
_exit(err.Error())
}
if err := checkPackageErrors(ps[1]); err != nil {
_exit(err.Error())
}
if ps[0].PkgPath != rlpPackagePath {
_exit(fmt.Sprintf("expected first package to be %s\n", rlpPackagePath))
}

pkgSrc = ps[1].Types
fmt.Println("pkgSrc: ", pkgSrc.Name())
fmt.Println("typename: ", *typename)

// 1. search for a struct
typ, err := findType(pkgSrc.Scope(), *typename)
if err != nil {
_exit(err.Error())
}

// TODO(racytech): add error checks for the possible unhandled errors

var encodingSize bytes.Buffer
var encodeRLP bytes.Buffer
var decodeRLP bytes.Buffer
// ps[0].Types - rlp package
// ps[1].Types - package where to search for to-be generated struct
if err := process(typ, &encodingSize, &encodeRLP, &decodeRLP); err != nil {
_exit(err.Error())
}

result := addImports()

result = append(result, encodingSize.Bytes()...)
result = append(result, encodeRLP.Bytes()...)
result = append(result, decodeRLP.Bytes()...)
os.Stdout.Write(result)
if *writefile {
outfile := fmt.Sprintf("%s/gen_%s_rlp.go", *pkgdir, strings.ToLower(typ.Obj().Name()))
fmt.Println("outfile: ", outfile)
if err := os.WriteFile(outfile, result, 0600); err != nil {
_exit(err.Error())
}
} else {
os.Stdout.Write(result)
}
}

func _exit(msg string) {
fmt.Println(msg)
os.Exit(1)
}

func checkPackageErrors(pkg *packages.Package) error {
var b bytes.Buffer
if len(pkg.Errors) > 0 {
fmt.Fprintf(&b, "package %s has errors: \n", pkg.PkgPath)
for _, e := range pkg.Errors {
fmt.Fprintf(&b, "%s\n", e.Msg)
}
}
if b.Len() > 0 {
return errors.New(b.String())
}
return nil
}

func addImports() []byte {
_imports["fmt"] = true
_imports["io"] = true
_imports[rlpPackagePath] = true

result := make([]byte, 0, len(_imports))
result = append(result, []byte(headerMsg)...)
result = append(result, []byte("package "+pkgSrc.Name()+"\n\n")...)
result = append(result, []byte("import (\n")...)
for k := range _imports {
result = append(result, []byte(" ")...)
result = append(result, '"')
result = append(result, []byte(k)...)
result = append(result, '"', '\n')
}
result = append(result, []byte(")\n\n")...)
return result
}

func process(typ *types.Named, b1, b2, b3 *bytes.Buffer) error {
// TODO(racytech): handle all possible errors

typename := typ.Obj().Name()

// 1. start EncodingSize method on a struct
fmt.Fprintf(b1, "func (obj *%s) EncodingSize() (size int) {\n", typename)

// 2. start EncodeRLP
fmt.Fprintf(b2, "func (obj *%s) EncodeRLP(w io.Writer) error {\n", typename)
fmt.Fprint(b2, " var b [32]byte\n")
fmt.Fprint(b2, " if err := rlp.EncodeStructSizePrefix(obj.EncodingSize(), w, b[:]); err != nil {\n")
fmt.Fprint(b2, " return err\n")
fmt.Fprint(b2, " }\n")

// 3. start DecodeRLP
fmt.Fprintf(b3, "func (obj *%s) DecodeRLP(s *rlp.Stream) error {\n", typename)
fmt.Fprint(b3, " _, err := s.List()\n")
fmt.Fprint(b3, " if err != nil {\n")
fmt.Fprint(b3, " return err\n")
fmt.Fprint(b3, " }\n")

// 4. add encoding/decoding logic
if err := addEncodeLogic(b1, b2, b3, typ); err != nil {
return err
}

// 5. end EncodingSize method
fmt.Fprintf(b1, " return\n}\n\n")

// 6. end EcnodeRLP
fmt.Fprintf(b2, " return nil\n}\n\n")

// 7. end DecodeRLP
fmt.Fprintf(b3, " if err = s.ListEnd(); err != nil {\n")
fmt.Fprintf(b3, " return fmt.Errorf(\"error closing %s, err: %%w\", err)\n", typename)
fmt.Fprintf(b3, " }\n")
fmt.Fprintf(b3, " return nil\n}\n")

return nil
}

func findType(scope *types.Scope, typename string) (*types.Named, error) {
// fmt.Println("TYPENAME: ", typename)
// names := scope.Names()
// for _, s := range names {
// fmt.Println("obj: ", s)
// }
obj := scope.Lookup(typename)
if obj == nil {
return nil, fmt.Errorf("no such identifier: %s", typename)
}
typ, ok := obj.(*types.TypeName)
if !ok {
return nil, errors.New("not a type")
}
if named, ok := typ.Type().(*types.Named); ok {
return named, nil
}
return nil, errors.New("not a named type")
}

func addEncodeLogic(b1, b2, b3 *bytes.Buffer, named *types.Named) error {

if _struct, ok := named.Underlying().(*types.Struct); ok {
for i := 0; i < _struct.NumFields(); i++ {

strTyp := matchTypeToString(_struct.Field(i).Type(), "")
// fmt.Println("-+-", strTyp)

matchStrTypeToFunc(strTyp)(b1, b2, b3, _struct.Field(i).Type(), _struct.Field(i).Name())
}
}
// else {

// // TODO(racytech): see handleType
// }

return nil
}

func handleType(t types.Type, caller types.Type, depth int, ptr bool) {
switch e := t.(type) {
case *types.Pointer:
// check if double pointer, fail if depth > 0 and ptr == true

// if t is Pointer type pass to the next level
handleType(e.Elem(), e, depth+1, true)
case *types.Named:
// if t is user named type,
// check if big.Int or uint256.Int -> check if pointer -> encode/decode accordingly -> return
// check if rlp generated for this type -> if yes remove the file
// else pass to the next level
handleType(e.Underlying(), e, depth+1, false)
// if underlying is a struct
// check if rlp generated for this type -> if yes call RLP encode/decode methods on it, return
case *types.Basic:
// check if caller Named (e.g type MyInt int) -> TODO
// check if caller Slice or Array
// if t is byte slice or byte array -> encode -> return
// check if caller Pointer -> encode -> return
//
// or if caller is nil -> call rlp encoding function on basic types, bool, uint, int etc.
// return
case *types.Slice:
// check if it's simple byteslice
// if yes call RLP encode/decode methods on it, return
// check if it's slice of named types. e.g []common.Hash -> [][32]byte, or type MyStruct struct {a, b, c}
// ^TODO think about this case^
//
handleType(e.Elem(), e, depth+1, false)
case *types.Array:
// check if it's simple bytearray
// if yes call RLP encode/decode methods on it, return
// check if it's slice of named types. e.g [10]common.Hash -> [10][32]byte, or type MyStruct struct {a, b, c}
// ^TODO think about this case^
//
handleType(e.Elem(), e, depth+1, false)
case *types.Struct:
// check if nested struct
// if yes check if rlp previously generated for this type -> if yes remove the file
// try generating rlp for this structure, e.g as follows:
// process(t) -> if successful
// add encoding/decoding logic of this struct to the buffers
// return

// else -> top level call
for i := 0; i < e.NumFields(); i++ {
// this should panic and fail generating everything in case of error
handleType(e.Field(i).Type(), e, depth+1, false)
}
default:
panic("unhandled")
}
}
89 changes: 89 additions & 0 deletions cmd/rlpgen/matcher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright 2025 The Erigon Authors
// This file is part of Erigon.
//
// Erigon is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Erigon is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with Erigon. If not, see <http://www.gnu.org/licenses/>.

package main

import (
"bytes"
"go/types"
)

// handle should write encoding size of type as well as encoding and decoding logic for the type
type handle func(b1, b2, b3 *bytes.Buffer, fieldType types.Type, fieldName string)

// func foofunc(b1, b2, b3 *bytes.Buffer, fieldType types.Type, fieldName string) {}

// all possible types that this generator can handle for the time being.
// to add a new type add a string representation of type here and write the handle function for it in the `handlers.go`
var handlers = map[string]handle{
"uint64": uintHandle,
"*uint64": uintPtrHandle,
"big.Int": bigIntHandle,
"*big.Int": bigIntPtrHandle,
"uint256.Int": uint256Handle,
"*uint256.Int": uint256PtrHandle,
"types.BlockNonce": blockNonceHandle,
"*types.BlockNonce": blockNoncePtrHandle,
"common.Address": addressHandle,
"*common.Address": addressPtrHandle,
"common.Hash": hashHandle,
"*common.Hash": hashPtrHandle,
"types.Bloom": bloomHandle,
"*types.Bloom": bloomPtrHandle,
"[]byte": byteSliceHandle,
"*[]byte": byteSlicePtrHandle,
"[][]byte": byteSliceSliceHandle,
"[]types.BlockNonce": blockNonceSliceHandle,
"[]*types.BlockNonce": blockNoncePtrSliceHandle,
"[]common.Address": addressSliceHandle,
"[]*common.Address": addressPtrSliceHandle,
"[]common.Hash": hashSliceHandle,
"[]*common.Hash": hashPtrSliceHandle,
"[n]byte": byteArrayHandle,
"*[n]byte": byteArrayPtrHandle,
}

// recursive function, constructs string representation of a type. array represented as [n]
func matchTypeToString(fieldType types.Type, in string) string {
if named, ok := fieldType.(*types.Named); ok {
return in + named.Obj().Pkg().Name() + "." + named.Obj().Name()
} else if ptr, ok := fieldType.(*types.Pointer); ok {
return matchTypeToString(ptr.Elem(), in+"*")
} else if slc, ok := fieldType.(*types.Slice); ok {
return matchTypeToString(slc.Elem(), in+"[]")
} else if arr, ok := fieldType.(*types.Array); ok {
return matchTypeToString(arr.Elem(), in+"[n]")
} else if basic, ok := fieldType.(*types.Basic); ok {
return in + basic.Name()
} else {
panic("_matchTypeToString: unhandled match")
}
}

// matches string representation of a type to a corresponding function
func matchStrTypeToFunc(strType string) handle {
switch strType {
case "int16", "int32", "int", "int64", "uint16", "uint32", "uint", "uint64":
return handlers["uint64"]
case "*int16", "*int32", "*int", "*int64", "*uint16", "*uint32", "*uint", "*uint64":
return handlers["*uint64"]
default:
if fn, ok := handlers[strType]; ok {
return fn
}
panic("no handle added for type: " + strType)
}
}
Loading

0 comments on commit 6c2b596

Please sign in to comment.