diff --git a/ens/constants.py b/ens/constants.py index 348da81b2e..b7a600e3c9 100644 --- a/ens/constants.py +++ b/ens/constants.py @@ -1,5 +1,6 @@ from eth_typing import ( HexAddress, + HexStr, ) from hexbytes import ( HexBytes, @@ -11,6 +12,6 @@ AUCTION_START_GAS_MARGINAL = 39000 EMPTY_SHA3_BYTES = HexBytes(b'\0' * 32) -EMPTY_ADDR_HEX = HexAddress('0x' + '00' * 20) +EMPTY_ADDR_HEX = HexAddress(HexStr('0x' + '00' * 20)) REVERSE_REGISTRAR_DOMAIN = 'addr.reverse' diff --git a/ens/main.py b/ens/main.py index 58192b4939..bcdf297dde 100644 --- a/ens/main.py +++ b/ens/main.py @@ -12,6 +12,7 @@ ChecksumAddress, Hash32, HexAddress, + HexStr, ) from eth_utils import ( is_binary_address, @@ -56,7 +57,9 @@ ) -ENS_MAINNET_ADDR = ChecksumAddress(HexAddress('0x314159265dD8dbb310642f98f50C066173C1259b')) +ENS_MAINNET_ADDR = ChecksumAddress( + HexAddress(HexStr('0x314159265dD8dbb310642f98f50C066173C1259b')) +) class ENS: diff --git a/tox.ini b/tox.ini index faac162d5c..ad80291087 100644 --- a/tox.ini +++ b/tox.ini @@ -63,4 +63,4 @@ extras=linter commands= flake8 {toxinidir}/web3 {toxinidir}/ens {toxinidir}/ethpm {toxinidir}/tests isort --recursive --check-only --diff {toxinidir}/web3/ {toxinidir}/ens/ {toxinidir}/ethpm/ {toxinidir}/tests/ - mypy -p web3.providers.eth_tester -p web3.main -p web3.contract -p web3.datastructures -p web3.eth -p web3.exceptions -p web3.geth -p web3.iban -p web3.logs -p web3.manager -p web3.module -p web3.net -p web3.parity -p web3.middleware -p web3.pm -p web3.auto -p web3.gas_strategies -p web3.testing -p web3.tools -p web3.version -p ethpm -p ens --config-file {toxinidir}/mypy.ini + mypy -p web3.providers -p web3.main -p web3.contract -p web3.datastructures -p web3.eth -p web3.exceptions -p web3.geth -p web3.iban -p web3.logs -p web3.manager -p web3.module -p web3.net -p web3.parity -p web3.middleware -p web3.pm -p web3.auto -p web3.gas_strategies -p web3.testing -p web3.tools -p web3.version -p ethpm -p ens --config-file {toxinidir}/mypy.ini diff --git a/web3/main.py b/web3/main.py index a811cfb77b..5146272e18 100644 --- a/web3/main.py +++ b/web3/main.py @@ -47,9 +47,6 @@ from web3._utils.normalizers import ( abi_ens_resolver, ) -from web3.datastructures import ( - NamedElementOnion, -) from web3.eth import ( Eth, ) @@ -93,8 +90,9 @@ from web3.testing import ( Testing, ) -from web3.types import ( +from web3.types import ( # noqa: F401 Middleware, + MiddlewareOnion, ) from web3.version import ( Version, @@ -175,7 +173,7 @@ def __init__( self.ens = ens @property - def middleware_onion(self) -> NamedElementOnion[str, Middleware]: + def middleware_onion(self) -> MiddlewareOnion: return self.manager.middleware_onion @property diff --git a/web3/manager.py b/web3/manager.py index 7629277a55..3dcf5f06a0 100644 --- a/web3/manager.py +++ b/web3/manager.py @@ -6,6 +6,7 @@ Dict, List, NoReturn, + Optional, Sequence, Tuple, ) @@ -40,8 +41,9 @@ AutoProvider, BaseProvider, ) -from web3.types import ( +from web3.types import ( # noqa: F401 Middleware, + MiddlewareOnion, RPCResponse, ) @@ -74,7 +76,7 @@ def __init__( if middlewares is None: middlewares = self.default_middlewares(web3) - self.middleware_onion: NamedElementOnion[str, Middleware] = NamedElementOnion(middlewares) + self.middleware_onion: MiddlewareOnion = NamedElementOnion(middlewares) if provider is None: self.provider = AutoProvider() @@ -117,14 +119,14 @@ def default_middlewares( def _make_request(self, method: str, params: Any) -> RPCResponse: request_func = self.provider.request_func( self.web3, - tuple(self.middleware_onion)) + self.middleware_onion) self.logger.debug("Making request. Method: %s", method) return request_func(method, params) async def _coro_make_request(self, method: str, params: Any) -> RPCResponse: request_func = self.provider.request_func( self.web3, - tuple(self.middleware_onion)) + self.middleware_onion) self.logger.debug("Making request. Method: %s", method) return await request_func(method, params) diff --git a/web3/providers/auto.py b/web3/providers/auto.py index 26c9ca3ab6..9424469bc2 100644 --- a/web3/providers/auto.py +++ b/web3/providers/auto.py @@ -1,8 +1,22 @@ import os +from typing import ( + Any, + Callable, + Dict, + Optional, + Sequence, + Tuple, + Type, + Union, +) from urllib.parse import ( urlparse, ) +from eth_typing import ( + URI, +) + from web3.exceptions import ( CannotHandleRequest, ) @@ -12,20 +26,26 @@ IPCProvider, WebsocketProvider, ) +from web3.types import ( + RPCEndpoint, + RPCResponse, +) HTTP_SCHEMES = {'http', 'https'} WS_SCHEMES = {'ws', 'wss'} -def load_provider_from_environment(): - uri_string = os.environ.get('WEB3_PROVIDER_URI', '') +def load_provider_from_environment() -> BaseProvider: + uri_string = URI(os.environ.get('WEB3_PROVIDER_URI', '')) if not uri_string: return None return load_provider_from_uri(uri_string) -def load_provider_from_uri(uri_string, headers=None): +def load_provider_from_uri( + uri_string: URI, headers: Dict[str, Tuple[str, str]]=None +) -> BaseProvider: uri = urlparse(uri_string) if uri.scheme == 'file': return IPCProvider(uri.path) @@ -52,7 +72,10 @@ class AutoProvider(BaseProvider): ) _active_provider = None - def __init__(self, potential_providers=None): + def __init__( + self, + potential_providers: Sequence[Union[Callable[..., BaseProvider], Type[BaseProvider]]]=None + ) -> None: """ :param iterable potential_providers: ordered series of provider classes to attempt with @@ -65,17 +88,17 @@ def __init__(self, potential_providers=None): else: self._potential_providers = self.default_providers - def make_request(self, method, params): + def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: try: return self._proxy_request(method, params) except IOError as exc: return self._proxy_request(method, params, use_cache=False) - def isConnected(self): + def isConnected(self) -> bool: provider = self._get_active_provider(use_cache=True) return provider is not None and provider.isConnected() - def _proxy_request(self, method, params, use_cache=True): + def _proxy_request(self, method: RPCEndpoint, params: Any, use_cache: bool=True) -> RPCResponse: provider = self._get_active_provider(use_cache) if provider is None: raise CannotHandleRequest( @@ -87,7 +110,7 @@ def _proxy_request(self, method, params, use_cache=True): return provider.make_request(method, params) - def _get_active_provider(self, use_cache): + def _get_active_provider(self, use_cache: bool) -> Optional[BaseProvider]: if use_cache and self._active_provider is not None: return self._active_provider diff --git a/web3/providers/base.py b/web3/providers/base.py index 0101e8c48d..c00c0462f3 100644 --- a/web3/providers/base.py +++ b/web3/providers/base.py @@ -1,4 +1,11 @@ import itertools +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Sequence, + Tuple, +) from eth_utils import ( to_bytes, @@ -11,26 +18,42 @@ from web3.middleware import ( combine_middlewares, ) +from web3.types import ( + Middleware, + MiddlewareOnion, + RPCEndpoint, + RPCResponse, +) + +if TYPE_CHECKING: + from web3 import Web3 # noqa: F401 class BaseProvider: - _middlewares = () - _request_func_cache = (None, None) # a tuple of (all_middlewares, request_func) + _middlewares: Tuple[Middleware, ...] = () + # a tuple of (all_middlewares, request_func) + _request_func_cache: Tuple[Tuple[Middleware, ...], Callable[..., RPCResponse]] = (None, None) @property - def middlewares(self): + def middlewares(self) -> Tuple[Middleware, ...]: return self._middlewares @middlewares.setter - def middlewares(self, values): - self._middlewares = tuple(values) - - def request_func(self, web3, outer_middlewares): + def middlewares( + self, values: MiddlewareOnion + ) -> None: + # tuple(values) converts to MiddlewareOnion -> Tuple[Middleware, ...] + self._middlewares = tuple(values) # type: ignore + + def request_func( + self, web3: "Web3", outer_middlewares: MiddlewareOnion + ) -> Callable[..., RPCResponse]: """ @param outer_middlewares is an iterable of middlewares, ordered by first to execute @returns a function that calls all the middleware and eventually self.make_request() """ - all_middlewares = tuple(outer_middlewares) + tuple(self.middlewares) + # type ignored b/c tuple(MiddlewareOnion) converts to tuple of middlewares + all_middlewares: Tuple[Middleware] = tuple(outer_middlewares) + tuple(self.middlewares) # type: ignore # noqa: E501 cache_key = self._request_func_cache[0] if cache_key is None or cache_key != all_middlewares: @@ -40,29 +63,31 @@ def request_func(self, web3, outer_middlewares): ) return self._request_func_cache[-1] - def _generate_request_func(self, web3, middlewares): + def _generate_request_func( + self, web3: "Web3", middlewares: Sequence[Middleware] + ) -> Callable[..., RPCResponse]: return combine_middlewares( middlewares=middlewares, web3=web3, provider_request_fn=self.make_request, ) - def make_request(self, method, params): + def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: raise NotImplementedError("Providers must implement this method") - def isConnected(self): + def isConnected(self) -> bool: raise NotImplementedError("Providers must implement this method") class JSONBaseProvider(BaseProvider): - def __init__(self): + def __init__(self) -> None: self.request_counter = itertools.count() - def decode_rpc_response(self, response): - text_response = to_text(response) + def decode_rpc_response(self, raw_response: bytes) -> RPCResponse: + text_response = to_text(raw_response) return FriendlyJsonSerde().json_decode(text_response) - def encode_rpc_request(self, method, params): + def encode_rpc_request(self, method: RPCEndpoint, params: Any) -> bytes: rpc_dict = { "jsonrpc": "2.0", "method": method, @@ -72,9 +97,9 @@ def encode_rpc_request(self, method, params): encoded = FriendlyJsonSerde().json_encode(rpc_dict) return to_bytes(text=encoded) - def isConnected(self): + def isConnected(self) -> bool: try: - response = self.make_request('web3_clientVersion', []) + response = self.make_request(RPCEndpoint('web3_clientVersion'), []) except IOError: return False diff --git a/web3/providers/eth_tester/main.py b/web3/providers/eth_tester/main.py index 8337d9045c..da8665c8ae 100644 --- a/web3/providers/eth_tester/main.py +++ b/web3/providers/eth_tester/main.py @@ -44,10 +44,10 @@ async def make_request( class EthereumTesterProvider(BaseProvider): - middlewares = [ + middlewares = ( default_transaction_fields_middleware, ethereum_tester_middleware, - ] + ) ethereum_tester = None api_endpoints = None diff --git a/web3/providers/ipc.py b/web3/providers/ipc.py index f77623e9ce..6fc45ebfd7 100644 --- a/web3/providers/ipc.py +++ b/web3/providers/ipc.py @@ -1,3 +1,6 @@ +from json import ( + JSONDecodeError, +) import logging import os from pathlib import ( @@ -6,22 +9,28 @@ import socket import sys import threading +from types import ( + TracebackType, +) +from typing import ( + Any, + Type, +) from web3._utils.threads import ( Timeout, ) +from web3.types import ( + RPCEndpoint, + RPCResponse, +) from .base import ( JSONBaseProvider, ) -try: - from json import JSONDecodeError -except ImportError: - JSONDecodeError = ValueError - -def get_ipc_socket(ipc_path, timeout=0.1): +def get_ipc_socket(ipc_path: str, timeout: float=0.1) -> socket.socket: if sys.platform == 'win32': # On Windows named pipe is used. Simulate socket with it. from web3._utils.windows import NamedPipe @@ -37,10 +46,10 @@ def get_ipc_socket(ipc_path, timeout=0.1): class PersistantSocket: sock = None - def __init__(self, ipc_path): + def __init__(self, ipc_path: str) -> None: self.ipc_path = ipc_path - def __enter__(self): + def __enter__(self) -> socket.socket: if not self.ipc_path: raise FileNotFoundError("cannot connect to IPC socket at path: %r" % self.ipc_path) @@ -48,7 +57,9 @@ def __enter__(self): self.sock = self._open() return self.sock - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, exc_type: Type[BaseException], exc_value: BaseException, traceback: TracebackType + ) -> None: # only close the socket if there was an error if exc_value is not None: try: @@ -57,16 +68,17 @@ def __exit__(self, exc_type, exc_value, traceback): pass self.sock = None - def _open(self): + def _open(self) -> socket.socket: return get_ipc_socket(self.ipc_path) - def reset(self): + def reset(self) -> socket.socket: self.sock.close() self.sock = self._open() return self.sock -def get_default_ipc_path(): +# type ignored b/c missing return statement is by design here +def get_default_ipc_path() -> str: # type: ignore if sys.platform == 'darwin': ipc_path = os.path.expanduser(os.path.join( "~", @@ -88,8 +100,8 @@ def get_default_ipc_path(): return ipc_path base_trinity_path = Path('~').expanduser() / '.local' / 'share' / 'trinity' - ipc_path = base_trinity_path / 'mainnet' / 'jsonrpc.ipc' - if ipc_path.exists(): + ipc_path = str(base_trinity_path / 'mainnet' / 'jsonrpc.ipc') + if Path(ipc_path).exists(): return str(ipc_path) elif sys.platform.startswith('linux') or sys.platform.startswith('freebsd'): @@ -112,8 +124,9 @@ def get_default_ipc_path(): return ipc_path base_trinity_path = Path('~').expanduser() / '.local' / 'share' / 'trinity' - ipc_path = base_trinity_path / 'mainnet' / 'jsonrpc.ipc' - if ipc_path.exists(): + # type ignored b/c ipc_path is already defined as a str above + ipc_path = base_trinity_path / 'mainnet' / 'jsonrpc.ipc' # type: ignore + if ipc_path.exists(): # type: ignore return str(ipc_path) elif sys.platform == 'win32': @@ -142,7 +155,8 @@ def get_default_ipc_path(): ) -def get_dev_ipc_path(): +# type ignored b/c missing return statement is by design here +def get_dev_ipc_path() -> str: # type: ignore if sys.platform == 'darwin': tmpdir = os.environ.get('TMPDIR', '') ipc_path = os.path.expanduser(os.path.join( @@ -190,7 +204,7 @@ class IPCProvider(JSONBaseProvider): logger = logging.getLogger("web3.providers.IPCProvider") _socket = None - def __init__(self, ipc_path=None, timeout=10, *args, **kwargs): + def __init__(self, ipc_path: str=None, timeout: int=10, *args: Any, **kwargs: Any) -> None: if ipc_path is None: self.ipc_path = get_default_ipc_path() elif isinstance(ipc_path, str) or isinstance(ipc_path, Path): @@ -201,9 +215,9 @@ def __init__(self, ipc_path=None, timeout=10, *args, **kwargs): self.timeout = timeout self._lock = threading.Lock() self._socket = PersistantSocket(self.ipc_path) - super().__init__(*args, **kwargs) + super().__init__() - def make_request(self, method, params): + def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: self.logger.debug("Making request IPC. Path: %s, Method: %s", self.ipc_path, method) request = self.encode_rpc_request(method, params) @@ -240,7 +254,7 @@ def make_request(self, method, params): # A valid JSON RPC response can only end in } or ] http://www.jsonrpc.org/specification -def has_valid_json_rpc_ending(raw_response): +def has_valid_json_rpc_ending(raw_response: bytes) -> bool: stripped_raw_response = raw_response.rstrip() for valid_ending in [b"}", b"]"]: if stripped_raw_response.endswith(valid_ending): diff --git a/web3/providers/rpc.py b/web3/providers/rpc.py index b939f2a9a0..8ce7f2e612 100644 --- a/web3/providers/rpc.py +++ b/web3/providers/rpc.py @@ -1,6 +1,15 @@ import logging import os +from typing import ( + Any, + Dict, + Iterable, + Tuple, +) +from eth_typing import ( + URI, +) from eth_utils import ( to_dict, ) @@ -17,14 +26,19 @@ from web3.middleware import ( http_retry_request_middleware, ) +from web3.types import ( + Middleware, + RPCEndpoint, + RPCResponse, +) from .base import ( JSONBaseProvider, ) -def get_default_endpoint(): - return os.environ.get('WEB3_HTTP_PROVIDER_URI', 'http://localhost:8545') +def get_default_endpoint() -> URI: + return URI(os.environ.get('WEB3_HTTP_PROVIDER_URI', 'http://localhost:8545')) class HTTPProvider(JSONBaseProvider): @@ -32,9 +46,10 @@ class HTTPProvider(JSONBaseProvider): endpoint_uri = None _request_args = None _request_kwargs = None - _middlewares = NamedElementOnion([(http_retry_request_middleware, 'http_retry_request')]) + # type ignored b/c conflict with _middlewares attr on BaseProvider + _middlewares: Tuple[Middleware, ...] = NamedElementOnion([(http_retry_request_middleware, 'http_retry_request')]) # type: ignore # noqa: E501 - def __init__(self, endpoint_uri=None, request_kwargs=None): + def __init__(self, endpoint_uri: URI=None, request_kwargs: Any=None) -> None: if endpoint_uri is None: self.endpoint_uri = get_default_endpoint() else: @@ -42,23 +57,23 @@ def __init__(self, endpoint_uri=None, request_kwargs=None): self._request_kwargs = request_kwargs or {} super().__init__() - def __str__(self): + def __str__(self) -> str: return "RPC connection {0}".format(self.endpoint_uri) @to_dict - def get_request_kwargs(self): + def get_request_kwargs(self) -> Iterable[Tuple[str, Any]]: if 'headers' not in self._request_kwargs: yield 'headers', self.get_request_headers() for key, value in self._request_kwargs.items(): yield key, value - def get_request_headers(self): + def get_request_headers(self) -> Dict[str, str]: return { 'Content-Type': 'application/json', 'User-Agent': construct_user_agent(str(type(self))), } - def make_request(self, method, params): + def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: self.logger.debug("Making request HTTP. URI: %s, Method: %s", self.endpoint_uri, method) request_data = self.encode_rpc_request(method, params) diff --git a/web3/providers/websocket.py b/web3/providers/websocket.py index a5df377e38..8d4b061bbc 100644 --- a/web3/providers/websocket.py +++ b/web3/providers/websocket.py @@ -5,7 +5,17 @@ from threading import ( Thread, ) +from types import ( + TracebackType, +) +from typing import ( + Any, + Type, +) +from eth_typing import ( + URI, +) import websockets from web3.exceptions import ( @@ -14,44 +24,52 @@ from web3.providers.base import ( JSONBaseProvider, ) +from web3.types import ( + RPCEndpoint, + RPCResponse, +) RESTRICTED_WEBSOCKET_KWARGS = {'uri', 'loop'} DEFAULT_WEBSOCKET_TIMEOUT = 10 -def _start_event_loop(loop): +def _start_event_loop(loop: asyncio.AbstractEventLoop) -> None: asyncio.set_event_loop(loop) loop.run_forever() loop.close() -def _get_threaded_loop(): +def _get_threaded_loop() -> asyncio.AbstractEventLoop: new_loop = asyncio.new_event_loop() thread_loop = Thread(target=_start_event_loop, args=(new_loop,), daemon=True) thread_loop.start() return new_loop -def get_default_endpoint(): - return os.environ.get('WEB3_WS_PROVIDER_URI', 'ws://127.0.0.1:8546') +def get_default_endpoint() -> URI: + return URI(os.environ.get('WEB3_WS_PROVIDER_URI', 'ws://127.0.0.1:8546')) class PersistentWebSocket: - def __init__(self, endpoint_uri, loop, websocket_kwargs): - self.ws = None + def __init__( + self, endpoint_uri: URI, loop: asyncio.AbstractEventLoop, websocket_kwargs: Any + ) -> None: + self.ws: websockets.WebSocketClientProtocol = None self.endpoint_uri = endpoint_uri self.loop = loop self.websocket_kwargs = websocket_kwargs - async def __aenter__(self): + async def __aenter__(self) -> websockets.WebSocketClientProtocol: if self.ws is None: self.ws = await websockets.connect( uri=self.endpoint_uri, loop=self.loop, **self.websocket_kwargs ) return self.ws - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__( + self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType + ) -> None: if exc_val is not None: try: await self.ws.close() @@ -66,10 +84,10 @@ class WebsocketProvider(JSONBaseProvider): def __init__( self, - endpoint_uri=None, - websocket_kwargs=None, - websocket_timeout=DEFAULT_WEBSOCKET_TIMEOUT - ): + endpoint_uri: URI=None, + websocket_kwargs: Any=None, + websocket_timeout: int=DEFAULT_WEBSOCKET_TIMEOUT, + ) -> None: self.endpoint_uri = endpoint_uri self.websocket_timeout = websocket_timeout if self.endpoint_uri is None: @@ -92,10 +110,10 @@ def __init__( ) super().__init__() - def __str__(self): + def __str__(self) -> str: return "WS connection {0}".format(self.endpoint_uri) - async def coro_make_request(self, request_data): + async def coro_make_request(self, request_data: bytes) -> RPCResponse: async with self.conn as conn: await asyncio.wait_for( conn.send(request_data), @@ -108,7 +126,7 @@ async def coro_make_request(self, request_data): ) ) - def make_request(self, method, params): + def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: self.logger.debug("Making request WebSocket. URI: %s, " "Method: %s", self.endpoint_uri, method) request_data = self.encode_rpc_request(method, params) diff --git a/web3/types.py b/web3/types.py index 3756ac341a..4dba5c18c9 100644 --- a/web3/types.py +++ b/web3/types.py @@ -26,6 +26,9 @@ from web3._utils.compat import ( TypedDict, ) +from web3.datastructures import ( + NamedElementOnion, +) Wei = NewType('Wei', int) @@ -147,6 +150,7 @@ GasPriceStrategy = Callable[[Any, TxParams], Wei] # 2 input to parent callable Any should be updated to Web3 once all type hints land Middleware = Callable[[Callable[[RPCEndpoint, Any], RPCResponse], Any], Any] +MiddlewareOnion = NamedElementOnion[str, Middleware] LogReceipt = TypedDict("LogReceipt", {