Skip to content

Commit

Permalink
Add type hints to web3.providers
Browse files Browse the repository at this point in the history
  • Loading branch information
njgheorghita committed Nov 22, 2019
1 parent 0c79503 commit be7acaf
Show file tree
Hide file tree
Showing 12 changed files with 175 additions and 79 deletions.
3 changes: 2 additions & 1 deletion ens/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from eth_typing import (
HexAddress,
HexStr,
)
from hexbytes import (
HexBytes,
Expand All @@ -11,6 +12,6 @@
AUCTION_START_GAS_MARGINAL = 39000

EMPTY_SHA3_BYTES = HexBytes(b'\0' * 32)
EMPTY_ADDR_HEX = HexAddress('0x' + '00' * 20)
EMPTY_ADDR_HEX = HexAddress(HexStr('0x' + '00' * 20))

REVERSE_REGISTRAR_DOMAIN = 'addr.reverse'
5 changes: 4 additions & 1 deletion ens/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ChecksumAddress,
Hash32,
HexAddress,
HexStr,
)
from eth_utils import (
is_binary_address,
Expand Down Expand Up @@ -56,7 +57,9 @@
)


ENS_MAINNET_ADDR = ChecksumAddress(HexAddress('0x314159265dD8dbb310642f98f50C066173C1259b'))
ENS_MAINNET_ADDR = ChecksumAddress(
HexAddress(HexStr('0x314159265dD8dbb310642f98f50C066173C1259b'))
)


