diff --git a/python/NOTICE.txt b/python/NOTICE.txt new file mode 100644 index 000000000..957bf3dba --- /dev/null +++ b/python/NOTICE.txt @@ -0,0 +1,9 @@ +############################################################################################### +# +# Copyright © 2023, 2023, Oracle and/or its affiliates. +# Issue ref #269: OAuth 2.0 Credential Format for Delta Sharing Client +# Code update: +# - protocol.py +# - rest_client.py +# +############################################################################################### \ No newline at end of file diff --git a/python/delta_sharing/protocol.py b/python/delta_sharing/protocol.py index 289133864..c955ef714 100644 --- a/python/delta_sharing/protocol.py +++ b/python/delta_sharing/protocol.py @@ -23,12 +23,18 @@ @dataclass(frozen=True) class DeltaSharingProfile: - CURRENT: ClassVar[int] = 1 + CURRENT: ClassVar[int] = 2 share_credentials_version: int endpoint: str - bearer_token: str + bearer_token: Optional[str] = None expiration_time: Optional[str] = None + type: Optional[str] = None + token_endpoint: Optional[str] = None + client_id: Optional[str] = None + client_secret: Optional[str] = None + username: Optional[str] = None + password: Optional[str] = None def __post_init__(self): if self.share_credentials_version > DeltaSharingProfile.CURRENT: @@ -56,16 +62,60 @@ def read_from_file(profile: Union[str, IO, Path]) -> "DeltaSharingProfile": def from_json(json) -> "DeltaSharingProfile": if isinstance(json, (str, bytes, bytearray)): json = loads(json) + + share_credentials_version = int(json["shareCredentialsVersion"]) endpoint = json["endpoint"] - if endpoint.endswith("/"): + if endpoint is not None and endpoint.endswith("/"): endpoint = endpoint[:-1] - expiration_time = json.get("expirationTime") - return DeltaSharingProfile( - share_credentials_version=int(json["shareCredentialsVersion"]), - endpoint=endpoint, - bearer_token=json["bearerToken"], - expiration_time=expiration_time, - ) + + if share_credentials_version == 1: + return DeltaSharingProfile( + share_credentials_version=share_credentials_version, + endpoint=endpoint, + bearer_token=json["bearerToken"], + expiration_time=json.get("expirationTime"), + ) + elif share_credentials_version == 2: + type = json["type"] + if type == "persistent_oauth2.0": + token_endpoint = json["tokenEndpoint"] + if token_endpoint is not None and token_endpoint.endswith("/"): + token_endpoint = token_endpoint[:-1] + return DeltaSharingProfile( + share_credentials_version=share_credentials_version, + type=type, + endpoint=endpoint, + token_endpoint=token_endpoint, + client_id=json["clientID"], + client_secret=json["clientSecret"], + ) + elif type == "bearer_token": + return DeltaSharingProfile( + share_credentials_version=share_credentials_version, + type=type, + endpoint=endpoint, + bearer_token=json["bearerToken"], + expiration_time=json.get("expirationTime") + ) + elif type == "basic": + return DeltaSharingProfile( + share_credentials_version=share_credentials_version, + type=type, + endpoint=endpoint, + username=json["username"], + password=json["password"], + ) + else: + raise ValueError( + "The current release does not supports {type} type. " + "Please check type.") + else: + raise ValueError( + "'shareCredentialsVersion' in the profile is " + f"{share_credentials_version} which is too new. " + f"The current release supports version {DeltaSharingProfile.CURRENT} and below. " + "Please upgrade to a newer release." + ) @dataclass(frozen=True) @@ -101,7 +151,8 @@ class Table: def from_json(json) -> "Table": if isinstance(json, (str, bytes, bytearray)): json = loads(json) - return Table(name=json["name"], share=json["share"], schema=json["schema"]) + return Table(name=json["name"], share=json["share"], + schema=json["schema"]) @dataclass(frozen=True) diff --git a/python/delta_sharing/rest_client.py b/python/delta_sharing/rest_client.py index 55067991d..07e9d7cd3 100644 --- a/python/delta_sharing/rest_client.py +++ b/python/delta_sharing/rest_client.py @@ -143,16 +143,61 @@ def __init__(self, profile: DeltaSharingProfile, num_retries=10): self._profile = profile self._num_retries = num_retries self._sleeper = lambda sleep_ms: time.sleep(sleep_ms / 1000) + self.auth_session(profile) + def auth_session(self, profile) -> requests.Session: self._session = requests.Session() + self.auth_broker(profile) + if urlparse(profile.endpoint).hostname == "localhost": + self._session.verify = False + + def auth_broker(self, profile): + if profile.share_credentials_version == 2: + if profile.type == "persistent_oauth2.0": + self.auth_persistent_oauth2(profile) + elif profile.type == "bearer_token": + self.auth_bearer_token(profile) + elif profile.type == "basic": + self.auth_basic(profile) + else: + self.auth_bearer_token(profile) + else: + self.auth_bearer_token(profile) + + def auth_bearer_token(self, profile): self._session.headers.update( { "Authorization": f"Bearer {profile.bearer_token}", "User-Agent": DataSharingRestClient.USER_AGENT, } ) - if urlparse(profile.endpoint).hostname == "localhost": - self._session.verify = False + + def auth_persistent_oauth2(self, profile): + response = requests.post(profile.token_endpoint, + data={"grant_type": "client_credentials"}, + auth=(profile.client_id, + profile.client_secret),) + bearer_token = "{}".format(response.json()["access_token"]) + + self._session.headers.update( + { + "Authorization": f"Bearer {bearer_token}", + "User-Agent": DataSharingRestClient.USER_AGENT, + } + ) + + def auth_basic(self, profile): + response = requests.post(profile.token_endpoint, + data={"grant_type": "client_credentials"}, + auth=(profile.username, profile.password),) + bearer_token = "{}".format(response.json()["access_token"]) + + self._session.headers.update( + { + "Authorization": f"Bearer {bearer_token}", + "User-Agent": DataSharingRestClient.USER_AGENT, + } + ) @retry_with_exponential_backoff def list_shares(