diff --git a/core/vm/contracts.go b/core/vm/contracts.go index de64292baca8..7b5f1dd5b13f 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -34,6 +34,7 @@ import ( "github.com/ethereum/go-ethereum/crypto/kzg4844" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/precompile/contract" + "github.com/ethereum/go-ethereum/precompile/modules" "golang.org/x/crypto/ripemd160" ) @@ -151,19 +152,29 @@ func init() { } // ActivePrecompiles returns the precompiles enabled with the current configuration. -func ActivePrecompiles(rules params.Rules) []common.Address { +func ActivePrecompiles(rules params.Rules) (precompiles []common.Address) { switch { case rules.IsCancun: - return PrecompiledAddressesCancun + precompiles = PrecompiledAddressesCancun case rules.IsBerlin: - return PrecompiledAddressesBerlin + precompiles = PrecompiledAddressesBerlin + case rules.IsIstanbul: - return PrecompiledAddressesIstanbul + precompiles = PrecompiledAddressesIstanbul case rules.IsByzantium: - return PrecompiledAddressesByzantium + precompiles = PrecompiledAddressesByzantium default: - return PrecompiledAddressesHomestead + precompiles = PrecompiledAddressesHomestead + } + + // TODO: Consider performance improvements here to prevent iteration & allocations on every call. + // NOTE: If using init to cache addresses, then some care should be taken to ensure all precompiles are + // registered before being cached here. + for _, precompile := range modules.RegisteredModules() { + precompiles = append(precompiles, precompile.Address) } + + return precompiles } // RunPrecompiledContract runs and evaluates the output of a precompiled contract. diff --git a/core/vm/precompile_test.go b/core/vm/precompile_test.go index 7f1a35bd9190..916ad63f0661 100644 --- a/core/vm/precompile_test.go +++ b/core/vm/precompile_test.go @@ -1,8 +1,11 @@ package vm import ( + "math/big" "testing" + "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ethereum/go-ethereum/common" @@ -54,3 +57,71 @@ func TestEvmIsPrecompileMethod(t *testing.T) { require.False(t, ok) require.Nil(t, precompiledContract) } + +func TestActivePrecompiles(t *testing.T) { + genesisTime := time.Now() + getBlockTime := func(height *big.Int) *uint64 { + if !height.IsInt64() { + t.Fatalf("expected height bounded to int64") + } + totalBlockSeconds := time.Duration(10*height.Int64()) * time.Second + blockTimeUnix := uint64(genesisTime.Add(totalBlockSeconds).Unix()) + + return &blockTimeUnix + } + chainConfig := params.TestChainConfig + chainConfig.HomesteadBlock = big.NewInt(1) + chainConfig.ByzantiumBlock = big.NewInt(2) + chainConfig.IstanbulBlock = big.NewInt(3) + chainConfig.BerlinBlock = big.NewInt(4) + chainConfig.CancunTime = getBlockTime(big.NewInt(5)) + + testCases := []struct { + name string + block *big.Int + }{ + {"homestead", chainConfig.HomesteadBlock}, + {"byzantium", chainConfig.ByzantiumBlock}, + {"istanbul", chainConfig.IstanbulBlock}, + {"berlin", chainConfig.BerlinBlock}, + {"cancun", new(big.Int).Add(chainConfig.BerlinBlock, big.NewInt(1))}, + } + + // custom precompile address used for test + contractAddress := common.HexToAddress("0x0400000000000000000000000000000000000000") + + // ensure we are not being shadowed by a core preompile address + for _, tc := range testCases { + rules := chainConfig.Rules(tc.block, false, *getBlockTime(tc.block)) + + for _, precompileAddr := range ActivePrecompiles(rules) { + if precompileAddr == contractAddress { + t.Fatalf("expected precompile %s to not be returned in %s block", contractAddress, tc.name) + } + } + } + + // register the precompile + module := modules.Module{ + Address: contractAddress, + Contract: new(mockStatefulPrecompiledContract), + } + + // TODO: should we allow dynamic registration to update ActivePrecompiles? + // Or should we enforce registration only at init? + err := modules.RegisterModule(module) + require.NoError(t, err, "could not register precompile for test") + + for _, tc := range testCases { + rules := chainConfig.Rules(tc.block, false, *getBlockTime(tc.block)) + + exists := false + for _, precompileAddr := range ActivePrecompiles(rules) { + if precompileAddr == contractAddress { + exists = true + } + } + + assert.True(t, exists, "expected %s block to include active stateful precompile %s", tc.name, contractAddress) + } +}