Skip to content

Commit

Permalink
HTTPClient refactor (#1170)
Browse files Browse the repository at this point in the history
  • Loading branch information
richardm-stripe authored Dec 13, 2023
1 parent 4464b6a commit 4435524
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 143 deletions.
2 changes: 2 additions & 0 deletions flake8_stripe/flake8_stripe.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class TypingImportsChecker:
"NotRequired",
"Self",
"Unpack",
"Awaitable",
"Never",
]

allowed_typing_imports = [
Expand Down
205 changes: 107 additions & 98 deletions stripe/_http_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import sys
import textwrap
import warnings
import email
import time
import random
Expand All @@ -13,8 +12,20 @@
from stripe._request_metrics import RequestMetrics
from stripe._error import APIConnectionError

from typing import Any, Dict, List, Optional, Tuple, ClassVar, Union, cast
from typing_extensions import NoReturn, TypedDict
from typing import (
Any,
Dict,
List,
Optional,
Tuple,
ClassVar,
Union,
cast,
)
from typing_extensions import (
NoReturn,
TypedDict,
)

# - Requests is the preferred HTTP library
# - Google App Engine has urlfetch
Expand Down Expand Up @@ -85,19 +96,12 @@ def new_default_http_client(*args: Any, **kwargs: Any) -> "HTTPClient":
impl = PycurlClient
else:
impl = Urllib2Client
if sys.version_info < (2, 7, 9):
warnings.warn(
"Warning: the Stripe library is falling back to urllib2 "
"because neither requests nor pycurl are installed. "
"urllib2's SSL implementation doesn't verify server "
"certificates. For improved security, we suggest installing "
"requests."
)

return impl(*args, **kwargs)


class HTTPClient(object):
class HTTPClientBase(object):

name: ClassVar[str]

class _Proxy(TypedDict):
Expand Down Expand Up @@ -135,92 +139,6 @@ def __init__(

self._thread_local = threading.local()

# TODO: more specific types here would be helpful
def request_with_retries(
self,
method,
url,
headers,
post_data=None,
*,
_usage: Optional[List[str]] = None,
) -> Tuple[Any, int, Any]:
return self._request_with_retries_internal(
method, url, headers, post_data, is_streaming=False, _usage=_usage
)

def request_stream_with_retries(
self,
method,
url,
headers,
post_data=None,
*,
_usage: Optional[List[str]] = None
) -> Tuple[Any, int, Any]:
return self._request_with_retries_internal(
method, url, headers, post_data, is_streaming=True, _usage=_usage
)

def _request_with_retries_internal(
self, method, url, headers, post_data, is_streaming, *, _usage=None
):
self._add_telemetry_header(headers)

num_retries = 0

while True:
request_start = _now_ms()

try:
if is_streaming:
response = self.request_stream(
method, url, headers, post_data
)
else:
response = self.request(method, url, headers, post_data)
connection_error = None
except APIConnectionError as e:
connection_error = e
response = None

if self._should_retry(response, connection_error, num_retries):
if connection_error:
_util.log_info(
"Encountered a retryable error %s"
% connection_error.user_message
)
num_retries += 1
sleep_time = self._sleep_time_seconds(num_retries, response)
_util.log_info(
(
"Initiating retry %i for request %s %s after "
"sleeping %.2f seconds."
% (num_retries, method, url, sleep_time)
)
)
time.sleep(sleep_time)
else:
if response is not None:
self._record_request_metrics(
response, request_start, _usage
)

return response
else:
assert connection_error is not None
raise connection_error

def request(self, method, url, headers, post_data=None):
raise NotImplementedError(
"HTTPClient subclasses must implement `request`"
)

def request_stream(self, method, url, headers, post_data=None):
raise NotImplementedError(
"HTTPClient subclasses must implement `request_stream`"
)

def _should_retry(self, response, api_connection_error, num_retries):
if num_retries >= self._max_network_retries():
return False
Expand Down Expand Up @@ -320,6 +238,96 @@ def _record_request_metrics(self, response, request_start, usage):
request_id, request_duration_ms, usage=usage
)


class HTTPClient(HTTPClientBase):
# TODO: more specific types here would be helpful
def request_with_retries(
self,
method,
url,
headers,
post_data=None,
*,
_usage: Optional[List[str]] = None
) -> Tuple[Any, int, Any]:
return self._request_with_retries_internal(
method, url, headers, post_data, is_streaming=False, _usage=_usage
)

def request_stream_with_retries(
self,
method,
url,
headers,
post_data=None,
*,
_usage: Optional[List[str]] = None
) -> Tuple[Any, int, Any]:
return self._request_with_retries_internal(
method, url, headers, post_data, is_streaming=True, _usage=_usage
)

def _request_with_retries_internal(
self, method, url, headers, post_data, is_streaming, *, _usage=None
):
self._add_telemetry_header(headers)

num_retries = 0

while True:
request_start = _now_ms()

try:
if is_streaming:
response = self.request_stream(
method, url, headers, post_data
)
else:
response = self.request(method, url, headers, post_data)
connection_error = None
except APIConnectionError as e:
connection_error = e
response = None

if self._should_retry(response, connection_error, num_retries):
if connection_error:
_util.log_info(
"Encountered a retryable error %s"
% connection_error.user_message
)
num_retries += 1
sleep_time = self._sleep_time_seconds(num_retries, response)
_util.log_info(
(
"Initiating retry %i for request %s %s after "
"sleeping %.2f seconds."
% (num_retries, method, url, sleep_time)
)
)
time.sleep(sleep_time)
else:
if response is not None:
self._record_request_metrics(
response, request_start, usage=_usage
)

return response
else:
assert connection_error is not None
raise connection_error

def request(self, method, url, headers, post_data=None, *, _usage=None):
raise NotImplementedError(
"HTTPClient subclasses must implement `request`"
)

def request_stream(
self, method, url, headers, post_data=None, *, _usage=None
):
raise NotImplementedError(
"HTTPClient subclasses must implement `request_stream`"
)

def close(self):
raise NotImplementedError(
"HTTPClient subclasses must implement `close`"
Expand All @@ -335,6 +343,7 @@ def __init__(
session: Optional["Session"] = None,
verify_ssl_certs: bool = True,
proxy: Optional[Union[str, HTTPClient._Proxy]] = None,
**kwargs
):
super(RequestsClient, self).__init__(
verify_ssl_certs=verify_ssl_certs, proxy=proxy
Expand Down
Loading

0 comments on commit 4435524

Please sign in to comment.