Skip to content

Commit

Permalink
Merge pull request #4 from berachain/new-precompile
Browse files Browse the repository at this point in the history
  • Loading branch information
Devon Bear authored Jan 19, 2023
2 parents 37748e5 + d1c5ff3 commit 981b007
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 91 deletions.
62 changes: 24 additions & 38 deletions core/vm/contracts.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package vm

import (
"context"
"crypto/sha256"
"encoding/binary"
"errors"
Expand All @@ -36,8 +37,8 @@ import (
// requires a deterministic gas count based on the input size of the Run method of the
// contract.
type PrecompiledContract interface {
RequiredGas(input []byte) uint64 // RequiredPrice calculates the contract gas use
Run(sdb StateDB, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) // Run runs the precompiled contract
RequiredGas(input []byte) uint64 // RequiredPrice calculates the contract gas use
Run(ctx context.Context, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) // Run runs the precompiled contract
}

// PrecompiledContractsHomestead contains the default set of pre-compiled Ethereum
Expand Down Expand Up @@ -140,29 +141,14 @@ func ActivePrecompiles(rules params.Rules) []common.Address {
}
}

// RunPrecompiledContract runs and evaluates the output of a precompiled contract.
// It returns
// - the returned bytes,
// - the _remaining_ gas,
// - any error that occurred
func RunPrecompiledContract(p PrecompiledContract, input []byte, suppliedGas uint64) (ret []byte, remainingGas uint64, err error) {
gasCost := p.RequiredGas(input)
if suppliedGas < gasCost {
return nil, 0, ErrOutOfGas
}
suppliedGas -= gasCost
output, err := p.Run(nil, input, common.Address{}, nil, true)
return output, suppliedGas, err
}

// ECRECOVER implemented as a native contract.
type ecrecover struct{}

func (c *ecrecover) RequiredGas(input []byte) uint64 {
return params.EcrecoverGas
}

func (c *ecrecover) Run(sdb StateDB, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
func (c *ecrecover) Run(ctx context.Context, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
const ecRecoverInputLength = 128

input = common.RightPadBytes(input, ecRecoverInputLength)
Expand Down Expand Up @@ -203,7 +189,7 @@ type sha256hash struct{}
func (c *sha256hash) RequiredGas(input []byte) uint64 {
return uint64(len(input)+31)/32*params.Sha256PerWordGas + params.Sha256BaseGas
}
func (c *sha256hash) Run(sdb StateDB, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
func (c *sha256hash) Run(ctx context.Context, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
h := sha256.Sum256(input)
return h[:], nil
}
Expand All @@ -218,7 +204,7 @@ type ripemd160hash struct{}
func (c *ripemd160hash) RequiredGas(input []byte) uint64 {
return uint64(len(input)+31)/32*params.Ripemd160PerWordGas + params.Ripemd160BaseGas
}
func (c *ripemd160hash) Run(sdb StateDB, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
func (c *ripemd160hash) Run(ctx context.Context, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
ripemd := ripemd160.New()
ripemd.Write(input)
return common.LeftPadBytes(ripemd.Sum(nil), 32), nil
Expand All @@ -234,7 +220,7 @@ type dataCopy struct{}
func (c *dataCopy) RequiredGas(input []byte) uint64 {
return uint64(len(input)+31)/32*params.IdentityPerWordGas + params.IdentityBaseGas
}
func (c *dataCopy) Run(sdb StateDB, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
func (c *dataCopy) Run(ctx context.Context, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
return input, nil
}

Expand Down Expand Up @@ -361,7 +347,7 @@ func (c *bigModExp) RequiredGas(input []byte) uint64 {
return gas.Uint64()
}

func (c *bigModExp) Run(sdb StateDB, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
func (c *bigModExp) Run(ctx context.Context, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
var (
baseLen = new(big.Int).SetBytes(getData(input, 0, 32)).Uint64()
expLen = new(big.Int).SetBytes(getData(input, 32, 32)).Uint64()
Expand Down Expand Up @@ -434,7 +420,7 @@ func (c *bn256AddIstanbul) RequiredGas(input []byte) uint64 {
return params.Bn256AddGasIstanbul
}

func (c *bn256AddIstanbul) Run(sdb StateDB, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
func (c *bn256AddIstanbul) Run(ctx context.Context, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
return runBn256Add(input)
}

Expand All @@ -447,7 +433,7 @@ func (c *bn256AddByzantium) RequiredGas(input []byte) uint64 {
return params.Bn256AddGasByzantium
}

func (c *bn256AddByzantium) Run(sdb StateDB, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
func (c *bn256AddByzantium) Run(ctx context.Context, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
return runBn256Add(input)
}

Expand All @@ -472,7 +458,7 @@ func (c *bn256ScalarMulIstanbul) RequiredGas(input []byte) uint64 {
return params.Bn256ScalarMulGasIstanbul
}

func (c *bn256ScalarMulIstanbul) Run(sdb StateDB, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
func (c *bn256ScalarMulIstanbul) Run(ctx context.Context, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
return runBn256ScalarMul(input)
}

Expand All @@ -485,7 +471,7 @@ func (c *bn256ScalarMulByzantium) RequiredGas(input []byte) uint64 {
return params.Bn256ScalarMulGasByzantium
}

func (c *bn256ScalarMulByzantium) Run(sdb StateDB, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
func (c *bn256ScalarMulByzantium) Run(ctx context.Context, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
return runBn256ScalarMul(input)
}

Expand Down Expand Up @@ -540,7 +526,7 @@ func (c *bn256PairingIstanbul) RequiredGas(input []byte) uint64 {
return params.Bn256PairingBaseGasIstanbul + uint64(len(input)/192)*params.Bn256PairingPerPointGasIstanbul
}

func (c *bn256PairingIstanbul) Run(sdb StateDB, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
func (c *bn256PairingIstanbul) Run(ctx context.Context, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
return runBn256Pairing(input)
}

Expand All @@ -553,7 +539,7 @@ func (c *bn256PairingByzantium) RequiredGas(input []byte) uint64 {
return params.Bn256PairingBaseGasByzantium + uint64(len(input)/192)*params.Bn256PairingPerPointGasByzantium
}

func (c *bn256PairingByzantium) Run(sdb StateDB, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
func (c *bn256PairingByzantium) Run(ctx context.Context, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
return runBn256Pairing(input)
}

Expand All @@ -579,7 +565,7 @@ var (
errBlake2FInvalidFinalFlag = errors.New("invalid final flag")
)

func (c *blake2F) Run(sdb StateDB, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
func (c *blake2F) Run(ctx context.Context, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
// Make sure the input is valid (correct length and final flag)
if len(input) != blake2FInputLength {
return nil, errBlake2FInvalidInputLength
Expand Down Expand Up @@ -633,7 +619,7 @@ func (c *bls12381G1Add) RequiredGas(input []byte) uint64 {
return params.Bls12381G1AddGas
}

func (c *bls12381G1Add) Run(sdb StateDB, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
func (c *bls12381G1Add) Run(ctx context.Context, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
// Implements EIP-2537 G1Add precompile.
// > G1 addition call expects `256` bytes as an input that is interpreted as byte concatenation of two G1 points (`128` bytes each).
// > Output is an encoding of addition operation result - single G1 point (`128` bytes).
Expand Down Expand Up @@ -671,7 +657,7 @@ func (c *bls12381G1Mul) RequiredGas(input []byte) uint64 {
return params.Bls12381G1MulGas
}

func (c *bls12381G1Mul) Run(sdb StateDB, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
func (c *bls12381G1Mul) Run(ctx context.Context, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
// Implements EIP-2537 G1Mul precompile.
// > G1 multiplication call expects `160` bytes as an input that is interpreted as byte concatenation of encoding of G1 point (`128` bytes) and encoding of a scalar value (`32` bytes).
// > Output is an encoding of multiplication operation result - single G1 point (`128` bytes).
Expand Down Expand Up @@ -721,7 +707,7 @@ func (c *bls12381G1MultiExp) RequiredGas(input []byte) uint64 {
return (uint64(k) * params.Bls12381G1MulGas * discount) / 1000
}

func (c *bls12381G1MultiExp) Run(sdb StateDB, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
func (c *bls12381G1MultiExp) Run(ctx context.Context, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
// Implements EIP-2537 G1MultiExp precompile.
// G1 multiplication call expects `160*k` bytes as an input that is interpreted as byte concatenation of `k` slices each of them being a byte concatenation of encoding of G1 point (`128` bytes) and encoding of a scalar value (`32` bytes).
// Output is an encoding of multiexponentiation operation result - single G1 point (`128` bytes).
Expand Down Expand Up @@ -764,7 +750,7 @@ func (c *bls12381G2Add) RequiredGas(input []byte) uint64 {
return params.Bls12381G2AddGas
}

func (c *bls12381G2Add) Run(sdb StateDB, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
func (c *bls12381G2Add) Run(ctx context.Context, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
// Implements EIP-2537 G2Add precompile.
// > G2 addition call expects `512` bytes as an input that is interpreted as byte concatenation of two G2 points (`256` bytes each).
// > Output is an encoding of addition operation result - single G2 point (`256` bytes).
Expand Down Expand Up @@ -802,7 +788,7 @@ func (c *bls12381G2Mul) RequiredGas(input []byte) uint64 {
return params.Bls12381G2MulGas
}

func (c *bls12381G2Mul) Run(sdb StateDB, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
func (c *bls12381G2Mul) Run(ctx context.Context, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
// Implements EIP-2537 G2MUL precompile logic.
// > G2 multiplication call expects `288` bytes as an input that is interpreted as byte concatenation of encoding of G2 point (`256` bytes) and encoding of a scalar value (`32` bytes).
// > Output is an encoding of multiplication operation result - single G2 point (`256` bytes).
Expand Down Expand Up @@ -852,7 +838,7 @@ func (c *bls12381G2MultiExp) RequiredGas(input []byte) uint64 {
return (uint64(k) * params.Bls12381G2MulGas * discount) / 1000
}

func (c *bls12381G2MultiExp) Run(sdb StateDB, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
func (c *bls12381G2MultiExp) Run(ctx context.Context, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
// Implements EIP-2537 G2MultiExp precompile logic
// > G2 multiplication call expects `288*k` bytes as an input that is interpreted as byte concatenation of `k` slices each of them being a byte concatenation of encoding of G2 point (`256` bytes) and encoding of a scalar value (`32` bytes).
// > Output is an encoding of multiexponentiation operation result - single G2 point (`256` bytes).
Expand Down Expand Up @@ -895,7 +881,7 @@ func (c *bls12381Pairing) RequiredGas(input []byte) uint64 {
return params.Bls12381PairingBaseGas + uint64(len(input)/384)*params.Bls12381PairingPerPairGas
}

func (c *bls12381Pairing) Run(sdb StateDB, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
func (c *bls12381Pairing) Run(ctx context.Context, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
// Implements EIP-2537 Pairing precompile logic.
// > Pairing call expects `384*k` bytes as an inputs that is interpreted as byte concatenation of `k` slices. Each slice has the following structure:
// > - `128` bytes of G1 point encoding
Expand Down Expand Up @@ -974,7 +960,7 @@ func (c *bls12381MapG1) RequiredGas(input []byte) uint64 {
return params.Bls12381MapG1Gas
}

func (c *bls12381MapG1) Run(sdb StateDB, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
func (c *bls12381MapG1) Run(ctx context.Context, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
// Implements EIP-2537 Map_To_G1 precompile.
// > Field-to-curve call expects `64` bytes an an input that is interpreted as a an element of the base field.
// > Output of this call is `128` bytes and is G1 point following respective encoding rules.
Expand Down Expand Up @@ -1009,7 +995,7 @@ func (c *bls12381MapG2) RequiredGas(input []byte) uint64 {
return params.Bls12381MapG2Gas
}

func (c *bls12381MapG2) Run(sdb StateDB, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
func (c *bls12381MapG2) Run(ctx context.Context, input []byte, caller common.Address, value *big.Int, readonly bool) ([]byte, error) {
// Implements EIP-2537 Map_FP2_TO_G2 precompile logic.
// > Field-to-curve call expects `128` bytes an an input that is interpreted as a an element of the quadratic extension field.
// > Output of this call is `256` bytes and is G2 point following respective encoding rules.
Expand Down
80 changes: 27 additions & 53 deletions core/vm/evm.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,39 +41,15 @@ type (
GetHashFunc func(uint64) common.Hash
)

type PrecompileRegistry interface {
GetContract(addr common.Address) (PrecompiledContract, bool)
}

type BasePrecompileRegistry struct {
chainRules params.Rules
}

func (bpm *BasePrecompileRegistry) GetContract(addr common.Address) (PrecompiledContract, bool) {
var precompiles map[common.Address]PrecompiledContract
switch {
case bpm.chainRules.IsBerlin:
precompiles = PrecompiledContractsBerlin
case bpm.chainRules.IsIstanbul:
precompiles = PrecompiledContractsIstanbul
case bpm.chainRules.IsByzantium:
precompiles = PrecompiledContractsByzantium
default:
precompiles = PrecompiledContractsHomestead
}
p, ok := precompiles[addr]
return p, ok
}

func RunContract(sdb StateDB, p PrecompiledContract, input []byte, caller common.Address,
value *big.Int, suppliedGas uint64, readonly bool) (ret []byte, remainingGas uint64, err error) {
gasCost := p.RequiredGas(input)
if suppliedGas < gasCost {
return nil, 0, ErrOutOfGas
}
suppliedGas -= gasCost
output, err := p.Run(sdb, input, caller, value, readonly)
return output, suppliedGas, err
// `PrecompileHost` is allows the EVM to execute a precompiled contract.
type PrecompileHost interface {
// `Exists` returns if a precompiled contract was found at `addr`.
Exists(addr common.Address) (PrecompiledContract, bool)

// `Run` runs a precompiled contract and returns the leftover gas.
Run(p PrecompiledContract, sdb StateDB, input []byte, caller common.Address,
value *big.Int, suppliedGas uint64, readonly bool,
) (ret []byte, remainingGas uint64, err error)
}

// BlockContext provides the EVM with auxiliary information. Once provided
Expand Down Expand Up @@ -120,8 +96,8 @@ type EVM struct {
TxContext
// StateDB gives access to the underlying state
StateDB StateDB
// PrecompileRegistry gives access to the precompiled contracts
PrecompileRegistry PrecompileRegistry
// precompileHost gives access to the precompiled contracts
precompileHost PrecompileHost
// Depth is the current call stack
depth int

Expand All @@ -146,17 +122,15 @@ type EVM struct {

// NewEVM returns a new EVM. The returned EVM is not thread safe and should
// only ever be used *once*.
func NewEVM(blockCtx BlockContext, txCtx TxContext, statedb StateDB, chainConfig *params.ChainConfig, config Config) *EVM {
func NewEVM(blockCtx BlockContext, txCtx TxContext, statedb StateDB, chainConfig *params.ChainConfig, config Config, precompileHost PrecompileHost) *EVM {
evm := &EVM{
Context: blockCtx,
TxContext: txCtx,
StateDB: statedb,
PrecompileRegistry: &BasePrecompileRegistry{
chainRules: chainConfig.Rules(blockCtx.BlockNumber, blockCtx.Random != nil),
},
Config: config,
chainConfig: chainConfig,
chainRules: chainConfig.Rules(blockCtx.BlockNumber, blockCtx.Random != nil),
Context: blockCtx,
TxContext: txCtx,
StateDB: statedb,
precompileHost: precompileHost,
Config: config,
chainConfig: chainConfig,
chainRules: chainConfig.Rules(blockCtx.BlockNumber, blockCtx.Random != nil),
}
evm.interpreter = NewEVMInterpreter(evm, config)
return evm
Expand Down Expand Up @@ -199,7 +173,7 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas
return nil, gas, ErrInsufficientBalance
}
snapshot := evm.StateDB.Snapshot()
p, isPrecompile := evm.PrecompileRegistry.GetContract(addr)
p, isPrecompile := evm.precompileHost.Exists(addr)

if !evm.StateDB.Exist(addr) {
if !isPrecompile && evm.chainRules.IsEIP158 && value.Sign() == 0 {
Expand Down Expand Up @@ -236,7 +210,7 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas
}

if isPrecompile {
ret, gas, err = RunContract(evm.StateDB, p, input, caller.Address(), value, gas, false)
ret, gas, err = evm.precompileHost.Run(p, evm.StateDB, input, caller.Address(), value, gas, false)
} else {
// Initialise a new contract and set the code that is to be used by the EVM.
// The contract is a scoped environment for this execution context only.
Expand Down Expand Up @@ -298,8 +272,8 @@ func (evm *EVM) CallCode(caller ContractRef, addr common.Address, input []byte,
}

// It is allowed to call precompiles, even via delegatecall
if p, isPrecompile := evm.PrecompileRegistry.GetContract(addr); isPrecompile {
ret, gas, err = RunContract(evm.StateDB, p, input, caller.Address(), value, gas, true)
if p, isPrecompile := evm.precompileHost.Exists(addr); isPrecompile {
ret, gas, err = evm.precompileHost.Run(p, evm.StateDB, input, caller.Address(), value, gas, true)
} else {
addrCopy := addr
// Initialise a new contract and set the code that is to be used by the EVM.
Expand Down Expand Up @@ -339,9 +313,9 @@ func (evm *EVM) DelegateCall(caller ContractRef, addr common.Address, input []by
}

// It is allowed to call precompiles, even via delegatecall
if p, isPrecompile := evm.PrecompileRegistry.GetContract(addr); isPrecompile {
if p, isPrecompile := evm.precompileHost.Exists(addr); isPrecompile {
parent := caller.(*Contract)
ret, gas, err = RunContract(evm.StateDB, p, input, parent.CallerAddress, parent.value, gas, false)
ret, gas, err = evm.precompileHost.Run(p, evm.StateDB, input, parent.CallerAddress, parent.value, gas, false)
} else {
addrCopy := addr
// Initialise a new contract and make initialise the delegate values
Expand Down Expand Up @@ -389,8 +363,8 @@ func (evm *EVM) StaticCall(caller ContractRef, addr common.Address, input []byte
}(gas)
}

if p, isPrecompile := evm.PrecompileRegistry.GetContract(addr); isPrecompile {
ret, gas, err = RunContract(evm.StateDB, p, input, caller.Address(), new(big.Int), gas, true)
if p, isPrecompile := evm.precompileHost.Exists(addr); isPrecompile {
ret, gas, err = evm.precompileHost.Run(p, evm.StateDB, input, caller.Address(), new(big.Int), gas, true)
} else {
// At this point, we use a copy of address. If we don't, the go compiler will
// leak the 'contract' to the outer scope, and make allocation for 'contract'
Expand Down

0 comments on commit 981b007

Please sign in to comment.