Skip to content

Commit

Permalink
sdk/python: Refactored request client to move session specific proper…
Browse files Browse the repository at this point in the history
…ties to SessionManager

Signed-off-by: Aaron Wilson <[email protected]>
  • Loading branch information
aaronnw committed Aug 19, 2024
1 parent 1d41ece commit e743e01
Show file tree
Hide file tree
Showing 12 changed files with 315 additions and 186 deletions.
16 changes: 9 additions & 7 deletions python/aistore/pytorch/base_iter_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""

from itertools import islice
from typing import List, Union, Iterable, Dict, Iterator, Tuple
from aistore.sdk.ais_source import AISSource
from torch.utils.data import IterableDataset
from abc import ABC, abstractmethod
from aistore.pytorch.worker_request_client import WorkerRequestClient
import torch.utils.data as torch_utils
from itertools import islice
from abc import ABC, abstractmethod
from aistore.sdk.ais_source import AISSource
from aistore.pytorch.worker_session_manager import WorkerSessionManager


class AISBaseIterDataset(ABC, IterableDataset):
class AISBaseIterDataset(ABC, torch_utils.IterableDataset):
"""
A base class for creating AIS Iterable Datasets. Should not be instantiated directly. Subclasses
should implement :meth:`__iter__` which returns the samples from the dataset and can optionally
Expand Down Expand Up @@ -51,7 +50,10 @@ def _create_objects_iter(self) -> Iterable:
"""
for source in self._ais_source_list:
# Add pytorch worker support to the internal request client
source.client = WorkerRequestClient(source.client)
# TODO: Do not modify the provided source client
source.client.session_manager = WorkerSessionManager(
source.client.session_manager
)
if source not in self._prefix_map or self._prefix_map[source] is None:
for sample in source.list_all_objects_iter(prefix=""):
yield sample
Expand Down
11 changes: 7 additions & 4 deletions python/aistore/pytorch/base_map_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
"""

from typing import List, Union, Dict
from abc import ABC, abstractmethod
from torch.utils.data import Dataset
from aistore.sdk.ais_source import AISSource
from aistore.sdk.object import Object
from torch.utils.data import Dataset
from abc import ABC, abstractmethod
from aistore.pytorch.worker_request_client import WorkerRequestClient
from aistore.pytorch.worker_session_manager import WorkerSessionManager


class AISBaseMapDataset(ABC, Dataset):
Expand Down Expand Up @@ -52,7 +52,10 @@ def _create_objects_list(self) -> List[Object]:

for source in self._ais_source_list:
# Add pytorch worker support to the internal request client
source.client = WorkerRequestClient(source.client)
# TODO: Do not modify the provided source client
source.client.session_manager = WorkerSessionManager(
source.client.session_manager
)
if source not in self._prefix_map or self._prefix_map[source] is None:
samples.extend(list(source.list_all_objects_iter(prefix="")))
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,25 @@
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""

from aistore.sdk.request_client import RequestClient
from torch.utils.data import get_worker_info

from aistore.sdk.session_manager import SessionManager

class WorkerRequestClient(RequestClient):

class WorkerSessionManager(SessionManager):
"""
Extension that supports Pytorch and multiple workers of internal client for
buckets, objects, jobs, etc. to use for making requests to an AIS cluster.
Args:
client (RequestClient): Existing RequestClient to replace
session_manager (SessionManager): Existing SessionManager to replace
"""

def __init__(self, client: RequestClient):
def __init__(self, session_manager: SessionManager):
super().__init__(
endpoint=client._endpoint,
skip_verify=client._skip_verify,
ca_cert=client._ca_cert,
timeout=client._timeout,
retry=client._retry,
token=client._token,
retry=session_manager.retry,
ca_cert=session_manager.ca_cert,
skip_verify=session_manager.skip_verify,
)
self._worker_sessions = {}

Expand All @@ -43,5 +41,5 @@ def session(self):
return self._session
# if we only have one session but multiple workers, create more
if worker_info.id not in self._worker_sessions:
self._worker_sessions[worker_info.id] = self._create_new_session()
self._worker_sessions[worker_info.id] = self._create_session()
return self._worker_sessions[worker_info.id]
8 changes: 5 additions & 3 deletions python/aistore/sdk/authn/authn_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from aistore.sdk.authn.types import TokenMsg, LoginMsg
from aistore.sdk.authn.cluster_manager import ClusterManager
from aistore.sdk.authn.role_manager import RoleManager
from aistore.sdk.session_manager import SessionManager
from aistore.sdk.authn.user_manager import UserManager

# logging
Expand Down Expand Up @@ -49,12 +50,13 @@ def __init__(
token: Optional[str] = None,
):
logger.info("Initializing AuthNClient")
session_manager = SessionManager(
retry=retry, ca_cert=ca_cert, skip_verify=skip_verify
)
self._request_client = RequestClient(
endpoint=endpoint,
skip_verify=skip_verify,
ca_cert=ca_cert,
session_manager=session_manager,
timeout=timeout,
retry=retry,
token=token,
)
logger.info("AuthNClient initialized with endpoint: %s", endpoint)
Expand Down
9 changes: 8 additions & 1 deletion python/aistore/sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from aistore.sdk.cluster import Cluster
from aistore.sdk.dsort import Dsort
from aistore.sdk.request_client import RequestClient
from aistore.sdk.session_manager import SessionManager
from aistore.sdk.types import Namespace
from aistore.sdk.job import Job
from aistore.sdk.etl import Etl
Expand Down Expand Up @@ -47,8 +48,14 @@ def __init__(
retry: Retry = None,
token: str = None,
):
session_manager = SessionManager(
retry=retry, ca_cert=ca_cert, skip_verify=skip_verify
)
self._request_client = RequestClient(
endpoint, skip_verify, ca_cert, timeout, retry, token
endpoint=endpoint,
session_manager=session_manager,
timeout=timeout,
token=token,
)

def bucket(
Expand Down
87 changes: 27 additions & 60 deletions python/aistore/sdk/request_client.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,22 @@
#
# Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
#
import os
from urllib.parse import urljoin, urlencode
from typing import Optional, TypeVar, Tuple, Type, Union, Any, Dict
from requests import Session, Response
from requests.adapters import HTTPAdapter
from urllib3 import Retry
from typing import TypeVar, Type, Any, Dict, Optional, Tuple, Union
from requests import Response

from aistore.sdk.const import (
JSON_CONTENT_TYPE,
HEADER_USER_AGENT,
USER_AGENT_BASE,
HEADER_CONTENT_TYPE,
AIS_CLIENT_CA,
HEADER_AUTHORIZATION,
)
from aistore.sdk.session_manager import SessionManager
from aistore.sdk.utils import handle_errors, decode_response
from aistore.version import __version__ as sdk_version

T = TypeVar("T")
DEFAULT_RETRY = Retry(total=6, connect=3, backoff_factor=1)


# pylint: disable=unused-variable, duplicate-code, too-many-arguments
Expand All @@ -30,66 +26,24 @@ class RequestClient:
Args:
endpoint (str): AIStore endpoint
skip_verify (bool, optional): If True, skip SSL certificate verification. Defaults to False.
ca_cert (str, optional): Path to a CA certificate file for SSL verification.
session_manager (SessionManager): SessionManager for creating and accessing requests session
timeout (Union[float, Tuple[float, float], None], optional): Request timeout in seconds; a single float
for both connect/read timeouts (e.g., 5.0), a tuple for separate connect/read timeouts (e.g., (3.0, 10.0)),
or None to disable timeout.
retry (urllib3.Retry, optional): Retry configuration object from the urllib3 library.
Default: Retry(total=6, connect=3, backoff_factor=1).
token (str, optional): Authorization token.
"""

