diff --git a/sdk/core/azure-core/azure/core/_enum_meta.py b/sdk/core/azure-core/azure/core/_enum_meta.py index 745f35b2c407..3015ce3faf24 100644 --- a/sdk/core/azure-core/azure/core/_enum_meta.py +++ b/sdk/core/azure-core/azure/core/_enum_meta.py @@ -23,8 +23,8 @@ # IN THE SOFTWARE. # # -------------------------------------------------------------------------- - -from enum import EnumMeta +from typing import Any +from enum import EnumMeta, Enum class CaseInsensitiveEnumMeta(EnumMeta): @@ -43,13 +43,13 @@ class MyCustomEnum(str, Enum, metaclass=CaseInsensitiveEnumMeta): """ - def __getitem__(cls, name): + def __getitem__(cls, name: str) -> Any: # disabling pylint bc of pylint bug https://github.com/PyCQA/astroid/issues/713 return super( # pylint: disable=no-value-for-parameter CaseInsensitiveEnumMeta, cls ).__getitem__(name.upper()) - def __getattr__(cls, name): + def __getattr__(cls, name: str) -> Enum: """Return the enum member matching `name` We use __getattr__ instead of descriptors or inserting into the enum class' __dict__ in order to support `name` and `value` being both diff --git a/sdk/core/azure-core/azure/core/_match_conditions.py b/sdk/core/azure-core/azure/core/_match_conditions.py index c3036faa8781..ee4a0c8278c9 100644 --- a/sdk/core/azure-core/azure/core/_match_conditions.py +++ b/sdk/core/azure-core/azure/core/_match_conditions.py @@ -30,8 +30,17 @@ class MatchConditions(Enum): """An enum to describe match conditions.""" - Unconditionally = 1 # Matches any condition - IfNotModified = 2 # If the target object is not modified. Usually it maps to etag= - IfModified = 3 # Only if the target object is modified. Usually it maps to etag!= - IfPresent = 4 # If the target object exists. Usually it maps to etag='*' - IfMissing = 5 # If the target object does not exist. Usually it maps to etag!='*' + Unconditionally = 1 + """Matches any condition""" + + IfNotModified = 2 + """If the target object is not modified. Usually it maps to etag=""" + + IfModified = 3 + """Only if the target object is modified. Usually it maps to etag!=""" + + IfPresent = 4 + """If the target object exists. Usually it maps to etag='*'""" + + IfMissing = 5 + """If the target object does not exist. Usually it maps to etag!='*'""" diff --git a/sdk/core/azure-core/azure/core/_pipeline_client.py b/sdk/core/azure-core/azure/core/_pipeline_client.py index 9c5ea5bbe91a..6460eb0a04f4 100644 --- a/sdk/core/azure-core/azure/core/_pipeline_client.py +++ b/sdk/core/azure-core/azure/core/_pipeline_client.py @@ -27,7 +27,6 @@ import logging from collections.abc import Iterable from typing import ( - Any, TypeVar, TYPE_CHECKING, ) @@ -169,8 +168,7 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use return Pipeline(transport, policies) - def send_request(self, request, **kwargs): - # type: (HTTPRequestType, Any) -> HTTPResponseType + def send_request(self, request: "HTTPRequestType", **kwargs) -> "HTTPResponseType": """Method that runs the network request through the client's chained policies. >>> from azure.core.rest import HttpRequest diff --git a/sdk/core/azure-core/azure/core/_pipeline_client_async.py b/sdk/core/azure-core/azure/core/_pipeline_client_async.py index d528b12ae7a0..8207e4d4216d 100644 --- a/sdk/core/azure-core/azure/core/_pipeline_client_async.py +++ b/sdk/core/azure-core/azure/core/_pipeline_client_async.py @@ -26,7 +26,16 @@ import logging import collections.abc -from typing import Any, Awaitable, TypeVar +from typing import ( + Any, + Awaitable, + TypeVar, + AsyncContextManager, + Generator, + cast, + TYPE_CHECKING, +) +from typing_extensions import Protocol from .configuration import Configuration from .pipeline import AsyncPipeline from .pipeline.transport._base import PipelineClientBase @@ -38,33 +47,88 @@ AsyncRetryPolicy, ) + +if TYPE_CHECKING: # Protocol and non-Protocol can't mix in Python 3.7 + + class _AsyncContextManagerCloseable(AsyncContextManager, Protocol): + """Defines a context manager that is closeable at the same time.""" + + async def close(self): + ... + + HTTPRequestType = TypeVar("HTTPRequestType") -AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType") +AsyncHTTPResponseType = TypeVar( + "AsyncHTTPResponseType", bound="_AsyncContextManagerCloseable" +) _LOGGER = logging.getLogger(__name__) -class _AsyncContextManager(collections.abc.Awaitable): - def __init__(self, wrapped: collections.abc.Awaitable): +class _Coroutine(Awaitable[AsyncHTTPResponseType]): + """Wrapper to get both context manager and awaitable in place. + + Naming it "_Coroutine" because if you don't await it makes the error message easier: + >>> result = client.send_request(request) + >>> result.text() + AttributeError: '_Coroutine' object has no attribute 'text' + + Indeed, the message for calling a coroutine without waiting would be: + AttributeError: 'coroutine' object has no attribute 'text' + + This allows the dev to either use the "async with" syntax, or simply the object directly. + It's also why "send_request" is not declared as async, since it couldn't be both easily. + + "wrapped" must be an awaitable that returns an object that: + - has an async "close()" + - has an "__aexit__" method (IOW, is an async context manager) + + This permits this code to work for both requests. + + ```python + from azure.core import AsyncPipelineClient + from azure.core.rest import HttpRequest + + async def main(): + + request = HttpRequest("GET", "https://httpbin.org/user-agent") + async with AsyncPipelineClient("https://httpbin.org/") as client: + # Can be used directly + result = await client.send_request(request) + print(result.text()) + + # Can be used as an async context manager + async with client.send_request(request) as result: + print(result.text()) + ``` + + :param wrapped: Must be an awaitable the returns an async context manager that supports async "close()" + """ + + def __init__(self, wrapped: Awaitable[AsyncHTTPResponseType]) -> None: super().__init__() - self.wrapped = wrapped - self.response = None + self._wrapped = wrapped + # If someone tries to use the object without awaiting, they will get a + # AttributeError: '_Coroutine' object has no attribute 'text' + self._response: AsyncHTTPResponseType = cast(AsyncHTTPResponseType, None) - def __await__(self): - return self.wrapped.__await__() + def __await__(self) -> Generator[Any, None, AsyncHTTPResponseType]: + return self._wrapped.__await__() - async def __aenter__(self): - self.response = await self - return self.response + async def __aenter__(self) -> AsyncHTTPResponseType: + self._response = await self + return self._response - async def __aexit__(self, *args): - await self.response.__aexit__(*args) + async def __aexit__(self, *args) -> None: + await self._response.__aexit__(*args) - async def close(self): - await self.response.close() + async def close(self) -> None: + await self._response.close() -class AsyncPipelineClient(PipelineClientBase): +class AsyncPipelineClient( + PipelineClientBase, AsyncContextManager["AsyncPipelineClient"] +): """Service client core methods. Builds an AsyncPipeline client. @@ -212,4 +276,4 @@ def send_request( :rtype: ~azure.core.rest.AsyncHttpResponse """ wrapped = self._make_pipeline_call(request, stream=stream, **kwargs) - return _AsyncContextManager(wrapped=wrapped) + return _Coroutine(wrapped=wrapped) diff --git a/sdk/core/azure-core/azure/core/async_paging.py b/sdk/core/azure-core/azure/core/async_paging.py index 9f5540329b5c..62f7eda3ae6f 100644 --- a/sdk/core/azure-core/azure/core/async_paging.py +++ b/sdk/core/azure-core/azure/core/async_paging.py @@ -33,6 +33,7 @@ Tuple, Optional, Awaitable, + Any, ) from .exceptions import AzureError @@ -85,10 +86,10 @@ def __init__( self._extract_data = extract_data self.continuation_token = continuation_token self._did_a_call_already = False - self._response = None - self._current_page = None + self._response: Optional[ResponseType] = None + self._current_page: Optional[AsyncIterator[ReturnType]] = None - async def __anext__(self): + async def __anext__(self) -> AsyncIterator[ReturnType]: if self.continuation_token is None and self._did_a_call_already: raise StopAsyncIteration("End of paging") try: @@ -112,7 +113,7 @@ async def __anext__(self): class AsyncItemPaged(AsyncIterator[ReturnType]): - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: """Return an async iterator of items. args and kwargs will be passed to the AsyncPageIterator constructor directly, @@ -120,10 +121,8 @@ def __init__(self, *args, **kwargs) -> None: """ self._args = args self._kwargs = kwargs - self._page_iterator = ( - None - ) # type: Optional[AsyncIterator[AsyncIterator[ReturnType]]] - self._page = None # type: Optional[AsyncIterator[ReturnType]] + self._page_iterator: Optional[AsyncIterator[AsyncIterator[ReturnType]]] = None + self._page: Optional[AsyncIterator[ReturnType]] = None self._page_iterator_class = self._kwargs.pop( "page_iterator_class", AsyncPageIterator ) diff --git a/sdk/core/azure-core/azure/core/configuration.py b/sdk/core/azure-core/azure/core/configuration.py index b6f2f9f3c604..02d51a3f195d 100644 --- a/sdk/core/azure-core/azure/core/configuration.py +++ b/sdk/core/azure-core/azure/core/configuration.py @@ -23,9 +23,10 @@ # IN THE SOFTWARE. # # -------------------------------------------------------------------------- +from typing import Union, Optional -class Configuration(object): +class Configuration: """Provides the home for all of the configurable policies in the pipeline. A new Configuration object provides no default policies and does not specify in what @@ -88,16 +89,17 @@ def __init__(self, **kwargs): self.polling_interval = kwargs.get("polling_interval", 30) -class ConnectionConfiguration(object): +class ConnectionConfiguration: """HTTP transport connection configuration settings. Common properties that can be configured on all transports. Found in the Configuration object. - :keyword int connection_timeout: A single float in seconds for the connection timeout. Defaults to 300 seconds. - :keyword int read_timeout: A single float in seconds for the read timeout. Defaults to 300 seconds. - :keyword bool connection_verify: SSL certificate verification. Enabled by default. Set to False to disable, + :keyword float connection_timeout: A single float in seconds for the connection timeout. Defaults to 300 seconds. + :keyword float read_timeout: A single float in seconds for the read timeout. Defaults to 300 seconds. + :keyword connection_verify: SSL certificate verification. Enabled by default. Set to False to disable, alternatively can be set to the path to a CA_BUNDLE file or directory with certificates of trusted CAs. + :paramtype connection_verify: bool or str :keyword str connection_cert: Client-side certificates. You can specify a local cert to use as client side certificate, as a single file (containing the private key and the certificate) or as a tuple of both files' paths. :keyword int connection_data_block_size: The block size of data sent over the connection. Defaults to 4096 bytes. @@ -112,9 +114,18 @@ class ConnectionConfiguration(object): :caption: Configuring transport connection settings. """ - def __init__(self, **kwargs): - self.timeout = kwargs.pop("connection_timeout", 300) - self.read_timeout = kwargs.pop("read_timeout", 300) - self.verify = kwargs.pop("connection_verify", True) - self.cert = kwargs.pop("connection_cert", None) - self.data_block_size = kwargs.pop("connection_data_block_size", 4096) + def __init__( + self, # pylint: disable=unused-argument + *, + connection_timeout: float = 300, + read_timeout: float = 300, + connection_verify: Union[bool, str] = True, + connection_cert: Optional[str] = None, + connection_data_block_size: int = 4096, + **kwargs + ) -> None: + self.timeout = connection_timeout + self.read_timeout = read_timeout + self.verify = connection_verify + self.cert = connection_cert + self.data_block_size = connection_data_block_size diff --git a/sdk/core/azure-core/azure/core/credentials.py b/sdk/core/azure-core/azure/core/credentials.py index e1b0627ec83d..dbf09d7d2e8f 100644 --- a/sdk/core/azure-core/azure/core/credentials.py +++ b/sdk/core/azure-core/azure/core/credentials.py @@ -55,7 +55,7 @@ def get_token( ] -class AzureKeyCredential(object): +class AzureKeyCredential: """Credential type used for authenticating to an Azure service. It provides the ability to update the key without creating a new client. @@ -63,23 +63,20 @@ class AzureKeyCredential(object): :raises: TypeError """ - def __init__(self, key): - # type: (str) -> None + def __init__(self, key: str) -> None: if not isinstance(key, str): raise TypeError("key must be a string.") - self._key = key # type: str + self._key = key @property - def key(self): - # type () -> str + def key(self) -> str: """The value of the configured key. :rtype: str """ return self._key - def update(self, key): - # type: (str) -> None + def update(self, key: str) -> None: """Update the key. This can be used when you've regenerated your service key and want @@ -95,7 +92,7 @@ def update(self, key): self._key = key -class AzureSasCredential(object): +class AzureSasCredential: """Credential type used for authenticating to an Azure service. It provides the ability to update the shared access signature without creating a new client. @@ -103,23 +100,20 @@ class AzureSasCredential(object): :raises: TypeError """ - def __init__(self, signature): - # type: (str) -> None + def __init__(self, signature: str) -> None: if not isinstance(signature, str): raise TypeError("signature must be a string.") - self._signature = signature # type: str + self._signature = signature @property - def signature(self): - # type () -> str + def signature(self) -> str: """The value of the configured shared access signature. :rtype: str """ return self._signature - def update(self, signature): - # type: (str) -> None + def update(self, signature: str) -> None: """Update the shared access signature. This can be used when you've regenerated your shared access signature and want @@ -135,7 +129,7 @@ def update(self, signature): self._signature = signature -class AzureNamedKeyCredential(object): +class AzureNamedKeyCredential: """Credential type used for working with any service needing a named key that follows patterns established by the other credential types. @@ -144,23 +138,20 @@ class AzureNamedKeyCredential(object): :raises: TypeError """ - def __init__(self, name, key): - # type: (str, str) -> None + def __init__(self, name: str, key: str) -> None: if not isinstance(name, str) or not isinstance(key, str): raise TypeError("Both name and key must be strings.") self._credential = AzureNamedKey(name, key) @property - def named_key(self): - # type () -> AzureNamedKey + def named_key(self) -> AzureNamedKey: """The value of the configured name. :rtype: AzureNamedKey """ return self._credential - def update(self, name, key): - # type: (str, str) -> None + def update(self, name: str, key: str) -> None: """Update the named key credential. Both name and key must be provided in order to update the named key credential. diff --git a/sdk/core/azure-core/azure/core/exceptions.py b/sdk/core/azure-core/azure/core/exceptions.py index 0210456cd7fe..177c6070ab4c 100644 --- a/sdk/core/azure-core/azure/core/exceptions.py +++ b/sdk/core/azure-core/azure/core/exceptions.py @@ -28,7 +28,7 @@ import logging import sys -from typing import Callable, Any, Dict, Optional, List, Union, Type, TYPE_CHECKING +from typing import Callable, Any, Optional, Union, Type, List, Dict, TYPE_CHECKING _LOGGER = logging.getLogger(__name__) @@ -58,8 +58,7 @@ ] -def raise_with_traceback(exception, *args, **kwargs): - # type: (Callable, Any, Any) -> None +def raise_with_traceback(exception: Callable, *args, **kwargs) -> None: """Raise exception with a specified traceback. This MUST be called inside a "except" clause. @@ -81,7 +80,7 @@ def raise_with_traceback(exception, *args, **kwargs): raise error -class ErrorMap(object): +class ErrorMap: """Error Map class. To be used in map_error method, behaves like a dictionary. It returns the error type if it is found in custom_error_map. Or return default_error @@ -110,7 +109,7 @@ def map_error(status_code, response, error_map): raise error -class ODataV4Format(object): +class ODataV4Format: """Class to describe OData V4 error format. http://docs.oasis-open.org/odata/odata-json-format/v4.0/os/odata-json-format-v4.0-os.html#_Toc372793091 @@ -155,14 +154,14 @@ class ODataV4Format(object): DETAILS_LABEL = "details" INNERERROR_LABEL = "innererror" - def __init__(self, json_object): + def __init__(self, json_object: Dict[str, Any]): if "error" in json_object: json_object = json_object["error"] - cls = self.__class__ # type: Type[ODataV4Format] + cls: Type[ODataV4Format] = self.__class__ # Required fields, but assume they could be missing still to be robust - self.code = json_object.get(cls.CODE_LABEL) # type: Optional[str] - self.message = json_object.get(cls.MESSAGE_LABEL) # type: Optional[str] + self.code: Optional[str] = json_object.get(cls.CODE_LABEL) + self.message: Optional[str] = json_object.get(cls.MESSAGE_LABEL) if not (self.code or self.message): raise ValueError( @@ -171,19 +170,17 @@ def __init__(self, json_object): ) # Optional fields - self.target = json_object.get(cls.TARGET_LABEL) # type: Optional[str] + self.target: Optional[str] = json_object.get(cls.TARGET_LABEL) # details is recursive of this very format - self.details = [] # type: List[ODataV4Format] + self.details: List[ODataV4Format] = [] for detail_node in json_object.get(cls.DETAILS_LABEL) or []: try: self.details.append(self.__class__(detail_node)) except Exception: # pylint: disable=broad-except pass - self.innererror = json_object.get( - cls.INNERERROR_LABEL, {} - ) # type: Dict[str, Any] + self.innererror: Dict[str, Any] = json_object.get(cls.INNERERROR_LABEL, {}) @property def error(self): @@ -198,9 +195,8 @@ def error(self): def __str__(self): return "({}) {}\n{}".format(self.code, self.message, self.message_details()) - def message_details(self): + def message_details(self) -> str: """Return a detailled string of the error.""" - # () -> str error_str = "Code: {}".format(self.code) error_str += "\nMessage: {}".format(self.message) if self.target: @@ -309,16 +305,14 @@ def __init__(self, message=None, response=None, **kwargs): # old autorest are setting "error" before calling __init__, so it might be there already # transferring into self.model - model = kwargs.pop("model", None) # type: Optional[msrest.serialization.Model] + model: Optional[Any] = kwargs.pop("model", None) if model is not None: # autorest v5 self.model = model else: # autorest azure-core, for KV 1.0, Storage 12.0, etc. - self.model = getattr( - self, "error", None - ) # type: Optional[msrest.serialization.Model] - self.error = self._parse_odata_body( + self.model: Optional[Any] = getattr(self, "error", None) + self.error: Optional[ODataV4Format] = self._parse_odata_body( error_format, response - ) # type: Optional[ODataV4Format] + ) # By priority, message is: # - odatav4 message, OR @@ -334,8 +328,9 @@ def __init__(self, message=None, response=None, **kwargs): super(HttpResponseError, self).__init__(message=message, **kwargs) @staticmethod - def _parse_odata_body(error_format, response): - # type: (Type[ODataV4Format], _HttpResponseBase) -> Optional[ODataV4Format] + def _parse_odata_body( + error_format: Type[ODataV4Format], response: "_HttpResponseBase" + ) -> Optional[ODataV4Format]: try: odata_json = json.loads(response.text()) return error_format(odata_json) @@ -415,11 +410,9 @@ class ODataV4Error(HttpResponseError): _ERROR_FORMAT = ODataV4Format - def __init__(self, response, **kwargs): - # type: (_HttpResponseBase, Any) -> None - + def __init__(self, response: "_HttpResponseBase", **kwargs) -> None: # Ensure field are declared, whatever can happen afterwards - self.odata_json = None # type: Optional[Dict[str, Any]] + self.odata_json: Optional[Dict[str, Any]] = None try: self.odata_json = json.loads(response.text()) odata_message = self.odata_json.setdefault("error", {}).get("message") @@ -427,18 +420,18 @@ def __init__(self, response, **kwargs): # If the body is not JSON valid, just stop now odata_message = None - self.code = None # type: Optional[str] - self.message = kwargs.get("message", odata_message) # type: Optional[str] - self.target = None # type: Optional[str] - self.details = [] # type: Optional[List[Any]] - self.innererror = {} # type: Optional[Dict[str, Any]] + self.code: Optional[str] = None + self.message: Optional[str] = kwargs.get("message", odata_message) + self.target: Optional[str] = None + self.details: Optional[List[Any]] = [] + self.innererror: Optional[Dict[str, Any]] = {} if self.message and "message" not in kwargs: kwargs["message"] = self.message super(ODataV4Error, self).__init__(response=response, **kwargs) - self._error_format = None # type: Optional[Union[str, ODataV4Format]] + self._error_format: Optional[Union[str, ODataV4Format]] = None if self.odata_json: try: error_node = self.odata_json["error"] diff --git a/sdk/core/azure-core/azure/core/messaging.py b/sdk/core/azure-core/azure/core/messaging.py index 05c018d08981..5140558335e7 100644 --- a/sdk/core/azure-core/azure/core/messaging.py +++ b/sdk/core/azure-core/azure/core/messaging.py @@ -7,7 +7,7 @@ import uuid from base64 import b64decode from datetime import datetime -from typing import cast, Union, Any, Optional, Dict +from typing import cast, Union, Any, Optional, Dict, TypeVar, Generic from .utils._utils import _convert_to_isoformat, TZ_UTC from .utils._messaging_shared import _get_json_content from .serialization import NULL @@ -16,7 +16,12 @@ __all__ = ["CloudEvent"] -class CloudEvent(object): # pylint:disable=too-many-instance-attributes +_Unset: Any = object() + +DataType = TypeVar("DataType") + + +class CloudEvent(Generic[DataType]): # pylint:disable=too-many-instance-attributes """Properties of the CloudEvent 1.0 Schema. All required parameters must be populated in order to send to Azure. @@ -25,79 +30,103 @@ class CloudEvent(object): # pylint:disable=too-many-instance-attributes :type source: str :param type: Required. Type of event related to the originating occurrence. :type type: str + :keyword specversion: Optional. The version of the CloudEvent spec. Defaults to "1.0" + :paramtype specversion: str :keyword data: Optional. Event data specific to the event type. - :type data: object + :paramtype data: object :keyword time: Optional. The time (in UTC) the event was generated. - :type time: ~datetime.datetime + :paramtype time: ~datetime.datetime :keyword dataschema: Optional. Identifies the schema that data adheres to. - :type dataschema: str + :paramtype dataschema: str :keyword datacontenttype: Optional. Content type of data value. - :type datacontenttype: str + :paramtype datacontenttype: str :keyword subject: Optional. This describes the subject of the event in the context of the event producer (identified by source). - :type subject: str - :keyword specversion: Optional. The version of the CloudEvent spec. Defaults to "1.0" - :type specversion: str + :paramtype subject: str :keyword id: Optional. An identifier for the event. The combination of id and source must be unique for each distinct event. If not provided, a random UUID will be generated and used. - :type id: Optional[str] + :paramtype id: Optional[str] :keyword extensions: Optional. A CloudEvent MAY include any number of additional context attributes with distinct names represented as key - value pairs. Each extension must be alphanumeric, lower cased and must not exceed the length of 20 characters. - :type extensions: Optional[Dict] - :ivar source: Identifies the context in which an event happened. The combination of id and source must - be unique for each distinct event. If publishing to a domain topic, source must be the domain topic name. - :vartype source: str - :ivar data: Event data specific to the event type. - :vartype data: object - :ivar type: Type of event related to the originating occurrence. - :vartype type: str - :ivar time: The time (in UTC) the event was generated. - :vartype time: ~datetime.datetime - :ivar dataschema: Identifies the schema that data adheres to. - :vartype dataschema: str - :ivar datacontenttype: Content type of data value. - :vartype datacontenttype: str - :ivar subject: This describes the subject of the event in the context of the event producer - (identified by source). - :vartype subject: str - :ivar specversion: Optional. The version of the CloudEvent spec. Defaults to "1.0" - :vartype specversion: str - :ivar id: An identifier for the event. The combination of id and source must be - unique for each distinct event. If not provided, a random UUID will be generated and used. - :vartype id: str - :ivar extensions: A CloudEvent MAY include any number of additional context attributes - with distinct names represented as key - value pairs. Each extension must be alphanumeric, lower cased - and must not exceed the length of 20 characters. - :vartype extensions: Dict + :paramtype extensions: Optional[dict] """ - def __init__(self, source, type, **kwargs): # pylint: disable=redefined-builtin - # type: (str, str, **Any) -> None - self.source = source # type: str - self.type = type # type: str - self.specversion = kwargs.pop("specversion", "1.0") # type: Optional[str] - self.id = kwargs.pop("id", str(uuid.uuid4())) # type: Optional[str] - self.time = kwargs.pop("time", datetime.now(TZ_UTC)) # type: Optional[datetime] - - self.datacontenttype = kwargs.pop( - "datacontenttype", None - ) # type: Optional[str] - self.dataschema = kwargs.pop("dataschema", None) # type: Optional[str] - self.subject = kwargs.pop("subject", None) # type: Optional[str] - self.data = kwargs.pop("data", None) # type: Optional[object] - - try: - self.extensions = kwargs.pop("extensions") # type: Optional[Dict] - for ( - key - ) in self.extensions.keys(): # type:ignore # extensions won't be None here + source: str + """Identifies the context in which an event happened. The combination of id and source must + be unique for each distinct event. If publishing to a domain topic, source must be the domain topic name.""" + + type: str # pylint: disable=redefined-builtin + """Type of event related to the originating occurrence.""" + + specversion: str = "1.0" + """The version of the CloudEvent spec. Defaults to "1.0" """ + + id: str # pylint: disable=redefined-builtin + """An identifier for the event. The combination of id and source must be + unique for each distinct event. If not provided, a random UUID will be generated and used.""" + + data: Optional[DataType] + """Event data specific to the event type.""" + + time: Optional[datetime] + """The time (in UTC) the event was generated.""" + + dataschema: Optional[str] + """Identifies the schema that data adheres to.""" + + datacontenttype: Optional[str] + """Content type of data value.""" + + subject: Optional[str] + """This describes the subject of the event in the context of the event producer + (identified by source)""" + + extensions: Optional[Dict[str, Any]] + """A CloudEvent MAY include any number of additional context attributes + with distinct names represented as key - value pairs. Each extension must be alphanumeric, lower cased + and must not exceed the length of 20 characters.""" + + def __init__( + self, + source: str, + type: str, # pylint: disable=redefined-builtin + *, + specversion: Optional[str] = None, + id: Optional[str] = None, # pylint: disable=redefined-builtin + time: Optional[datetime] = _Unset, + datacontenttype: Optional[str] = None, + dataschema: Optional[str] = None, + subject: Optional[str] = None, + data: Optional[DataType] = None, + extensions: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + self.source: str = source + self.type: str = type + + if specversion: + self.specversion: str = specversion + self.id: str = id if id else str(uuid.uuid4()) + + self.time: Optional[datetime] + if time is _Unset: + self.time = datetime.now(TZ_UTC) + else: + self.time = time + + self.datacontenttype: Optional[str] = datacontenttype + self.dataschema: Optional[str] = dataschema + self.subject: Optional[str] = subject + self.data: Optional[DataType] = data + + self.extensions: Optional[Dict[str, Any]] = extensions + if self.extensions: + for key in self.extensions.keys(): if not key.islower() or not key.isalnum(): raise ValueError( "Extension attributes should be lower cased and alphanumeric." ) - except KeyError: - self.extensions = None if kwargs: remaining = ", ".join(kwargs.keys()) @@ -106,21 +135,20 @@ def __init__(self, source, type, **kwargs): # pylint: disable=redefined-builtin + "Any extension attributes must be passed explicitly using extensions." ) - def __repr__(self): + def __repr__(self) -> str: return "CloudEvent(source={}, type={}, specversion={}, id={}, time={})".format( self.source, self.type, self.specversion, self.id, self.time )[:1024] @classmethod - def from_dict(cls, event): - # type: (Dict) -> CloudEvent + def from_dict(cls, event: Dict[str, Any]) -> "CloudEvent": """ Returns the deserialized CloudEvent object when a dict is provided. :param event: The dict representation of the event which needs to be deserialized. :type event: dict :rtype: CloudEvent """ - kwargs = {} # type: Dict[Any, Any] + kwargs: Dict[str, Any] = {} reserved_attr = [ "data", "data_base64", @@ -193,8 +221,7 @@ def from_dict(cls, event): return event_obj @classmethod - def from_json(cls, event): - # type: (Any) -> CloudEvent + def from_json(cls, event: Any) -> "CloudEvent": """ Returns the deserialized CloudEvent object when a json payload is provided. :param event: The json string that should be converted into a CloudEvent. This can also be diff --git a/sdk/core/azure-core/azure/core/paging.py b/sdk/core/azure-core/azure/core/paging.py index c06382555353..f0ee547039a3 100644 --- a/sdk/core/azure-core/azure/core/paging.py +++ b/sdk/core/azure-core/azure/core/paging.py @@ -31,6 +31,7 @@ Iterator, Iterable, Tuple, + Any, ) import logging @@ -46,9 +47,9 @@ class PageIterator(Iterator[Iterator[ReturnType]]): def __init__( self, - get_next, # type: Callable[[Optional[str]], ResponseType] - extract_data, # type: Callable[[ResponseType], Tuple[str, Iterable[ReturnType]]] - continuation_token=None, # type: Optional[str] + get_next: Callable[[Optional[str]], ResponseType], + extract_data: Callable[[ResponseType], Tuple[str, Iterable[ReturnType]]], + continuation_token: Optional[str] = None, ): """Return an iterator of pages. @@ -61,15 +62,14 @@ def __init__( self._extract_data = extract_data self.continuation_token = continuation_token self._did_a_call_already = False - self._response = None # type: Optional[ResponseType] - self._current_page = None # type: Optional[Iterable[ReturnType]] + self._response: Optional[ResponseType] = None + self._current_page: Optional[Iterable[ReturnType]] = None - def __iter__(self): + def __iter__(self) -> Iterator[Iterator[ReturnType]]: """Return 'self'.""" return self - def __next__(self): - # type: () -> Iterator[ReturnType] + def __next__(self) -> Iterator[ReturnType]: if self.continuation_token is None and self._did_a_call_already: raise StopIteration("End of paging") try: @@ -85,11 +85,11 @@ def __next__(self): return iter(self._current_page) - next = __next__ # Python 2 compatibility. + next = __next__ # Python 2 compatibility. Can't be removed as some people are using ".next()" even in Py3 class ItemPaged(Iterator[ReturnType]): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: """Return an iterator of items. args and kwargs will be passed to the PageIterator constructor directly, @@ -97,7 +97,7 @@ def __init__(self, *args, **kwargs): """ self._args = args self._kwargs = kwargs - self._page_iterator = None + self._page_iterator: Optional[Iterator[ReturnType]] = None self._page_iterator_class = self._kwargs.pop( "page_iterator_class", PageIterator ) @@ -117,12 +117,12 @@ def by_page( continuation_token=continuation_token, *self._args, **self._kwargs ) - def __repr__(self): + def __repr__(self) -> str: return "".format( hex(id(self)) ) - def __iter__(self): + def __iter__(self) -> Iterator[ReturnType]: """Return 'self'.""" return self @@ -131,4 +131,4 @@ def __next__(self) -> ReturnType: self._page_iterator = itertools.chain.from_iterable(self.by_page()) return next(self._page_iterator) - next = __next__ # Python 2 compatibility. + next = __next__ # Python 2 compatibility. Can't be removed as some people are using ".next()" even in Py3 diff --git a/sdk/core/azure-core/azure/core/pipeline/__init__.py b/sdk/core/azure-core/azure/core/pipeline/__init__.py index b060aae88413..fcadb2b3518d 100644 --- a/sdk/core/azure-core/azure/core/pipeline/__init__.py +++ b/sdk/core/azure-core/azure/core/pipeline/__init__.py @@ -24,17 +24,13 @@ # # -------------------------------------------------------------------------- -import abc -from typing import TypeVar, Generic -from contextlib import AbstractContextManager # pylint: disable=unused-import - -ABC = abc.ABC +from typing import TypeVar, Generic, Dict, Any HTTPResponseType = TypeVar("HTTPResponseType") HTTPRequestType = TypeVar("HTTPRequestType") -class PipelineContext(dict): +class PipelineContext(Dict[str, Any]): """A context object carried by the pipeline request and response containers. This is transport specific and can contain data persisted between @@ -124,8 +120,7 @@ class PipelineRequest(Generic[HTTPRequestType]): :type context: ~azure.core.pipeline.PipelineContext """ - def __init__(self, http_request, context): - # type: (HTTPRequestType, PipelineContext) -> None + def __init__(self, http_request: HTTPRequestType, context: PipelineContext) -> None: self.http_request = http_request self.context = context @@ -148,8 +143,12 @@ class PipelineResponse(Generic[HTTPRequestType, HTTPResponseType]): :type context: ~azure.core.pipeline.PipelineContext """ - def __init__(self, http_request, http_response, context): - # type: (HTTPRequestType, HTTPResponseType, PipelineContext) -> None + def __init__( + self, + http_request: HTTPRequestType, + http_response: HTTPResponseType, + context: PipelineContext, + ) -> None: self.http_request = http_request self.http_response = http_response self.context = context diff --git a/sdk/core/azure-core/azure/core/pipeline/_base.py b/sdk/core/azure-core/azure/core/pipeline/_base.py index 81149840dbe8..ecb62e9677ed 100644 --- a/sdk/core/azure-core/azure/core/pipeline/_base.py +++ b/sdk/core/azure-core/azure/core/pipeline/_base.py @@ -25,7 +25,7 @@ # -------------------------------------------------------------------------- import logging -from typing import Generic, TypeVar, List, Union, Any, Dict +from typing import Generic, TypeVar, Union, Any, List, Dict, Optional from contextlib import AbstractContextManager from azure.core.pipeline import ( PipelineRequest, @@ -52,13 +52,11 @@ class _SansIOHTTPPolicyRunner(HTTPPolicy): :type policy: ~azure.core.pipeline.policies.SansIOHTTPPolicy """ - def __init__(self, policy): - # type: (SansIOHTTPPolicy) -> None + def __init__(self, policy: SansIOHTTPPolicy) -> None: super(_SansIOHTTPPolicyRunner, self).__init__() self._policy = policy - def send(self, request): - # type: (PipelineRequest) -> PipelineResponse + def send(self, request: PipelineRequest) -> PipelineResponse: """Modifies the request and sends to the next policy in the chain. :param request: The PipelineRequest object. @@ -85,8 +83,7 @@ class _TransportRunner(HTTPPolicy): :param sender: The Http Transport instance. """ - def __init__(self, sender): - # type: (HttpTransportType) -> None + def __init__(self, sender: HttpTransportType) -> None: super(_TransportRunner, self).__init__() self._sender = sender @@ -124,9 +121,10 @@ class Pipeline(AbstractContextManager, Generic[HTTPRequestType, HTTPResponseType :caption: Builds the pipeline for synchronous transport. """ - def __init__(self, transport, policies=None): - # type: (HttpTransportType, PoliciesType) -> None - self._impl_policies = [] # type: List[HTTPPolicy] + def __init__( + self, transport: HttpTransportType, policies: Optional[PoliciesType] = None + ) -> None: + self._impl_policies: List[HTTPPolicy] = [] self._transport = transport for policy in policies or []: @@ -139,8 +137,7 @@ def __init__(self, transport, policies=None): if self._impl_policies: self._impl_policies[-1].next = _TransportRunner(self._transport) - def __enter__(self): - # type: () -> Pipeline + def __enter__(self) -> "Pipeline": self._transport.__enter__() # type: ignore return self @@ -148,8 +145,7 @@ def __exit__(self, *exc_details): # pylint: disable=arguments-differ self._transport.__exit__(*exc_details) @staticmethod - def _prepare_multipart_mixed_request(request): - # type: (HTTPRequestType) -> None + def _prepare_multipart_mixed_request(request: HTTPRequestType) -> None: """Will execute the multipart policies. Does nothing if "set_multipart_mixed" was never called. @@ -158,9 +154,9 @@ def _prepare_multipart_mixed_request(request): if not multipart_mixed_info: return - requests = multipart_mixed_info[0] # type: List[HTTPRequestType] - policies = multipart_mixed_info[1] # type: List[SansIOHTTPPolicy] - pipeline_options = multipart_mixed_info[3] # type: Dict[str, Any] + requests: List[HTTPRequestType] = multipart_mixed_info[0] + policies: List[SansIOHTTPPolicy] = multipart_mixed_info[1] + pipeline_options: Dict[str, Any] = multipart_mixed_info[3] # Apply on_requests concurrently to all requests import concurrent.futures @@ -180,8 +176,7 @@ def prepare_requests(req): _ for _ in executor.map(prepare_requests, requests) ] - def _prepare_multipart(self, request): - # type: (HTTPRequestType) -> None + def _prepare_multipart(self, request: HTTPRequestType) -> None: # This code is fine as long as HTTPRequestType is actually # azure.core.pipeline.transport.HTTPRequest, bu we don't check it in here # since we didn't see (yet) pipeline usage where it's not this actual instance @@ -189,8 +184,7 @@ def _prepare_multipart(self, request): self._prepare_multipart_mixed_request(request) request.prepare_multipart_body() # type: ignore - def run(self, request, **kwargs): - # type: (HTTPRequestType, Any) -> PipelineResponse + def run(self, request: HTTPRequestType, **kwargs: Any) -> PipelineResponse: """Runs the HTTP Request through the chained policies. :param request: The HTTP request object. @@ -200,9 +194,9 @@ def run(self, request, **kwargs): """ self._prepare_multipart(request) context = PipelineContext(self._transport, **kwargs) - pipeline_request = PipelineRequest( + pipeline_request: PipelineRequest[HTTPRequestType] = PipelineRequest( request, context - ) # type: PipelineRequest[HTTPRequestType] + ) first_node = ( self._impl_policies[0] if self._impl_policies diff --git a/sdk/core/azure-core/azure/core/pipeline/_base_async.py b/sdk/core/azure-core/azure/core/pipeline/_base_async.py index 3614e14cb2f3..7cbd458da942 100644 --- a/sdk/core/azure-core/azure/core/pipeline/_base_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/_base_async.py @@ -23,7 +23,7 @@ # IN THE SOFTWARE. # # -------------------------------------------------------------------------- -from typing import Any, Union, List, Generic, TypeVar, Dict +from typing import Any, Union, Generic, TypeVar, List, Dict from contextlib import AbstractAsyncContextManager from azure.core.pipeline import PipelineRequest, PipelineResponse, PipelineContext @@ -154,9 +154,9 @@ async def _prepare_multipart_mixed_request(self, request: HTTPRequestType) -> No if not multipart_mixed_info: return - requests = multipart_mixed_info[0] # type: List[HTTPRequestType] - policies = multipart_mixed_info[1] # type: List[SansIOHTTPPolicy] - pipeline_options = multipart_mixed_info[3] # type: Dict[str, Any] + requests: List[HTTPRequestType] = multipart_mixed_info[0] + policies: List[SansIOHTTPPolicy] = multipart_mixed_info[1] + pipeline_options: Dict[str, Any] = multipart_mixed_info[3] async def prepare_requests(req): if req.multipart_mixed_info: diff --git a/sdk/core/azure-core/azure/core/pipeline/_tools.py b/sdk/core/azure-core/azure/core/pipeline/_tools.py index 0a30d6b1462c..c19460f73a2c 100644 --- a/sdk/core/azure-core/azure/core/pipeline/_tools.py +++ b/sdk/core/azure-core/azure/core/pipeline/_tools.py @@ -40,8 +40,7 @@ def await_result(func, *args, **kwargs): return result -def is_rest(obj): - # type: (Any) -> bool +def is_rest(obj) -> bool: """Return whether a request or a response is a rest request / response. Checking whether the response has the object content can sometimes result @@ -52,8 +51,7 @@ def is_rest(obj): return hasattr(obj, "is_stream_consumed") or hasattr(obj, "content") -def handle_non_stream_rest_response(response): - # type: (RestHttpResponse) -> None +def handle_non_stream_rest_response(response: "RestHttpResponse") -> None: """Handle reading and closing of non stream rest responses. For our new rest responses, we have to call .read() and .close() for our non-stream responses. This way, we load in the body for users to access. diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py index 5dd120a3dd0b..2a70d753f33b 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py @@ -4,7 +4,7 @@ # license information. # ------------------------------------------------------------------------- import time -from typing import TYPE_CHECKING, Any, Dict, Optional # pylint:disable=unused-import +from typing import TYPE_CHECKING, Dict, Optional from . import HTTPPolicy, SansIOHTTPPolicy from ...exceptions import ServiceRequestError @@ -21,7 +21,7 @@ # pylint:disable=too-few-public-methods -class _BearerTokenCredentialPolicyBase(object): +class _BearerTokenCredentialPolicyBase: """Base class for a Bearer Token Credential Policy. :param credential: The credential. @@ -29,17 +29,19 @@ class _BearerTokenCredentialPolicyBase(object): :param str scopes: Lets you specify the type of access needed. """ - def __init__(self, credential, *scopes, **kwargs): # pylint:disable=unused-argument - # type: (TokenCredential, *str, **Any) -> None + def __init__( + self, + credential: "TokenCredential", + *scopes: str, + **kwargs # pylint:disable=unused-argument + ) -> None: super(_BearerTokenCredentialPolicyBase, self).__init__() self._scopes = scopes self._credential = credential - self._token = None # type: Optional[AccessToken] + self._token: Optional["AccessToken"] = None @staticmethod - def _enforce_https(request): - # type: (PipelineRequest) -> None - + def _enforce_https(request: "PipelineRequest") -> None: # move 'enforce_https' from options to context so it persists # across retries but isn't passed to a transport implementation option = request.context.options.pop("enforce_https", None) @@ -55,8 +57,7 @@ def _enforce_https(request): ) @staticmethod - def _update_headers(headers, token): - # type: (Dict[str, str], str) -> None + def _update_headers(headers: Dict[str, str], token: str) -> None: """Updates the Authorization header with the bearer token. :param dict headers: The HTTP Request headers @@ -65,8 +66,7 @@ def _update_headers(headers, token): headers["Authorization"] = "Bearer {}".format(token) @property - def _need_new_token(self): - # type: () -> bool + def _need_new_token(self) -> bool: return not self._token or self._token.expires_on - time.time() < 300 @@ -79,8 +79,7 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy): :raises: :class:`~azure.core.exceptions.ServiceRequestError` """ - def on_request(self, request): - # type: (PipelineRequest) -> None + def on_request(self, request: "PipelineRequest") -> None: """Called before the policy sends a request. The base implementation authorizes the request with a bearer token. @@ -93,8 +92,9 @@ def on_request(self, request): self._token = self._credential.get_token(*self._scopes) self._update_headers(request.http_request.headers, self._token.token) - def authorize_request(self, request, *scopes, **kwargs): - # type: (PipelineRequest, *str, **Any) -> None + def authorize_request( + self, request: "PipelineRequest", *scopes: str, **kwargs + ) -> None: """Acquire a token from the credential and authorize the request with it. Keyword arguments are passed to the credential's get_token method. The token will be cached and used to @@ -106,8 +106,7 @@ def authorize_request(self, request, *scopes, **kwargs): self._token = self._credential.get_token(*scopes, **kwargs) self._update_headers(request.http_request.headers, self._token.token) - def send(self, request): - # type: (PipelineRequest) -> PipelineResponse + def send(self, request: "PipelineRequest") -> "PipelineResponse": """Authorize request with a bearer token and send it to the next policy :param request: The pipeline request object @@ -135,8 +134,9 @@ def send(self, request): return response - def on_challenge(self, request, response): - # type: (PipelineRequest, PipelineResponse) -> bool + def on_challenge( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> bool: """Authorize request according to an authentication challenge This method is called when the resource provider responds 401 with a WWW-Authenticate header. @@ -148,8 +148,9 @@ def on_challenge(self, request, response): # pylint:disable=unused-argument,no-self-use return False - def on_response(self, request, response): - # type: (PipelineRequest, PipelineResponse) -> None + def on_response( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> None: """Executed after the request comes back from the next policy. :param request: Request to be modified after returning from the policy. @@ -158,8 +159,7 @@ def on_response(self, request, response): :type response: ~azure.core.pipeline.PipelineResponse """ - def on_exception(self, request): - # type: (PipelineRequest) -> None + def on_exception(self, request: "PipelineRequest") -> None: """Executed when an exception is raised while executing the next policy. This method is executed inside the exception handler. @@ -180,8 +180,12 @@ class AzureKeyCredentialPolicy(SansIOHTTPPolicy): :raises: ValueError or TypeError """ - def __init__(self, credential, name, **kwargs): # pylint: disable=unused-argument - # type: (AzureKeyCredential, str, **Any) -> None + def __init__( + self, + credential: "AzureKeyCredential", + name: str, + **kwargs # pylint: disable=unused-argument + ) -> None: super(AzureKeyCredentialPolicy, self).__init__() self._credential = credential if not name: @@ -202,8 +206,11 @@ class AzureSasCredentialPolicy(SansIOHTTPPolicy): :raises: ValueError or TypeError """ - def __init__(self, credential, **kwargs): # pylint: disable=unused-argument - # type: (AzureSasCredential, **Any) -> None + def __init__( + self, + credential: "AzureSasCredential", + **kwargs # pylint: disable=unused-argument + ) -> None: super(AzureSasCredentialPolicy, self).__init__() if not credential: raise ValueError("credential can not be None") diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py index 9a6f84c0b267..9bd4730e569b 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py @@ -5,7 +5,7 @@ # ------------------------------------------------------------------------- import asyncio import time -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Awaitable, Optional from azure.core.pipeline.policies import AsyncHTTPPolicy from azure.core.pipeline.policies._authentication import ( @@ -15,7 +15,6 @@ from .._tools_async import await_result if TYPE_CHECKING: - from typing import Any, Awaitable, Optional, Union from azure.core.credentials import AccessToken from azure.core.credentials_async import AsyncTokenCredential from azure.core.pipeline import PipelineRequest, PipelineResponse @@ -30,14 +29,14 @@ class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy): """ def __init__( - self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: "Any" + self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: Any ) -> None: # pylint:disable=unused-argument super().__init__() self._credential = credential self._lock = asyncio.Lock() self._scopes = scopes - self._token = None # type: Optional[AccessToken] + self._token: Optional["AccessToken"] = None async def on_request( self, request: "PipelineRequest" @@ -60,7 +59,7 @@ async def on_request( request.http_request.headers["Authorization"] = "Bearer " + self._token.token async def authorize_request( - self, request: "PipelineRequest", *scopes: str, **kwargs: "Any" + self, request: "PipelineRequest", *scopes: str, **kwargs: Any ) -> None: """Acquire a token from the credential and authorize the request with it. @@ -120,7 +119,7 @@ async def on_challenge( def on_response( self, request: "PipelineRequest", response: "PipelineResponse" - ) -> "Union[None, Awaitable[None]]": + ) -> Optional[Awaitable[None]]: """Executed after the request comes back from the next policy. :param request: Request to be modified after returning from the policy. diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_base.py b/sdk/core/azure-core/azure/core/pipeline/policies/_base.py index 0c588f8a46d5..8ed9fbb17ebd 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_base.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_base.py @@ -29,29 +29,29 @@ import logging from typing import ( + TYPE_CHECKING, Generic, TypeVar, Union, Any, - Dict, Optional, -) # pylint: disable=unused-import + Awaitable, + Dict, +) -try: - from typing import Awaitable # pylint: disable=unused-import -except ImportError: - pass +from azure.core.pipeline import PipelineRequest, PipelineResponse -from azure.core.pipeline import ABC, PipelineRequest, PipelineResponse +if TYPE_CHECKING: + from azure.core.pipeline.transport import HttpTransport -HTTPResponseType = TypeVar("HTTPResponseType") -HTTPRequestType = TypeVar("HTTPRequestType") +HTTPResponseTypeVar = TypeVar("HTTPResponseTypeVar") +HTTPRequestTypeVar = TypeVar("HTTPRequestTypeVar") _LOGGER = logging.getLogger(__name__) -class HTTPPolicy(ABC, Generic[HTTPRequestType, HTTPResponseType]): +class HTTPPolicy(abc.ABC, Generic[HTTPRequestTypeVar, HTTPResponseTypeVar]): """An HTTP policy ABC. Use with a synchronous pipeline. @@ -61,12 +61,16 @@ class HTTPPolicy(ABC, Generic[HTTPRequestType, HTTPResponseType]): :type next: ~azure.core.pipeline.policies.HTTPPolicy or ~azure.core.pipeline.transport.HttpTransport """ - def __init__(self): - self.next = None # type: Union[HTTPPolicy, HttpTransport] + next: Union[ + "HTTPPolicy[HTTPRequestTypeVar, HTTPResponseTypeVar]", + "HttpTransport[HTTPRequestTypeVar, HTTPResponseTypeVar]", + ] + """Pointer to the next policy or a transport. Will be set at pipeline creation.""" @abc.abstractmethod - def send(self, request): - # type: (PipelineRequest) -> PipelineResponse + def send( + self, request: PipelineRequest[HTTPRequestTypeVar] + ) -> PipelineResponse[HTTPRequestTypeVar, HTTPResponseTypeVar]: """Abstract send method for a synchronous pipeline. Mutates the request. Context content is dependent on the HttpTransport. @@ -78,7 +82,7 @@ def send(self, request): """ -class SansIOHTTPPolicy(Generic[HTTPRequestType, HTTPResponseType]): +class SansIOHTTPPolicy(Generic[HTTPRequestTypeVar, HTTPResponseTypeVar]): """Represents a sans I/O policy. SansIOHTTPPolicy is a base class for policies that only modify or @@ -90,16 +94,20 @@ class SansIOHTTPPolicy(Generic[HTTPRequestType, HTTPResponseType]): but they will then be tied to AsyncPipeline usage. """ - def on_request(self, request): - # type: (PipelineRequest) -> Union[None, Awaitable[None]] + def on_request( + self, request: PipelineRequest[HTTPRequestTypeVar] + ) -> Union[None, Awaitable[None]]: """Is executed before sending the request from next policy. :param request: Request to be modified before sent from next policy. :type request: ~azure.core.pipeline.PipelineRequest """ - def on_response(self, request, response): - # type: (PipelineRequest, PipelineResponse) -> Union[None, Awaitable[None]] + def on_response( + self, + request: PipelineRequest[HTTPRequestTypeVar], + response: PipelineResponse[HTTPRequestTypeVar, HTTPResponseTypeVar], + ) -> Union[None, Awaitable[None]]: """Is executed after the request comes back from the policy. :param request: Request to be modified after returning from the policy. @@ -109,8 +117,10 @@ def on_response(self, request, response): """ # pylint: disable=no-self-use - def on_exception(self, request): # pylint: disable=unused-argument - # type: (PipelineRequest) -> None + def on_exception( + self, + request: PipelineRequest[HTTPRequestTypeVar], # pylint: disable=unused-argument + ) -> None: """Is executed if an exception is raised while executing the next policy. This method is executed inside the exception handler. @@ -129,7 +139,7 @@ def on_exception(self, request): # pylint: disable=unused-argument return -class RequestHistory(object): +class RequestHistory: """A container for an attempted request and the applicable response. This is used to document requests/responses that resulted in redirected/retried requests. @@ -142,8 +152,13 @@ class RequestHistory(object): :param dict context: The pipeline context. """ - def __init__(self, http_request, http_response=None, error=None, context=None): - # type: (HTTPRequestType, Optional[HTTPResponseType], Exception, Optional[Dict[str, Any]]) -> None + def __init__( + self, + http_request: HTTPRequestTypeVar, + http_response: Optional[HTTPResponseTypeVar] = None, + error: Optional[Exception] = None, + context: Optional[Dict[str, Any]] = None, + ) -> None: self.http_request = copy.deepcopy(http_request) self.http_response = http_response self.error = error diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_base_async.py b/sdk/core/azure-core/azure/core/pipeline/policies/_base_async.py index f3718873a905..35c38488710b 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_base_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_base_async.py @@ -25,17 +25,20 @@ # -------------------------------------------------------------------------- import abc -from typing import Generic, TypeVar, Union, Any, cast +from typing import TYPE_CHECKING, Generic, TypeVar, Union -from azure.core.pipeline import PipelineRequest +from .. import PipelineRequest +if TYPE_CHECKING: + from ..transport._base_async import AsyncHttpTransport -AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType") -HTTPResponseType = TypeVar("HTTPResponseType") -HTTPRequestType = TypeVar("HTTPRequestType") +AsyncHTTPResponseTypeVar = TypeVar("AsyncHTTPResponseTypeVar") +HTTPResponseTypeVar = TypeVar("HTTPResponseTypeVar") +HTTPRequestTypeVar = TypeVar("HTTPRequestTypeVar") -class AsyncHTTPPolicy(abc.ABC, Generic[HTTPRequestType, AsyncHTTPResponseType]): + +class AsyncHTTPPolicy(abc.ABC, Generic[HTTPRequestTypeVar, AsyncHTTPResponseTypeVar]): """An async HTTP policy ABC. Use with an asynchronous pipeline. @@ -45,14 +48,14 @@ class AsyncHTTPPolicy(abc.ABC, Generic[HTTPRequestType, AsyncHTTPResponseType]): :type next: ~azure.core.pipeline.policies.AsyncHTTPPolicy or ~azure.core.pipeline.transport.AsyncHttpTransport """ - def __init__(self) -> None: - # next will be set once in the pipeline - from ..transport._base_async import AsyncHttpTransport - - self.next = cast(Union[AsyncHTTPPolicy, AsyncHttpTransport], None) + next: Union[ + "AsyncHTTPPolicy[HTTPRequestTypeVar, AsyncHTTPResponseTypeVar]", + "AsyncHttpTransport[HTTPRequestTypeVar, AsyncHTTPResponseTypeVar]", + ] + """Pointer to the next policy or a transport. Will be set at pipeline creation.""" @abc.abstractmethod - async def send(self, request: PipelineRequest): + async def send(self, request: PipelineRequest[HTTPRequestTypeVar]): """Abstract send method for a asynchronous pipeline. Mutates the request. Context content is dependent on the HttpTransport. diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_custom_hook.py b/sdk/core/azure-core/azure/core/pipeline/policies/_custom_hook.py index 1e17c2c30c99..c8c34eb16ef4 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_custom_hook.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_custom_hook.py @@ -41,8 +41,9 @@ def __init__( self._request_callback = kwargs.get("raw_request_hook") self._response_callback = kwargs.get("raw_response_hook") - def on_request(self, request): # pylint: disable=arguments-differ - # type: (PipelineRequest) -> None + def on_request( + self, request: PipelineRequest + ) -> None: # pylint: disable=arguments-differ """This is executed before sending the request to the next policy. :param request: The PipelineRequest object. @@ -59,8 +60,9 @@ def on_request(self, request): # pylint: disable=arguments-differ if response_callback: request.context["raw_response_hook"] = response_callback - def on_response(self, request, response): # pylint: disable=arguments-differ - # type: (PipelineRequest, PipelineResponse) -> None + def on_response( + self, request: PipelineRequest, response: PipelineResponse + ) -> None: # pylint: disable=arguments-differ """This is executed after the request comes back from the policy. :param request: The PipelineRequest object. diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_distributed_tracing.py b/sdk/core/azure-core/azure/core/pipeline/policies/_distributed_tracing.py index 787f1f5cf004..05f31620ffd0 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_distributed_tracing.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_distributed_tracing.py @@ -27,12 +27,7 @@ import logging import sys import urllib -from typing import ( - TYPE_CHECKING, - Optional, - Union, - Tuple, -) # pylint: disable=ungrouped-imports +from typing import TYPE_CHECKING, Optional, Union, Tuple from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.core.settings import settings @@ -44,22 +39,21 @@ HttpRequest, HttpResponse, AsyncHttpResponse, - ) # pylint: disable=ungrouped-imports + ) from azure.core.tracing._abstract_span import ( AbstractSpan, - ) # pylint: disable=ungrouped-imports + ) from azure.core.pipeline import ( PipelineRequest, PipelineResponse, - ) # pylint: disable=ungrouped-imports + ) HttpResponseType = Union[HttpResponse, AsyncHttpResponse] _LOGGER = logging.getLogger(__name__) -def _default_network_span_namer(http_request): - # type (HttpRequest) -> str +def _default_network_span_namer(http_request: "HttpRequest") -> str: """Extract the path to be used as network span name. :param http_request: The HTTP request @@ -92,8 +86,7 @@ def __init__(self, **kwargs): ) self._tracing_attributes = kwargs.get("tracing_attributes", {}) - def on_request(self, request): - # type: (PipelineRequest) -> None + def on_request(self, request: "PipelineRequest") -> None: ctxt = request.context.options try: span_impl_type = settings.tracing_implementation() @@ -115,14 +108,18 @@ def on_request(self, request): except Exception as err: # pylint: disable=broad-except _LOGGER.warning("Unable to start network span: %s", err) - def end_span(self, request, response=None, exc_info=None): - # type: (PipelineRequest, Optional[HttpResponseType], Optional[Tuple]) -> None + def end_span( + self, + request: "PipelineRequest", + response: Optional["HttpResponseType"] = None, + exc_info: Optional[Tuple] = None, + ) -> None: """Ends the span that is tracing the network and updates its status.""" if self.TRACING_CONTEXT not in request.context: return - span = request.context[self.TRACING_CONTEXT] # type: AbstractSpan - http_request = request.http_request # type: HttpRequest + span: "AbstractSpan" = request.context[self.TRACING_CONTEXT] + http_request: "HttpRequest" = request.http_request if span is not None: span.set_http_attributes(http_request, response=response) request_id = http_request.headers.get(self._REQUEST_ID) @@ -137,10 +134,10 @@ def end_span(self, request, response=None, exc_info=None): else: span.finish() - def on_response(self, request, response): - # type: (PipelineRequest, PipelineResponse) -> None + def on_response( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> None: self.end_span(request, response=response.http_response) - def on_exception(self, request): - # type: (PipelineRequest) -> None + def on_exception(self, request: "PipelineRequest") -> None: self.end_span(request, exc_info=sys.exc_info()) diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_redirect.py b/sdk/core/azure-core/azure/core/pipeline/policies/_redirect.py index d1cd038ecb32..823216343d07 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_redirect.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_redirect.py @@ -27,7 +27,7 @@ This module is the requests implementation of Pipeline ABC """ import logging -from urllib.parse import urlparse # type: ignore +from urllib.parse import urlparse from azure.core.exceptions import TooManyRedirectsError @@ -37,7 +37,7 @@ _LOGGER = logging.getLogger(__name__) -class RedirectPolicyBase(object): +class RedirectPolicyBase: REDIRECT_STATUSES = frozenset([300, 301, 302, 303, 307, 308]) diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_retry.py b/sdk/core/azure-core/azure/core/pipeline/policies/_retry.py index bc46dc6d8605..22e9d291f23e 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_retry.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_retry.py @@ -31,16 +31,6 @@ import logging import time from enum import Enum -from typing import ( # pylint: disable=unused-import - TYPE_CHECKING, - List, - Callable, - Iterator, - Any, - Union, - Dict, - Optional, -) from azure.core.pipeline import PipelineResponse from azure.core.exceptions import ( AzureError, @@ -64,7 +54,7 @@ class RetryMode(str, Enum, metaclass=CaseInsensitiveEnumMeta): Fixed = "fixed" -class RetryPolicyBase(object): +class RetryPolicyBase: # pylint: disable=too-many-instance-attributes #: Maximum backoff time. BACKOFF_MAX = 120 diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_universal.py b/sdk/core/azure-core/azure/core/pipeline/policies/_universal.py index d0bbd96b7fda..fea1da1c7d77 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_universal.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_universal.py @@ -35,19 +35,9 @@ import types import re import uuid -from typing import ( - IO, - TypeVar, - TYPE_CHECKING, - Type, - cast, - Any, - Union, - Dict, - Optional, - AnyStr, -) -import urllib +from typing import IO, cast, Union, Optional, AnyStr, Dict, MutableMapping +import urllib.parse +from typing_extensions import Protocol from azure.core import __version__ as azcore_version from azure.core.exceptions import DecodeError, raise_with_traceback @@ -55,18 +45,39 @@ from azure.core.pipeline import PipelineRequest, PipelineResponse from ._base import SansIOHTTPPolicy -if TYPE_CHECKING: - from azure.core.pipeline.transport import HttpResponse, AsyncHttpResponse _LOGGER = logging.getLogger(__name__) -ContentDecodePolicyType = TypeVar( - "ContentDecodePolicyType", bound="ContentDecodePolicy" -) -HTTPRequestType = TypeVar("HTTPRequestType") -HTTPResponseType = TypeVar("HTTPResponseType") -class HeadersPolicy(SansIOHTTPPolicy): +class HTTPRequestType(Protocol): + """Protocol compatible with new rest request and legacy transport request""" + + headers: MutableMapping[str, str] + url: str + method: str + body: bytes + + +class HTTPResponseType(Protocol): + """Protocol compatible with new rest response and legacy transport response""" + + @property + def headers(self) -> MutableMapping[str, str]: + ... + + @property + def status_code(self) -> int: + ... + + @property + def content_type(self) -> Optional[str]: + ... + + def text(self, encoding: Optional[str] = None) -> str: + ... + + +class HeadersPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): """A simple policy that sends the given headers with the request. This will overwrite any headers already defined in the request. Headers can be @@ -86,10 +97,9 @@ class HeadersPolicy(SansIOHTTPPolicy): """ def __init__( - self, base_headers=None, **kwargs - ): # pylint: disable=super-init-not-called - # type: (Dict[str, str], Any) -> None - self._headers = base_headers or {} + self, base_headers: Optional[Dict[str, str]] = None, **kwargs + ) -> None: # pylint: disable=super-init-not-called + self._headers: Dict[str, str] = base_headers or {} self._headers.update(kwargs.pop("headers", {})) @property @@ -97,7 +107,7 @@ def headers(self): """The current headers collection.""" return self._headers - def add_header(self, key, value): + def add_header(self, key: str, value: str) -> None: """Add a header to the configuration to be applied to all requests. :param str key: The header. @@ -105,8 +115,7 @@ def add_header(self, key, value): """ self._headers[key] = value - def on_request(self, request): - # type: (PipelineRequest) -> None + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: """Updates with the given headers before sending the request to the next policy. :param request: The PipelineRequest object @@ -118,11 +127,11 @@ def on_request(self, request): request.http_request.headers.update(additional_headers) -class _Unset(object): +class _Unset: pass -class RequestIdPolicy(SansIOHTTPPolicy): +class RequestIdPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): """A simple policy that sets the given request id in the header. This will overwrite request id that is already defined in the request. Request id can be @@ -142,20 +151,18 @@ class RequestIdPolicy(SansIOHTTPPolicy): :caption: Configuring a request id policy. """ - def __init__(self, **kwargs): # pylint: disable=super-init-not-called - # type: (dict) -> None + def __init__(self, **kwargs) -> None: # pylint: disable=super-init-not-called self._request_id = kwargs.pop("request_id", _Unset) self._auto_request_id = kwargs.pop("auto_request_id", True) - def set_request_id(self, value): + def set_request_id(self, value: str) -> None: """Add the request id to the configuration to be applied to all requests. :param str value: The request id value. """ self._request_id = value - def on_request(self, request): - # type: (PipelineRequest) -> None + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: """Updates with the given request id before sending the request to the next policy. :param request: The PipelineRequest object @@ -177,11 +184,11 @@ def on_request(self, request): return request_id = str(uuid.uuid1()) if request_id is not unset: - header = {"x-ms-client-request-id": request_id} + header = {"x-ms-client-request-id": cast(str, request_id)} request.http_request.headers.update(header) -class UserAgentPolicy(SansIOHTTPPolicy): +class UserAgentPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): """User-Agent Policy. Allows custom values to be added to the User-Agent header. :param str base_user_agent: Sets the base user agent value. @@ -206,9 +213,8 @@ class UserAgentPolicy(SansIOHTTPPolicy): _ENV_ADDITIONAL_USER_AGENT = "AZURE_HTTP_USER_AGENT" def __init__( - self, base_user_agent=None, **kwargs - ): # pylint: disable=super-init-not-called - # type: (Optional[str], **Any) -> None + self, base_user_agent: Optional[str] = None, **kwargs + ) -> None: # pylint: disable=super-init-not-called self.overwrite = kwargs.pop("user_agent_overwrite", False) self.use_env = kwargs.pop("user_agent_use_env", True) application_id = kwargs.pop("user_agent", None) @@ -225,8 +231,7 @@ def __init__( self._user_agent = "{} {}".format(application_id, self._user_agent) @property - def user_agent(self): - # type: () -> str + def user_agent(self) -> str: """The current user agent value.""" if self.use_env: add_user_agent_header = os.environ.get( @@ -236,15 +241,13 @@ def user_agent(self): return "{} {}".format(self._user_agent, add_user_agent_header) return self._user_agent - def add_user_agent(self, value): - # type: (str) -> None + def add_user_agent(self, value: str) -> None: """Add value to current user agent with a space. :param str value: value to add to user agent. """ self._user_agent = "{} {}".format(self._user_agent, value) - def on_request(self, request): - # type: (PipelineRequest) -> None + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: """Modifies the User-Agent header before the request is sent. :param request: The PipelineRequest object @@ -264,7 +267,7 @@ def on_request(self, request): http_request.headers[self._USERAGENT] = self.user_agent -class NetworkTraceLoggingPolicy(SansIOHTTPPolicy): +class NetworkTraceLoggingPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): """The logging policy in the pipeline is used to output HTTP network trace to the configured logger. @@ -287,8 +290,9 @@ def __init__( ): # pylint: disable=unused-argument self.enable_http_logger = logging_enable - def on_request(self, request): # pylint: disable=too-many-return-statements - # type: (PipelineRequest) -> None + def on_request( + self, request: PipelineRequest[HTTPRequestType] + ) -> None: # pylint: disable=too-many-return-statements """Logs HTTP request to the DEBUG logger. :param request: The PipelineRequest object. @@ -331,8 +335,11 @@ def on_request(self, request): # pylint: disable=too-many-return-statements except Exception as err: # pylint: disable=broad-except _LOGGER.debug("Failed to log request: %r", err) - def on_response(self, request, response): - # type: (PipelineRequest, PipelineResponse) -> None + def on_response( + self, + request: PipelineRequest[HTTPRequestType], + response: PipelineResponse[HTTPRequestType, HTTPResponseType], + ) -> None: """Logs HTTP response to the DEBUG logger. :param request: The PipelineRequest object. @@ -389,7 +396,10 @@ def DEFAULT_HEADERS_WHITELIST(cls, value): cls.DEFAULT_HEADERS_ALLOWLIST = value -class HttpLoggingPolicy(SansIOHTTPPolicy, metaclass=_HiddenClassProperties): +class HttpLoggingPolicy( + SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType], + metaclass=_HiddenClassProperties, +): """The Pipeline policy that handles logging of HTTP requests and responses.""" DEFAULT_HEADERS_ALLOWLIST = set( @@ -451,8 +461,9 @@ def _redact_header(self, key, value): else HttpLoggingPolicy.REDACTED_PLACEHOLDER ) - def on_request(self, request): # pylint: disable=too-many-return-statements - # type: (PipelineRequest) -> None + def on_request( # pylint: disable=too-many-return-statements + self, request: PipelineRequest[HTTPRequestType] + ) -> None: """Logs HTTP method, url and headers. :param request: The PipelineRequest object. :type request: ~azure.core.pipeline.PipelineRequest @@ -528,13 +539,23 @@ def on_request(self, request): # pylint: disable=too-many-return-statements except Exception as err: # pylint: disable=broad-except logger.warning("Failed to log request: %s", repr(err)) - def on_response(self, request, response): - # type: (PipelineRequest, PipelineResponse) -> None + def on_response( + self, + request: PipelineRequest[HTTPRequestType], + response: PipelineResponse[HTTPRequestType, HTTPResponseType], + ) -> None: http_response = response.http_response - try: - logger = response.context["logger"] + # Get logger in my context first (request has been retried) + # then read from kwargs (pop if that's the case) + # then use my instance logger + # If on_request was called, should always read from context + options = request.context.options + logger = request.context.setdefault( + "logger", options.pop("logger", self.logger) + ) + try: if not logger.isEnabledFor(logging.INFO): return @@ -556,7 +577,7 @@ def on_response(self, request, response): logger.warning("Failed to log response: %s", repr(err)) -class ContentDecodePolicy(SansIOHTTPPolicy): +class ContentDecodePolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): """Policy for decoding unstreamed response content. :param response_encoding: The encoding to use if known for this service (will disable auto-detection) @@ -570,17 +591,18 @@ class ContentDecodePolicy(SansIOHTTPPolicy): CONTEXT_NAME = "deserialized_data" def __init__( - self, response_encoding=None, **kwargs - ): # pylint: disable=unused-argument - # type: (Optional[str], Any) -> None + self, + response_encoding: Optional[str] = None, + **kwargs # pylint: disable=unused-argument + ) -> None: self._response_encoding = response_encoding @classmethod def deserialize_from_text( - cls, # type: Type[ContentDecodePolicyType] - data, # type: Optional[Union[AnyStr, IO]] - mime_type=None, # Optional[str] - response=None, # Optional[Union[HttpResponse, AsyncHttpResponse]] + cls, + data: Optional[Union[AnyStr, IO]], + mime_type: Optional[str] = None, + response: Optional[HTTPResponseType] = None, ): """Decode response data according to content-type. @@ -650,9 +672,9 @@ def _json_attemp(data): @classmethod def deserialize_from_http_generics( - cls, # type: Type[ContentDecodePolicyType] - response, # Union[HttpResponse, AsyncHttpResponse] - encoding=None, # Optional[str] + cls, + response: HTTPResponseType, + encoding: Optional[str] = None, ): """Deserialize from HTTP response. @@ -675,23 +697,20 @@ def deserialize_from_http_generics( # Rely on transport implementation to give me "text()" decoded correctly if hasattr(response, "read"): - try: - # since users can call deserialize_from_http_generics by themselves - # we want to make sure our new responses are read before we try to - # deserialize. Only read sync responses since we're in a sync function - if not inspect.iscoroutinefunction(response.read): - response.read() - except AttributeError: - # raises an AttributeError in 2.7 bc inspect.iscoroutinefunction was added in 3.5 - # Entering here means it's 2.7 and that the response has a read method, so we read - # bc it will be sync. - response.read() + # since users can call deserialize_from_http_generics by themselves + # we want to make sure our new responses are read before we try to + # deserialize. Only read sync responses since we're in a sync function + # + # Technically HttpResponse do not contain a "read()", but we don't know what + # people have been able to pass here, so keep this code for safety, + # even if it's likely dead code + if not inspect.iscoroutinefunction(response.read): # type: ignore + response.read() # type: ignore return cls.deserialize_from_text( response.text(encoding), mime_type, response=response ) - def on_request(self, request): - # type: (PipelineRequest) -> None + def on_request(self, request: PipelineRequest) -> None: options = request.context.options response_encoding = options.pop("response_encoding", self._response_encoding) if response_encoding: @@ -699,10 +718,9 @@ def on_request(self, request): def on_response( self, - request, # type: PipelineRequest[HTTPRequestType] - response, # type: PipelineResponse[HTTPRequestType, Union[HttpResponse, AsyncHttpResponse]] - ): - # type: (...) -> None + request: PipelineRequest[HTTPRequestType], + response: PipelineResponse[HTTPRequestType, HTTPResponseType], + ) -> None: """Extract data from the body of a REST response object. This will load the entire payload in memory. Will follow Content-Type to parse. @@ -730,7 +748,7 @@ def on_response( ) -class ProxyPolicy(SansIOHTTPPolicy): +class ProxyPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): """A proxy policy. Dictionary mapping protocol or protocol and host to the URL of the proxy @@ -754,8 +772,7 @@ def __init__( ): # pylint: disable=unused-argument,super-init-not-called self.proxies = proxies - def on_request(self, request): - # type: (PipelineRequest) -> None + def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: ctxt = request.context.options if self.proxies and "proxies" not in ctxt: ctxt["proxies"] = self.proxies diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_base.py b/sdk/core/azure-core/azure/core/pipeline/transport/_base.py index d54a8ee38fee..1a5a568a6325 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_base.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_base.py @@ -24,37 +24,34 @@ # # -------------------------------------------------------------------------- import abc +from collections.abc import MutableMapping +from contextlib import AbstractContextManager from email.message import Message import json import logging import time import copy -from urllib.parse import urlparse # type: ignore +from urllib.parse import urlparse import xml.etree.ElementTree as ET from typing import ( - TYPE_CHECKING, Generic, TypeVar, IO, - List, Union, Any, Mapping, - Dict, Optional, Tuple, Iterator, Type, + Dict, + List, ) from http.client import HTTPResponse as _HTTPResponse from azure.core.exceptions import HttpResponseError -from azure.core.pipeline import ( - ABC, - AbstractContextManager, -) from ...utils._utils import case_insensitive_dict from ...utils._pipeline_transport_rest_shared import ( _format_parameters_helper, @@ -68,12 +65,10 @@ ) -if TYPE_CHECKING: - from collections.abc import MutableMapping - HTTPResponseType = TypeVar("HTTPResponseType") HTTPRequestType = TypeVar("HTTPRequestType") PipelineType = TypeVar("PipelineType") +DataType = Optional[Union[bytes, Dict[str, Union[str, int]]]] _LOGGER = logging.getLogger(__name__) @@ -106,8 +101,7 @@ def _format_url_section(template, **kwargs): # No URL sections left - returning None -def _urljoin(base_url, stub_url): - # type: (str, str) -> str +def _urljoin(base_url: str, stub_url: str) -> str: """Append to end of base URL without losing query parameters. :param str base_url: The base URL. @@ -121,13 +115,12 @@ def _urljoin(base_url, stub_url): class HttpTransport( - AbstractContextManager, ABC, Generic[HTTPRequestType, HTTPResponseType] -): # type: ignore + AbstractContextManager, abc.ABC, Generic[HTTPRequestType, HTTPResponseType] +): """An http sender ABC.""" @abc.abstractmethod - def send(self, request, **kwargs): - # type: (HTTPRequestType, Any) -> HTTPResponseType + def send(self, request: HTTPRequestType, **kwargs) -> HTTPResponseType: """Send the request using this HTTP sender. :param request: The pipeline request object @@ -148,7 +141,7 @@ def sleep(self, duration): # pylint: disable=no-self-use time.sleep(duration) -class HttpRequest(object): +class HttpRequest: """Represents a HTTP request. URL can be given without query parameters, to be added later using "format_parameters". @@ -158,17 +151,23 @@ class HttpRequest(object): :param dict[str,str] headers: HTTP headers :param files: Files list. :param data: Body to be sent. - :type data: bytes or str. + :type data: bytes or dict (for form) """ - def __init__(self, method, url, headers=None, files=None, data=None): - # type: (str, str, Mapping[str, str], Any, Any) -> None + def __init__( + self, + method: str, + url: str, + headers: Optional[Mapping[str, str]] = None, + files: Optional[Any] = None, + data: Optional[DataType] = None, + ) -> None: self.method = method self.url = url self.headers = case_insensitive_dict(headers) self.files = files self.data = data - self.multipart_mixed_info = None # type: Optional[Tuple] + self.multipart_mixed_info: Optional[Tuple] = None def __repr__(self): return "".format(self.method, self.url) @@ -184,7 +183,7 @@ def __deepcopy__(self, memo=None): return copy.copy(self) @property - def query(self): + def query(self) -> Dict[str, str]: """The query parameters of the request as a dict. :rtype: dict[str, str] @@ -195,20 +194,21 @@ def query(self): return {} @property - def body(self): + def body(self) -> DataType: """Alias to data. - :rtype: bytes + :rtype: bytes or dict """ return self.data @body.setter - def body(self, value): + def body(self, value: DataType): self.data = value @staticmethod - def _format_data(data): - # type: (Union[str, IO]) -> Union[Tuple[None, str], Tuple[Optional[str], IO, str]] + def _format_data( + data: Union[str, IO] + ) -> Union[Tuple[None, str], Tuple[Optional[str], IO, str]]: """Format field data according to whether it is a stream or a string for a form-data request. @@ -217,8 +217,7 @@ def _format_data(data): """ return _format_data_helper(data) - def format_parameters(self, params): - # type: (Dict[str, str]) -> None + def format_parameters(self, params: Dict[str, str]) -> None: """Format parameters into a valid query string. It's assumed all parameters have already been quoted as valid URL strings. @@ -313,8 +312,7 @@ def set_bytes_body(self, data): self.data = data self.files = None - def set_multipart_mixed(self, *requests, **kwargs): - # type: (HttpRequest, Any) -> None + def set_multipart_mixed(self, *requests: "HttpRequest", **kwargs) -> None: """Set the part of a multipart/mixed. Only supported args for now are HttpRequest objects. @@ -337,8 +335,7 @@ def set_multipart_mixed(self, *requests, **kwargs): kwargs, ) - def prepare_multipart_body(self, content_index=0): - # type: (int) -> int + def prepare_multipart_body(self, content_index: int = 0) -> int: """Will prepare the body of this request according to the multipart information. This call assumes the on_request policies have been applied already in their @@ -352,8 +349,7 @@ def prepare_multipart_body(self, content_index=0): """ return _prepare_multipart_body_helper(self, content_index) - def serialize(self): - # type: () -> bytes + def serialize(self) -> bytes: """Serialize this request using application/http spec. :rtype: bytes @@ -361,7 +357,7 @@ def serialize(self): return _serialize_request(self) -class _HttpResponseBase(object): +class _HttpResponseBase: """Represent a HTTP response. No body is defined here on purpose, since async pipeline @@ -374,23 +370,25 @@ class _HttpResponseBase(object): :param int block_size: Defaults to 4096 bytes. """ - def __init__(self, request, internal_response, block_size=None): - # type: (HttpRequest, Any, Optional[int]) -> None + def __init__( + self, + request: HttpRequest, + internal_response: Any, + block_size: Optional[int] = None, + ) -> None: self.request = request self.internal_response = internal_response - self.status_code = None # type: Optional[int] - self.headers = {} # type: MutableMapping[str, str] - self.reason = None # type: Optional[str] - self.content_type = None # type: Optional[str] - self.block_size = block_size or 4096 # Default to same as Requests + self.status_code: Optional[int] = None + self.headers: MutableMapping[str, str] = {} + self.reason: Optional[str] = None + self.content_type: Optional[str] = None + self.block_size: int = block_size or 4096 # Default to same as Requests - def body(self): - # type: () -> bytes + def body(self) -> bytes: """Return the whole body as bytes in memory.""" raise NotImplementedError() - def text(self, encoding=None): - # type: (str) -> str + def text(self, encoding: Optional[str] = None) -> str: """Return the whole body as a string. :param str encoding: The encoding to apply. If None, use "utf-8" with BOM parsing (utf-8-sig). @@ -400,15 +398,20 @@ def text(self, encoding=None): encoding = "utf-8-sig" return self.body().decode(encoding) - def _decode_parts(self, message, http_response_type, requests): - # type: (Message, Type[_HttpResponseBase], List[HttpRequest]) -> List[HttpResponse] + def _decode_parts( + self, + message: Message, + http_response_type: Type["_HttpResponseBase"], + requests: List[HttpRequest], + ) -> List["HttpResponse"]: """Rebuild an HTTP response from pure string.""" return _decode_parts_helper( self, message, http_response_type, requests, _deserialize_response ) - def _get_raw_parts(self, http_response_type=None): - # type (Optional[Type[_HttpResponseBase]]) -> Iterator[HttpResponse] + def _get_raw_parts( + self, http_response_type: Optional[Type["_HttpResponseBase"]] = None + ) -> Iterator["HttpResponse"]: """Assuming this body is multipart, return the iterator or parts. If parts are application/http use http_response_type or HttpClientTransportResponse @@ -418,12 +421,11 @@ def _get_raw_parts(self, http_response_type=None): self, http_response_type or HttpClientTransportResponse ) - def raise_for_status(self): - # type () -> None + def raise_for_status(self) -> None: """Raises an HttpResponseError if the response has an error status code. If response is good, does nothing. """ - if self.status_code >= 400: + if not self.status_code or self.status_code >= 400: raise HttpResponseError(response=self) def __repr__(self): @@ -437,8 +439,7 @@ def __repr__(self): class HttpResponse(_HttpResponseBase): # pylint: disable=abstract-method - def stream_download(self, pipeline, **kwargs): - # type: (PipelineType, **Any) -> Iterator[bytes] + def stream_download(self, pipeline: PipelineType, **kwargs) -> Iterator[bytes]: """Generator for streaming request body data. Should be implemented by sub-classes if streaming download @@ -447,8 +448,7 @@ def stream_download(self, pipeline, **kwargs): :rtype: iterator[bytes] """ - def parts(self): - # type: () -> Iterator[HttpResponse] + def parts(self) -> Iterator["HttpResponse"]: """Assuming the content-type is multipart/mixed, will return the parts as an iterator. :rtype: iterator[HttpResponse] @@ -496,7 +496,7 @@ def _deserialize_response( return http_response_type(http_request, response) -class PipelineClientBase(object): +class PipelineClientBase: """Base class for pipeline clients. :param str base_url: URL for the request. @@ -507,15 +507,14 @@ def __init__(self, base_url): def _request( self, - method, # type: str - url, # type: str - params, # type: Optional[Dict[str, str]] - headers, # type: Optional[Dict[str, str]] - content, # type: Any - form_content, # type: Optional[Dict[str, Any]] - stream_content, # type: Any - ): - # type: (...) -> HttpRequest + method: str, + url: str, + params: Optional[Dict[str, str]], + headers: Optional[Dict[str, str]], + content: Any, + form_content: Optional[Dict[str, Any]], + stream_content: Any, + ) -> HttpRequest: """Create HttpRequest object. If content is not None, guesses will be used to set the right body: @@ -563,8 +562,7 @@ def _request( return request - def format_url(self, url_template, **kwargs): - # type: (str, Any) -> str + def format_url(self, url_template: str, **kwargs) -> str: """Format request URL with the client base URL, unless the supplied URL is already absolute. @@ -588,13 +586,12 @@ def format_url(self, url_template, **kwargs): def get( self, - url, # type: str - params=None, # type: Optional[Dict[str, str]] - headers=None, # type: Optional[Dict[str, str]] - content=None, # type: Any - form_content=None, # type: Optional[Dict[str, Any]] - ): - # type: (...) -> HttpRequest + url: str, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + content: Any = None, + form_content: Optional[Dict[str, Any]] = None, + ) -> HttpRequest: """Create a GET request object. :param str url: The request URL. @@ -613,14 +610,13 @@ def get( def put( self, - url, # type: str - params=None, # type: Optional[Dict[str, str]] - headers=None, # type: Optional[Dict[str, str]] - content=None, # type: Any - form_content=None, # type: Optional[Dict[str, Any]] - stream_content=None, # type: Any - ): - # type: (...) -> HttpRequest + url: str, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + content: Any = None, + form_content: Optional[Dict[str, Any]] = None, + stream_content: Any = None, + ) -> HttpRequest: """Create a PUT request object. :param str url: The request URL. @@ -638,14 +634,13 @@ def put( def post( self, - url, # type: str - params=None, # type: Optional[Dict[str, str]] - headers=None, # type: Optional[Dict[str, str]] - content=None, # type: Any - form_content=None, # type: Optional[Dict[str, Any]] - stream_content=None, # type: Any - ): - # type: (...) -> HttpRequest + url: str, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + content: Any = None, + form_content: Optional[Dict[str, Any]] = None, + stream_content: Any = None, + ) -> HttpRequest: """Create a POST request object. :param str url: The request URL. @@ -663,14 +658,13 @@ def post( def head( self, - url, # type: str - params=None, # type: Optional[Dict[str, str]] - headers=None, # type: Optional[Dict[str, str]] - content=None, # type: Any - form_content=None, # type: Optional[Dict[str, Any]] - stream_content=None, # type: Any - ): - # type: (...) -> HttpRequest + url: str, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + content: Any = None, + form_content: Optional[Dict[str, Any]] = None, + stream_content: Any = None, + ) -> HttpRequest: """Create a HEAD request object. :param str url: The request URL. @@ -688,14 +682,13 @@ def head( def patch( self, - url, # type: str - params=None, # type: Optional[Dict[str, str]] - headers=None, # type: Optional[Dict[str, str]] - content=None, # type: Any - form_content=None, # type: Optional[Dict[str, Any]] - stream_content=None, # type: Any - ): - # type: (...) -> HttpRequest + url: str, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + content: Any = None, + form_content: Optional[Dict[str, Any]] = None, + stream_content: Any = None, + ) -> HttpRequest: """Create a PATCH request object. :param str url: The request URL. @@ -711,8 +704,14 @@ def patch( ) return request - def delete(self, url, params=None, headers=None, content=None, form_content=None): - # type: (str, Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> HttpRequest + def delete( + self, + url: str, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + content: Any = None, + form_content: Optional[Dict[str, Any]] = None, + ) -> HttpRequest: """Create a DELETE request object. :param str url: The request URL. @@ -728,8 +727,14 @@ def delete(self, url, params=None, headers=None, content=None, form_content=None ) return request - def merge(self, url, params=None, headers=None, content=None, form_content=None): - # type: (str, Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> HttpRequest + def merge( + self, + url: str, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + content: Any = None, + form_content: Optional[Dict[str, Any]] = None, + ) -> HttpRequest: """Create a MERGE request object. :param str url: The request URL. @@ -745,8 +750,13 @@ def merge(self, url, params=None, headers=None, content=None, form_content=None) ) return request - def options(self, url, params=None, headers=None, **kwargs): - # type: (str, Optional[Dict[str, str]], Optional[Dict[str, str]], Any) -> HttpRequest + def options( + self, + url: str, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + **kwargs + ) -> HttpRequest: """Create a OPTIONS request object. :param str url: The request URL. diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py b/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py index 890ce5486c5a..aa749b10b4cb 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py @@ -28,7 +28,7 @@ import abc from collections.abc import AsyncIterator from typing import AsyncIterator as AsyncIteratorType, TypeVar, Generic -from contextlib import AbstractAsyncContextManager # type: ignore +from contextlib import AbstractAsyncContextManager from ._base import ( _HttpResponseBase, diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py index d00051030592..365703bc712a 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py @@ -34,7 +34,7 @@ TYPE_CHECKING, overload, ) -import urllib3 # type: ignore +import urllib3 import requests @@ -115,7 +115,7 @@ async def send( # pylint:disable=invalid-overridden-method :keyword dict proxies: will define the proxy to use. Proxy is a dict (protocol, url) """ - @overload # type: ignore + @overload async def send( # pylint:disable=invalid-overridden-method self, request: "RestHttpRequest", **kwargs: Any ) -> "RestAsyncHttpResponse": @@ -282,4 +282,4 @@ class AsyncioRequestsTransportResponse(AsyncHttpResponse, RequestsTransportRespo def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]: # type: ignore """Generator for streaming request body data.""" - return AsyncioStreamDownloadGenerator(pipeline, self, **kwargs) # type: ignore + return AsyncioStreamDownloadGenerator(pipeline, self, **kwargs) diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py index 16974f9fda41..e28fdd9419f1 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py @@ -24,9 +24,9 @@ # # -------------------------------------------------------------------------- import logging -from typing import Iterator, Optional, Any, Union, TypeVar, overload, TYPE_CHECKING -import urllib3 # type: ignore -from urllib3.util.retry import Retry # type: ignore +from typing import Iterator, Optional, Union, TypeVar, overload, TYPE_CHECKING +import urllib3 +from urllib3.util.retry import Retry from urllib3.exceptions import ( DecodeError as CoreDecodeError, ReadTimeoutError, @@ -139,7 +139,7 @@ def text(self, encoding=None): return self.internal_response.text -class StreamDownloadGenerator(object): +class StreamDownloadGenerator: """Generator for streaming response data. :param pipeline: The pipeline object @@ -207,8 +207,7 @@ def __next__(self): class RequestsTransportResponse(HttpResponse, _RequestsTransportResponseBase): """Streaming of data from the response.""" - def stream_download(self, pipeline, **kwargs): - # type: (PipelineType, **Any) -> Iterator[bytes] + def stream_download(self, pipeline: PipelineType, **kwargs) -> Iterator[bytes]: """Generator for streaming request body data.""" return StreamDownloadGenerator(pipeline, self, **kwargs) @@ -240,23 +239,20 @@ class RequestsTransport(HttpTransport): _protocols = ["http://", "https://"] - def __init__(self, **kwargs): - # type: (Any) -> None + def __init__(self, **kwargs) -> None: self.session = kwargs.get("session", None) self._session_owner = kwargs.get("session_owner", True) self.connection_config = ConnectionConfiguration(**kwargs) self._use_env_settings = kwargs.pop("use_env_settings", True) - def __enter__(self): - # type: () -> RequestsTransport + def __enter__(self) -> "RequestsTransport": self.open() return self def __exit__(self, *args): # pylint: disable=arguments-differ self.close() - def _init_session(self, session): - # type: (requests.Session) -> None + def _init_session(self, session: requests.Session) -> None: """Init session level configuration of requests. This is initialization I want to do once only on a session. @@ -279,8 +275,7 @@ def close(self): self.session = None @overload - def send(self, request, **kwargs): - # type: (HttpRequest, Any) -> HttpResponse + def send(self, request: HttpRequest, **kwargs) -> HttpResponse: """Send a rest request and get back a rest response. :param request: The request object to be sent. @@ -294,8 +289,7 @@ def send(self, request, **kwargs): """ @overload - def send(self, request, **kwargs): - # type: (RestHttpRequest, Any) -> RestHttpResponse + def send(self, request: "RestHttpRequest", **kwargs) -> "RestHttpResponse": """Send an `azure.core.rest` request and get back a rest response. :param request: The request object to be sent. @@ -308,7 +302,7 @@ def send(self, request, **kwargs): :keyword dict proxies: will define the proxy to use. Proxy is a dict (protocol, url) """ - def send(self, request, **kwargs): # type: ignore + def send(self, request, **kwargs): """Send request object according to configuration. :param request: The request object to be sent. @@ -322,7 +316,7 @@ def send(self, request, **kwargs): # type: ignore """ self.open() response = None - error = None # type: Optional[AzureErrorUnion] + error: Optional[AzureErrorUnion] = None try: connection_timeout = kwargs.pop( diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py index e23a4416d478..68aa033b77d3 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py @@ -151,7 +151,7 @@ def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]: # ty return TrioStreamDownloadGenerator(pipeline, self, **kwargs) -class TrioRequestsTransport(RequestsAsyncTransportBase): # type: ignore +class TrioRequestsTransport(RequestsAsyncTransportBase): """Identical implementation as the synchronous RequestsTransport wrapped in a class with asynchronous methods. Uses the third party trio event loop. @@ -190,7 +190,7 @@ async def send( # pylint:disable=invalid-overridden-method :keyword dict proxies: will define the proxy to use. Proxy is a dict (protocol, url) """ - @overload # type: ignore + @overload async def send( # pylint:disable=invalid-overridden-method self, request: "RestHttpRequest", **kwargs: Any ) -> "RestAsyncHttpResponse": @@ -223,7 +223,7 @@ async def send( self.open() trio_limiter = kwargs.get("trio_limiter", None) response = None - error = None # type: Optional[AzureErrorUnion] + error: Optional[AzureErrorUnion] = None data_to_send = await self._retrieve_request_data(request) try: try: diff --git a/sdk/core/azure-core/azure/core/polling/_poller.py b/sdk/core/azure-core/azure/core/polling/_poller.py index e4a829d13061..c7d208d49548 100644 --- a/sdk/core/azure-core/azure/core/polling/_poller.py +++ b/sdk/core/azure-core/azure/core/polling/_poller.py @@ -27,7 +27,7 @@ import logging import threading import uuid -from typing import TypeVar, Generic, Any, Callable, List, Optional +from typing import TypeVar, Generic, Any, Callable, Optional, Tuple, List from azure.core.exceptions import AzureError from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.common import with_current_context @@ -41,28 +41,24 @@ class PollingMethod(Generic[PollingReturnType]): """ABC class for polling method.""" - def initialize(self, client, initial_response, deserialization_callback): - # type: (Any, Any, Any) -> None + def initialize( + self, client: Any, initial_response: Any, deserialization_callback: Any + ) -> None: raise NotImplementedError("This method needs to be implemented") - def run(self): - # type: () -> None + def run(self) -> None: raise NotImplementedError("This method needs to be implemented") - def status(self): - # type: () -> str + def status(self) -> str: raise NotImplementedError("This method needs to be implemented") - def finished(self): - # type: () -> bool + def finished(self) -> bool: raise NotImplementedError("This method needs to be implemented") - def resource(self): - # type: () -> PollingReturnType + def resource(self) -> PollingReturnType: raise NotImplementedError("This method needs to be implemented") - def get_continuation_token(self): - # type() -> str + def get_continuation_token(self) -> str: raise TypeError( "Polling method '{}' doesn't support get_continuation_token".format( self.__class__.__name__ @@ -70,8 +66,9 @@ def get_continuation_token(self): ) @classmethod - def from_continuation_token(cls, continuation_token, **kwargs): - # type(str, Any) -> Tuple[Any, Any, Callable] + def from_continuation_token( + cls, continuation_token: str, **kwargs + ) -> Tuple[Any, Any, Callable]: raise TypeError( "Polling method '{}' doesn't support from_continuation_token".format( cls.__name__ @@ -86,44 +83,41 @@ def __init__(self): self._initial_response = None self._deserialization_callback = None - def initialize(self, _, initial_response, deserialization_callback): - # type: (Any, Any, Callable) -> None + def initialize( + self, _: Any, initial_response: Any, deserialization_callback: Callable + ) -> None: self._initial_response = initial_response self._deserialization_callback = deserialization_callback - def run(self): - # type: () -> None + def run(self) -> None: """Empty run, no polling.""" - def status(self): - # type: () -> str + def status(self) -> str: """Return the current status as a string. :rtype: str """ return "succeeded" - def finished(self): - # type: () -> bool + def finished(self) -> bool: """Is this polling finished? :rtype: bool """ return True - def resource(self): - # type: () -> Any + def resource(self) -> Any: return self._deserialization_callback(self._initial_response) - def get_continuation_token(self): - # type() -> str + def get_continuation_token(self) -> str: import pickle return base64.b64encode(pickle.dumps(self._initial_response)).decode("ascii") @classmethod - def from_continuation_token(cls, continuation_token, **kwargs): - # type(str, Any) -> Tuple + def from_continuation_token( + cls, continuation_token: str, **kwargs + ) -> Tuple[Any, Any, Callable]: try: deserialization_callback = kwargs["deserialization_callback"] except KeyError: @@ -151,10 +145,13 @@ class LROPoller(Generic[PollingReturnType]): """ def __init__( - self, client, initial_response, deserialization_callback, polling_method - ): - # type: (Any, Any, Callable, PollingMethod[PollingReturnType]) -> None - self._callbacks = [] # type: List[Callable] + self, + client: Any, + initial_response: Any, + deserialization_callback: Callable, + polling_method: PollingMethod[PollingReturnType], + ) -> None: + self._callbacks: List[Callable] = [] self._polling_method = polling_method # This implicit test avoids bringing in an explicit dependency on Model directly @@ -211,13 +208,11 @@ def _start(self): call(self._polling_method) callbacks, self._callbacks = self._callbacks, [] - def polling_method(self): - # type: () -> PollingMethod[PollingReturnType] + def polling_method(self) -> PollingMethod[PollingReturnType]: """Return the polling method associated to this poller.""" return self._polling_method - def continuation_token(self): - # type: () -> str + def continuation_token(self) -> str: """Return a continuation token that allows to restart the poller later. :returns: An opaque continuation token @@ -226,8 +221,12 @@ def continuation_token(self): return self._polling_method.get_continuation_token() @classmethod - def from_continuation_token(cls, polling_method, continuation_token, **kwargs): - # type: (PollingMethod[PollingReturnType], str, Any) -> LROPoller[PollingReturnType] + def from_continuation_token( + cls, + polling_method: PollingMethod[PollingReturnType], + continuation_token: str, + **kwargs + ) -> "LROPoller[PollingReturnType]": ( client, initial_response, @@ -235,8 +234,7 @@ def from_continuation_token(cls, polling_method, continuation_token, **kwargs): ) = polling_method.from_continuation_token(continuation_token, **kwargs) return cls(client, initial_response, deserialization_callback, polling_method) - def status(self): - # type: () -> str + def status(self) -> str: """Returns the current status string. :returns: The current status string @@ -244,8 +242,7 @@ def status(self): """ return self._polling_method.status() - def result(self, timeout=None): - # type: (Optional[int]) -> PollingReturnType + def result(self, timeout: Optional[float] = None) -> PollingReturnType: """Return the result of the long running operation, or the result available after the specified timeout. @@ -257,8 +254,7 @@ def result(self, timeout=None): return self._polling_method.resource() @distributed_trace - def wait(self, timeout=None): - # type: (Optional[float]) -> None + def wait(self, timeout: Optional[float] = None) -> None: """Wait on the long running operation for a specified length of time. You can check if this call as ended with timeout with the "done()" method. @@ -277,8 +273,7 @@ def wait(self, timeout=None): except TypeError: # Was None pass - def done(self): - # type: () -> bool + def done(self) -> bool: """Check status of the long running operation. :returns: 'True' if the process has completed, else 'False'. @@ -286,8 +281,7 @@ def done(self): """ return self._thread is None or not self._thread.is_alive() - def add_done_callback(self, func): - # type: (Callable) -> None + def add_done_callback(self, func: Callable) -> None: """Add callback function to be run once the long running operation has completed - regardless of the status of the operation. @@ -300,8 +294,7 @@ def add_done_callback(self, func): # Let's add them still, for consistency (if you wish to access to it for some reasons) self._callbacks.append(func) - def remove_done_callback(self, func): - # type: (Callable) -> None + def remove_done_callback(self, func: Callable) -> None: """Remove a callback from the long running operation. :param callable func: The function to be removed from the callbacks. diff --git a/sdk/core/azure-core/azure/core/polling/base_polling.py b/sdk/core/azure-core/azure/core/polling/base_polling.py index 9a6a729d46f8..7562db2f9131 100644 --- a/sdk/core/azure-core/azure/core/polling/base_polling.py +++ b/sdk/core/azure-core/azure/core/polling/base_polling.py @@ -27,7 +27,7 @@ import base64 import json from enum import Enum -from typing import TYPE_CHECKING, Optional, Any, Union +from typing import TYPE_CHECKING, Optional, Any, Union, Tuple, Callable, Dict from ..exceptions import HttpResponseError, DecodeError from . import PollingMethod @@ -84,8 +84,7 @@ class OperationFailed(Exception): pass -def _as_json(response): - # type: (ResponseType) -> dict +def _as_json(response: "ResponseType") -> Dict[str, Any]: """Assuming this is not empty, return the content as JSON. Result/exceptions is not determined if you call this method without testing _is_empty. @@ -98,8 +97,7 @@ def _as_json(response): raise DecodeError("Error occurred in deserializing the response body.") -def _raise_if_bad_http_status_and_method(response): - # type: (ResponseType) -> None +def _raise_if_bad_http_status_and_method(response: "ResponseType") -> None: """Check response status code is valid. Must be 200, 201, 202, or 204. @@ -116,8 +114,7 @@ def _raise_if_bad_http_status_and_method(response): ) -def _is_empty(response): - # type: (ResponseType) -> bool +def _is_empty(response: "ResponseType") -> bool: """Check if response body contains meaningful content. :rtype: bool @@ -137,20 +134,17 @@ class LongRunningOperation(ABC): """ @abc.abstractmethod - def can_poll(self, pipeline_response): - # type: (PipelineResponseType) -> bool + def can_poll(self, pipeline_response: "PipelineResponseType") -> bool: """Answer if this polling method could be used.""" raise NotImplementedError() @abc.abstractmethod - def get_polling_url(self): - # type: () -> str + def get_polling_url(self) -> str: """Return the polling URL.""" raise NotImplementedError() @abc.abstractmethod - def set_initial_status(self, pipeline_response): - # type: (PipelineResponseType) -> str + def set_initial_status(self, pipeline_response: "PipelineResponseType") -> str: """Process first response after initiating long running operation. :param azure.core.pipeline.PipelineResponse response: initial REST call response. @@ -158,14 +152,14 @@ def set_initial_status(self, pipeline_response): raise NotImplementedError() @abc.abstractmethod - def get_status(self, pipeline_response): - # type: (PipelineResponseType) -> str + def get_status(self, pipeline_response: "PipelineResponseType") -> str: """Return the status string extracted from this response.""" raise NotImplementedError() @abc.abstractmethod - def get_final_get_url(self, pipeline_response): - # type: (PipelineResponseType) -> Optional[str] + def get_final_get_url( + self, pipeline_response: "PipelineResponseType" + ) -> Optional[str]: """If a final GET is needed, returns the URL. :rtype: str @@ -211,13 +205,13 @@ def can_poll(self, pipeline_response): response = pipeline_response.http_response return self._operation_location_header in response.headers - def get_polling_url(self): - # type: () -> str + def get_polling_url(self) -> str: """Return the polling URL.""" return self._async_url - def get_final_get_url(self, pipeline_response): - # type: (PipelineResponseType) -> Optional[str] + def get_final_get_url( + self, pipeline_response: "PipelineResponseType" + ) -> Optional[str]: """If a final GET is needed, returns the URL. :rtype: str @@ -253,8 +247,7 @@ def get_final_get_url(self, pipeline_response): return None - def set_initial_status(self, pipeline_response): - # type: (PipelineResponseType) -> str + def set_initial_status(self, pipeline_response: "PipelineResponseType") -> str: """Process first response after initiating long running operation. :param azure.core.pipeline.PipelineResponse response: initial REST call response. @@ -268,16 +261,14 @@ def set_initial_status(self, pipeline_response): return "InProgress" raise OperationFailed("Operation failed or canceled") - def _set_async_url_if_present(self, response): - # type: (ResponseType) -> None + def _set_async_url_if_present(self, response: "ResponseType") -> None: self._async_url = response.headers[self._operation_location_header] location_url = response.headers.get("location") if location_url: self._location_url = location_url - def get_status(self, pipeline_response): - # type: (PipelineResponseType) -> str + def get_status(self, pipeline_response: "PipelineResponseType") -> str: """Process the latest status update retrieved from an "Operation-Location" header. :param azure.core.pipeline.PipelineResponse response: The response to extract the status. @@ -302,27 +293,25 @@ class LocationPolling(LongRunningOperation): def __init__(self): self._location_url = None - def can_poll(self, pipeline_response): - # type: (PipelineResponseType) -> bool + def can_poll(self, pipeline_response: "PipelineResponseType") -> bool: """Answer if this polling method could be used.""" response = pipeline_response.http_response return "location" in response.headers - def get_polling_url(self): - # type: () -> str + def get_polling_url(self) -> str: """Return the polling URL.""" return self._location_url - def get_final_get_url(self, pipeline_response): - # type: (PipelineResponseType) -> Optional[str] + def get_final_get_url( + self, pipeline_response: "PipelineResponseType" + ) -> Optional[str]: """If a final GET is needed, returns the URL. :rtype: str """ return None - def set_initial_status(self, pipeline_response): - # type: (PipelineResponseType) -> str + def set_initial_status(self, pipeline_response: "PipelineResponseType") -> str: """Process first response after initiating long running operation. :param azure.core.pipeline.PipelineResponse response: initial REST call response. @@ -335,8 +324,7 @@ def set_initial_status(self, pipeline_response): return "InProgress" raise OperationFailed("Operation failed or canceled") - def get_status(self, pipeline_response): - # type: (PipelineResponseType) -> str + def get_status(self, pipeline_response: "PipelineResponseType") -> str: """Process the latest status update retrieved from a 'location' header. :param azure.core.pipeline.PipelineResponse response: latest REST call response. @@ -354,18 +342,15 @@ class StatusCheckPolling(LongRunningOperation): if not other polling are detected and status code is 2xx. """ - def can_poll(self, pipeline_response): - # type: (PipelineResponseType) -> bool + def can_poll(self, pipeline_response: "PipelineResponseType") -> bool: """Answer if this polling method could be used.""" return True - def get_polling_url(self): - # type: () -> str + def get_polling_url(self) -> str: """Return the polling URL.""" raise ValueError("This polling doesn't support polling") - def set_initial_status(self, pipeline_response): - # type: (PipelineResponseType) -> str + def set_initial_status(self, pipeline_response: "PipelineResponseType") -> str: """Process first response after initiating long running operation and set self.status attribute. @@ -373,12 +358,12 @@ def set_initial_status(self, pipeline_response): """ return "Succeeded" - def get_status(self, pipeline_response): - # type: (PipelineResponseType) -> str + def get_status(self, pipeline_response: "PipelineResponseType") -> str: return "Succeeded" - def get_final_get_url(self, pipeline_response): - # type: (PipelineResponseType) -> Optional[str] + def get_final_get_url( + self, pipeline_response: "PipelineResponseType" + ) -> Optional[str]: """If a final GET is needed, returns the URL. :rtype: str @@ -478,15 +463,15 @@ def initialize(self, client, initial_response, deserialization_callback): except OperationFailed as err: raise HttpResponseError(response=initial_response.http_response, error=err) - def get_continuation_token(self): - # type() -> str + def get_continuation_token(self) -> str: import pickle return base64.b64encode(pickle.dumps(self._initial_response)).decode("ascii") @classmethod - def from_continuation_token(cls, continuation_token, **kwargs): - # type(str, Any) -> Tuple + def from_continuation_token( + cls, continuation_token: str, **kwargs + ) -> Tuple[Any, Any, Callable]: try: client = kwargs["client"] except KeyError: @@ -557,8 +542,9 @@ def _poll(self): self._pipeline_response = self.request_status(final_get_url) _raise_if_bad_http_status_and_method(self._pipeline_response.http_response) - def _parse_resource(self, pipeline_response): - # type: (PipelineResponseType) -> Optional[Any] + def _parse_resource( + self, pipeline_response: "PipelineResponseType" + ) -> Optional[Any]: """Assuming this response is a resource, use the deserialization callback to parse it. If body is empty, assuming no resource to return. """ diff --git a/sdk/core/azure-core/azure/core/rest/_aiohttp.py b/sdk/core/azure-core/azure/core/rest/_aiohttp.py index 6e2037ef49eb..6b4775fc054b 100644 --- a/sdk/core/azure-core/azure/core/rest/_aiohttp.py +++ b/sdk/core/azure-core/azure/core/rest/_aiohttp.py @@ -26,6 +26,7 @@ import collections.abc import asyncio from itertools import groupby +from typing import Iterator, cast from multidict import CIMultiDict from ._http_response_impl_async import ( AsyncHttpResponseImpl, @@ -62,17 +63,20 @@ def __init__(self, items): super().__init__(items) self._items = items - def __iter__(self): + def __iter__(self) -> Iterator[str]: for key, _ in self._items: yield key def __contains__(self, key): - for k in self.__iter__(): - if key.lower() == k.lower(): - return True + try: + for k in self.__iter__(): + if cast(str, key).lower() == k.lower(): + return True + except AttributeError: # Catch "lower()" if key not a string + pass return False - def __repr__(self): + def __repr__(self) -> str: return f"dict_keys({list(self.__iter__())})" @@ -179,8 +183,7 @@ def __getstate__(self): return state @property - def content(self): - # type: (...) -> bytes + def content(self) -> bytes: """Return the response's content in bytes.""" if self._content is None: raise ResponseNotReadError(self) diff --git a/sdk/core/azure-core/azure/core/rest/_helpers.py b/sdk/core/azure-core/azure/core/rest/_helpers.py index 7482a1ad0ded..42dbff631b37 100644 --- a/sdk/core/azure-core/azure/core/rest/_helpers.py +++ b/sdk/core/azure-core/azure/core/rest/_helpers.py @@ -35,11 +35,11 @@ Tuple, IO, Any, - Dict, Iterable, MutableMapping, AsyncIterable, cast, + Dict, ) import xml.etree.ElementTree as ET from urllib.parse import urlparse @@ -188,7 +188,7 @@ def decode_to_text(encoding: Optional[str], content: bytes) -> str: return codecs.getincrementaldecoder("utf-8-sig")(errors="replace").decode(content) -class HttpRequestBackcompatMixin(object): +class HttpRequestBackcompatMixin: def __getattr__(self, attr): backcompat_attrs = [ "files", diff --git a/sdk/core/azure-core/azure/core/rest/_http_response_impl.py b/sdk/core/azure-core/azure/core/rest/_http_response_impl.py index 7ee5fc718876..4078696c7e1a 100644 --- a/sdk/core/azure-core/azure/core/rest/_http_response_impl.py +++ b/sdk/core/azure-core/azure/core/rest/_http_response_impl.py @@ -51,7 +51,7 @@ ) -class _HttpResponseBackcompatMixinBase(object): +class _HttpResponseBackcompatMixinBase: """Base Backcompat mixin for responses. This mixin is used by both sync and async HttpResponse @@ -171,28 +171,26 @@ class _HttpResponseBaseImpl( :keyword Callable stream_download_generator: The stream download generator that we use to stream the response. """ - def __init__(self, **kwargs): - # type: (Any) -> None + def __init__(self, **kwargs) -> None: super(_HttpResponseBaseImpl, self).__init__() self._request = kwargs.pop("request") self._internal_response = kwargs.pop("internal_response") - self._block_size = kwargs.pop("block_size", None) or 4096 # type: int - self._status_code = kwargs.pop("status_code") # type: int - self._reason = kwargs.pop("reason") # type: str - self._content_type = kwargs.pop("content_type") # type: str - self._headers = kwargs.pop("headers") # type: MutableMapping[str, str] - self._stream_download_generator = kwargs.pop( + self._block_size: int = kwargs.pop("block_size", None) or 4096 + self._status_code: int = kwargs.pop("status_code") + self._reason: str = kwargs.pop("reason") + self._content_type: str = kwargs.pop("content_type") + self._headers: MutableMapping[str, str] = kwargs.pop("headers") + self._stream_download_generator: Callable = kwargs.pop( "stream_download_generator" - ) # type: Callable + ) self._is_closed = False self._is_stream_consumed = False self._json = None # this is filled in ContentDecodePolicy, when we deserialize - self._content = None # type: Optional[bytes] - self._text = None # type: Optional[str] + self._content: Optional[bytes] = None + self._text: Optional[str] = None @property - def request(self): - # type: (...) -> _HttpRequest + def request(self) -> _HttpRequest: """The request that resulted in this response. :rtype: ~azure.core.rest.HttpRequest @@ -200,8 +198,7 @@ def request(self): return self._request @property - def url(self): - # type: (...) -> str + def url(self) -> str: """The URL that resulted in this response. :rtype: str @@ -209,8 +206,7 @@ def url(self): return self.request.url @property - def is_closed(self): - # type: (...) -> bool + def is_closed(self) -> bool: """Whether the network connection has been closed yet. :rtype: bool @@ -218,14 +214,12 @@ def is_closed(self): return self._is_closed @property - def is_stream_consumed(self): - # type: (...) -> bool + def is_stream_consumed(self) -> bool: """Whether the stream has been consumed""" return self._is_stream_consumed @property - def status_code(self): - # type: (...) -> int + def status_code(self) -> int: """The status code of this response. :rtype: int @@ -233,8 +227,7 @@ def status_code(self): return self._status_code @property - def headers(self): - # type: (...) -> MutableMapping[str, str] + def headers(self) -> MutableMapping[str, str]: """The response headers. :rtype: MutableMapping[str, str] @@ -242,8 +235,7 @@ def headers(self): return self._headers @property - def content_type(self): - # type: (...) -> Optional[str] + def content_type(self) -> Optional[str]: """The content type of the response. :rtype: optional[str] @@ -251,8 +243,7 @@ def content_type(self): return self._content_type @property - def reason(self): - # type: (...) -> str + def reason(self) -> str: """The reason phrase for this response. :rtype: str @@ -260,8 +251,7 @@ def reason(self): return self._reason @property - def encoding(self): - # type: (...) -> Optional[str] + def encoding(self) -> Optional[str]: """Returns the response encoding. :return: The response encoding. We either return the encoding set by the user, @@ -272,19 +262,17 @@ def encoding(self): try: return self._encoding except AttributeError: - self._encoding = get_charset_encoding(self) # type: Optional[str] + self._encoding: Optional[str] = get_charset_encoding(self) return self._encoding @encoding.setter - def encoding(self, value): - # type: (str) -> None + def encoding(self, value: str) -> None: """Sets the response encoding""" self._encoding = value self._text = None # clear text cache self._json = None # clear json cache as well - def text(self, encoding=None): - # type: (Optional[str]) -> str + def text(self, encoding: Optional[str] = None) -> str: """Returns the response body as a string :param optional[str] encoding: The encoding you want to decode the text with. Can @@ -298,8 +286,7 @@ def text(self, encoding=None): self._text = decode_to_text(self.encoding, self.content) return self._text - def json(self): - # type: (...) -> Any + def json(self) -> Any: """Returns the whole body as a json object. :return: The JSON deserialized response body @@ -320,8 +307,7 @@ def _stream_download_check(self): self._is_stream_consumed = True - def raise_for_status(self): - # type: (...) -> None + def raise_for_status(self) -> None: """Raises an HttpResponseError if the response has an error status code. If response is good, does nothing. @@ -330,15 +316,13 @@ def raise_for_status(self): raise HttpResponseError(response=self) @property - def content(self): - # type: (...) -> bytes + def content(self) -> bytes: """Return the response's content in bytes.""" if self._content is None: raise ResponseNotReadError(self) return self._content - def __repr__(self): - # type: (...) -> str + def __repr__(self) -> str: content_type_str = ( ", Content-Type: {}".format(self.content_type) if self.content_type else "" ) @@ -368,26 +352,22 @@ class HttpResponseImpl( :keyword Callable stream_download_generator: The stream download generator that we use to stream the response. """ - def __enter__(self): - # type: (...) -> HttpResponseImpl + def __enter__(self) -> "HttpResponseImpl": return self - def close(self): - # type: (...) -> None + def close(self) -> None: if not self.is_closed: self._is_closed = True self._internal_response.close() - def __exit__(self, *args): - # type: (...) -> None + def __exit__(self, *args) -> None: self.close() def _set_read_checks(self): self._is_stream_consumed = True self.close() - def read(self): - # type: (...) -> bytes + def read(self) -> bytes: """ Read the response's bytes. @@ -397,8 +377,7 @@ def read(self): self._set_read_checks() return self.content - def iter_bytes(self, **kwargs): - # type: (Any) -> Iterator[bytes] + def iter_bytes(self, **kwargs) -> Iterator[bytes]: """Iterates over the response's bytes. Will decompress in the process. :return: An iterator of bytes from the response @@ -418,8 +397,7 @@ def iter_bytes(self, **kwargs): yield part self.close() - def iter_raw(self, **kwargs): - # type: (Any) -> Iterator[bytes] + def iter_raw(self, **kwargs) -> Iterator[bytes]: """Iterates over the response's bytes. Will not decompress in the process. :return: An iterator of bytes from the response diff --git a/sdk/core/azure-core/azure/core/rest/_rest_py3.py b/sdk/core/azure-core/azure/core/rest/_rest_py3.py index a5dc41ed5536..f65dc7537eb1 100644 --- a/sdk/core/azure-core/azure/core/rest/_rest_py3.py +++ b/sdk/core/azure-core/azure/core/rest/_rest_py3.py @@ -34,6 +34,8 @@ Optional, Union, MutableMapping, + Dict, + AsyncContextManager, ) from ..utils._utils import case_insensitive_dict @@ -97,7 +99,7 @@ def __init__( headers: Optional[MutableMapping[str, str]] = None, json: Any = None, content: Optional[ContentType] = None, - data: Optional[dict] = None, + data: Optional[Dict[str, Any]] = None, files: Optional[FilesType] = None, **kwargs ): @@ -107,7 +109,7 @@ def __init__( if params: _format_parameters_helper(self, params) self._files = None - self._data = None # type: Any + self._data: Any = None default_headers = self._set_body( content=content, @@ -128,12 +130,12 @@ def __init__( def _set_body( self, content: Optional[ContentType] = None, - data: Optional[dict] = None, + data: Optional[Dict[str, Any]] = None, files: Optional[FilesType] = None, json: Any = None, ) -> MutableMapping[str, str]: """Sets the body of the request, and returns the default headers""" - default_headers = {} # type: MutableMapping[str, str] + default_headers: MutableMapping[str, str] = {} if data is not None and not isinstance(data, dict): # should we warn? content = data @@ -377,7 +379,7 @@ def __repr__(self) -> str: ) -class AsyncHttpResponse(_HttpResponseBase): +class AsyncHttpResponse(_HttpResponseBase, AsyncContextManager["AsyncHttpResponse"]): """Abstract base class for Async HTTP responses. Use this abstract base class to create your own transport responses. @@ -426,7 +428,3 @@ async def iter_bytes(self, **kwargs: Any) -> AsyncIterator[bytes]: @abc.abstractmethod async def close(self) -> None: ... - - @abc.abstractmethod - async def __aexit__(self, *args) -> None: - ... diff --git a/sdk/core/azure-core/azure/core/serialization.py b/sdk/core/azure-core/azure/core/serialization.py index f444a6dd85db..d3b809c75f3e 100644 --- a/sdk/core/azure-core/azure/core/serialization.py +++ b/sdk/core/azure-core/azure/core/serialization.py @@ -12,17 +12,15 @@ __all__ = ["NULL", "AzureJSONEncoder"] -TZ_UTC = timezone.utc # type: ignore +TZ_UTC = timezone.utc -class _Null(object): +class _Null: """To create a Falsy object""" def __bool__(self): return False - __nonzero__ = __bool__ # Python2 compatibility - NULL = _Null() """ @@ -31,8 +29,7 @@ def __bool__(self): """ -def _timedelta_as_isostr(td): - # type: (timedelta) -> str +def _timedelta_as_isostr(td: timedelta) -> str: """Converts a datetime.timedelta object into an ISO 8601 formatted string, e.g. 'P4DT12H30M05S' Function adapted from the Tin Can Python project: https://github.com/RusticiSoftware/TinCanPython @@ -82,8 +79,7 @@ def _timedelta_as_isostr(td): return "P" + date_str + time_str -def _datetime_as_isostr(dt): - # type: (Union[datetime, date, time, timedelta]) -> str +def _datetime_as_isostr(dt: Union[datetime, date, time, timedelta]) -> str: """Converts a datetime.(datetime|date|time|timedelta) object into an ISO 8601 formatted string""" # First try datetime.datetime if hasattr(dt, "year") and hasattr(dt, "hour"): diff --git a/sdk/core/azure-core/azure/core/settings.py b/sdk/core/azure-core/azure/core/settings.py index e8a0ccf4eb9c..ee804cfc4cb0 100644 --- a/sdk/core/azure-core/azure/core/settings.py +++ b/sdk/core/azure-core/azure/core/settings.py @@ -31,7 +31,7 @@ import logging import os import sys -from typing import Type, Optional, Dict, Callable, cast, Any, Union, TYPE_CHECKING +from typing import Type, Optional, Callable, cast, Union, Dict, TYPE_CHECKING from azure.core.tracing import AbstractSpan if TYPE_CHECKING: @@ -54,8 +54,7 @@ class _Unset(Enum): _unset = _Unset.token -def convert_bool(value): - # type: (Union[str, bool]) -> bool +def convert_bool(value: Union[str, bool]) -> bool: """Convert a string to True or False If a boolean is passed in, it is returned as-is. Otherwise the function @@ -89,8 +88,7 @@ def convert_bool(value): } -def convert_logging(value): - # type: (Union[str, int]) -> int +def convert_logging(value: Union[str, int]) -> int: """Convert a string to a Python logging level If a log level is passed in, it is returned as-is. Otherwise the function @@ -121,8 +119,7 @@ def convert_logging(value): return level -def get_opencensus_span(): - # type: () -> Optional[Type[AbstractSpan]] +def get_opencensus_span() -> Optional[Type[AbstractSpan]]: """Returns the OpenCensusSpan if opencensus is installed else returns None""" try: from azure.core.tracing.ext.opencensus_span import ( # pylint:disable=redefined-outer-name @@ -134,20 +131,20 @@ def get_opencensus_span(): return None -def get_opencensus_span_if_opencensus_is_imported(): - # type: () -> Optional[Type[AbstractSpan]] +def get_opencensus_span_if_opencensus_is_imported() -> Optional[Type[AbstractSpan]]: if "opencensus" not in sys.modules: return None return get_opencensus_span() -_tracing_implementation_dict = { +_tracing_implementation_dict: Dict[str, Callable[[], Optional[Type[AbstractSpan]]]] = { "opencensus": get_opencensus_span -} # type: Dict[str, Callable[[], Optional[Type[AbstractSpan]]]] +} -def convert_tracing_impl(value): - # type: (Union[str, Type[AbstractSpan]]) -> Optional[Type[AbstractSpan]] +def convert_tracing_impl( + value: Union[str, Type[AbstractSpan]] +) -> Optional[Type[AbstractSpan]]: """Convert a string to AbstractSpan If a AbstractSpan is passed in, it is returned as-is. Otherwise the function @@ -168,10 +165,9 @@ def convert_tracing_impl(value): value = cast(Type[AbstractSpan], value) return value - value = cast(str, value) # mypy clarity value = value.lower() get_wrapper_class = _tracing_implementation_dict.get(value, lambda: _unset) - wrapper_class = get_wrapper_class() # type: Union[None, _Unset, Type[AbstractSpan]] + wrapper_class: Optional[Union[_Unset, Type[AbstractSpan]]] = get_wrapper_class() if wrapper_class is _unset: raise ValueError( "Cannot convert {} to AbstractSpan, valid values are: {}".format( @@ -181,7 +177,7 @@ def convert_tracing_impl(value): return wrapper_class -class PrioritizedSetting(object): +class PrioritizedSetting: """Return a value for a global setting according to configuration precedence. The following methods are searched in order for the setting: @@ -220,12 +216,10 @@ def __init__( self._convert = convert if convert else lambda x: x self._user_value = _Unset - def __repr__(self): - # type () -> str + def __repr__(self) -> str: return "PrioritizedSetting(%r)" % self._name def __call__(self, value=None): - # type: (Any) -> Any """Return the setting value according to the standard precedence. :param time: value @@ -264,8 +258,7 @@ def __get__(self, instance, owner): def __set__(self, instance, value): self.set_value(value) - def set_value(self, value): - # type: (Any) -> None + def set_value(self, value) -> None: """Specify a value for this setting programmatically. A value set this way takes precedence over all other methods except @@ -278,8 +271,7 @@ def set_value(self, value): """ self._user_value = value - def unset_value(self): - # () -> None + def unset_value(self) -> None: """Unset the previous user value such that the priority is reset.""" self._user_value = _Unset @@ -292,7 +284,7 @@ def default(self): return self._default -class Settings(object): +class Settings: """Settings for globally used Azure configuration values. You probably don't want to create an instance of this class, but call the singleton instance: diff --git a/sdk/core/azure-core/azure/core/tracing/_abstract_span.py b/sdk/core/azure-core/azure/core/tracing/_abstract_span.py index 39185c0144f1..5e98b8da7fa2 100644 --- a/sdk/core/azure-core/azure/core/tracing/_abstract_span.py +++ b/sdk/core/azure-core/azure/core/tracing/_abstract_span.py @@ -8,11 +8,11 @@ TYPE_CHECKING, Any, Sequence, - Dict, Optional, Union, Callable, ContextManager, + Dict, ) if TYPE_CHECKING: @@ -54,32 +54,28 @@ class AbstractSpan(Protocol): """Wraps a span from a distributed tracing implementation.""" def __init__( # pylint: disable=super-init-not-called - self, span=None, name=None, **kwargs - ): - # type: (Optional[Any], Optional[str], Any) -> None + self, span: Optional[Any] = None, name: Optional[str] = None, **kwargs + ) -> None: """ If a span is given wraps the span. Else a new span is created. The optional argument name is given to the new span. """ - def span(self, name="child_span", **kwargs): - # type: (Optional[str], Any) -> AbstractSpan + def span(self, name: str = "child_span", **kwargs) -> "AbstractSpan": """ Create a child span for the current span and append it to the child spans list. The child span must be wrapped by an implementation of AbstractSpan """ @property - def kind(self): - # type: () -> Optional[SpanKind] + def kind(self) -> Optional[SpanKind]: """Get the span kind of this span. :rtype: SpanKind """ @kind.setter - def kind(self, value): - # type: (SpanKind) -> None + def kind(self, value: SpanKind) -> None: """Set the span kind of this span.""" def __enter__(self): @@ -88,22 +84,18 @@ def __enter__(self): def __exit__(self, exception_type, exception_value, traceback): """Finish a span.""" - def start(self): - # type: () -> None + def start(self) -> None: """Set the start time for a span.""" - def finish(self): - # type: () -> None + def finish(self) -> None: """Set the end time for a span.""" - def to_header(self): - # type: () -> Dict[str, str] + def to_header(self) -> Dict[str, str]: """ Returns a dictionary with the header labels and values. """ - def add_attribute(self, key, value): - # type: (str, Union[str, int]) -> None + def add_attribute(self, key: str, value: Union[str, int]) -> None: """ Add attribute (key value pair) to the current span. @@ -113,8 +105,9 @@ def add_attribute(self, key, value): :type value: str """ - def set_http_attributes(self, request, response=None): - # type: (HttpRequest, Optional[HttpResponseType]) -> None + def set_http_attributes( + self, request: "HttpRequest", response: Optional["HttpResponseType"] = None + ) -> None: """ Add correct attributes for a http client span. @@ -124,8 +117,7 @@ def set_http_attributes(self, request, response=None): :type response: ~azure.core.pipeline.transport.HttpResponse or ~azure.core.pipeline.transport.AsyncHttpResponse """ - def get_trace_parent(self): - # type: () -> str + def get_trace_parent(self) -> str: """Return traceparent string. :return: a traceparent string @@ -133,15 +125,13 @@ def get_trace_parent(self): """ @property - def span_instance(self): - # type: () -> Any + def span_instance(self) -> Any: """ Returns the span the class is wrapping. """ @classmethod - def link(cls, traceparent, attributes=None): - # type: (str, Attributes) -> None + def link(cls, traceparent: str, attributes: Optional["Attributes"] = None) -> None: """ Given a traceparent, extracts the context and links the context to the current tracer. @@ -150,8 +140,9 @@ def link(cls, traceparent, attributes=None): """ @classmethod - def link_from_headers(cls, headers, attributes=None): - # type: (Dict[str, str], Attributes) -> None + def link_from_headers( + cls, headers: Dict[str, str], attributes: Optional["Attributes"] = None + ) -> None: """ Given a dictionary, extracts the context and links the context to the current tracer. @@ -160,44 +151,38 @@ def link_from_headers(cls, headers, attributes=None): """ @classmethod - def get_current_span(cls): - # type: () -> Any + def get_current_span(cls) -> Any: """ Get the current span from the execution context. Return None otherwise. """ @classmethod - def get_current_tracer(cls): - # type: () -> Any + def get_current_tracer(cls) -> Any: """ Get the current tracer from the execution context. Return None otherwise. """ @classmethod - def set_current_span(cls, span): - # type: (Any) -> None + def set_current_span(cls, span: Any) -> None: """ Set the given span as the current span in the execution context. """ @classmethod - def set_current_tracer(cls, tracer): - # type: (Any) -> None + def set_current_tracer(cls, tracer: Any): """ Set the given tracer as the current tracer in the execution context. """ @classmethod - def change_context(cls, span): - # type: (AbstractSpan) -> ContextManager + def change_context(cls, span: "AbstractSpan") -> ContextManager: """Change the context for the life of this context manager. :rtype: contextmanager """ @classmethod - def with_current_context(cls, func): - # type: (Callable) -> Callable + def with_current_context(cls, func: Callable) -> Callable: """Passes the current spans to the new context the function will be run in. :param func: The function that will be run in the new context @@ -222,8 +207,9 @@ class HttpSpanMixin(_MIXIN_BASE): _HTTP_URL = "http.url" _HTTP_STATUS_CODE = "http.status_code" - def set_http_attributes(self, request, response=None): - # type: (HttpRequest, Optional[HttpResponseType]) -> None + def set_http_attributes( + self, request: "HttpRequest", response: Optional["HttpResponseType"] = None + ) -> None: """ Add correct attributes for a http client span. @@ -245,7 +231,7 @@ def set_http_attributes(self, request, response=None): self.add_attribute(self._HTTP_STATUS_CODE, 504) -class Link(object): +class Link: """ This is a wrapper class to link the context to the current tracer. :param headers: A dictionary of the request header as key value pairs. @@ -254,7 +240,8 @@ class Link(object): :type attributes: dict """ - def __init__(self, headers, attributes=None): - # type: (Dict[str, str], Attributes) -> None + def __init__( + self, headers: Dict[str, str], attributes: Optional["Attributes"] = None + ) -> None: self.headers = headers self.attributes = attributes diff --git a/sdk/core/azure-core/azure/core/tracing/common.py b/sdk/core/azure-core/azure/core/tracing/common.py index 6ea943608a19..b23186c855e8 100644 --- a/sdk/core/azure-core/azure/core/tracing/common.py +++ b/sdk/core/azure-core/azure/core/tracing/common.py @@ -25,29 +25,20 @@ # -------------------------------------------------------------------------- """Common functions shared by both the sync and the async decorators.""" from contextlib import contextmanager +from typing import Any, Optional, Callable, Type, Generator import warnings from ._abstract_span import AbstractSpan from ..settings import settings -try: - from typing import TYPE_CHECKING -except ImportError: - TYPE_CHECKING = False - -if TYPE_CHECKING: - from typing import Any, Optional, Union, Callable, List, Type, Generator - - __all__ = [ "change_context", "with_current_context", ] -def get_function_and_class_name(func, *args): - # type: (Callable, List[Any]) -> str +def get_function_and_class_name(func: Callable, *args) -> str: """ Given a function and its unamed arguments, returns class_name.function_name. It assumes the first argument is `self`. If there are no arguments then it only returns the function name. @@ -67,8 +58,7 @@ def get_function_and_class_name(func, *args): @contextmanager -def change_context(span): - # type: (Optional[AbstractSpan]) -> Generator +def change_context(span: Optional[AbstractSpan]) -> Generator: """Execute this block inside the given context and restore it afterwards. This does not start and ends the span, but just make sure all code is executed within @@ -80,7 +70,7 @@ def change_context(span): :type span: AbstractSpan :rtype: contextmanager """ - span_impl_type = settings.tracing_implementation() # type: Type[AbstractSpan] + span_impl_type: Type[AbstractSpan] = settings.tracing_implementation() if span_impl_type is None or span is None: yield else: @@ -101,15 +91,14 @@ def change_context(span): span_impl_type.set_current_span(original_span) -def with_current_context(func): - # type: (Callable) -> Any +def with_current_context(func: Callable) -> Any: """Passes the current spans to the new context the function will be run in. :param func: The function that will be run in the new context :return: The func wrapped with correct context :rtype: callable """ - span_impl_type = settings.tracing_implementation() # type: Type[AbstractSpan] + span_impl_type: Type[AbstractSpan] = settings.tracing_implementation() if span_impl_type is None: return func diff --git a/sdk/core/azure-core/azure/core/utils/_connection_string_parser.py b/sdk/core/azure-core/azure/core/utils/_connection_string_parser.py index a074df43b586..62a2f50989a8 100644 --- a/sdk/core/azure-core/azure/core/utils/_connection_string_parser.py +++ b/sdk/core/azure-core/azure/core/utils/_connection_string_parser.py @@ -7,8 +7,9 @@ from typing import Mapping -def parse_connection_string(conn_str, case_sensitive_keys=False): - # type: (str, bool) -> Mapping[str, str] +def parse_connection_string( + conn_str: str, case_sensitive_keys: bool = False +) -> Mapping[str, str]: """Parses the connection string into a dict of its component parts, with the option of preserving case of keys, and validates that each key in the connection string has a provided value. If case of keys is not preserved (ie. `case_sensitive_keys=False`), then a dict with LOWERCASE KEYS will be returned. diff --git a/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared.py b/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared.py index 6dfc924a4d2a..b9f280390e69 100644 --- a/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared.py +++ b/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared.py @@ -15,14 +15,13 @@ TYPE_CHECKING, cast, IO, - Dict, - List, Union, Tuple, Optional, Callable, Type, Iterator, + List, ) from http.client import HTTPConnection from urllib.parse import urlparse @@ -54,7 +53,7 @@ binary_type = str -class BytesIOSocket(object): +class BytesIOSocket: """Mocking the "makefile" of socket for HTTPResponse. This can be used to create a http.client.HTTPResponse object based on bytes and not a real socket. @@ -100,8 +99,7 @@ def _format_parameters_helper(http_request, params): http_request.url = http_request.url + query -def _pad_attr_name(attr, backcompat_attrs): - # type: (str, List[str]) -> str +def _pad_attr_name(attr: str, backcompat_attrs: List[str]) -> str: """Pad hidden attributes so users can access them. Currently, for our backcompat attributes, we define them @@ -114,8 +112,9 @@ def _pad_attr_name(attr, backcompat_attrs): return "_{}".format(attr) if attr in backcompat_attrs else attr -def _prepare_multipart_body_helper(http_request, content_index=0): - # type: (HTTPRequestType, int) -> int +def _prepare_multipart_body_helper( + http_request: "HTTPRequestType", content_index: int = 0 +) -> int: """Helper for prepare_multipart_body. Will prepare the body of this request according to the multipart information. @@ -133,8 +132,8 @@ def _prepare_multipart_body_helper(http_request, content_index=0): if not http_request.multipart_mixed_info: return 0 - requests = http_request.multipart_mixed_info[0] # type: List[HTTPRequestType] - boundary = http_request.multipart_mixed_info[2] # type: Optional[str] + requests: List["HTTPRequestType"] = http_request.multipart_mixed_info[0] + boundary: Optional[str] = http_request.multipart_mixed_info[2] # Update the main request with the body main_message = Message() @@ -169,7 +168,7 @@ def _prepare_multipart_body_helper(http_request, content_index=0): return content_index -class _HTTPSerializer(HTTPConnection, object): +class _HTTPSerializer(HTTPConnection): """Hacking the stdlib HTTPConnection to serialize HTTP request as strings.""" def __init__(self, *args, **kwargs): @@ -186,8 +185,7 @@ def send(self, data): self.buffer += data -def _serialize_request(http_request): - # type: (HTTPRequestType) -> bytes +def _serialize_request(http_request: "HTTPRequestType") -> bytes: """Helper for serialize. Serialize a request using the application/http spec/ @@ -196,6 +194,8 @@ def _serialize_request(http_request): to serialize. :rtype: bytes """ + if isinstance(http_request.body, dict): + raise TypeError("Cannot serialize an HTTPRequest with dict body.") serializer = _HTTPSerializer() serializer.request( method=http_request.method, @@ -207,13 +207,12 @@ def _serialize_request(http_request): def _decode_parts_helper( - response, # type: PipelineTransportHttpResponseBase - message, # type: Message - http_response_type, # type: Type[PipelineTransportHttpResponseBase] - requests, # type: List[PipelineTransportHttpRequest] - deserialize_response, # type: Callable -): - # type: (...) -> List[PipelineTransportHttpResponse] + response: "PipelineTransportHttpResponseBase", + message: Message, + http_response_type: Type["PipelineTransportHttpResponseBase"], + requests: List["PipelineTransportHttpRequest"], + deserialize_response: Callable, +) -> List["PipelineTransportHttpResponse"]: """Helper for _decode_parts. Rebuild an HTTP response from pure string. @@ -261,15 +260,16 @@ def _get_raw_parts_helper(response, http_response_type): + b"\r\n\r\n" + body_as_bytes ) - message = message_parser(http_body) # type: Message + message: Message = message_parser(http_body) requests = response.request.multipart_mixed_info[0] return response._decode_parts( # pylint: disable=protected-access message, http_response_type, requests ) -def _parts_helper(response): - # type: (PipelineTransportHttpResponse) -> Iterator[PipelineTransportHttpResponse] +def _parts_helper( + response: "PipelineTransportHttpResponse", +) -> Iterator["PipelineTransportHttpResponse"]: """Assuming the content-type is multipart/mixed, will return the parts as an iterator. :rtype: iterator[HttpResponse] @@ -282,9 +282,7 @@ def _parts_helper(response): responses = response._get_raw_parts() # pylint: disable=protected-access if response.request.multipart_mixed_info: - policies = response.request.multipart_mixed_info[ - 1 - ] # type: List[SansIOHTTPPolicy] + policies: List["SansIOHTTPPolicy"] = response.request.multipart_mixed_info[1] # Apply on_response concurrently to all requests import concurrent.futures @@ -309,8 +307,9 @@ def parse_responses(response): return responses -def _format_data_helper(data): - # type: (Union[str, IO]) -> Union[Tuple[None, str], Tuple[Optional[str], IO, str]] +def _format_data_helper( + data: Union[str, IO] +) -> Union[Tuple[None, str], Tuple[Optional[str], IO, str]]: """Helper for _format_data. Format field data according to whether it is a stream or @@ -331,9 +330,10 @@ def _format_data_helper(data): return (None, cast(str, data)) -def _aiohttp_body_helper(response): +def _aiohttp_body_helper( + response: "PipelineTransportAioHttpTransportResponse", +) -> bytes: # pylint: disable=protected-access - # type: (PipelineTransportAioHttpTransportResponse) -> bytes """Helper for body method of Aiohttp responses. Since aiohttp body methods need decompression work synchronously, diff --git a/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared_async.py b/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared_async.py index caa175db5f3a..b30b7c02349b 100644 --- a/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared_async.py +++ b/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared_async.py @@ -6,12 +6,11 @@ # -------------------------------------------------------------------------- import asyncio from collections.abc import AsyncIterator -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, List from ..pipeline import PipelineContext, PipelineRequest, PipelineResponse from ..pipeline._tools_async import await_result as _await_result if TYPE_CHECKING: - from typing import List from ..pipeline.policies import SansIOHTTPPolicy @@ -31,9 +30,9 @@ async def _parse_response(self): http_response_type=self._default_http_response_type ) if self._response.request.multipart_mixed_info: - policies = self._response.request.multipart_mixed_info[ - 1 - ] # type: List[SansIOHTTPPolicy] + policies: List[ + "SansIOHTTPPolicy" + ] = self._response.request.multipart_mixed_info[1] async def parse_responses(response): http_request = response.request diff --git a/sdk/core/azure-core/azure/core/utils/_utils.py b/sdk/core/azure-core/azure/core/utils/_utils.py index 16ef07f07c49..fc4a97433c06 100644 --- a/sdk/core/azure-core/azure/core/utils/_utils.py +++ b/sdk/core/azure-core/azure/core/utils/_utils.py @@ -7,7 +7,6 @@ import datetime from typing import ( Any, - Dict, Iterable, Iterator, Mapping, @@ -15,10 +14,11 @@ Optional, Tuple, Union, + Dict, ) from datetime import timezone -TZ_UTC = timezone.utc # type: ignore +TZ_UTC = timezone.utc class _FixedOffset(datetime.tzinfo): diff --git a/sdk/core/azure-core/pyproject.toml b/sdk/core/azure-core/pyproject.toml index c98953fca712..49ab63c7da4a 100644 --- a/sdk/core/azure-core/pyproject.toml +++ b/sdk/core/azure-core/pyproject.toml @@ -1,8 +1,9 @@ [tool.azure-sdk-build] +mypy = true type_check_samples = false -verifytypes = false +verifytypes = true pyright = false -# For test environments or static checks where a check should be run by default, not explicitly disabling will enable the check. +# For test environments or static checks where a check should be run by default, not explicitly disabling will enable the check. # pylint is enabled by default, so there is no reason for a pylint = true in every pyproject.toml. # # For newly added checks that are not enabled by default, packages should opt IN by " = true". \ No newline at end of file