Skip to content

Commit

Permalink
Issue ref delta-io#269: OAuth 2.0 Credential Format for Delta Sharing…
Browse files Browse the repository at this point in the history
… Python Client Pull Request

Signed-off-by: Dima Alberg <[email protected]>
  • Loading branch information
dialberg committed May 15, 2023
1 parent 7f01260 commit 372f0a3
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 13 deletions.
9 changes: 9 additions & 0 deletions python/NOTICE.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
###############################################################################################
#
# Copyright © 2023, 2023, Oracle and/or its affiliates.
# Issue ref #269: OAuth 2.0 Credential Format for Delta Sharing Client
# Code update:
# - protocol.py
# - rest_client.py
#
###############################################################################################
73 changes: 62 additions & 11 deletions python/delta_sharing/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@

@dataclass(frozen=True)
class DeltaSharingProfile:
CURRENT: ClassVar[int] = 1
CURRENT: ClassVar[int] = 2

share_credentials_version: int
endpoint: str
bearer_token: str
bearer_token: Optional[str] = None
expiration_time: Optional[str] = None
type: Optional[str] = None
token_endpoint: Optional[str] = None
client_id: Optional[str] = None
client_secret: Optional[str] = None
username: Optional[str] = None
password: Optional[str] = None

def __post_init__(self):
if self.share_credentials_version > DeltaSharingProfile.CURRENT:
Expand Down Expand Up @@ -56,16 +62,60 @@ def read_from_file(profile: Union[str, IO, Path]) -> "DeltaSharingProfile":
def from_json(json) -> "DeltaSharingProfile":
if isinstance(json, (str, bytes, bytearray)):
json = loads(json)

share_credentials_version = int(json["shareCredentialsVersion"])
endpoint = json["endpoint"]
if endpoint.endswith("/"):
if endpoint is not None and endpoint.endswith("/"):
endpoint = endpoint[:-1]
expiration_time = json.get("expirationTime")
return DeltaSharingProfile(
share_credentials_version=int(json["shareCredentialsVersion"]),
endpoint=endpoint,
bearer_token=json["bearerToken"],
expiration_time=expiration_time,
)

if share_credentials_version == 1:
return DeltaSharingProfile(
share_credentials_version=share_credentials_version,
endpoint=endpoint,
bearer_token=json["bearerToken"],
expiration_time=json.get("expirationTime"),
)
elif share_credentials_version == 2:
type = json["type"]
if type == "persistent_oauth2.0":
token_endpoint = json["tokenEndpoint"]
if token_endpoint is not None and token_endpoint.endswith("/"):
token_endpoint = token_endpoint[:-1]
return DeltaSharingProfile(
share_credentials_version=share_credentials_version,
type=type,
endpoint=endpoint,
token_endpoint=token_endpoint,
client_id=json["clientID"],
client_secret=json["clientSecret"],
)
elif type == "bearer_token":
return DeltaSharingProfile(
share_credentials_version=share_credentials_version,
type=type,
endpoint=endpoint,
bearer_token=json["bearerToken"],
expiration_time=json.get("expirationTime")
)
elif type == "basic":
return DeltaSharingProfile(
share_credentials_version=share_credentials_version,
type=type,
endpoint=endpoint,
username=json["username"],
password=json["password"],
)
else:
raise ValueError(
"The current release does not supports {type} type. "
"Please check type.")
else:
raise ValueError(
"'shareCredentialsVersion' in the profile is "
f"{share_credentials_version} which is too new. "
f"The current release supports version {DeltaSharingProfile.CURRENT} and below. "
"Please upgrade to a newer release."
)


@dataclass(frozen=True)
Expand Down Expand Up @@ -101,7 +151,8 @@ class Table:
def from_json(json) -> "Table":
if isinstance(json, (str, bytes, bytearray)):
json = loads(json)
return Table(name=json["name"], share=json["share"], schema=json["schema"])
return Table(name=json["name"], share=json["share"],
schema=json["schema"])


@dataclass(frozen=True)
Expand Down
49 changes: 47 additions & 2 deletions python/delta_sharing/rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,61 @@ def __init__(self, profile: DeltaSharingProfile, num_retries=10):
self._profile = profile
self._num_retries = num_retries
self._sleeper = lambda sleep_ms: time.sleep(sleep_ms / 1000)
self.auth_session(profile)

def auth_session(self, profile) -> requests.Session:
self._session = requests.Session()
self.auth_broker(profile)
if urlparse(profile.endpoint).hostname == "localhost":
self._session.verify = False

def auth_broker(self, profile):
if profile.share_credentials_version == 2:
if profile.type == "persistent_oauth2.0":
self.auth_persistent_oauth2(profile)
elif profile.type == "bearer_token":
self.auth_bearer_token(profile)
elif profile.type == "basic":
self.auth_basic(profile)
else:
self.auth_bearer_token(profile)
else:
self.auth_bearer_token(profile)

def auth_bearer_token(self, profile):
self._session.headers.update(
{
"Authorization": f"Bearer {profile.bearer_token}",
"User-Agent": DataSharingRestClient.USER_AGENT,
}
)
if urlparse(profile.endpoint).hostname == "localhost":
self._session.verify = False

def auth_persistent_oauth2(self, profile):
response = requests.post(profile.token_endpoint,
data={"grant_type": "client_credentials"},
auth=(profile.client_id,
profile.client_secret),)
bearer_token = "{}".format(response.json()["access_token"])

self._session.headers.update(
{
"Authorization": f"Bearer {bearer_token}",
"User-Agent": DataSharingRestClient.USER_AGENT,
}
)

def auth_basic(self, profile):
response = requests.post(profile.token_endpoint,
data={"grant_type": "client_credentials"},
auth=(profile.username, profile.password),)
bearer_token = "{}".format(response.json()["access_token"])

self._session.headers.update(
{
"Authorization": f"Bearer {bearer_token}",
"User-Agent": DataSharingRestClient.USER_AGENT,
}
)

@retry_with_exponential_backoff
def list_shares(
Expand Down

0 comments on commit 372f0a3

Please sign in to comment.