# pylint:disable=too-many-instance-attributes
def __init__(
self,
endpoint: str,
skip_verify: bool = False,
ca_cert: Optional[str] = None,
session_manager: SessionManager,
timeout: Optional[Union[float, Tuple[float, float]]] = None,
retry: Retry = DEFAULT_RETRY,
token: str = None,
):
self._endpoint = endpoint
self._base_url = urljoin(endpoint, "v1")
self._skip_verify = skip_verify
self._ca_cert = ca_cert
self._session_manager = session_manager
self._token = token
self._timeout = timeout
self._retry = retry
self._session = self._create_new_session()

def _create_new_session(self) -> Session:
"""
Creates a new requests session for HTTP requests.
Returns:
New HTTP request Session
"""
request_session = Session()
if "https" in self._endpoint:
self._set_session_verification(request_session)
for protocol in ("http://", "https://"):
request_session.mount(protocol, HTTPAdapter(max_retries=self._retry))
return request_session

def _set_session_verification(self, request_session: Session):
"""
Set session verify value for validating the server's SSL certificate
The requests library allows this to be a boolean or a string path to the cert
If we do not skip verification, the order is:
1. Provided cert path
2. Cert path from env var.
3. True (verify with system's approved CA list)
"""
if self._skip_verify:
request_session.verify = False
return
if self._ca_cert:
request_session.verify = self._ca_cert
return
env_crt = os.getenv(AIS_CLIENT_CA)
request_session.verify = env_crt if env_crt else True

