Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: stateful precompiles + allowlist hooks #3

Merged
merged 9 commits into from
Sep 10, 2024
3 changes: 3 additions & 0 deletions core/state_transition.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,9 @@ func (st *StateTransition) preCheck() error {
// However if any consensus issue encountered, return the error directly with
// nil evm execution result.
func (st *StateTransition) TransitionDb() (*ExecutionResult, error) {
if err := st.canExecuteTransaction(); err != nil {
return nil, err
}
// First check this message satisfies all consensus rules before
// applying the message. The rules include these clauses
//
Expand Down
9 changes: 9 additions & 0 deletions core/state_transition.libevm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package core

// canExecuteTransaction is a convenience wrapper for calling the
// [params.RulesHooks.CanExecuteTransaction] hook.
func (st *StateTransition) canExecuteTransaction() error {
bCtx := st.evm.Context
rules := st.evm.ChainConfig().Rules(bCtx.BlockNumber, bCtx.Random != nil, bCtx.Time)
return rules.Hooks().CanExecuteTransaction(st.msg.From, st.msg.To, st.state)
}
40 changes: 40 additions & 0 deletions core/state_transition.libevm_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package core_test

import (
"fmt"
"testing"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/libevm"
"github.com/ethereum/go-ethereum/libevm/ethtest"
"github.com/ethereum/go-ethereum/libevm/hookstest"
"github.com/stretchr/testify/require"
)

func TestCanExecuteTransaction(t *testing.T) {
rng := ethtest.NewPseudoRand(42)
account := rng.Address()
slot := rng.Hash()

makeErr := func(from common.Address, to *common.Address, val common.Hash) error {
return fmt.Errorf("From: %v To: %v State: %v", from, to, val)
}
hooks := &hookstest.Stub{
CanExecuteTransactionFn: func(from common.Address, to *common.Address, s libevm.StateReader) error {
return makeErr(from, to, s.GetState(account, slot))
},
}
hooks.RegisterForRules(t)

value := rng.Hash()

state, evm := ethtest.NewZeroEVM(t)
state.SetState(account, slot, value)
msg := &core.Message{
From: rng.Address(),
To: rng.AddressPtr(),
}
_, err := core.ApplyMessage(evm, msg, new(core.GasPool).AddGas(30e6))
require.EqualError(t, err, makeErr(msg.From, msg.To, value).Error())
}
4 changes: 2 additions & 2 deletions core/vm/contracts.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,13 @@ func ActivePrecompiles(rules params.Rules) []common.Address {
// - the returned bytes,
// - the _remaining_ gas,
// - any error that occurred
func RunPrecompiledContract(p PrecompiledContract, input []byte, suppliedGas uint64) (ret []byte, remainingGas uint64, err error) {
func (args *evmCallArgs) RunPrecompiledContract(p PrecompiledContract, input []byte, suppliedGas uint64) (ret []byte, remainingGas uint64, err error) {
gasCost := p.RequiredGas(input)
if suppliedGas < gasCost {
return nil, 0, ErrOutOfGas
}
suppliedGas -= gasCost
output, err := p.Run(input)
output, err := args.run(p, input)
return output, suppliedGas, err
}

Expand Down
83 changes: 83 additions & 0 deletions core/vm/contracts.libevm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package vm

import (
"fmt"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/params"
"github.com/holiman/uint256"
)

// evmCallArgs mirrors the parameters of the [EVM] methods Call(), CallCode(),
// DelegateCall() and StaticCall(). Its fields are identical to those of the
// parameters, prepended with the receiver name. As {Delegate,Static}Call don't
// accept a value, they MUST set the respective field to nil.
//
// Instantiation can be achieved by merely copying the parameter names, in
// order, which is trivially achieved with AST manipulation:
//
// func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas uint64, value *uint256.Int) ... {
// ...
// args := &evmCallArgs{evm, caller, addr, input, gas, value}
type evmCallArgs struct {
evm *EVM
caller ContractRef
addr common.Address
input []byte
gas uint64
value *uint256.Int
}

// run runs the [PrecompiledContract], differentiating between stateful and
// regular types.
func (args *evmCallArgs) run(p PrecompiledContract, input []byte) (ret []byte, err error) {
if p, ok := p.(statefulPrecompile); ok {
return p.run(args.evm.StateDB, &args.evm.chainRules, args.caller.Address(), args.addr, input)
}
return p.Run(input)
}

// PrecompiledStatefulRun is the stateful equivalent of the Run() method of a
// [PrecompiledContract].
type PrecompiledStatefulRun func(_ StateDB, _ *params.Rules, caller, self common.Address, input []byte) ([]byte, error)

// NewStatefulPrecompile constructs a new PrecompiledContract that can be used
// via an [EVM] instance but MUST NOT be called directly; a direct call to Run()
// reserves the right to panic. See other requirements defined in the comments
// on [PrecompiledContract].
func NewStatefulPrecompile(run PrecompiledStatefulRun, requiredGas func([]byte) uint64) PrecompiledContract {
return statefulPrecompile{
gas: requiredGas,
run: run,
}
}

type statefulPrecompile struct {
gas func([]byte) uint64
run PrecompiledStatefulRun
}

func (p statefulPrecompile) RequiredGas(input []byte) uint64 {
return p.gas(input)
}

func (p statefulPrecompile) Run([]byte) ([]byte, error) {
// https://google.github.io/styleguide/go/best-practices.html#when-to-panic
// This would indicate an API misuse and would occur in tests, not in
// production.
panic(fmt.Sprintf("BUG: call to %T.Run(); MUST call %T", p, p.run))
}

var (
// These lock in the assumptions made when implementing [evmCallArgs]. If
// these break then the struct fields SHOULD be changed to match these
// signatures.
_ = [](func(ContractRef, common.Address, []byte, uint64, *uint256.Int) ([]byte, uint64, error)){
(*EVM)(nil).Call,
(*EVM)(nil).CallCode,
}
_ = [](func(ContractRef, common.Address, []byte, uint64) ([]byte, uint64, error)){
(*EVM)(nil).DelegateCall,
(*EVM)(nil).StaticCall,
}
)
159 changes: 112 additions & 47 deletions core/vm/contracts.libevm_test.go
Original file line number Diff line number Diff line change
@@ -1,35 +1,22 @@
package vm
package vm_test

import (
"fmt"
"math/big"
"testing"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/libevm"
"github.com/ethereum/go-ethereum/libevm/ethtest"
"github.com/ethereum/go-ethereum/libevm/hookstest"
"github.com/ethereum/go-ethereum/params"
"github.com/holiman/uint256"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/rand"
)

// precompileOverrides is a [params.RulesHooks] that overrides precompiles from
// a map of predefined addresses.
type precompileOverrides struct {
contracts map[common.Address]PrecompiledContract
params.NOOPHooks // all other hooks
}

func (o precompileOverrides) PrecompileOverride(a common.Address) (libevm.PrecompiledContract, bool) {
c, ok := o.contracts[a]
return c, ok
}

// A precompileStub is a [PrecompiledContract] that always returns the same
// values.
type precompileStub struct {
requiredGas uint64
returnData []byte
Expand Down Expand Up @@ -58,7 +45,7 @@ func TestPrecompileOverride(t *testing.T) {
}

rng := rand.New(rand.NewSource(42))
for _, addr := range PrecompiledAddressesCancun {
for _, addr := range vm.PrecompiledAddressesCancun {
tests = append(tests, test{
name: fmt.Sprintf("existing precompile %v", addr),
addr: addr,
Expand All @@ -69,24 +56,19 @@ func TestPrecompileOverride(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
precompile := &precompileStub{
requiredGas: tt.requiredGas,
returnData: tt.stubData,
}

params.TestOnlyClearRegisteredExtras()
params.RegisterExtras(params.Extras[params.NOOPHooks, precompileOverrides]{
NewRules: func(_ *params.ChainConfig, _ *params.Rules, _ *params.NOOPHooks, blockNum *big.Int, isMerge bool, timestamp uint64) *precompileOverrides {
return &precompileOverrides{
contracts: map[common.Address]PrecompiledContract{
tt.addr: precompile,
},
}
hooks := &hookstest.Stub{
PrecompileOverrides: map[common.Address]libevm.PrecompiledContract{
tt.addr: &precompileStub{
requiredGas: tt.requiredGas,
returnData: tt.stubData,
},
},
})
}
hooks.RegisterForRules(t)

t.Run(fmt.Sprintf("%T.Call([overridden precompile address = %v])", &EVM{}, tt.addr), func(t *testing.T) {
gotData, gotGasLeft, err := newEVM(t).Call(AccountRef{}, tt.addr, nil, gasLimit, uint256.NewInt(0))
t.Run(fmt.Sprintf("%T.Call([overridden precompile address = %v])", &vm.EVM{}, tt.addr), func(t *testing.T) {
_, evm := ethtest.NewZeroEVM(t)
gotData, gotGasLeft, err := evm.Call(vm.AccountRef{}, tt.addr, nil, gasLimit, uint256.NewInt(0))
require.NoError(t, err)
assert.Equal(t, tt.stubData, gotData, "contract's return data")
assert.Equal(t, gasLimit-tt.requiredGas, gotGasLeft, "gas left")
Expand All @@ -95,19 +77,102 @@ func TestPrecompileOverride(t *testing.T) {
}
}

func newEVM(t *testing.T) *EVM {
t.Helper()
func TestNewStatefulPrecompile(t *testing.T) {
rng := ethtest.NewPseudoRand(314159)
precompile := rng.Address()
slot := rng.Hash()

const gasLimit = 1e6
gasCost := rng.Uint64n(gasLimit)

makeOutput := func(caller, self common.Address, input []byte, stateVal common.Hash) []byte {
return []byte(fmt.Sprintf(
"Caller: %v Precompile: %v State: %v Input: %#x",
caller, self, stateVal, input,
))
}
hooks := &hookstest.Stub{
PrecompileOverrides: map[common.Address]libevm.PrecompiledContract{
precompile: vm.NewStatefulPrecompile(
func(state vm.StateDB, _ *params.Rules, caller, self common.Address, input []byte) ([]byte, error) {
return makeOutput(caller, self, input, state.GetState(precompile, slot)), nil
},
func(b []byte) uint64 {
return gasCost
},
),
},
}
hooks.RegisterForRules(t)

caller := rng.Address()
input := rng.Bytes(8)
value := rng.Hash()

sdb, err := state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase()), nil)
require.NoError(t, err, "state.New()")
state, evm := ethtest.NewZeroEVM(t)
state.SetState(precompile, slot, value)
wantReturnData := makeOutput(caller, precompile, input, value)
wantGasLeft := gasLimit - gasCost

return NewEVM(
BlockContext{
Transfer: func(_ StateDB, _, _ common.Address, _ *uint256.Int) {},
gotReturnData, gotGasLeft, err := evm.Call(vm.AccountRef(caller), precompile, input, gasLimit, uint256.NewInt(0))
require.NoError(t, err)
assert.Equal(t, wantReturnData, gotReturnData)
assert.Equal(t, wantGasLeft, gotGasLeft)
}

func TestCanCreateContract(t *testing.T) {
rng := ethtest.NewPseudoRand(142857)
account := rng.Address()
slot := rng.Hash()

makeErr := func(cc *libevm.AddressContext, stateVal common.Hash) error {
return fmt.Errorf("Origin: %v Caller: %v Contract: %v State: %v", cc.Origin, cc.Caller, cc.Self, stateVal)
}
hooks := &hookstest.Stub{
CanCreateContractFn: func(cc *libevm.AddressContext, s libevm.StateReader) error {
return makeErr(cc, s.GetState(account, slot))
},
}
hooks.RegisterForRules(t)

origin := rng.Address()
caller := rng.Address()
value := rng.Hash()
code := rng.Bytes(8)
salt := rng.Hash()

create := crypto.CreateAddress(caller, 0)
create2 := crypto.CreateAddress2(caller, salt, crypto.Keccak256(code))

tests := []struct {
name string
create func(*vm.EVM) ([]byte, common.Address, uint64, error)
wantErr error
}{
{
name: "Create",
create: func(evm *vm.EVM) ([]byte, common.Address, uint64, error) {
return evm.Create(vm.AccountRef(caller), code, 1e6, uint256.NewInt(0))
},
wantErr: makeErr(&libevm.AddressContext{Origin: origin, Caller: caller, Self: create}, value),
},
{
name: "Create2",
create: func(evm *vm.EVM) ([]byte, common.Address, uint64, error) {
return evm.Create2(vm.AccountRef(caller), code, 1e6, uint256.NewInt(0), new(uint256.Int).SetBytes(salt[:]))
},
wantErr: makeErr(&libevm.AddressContext{Origin: origin, Caller: caller, Self: create2}, value),
},
TxContext{},
sdb,
&params.ChainConfig{},
Config{},
)
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
state, evm := ethtest.NewZeroEVM(t)
state.SetState(account, slot, value)
evm.TxContext.Origin = origin

_, _, _, err := tt.create(evm)
require.EqualError(t, err, tt.wantErr.Error())
})
}
}
Loading