Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EventHubs] py3 typing #28385

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 71 additions & 53 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
import functools
import collections
from typing import Any, Dict, Tuple, List, Optional, TYPE_CHECKING, cast, Union
from typing import Any, Dict, Tuple, List, Optional, TYPE_CHECKING, cast, Union, Callable
try:
from typing import TypeAlias # type: ignore
except ImportError:
Expand Down Expand Up @@ -47,29 +47,45 @@

if TYPE_CHECKING:
from azure.core.credentials import TokenCredential
from ._producer_client import EventHubProducerClient
from ._consumer_client import EventHubConsumerClient
from ._transport._base import AmqpTransport
try:
from uamqp import Message as uamqp_Message
from uamqp.authentication import JWTTokenAuth as uamqp_JWTTokenAuth
from uamqp import ( # pylint:disable=unused-import
Message as uamqp_Message,
ReceiveClient as uamqp_ReceiveClient,
SendClient as uamqp_SendClient
)
from uamqp.authentication import JWTTokenAuth as uamqp_JWTTokenAuth # pylint:disable=unused-import
except ImportError:
uamqp_Message = None
uamqp_JWTTokenAuth = None
pass
from ._pyamqp.message import Message
from ._pyamqp.authentication import JWTTokenAuth
from ._pyamqp import SendClient, ReceiveClient

_LOGGER = logging.getLogger(__name__)
_Address = collections.namedtuple("_Address", "hostname path")


