Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue ref #269: OAuth 2.0 Credential Format for Delta Sharing Python … #309

Merged
merged 6 commits into from
Jun 5, 2023
9 changes: 9 additions & 0 deletions python/NOTICE.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
###############################################################################################
#
# Copyright © 2023, 2023, Oracle and/or its affiliates.
# Issue ref #269: OAuth 2.0 Credential Format for Delta Sharing Client
# Code update:
# - protocol.py
# - rest_client.py
#
###############################################################################################
73 changes: 62 additions & 11 deletions python/delta_sharing/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@

@dataclass(frozen=True)
class DeltaSharingProfile:
CURRENT: ClassVar[int] = 1
CURRENT: ClassVar[int] = 2

share_credentials_version: int
endpoint: str
bearer_token: str
bearer_token: Optional[str] = None
expiration_time: Optional[str] = None
type: Optional[str] = None
token_endpoint: Optional[str] = None
client_id: Optional[str] = None
client_secret: Optional[str] = None
username: Optional[str] = None
password: Optional[str] = None

def __post_init__(self):
if self.share_credentials_version > DeltaSharingProfile.CURRENT:
Expand Down Expand Up @@ -56,16 +62,60 @@ def read_from_file(profile: Union[str, IO, Path]) -> "DeltaSharingProfile":
def from_json(json) -> "DeltaSharingProfile":
if isinstance(json, (str, bytes, bytearray)):
json = loads(json)

share_credentials_version = int(json["shareCredentialsVersion"])
endpoint = json["endpoint"]
if endpoint.endswith("/"):
if endpoint is not None and endpoint.endswith("/"):
endpoint = endpoint[:-1]
expiration_time = json.get("expirationTime")
return DeltaSharingProfile(
share_credentials_version=int(json["shareCredentialsVersion"]),
endpoint=endpoint,
bearer_token=json["bearerToken"],
expiration_time=expiration_time,
)

if share_credentials_version == 1:
return DeltaSharingProfile(
share_credentials_version=share_credentials_version,
endpoint=endpoint,
bearer_token=json["bearerToken"],
expiration_time=json.get("expirationTime"),
)
elif share_credentials_version == 2:
type = json["type"]
if type == "persistent_oauth2.0":
token_endpoint = json["tokenEndpoint"]
if token_endpoint is not None and token_endpoint.endswith("/"):
token_endpoint = token_endpoint[:-1]
return DeltaSharingProfile(
share_credentials_version=share_credentials_version,
type=type,
endpoint=endpoint,
token_endpoint=token_endpoint,
client_id=json["clientID"],
client_secret=json["clientSecret"],
)
elif type == "bearer_token":
return DeltaSharingProfile(
share_credentials_version=share_credentials_version,
type=type,
endpoint=endpoint,
bearer_token=json["bearerToken"],
expiration_time=json.get("expirationTime")
)
elif type == "basic":
return DeltaSharingProfile(
share_credentials_version=share_credentials_version,
type=type,
endpoint=endpoint,
username=json["username"],
password=json["password"],
)
else:
raise ValueError(
"The current release does not supports {type} type. "
"Please check type.")
else:
raise ValueError(
"'shareCredentialsVersion' in the profile is "
f"{share_credentials_version} which is too new. "
f"The current release supports version {DeltaSharingProfile.CURRENT} and below. "
"Please upgrade to a newer release."
)


@dataclass(frozen=True)
Expand Down Expand Up @@ -101,7 +151,8 @@ class Table:
def from_json(json) -> "Table":
if isinstance(json, (str, bytes, bytearray)):
json = loads(json)
return Table(name=json["name"], share=json["share"], schema=json["schema"])
return Table(name=json["name"], share=json["share"],
schema=json["schema"])


@dataclass(frozen=True)
Expand Down
49 changes: 47 additions & 2 deletions python/delta_sharing/rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,61 @@ def __init__(self, profile: DeltaSharingProfile, num_retries=10):
self._profile = profile
self._num_retries = num_retries
self._sleeper = lambda sleep_ms: time.sleep(sleep_ms / 1000)
self.auth_session(profile)

def auth_session(self, profile):
self._session = requests.Session()
self.auth_broker(profile)
if urlparse(profile.endpoint).hostname == "localhost":
self._session.verify = False

def auth_broker(self, profile):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make all internal methods private by prepend the _? e.g. _auth_broker

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

if profile.share_credentials_version == 2:
if profile.type == "persistent_oauth2.0":
self.auth_persistent_oauth2(profile)
elif profile.type == "bearer_token":
self.auth_bearer_token(profile)
elif profile.type == "basic":
self.auth_basic(profile)
else:
self.auth_bearer_token(profile)
else:
self.auth_bearer_token(profile)

def auth_bearer_token(self, profile):
self._session.headers.update(
{
"Authorization": f"Bearer {profile.bearer_token}",
"User-Agent": DataSharingRestClient.USER_AGENT,
}
)
if urlparse(profile.endpoint).hostname == "localhost":
self._session.verify = False

def auth_persistent_oauth2(self, profile):
response = requests.post(profile.token_endpoint,
data={"grant_type": "client_credentials"},
auth=(profile.client_id,
profile.client_secret),)
bearer_token = "{}".format(response.json()["access_token"])

self._session.headers.update(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The content-type and accept headers are not meant to be sent to the API headers but the token exchange endpoints.

Copy link
Contributor Author

@dialberg dialberg May 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

{
"Authorization": f"Bearer {bearer_token}",
"User-Agent": DataSharingRestClient.USER_AGENT,
}
)

def auth_basic(self, profile):
response = requests.post(profile.token_endpoint,
data={"grant_type": "client_credentials"},
auth=(profile.username, profile.password),)
bearer_token = "{}".format(response.json()["access_token"])

self._session.headers.update(
{
"Authorization": f"Bearer {bearer_token}",
"User-Agent": DataSharingRestClient.USER_AGENT,
}
)

@retry_with_exponential_backoff
def list_shares(
Expand Down