Skip to content

Commit

Permalink
Added caching to the async session in request.py and AsyncHTTPProvider (
Browse files Browse the repository at this point in the history
#2254)

* Added caching to the async session in request.py and AsyncHTTPProvider
  • Loading branch information
dbfreem authored Feb 24, 2022
1 parent 313d919 commit eed1ae4
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 7 deletions.
7 changes: 6 additions & 1 deletion docs/providers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,11 @@ AsyncHTTPProvider
be omitted from the URI.
* ``request_kwargs`` should be a dictionary of keyword arguments which
will be passed onto each http/https POST request made to your node.
* the ``cache_async_session()`` method allows you to use your own ``aiohttp.ClientSession`` object. This is an async method and not part of the constructor

.. code-block:: python
>>> from aiohttp import ClientSession
>>> from web3 import Web3, AsyncHTTPProvider
>>> from web3.eth import AsyncEth
>>> from web3.net import AsyncNet
Expand All @@ -396,7 +398,10 @@ AsyncHTTPProvider
... 'personal': (AsyncGethPersonal,),
... 'admin' : (AsyncGethAdmin,)})
... },
... middlewares=[]) # See supported middleware section below for middleware options
... middlewares=[] # See supported middleware section below for middleware options
... )
>>> custom_session = ClientSession() # If you want to pass in your own session
>>> await w3.provider.cache_async_session(custom_session) # This method is an async method so it needs to be handled accordingly
Under the hood, the ``AsyncHTTPProvider`` uses the python
`aiohttp <https://docs.aiohttp.org/en/stable/>`_ library for making requests.
Expand Down
1 change: 1 addition & 0 deletions newsfragments/2016.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added session caching to the AsyncHTTPProvider
22 changes: 22 additions & 0 deletions tests/core/providers/test_async_http_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@

import pytest

from aiohttp import (
ClientSession,
)

from web3._utils import (
request,
)
from web3.providers.async_rpc import (
AsyncHTTPProvider,
)


@pytest.mark.asyncio
async def test_user_provided_session() -> None:

session = ClientSession()
provider = AsyncHTTPProvider(endpoint_uri="http://mynode.local:8545")
await provider.cache_async_session(session)
assert len(request._async_session_cache) == 1
52 changes: 52 additions & 0 deletions tests/core/utilities/test_request.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import pytest