def _parse_conn_str(conn_str, **kwargs):
# type: (str, Any) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]]
def _parse_conn_str(
conn_str: str,
**kwargs: Any
) -> Tuple[
str,
Optional[str],
Optional[str],
str,
Optional[str],
Optional[int]
]:
endpoint = None
shared_access_key_name = None
shared_access_key = None
entity_path = None # type: Optional[str]
shared_access_signature = None # type: Optional[str]
entity_path: Optional[str] = None
shared_access_signature: Optional[str] = None
shared_access_signature_expiry = None
eventhub_name = kwargs.pop("eventhub_name", None) # type: Optional[str]
check_case = kwargs.pop("check_case", False) # type: bool
eventhub_name: Optional[str] = kwargs.pop("eventhub_name", None)
check_case: bool = kwargs.pop("check_case", False)
conn_settings = core_parse_connection_string(
conn_str, case_sensitive_keys=check_case
)
Expand Down Expand Up @@ -97,7 +113,7 @@ def _parse_conn_str(conn_str, **kwargs):
try:
# Expiry can be stored in the "se=<timestamp>" clause of the token. ('&'-separated key-value pairs)
shared_access_signature_expiry = int(
shared_access_signature.split("se=")[1].split("&")[0] # type: ignore
shared_access_signature.split("se=")[1].split("&")[0]
)
except (
IndexError,
Expand All @@ -115,7 +131,7 @@ def _parse_conn_str(conn_str, **kwargs):
parsed = urlparse(endpoint)
if not parsed.netloc:
raise ValueError("Invalid Endpoint on the Connection String.")
host = cast(str, parsed.netloc.strip())
host = parsed.netloc.strip()

if any([shared_access_key, shared_access_key_name]) and not all(
[shared_access_key, shared_access_key_name]
Expand Down Expand Up @@ -145,8 +161,7 @@ def _parse_conn_str(conn_str, **kwargs):
)


def _generate_sas_token(uri, policy, key, expiry=None):
# type: (str, str, str, Optional[timedelta]) -> AccessToken
def _generate_sas_token(uri: str, policy: str, key: str, expiry: Optional[timedelta] = None) -> AccessToken:
"""Create a shared access signature token as a string literal.
:returns: SAS token as string literal.
:rtype: str
Expand All @@ -159,8 +174,8 @@ def _generate_sas_token(uri, policy, key, expiry=None):
token = generate_sas_token(uri, policy, key, abs_expiry)
return AccessToken(token=token, expires_on=abs_expiry)

def _build_uri(address, entity):
# type: (str, Optional[str]) -> str

def _build_uri(address: str, entity: Optional[str]) -> str:
parsed = urlparse(address)
if parsed.path:
return address
Expand All @@ -185,14 +200,14 @@ class EventHubSharedKeyCredential(object):
:param str key: The shared access key.
"""

def __init__(self, policy, key):
# type: (str, str) -> None
def __init__(self, policy: str, key: str) -> None:
self.policy = policy
self.key = key
self.token_type = b"servicebus.windows.net:sastoken"

def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (str, Any) -> AccessToken
def get_token(
self, *scopes: str, **kwargs: Any # pylint:disable=unused-argument
) -> AccessToken:
if not scopes:
raise ValueError("No token scope provided.")
return _generate_sas_token(scopes[0], self.policy, self.key)
Expand All @@ -205,13 +220,13 @@ class EventhubAzureNamedKeyTokenCredential(object):
:type credential: ~azure.core.credentials.AzureNamedKeyCredential
"""

def __init__(self, azure_named_key_credential):
# type: (AzureNamedKeyCredential) -> None
def __init__(self, azure_named_key_credential: AzureNamedKeyCredential) -> None:
self._credential = azure_named_key_credential
self.token_type = b"servicebus.windows.net:sastoken"

def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (str, Any) -> AccessToken
def get_token(
self, *scopes: str, **kwargs: Any # pylint:disable=unused-argument
) -> AccessToken:
if not scopes:
raise ValueError("No token scope provided.")
name, key = self._credential.named_key
Expand All @@ -225,8 +240,7 @@ class EventHubSASTokenCredential(object):
:param int expiry: The epoch timestamp
"""

def __init__(self, token, expiry):
# type: (str, int) -> None
def __init__(self, token: str, expiry: int) -> None:
"""
:param str token: The shared access token string
:param float expiry: The epoch timestamp
Expand All @@ -235,8 +249,9 @@ def __init__(self, token, expiry):
self.expiry = expiry
self.token_type = b"servicebus.windows.net:sastoken"

def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (str, Any) -> AccessToken
def get_token(
self, *scopes: str, **kwargs: Any # pylint:disable=unused-argument
) -> AccessToken:
"""
This method is automatically called when token is about to expire.
"""
Expand All @@ -251,8 +266,7 @@ class EventhubAzureSasTokenCredential(object):
:type azure_sas_credential: ~azure.core.credentials.AzureSasCredential
"""

def __init__(self, azure_sas_credential):
# type: (AzureSasCredential) -> None
def __init__(self, azure_sas_credential: AzureSasCredential) -> None:
"""The shared access token credential used for authentication
when AzureSasCredential is provided.

Expand All @@ -262,8 +276,9 @@ def __init__(self, azure_sas_credential):
self._credential = azure_sas_credential
self.token_type = b"servicebus.windows.net:sastoken"

def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (str, Any) -> AccessToken
def get_token(
self, *scopes: str, **kwargs: Any # pylint:disable=unused-argument
) -> AccessToken:
"""
This method is automatically called when token is about to expire.
"""
Expand Down Expand Up @@ -302,12 +317,14 @@ def __init__(
path = "/" + eventhub_name if eventhub_name else ""
self._address = _Address(hostname=fully_qualified_namespace, path=path)
self._container_id = CONTAINER_PREFIX + str(uuid.uuid4())[:8]
self._credential: Union[EventHubSharedKeyCredential, TokenCredential]
if isinstance(credential, AzureSasCredential):
self._credential = EventhubAzureSasTokenCredential(credential)
elif isinstance(credential, AzureNamedKeyCredential):
self._credential = EventhubAzureNamedKeyTokenCredential(credential) # type: ignore
self._credential = EventhubAzureNamedKeyTokenCredential(credential)
else:
self._credential = credential # type: ignore
self._credential = credential

self._keep_alive = kwargs.get("keep_alive", 30)
self._auto_reconnect = kwargs.get("auto_reconnect", True)
self._auth_uri = f"sb://{self._address.hostname}{self._address.path}"
Expand All @@ -324,8 +341,7 @@ def __init__(
self._idle_timeout = kwargs.get("idle_timeout", None)

@staticmethod
def _from_connection_string(conn_str, **kwargs):
# type: (str, Any) -> Dict[str, Any]
def _from_connection_string(conn_str: str, **kwargs: Any) -> Dict[str, Any]:
host, policy, key, entity, token, token_expiry = _parse_conn_str(
conn_str, **kwargs
)
Expand All @@ -337,7 +353,7 @@ def _from_connection_string(conn_str, **kwargs):
kwargs["credential"] = EventHubSharedKeyCredential(policy, key)
return kwargs

def _create_auth(self) -> Union[uamqp_JWTTokenAuth, JWTTokenAuth]:
def _create_auth(self) -> Union["uamqp_JWTTokenAuth", JWTTokenAuth]:
"""
Create an ~uamqp.authentication.SASTokenAuth instance
to authenticate the session.
Expand All @@ -363,14 +379,16 @@ def _create_auth(self) -> Union[uamqp_JWTTokenAuth, JWTTokenAuth]:
update_token=False,
)

