Skip to content

Commit

Permalink
add response as param to get_timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
inyutin committed Aug 4, 2022
1 parent 0b4c116 commit 013cfbe
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 22 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class RetryOptionsBase:
...

@abc.abstractmethod
def get_timeout(self, attempt: int) -> float:
def get_timeout(self, attempt: int, response: Optional[Response] = None) -> float:
raise NotImplementedError

```
Expand All @@ -140,6 +140,10 @@ You can define your own timeouts logic or use:
- ```FibonacciRetry``` with backoff that looks like fibonacci sequence
- ```JitterRetry``` exponential retry with a bit of randomness

**Important**: you can server response as an parameter for calculating next timeout.
However this response can be None, server didn't make a response or you have set up ```raise_for_status=True```
Look here for an example: https://github.com/inyutin/aiohttp_retry/issues/59

#### Request Trace Context
`RetryClient` add *current attempt number* to `request_trace_ctx` (see examples,
for more info see [aiohttp doc](https://docs.aiohttp.org/en/stable/client_advanced.html#aiohttp-client-tracing)).
Expand Down
32 changes: 17 additions & 15 deletions aiohttp_retry/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ async def _do_request(self) -> ClientResponse:
current_attempt = 0
while True:
self._logger.debug(f"Attempt {current_attempt+1} out of {self._retry_options.attempts}")
if current_attempt > 0:
retry_wait = self._retry_options.get_timeout(current_attempt)
await asyncio.sleep(retry_wait)

current_attempt += 1
try:
Expand All @@ -80,22 +77,27 @@ async def _do_request(self) -> ClientResponse:
**self._trace_request_ctx,
},
)

if self._is_status_code_ok(response.status) or current_attempt == self._retry_options.attempts:
if self._raise_for_status:
response.raise_for_status()
self._response = response
return response

self._logger.debug(f"Retrying after response code: {response.status}")
retry_wait = self._retry_options.get_timeout(attempt=current_attempt, response=response)
except Exception as e:
if current_attempt < self._retry_options.attempts:
is_exc_valid = any([isinstance(e, exc) for exc in self._retry_options.exceptions])
if is_exc_valid:
self._logger.debug(f"Retrying after exception: {repr(e)}")
continue
if current_attempt >= self._retry_options.attempts:
raise e

raise e
is_exc_valid = any([isinstance(e, exc) for exc in self._retry_options.exceptions])
if not is_exc_valid:
raise e

if self._is_status_code_ok(response.status) or current_attempt == self._retry_options.attempts:
if self._raise_for_status:
response.raise_for_status()
self._response = response
return response
self._logger.debug(f"Retrying after exception: {repr(e)}")
retry_wait = self._retry_options.get_timeout(attempt=current_attempt, response=None)

self._logger.debug(f"Retrying after response code: {response.status}")
await asyncio.sleep(retry_wait)

def __await__(self) -> Generator[Any, None, ClientResponse]:
return self.__aenter__().__await__()
Expand Down
14 changes: 8 additions & 6 deletions aiohttp_retry/retry_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import Any, Callable, Iterable, List, Optional, Set, Type
from warnings import warn

from aiohttp import ClientResponse


class RetryOptionsBase:
def __init__(
Expand All @@ -24,7 +26,7 @@ def __init__(
self.retry_all_server_errors = retry_all_server_errors

@abc.abstractmethod
def get_timeout(self, attempt: int) -> float:
def get_timeout(self, attempt: int, response: Optional[ClientResponse] = None) -> float:
raise NotImplementedError


Expand All @@ -45,7 +47,7 @@ def __init__(
self._max_timeout: float = max_timeout
self._factor: float = factor

def get_timeout(self, attempt: int) -> float:
def get_timeout(self, attempt: int, response: Optional[ClientResponse] = None) -> float:
"""Return timeout with exponential backoff."""
timeout = self._start_timeout * (self._factor ** attempt)
return min(timeout, self._max_timeout)
Expand Down Expand Up @@ -73,7 +75,7 @@ def __init__(
self.max_timeout: float = max_timeout
self.random = random_func

def get_timeout(self, attempt: int) -> float:
def get_timeout(self, attempt: int, response: Optional[ClientResponse] = None) -> float:
"""Generate random timeouts."""
return self.min_timeout + self.random() * (self.max_timeout - self.min_timeout)

Expand All @@ -89,7 +91,7 @@ def __init__(
super().__init__(len(timeouts), statuses, exceptions, retry_all_server_errors)
self.timeouts = timeouts

def get_timeout(self, attempt: int) -> float:
def get_timeout(self, attempt: int, response: Optional[ClientResponse] = None) -> float:
"""timeouts from a defined list."""
return self.timeouts[attempt]

Expand All @@ -111,7 +113,7 @@ def __init__(
self.prev_step = 1.0
self.current_step = 1.0

def get_timeout(self, attempt: int) -> float:
def get_timeout(self, attempt: int, response: Optional[ClientResponse] = None) -> float:
new_current_step = self.prev_step + self.current_step
self.prev_step = self.current_step
self.current_step = new_current_step
Expand Down Expand Up @@ -148,6 +150,6 @@ def __init__(
self._factor: float = factor
self._random_interval_size = random_interval_size

def get_timeout(self, attempt: int) -> float:
def get_timeout(self, attempt: int, response: Optional[ClientResponse] = None) -> float:
timeout: float = super().get_timeout(attempt) + random.uniform(0, self._random_interval_size) ** self._factor
return timeout

0 comments on commit 013cfbe

Please sign in to comment.