diff --git a/libevm/pseudo/constructor.go b/libevm/pseudo/constructor.go new file mode 100644 index 000000000000..91429240340a --- /dev/null +++ b/libevm/pseudo/constructor.go @@ -0,0 +1,24 @@ +package pseudo + +// A Constructor returns newly constructed [Type] instances for a pre-registered +// concrete type. +type Constructor interface { + Zero() *Type + NewPointer() *Type + NilPointer() *Type +} + +// NewConstructor returns a [Constructor] that builds `T` [Type] instances. +func NewConstructor[T any]() Constructor { + return ctor[T]{} +} + +type ctor[T any] struct{} + +func (ctor[T]) Zero() *Type { return Zero[T]().Type } +func (ctor[T]) NilPointer() *Type { return Zero[*T]().Type } + +func (ctor[T]) NewPointer() *Type { + var x T + return From(&x).Type +} diff --git a/libevm/pseudo/constructor_test.go b/libevm/pseudo/constructor_test.go new file mode 100644 index 000000000000..a28f3dd42de6 --- /dev/null +++ b/libevm/pseudo/constructor_test.go @@ -0,0 +1,45 @@ +package pseudo + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConstructor(t *testing.T) { + testConstructor[uint](t) + testConstructor[string](t) + testConstructor[struct{ x string }](t) +} + +func testConstructor[T any](t *testing.T) { + var zero T + t.Run(fmt.Sprintf("%T", zero), func(t *testing.T) { + ctor := NewConstructor[T]() + + t.Run("NilPointer()", func(t *testing.T) { + got := get[*T](t, ctor.NilPointer()) + assert.Nil(t, got) + }) + + t.Run("NewPointer()", func(t *testing.T) { + got := get[*T](t, ctor.NewPointer()) + require.NotNil(t, got) + assert.Equal(t, zero, *got) + }) + + t.Run("Zero()", func(t *testing.T) { + got := get[T](t, ctor.Zero()) + assert.Equal(t, zero, got) + }) + }) +} + +func get[T any](t *testing.T, typ *Type) (x T) { + t.Helper() + val, err := NewValue[T](typ) + require.NoError(t, err, "NewValue[%T]()", x) + return val.Get() +} diff --git a/params/config.libevm.go b/params/config.libevm.go index 38acbb7a3b89..66e995f9a6a8 100644 --- a/params/config.libevm.go +++ b/params/config.libevm.go @@ -44,17 +44,45 @@ func RegisterExtras[C any, R any](e Extras[C, R]) ExtraPayloadGetter[C, R] { } mustBeStruct[C]() mustBeStruct[R]() - registeredExtras = &e - return ExtraPayloadGetter[C, R]{} + registeredExtras = &extraConstructors{ + chainConfig: pseudo.NewConstructor[C](), + rules: pseudo.NewConstructor[R](), + newForRules: e.newForRules, + } + return e.getter() +} + +// registeredExtras holds non-generic constructors for the [Extras] types +// registered via [RegisterExtras]. +var registeredExtras *extraConstructors + +type extraConstructors struct { + chainConfig, rules pseudo.Constructor + newForRules func(_ *ChainConfig, _ *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) *pseudo.Type +} + +func (e *Extras[C, R]) newForRules(c *ChainConfig, r *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) *pseudo.Type { + if e.NewRules == nil { + return registeredExtras.rules.NilPointer() + } + rExtra := e.NewRules(c, r, e.getter().FromChainConfig(c), blockNum, isMerge, timestamp) + return pseudo.From(rExtra).Type +} + +func (*Extras[C, R]) getter() (g ExtraPayloadGetter[C, R]) { return } + +// mustBeStruct panics if `T` isn't a struct. +func mustBeStruct[T any]() { + if k := reflect.TypeFor[T]().Kind(); k != reflect.Struct { + panic(notStructMessage[T]()) + } } -// registeredExtras holds the [Extras] registered via [RegisterExtras]. As we -// don't know `C` and `R` at compile time, it must be an interface. -var registeredExtras interface { - nilForChainConfig() *pseudo.Type - nilForRules() *pseudo.Type - newForChainConfig() *pseudo.Type - newForRules(_ *ChainConfig, _ *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) *pseudo.Type +// notStructMessage returns the message with which [mustBeStruct] might panic. +// It exists to avoid change-detector tests should the message contents change. +func notStructMessage[T any]() string { + var x T + return fmt.Sprintf("%T is not a struct", x) } // An ExtraPayloadGettter provides strongly typed access to the extra payloads @@ -74,20 +102,6 @@ func (ExtraPayloadGetter[C, R]) FromRules(r *Rules) *R { return pseudo.MustNewValue[*R](r.extraPayload()).Get() } -func mustBeStruct[T any]() { - var x T - if k := reflect.TypeOf(x).Kind(); k != reflect.Struct { - panic(notStructMessage[T]()) - } -} - -// notStructMessage returns the message with which [mustBeStruct] might panic. -// It exists to avoid change-detector tests should the message contents change. -func notStructMessage[T any]() string { - var x T - return fmt.Sprintf("%T is not a struct", x) -} - // UnmarshalJSON implements the [json.Unmarshaler] interface. func (c *ChainConfig) UnmarshalJSON(data []byte) error { type raw ChainConfig // doesn't inherit methods so avoids recursing back here (infinitely) @@ -95,8 +109,8 @@ func (c *ChainConfig) UnmarshalJSON(data []byte) error { *raw Extra *pseudo.Type `json:"extra"` }{ - raw: (*raw)(c), // embedded to achieve regular JSON unmarshalling - Extra: registeredExtras.nilForChainConfig(), // `c.extra` is otherwise unexported + raw: (*raw)(c), // embedded to achieve regular JSON unmarshalling + Extra: registeredExtras.chainConfig.NilPointer(), // `c.extra` is otherwise unexported } if err := json.Unmarshal(data, cc); err != nil { @@ -143,7 +157,7 @@ func (c *ChainConfig) extraPayload() *pseudo.Type { panic(fmt.Sprintf("%T.ExtraPayload() called before RegisterExtras()", c)) } if c.extra == nil { - c.extra = registeredExtras.nilForChainConfig() + c.extra = registeredExtras.chainConfig.NilPointer() } return c.extra } @@ -155,30 +169,7 @@ func (r *Rules) extraPayload() *pseudo.Type { panic(fmt.Sprintf("%T.ExtraPayload() called before RegisterExtras()", r)) } if r.extra == nil { - r.extra = registeredExtras.nilForRules() + r.extra = registeredExtras.rules.NilPointer() } return r.extra } - -/** - * Start of Extras implementing the registeredExtras interface. - */ - -func (Extras[C, R]) nilForChainConfig() *pseudo.Type { return pseudo.Zero[*C]().Type } -func (Extras[C, R]) nilForRules() *pseudo.Type { return pseudo.Zero[*R]().Type } - -func (*Extras[C, R]) newForChainConfig() *pseudo.Type { - var x C - return pseudo.From(&x).Type -} - -func (e *Extras[C, R]) newForRules(c *ChainConfig, r *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) *pseudo.Type { - if e.NewRules == nil { - return e.nilForRules() - } - return pseudo.From(e.NewRules(c, r, c.extra.Interface().(*C), blockNum, isMerge, timestamp)).Type -} - -/** - * End of Extras implementing the registeredExtras interface. - */