def _close_connection(self):
# type: () -> None
def _close_connection(self) -> None:
self._conn_manager.reset_connection_if_broken()

def _backoff(
self, retried_times, last_exception, timeout_time=None, entity_name=None
):
# type: (int, Exception, Optional[int], Optional[str]) -> None
self,
retried_times: int,
last_exception: Exception,
timeout_time: Optional[int] = None,
entity_name: Optional[str] = None
) -> None:
entity_name = entity_name or self._container_id
backoff = _get_backoff_time(
self._config.retry_mode,
Expand All @@ -396,7 +414,7 @@ def _backoff(
raise last_exception

def _management_request(
self, mgmt_msg: Union[uamqp_Message, Message], op_type: bytes
self, mgmt_msg: Union["uamqp_Message", Message], op_type: bytes
) -> Any:
# pylint:disable=assignment-from-none
retried_times = 0
Expand Down Expand Up @@ -472,20 +490,18 @@ def _get_eventhub_properties(self) -> Dict[str, Any]:
]
return output

def _get_partition_ids(self):
# type:() -> List[str]
def _get_partition_ids(self) -> List[str]:
return self._get_eventhub_properties()["partition_ids"]

def _get_partition_properties(self, partition_id):
# type:(str) -> Dict[str, Any]
def _get_partition_properties(self, partition_id: str) -> Dict[str, Any]:
mgmt_msg = self._amqp_transport.build_message(
application_properties={
"name": self.eventhub_name,
"partition": partition_id,
}
)
response = self._management_request(mgmt_msg, op_type=MGMT_PARTITION_OPERATION)
partition_info = response.value # type: Dict[bytes, Any]
partition_info: Dict[bytes, Any] = response.value
output = {}
if partition_info:
output["eventhub_name"] = partition_info[b"name"].decode("utf-8")
Expand All @@ -505,8 +521,7 @@ def _get_partition_properties(self, partition_id):
)
return output

def _close(self):
# type:() -> None
def _close(self) -> None:
self._conn_manager.close_connection()


Expand All @@ -521,6 +536,7 @@ def _create_handler(self, auth):
pass

def _check_closed(self):
self._name: str
if self.closed:
raise ClientClosedError(
f"{self._name} has been closed. Please create a new one to handle event data."
Expand All @@ -529,6 +545,9 @@ def _check_closed(self):
def _open(self):
"""Open the EventHubConsumer/EventHubProducer using the supplied connection."""
# pylint: disable=protected-access
self._handler: Union["uamqp_SendClient", "uamqp_ReceiveClient", SendClient, ReceiveClient]
self._client: Union["EventHubProducerClient", "EventHubConsumerClient"]
self._amqp_transport: "AmqpTransport"
if not self.running:
if self._handler:
self._handler.close()
Expand Down Expand Up @@ -600,8 +619,7 @@ def _do_retryable_operation(self, operation, timeout=None, **kwargs):
)
raise last_exception

def close(self):
# type:() -> None
def close(self) -> None:
"""
Close down the handler. If the handler has already closed,
this will be a no op.
Expand Down
Loading