diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 6923e95b..840ee8f3 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -22,7 +22,6 @@ import pytest from trino.client import TrinoQuery, TrinoRequest, ClientSession from trino.constants import DEFAULT_PORT -from trino.exceptions import TimeoutError logger = trino.logging.get_logger(__name__) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 313a57d3..22630811 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -28,7 +28,8 @@ SERVER_ADDRESS from trino import constants from trino.auth import KerberosAuthentication, _OAuth2TokenBearer -from trino.client import TrinoQuery, TrinoRequest, TrinoResult, ClientSession +from trino.client import TrinoQuery, TrinoRequest, TrinoResult, ClientSession, _DelayExponential, _retry_with, \ + _RetryWithExponentialBackoff @mock.patch("trino.client.TrinoRequest.http") @@ -947,3 +948,57 @@ def json(self): # Validate the result is an instance of TrinoResult assert isinstance(result, TrinoResult) + + +def test_delay_exponential_without_jitter(): + max_delay = 1200.0 + get_delay = _DelayExponential(base=5, jitter=False, max_delay=max_delay) + results = [ + 10.0, + 20.0, + 40.0, + 80.0, + 160.0, + 320.0, + 640.0, + max_delay, # rather than 1280.0 + max_delay, # rather than 2560.0 + ] + for i, result in enumerate(results, start=1): + assert get_delay(i) == result + + +def test_delay_exponential_with_jitter(): + max_delay = 120.0 + get_delay = _DelayExponential(base=10, jitter=False, max_delay=max_delay) + for i in range(10): + assert get_delay(i) <= max_delay + + +class SomeException(Exception): + pass + + +def test_retry_with(): + max_attempts = 3 + with_retry = _retry_with( + handle_retry=_RetryWithExponentialBackoff(), + handled_exceptions=[SomeException], + conditions={}, + max_attempts=max_attempts, + ) + + class FailerUntil(object): + def __init__(self, until=1): + self.attempt = 0 + self._until = until + + def __call__(self): + self.attempt += 1 + if self.attempt > self._until: + return + raise SomeException(self.attempt) + + with_retry(FailerUntil(2).__call__)() + with pytest.raises(SomeException): + with_retry(FailerUntil(3).__call__)() diff --git a/tests/unit/test_exceptions.py b/tests/unit/test_exceptions.py deleted file mode 100644 index 512faa4e..00000000 --- a/tests/unit/test_exceptions.py +++ /dev/null @@ -1,74 +0,0 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - -This module defines exceptions for Trino operations. It follows the structure -defined in pep-0249. -""" - - -from trino import exceptions -import pytest - - -def test_delay_exponential_without_jitter(): - max_delay = 1200.0 - get_delay = exceptions.DelayExponential(base=5, jitter=False, max_delay=max_delay) - results = [ - 10.0, - 20.0, - 40.0, - 80.0, - 160.0, - 320.0, - 640.0, - max_delay, # rather than 1280.0 - max_delay, # rather than 2560.0 - ] - for i, result in enumerate(results, start=1): - assert get_delay(i) == result - - -def test_delay_exponential_with_jitter(): - max_delay = 120.0 - get_delay = exceptions.DelayExponential(base=10, jitter=False, max_delay=max_delay) - for i in range(10): - assert get_delay(i) <= max_delay - - -class SomeException(Exception): - pass - - -def test_retry_with(): - max_attempts = 3 - with_retry = exceptions.retry_with( - handle_retry=exceptions.RetryWithExponentialBackoff(), - exceptions=[SomeException], - conditions={}, - max_attempts=max_attempts, - ) - - class FailerUntil(object): - def __init__(self, until=1): - self.attempt = 0 - self._until = until - - def __call__(self): - self.attempt += 1 - if self.attempt > self._until: - return - raise SomeException(self.attempt) - - with_retry(FailerUntil(2).__call__)() - with pytest.raises(SomeException): - with_retry(FailerUntil(3).__call__)() diff --git a/trino/client.py b/trino/client.py index 3e198598..3dc75b70 100644 --- a/trino/client.py +++ b/trino/client.py @@ -34,9 +34,12 @@ """ import copy +import functools import os +import random import re import threading +import time import urllib.parse from datetime import datetime, timedelta, timezone from decimal import Decimal @@ -227,6 +230,34 @@ def __repr__(self): ) +class _DelayExponential(object): + def __init__( + self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600 # 100ms # 2 hours + ): + self._base = base + self._exponent = exponent + self._jitter = jitter + self._max_delay = max_delay + + def __call__(self, attempt): + delay = float(self._base) * (self._exponent ** attempt) + if self._jitter: + delay *= random.random() + delay = min(float(self._max_delay), delay) + return delay + + +class _RetryWithExponentialBackoff(object): + def __init__( + self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600 # 100ms # 2 hours + ): + self._get_delay = _DelayExponential(base, exponent, jitter, max_delay) + + def retry(self, func, args, kwargs, err, attempt): + delay = self._get_delay(attempt) + time.sleep(delay) + + class TrinoRequest(object): """ Manage the HTTP requests of a Trino query. @@ -286,7 +317,7 @@ def __init__( redirect_handler: Any = None, max_attempts: int = MAX_ATTEMPTS, request_timeout: Union[float, Tuple[float, float]] = constants.DEFAULT_REQUEST_TIMEOUT, - handle_retry=exceptions.RetryWithExponentialBackoff(), + handle_retry=_RetryWithExponentialBackoff(), verify: bool = True, ) -> None: self._client_session = client_session @@ -383,9 +414,9 @@ def max_attempts(self, value) -> None: self._delete = self._http_session.delete return - with_retry = exceptions.retry_with( + with_retry = _retry_with( self._handle_retry, - exceptions=self._exceptions, + handled_exceptions=self._exceptions, conditions=( # need retry when there is no exception but the status code is 502, 503, or 504 lambda response: getattr(response, "status_code", None) @@ -779,3 +810,32 @@ def cancelled(self) -> bool: @property def response_headers(self): return self._response_headers + + +def _retry_with(handle_retry, handled_exceptions, conditions, max_attempts): + def wrapper(func): + @functools.wraps(func) + def decorated(*args, **kwargs): + error = None + result = None + for attempt in range(1, max_attempts + 1): + try: + result = func(*args, **kwargs) + if any(guard(result) for guard in conditions): + handle_retry.retry(func, args, kwargs, None, attempt) + continue + return result + except Exception as err: + error = err + if any(isinstance(err, exc) for exc in handled_exceptions): + handle_retry.retry(func, args, kwargs, err, attempt) + continue + break + logger.info("failed after %s attempts", attempt) + if error is not None: + raise error + return result + + return decorated + + return wrapper diff --git a/trino/exceptions.py b/trino/exceptions.py index cd311ab7..86708fd0 100644 --- a/trino/exceptions.py +++ b/trino/exceptions.py @@ -16,48 +16,62 @@ """ -import functools -import random -import time - import trino.logging logger = trino.logging.get_logger(__name__) -class HttpError(Exception): +# PEP 249 Errors +class Error(Exception): pass -class Http502Error(Exception): +class Warning(Exception): pass -class Http503Error(HttpError): +class InterfaceError(Error): pass -class Http504Error(HttpError): +class DatabaseError(Error): + pass + + +class InternalError(DatabaseError): + pass + + +class OperationalError(DatabaseError): + pass + + +class ProgrammingError(DatabaseError): + pass + + +class IntegrityError(DatabaseError): pass -class TrinoError(Exception): +class DataError(DatabaseError): pass -class TimeoutError(Exception): +class NotSupportedError(DatabaseError): pass -class TrinoAuthError(Exception): +# dbapi module errors (extending PEP 249 errors) +class TrinoAuthError(OperationalError): pass -class TrinoDataError(Exception): +class TrinoDataError(NotSupportedError): pass -class TrinoQueryError(Exception): +class TrinoQueryError(Error): def __init__(self, error, query_id=None): self._error = error self._query_id = query_id @@ -108,127 +122,46 @@ def __str__(self): return repr(self) -class TrinoExternalError(TrinoQueryError): +class TrinoExternalError(TrinoQueryError, OperationalError): pass -class TrinoInternalError(TrinoQueryError): - pass - - -class TrinoUserError(TrinoQueryError): - pass - - -def retry_with(handle_retry, exceptions, conditions, max_attempts): - def wrapper(func): - @functools.wraps(func) - def decorated(*args, **kwargs): - error = None - result = None - for attempt in range(1, max_attempts + 1): - try: - result = func(*args, **kwargs) - if any(guard(result) for guard in conditions): - handle_retry.retry(func, args, kwargs, None, attempt) - continue - return result - except Exception as err: - error = err - if any(isinstance(err, exc) for exc in exceptions): - handle_retry.retry(func, args, kwargs, err, attempt) - continue - break - logger.info("failed after %s attempts", attempt) - if error is not None: - raise error - return result - - return decorated - - return wrapper - - -class DelayExponential(object): - def __init__( - self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600 # 100ms # 2 hours - ): - self._base = base - self._exponent = exponent - self._jitter = jitter - self._max_delay = max_delay - - def __call__(self, attempt): - delay = float(self._base) * (self._exponent ** attempt) - if self._jitter: - delay *= random.random() - delay = min(float(self._max_delay), delay) - return delay - - -class RetryWithExponentialBackoff(object): - def __init__( - self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600 # 100ms # 2 hours - ): - self._get_delay = DelayExponential(base, exponent, jitter, max_delay) - - def retry(self, func, args, kwargs, err, attempt): - delay = self._get_delay(attempt) - time.sleep(delay) - - -# PEP 249 -class Error(Exception): +class TrinoInternalError(TrinoQueryError, InternalError): pass -class Warning(Exception): +class TrinoUserError(TrinoQueryError, ProgrammingError): pass -class InterfaceError(Error): - pass - - -class DatabaseError(Error): - pass - - -class InternalError(DatabaseError): - pass - - -class OperationalError(DatabaseError): - pass - - -class ProgrammingError(DatabaseError): +class FailedToObtainAddedPrepareHeader(Error): + """ + Raise this exception when unable to find the 'X-Trino-Added-Prepare' + header in the response of a PREPARE statement request. + """ pass -class IntegrityError(DatabaseError): +class FailedToObtainDeallocatedPrepareHeader(Error): + """ + Raise this exception when unable to find the 'X-Trino-Deallocated-Prepare' + header in the response of a DEALLOCATED statement request. + """ pass -class DataError(DatabaseError): +# client module errors +class HttpError(Exception): pass -class NotSupportedError(DatabaseError): +class Http502Error(HttpError): pass -class FailedToObtainAddedPrepareHeader(Error): - """ - Raise this exception when unable to find the 'X-Trino-Added-Prepare' - header in the response of a PREPARE statement request. - """ +class Http503Error(HttpError): pass -class FailedToObtainDeallocatedPrepareHeader(Error): - """ - Raise this exception when unable to find the 'X-Trino-Deallocated-Prepare' - header in the response of a DEALLOCATED statement request. - """ +class Http504Error(HttpError): pass