Skip to content

Commit

Permalink
formatting and validation middleware async support
Browse files Browse the repository at this point in the history
  • Loading branch information
fselmo committed Dec 15, 2021
1 parent 869101e commit 53f1de3
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 64 deletions.
4 changes: 3 additions & 1 deletion tests/integration/go_ethereum/test_goethereum_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from web3.middleware import (
async_buffered_gas_estimate_middleware,
async_gas_price_strategy_middleware,
async_validation_middleware,
)
from web3.net import (
AsyncNet,
Expand Down Expand Up @@ -85,8 +86,9 @@ async def async_w3(geth_process, endpoint_uri):
_web3 = Web3(
AsyncHTTPProvider(endpoint_uri),
middlewares=[
async_buffered_gas_estimate_middleware,
async_gas_price_strategy_middleware,
async_buffered_gas_estimate_middleware
await async_validation_middleware,
],
modules={'eth': (AsyncEth,), 'async_net': (AsyncNet,)})
return _web3
Expand Down
48 changes: 48 additions & 0 deletions web3/_utils/module_testing/eth_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
NameNotFound,
TransactionNotFound,
TransactionTypeMismatch,
ValidationError,
)
from web3.types import ( # noqa: F401
BlockData,
Expand Down Expand Up @@ -265,6 +266,30 @@ async def test_eth_send_transaction_max_fee_less_than_tip(
):
await async_w3.eth.send_transaction(txn_params) # type: ignore

@pytest.mark.asyncio
async def test_validation_middleware_chain_id_mismatch(
self, async_w3: "Web3", unlocked_account_dual_type: ChecksumAddress
) -> None:
wrong_chain_id = 1234567890
actual_chain_id = await async_w3.eth.chain_id # type: ignore

txn_params: TxParams = {
'from': unlocked_account_dual_type,
'to': unlocked_account_dual_type,
'value': Wei(1),
'gas': Wei(21000),
'maxFeePerGas': async_w3.toWei(2, 'gwei'),
'maxPriorityFeePerGas': async_w3.toWei(1, 'gwei'),
'chainId': wrong_chain_id,

}
with pytest.raises(
ValidationError,
match=f'The transaction declared chain ID {wrong_chain_id}, '
f'but the connected node is on {actual_chain_id}'
):
await async_w3.eth.send_transaction(txn_params) # type: ignore

@pytest.mark.asyncio
async def test_eth_send_raw_transaction(self, async_w3: "Web3") -> None:
# private key 0x3c2ab4e8f17a7dea191b8c991522660126d681039509dc3bb31af7c9bdb63518
Expand Down Expand Up @@ -1600,6 +1625,29 @@ def test_eth_send_transaction_max_fee_less_than_tip(
):
web3.eth.send_transaction(txn_params)

def test_validation_middleware_chain_id_mismatch(
self, web3: "Web3", unlocked_account_dual_type: ChecksumAddress
) -> None:
wrong_chain_id = 1234567890
actual_chain_id = web3.eth.chain_id

txn_params: TxParams = {
'from': unlocked_account_dual_type,
'to': unlocked_account_dual_type,
'value': Wei(1),
'gas': Wei(21000),
'maxFeePerGas': web3.toWei(2, 'gwei'),
'maxPriorityFeePerGas': web3.toWei(1, 'gwei'),
'chainId': wrong_chain_id,

}
with pytest.raises(
ValidationError,
match=f'The transaction declared chain ID {wrong_chain_id}, '
f'but the connected node is on {actual_chain_id}'
):
web3.eth.send_transaction(txn_params)

@pytest.mark.parametrize(
"max_fee",
(1000000000, None),
Expand Down
10 changes: 10 additions & 0 deletions web3/eth.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ class BaseEth(Module):
mungers=None,
)

_chain_id: Method[Callable[[], int]] = Method(
RPC.eth_chainId,
mungers=None,
)

""" property default_block """
@property
def default_block(self) -> BlockIdentifier:
Expand Down Expand Up @@ -276,6 +281,11 @@ def call_munger(
class AsyncEth(BaseEth):
is_async = True

@property
async def chain_id(self) -> int:
# types ignored b/c mypy conflict with BlockingEth properties
return await self._chain_id() # type: ignore

@property
async def gas_price(self) -> Wei:
# types ignored b/c mypy conflict with BlockingEth properties
Expand Down
6 changes: 3 additions & 3 deletions web3/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,11 @@ def default_middlewares(
"""
return [
(request_parameter_normalizer, 'request_param_normalizer'), # Delete
(gas_price_strategy_middleware, 'gas_price_strategy'), # Add Async
(gas_price_strategy_middleware, 'gas_price_strategy'),
(name_to_address_middleware(web3), 'name_to_address'), # Add Async
(attrdict_middleware, 'attrdict'), # Delete
(pythonic_middleware, 'pythonic'), # Delete
(validation_middleware, 'validation'), # Add async
(validation_middleware, 'validation'),
(abi_middleware, 'abi'), # Delete
(buffered_gas_estimate_middleware, 'gas_estimate'),
]
Expand All @@ -159,8 +159,8 @@ async def _coro_make_request(
self.logger.debug("Making request. Method: %s", method)
return await request_func(method, params)

@staticmethod
def formatted_response(
self,
response: RPCResponse,
params: Any,
error_formatters: Optional[Callable[..., Any]] = None,
Expand Down
1 change: 1 addition & 0 deletions web3/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
make_stalecheck_middleware,
)
from .validation import ( # noqa: F401
async_validation_middleware,
validation_middleware,
)

Expand Down
118 changes: 79 additions & 39 deletions web3/middleware/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
TYPE_CHECKING,
Any,
Callable,
Coroutine,
Literal,
Optional,
)

from eth_utils.toolz import (
assoc,
curry,
merge,
)

Expand All @@ -22,14 +23,20 @@
if TYPE_CHECKING:
from web3 import Web3 # noqa: F401

FORMATTER_DEFAULTS = {
"request_formatters": {},
"result_formatters": {},
"error_formatters": {},
}


def construct_formatting_middleware(
request_formatters: Optional[Formatters] = None,
result_formatters: Optional[Formatters] = None,
error_formatters: Optional[Formatters] = None
) -> Middleware:
def ignore_web3_in_standard_formatters(
w3: "Web3",
w3: "Web3", method: RPCEndpoint,
) -> FormattersDict:
return dict(
request_formatters=request_formatters or {},
Expand All @@ -41,55 +48,88 @@ def ignore_web3_in_standard_formatters(


def construct_web3_formatting_middleware(
web3_formatters_builder: Callable[["Web3"], FormattersDict]
web3_formatters_builder: Callable[["Web3", RPCEndpoint], FormattersDict],
) -> Middleware:
def formatter_middleware(
make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3"
make_request: Callable[[RPCEndpoint, Any], Any],
w3: "Web3",
) -> Callable[[RPCEndpoint, Any], RPCResponse]:
formatters = merge(
{
"request_formatters": {},
"result_formatters": {},
"error_formatters": {},
},
web3_formatters_builder(w3),
)
return apply_formatters(make_request=make_request, **formatters)
def middleware(method: RPCEndpoint, params: Any) -> RPCResponse:
formatters = merge(
FORMATTER_DEFAULTS,
web3_formatters_builder(w3, method),
)
response = _make_request_with_formatters(
method=method,
params=params,
request_formatters=formatters.pop('request_formatters'),
)
return _apply_response_formatters(method=method, response=response, **formatters)

def _make_request_with_formatters(
method: RPCEndpoint, params: Any, request_formatters: Formatters
) -> RPCResponse:
if method in request_formatters:
formatter = request_formatters[method]
formatted_params = formatter(params)
return make_request(method, formatted_params)
return make_request(method, params)

return middleware
return formatter_middleware


@curry
def apply_formatters(
async def async_construct_web3_formatting_middleware(
async_web3_formatters_builder:
Callable[["Web3", RPCEndpoint], Coroutine[Any, Any, FormattersDict]]
) -> Callable[[Callable[[RPCEndpoint, Any], Any], "Web3"],
Coroutine[Any, Any, Callable[[RPCEndpoint, Any], Coroutine[Any, Any, RPCResponse]]]]:
async def formatter_middleware(
make_request: Callable[[RPCEndpoint, Any], Any],
async_w3: "Web3",
) -> Callable[[RPCEndpoint, Any], Coroutine[Any, Any, RPCResponse]]:
async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse:
formatters = merge(
FORMATTER_DEFAULTS,
await async_web3_formatters_builder(async_w3, method),
)
response = await _make_async_request_with_formatters(
method=method,
params=params,
request_formatters=formatters.pop('request_formatters'),
)
return _apply_response_formatters(method=method, response=response, **formatters)

async def _make_async_request_with_formatters(
method: RPCEndpoint, params: Any, request_formatters: Formatters
) -> RPCResponse:
if method in request_formatters:
formatter = request_formatters[method]
formatted_params = formatter(params)
return await make_request(method, formatted_params)
return await make_request(method, params)

return middleware
return formatter_middleware


def _apply_response_formatters(
method: RPCEndpoint,
params: Any,
make_request: Callable[[RPCEndpoint, Any], RPCResponse],
request_formatters: Formatters,
response: RPCResponse,
result_formatters: Formatters,
error_formatters: Formatters,
) -> RPCResponse:
if method in request_formatters:
formatter = request_formatters[method]
formatted_params = formatter(params)
response = make_request(method, formatted_params)
else:
response = make_request(method, params)

if "result" in response and method in result_formatters:
formatter = result_formatters[method]
formatted_response = assoc(
response,
"result",
formatter(response["result"]),
def _format_response(
response_type: Literal["result", "error"],
method_response_formatter: Callable[..., Any]
) -> RPCResponse:
appropriate_response = response[response_type]
return assoc(
response, response_type, method_response_formatter(appropriate_response)
)
return formatted_response
if "result" in response and method in result_formatters:
return _format_response("result", result_formatters[method])
elif "error" in response and method in error_formatters:
formatter = error_formatters[method]
formatted_response = assoc(
response,
"error",
formatter(response["error"]),
)
return formatted_response
return _format_response("error", error_formatters[method])
else:
return response
Loading

0 comments on commit 53f1de3

Please sign in to comment.