diff --git a/newsfragments/2098.feature.rst b/newsfragments/2098.feature.rst new file mode 100644 index 0000000000..b1dffae787 --- /dev/null +++ b/newsfragments/2098.feature.rst @@ -0,0 +1 @@ +async support for formatting, validation, and geth poa middlewares \ No newline at end of file diff --git a/tests/integration/go_ethereum/test_goethereum_http.py b/tests/integration/go_ethereum/test_goethereum_http.py index 90e0b7313e..c4c4db63a4 100644 --- a/tests/integration/go_ethereum/test_goethereum_http.py +++ b/tests/integration/go_ethereum/test_goethereum_http.py @@ -18,6 +18,7 @@ from web3.middleware import ( async_buffered_gas_estimate_middleware, async_gas_price_strategy_middleware, + async_validation_middleware, ) from web3.net import ( AsyncNet, @@ -95,8 +96,9 @@ async def async_w3(geth_process, endpoint_uri): _web3 = Web3( AsyncHTTPProvider(endpoint_uri), middlewares=[ + async_buffered_gas_estimate_middleware, async_gas_price_strategy_middleware, - async_buffered_gas_estimate_middleware + async_validation_middleware, ], modules={'eth': AsyncEth, 'async_net': AsyncNet, diff --git a/web3/_utils/module_testing/eth_module.py b/web3/_utils/module_testing/eth_module.py index 887201fa39..6d238b1f63 100644 --- a/web3/_utils/module_testing/eth_module.py +++ b/web3/_utils/module_testing/eth_module.py @@ -53,9 +53,14 @@ TimeExhausted, TransactionNotFound, TransactionTypeMismatch, + ValidationError, +) +from web3.middleware import ( + async_geth_poa_middleware, ) from web3.middleware.fixture import ( async_construct_error_generator_middleware, + async_construct_result_generator_middleware, construct_error_generator_middleware, ) from web3.types import ( # noqa: F401 @@ -290,6 +295,47 @@ async def test_eth_send_transaction_max_fee_less_than_tip( ): await async_w3.eth.send_transaction(txn_params) # type: ignore + @pytest.mark.asyncio + async def test_validation_middleware_chain_id_mismatch( + self, async_w3: "Web3", unlocked_account_dual_type: ChecksumAddress + ) -> None: + wrong_chain_id = 1234567890 + actual_chain_id = await async_w3.eth.chain_id # type: ignore + + txn_params: TxParams = { + 'from': unlocked_account_dual_type, + 'to': unlocked_account_dual_type, + 'value': Wei(1), + 'gas': Wei(21000), + 'maxFeePerGas': async_w3.toWei(2, 'gwei'), + 'maxPriorityFeePerGas': async_w3.toWei(1, 'gwei'), + 'chainId': wrong_chain_id, + + } + with pytest.raises( + ValidationError, + match=f'The transaction declared chain ID {wrong_chain_id}, ' + f'but the connected node is on {actual_chain_id}' + ): + await async_w3.eth.send_transaction(txn_params) # type: ignore + + @pytest.mark.asyncio + async def test_geth_poa_middleware(self, async_w3: "Web3") -> None: + return_block_with_long_extra_data = await async_construct_result_generator_middleware( + { + RPCEndpoint('eth_getBlockByNumber'): lambda *_: {'extraData': '0x' + 'ff' * 33}, + } + ) + async_w3.middleware_onion.inject(async_geth_poa_middleware, 'poa', layer=0) + async_w3.middleware_onion.inject(return_block_with_long_extra_data, 'extradata', layer=0) + block = await async_w3.eth.get_block('latest') # type: ignore + assert 'extraData' not in block + assert block.proofOfAuthorityData == b'\xff' * 33 + + # clean up + async_w3.middleware_onion.remove('poa') + async_w3.middleware_onion.remove('extradata') + @pytest.mark.asyncio async def test_eth_send_raw_transaction(self, async_w3: "Web3") -> None: # private key 0x3c2ab4e8f17a7dea191b8c991522660126d681039509dc3bb31af7c9bdb63518 @@ -1997,6 +2043,29 @@ def test_eth_send_transaction_max_fee_less_than_tip( ): web3.eth.send_transaction(txn_params) + def test_validation_middleware_chain_id_mismatch( + self, web3: "Web3", unlocked_account_dual_type: ChecksumAddress + ) -> None: + wrong_chain_id = 1234567890 + actual_chain_id = web3.eth.chain_id + + txn_params: TxParams = { + 'from': unlocked_account_dual_type, + 'to': unlocked_account_dual_type, + 'value': Wei(1), + 'gas': Wei(21000), + 'maxFeePerGas': web3.toWei(2, 'gwei'), + 'maxPriorityFeePerGas': web3.toWei(1, 'gwei'), + 'chainId': wrong_chain_id, + + } + with pytest.raises( + ValidationError, + match=f'The transaction declared chain ID {wrong_chain_id}, ' + f'but the connected node is on {actual_chain_id}' + ): + web3.eth.send_transaction(txn_params) + @pytest.mark.parametrize( "max_fee", (1000000000, None), diff --git a/web3/manager.py b/web3/manager.py index b731ee7a81..1cd4e1c241 100644 --- a/web3/manager.py +++ b/web3/manager.py @@ -128,11 +128,11 @@ def default_middlewares( """ return [ (request_parameter_normalizer, 'request_param_normalizer'), # Delete - (gas_price_strategy_middleware, 'gas_price_strategy'), # Add Async + (gas_price_strategy_middleware, 'gas_price_strategy'), (name_to_address_middleware(web3), 'name_to_address'), # Add Async (attrdict_middleware, 'attrdict'), # Delete (pythonic_middleware, 'pythonic'), # Delete - (validation_middleware, 'validation'), # Add async + (validation_middleware, 'validation'), (abi_middleware, 'abi'), # Delete (buffered_gas_estimate_middleware, 'gas_estimate'), ] @@ -159,8 +159,8 @@ async def _coro_make_request( self.logger.debug("Making request. Method: %s", method) return await request_func(method, params) + @staticmethod def formatted_response( - self, response: RPCResponse, params: Any, error_formatters: Optional[Callable[..., Any]] = None, diff --git a/web3/middleware/__init__.py b/web3/middleware/__init__.py index baad91b1b6..ea2f00c6ed 100644 --- a/web3/middleware/__init__.py +++ b/web3/middleware/__init__.py @@ -51,6 +51,7 @@ gas_price_strategy_middleware, ) from .geth_poa import ( # noqa: F401 + async_geth_poa_middleware, geth_poa_middleware, ) from .names import ( # noqa: F401 @@ -69,6 +70,7 @@ make_stalecheck_middleware, ) from .validation import ( # noqa: F401 + async_validation_middleware, validation_middleware, ) diff --git a/web3/middleware/buffered_gas_estimate.py b/web3/middleware/buffered_gas_estimate.py index 8f75f8ae86..194c4178e1 100644 --- a/web3/middleware/buffered_gas_estimate.py +++ b/web3/middleware/buffered_gas_estimate.py @@ -2,7 +2,6 @@ TYPE_CHECKING, Any, Callable, - Coroutine, ) from eth_utils.toolz import ( @@ -16,6 +15,7 @@ get_buffered_gas_estimate, ) from web3.types import ( + AsyncMiddleware, RPCEndpoint, RPCResponse, ) @@ -43,7 +43,7 @@ def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: async def async_buffered_gas_estimate_middleware( make_request: Callable[[RPCEndpoint, Any], Any], web3: "Web3" -) -> Callable[[RPCEndpoint, Any], Coroutine[Any, Any, RPCResponse]]: +) -> AsyncMiddleware: async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: if method == 'eth_sendTransaction': transaction = params[0] diff --git a/web3/middleware/fixture.py b/web3/middleware/fixture.py index 1565ab50a7..cba6372f57 100644 --- a/web3/middleware/fixture.py +++ b/web3/middleware/fixture.py @@ -2,11 +2,11 @@ TYPE_CHECKING, Any, Callable, - Coroutine, Dict, ) from web3.types import ( + AsyncMiddleware, Middleware, RPCEndpoint, RPCResponse, @@ -78,6 +78,28 @@ def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: return error_generator_middleware +# --- async --- # + +async def async_construct_result_generator_middleware( + result_generators: Dict[RPCEndpoint, Any] +) -> Middleware: + """ + Constructs a middleware which returns a static response for any method + which is found in the provided fixtures. + """ + async def result_generator_middleware( + make_request: Callable[[RPCEndpoint, Any], Any], _: "Web3" + ) -> AsyncMiddleware: + async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: + if method in result_generators: + result = result_generators[method](method, params) + return {'result': result} + else: + return await make_request(method, params) + return middleware + return result_generator_middleware + + async def async_construct_error_generator_middleware( error_generators: Dict[RPCEndpoint, Any] ) -> Middleware: @@ -89,7 +111,7 @@ async def async_construct_error_generator_middleware( """ async def error_generator_middleware( make_request: Callable[[RPCEndpoint, Any], Any], _: "Web3" - ) -> Callable[[RPCEndpoint, Any], Coroutine[Any, Any, RPCResponse]]: + ) -> AsyncMiddleware: async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: if method in error_generators: error_msg = error_generators[method](method, params) diff --git a/web3/middleware/formatting.py b/web3/middleware/formatting.py index 3542583a9a..04602be0ae 100644 --- a/web3/middleware/formatting.py +++ b/web3/middleware/formatting.py @@ -2,18 +2,20 @@ TYPE_CHECKING, Any, Callable, + Coroutine, Optional, ) from eth_utils.toolz import ( assoc, - curry, merge, ) from web3.types import ( + AsyncMiddleware, Formatters, FormattersDict, + Literal, Middleware, RPCEndpoint, RPCResponse, @@ -22,6 +24,38 @@ if TYPE_CHECKING: from web3 import Web3 # noqa: F401 +FORMATTER_DEFAULTS: FormattersDict = { + "request_formatters": {}, + "result_formatters": {}, + "error_formatters": {}, +} + + +def _apply_response_formatters( + method: RPCEndpoint, + response: RPCResponse, + result_formatters: Formatters, + error_formatters: Formatters, +) -> RPCResponse: + + def _format_response( + response_type: Literal["result", "error"], + method_response_formatter: Callable[..., Any] + ) -> RPCResponse: + appropriate_response = response[response_type] + return assoc( + response, response_type, method_response_formatter(appropriate_response) + ) + + if "result" in response and method in result_formatters: + return _format_response("result", result_formatters[method]) + elif "error" in response and method in error_formatters: + return _format_response("error", error_formatters[method]) + else: + return response + + +# --- sync -- # def construct_formatting_middleware( request_formatters: Optional[Formatters] = None, @@ -29,7 +63,7 @@ def construct_formatting_middleware( error_formatters: Optional[Formatters] = None ) -> Middleware: def ignore_web3_in_standard_formatters( - w3: "Web3", + _w3: "Web3", _method: RPCEndpoint, ) -> FormattersDict: return dict( request_formatters=request_formatters or {}, @@ -41,55 +75,67 @@ def ignore_web3_in_standard_formatters( def construct_web3_formatting_middleware( - web3_formatters_builder: Callable[["Web3"], FormattersDict] + web3_formatters_builder: Callable[["Web3", RPCEndpoint], FormattersDict], ) -> Middleware: def formatter_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" + make_request: Callable[[RPCEndpoint, Any], Any], + w3: "Web3", ) -> Callable[[RPCEndpoint, Any], RPCResponse]: - formatters = merge( - { - "request_formatters": {}, - "result_formatters": {}, - "error_formatters": {}, - }, - web3_formatters_builder(w3), - ) - return apply_formatters(make_request=make_request, **formatters) + def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: + formatters = merge( + FORMATTER_DEFAULTS, + web3_formatters_builder(w3, method), + ) + request_formatters = formatters.pop('request_formatters') + + if method in request_formatters: + formatter = request_formatters[method] + params = formatter(params) + response = make_request(method, params) + return _apply_response_formatters(method=method, response=response, **formatters) + return middleware return formatter_middleware -@curry -def apply_formatters( - method: RPCEndpoint, - params: Any, - make_request: Callable[[RPCEndpoint, Any], RPCResponse], - request_formatters: Formatters, - result_formatters: Formatters, - error_formatters: Formatters, -) -> RPCResponse: - if method in request_formatters: - formatter = request_formatters[method] - formatted_params = formatter(params) - response = make_request(method, formatted_params) - else: - response = make_request(method, params) +# --- async --- # - if "result" in response and method in result_formatters: - formatter = result_formatters[method] - formatted_response = assoc( - response, - "result", - formatter(response["result"]), - ) - return formatted_response - elif "error" in response and method in error_formatters: - formatter = error_formatters[method] - formatted_response = assoc( - response, - "error", - formatter(response["error"]), +async def async_construct_formatting_middleware( + request_formatters: Optional[Formatters] = None, + result_formatters: Optional[Formatters] = None, + error_formatters: Optional[Formatters] = None +) -> Middleware: + async def ignore_web3_in_standard_formatters( + _w3: "Web3", _method: RPCEndpoint, + ) -> FormattersDict: + return dict( + request_formatters=request_formatters or {}, + result_formatters=result_formatters or {}, + error_formatters=error_formatters or {}, ) - return formatted_response - else: - return response + return await async_construct_web3_formatting_middleware(ignore_web3_in_standard_formatters) + + +async def async_construct_web3_formatting_middleware( + async_web3_formatters_builder: + Callable[["Web3", RPCEndpoint], Coroutine[Any, Any, FormattersDict]] +) -> Callable[[Callable[[RPCEndpoint, Any], Any], "Web3"], Coroutine[Any, Any, AsyncMiddleware]]: + async def formatter_middleware( + make_request: Callable[[RPCEndpoint, Any], Any], + async_w3: "Web3", + ) -> AsyncMiddleware: + async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: + formatters = merge( + FORMATTER_DEFAULTS, + await async_web3_formatters_builder(async_w3, method), + ) + request_formatters = formatters.pop('request_formatters') + + if method in request_formatters: + formatter = request_formatters[method] + params = formatter(params) + response = await make_request(method, params) + + return _apply_response_formatters(method=method, response=response, **formatters) + return middleware + return formatter_middleware diff --git a/web3/middleware/gas_price_strategy.py b/web3/middleware/gas_price_strategy.py index f0a7254fcc..9d3fcacc6d 100644 --- a/web3/middleware/gas_price_strategy.py +++ b/web3/middleware/gas_price_strategy.py @@ -2,7 +2,6 @@ TYPE_CHECKING, Any, Callable, - Coroutine, ) from eth_utils.toolz import ( @@ -22,6 +21,7 @@ TransactionTypeMismatch, ) from web3.types import ( + AsyncMiddleware, BlockData, RPCEndpoint, RPCResponse, @@ -94,7 +94,7 @@ def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: async def async_gas_price_strategy_middleware( make_request: Callable[[RPCEndpoint, Any], Any], web3: "Web3" -) -> Callable[[RPCEndpoint, Any], Coroutine[Any, Any, RPCResponse]]: +) -> AsyncMiddleware: """ - Uses a gas price strategy if one is set. This is only supported for legacy transactions. It is recommended to send dynamic fee transactions (EIP-1559) whenever possible. diff --git a/web3/middleware/geth_poa.py b/web3/middleware/geth_poa.py index f941b69a01..6048cc70f0 100644 --- a/web3/middleware/geth_poa.py +++ b/web3/middleware/geth_poa.py @@ -1,3 +1,9 @@ +from typing import ( + TYPE_CHECKING, + Any, + Callable, +) + from eth_utils.curried import ( apply_formatter_if, apply_formatters_to_dict, @@ -16,8 +22,16 @@ RPC, ) from web3.middleware.formatting import ( + async_construct_formatting_middleware, construct_formatting_middleware, ) +from web3.types import ( + AsyncMiddleware, + RPCEndpoint, +) + +if TYPE_CHECKING: + from web3 import Web3 # noqa: F401 is_not_null = complement(is_null) @@ -31,9 +45,22 @@ geth_poa_cleanup = compose(pythonic_geth_poa, remap_geth_poa_fields) + geth_poa_middleware = construct_formatting_middleware( result_formatters={ RPC.eth_getBlockByHash: apply_formatter_if(is_not_null, geth_poa_cleanup), RPC.eth_getBlockByNumber: apply_formatter_if(is_not_null, geth_poa_cleanup), }, ) + + +async def async_geth_poa_middleware( + make_request: Callable[[RPCEndpoint, Any], Any], web3: "Web3" +) -> AsyncMiddleware: + middleware = await async_construct_formatting_middleware( + result_formatters={ + RPC.eth_getBlockByHash: apply_formatter_if(is_not_null, geth_poa_cleanup), + RPC.eth_getBlockByNumber: apply_formatter_if(is_not_null, geth_poa_cleanup), + }, + ) + return await middleware(make_request, web3) diff --git a/web3/middleware/validation.py b/web3/middleware/validation.py index 2e956655f4..44d79825de 100644 --- a/web3/middleware/validation.py +++ b/web3/middleware/validation.py @@ -2,6 +2,7 @@ TYPE_CHECKING, Any, Callable, + Dict, ) from eth_utils.curried import ( @@ -32,10 +33,14 @@ ValidationError, ) from web3.middleware.formatting import ( + async_construct_web3_formatting_middleware, construct_web3_formatting_middleware, ) from web3.types import ( + AsyncMiddleware, + Formatters, FormattersDict, + RPCEndpoint, TxParams, ) @@ -45,25 +50,24 @@ MAX_EXTRADATA_LENGTH = 32 is_not_null = complement(is_null) - to_integer_if_hex = apply_formatter_if(is_string, hex_to_integer) @curry -def validate_chain_id(web3: "Web3", chain_id: int) -> int: - if to_integer_if_hex(chain_id) == web3.eth.chain_id: +def _validate_chain_id(web3_chain_id: int, chain_id: int) -> int: + if to_integer_if_hex(chain_id) == web3_chain_id: return chain_id else: raise ValidationError( "The transaction declared chain ID %r, " "but the connected node is on %r" % ( chain_id, - web3.eth.chain_id, + web3_chain_id, ) ) -def check_extradata_length(val: Any) -> Any: +def _check_extradata_length(val: Any) -> Any: if not isinstance(val, (str, int, bytes)): return val result = HexBytes(val) @@ -80,16 +84,16 @@ def check_extradata_length(val: Any) -> Any: return val -def transaction_normalizer(transaction: TxParams) -> TxParams: +def _transaction_normalizer(transaction: TxParams) -> TxParams: return dissoc(transaction, 'chainId') -def transaction_param_validator(web3: "Web3") -> Callable[..., Any]: +def _transaction_param_validator(web3_chain_id: int) -> Callable[..., Any]: transactions_params_validators = { "chainId": apply_formatter_if( # Bypass `validate_chain_id` if chainId can't be determined - lambda _: is_not_null(web3.eth.chain_id), - validate_chain_id(web3), + lambda _: is_not_null(web3_chain_id), + _validate_chain_id(web3_chain_id), ), } return apply_formatter_at_index( @@ -99,36 +103,70 @@ def transaction_param_validator(web3: "Web3") -> Callable[..., Any]: BLOCK_VALIDATORS = { - 'extraData': check_extradata_length, + 'extraData': _check_extradata_length, } - - block_validator = apply_formatter_if( is_not_null, apply_formatters_to_dict(BLOCK_VALIDATORS) ) +METHODS_TO_VALIDATE = [ + RPC.eth_sendTransaction, + RPC.eth_estimateGas, + RPC.eth_call +] -@curry -def chain_id_validator(web3: "Web3") -> Callable[..., Any]: + +def _chain_id_validator(web3_chain_id: int) -> Callable[..., Any]: return compose( - apply_formatter_at_index(transaction_normalizer, 0), - transaction_param_validator(web3) + apply_formatter_at_index(_transaction_normalizer, 0), + _transaction_param_validator(web3_chain_id) ) -def build_validators_with_web3(w3: "Web3") -> FormattersDict: +def _build_formatters_dict(request_formatters: Dict[RPCEndpoint, Any]) -> FormattersDict: return dict( - request_formatters={ - RPC.eth_sendTransaction: chain_id_validator(w3), - RPC.eth_estimateGas: chain_id_validator(w3), - RPC.eth_call: chain_id_validator(w3), - }, + request_formatters=request_formatters, result_formatters={ RPC.eth_getBlockByHash: block_validator, RPC.eth_getBlockByNumber: block_validator, - }, + } ) +# -- sync -- # + + +def build_method_validators(w3: "Web3", method: RPCEndpoint) -> FormattersDict: + request_formatters = {} + if RPCEndpoint(method) in METHODS_TO_VALIDATE: + w3_chain_id = w3.eth.chain_id + for method in METHODS_TO_VALIDATE: + request_formatters[method] = _chain_id_validator(w3_chain_id) + + return _build_formatters_dict(request_formatters) + -validation_middleware = construct_web3_formatting_middleware(build_validators_with_web3) +validation_middleware = construct_web3_formatting_middleware( + build_method_validators +) + + +# -- async --- # + +async def async_build_method_validators(async_w3: "Web3", method: RPCEndpoint) -> FormattersDict: + request_formatters: Formatters = {} + if RPCEndpoint(method) in METHODS_TO_VALIDATE: + w3_chain_id = await async_w3.eth.chain_id # type: ignore + for method in METHODS_TO_VALIDATE: + request_formatters[method] = _chain_id_validator(w3_chain_id) + + return _build_formatters_dict(request_formatters) + + +async def async_validation_middleware( + make_request: Callable[[RPCEndpoint, Any], Any], web3: "Web3" +) -> AsyncMiddleware: + middleware = await async_construct_web3_formatting_middleware( + async_build_method_validators + ) + return await middleware(make_request, web3) diff --git a/web3/types.py b/web3/types.py index f475fc14d3..fa9ddac284 100644 --- a/web3/types.py +++ b/web3/types.py @@ -2,6 +2,7 @@ TYPE_CHECKING, Any, Callable, + Coroutine, Dict, List, NewType, @@ -135,13 +136,14 @@ class RPCResponse(TypedDict, total=False): Middleware = Callable[[Callable[[RPCEndpoint, Any], RPCResponse], "Web3"], Any] +AsyncMiddleware = Callable[[RPCEndpoint, Any], Coroutine[Any, Any, RPCResponse]] MiddlewareOnion = NamedElementOnion[str, Middleware] class FormattersDict(TypedDict, total=False): - error_formatters: Formatters - request_formatters: Formatters - result_formatters: Formatters + error_formatters: Optional[Formatters] + request_formatters: Optional[Formatters] + result_formatters: Optional[Formatters] class FilterParams(TypedDict, total=False):