diff --git a/python/aistore/pytorch/base_iter_dataset.py b/python/aistore/pytorch/base_iter_dataset.py index f98b6e8739a..0e36038025c 100644 --- a/python/aistore/pytorch/base_iter_dataset.py +++ b/python/aistore/pytorch/base_iter_dataset.py @@ -4,16 +4,15 @@ Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. """ +from itertools import islice from typing import List, Union, Iterable, Dict, Iterator, Tuple -from aistore.sdk.ais_source import AISSource -from torch.utils.data import IterableDataset -from abc import ABC, abstractmethod -from aistore.pytorch.worker_request_client import WorkerRequestClient import torch.utils.data as torch_utils -from itertools import islice +from abc import ABC, abstractmethod +from aistore.sdk.ais_source import AISSource +from aistore.pytorch.worker_session_manager import WorkerSessionManager -class AISBaseIterDataset(ABC, IterableDataset): +class AISBaseIterDataset(ABC, torch_utils.IterableDataset): """ A base class for creating AIS Iterable Datasets. Should not be instantiated directly. Subclasses should implement :meth:`__iter__` which returns the samples from the dataset and can optionally @@ -51,7 +50,10 @@ def _create_objects_iter(self) -> Iterable: """ for source in self._ais_source_list: # Add pytorch worker support to the internal request client - source.client = WorkerRequestClient(source.client) + # TODO: Do not modify the provided source client + source.client.session_manager = WorkerSessionManager( + source.client.session_manager + ) if source not in self._prefix_map or self._prefix_map[source] is None: for sample in source.list_all_objects_iter(prefix=""): yield sample diff --git a/python/aistore/pytorch/base_map_dataset.py b/python/aistore/pytorch/base_map_dataset.py index aaabe173f26..0e967964905 100644 --- a/python/aistore/pytorch/base_map_dataset.py +++ b/python/aistore/pytorch/base_map_dataset.py @@ -5,11 +5,11 @@ """ from typing import List, Union, Dict +from abc import ABC, abstractmethod +from torch.utils.data import Dataset from aistore.sdk.ais_source import AISSource from aistore.sdk.object import Object -from torch.utils.data import Dataset -from abc import ABC, abstractmethod -from aistore.pytorch.worker_request_client import WorkerRequestClient +from aistore.pytorch.worker_session_manager import WorkerSessionManager class AISBaseMapDataset(ABC, Dataset): @@ -52,7 +52,10 @@ def _create_objects_list(self) -> List[Object]: for source in self._ais_source_list: # Add pytorch worker support to the internal request client - source.client = WorkerRequestClient(source.client) + # TODO: Do not modify the provided source client + source.client.session_manager = WorkerSessionManager( + source.client.session_manager + ) if source not in self._prefix_map or self._prefix_map[source] is None: samples.extend(list(source.list_all_objects_iter(prefix=""))) else: diff --git a/python/aistore/pytorch/worker_request_client.py b/python/aistore/pytorch/worker_session_manager.py similarity index 73% rename from python/aistore/pytorch/worker_request_client.py rename to python/aistore/pytorch/worker_session_manager.py index c5eb4a198dc..837e5062f68 100644 --- a/python/aistore/pytorch/worker_request_client.py +++ b/python/aistore/pytorch/worker_session_manager.py @@ -8,27 +8,25 @@ Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. """ -from aistore.sdk.request_client import RequestClient from torch.utils.data import get_worker_info +from aistore.sdk.session_manager import SessionManager -class WorkerRequestClient(RequestClient): + +class WorkerSessionManager(SessionManager): """ Extension that supports Pytorch and multiple workers of internal client for buckets, objects, jobs, etc. to use for making requests to an AIS cluster. Args: - client (RequestClient): Existing RequestClient to replace + session_manager (SessionManager): Existing SessionManager to replace """ - def __init__(self, client: RequestClient): + def __init__(self, session_manager: SessionManager): super().__init__( - endpoint=client._endpoint, - skip_verify=client._skip_verify, - ca_cert=client._ca_cert, - timeout=client._timeout, - retry=client._retry, - token=client._token, + retry=session_manager.retry, + ca_cert=session_manager.ca_cert, + skip_verify=session_manager.skip_verify, ) self._worker_sessions = {} @@ -43,5 +41,5 @@ def session(self): return self._session # if we only have one session but multiple workers, create more if worker_info.id not in self._worker_sessions: - self._worker_sessions[worker_info.id] = self._create_new_session() + self._worker_sessions[worker_info.id] = self._create_session() return self._worker_sessions[worker_info.id] diff --git a/python/aistore/sdk/authn/authn_client.py b/python/aistore/sdk/authn/authn_client.py index 08f445eafdd..462b4cccfa0 100644 --- a/python/aistore/sdk/authn/authn_client.py +++ b/python/aistore/sdk/authn/authn_client.py @@ -13,6 +13,7 @@ from aistore.sdk.authn.types import TokenMsg, LoginMsg from aistore.sdk.authn.cluster_manager import ClusterManager from aistore.sdk.authn.role_manager import RoleManager +from aistore.sdk.session_manager import SessionManager from aistore.sdk.authn.user_manager import UserManager # logging @@ -49,12 +50,13 @@ def __init__( token: Optional[str] = None, ): logger.info("Initializing AuthNClient") + session_manager = SessionManager( + retry=retry, ca_cert=ca_cert, skip_verify=skip_verify + ) self._request_client = RequestClient( endpoint=endpoint, - skip_verify=skip_verify, - ca_cert=ca_cert, + session_manager=session_manager, timeout=timeout, - retry=retry, token=token, ) logger.info("AuthNClient initialized with endpoint: %s", endpoint) diff --git a/python/aistore/sdk/client.py b/python/aistore/sdk/client.py index fd8df6cc37d..3637eae6a30 100644 --- a/python/aistore/sdk/client.py +++ b/python/aistore/sdk/client.py @@ -14,6 +14,7 @@ from aistore.sdk.cluster import Cluster from aistore.sdk.dsort import Dsort from aistore.sdk.request_client import RequestClient +from aistore.sdk.session_manager import SessionManager from aistore.sdk.types import Namespace from aistore.sdk.job import Job from aistore.sdk.etl import Etl @@ -47,8 +48,14 @@ def __init__( retry: Retry = None, token: str = None, ): + session_manager = SessionManager( + retry=retry, ca_cert=ca_cert, skip_verify=skip_verify + ) self._request_client = RequestClient( - endpoint, skip_verify, ca_cert, timeout, retry, token + endpoint=endpoint, + session_manager=session_manager, + timeout=timeout, + token=token, ) def bucket( diff --git a/python/aistore/sdk/request_client.py b/python/aistore/sdk/request_client.py index ef6c60db848..959021319c4 100644 --- a/python/aistore/sdk/request_client.py +++ b/python/aistore/sdk/request_client.py @@ -1,26 +1,22 @@ # # Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. # -import os from urllib.parse import urljoin, urlencode -from typing import Optional, TypeVar, Tuple, Type, Union, Any, Dict -from requests import Session, Response -from requests.adapters import HTTPAdapter -from urllib3 import Retry +from typing import TypeVar, Type, Any, Dict, Optional, Tuple, Union +from requests import Response from aistore.sdk.const import ( JSON_CONTENT_TYPE, HEADER_USER_AGENT, USER_AGENT_BASE, HEADER_CONTENT_TYPE, - AIS_CLIENT_CA, HEADER_AUTHORIZATION, ) +from aistore.sdk.session_manager import SessionManager from aistore.sdk.utils import handle_errors, decode_response from aistore.version import __version__ as sdk_version T = TypeVar("T") -DEFAULT_RETRY = Retry(total=6, connect=3, backoff_factor=1) # pylint: disable=unused-variable, duplicate-code, too-many-arguments @@ -30,66 +26,24 @@ class RequestClient: Args: endpoint (str): AIStore endpoint - skip_verify (bool, optional): If True, skip SSL certificate verification. Defaults to False. - ca_cert (str, optional): Path to a CA certificate file for SSL verification. + session_manager (SessionManager): SessionManager for creating and accessing requests session timeout (Union[float, Tuple[float, float], None], optional): Request timeout in seconds; a single float for both connect/read timeouts (e.g., 5.0), a tuple for separate connect/read timeouts (e.g., (3.0, 10.0)), or None to disable timeout. - retry (urllib3.Retry, optional): Retry configuration object from the urllib3 library. - Default: Retry(total=6, connect=3, backoff_factor=1). token (str, optional): Authorization token. """ - # pylint:disable=too-many-instance-attributes def __init__( self, endpoint: str, - skip_verify: bool = False, - ca_cert: Optional[str] = None, + session_manager: SessionManager, timeout: Optional[Union[float, Tuple[float, float]]] = None, - retry: Retry = DEFAULT_RETRY, token: str = None, ): - self._endpoint = endpoint self._base_url = urljoin(endpoint, "v1") - self._skip_verify = skip_verify - self._ca_cert = ca_cert + self._session_manager = session_manager self._token = token self._timeout = timeout - self._retry = retry - self._session = self._create_new_session() - - def _create_new_session(self) -> Session: - """ - Creates a new requests session for HTTP requests. - - Returns: - New HTTP request Session - """ - request_session = Session() - if "https" in self._endpoint: - self._set_session_verification(request_session) - for protocol in ("http://", "https://"): - request_session.mount(protocol, HTTPAdapter(max_retries=self._retry)) - return request_session - - def _set_session_verification(self, request_session: Session): - """ - Set session verify value for validating the server's SSL certificate - The requests library allows this to be a boolean or a string path to the cert - If we do not skip verification, the order is: - 1. Provided cert path - 2. Cert path from env var. - 3. True (verify with system's approved CA list) - """ - if self._skip_verify: - request_session.verify = False - return - if self._ca_cert: - request_session.verify = self._ca_cert - return - env_crt = os.getenv(AIS_CLIENT_CA) - request_session.verify = env_crt if env_crt else True @property def base_url(self): @@ -99,21 +53,34 @@ def base_url(self): return self._base_url @property - def endpoint(self): + def timeout(self): + """ + Returns: Timeout for requests """ - Returns: AIS cluster endpoint + return self._timeout + + @timeout.setter + def timeout(self, timeout: Union[float, Tuple[float, float]]): + """ + Set timeout for all requests from this client + Args: + timeout: Request timeout """ - return self._endpoint + self._timeout = timeout @property - def session(self): + def session_manager(self) -> SessionManager: """ - Returns: Active request session + Returns: SessionManager used to create sessions for this client """ - return self._session + return self._session_manager + + @session_manager.setter + def session_manager(self, session_manager): + self._session_manager = session_manager @property - def token(self): + def token(self) -> str: """ Returns: Token for Authorization """ @@ -180,7 +147,7 @@ def request( if self.token: headers[HEADER_AUTHORIZATION] = f"Bearer {self.token}" - resp = self.session.request( + resp = self.session_manager.session.request( method, url, headers=headers, diff --git a/python/aistore/sdk/session_manager.py b/python/aistore/sdk/session_manager.py new file mode 100644 index 00000000000..a2e750be653 --- /dev/null +++ b/python/aistore/sdk/session_manager.py @@ -0,0 +1,90 @@ +""" +Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +""" + +import os +from typing import Optional + +from requests import Session +from requests.adapters import HTTPAdapter +from urllib3 import Retry + +from aistore.sdk.const import AIS_CLIENT_CA + +DEFAULT_RETRY = Retry(total=6, connect=3, backoff_factor=1) + + +class SessionManager: + """ + Class for storing and creating requests library sessions. + + Args: + retry (urllib3.Retry, optional): Retry configuration object from the urllib3 library. + Default: Retry(total=6, connect=3, backoff_factor=1). + skip_verify (bool, optional): If True, skip SSL certificate verification. Defaults to False. + ca_cert (str, optional): Path to a CA certificate file for SSL verification. Defaults to None. + """ + + def __init__( + self, + retry: Retry = DEFAULT_RETRY, + ca_cert: Optional[str] = None, + skip_verify: bool = False, + ): + self._retry = retry + self._ca_cert = ca_cert + self._skip_verify = skip_verify + self._session = None + + @property + def retry(self) -> Retry: + """Returns retry config for this session""" + return self._retry + + @property + def ca_cert(self) -> Optional[str]: + """Returns CA certificate for this session, if any""" + return self._ca_cert + + @property + def skip_verify(self) -> bool: + """Returns whether this session's requests skip server certificate verification""" + return self._skip_verify + + @property + def session(self) -> Session: + """Returns an existing `requests` session, creating a new one if needed""" + if self._session is None: + self._session = self._create_session() + return self._session + + def _set_session_verification(self, request_session: Session): + """ + Set session verify value for validating the server's SSL certificate + The requests library allows this to be a boolean or a string path to the cert + If we do not skip verification, the order is: + 1. Provided cert path + 2. Cert path from env var. + 3. True (verify with system's approved CA list) + """ + if self._skip_verify: + request_session.verify = False + return + if self._ca_cert: + request_session.verify = self._ca_cert + return + env_crt = os.getenv(AIS_CLIENT_CA) + request_session.verify = env_crt if env_crt else True + + def _create_session(self) -> Session: + """ + Creates a new `requests` session for HTTP requests. + + Returns: + New HTTP request Session + """ + request_session = Session() + self._set_session_verification(request_session) + for protocol in ("http://", "https://"): + request_session.mount(protocol, HTTPAdapter(max_retries=self._retry)) + return request_session diff --git a/python/tests/unit/sdk/authn/test_authn_client.py b/python/tests/unit/sdk/authn/test_authn_client.py index 6283186c81e..c2aac38c7da 100644 --- a/python/tests/unit/sdk/authn/test_authn_client.py +++ b/python/tests/unit/sdk/authn/test_authn_client.py @@ -19,15 +19,14 @@ def setUp(self) -> None: self.endpoint = "http://authn-endpoint" self.client = AuthNClient(self.endpoint) + @patch("aistore.sdk.authn.authn_client.SessionManager") @patch("aistore.sdk.authn.authn_client.RequestClient") - def test_init_defaults(self, mock_request_client): + def test_init_defaults(self, mock_request_client, mock_sm): AuthNClient(self.endpoint) mock_request_client.assert_called_with( endpoint=self.endpoint, - skip_verify=False, - ca_cert=None, + session_manager=mock_sm.return_value, timeout=None, - retry=None, token=None, ) @@ -37,8 +36,9 @@ def test_init_defaults(self, mock_request_client): (False, None, 30.0, Retry(total=20), None), (False, None, (10, 30.0), Retry(total=20), "dummy.token"), ) + @patch("aistore.sdk.authn.authn_client.SessionManager") @patch("aistore.sdk.authn.authn_client.RequestClient") - def test_init(self, test_case, mock_request_client): + def test_init(self, test_case, mock_request_client, mock_sm): skip_verify, ca_cert, timeout, retry, token = test_case # print all vars print( @@ -52,12 +52,15 @@ def test_init(self, test_case, mock_request_client): retry=retry, token=token, ) - mock_request_client.assert_called_with( - endpoint=self.endpoint, + mock_sm.assert_called_with( + retry=retry, skip_verify=skip_verify, ca_cert=ca_cert, + ) + mock_request_client.assert_called_with( + endpoint=self.endpoint, + session_manager=mock_sm.return_value, timeout=timeout, - retry=retry, token=token, ) diff --git a/python/tests/unit/sdk/test_bucket.py b/python/tests/unit/sdk/test_bucket.py index aba0c9409fe..7f19d0bad82 100644 --- a/python/tests/unit/sdk/test_bucket.py +++ b/python/tests/unit/sdk/test_bucket.py @@ -86,7 +86,7 @@ def test_default_props(self): def test_properties(self): self.assertEqual(self.mock_client, self.ais_bck.client) expected_ns = Namespace(uuid="ns-id", name="ns-name") - client = RequestClient("test client name", skip_verify=False, ca_cert="") + client = RequestClient("test client name", session_manager=Mock()) bck = Bucket( client=client, name=BCK_NAME, diff --git a/python/tests/unit/sdk/test_client.py b/python/tests/unit/sdk/test_client.py index a950331d24d..fc7d159e96d 100644 --- a/python/tests/unit/sdk/test_client.py +++ b/python/tests/unit/sdk/test_client.py @@ -3,7 +3,7 @@ # import unittest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, Mock from urllib3.util import Retry from aistore.sdk import Client @@ -21,11 +21,15 @@ def setUp(self) -> None: self.endpoint = "https://aistore-endpoint" self.client = Client(self.endpoint) + @patch("aistore.sdk.client.SessionManager") @patch("aistore.sdk.client.RequestClient") - def test_init_defaults(self, mock_request_client): + def test_init_defaults(self, mock_request_client, mock_sm): Client(self.endpoint) mock_request_client.assert_called_with( - self.endpoint, False, None, None, None, None + endpoint=self.endpoint, + session_manager=mock_sm.return_value, + timeout=None, + token=None, ) @test_cases( @@ -34,8 +38,9 @@ def test_init_defaults(self, mock_request_client): (False, None, 30.0, Retry(total=4), None), (False, None, (10, 30.0), Retry(total=5, connect=2), "dummy.token"), ) + @patch("aistore.sdk.client.SessionManager") @patch("aistore.sdk.client.RequestClient") - def test_init(self, test_case, mock_request_client): + def test_init(self, test_case, mock_request_client, mock_sm): skip_verify, ca_cert, timeout, retry, token = test_case Client( self.endpoint, @@ -45,8 +50,14 @@ def test_init(self, test_case, mock_request_client): retry=retry, token=token, ) + mock_sm.assert_called_with( + retry=retry, ca_cert=ca_cert, skip_verify=skip_verify + ) mock_request_client.assert_called_with( - self.endpoint, skip_verify, ca_cert, timeout, retry, token + endpoint=self.endpoint, + session_manager=mock_sm.return_value, + timeout=timeout, + token=token, ) def test_bucket(self): @@ -54,7 +65,7 @@ def test_bucket(self): provider = "bucketProvider" namespace = Namespace(uuid="id", name="namespace") bucket = self.client.bucket(bck_name, provider, namespace) - self.assertEqual(self.endpoint, bucket.client.endpoint) + self.assertIn(self.endpoint, bucket.client.base_url) self.assertIsInstance(bucket.client, RequestClient) self.assertEqual(bck_name, bucket.name) self.assertEqual(provider, bucket.provider) @@ -62,7 +73,7 @@ def test_bucket(self): def test_cluster(self): res = self.client.cluster() - self.assertEqual(self.endpoint, res.client.endpoint) + self.assertIn(self.endpoint, res.client.base_url) self.assertIsInstance(res.client, RequestClient) self.assertIsInstance(res, Cluster) @@ -89,10 +100,10 @@ def test_fetch_object_from_url(self, mock_parse_url, mock_bucket): mock_parse_url.return_value = (provider, bck_name, obj_name) - mock_bucket_instance = MagicMock() + mock_bucket_instance = Mock() mock_bucket.return_value = mock_bucket_instance - expected_object = MagicMock() + expected_object = Mock() mock_bucket_instance.object.return_value = expected_object result = self.client.fetch_object_by_url(url) diff --git a/python/tests/unit/sdk/test_request_client.py b/python/tests/unit/sdk/test_request_client.py index 1048aaab7af..144c3573077 100644 --- a/python/tests/unit/sdk/test_request_client.py +++ b/python/tests/unit/sdk/test_request_client.py @@ -1,17 +1,16 @@ import unittest from unittest.mock import patch, Mock -import urllib3 -from requests import Response +from requests import Response, Session from aistore.sdk.const import ( JSON_CONTENT_TYPE, HEADER_USER_AGENT, USER_AGENT_BASE, HEADER_CONTENT_TYPE, - AIS_CLIENT_CA, ) -from aistore.sdk.request_client import RequestClient, DEFAULT_RETRY +from aistore.sdk.request_client import RequestClient +from aistore.sdk.session_manager import SessionManager from aistore.version import __version__ as sdk_version from tests.utils import test_cases @@ -19,95 +18,83 @@ class TestRequestClient(unittest.TestCase): # pylint: disable=unused-variable def setUp(self) -> None: self.endpoint = "https://aistore-endpoint" - self.mock_session = Mock() - with patch("aistore.sdk.request_client.Session") as mock_session: - mock_session.return_value = self.mock_session - self.request_client = RequestClient( - self.endpoint, skip_verify=True, ca_cert="" - ) - + self.mock_response = Mock(name="Mock response", spec=Response) + self.mock_session = Mock(name="Mock session", spec=Session) + self.mock_session.request.return_value = self.mock_response + self.mock_session_manager = Mock(spec=SessionManager, session=self.mock_session) self.request_headers = { HEADER_CONTENT_TYPE: JSON_CONTENT_TYPE, HEADER_USER_AGENT: f"{USER_AGENT_BASE}/{sdk_version}", } + self.default_request_client = RequestClient( + self.endpoint, self.mock_session_manager + ) - def test_default_session(self): - with patch( - "aistore.sdk.request_client.os.getenv", return_value=None - ) as mock_getenv: - self.request_client = RequestClient(self.endpoint) - mock_getenv.assert_called_with(AIS_CLIENT_CA) - self.assertEqual(True, self.request_client.session.verify) - self.assertEqual( - DEFAULT_RETRY, - self.request_client.session.get_adapter(self.endpoint).max_retries, - ) - self.assertIsNone(self.request_client.token) - - def test_custom_retry(self): - custom_retry = urllib3.util.Retry(total=40, connect=2) - self.request_client = RequestClient(self.endpoint, retry=custom_retry) + def test_init_default(self): + self.assertEqual(self.endpoint + "/v1", self.default_request_client.base_url) self.assertEqual( - custom_retry, - self.request_client.session.get_adapter(self.endpoint).max_retries, + self.mock_session_manager, self.default_request_client.session_manager ) + self.assertIsNone(self.default_request_client.timeout) - @patch("aistore.sdk.request_client.Session") - def test_session_timeout(self, mock_session_init): - mock_session = Mock() - mock_session.request.return_value = Mock(status_code=200) - mock_session_init.return_value = mock_session - custom_timeout = (10, 30.0) - request_client = RequestClient("", timeout=custom_timeout) - request_client.request("method", "path") - mock_session.request.assert_called_with( - "method", "v1/path", headers=self.request_headers, timeout=custom_timeout + @test_cases( + 10, + 30.0, + (10, 30.0), + ) + def test_init_properties(self, timeout): + auth_token = "any string" + request_client = RequestClient( + self.endpoint, self.mock_session_manager, timeout=timeout, token=auth_token ) + self.assertEqual(self.endpoint + "/v1", request_client.base_url) + self.assertEqual(self.mock_session_manager, request_client.session_manager) + self.assertEqual(timeout, request_client.timeout) + self.assertEqual(auth_token, request_client.token) + + def test_update_token(self): + auth_token = "any string" + self.default_request_client.token = auth_token + self.assertEqual(auth_token, self.default_request_client.token) @test_cases( - (("env-cert", "arg-cert", False), "arg-cert"), - (("env-cert", "arg-cert", True), False), - (("env-cert", None, False), "env-cert"), - ((True, None, False), True), - ((None, None, True), False), + 10, + 30.0, + (10, 30.0), ) - def test_session_tls(self, test_case): - env_cert, arg_cert, skip_verify = test_case[0] - with patch( - "aistore.sdk.request_client.os.getenv", return_value=env_cert - ) as mock_getenv: - self.request_client = RequestClient( - self.endpoint, skip_verify=skip_verify, ca_cert=arg_cert - ) - if not skip_verify and not arg_cert: - mock_getenv.assert_called_with(AIS_CLIENT_CA) - self.assertEqual(test_case[1], self.request_client.session.verify) - - def test_properties(self): - self.assertEqual(self.endpoint + "/v1", self.request_client.base_url) - self.assertEqual(self.endpoint, self.request_client.endpoint) + def test_update_timeout(self, timeout): + self.default_request_client.timeout = timeout + self.assertEqual(timeout, self.default_request_client.timeout) - @patch("aistore.sdk.request_client.RequestClient.request") @patch("aistore.sdk.request_client.decode_response") - def test_request_deserialize(self, mock_decode, mock_request): + def test_request_deserialize(self, mock_decode): method = "method" path = "path" decoded_value = "test value" custom_kw = "arg" mock_decode.return_value = decoded_value - mock_response = Mock(Response) - mock_request.return_value = mock_response + self.mock_response.status_code = 200 - res = self.request_client.request_deserialize( + res = self.default_request_client.request_deserialize( method, path, str, keyword=custom_kw ) + expected_url = self.endpoint + "/v1/" + path self.assertEqual(decoded_value, res) - mock_request.assert_called_with(method, path, keyword=custom_kw) - mock_decode.assert_called_with(str, mock_response) + self.mock_session.request.assert_called_with( + method, + expected_url, + headers=self.request_headers, + timeout=None, + keyword=custom_kw, + ) + mock_decode.assert_called_with(str, self.mock_response) - @test_cases(None, "http://custom_endpoint") - def test_request(self, endpoint_arg): + @test_cases((None, None), ("http://custom_endpoint", 30)) + def test_request(self, test_case): + endpoint_arg, timeout = test_case + if timeout: + self.default_request_client.timeout = timeout method = "request_method" path = "request_path" extra_kw_arg = "arg" @@ -116,13 +103,11 @@ def test_request(self, endpoint_arg): if endpoint_arg: req_url = f"{endpoint_arg}/v1/{path}" else: - req_url = f"{self.request_client.base_url}/{path}" + req_url = f"{self.default_request_client.base_url}/{path}" - mock_response = Mock() - mock_response.status_code = 200 - self.mock_session.request.return_value = mock_response + self.mock_response.status_code = 200 if endpoint_arg: - res = self.request_client.request( + res = self.default_request_client.request( method, path, endpoint=endpoint_arg, @@ -130,23 +115,22 @@ def test_request(self, endpoint_arg): keyword=extra_kw_arg, ) else: - res = self.request_client.request( + res = self.default_request_client.request( method, path, headers=extra_headers, keyword=extra_kw_arg ) self.mock_session.request.assert_called_with( method, req_url, headers=self.request_headers, - timeout=None, + timeout=timeout, keyword=extra_kw_arg, ) - self.assertEqual(mock_response, res) + self.assertEqual(self.mock_response, res) for response_code in [199, 300]: with patch("aistore.sdk.request_client.handle_errors") as mock_handle_err: - mock_response.status_code = response_code - self.mock_session.request.return_value = mock_response - res = self.request_client.request( + self.mock_response.status_code = response_code + res = self.default_request_client.request( method, path, endpoint=endpoint_arg, @@ -157,16 +141,16 @@ def test_request(self, endpoint_arg): method, req_url, headers=self.request_headers, - timeout=None, + timeout=timeout, keyword=extra_kw_arg, ) - self.assertEqual(mock_response, res) + self.assertEqual(self.mock_response, res) mock_handle_err.assert_called_once() def test_get_full_url(self): path = "/testpath/to_obj" params = {"p1key": "p1val", "p2key": "p2val"} - res = self.request_client.get_full_url(path, params) + res = self.default_request_client.get_full_url(path, params) self.assertEqual( "https://aistore-endpoint/v1/testpath/to_obj?p1key=p1val&p2key=p2val", res ) diff --git a/python/tests/unit/sdk/test_session_manager.py b/python/tests/unit/sdk/test_session_manager.py new file mode 100644 index 00000000000..69801374ab4 --- /dev/null +++ b/python/tests/unit/sdk/test_session_manager.py @@ -0,0 +1,62 @@ +import unittest +from unittest.mock import patch, Mock + +from requests.adapters import HTTPAdapter + +import urllib3 + +from aistore.sdk.const import AIS_CLIENT_CA +from aistore.sdk.session_manager import SessionManager, DEFAULT_RETRY +from tests.utils import test_cases + + +class TestSessionManager(unittest.TestCase): # pylint: disable=unused-variable + def setUp(self) -> None: + self.endpoint = "https://aistore-endpoint" + self.mock_session = Mock() + + def test_init_default(self): + session_manager = SessionManager() + self.assertEqual(DEFAULT_RETRY, session_manager.retry) + self.assertIsNone(session_manager.ca_cert) + self.assertFalse(session_manager.skip_verify) + + def test_init_args(self): + custom_retry = urllib3.Retry(total=9) + ca_cert_path = "/any/path" + session_manager = SessionManager( + retry=custom_retry, ca_cert=ca_cert_path, skip_verify=True + ) + self.assertEqual(custom_retry, session_manager.retry) + self.assertEqual(ca_cert_path, session_manager.ca_cert) + self.assertTrue(session_manager.skip_verify) + + def test_session_exists(self): + session_manager = SessionManager() + first_session = session_manager.session + self.assertEqual(first_session, session_manager.session) + + def test_create_custom_retry(self): + custom_retry = urllib3.util.Retry(total=40, connect=2) + session_manager = SessionManager(retry=custom_retry) + adapter = session_manager.session.get_adapter(self.endpoint) + self.assertIsInstance(adapter, HTTPAdapter) + self.assertEqual(custom_retry, adapter.max_retries) + + @test_cases( + (("env-cert", "arg-cert", False), "arg-cert"), + (("env-cert", "arg-cert", True), False), + (("env-cert", None, False), "env-cert"), + ((True, None, False), True), + ((None, None, True), False), + ) + def test_create_tls(self, test_case): + env_cert, arg_cert, skip_verify = test_case[0] + with patch( + "aistore.sdk.session_manager.os.getenv", return_value=env_cert + ) as mock_getenv: + session_manager = SessionManager(skip_verify=skip_verify, ca_cert=arg_cert) + session = session_manager.session + if not skip_verify and not arg_cert: + mock_getenv.assert_called_with(AIS_CLIENT_CA) + self.assertEqual(test_case[1], session.verify)