class ENS:
Expand Down
8 changes: 3 additions & 5 deletions web3/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@
from web3._utils.normalizers import (
abi_ens_resolver,
)
from web3.datastructures import (
NamedElementOnion,
)
from web3.eth import (
Eth,
)
Expand Down Expand Up @@ -93,8 +90,9 @@
from web3.testing import (
Testing,
)
from web3.types import (
from web3.types import ( # noqa: F401
Middleware,
MiddlewareOnion,
)
from web3.version import (
Version,
Expand Down Expand Up @@ -175,7 +173,7 @@ def __init__(
self.ens = ens

@property
def middleware_onion(self) -> NamedElementOnion[str, Middleware]:
def middleware_onion(self) -> MiddlewareOnion:
return self.manager.middleware_onion

@property
Expand Down
10 changes: 6 additions & 4 deletions web3/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Dict,
List,
NoReturn,
Optional,
Sequence,
Tuple,
)
Expand Down Expand Up @@ -40,8 +41,9 @@
AutoProvider,
BaseProvider,
)
from web3.types import (
from web3.types import ( # noqa: F401
Middleware,
MiddlewareOnion,
RPCResponse,
)

Expand Down Expand Up @@ -74,7 +76,7 @@ def __init__(
if middlewares is None:
middlewares = self.default_middlewares(web3)

self.middleware_onion: NamedElementOnion[str, Middleware] = NamedElementOnion(middlewares)
self.middleware_onion: MiddlewareOnion = NamedElementOnion(middlewares)

if provider is None:
self.provider = AutoProvider()
Expand Down Expand Up @@ -117,14 +119,14 @@ def default_middlewares(
def _make_request(self, method: str, params: Any) -> RPCResponse:
request_func = self.provider.request_func(
self.web3,
tuple(self.middleware_onion))
self.middleware_onion)
self.logger.debug("Making request. Method: %s", method)
return request_func(method, params)

async def _coro_make_request(self, method: str, params: Any) -> RPCResponse:
request_func = self.provider.request_func(
self.web3,
tuple(self.middleware_onion))
self.middleware_onion)
self.logger.debug("Making request. Method: %s", method)
return await request_func(method, params)

Expand Down
39 changes: 31 additions & 8 deletions web3/providers/auto.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
import os
from typing import (
Any,
Callable,
Dict,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from urllib.parse import (
urlparse,
)

from eth_typing import (
URI,
)

from web3.exceptions import (
CannotHandleRequest,
)
Expand All @@ -12,20 +26,26 @@
IPCProvider,
WebsocketProvider,
)
from web3.types import (
RPCEndpoint,
RPCResponse,
)

HTTP_SCHEMES = {'http', 'https'}
WS_SCHEMES = {'ws', 'wss'}


def load_provider_from_environment():
uri_string = os.environ.get('WEB3_PROVIDER_URI', '')
def load_provider_from_environment() -> BaseProvider:
uri_string = URI(os.environ.get('WEB3_PROVIDER_URI', ''))
if not uri_string:
return None

return load_provider_from_uri(uri_string)


def load_provider_from_uri(uri_string, headers=None):
def load_provider_from_uri(
uri_string: URI, headers: Dict[str, Tuple[str, str]]=None
) -> BaseProvider:
uri = urlparse(uri_string)
if uri.scheme == 'file':
return IPCProvider(uri.path)
Expand All @@ -52,7 +72,10 @@ class AutoProvider(BaseProvider):
)
_active_provider = None

def __init__(self, potential_providers=None):
def __init__(
self,
potential_providers: Sequence[Union[Callable[..., BaseProvider], Type[BaseProvider]]]=None
) -> None:
"""
:param iterable potential_providers: ordered series of provider classes to attempt with
Expand All @@ -65,17 +88,17 @@ def __init__(self, potential_providers=None):
else:
self._potential_providers = self.default_providers

def make_request(self, method, params):
def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:
try:
return self._proxy_request(method, params)
except IOError as exc:
return self._proxy_request(method, params, use_cache=False)

def isConnected(self):
def isConnected(self) -> bool:
provider = self._get_active_provider(use_cache=True)
return provider is not None and provider.isConnected()

def _proxy_request(self, method, params, use_cache=True):
def _proxy_request(self, method: RPCEndpoint, params: Any, use_cache: bool=True) -> RPCResponse:
provider = self._get_active_provider(use_cache)
if provider is None:
raise CannotHandleRequest(
Expand All @@ -87,7 +110,7 @@ def _proxy_request(self, method, params, use_cache=True):

return provider.make_request(method, params)

def _get_active_provider(self, use_cache):
def _get_active_provider(self, use_cache: bool) -> Optional[BaseProvider]:
if use_cache and self._active_provider is not None:
return self._active_provider

Expand Down
59 changes: 42 additions & 17 deletions web3/providers/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
import itertools
from typing import (
TYPE_CHECKING,
Any,
Callable,
Sequence,
Tuple,
)

from eth_utils import (
to_bytes,
Expand All @@ -11,26 +18,42 @@
from web3.middleware import (
combine_middlewares,
)
from web3.types import (
Middleware,
MiddlewareOnion,
RPCEndpoint,
RPCResponse,
)

if TYPE_CHECKING:
from web3 import Web3 # noqa: F401


class BaseProvider:
_middlewares = ()
_request_func_cache = (None, None) # a tuple of (all_middlewares, request_func)
_middlewares: Tuple[Middleware, ...] = ()
# a tuple of (all_middlewares, request_func)
_request_func_cache: Tuple[Tuple[Middleware, ...], Callable[..., RPCResponse]] = (None, None)

@property
def middlewares(self):
def middlewares(self) -> Tuple[Middleware, ...]:
return self._middlewares

@middlewares.setter
def middlewares(self, values):
self._middlewares = tuple(values)

def request_func(self, web3, outer_middlewares):
def middlewares(
self, values: MiddlewareOnion
) -> None:
# tuple(values) converts to MiddlewareOnion -> Tuple[Middleware, ...]
self._middlewares = tuple(values) # type: ignore

def request_func(
self, web3: "Web3", outer_middlewares: MiddlewareOnion
) -> Callable[..., RPCResponse]:
"""
@param outer_middlewares is an iterable of middlewares, ordered by first to execute
@returns a function that calls all the middleware and eventually self.make_request()
"""
all_middlewares = tuple(outer_middlewares) + tuple(self.middlewares)
# type ignored b/c tuple(MiddlewareOnion) converts to tuple of middlewares
all_middlewares: Tuple[Middleware] = tuple(outer_middlewares) + tuple(self.middlewares) # type: ignore # noqa: E501

cache_key = self._request_func_cache[0]
if cache_key is None or cache_key != all_middlewares:
Expand All @@ -40,29 +63,31 @@ def request_func(self, web3, outer_middlewares):
)
return self._request_func_cache[-1]

def _generate_request_func(self, web3, middlewares):
def _generate_request_func(
self, web3: "Web3", middlewares: Sequence[Middleware]
) -> Callable[..., RPCResponse]:
return combine_middlewares(
middlewares=middlewares,
web3=web3,
provider_request_fn=self.make_request,
)

def make_request(self, method, params):
def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:
raise NotImplementedError("Providers must implement this method")

def isConnected(self):
def isConnected(self) -> bool:
raise NotImplementedError("Providers must implement this method")


class JSONBaseProvider(BaseProvider):
def __init__(self):
def __init__(self) -> None:
self.request_counter = itertools.count()

def decode_rpc_response(self, response):
text_response = to_text(response)
def decode_rpc_response(self, raw_response: bytes) -> RPCResponse:
text_response = to_text(raw_response)
return FriendlyJsonSerde().json_decode(text_response)

def encode_rpc_request(self, method, params):
def encode_rpc_request(self, method: RPCEndpoint, params: Any) -> bytes:
rpc_dict = {
"jsonrpc": "2.0",
"method": method,
Expand All @@ -72,9 +97,9 @@ def encode_rpc_request(self, method, params):
encoded = FriendlyJsonSerde().json_encode(rpc_dict)
return to_bytes(text=encoded)

def isConnected(self):
def isConnected(self) -> bool:
try:
response = self.make_request('web3_clientVersion', [])
response = self.make_request(RPCEndpoint('web3_clientVersion'), [])
except IOError:
return False

Expand Down
1 change: 0 additions & 1 deletion web3/providers/eth_tester/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def not_implemented(*args: Any, **kwargs: Any) -> NoReturn:
raise NotImplementedError("RPC method not implemented")


# double check RPCResponse
@curry
def call_eth_tester(
fn_name: str, eth_tester: "EthereumTester", fn_args: Any, fn_kwargs: Any=None
Expand Down
4 changes: 2 additions & 2 deletions web3/providers/eth_tester/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ async def make_request(


class EthereumTesterProvider(BaseProvider):
middlewares = [
middlewares = (
default_transaction_fields_middleware,
ethereum_tester_middleware,
]
)
ethereum_tester = None
api_endpoints = None

Expand Down
Loading

0 comments on commit be7acaf

Please sign in to comment.