Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use dbapi PEP 249 errors #204

Merged
merged 3 commits into from
Jul 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import pytest
from trino.client import TrinoQuery, TrinoRequest, ClientSession
from trino.constants import DEFAULT_PORT
from trino.exceptions import TimeoutError


logger = trino.logging.get_logger(__name__)
Expand Down
57 changes: 56 additions & 1 deletion tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
SERVER_ADDRESS
from trino import constants
from trino.auth import KerberosAuthentication, _OAuth2TokenBearer
from trino.client import TrinoQuery, TrinoRequest, TrinoResult, ClientSession
from trino.client import TrinoQuery, TrinoRequest, TrinoResult, ClientSession, _DelayExponential, _retry_with, \
_RetryWithExponentialBackoff


@mock.patch("trino.client.TrinoRequest.http")
Expand Down Expand Up @@ -947,3 +948,57 @@ def json(self):

# Validate the result is an instance of TrinoResult
assert isinstance(result, TrinoResult)


def test_delay_exponential_without_jitter():
max_delay = 1200.0
get_delay = _DelayExponential(base=5, jitter=False, max_delay=max_delay)
results = [
10.0,
20.0,
40.0,
80.0,
160.0,
320.0,
640.0,
max_delay, # rather than 1280.0
max_delay, # rather than 2560.0
]
for i, result in enumerate(results, start=1):
assert get_delay(i) == result


def test_delay_exponential_with_jitter():
max_delay = 120.0
get_delay = _DelayExponential(base=10, jitter=False, max_delay=max_delay)
for i in range(10):
assert get_delay(i) <= max_delay


class SomeException(Exception):
pass


def test_retry_with():
max_attempts = 3
with_retry = _retry_with(
handle_retry=_RetryWithExponentialBackoff(),
handled_exceptions=[SomeException],
conditions={},
max_attempts=max_attempts,
)

class FailerUntil(object):
def __init__(self, until=1):
self.attempt = 0
self._until = until

def __call__(self):
self.attempt += 1
if self.attempt > self._until:
return
raise SomeException(self.attempt)

with_retry(FailerUntil(2).__call__)()
with pytest.raises(SomeException):
with_retry(FailerUntil(3).__call__)()
74 changes: 0 additions & 74 deletions tests/unit/test_exceptions.py

This file was deleted.

66 changes: 63 additions & 3 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,12 @@
"""

import copy
import functools
import os
import random
import re
import threading
import time
import urllib.parse
from datetime import datetime, timedelta, timezone
from decimal import Decimal
Expand Down Expand Up @@ -227,6 +230,34 @@ def __repr__(self):
)


class _DelayExponential(object):
def __init__(
self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600 # 100ms # 2 hours
):
self._base = base
self._exponent = exponent
self._jitter = jitter
self._max_delay = max_delay

def __call__(self, attempt):
delay = float(self._base) * (self._exponent ** attempt)
if self._jitter:
delay *= random.random()
delay = min(float(self._max_delay), delay)
return delay


class _RetryWithExponentialBackoff(object):
def __init__(
self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600 # 100ms # 2 hours
):
self._get_delay = _DelayExponential(base, exponent, jitter, max_delay)

def retry(self, func, args, kwargs, err, attempt):
delay = self._get_delay(attempt)
time.sleep(delay)


class TrinoRequest(object):
"""
Manage the HTTP requests of a Trino query.
Expand Down Expand Up @@ -286,7 +317,7 @@ def __init__(
redirect_handler: Any = None,
max_attempts: int = MAX_ATTEMPTS,
request_timeout: Union[float, Tuple[float, float]] = constants.DEFAULT_REQUEST_TIMEOUT,
handle_retry=exceptions.RetryWithExponentialBackoff(),
handle_retry=_RetryWithExponentialBackoff(),
verify: bool = True,
) -> None:
self._client_session = client_session
Expand Down Expand Up @@ -383,9 +414,9 @@ def max_attempts(self, value) -> None:
self._delete = self._http_session.delete
return

with_retry = exceptions.retry_with(
with_retry = _retry_with(
self._handle_retry,
exceptions=self._exceptions,
handled_exceptions=self._exceptions,
conditions=(
# need retry when there is no exception but the status code is 502, 503, or 504
lambda response: getattr(response, "status_code", None)
Expand Down Expand Up @@ -779,3 +810,32 @@ def cancelled(self) -> bool:
@property
def response_headers(self):
return self._response_headers


def _retry_with(handle_retry, handled_exceptions, conditions, max_attempts):
def wrapper(func):
@functools.wraps(func)
def decorated(*args, **kwargs):
error = None
result = None
for attempt in range(1, max_attempts + 1):
try:
result = func(*args, **kwargs)
if any(guard(result) for guard in conditions):
handle_retry.retry(func, args, kwargs, None, attempt)
continue
return result
except Exception as err:
error = err
if any(isinstance(err, exc) for exc in handled_exceptions):
handle_retry.retry(func, args, kwargs, err, attempt)
continue
break
logger.info("failed after %s attempts", attempt)
if error is not None:
raise error
return result

return decorated

return wrapper
Loading