from aiohttp import (
ClientSession,
)
from requests import (
Session,
adapters,
Expand All @@ -10,6 +15,9 @@
from web3._utils import (
request,
)
from web3._utils.request import (
SessionCache,
)


class MockedResponse:
Expand Down Expand Up @@ -80,3 +88,47 @@ def test_precached_session(mocker):
assert isinstance(adapter, HTTPAdapter)
assert adapter._pool_connections == 100
assert adapter._pool_maxsize == 100


@pytest.mark.asyncio
async def test_async_precached_session(mocker):
# Add a session
session = ClientSession()
await request.cache_async_session(URI, session)
assert len(request._async_session_cache) == 1

# Make sure the session isn't duplicated
await request.cache_async_session(URI, session)
assert len(request._async_session_cache) == 1

# Make sure a request with a different URI adds another cached session
await request.cache_async_session(f"{URI}/test", session)
assert len(request._async_session_cache) == 2


def test_cache_session_class():

cache = SessionCache(2)
evicted_items = cache.cache("1", "Hello1")
assert cache.get_cache_entry("1") == "Hello1"
assert evicted_items is None

evicted_items = cache.cache("2", "Hello2")
assert cache.get_cache_entry("2") == "Hello2"
assert evicted_items is None

# Changing what is stored at a given cache key should not cause the
# anything to be evicted
evicted_items = cache.cache("1", "HelloChanged")
assert cache.get_cache_entry("1") == "HelloChanged"
assert evicted_items is None

evicted_items = cache.cache("3", "Hello3")
assert "2" in cache
assert "3" in cache
assert "1" not in cache

with pytest.raises(KeyError):
# This should throw a KeyError since the cache size was 2 and 3 were inserted
# the first inserted cached item was removed and returned in evicted items
cache.get_cache_entry("1")
67 changes: 61 additions & 6 deletions web3/_utils/request.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from collections import (
OrderedDict,
)
import os
import threading
from typing import (
Any,
Dict,
)

from aiohttp import (
Expand All @@ -18,6 +23,37 @@
)


class SessionCache:

def __init__(self, size: int):
self._size = size
self._data: OrderedDict[str, Any] = OrderedDict()

def cache(self, key: str, value: Any) -> Dict[str, Any]:
evicted_items = None
# If the key is already in the OrderedDict just update it
# and don't evict any values. Ideally, we could still check to see
# if there are too many items in the OrderedDict but that may rearrange
# the order it should be unlikely that the size could grow over the limit
if key not in self._data:
while len(self._data) >= self._size:
if evicted_items is None:
evicted_items = {}
k, v = self._data.popitem(last=False)
evicted_items[k] = v
self._data[key] = value
return evicted_items

def get_cache_entry(self, key: str) -> Any:
return self._data[key]

def __contains__(self, item: str) -> bool:
return item in self._data

def __len__(self) -> int:
return len(self._data)


def get_default_http_endpoint() -> URI:
return URI(os.environ.get('WEB3_HTTP_PROVIDER_URI', 'http://localhost:8545'))

Expand All @@ -27,11 +63,22 @@ def cache_session(endpoint_uri: URI, session: requests.Session) -> None:
_session_cache[cache_key] = session


async def cache_async_session(endpoint_uri: URI, session: ClientSession) -> None:
cache_key = generate_cache_key(endpoint_uri)
with _async_session_cache_lock:
evicted_items = _async_session_cache.cache(cache_key, session)
if evicted_items is not None:
for key, session in evicted_items.items():
await session.close()


def _remove_session(key: str, session: requests.Session) -> None:
session.close()


_session_cache = lru.LRU(8, callback=_remove_session)
_async_session_cache_lock = threading.Lock()
_async_session_cache = SessionCache(size=8)


def _get_session(endpoint_uri: URI) -> requests.Session:
Expand All @@ -41,6 +88,13 @@ def _get_session(endpoint_uri: URI) -> requests.Session:
return _session_cache[cache_key]


async def _get_async_session(endpoint_uri: URI) -> ClientSession:
cache_key = generate_cache_key(endpoint_uri)
if cache_key not in _async_session_cache:
await cache_async_session(endpoint_uri, ClientSession(raise_for_status=True))
return _async_session_cache.get_cache_entry(cache_key)


def make_post_request(endpoint_uri: URI, data: bytes, *args: Any, **kwargs: Any) -> bytes:
kwargs.setdefault('timeout', 10)
session = _get_session(endpoint_uri)
Expand All @@ -55,9 +109,10 @@ async def async_make_post_request(
endpoint_uri: URI, data: bytes, *args: Any, **kwargs: Any
) -> bytes:
kwargs.setdefault('timeout', ClientTimeout(10))
async with ClientSession(raise_for_status=True) as session:
async with session.post(endpoint_uri,
data=data,
*args,
**kwargs) as response:
return await response.read()
# https://github.com/ethereum/go-ethereum/issues/17069
session = await _get_async_session(endpoint_uri)
async with session.post(endpoint_uri,
data=data,
*args,
**kwargs) as response:
return await response.read()
7 changes: 7 additions & 0 deletions web3/providers/async_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
Union,
)

from aiohttp import (
ClientSession,
)
from eth_typing import (
URI,
)
Expand All @@ -20,6 +23,7 @@
)
from web3._utils.request import (
async_make_post_request,
cache_async_session as _cache_async_session,
get_default_http_endpoint,
)
from web3.types import (
Expand Down Expand Up @@ -50,6 +54,9 @@ def __init__(

super().__init__()

async def cache_async_session(self, session: ClientSession) -> None:
await _cache_async_session(self.endpoint_uri, session)

def __str__(self) -> str:
return "RPC connection {0}".format(self.endpoint_uri)

Expand Down

0 comments on commit eed1ae4

Please sign in to comment.