@property
def base_url(self):
Expand All @@ -99,21 +53,34 @@ def base_url(self):
return self._base_url

@property
def endpoint(self):
def timeout(self):
"""
Returns: Timeout for requests
"""
Returns: AIS cluster endpoint
return self._timeout

@timeout.setter
def timeout(self, timeout: Union[float, Tuple[float, float]]):
"""
Set timeout for all requests from this client
Args:
timeout: Request timeout
"""
return self._endpoint
self._timeout = timeout

@property
def session(self):
def session_manager(self) -> SessionManager:
"""
Returns: Active request session
Returns: SessionManager used to create sessions for this client
"""
return self._session
return self._session_manager

@session_manager.setter
def session_manager(self, session_manager):
self._session_manager = session_manager

@property
def token(self):
def token(self) -> str:
"""
Returns: Token for Authorization
"""
Expand Down Expand Up @@ -180,7 +147,7 @@ def request(
if self.token:
headers[HEADER_AUTHORIZATION] = f"Bearer {self.token}"

resp = self.session.request(
resp = self.session_manager.session.request(
method,
url,
headers=headers,
Expand Down
90 changes: 90 additions & 0 deletions python/aistore/sdk/session_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""

import os
from typing import Optional

from requests import Session
from requests.adapters import HTTPAdapter
from urllib3 import Retry

from aistore.sdk.const import AIS_CLIENT_CA

DEFAULT_RETRY = Retry(total=6, connect=3, backoff_factor=1)


class SessionManager:
"""
Class for storing and creating requests library sessions.
Args:
retry (urllib3.Retry, optional): Retry configuration object from the urllib3 library.
Default: Retry(total=6, connect=3, backoff_factor=1).
skip_verify (bool, optional): If True, skip SSL certificate verification. Defaults to False.
ca_cert (str, optional): Path to a CA certificate file for SSL verification. Defaults to None.
"""

def __init__(
self,
retry: Retry = DEFAULT_RETRY,
ca_cert: Optional[str] = None,
skip_verify: bool = False,
):
self._retry = retry
self._ca_cert = ca_cert
self._skip_verify = skip_verify
self._session = None

@property
def retry(self) -> Retry:
"""Returns retry config for this session"""
return self._retry

@property
def ca_cert(self) -> Optional[str]:
"""Returns CA certificate for this session, if any"""
return self._ca_cert

@property
def skip_verify(self) -> bool:
"""Returns whether this session's requests skip server certificate verification"""
return self._skip_verify

@property
def session(self) -> Session:
"""Returns an existing `requests` session, creating a new one if needed"""
if self._session is None:
self._session = self._create_session()
return self._session

def _set_session_verification(self, request_session: Session):
"""
Set session verify value for validating the server's SSL certificate
The requests library allows this to be a boolean or a string path to the cert
If we do not skip verification, the order is:
1. Provided cert path
2. Cert path from env var.
3. True (verify with system's approved CA list)
"""
if self._skip_verify:
request_session.verify = False
return
if self._ca_cert:
request_session.verify = self._ca_cert
return
env_crt = os.getenv(AIS_CLIENT_CA)
request_session.verify = env_crt if env_crt else True

def _create_session(self) -> Session:
"""
Creates a new `requests` session for HTTP requests.
Returns:
New HTTP request Session
"""
request_session = Session()
self._set_session_verification(request_session)
for protocol in ("http://", "https://"):
request_session.mount(protocol, HTTPAdapter(max_retries=self._retry))
return request_session
Loading

0 comments on commit e743e01

Please sign in to comment.