Skip to content

Commit

Permalink
refactor: consolidate once-only registration of extras (#85)
Browse files Browse the repository at this point in the history
## Why this should be merged

Consolidates duplicated logic. Similar rationale to #84.

## How this works

New `register.AtMostOnce[T]` type is responsible for limiting calls to
`Register()`.

## How this was tested

Existing unit tests of `params`. Note that the equivalent functionality
in `types` wasn't tested but now is.
  • Loading branch information
ARR4N authored Dec 9, 2024
1 parent 25e5ca3 commit d71677f
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 48 deletions.
34 changes: 15 additions & 19 deletions core/types/rlp_payload.libevm.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"io"

"github.com/ava-labs/libevm/libevm/pseudo"
"github.com/ava-labs/libevm/libevm/register"
"github.com/ava-labs/libevm/libevm/testonly"
"github.com/ava-labs/libevm/rlp"
)
Expand All @@ -37,18 +38,15 @@ import (
// The payload can be accessed via the [ExtraPayloads.FromPayloadCarrier] method
// of the accessor returned by RegisterExtras.
func RegisterExtras[SA any]() ExtraPayloads[SA] {
if registeredExtras != nil {
panic("re-registration of Extras")
}
var extra ExtraPayloads[SA]
registeredExtras = &extraConstructors{
registeredExtras.MustRegister(&extraConstructors{
stateAccountType: func() string {
var x SA
return fmt.Sprintf("%T", x)
}(),
newStateAccount: pseudo.NewConstructor[SA]().Zero,
cloneStateAccount: extra.cloneStateAccount,
}
})
return extra
}

Expand All @@ -59,12 +57,10 @@ func RegisterExtras[SA any]() ExtraPayloads[SA] {
// defer-called afterwards, either directly or via testing.TB.Cleanup(). This is
// a workaround for the single-call limitation on [RegisterExtras].
func TestOnlyClearRegisteredExtras() {
testonly.OrPanic(func() {
registeredExtras = nil
})
registeredExtras.TestOnlyClear()
}

var registeredExtras *extraConstructors
var registeredExtras register.AtMostOnce[*extraConstructors]

type extraConstructors struct {
stateAccountType string
Expand All @@ -74,10 +70,10 @@ type extraConstructors struct {

func (e *StateAccountExtra) clone() *StateAccountExtra {
switch r := registeredExtras; {
case r == nil, e == nil:
case !r.Registered(), e == nil:
return nil
default:
return r.cloneStateAccount(e)
return r.Get().cloneStateAccount(e)
}
}

Expand Down Expand Up @@ -146,15 +142,15 @@ func (a *SlimAccount) extra() *StateAccountExtra {
func getOrSetNewStateAccountExtra(curr **StateAccountExtra) *StateAccountExtra {
if *curr == nil {
*curr = &StateAccountExtra{
t: registeredExtras.newStateAccount(),
t: registeredExtras.Get().newStateAccount(),
}
}
return *curr
}

func (e *StateAccountExtra) payload() *pseudo.Type {
if e.t == nil {
e.t = registeredExtras.newStateAccount()
e.t = registeredExtras.Get().newStateAccount()
}
return e.t
}
Expand Down Expand Up @@ -196,24 +192,24 @@ var _ interface {
// EncodeRLP implements the [rlp.Encoder] interface.
func (e *StateAccountExtra) EncodeRLP(w io.Writer) error {
switch r := registeredExtras; {
case r == nil:
case !r.Registered():
return nil
case e == nil:
e = &StateAccountExtra{}
fallthrough
case e.t == nil:
e.t = r.newStateAccount()
e.t = r.Get().newStateAccount()
}
return e.t.EncodeRLP(w)
}

// DecodeRLP implements the [rlp.Decoder] interface.
func (e *StateAccountExtra) DecodeRLP(s *rlp.Stream) error {
switch r := registeredExtras; {
case r == nil:
case !r.Registered():
return nil
case e.t == nil:
e.t = r.newStateAccount()
e.t = r.Get().newStateAccount()
fallthrough
default:
return s.Decode(e.t)
Expand All @@ -224,10 +220,10 @@ func (e *StateAccountExtra) DecodeRLP(s *rlp.Stream) error {
func (e *StateAccountExtra) Format(s fmt.State, verb rune) {
var out string
switch r := registeredExtras; {
case r == nil:
case !r.Registered():
out = "<nil>"
case e == nil, e.t == nil:
out = fmt.Sprintf("<nil>[*StateAccountExtra[%s]]", r.stateAccountType)
out = fmt.Sprintf("<nil>[*StateAccountExtra[%s]]", r.Get().stateAccountType)
default:
e.t.Format(s, verb)
return
Expand Down
68 changes: 68 additions & 0 deletions libevm/register/register.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright 2024 the libevm authors.
//
// The libevm additions to go-ethereum are free software: you can redistribute
// them and/or modify them under the terms of the GNU Lesser General Public License
// as published by the Free Software Foundation, either version 3 of the License,
// or (at your option) any later version.
//
// The libevm additions are distributed in the hope that they will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser
// General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see
// <http://www.gnu.org/licenses/>.

// Package register provides functionality for optional registration of types.
package register

import (
"errors"

"github.com/ava-labs/libevm/libevm/testonly"
)

// An AtMostOnce allows zero or one registration of a T.
type AtMostOnce[T any] struct {
v *T
}

// ErrReRegistration is returned on all but the first of calls to
// [AtMostOnce.Register].
var ErrReRegistration = errors.New("re-registration")

// Register registers `v` or returns [ErrReRegistration] if already called.
func (o *AtMostOnce[T]) Register(v T) error {
if o.Registered() {
return ErrReRegistration
}
o.v = &v
return nil
}

// MustRegister is equivalent to [AtMostOnce.Register], panicking on error.
func (o *AtMostOnce[T]) MustRegister(v T) {
if err := o.Register(v); err != nil {
panic(err)
}
}

// Registered reports whether [AtMostOnce.Register] has been called.
func (o *AtMostOnce[T]) Registered() bool {
return o.v != nil
}

// Get returns the registered value. It MUST NOT be called before
// [AtMostOnce.Register].
func (o *AtMostOnce[T]) Get() T {
return *o.v
}

// TestOnlyClear clears any previously registered value, returning `o` to its
// default state. It panics if called from a non-testing call stack.
func (o *AtMostOnce[T]) TestOnlyClear() {
testonly.OrPanic(func() {
o.v = nil
})
}
29 changes: 12 additions & 17 deletions params/config.libevm.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
"reflect"

"github.com/ava-labs/libevm/libevm/pseudo"
"github.com/ava-labs/libevm/libevm/testonly"
"github.com/ava-labs/libevm/libevm/register"
)

// Extras are arbitrary payloads to be added as extra fields in [ChainConfig]
Expand Down Expand Up @@ -68,20 +68,17 @@ type Extras[C ChainConfigHooks, R RulesHooks] struct {
// alter Ethereum behaviour; if this isn't desired then they can embed
// [NOOPHooks] to satisfy either interface.
func RegisterExtras[C ChainConfigHooks, R RulesHooks](e Extras[C, R]) ExtraPayloads[C, R] {
if registeredExtras != nil {
panic("re-registration of Extras")
}
mustBeStructOrPointerToOne[C]()
mustBeStructOrPointerToOne[R]()

payloads := e.payloads()
registeredExtras = &extraConstructors{
registeredExtras.MustRegister(&extraConstructors{
newChainConfig: pseudo.NewConstructor[C]().Zero,
newRules: pseudo.NewConstructor[R]().Zero,
reuseJSONRoot: e.ReuseJSONRoot,
newForRules: e.newForRules,
payloads: payloads,
}
})
return payloads
}

Expand All @@ -92,14 +89,12 @@ func RegisterExtras[C ChainConfigHooks, R RulesHooks](e Extras[C, R]) ExtraPaylo
// defer-called afterwards, either directly or via testing.TB.Cleanup(). This is
// a workaround for the single-call limitation on [RegisterExtras].
func TestOnlyClearRegisteredExtras() {
testonly.OrPanic(func() {
registeredExtras = nil
})
registeredExtras.TestOnlyClear()
}

// registeredExtras holds non-generic constructors for the [Extras] types
// registered via [RegisterExtras].
var registeredExtras *extraConstructors
var registeredExtras register.AtMostOnce[*extraConstructors]

type extraConstructors struct {
newChainConfig, newRules func() *pseudo.Type
Expand All @@ -115,7 +110,7 @@ type extraConstructors struct {

func (e *Extras[C, R]) newForRules(c *ChainConfig, r *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) *pseudo.Type {
if e.NewRules == nil {
return registeredExtras.newRules()
return registeredExtras.Get().newRules()
}
rExtra := e.NewRules(c, r, e.payloads().FromChainConfig(c), blockNum, isMerge, timestamp)
return pseudo.From(rExtra).Type
Expand Down Expand Up @@ -209,36 +204,36 @@ func (e ExtraPayloads[C, R]) hooksFromRules(r *Rules) RulesHooks {
// abstract the libevm-specific behaviour outside of original geth code.
func (c *ChainConfig) addRulesExtra(r *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) {
r.extra = nil
if registeredExtras != nil {
r.extra = registeredExtras.newForRules(c, r, blockNum, isMerge, timestamp)
if registeredExtras.Registered() {
r.extra = registeredExtras.Get().newForRules(c, r, blockNum, isMerge, timestamp)
}
}

// extraPayload returns the ChainConfig's extra payload iff [RegisterExtras] has
// already been called. If the payload hasn't been populated (typically via
// unmarshalling of JSON), a nil value is constructed and returned.
func (c *ChainConfig) extraPayload() *pseudo.Type {
if registeredExtras == nil {
if !registeredExtras.Registered() {
// This will only happen if someone constructs an [ExtraPayloads]
// directly, without a call to [RegisterExtras].
//
// See https://google.github.io/styleguide/go/best-practices#when-to-panic
panic(fmt.Sprintf("%T.ExtraPayload() called before RegisterExtras()", c))
}
if c.extra == nil {
c.extra = registeredExtras.newChainConfig()
c.extra = registeredExtras.Get().newChainConfig()
}
return c.extra
}

// extraPayload is equivalent to [ChainConfig.extraPayload].
func (r *Rules) extraPayload() *pseudo.Type {
if registeredExtras == nil {
if !registeredExtras.Registered() {
// See ChainConfig.extraPayload() equivalent.
panic(fmt.Sprintf("%T.ExtraPayload() called before RegisterExtras()", r))
}
if r.extra == nil {
r.extra = registeredExtras.newRules()
r.extra = registeredExtras.Get().newRules()
}
return r.extra
}
8 changes: 6 additions & 2 deletions params/config.libevm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/stretchr/testify/require"

"github.com/ava-labs/libevm/libevm/pseudo"
"github.com/ava-labs/libevm/libevm/register"
)

type rawJSON struct {
Expand Down Expand Up @@ -255,18 +256,21 @@ func TestExtrasPanic(t *testing.T) {
t, func() {
RegisterExtras(Extras[struct{ ChainConfigHooks }, struct{ RulesHooks }]{})
},
"re-registration",
register.ErrReRegistration.Error(),
)
}

func assertPanics(t *testing.T, fn func(), wantContains string) {
t.Helper()
defer func() {
t.Helper()
switch r := recover().(type) {
case nil:
t.Error("function did not panic as expected")
t.Error("function did not panic when panic expected")
case string:
assert.Contains(t, r, wantContains)
case error:
assert.Contains(t, r.Error(), wantContains)
default:
t.Fatalf("BAD TEST SETUP: recover() got unsupported type %T", r)
}
Expand Down
8 changes: 4 additions & 4 deletions params/hooks.libevm.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,17 @@ type RulesAllowlistHooks interface {
// Hooks returns the hooks registered with [RegisterExtras], or [NOOPHooks] if
// none were registered.
func (c *ChainConfig) Hooks() ChainConfigHooks {
if e := registeredExtras; e != nil {
return e.payloads.hooksFromChainConfig(c)
if e := registeredExtras; e.Registered() {
return e.Get().payloads.hooksFromChainConfig(c)
}
return NOOPHooks{}
}

// Hooks returns the hooks registered with [RegisterExtras], or [NOOPHooks] if
// none were registered.
func (r *Rules) Hooks() RulesHooks {
if e := registeredExtras; e != nil {
return e.payloads.hooksFromRules(r)
if e := registeredExtras; e.Registered() {
return e.Get().payloads.hooksFromRules(r)
}
return NOOPHooks{}
}
Expand Down
12 changes: 6 additions & 6 deletions params/json.libevm.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ type chainConfigWithExportedExtra struct {
// UnmarshalJSON implements the [json.Unmarshaler] interface.
func (c *ChainConfig) UnmarshalJSON(data []byte) error {
switch reg := registeredExtras; {
case reg != nil && !reg.reuseJSONRoot:
case reg.Registered() && !reg.Get().reuseJSONRoot:
return c.unmarshalJSONWithExtra(data)

case reg != nil && reg.reuseJSONRoot: // although the latter is redundant, it's clearer
c.extra = reg.newChainConfig()
case reg.Registered() && reg.Get().reuseJSONRoot: // although the latter is redundant, it's clearer
c.extra = reg.Get().newChainConfig()
if err := json.Unmarshal(data, c.extra); err != nil {
c.extra = nil
return err
Expand All @@ -63,7 +63,7 @@ func (c *ChainConfig) UnmarshalJSON(data []byte) error {
func (c *ChainConfig) unmarshalJSONWithExtra(data []byte) error {
cc := &chainConfigWithExportedExtra{
chainConfigWithoutMethods: (*chainConfigWithoutMethods)(c),
Extra: registeredExtras.newChainConfig(),
Extra: registeredExtras.Get().newChainConfig(),
}
if err := json.Unmarshal(data, cc); err != nil {
return err
Expand All @@ -75,10 +75,10 @@ func (c *ChainConfig) unmarshalJSONWithExtra(data []byte) error {
// MarshalJSON implements the [json.Marshaler] interface.
func (c *ChainConfig) MarshalJSON() ([]byte, error) {
switch reg := registeredExtras; {
case reg == nil:
case !reg.Registered():
return json.Marshal((*chainConfigWithoutMethods)(c))

case !reg.reuseJSONRoot:
case !reg.Get().reuseJSONRoot:
return c.marshalJSONWithExtra()

default: // reg.reuseJSONRoot == true
Expand Down

0 comments on commit d71677f

Please sign in to comment.