Skip to content

Commit

Permalink
fix bug in how eth_tester middleware filled default fields (ethereum#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Paul Robinson committed Sep 8, 2022
1 parent ce2793a commit 0ecffd7
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 9 deletions.
1 change: 1 addition & 0 deletions newsfragments/2600.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
fixed bug in how async_eth_tester_middleware fills default fields
175 changes: 175 additions & 0 deletions tests/core/middleware/test_eth_tester_middleware.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@
import pytest
from unittest.mock import (
Mock,
)

from web3.providers.eth_tester.middleware import (
async_default_transaction_fields_middleware,
default_transaction_fields_middleware,
)
from web3.types import (
BlockData,
)

SAMPLE_ADDRESS_LIST = [
"0x0000000000000000000000000000000000000001",
"0x0000000000000000000000000000000000000002",
"0x0000000000000000000000000000000000000003",
]
SAMPLE_ADDRESS = "0x0000000000000000000000000000000000000004"


@pytest.mark.parametrize("block_number", {0, "0x0", "earliest"})
def test_get_transaction_count_formatters(w3, block_number):
Expand All @@ -20,3 +34,164 @@ def test_get_block_formatters(w3):
keys_diff = all_block_keys.difference(latest_block_keys)
assert len(keys_diff) == 1
assert keys_diff.pop() == "mixHash" # mixHash is not implemented in eth-tester


@pytest.mark.parametrize(
"w3_accounts, w3_coinbase, method, from_field_added, from_field_value",
(
(SAMPLE_ADDRESS_LIST, SAMPLE_ADDRESS, "eth_call", True, SAMPLE_ADDRESS),
(
SAMPLE_ADDRESS_LIST,
SAMPLE_ADDRESS,
"eth_estimateGas",
True,
SAMPLE_ADDRESS,
),
(
SAMPLE_ADDRESS_LIST,
SAMPLE_ADDRESS,
"eth_sendTransaction",
True,
SAMPLE_ADDRESS,
),
(SAMPLE_ADDRESS_LIST, SAMPLE_ADDRESS, "eth_gasPrice", False, None),
(SAMPLE_ADDRESS_LIST, SAMPLE_ADDRESS, "eth_blockNumber", False, None),
(SAMPLE_ADDRESS_LIST, SAMPLE_ADDRESS, "meow", False, None),
(SAMPLE_ADDRESS_LIST, None, "eth_call", True, SAMPLE_ADDRESS_LIST[0]),
(SAMPLE_ADDRESS_LIST, None, "eth_estimateGas", True, SAMPLE_ADDRESS_LIST[0]),
(
SAMPLE_ADDRESS_LIST,
None,
"eth_sendTransaction",
True,
SAMPLE_ADDRESS_LIST[0],
),
(SAMPLE_ADDRESS_LIST, None, "eth_gasPrice", False, None),
(SAMPLE_ADDRESS_LIST, None, "eth_blockNumber", False, None),
(SAMPLE_ADDRESS_LIST, None, "meow", False, None),
(None, SAMPLE_ADDRESS, "eth_call", True, SAMPLE_ADDRESS),
(None, SAMPLE_ADDRESS, "eth_estimateGas", True, SAMPLE_ADDRESS),
(None, SAMPLE_ADDRESS, "eth_sendTransaction", True, SAMPLE_ADDRESS),
(None, SAMPLE_ADDRESS, "eth_gasPrice", False, SAMPLE_ADDRESS),
(None, SAMPLE_ADDRESS, "eth_blockNumber", False, SAMPLE_ADDRESS),
(None, SAMPLE_ADDRESS, "meow", False, SAMPLE_ADDRESS),
(None, None, "eth_call", True, None),
(None, None, "eth_estimateGas", True, None),
(None, None, "eth_sendTransaction", True, None),
(None, None, "eth_gasPrice", False, None),
(None, None, "eth_blockNumber", False, None),
(None, None, "meow", False, None),
),
)
def test_default_transaction_fields_middleware(
w3_accounts, w3_coinbase, method, from_field_added, from_field_value
):
def mock_request(_method, params):
return params

mock_w3 = Mock()
mock_w3.eth.accounts = w3_accounts
mock_w3.eth.coinbase = w3_coinbase

middleware = default_transaction_fields_middleware(mock_request, mock_w3)
base_params = {"chainId": 5}
filled_transaction = middleware(method, [base_params])

filled_params = filled_transaction[0]

assert ("from" in filled_params.keys()) == from_field_added
if "from" in filled_params.keys():
assert filled_params["from"] == from_field_value

filled_transaction[0].pop("from", None)
assert filled_transaction[0] == base_params


# -- async -- #


@pytest.mark.parametrize(
"w3_accounts, w3_coinbase, method, from_field_added, from_field_value",
(
(SAMPLE_ADDRESS_LIST, SAMPLE_ADDRESS, "eth_call", True, SAMPLE_ADDRESS),
(
SAMPLE_ADDRESS_LIST,
SAMPLE_ADDRESS,
"eth_estimateGas",
True,
SAMPLE_ADDRESS,
),
(
SAMPLE_ADDRESS_LIST,
SAMPLE_ADDRESS,
"eth_sendTransaction",
True,
SAMPLE_ADDRESS,
),
(SAMPLE_ADDRESS_LIST, SAMPLE_ADDRESS, "eth_gasPrice", False, None),
(SAMPLE_ADDRESS_LIST, SAMPLE_ADDRESS, "eth_blockNumber", False, None),
(SAMPLE_ADDRESS_LIST, SAMPLE_ADDRESS, "meow", False, None),
(SAMPLE_ADDRESS_LIST, None, "eth_call", True, SAMPLE_ADDRESS_LIST[0]),
(SAMPLE_ADDRESS_LIST, None, "eth_estimateGas", True, SAMPLE_ADDRESS_LIST[0]),
(
SAMPLE_ADDRESS_LIST,
None,
"eth_sendTransaction",
True,
SAMPLE_ADDRESS_LIST[0],
),
(SAMPLE_ADDRESS_LIST, None, "eth_gasPrice", False, None),
(SAMPLE_ADDRESS_LIST, None, "eth_blockNumber", False, None),
(SAMPLE_ADDRESS_LIST, None, "meow", False, None),
(None, SAMPLE_ADDRESS, "eth_call", True, SAMPLE_ADDRESS),
(None, SAMPLE_ADDRESS, "eth_estimateGas", True, SAMPLE_ADDRESS),
(None, SAMPLE_ADDRESS, "eth_sendTransaction", True, SAMPLE_ADDRESS),
(None, SAMPLE_ADDRESS, "eth_gasPrice", False, SAMPLE_ADDRESS),
(None, SAMPLE_ADDRESS, "eth_blockNumber", False, SAMPLE_ADDRESS),
(None, SAMPLE_ADDRESS, "meow", False, SAMPLE_ADDRESS),
(None, None, "eth_call", True, None),
(None, None, "eth_estimateGas", True, None),
(None, None, "eth_sendTransaction", True, None),
(None, None, "eth_gasPrice", False, None),
(None, None, "eth_blockNumber", False, None),
(None, None, "meow", False, None),
),
)
@pytest.mark.asyncio
async def test_async_default_transaction_fields_middleware(
w3_accounts,
w3_coinbase,
method,
from_field_added,
from_field_value,
):
async def mock_request(_method, params):
return params

async def mock_async_accounts():
return w3_accounts

async def mock_async_coinbase():
return w3_coinbase

mock_w3 = Mock()
mock_w3.eth.accounts = mock_async_accounts()
mock_w3.eth.coinbase = mock_async_coinbase()

middleware = await async_default_transaction_fields_middleware(
mock_request, mock_w3
)
base_params = {"chainId": 5}
filled_transaction = await middleware(method, [base_params])

filled_params = filled_transaction[0]
assert ("from" in filled_params.keys()) == from_field_added
if "from" in filled_params.keys():
assert filled_params["from"] == from_field_value

filled_transaction[0].pop("from", None)
assert filled_transaction[0] == base_params

# clean up
mock_w3.eth.accounts.close()
mock_w3.eth.coinbase.close()
38 changes: 29 additions & 9 deletions web3/providers/eth_tester/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,15 +315,21 @@ async def async_ethereum_tester_middleware( # type: ignore


def guess_from(w3: "Web3", _: TxParams) -> ChecksumAddress:
coinbase = w3.eth.coinbase
if w3.eth.coinbase:
return w3.eth.coinbase
elif w3.eth.accounts and len(w3.eth.accounts) > 0:
return w3.eth.accounts[0]

return None


async def async_guess_from(async_w3: "Web3", _: TxParams) -> ChecksumAddress:
coinbase = await async_w3.eth.coinbase # type: ignore
accounts = await async_w3.eth.accounts # type: ignore
if coinbase is not None:
return coinbase

try:
return w3.eth.accounts[0]
except KeyError:
# no accounts available to pre-fill, carry on
pass
elif accounts is not None and len(accounts) > 0:
return accounts[0]

return None

Expand All @@ -340,6 +346,18 @@ def fill_default(
return assoc(transaction, field, guess_val)


@curry
async def async_fill_default(
field: str, guess_func: Callable[..., Any], async_w3: "Web3", transaction: TxParams
) -> TxParams:
# type ignored b/c TxParams keys must be string literal types
if field in transaction and transaction[field] is not None: # type: ignore
return transaction
else:
guess_val = await guess_func(async_w3, transaction)
return assoc(transaction, field, guess_val)


def default_transaction_fields_middleware(
make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3"
) -> Callable[[RPCEndpoint, Any], RPCResponse]:
Expand All @@ -363,15 +381,17 @@ def middleware(method: RPCEndpoint, params: Any) -> RPCResponse:


async def async_default_transaction_fields_middleware(
make_request: Callable[[RPCEndpoint, Any], Any], web3: "Web3"
make_request: Callable[[RPCEndpoint, Any], Any], async_w3: "Web3"
) -> Callable[[RPCEndpoint, Any], RPCResponse]:
async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse:
if method in (
"eth_call",
"eth_estimateGas",
"eth_sendTransaction",
):
filled_transaction = fill_default("from", guess_from, web3, params[0])
filled_transaction = await async_fill_default(
"from", async_guess_from, async_w3, params[0]
)
return await make_request(method, [filled_transaction] + list(params)[1:])
else:
return await make_request(method, params)
Expand Down

0 comments on commit 0ecffd7

Please sign in to comment.