diff --git a/python/aistore/sdk/object.py b/python/aistore/sdk/object.py index 6f8c04edf3..7dc60b6c2b 100644 --- a/python/aistore/sdk/object.py +++ b/python/aistore/sdk/object.py @@ -9,7 +9,6 @@ from aistore.sdk.const import ( DEFAULT_CHUNK_SIZE, HTTP_METHOD_DELETE, - HTTP_METHOD_GET, HTTP_METHOD_HEAD, HTTP_METHOD_PUT, QPARAM_ARCHPATH, @@ -115,7 +114,7 @@ def get( byte_range: str = None, ) -> ObjectReader: """ - Reads an object + Creates and returns an ObjectReader with access to object contents and optionally writes to a provided writer. Args: archive_settings (ArchiveSettings, optional): Settings for archive extraction @@ -129,7 +128,8 @@ def get( both the start and end of the range (e.g. "bytes=0-499" to request the first 500 bytes) Returns: - The stream of bytes to read an object or a file inside an archive. + An ObjectReader which can be iterated over to stream chunks of object content or used to read all content + directly. Raises: requests.RequestException: "There was an ambiguous exception that occurred while handling..." @@ -163,16 +163,11 @@ def get( # https://www.rfc-editor.org/rfc/rfc7233#section-2.1 headers = {HEADER_RANGE: byte_range} - resp = self._client.request( - HTTP_METHOD_GET, + obj_reader = ObjectReader( + client=self._client, path=self._object_path, params=params, - stream=True, headers=headers, - ) - obj_reader = ObjectReader( - stream=resp, - response_headers=resp.headers, chunk_size=chunk_size, ) if writer: diff --git a/python/aistore/sdk/object_reader.py b/python/aistore/sdk/object_reader.py index b265804f08..51f5a4fb97 100644 --- a/python/aistore/sdk/object_reader.py +++ b/python/aistore/sdk/object_reader.py @@ -1,7 +1,7 @@ -from typing import Iterator +from typing import Iterator, List, Dict, Optional import requests -from requests.structures import CaseInsensitiveDict -from aistore.sdk.const import DEFAULT_CHUNK_SIZE +from aistore.sdk.request_client import RequestClient +from aistore.sdk.const import DEFAULT_CHUNK_SIZE, HTTP_METHOD_GET, HTTP_METHOD_HEAD from aistore.sdk.object_attributes import ObjectAttributes @@ -11,15 +11,52 @@ class ObjectReader: attributes. """ + # pylint: disable=too-many-arguments def __init__( self, - response_headers: CaseInsensitiveDict, - stream: requests.Response, + client: RequestClient, + path: str, + params: List[str], + headers: Optional[Dict[str, str]] = None, chunk_size: int = DEFAULT_CHUNK_SIZE, ): + self._request_client = client + self._request_path = path + self._request_params = params + self._request_headers = headers self._chunk_size = chunk_size - self._stream = stream - self._attributes = ObjectAttributes(response_headers) + self._attributes = None + + def _head(self): + """ + Make a head request to AIS to update and return only object attributes. + + Returns: + ObjectAttributes for this object + + """ + resp = self._request_client.request( + HTTP_METHOD_HEAD, path=self._request_path, params=self._request_params + ) + return ObjectAttributes(resp.headers) + + def _make_request(self, stream): + """ + Make a request to AIS to get the object content. + + Returns: + requests.Response from AIS + + """ + resp = self._request_client.request( + HTTP_METHOD_GET, + path=self._request_path, + params=self._request_params, + stream=stream, + headers=self._request_headers, + ) + self._attributes = ObjectAttributes(resp.headers) + return resp @property def attributes(self) -> ObjectAttributes: @@ -29,21 +66,20 @@ def attributes(self) -> ObjectAttributes: Returns: ObjectAttributes: Parsed object attributes from the headers returned by AIS """ + if not self._attributes: + self._attributes = self._head() return self._attributes def read_all(self) -> bytes: """ - Read all byte data from the object content stream. + Read all byte data directly from the object response without using a stream. - This uses a bytes cast which makes it slightly slower and requires all object content to fit in memory at once. + This requires all object content to fit in memory at once and downloads all content before returning. Returns: bytes: Object content as bytes. """ - obj_arr = bytearray() - for chunk in self: - obj_arr.extend(chunk) - return bytes(obj_arr) + return self._make_request(stream=False).content def raw(self) -> requests.Response: """ @@ -52,16 +88,17 @@ def raw(self) -> requests.Response: Returns: requests.Response: Raw byte stream of the object content """ - return self._stream.raw + return self._make_request(stream=True).raw def __iter__(self) -> Iterator[bytes]: """ - Creates a generator to read the stream content in chunks. + Make a request to get a stream from the provided object and yield chunks of the stream content. Returns: - Iterator[bytes]: An iterator to access the next chunk of bytes + Iterator[bytes]: An iterator over each chunk of bytes in the object """ + stream = self._make_request(stream=True) try: - yield from self._stream.iter_content(chunk_size=self._chunk_size) + yield from stream.iter_content(chunk_size=self._chunk_size) finally: - self._stream.close() + stream.close() diff --git a/python/tests/integration/sdk/remote_enabled_test.py b/python/tests/integration/sdk/remote_enabled_test.py index 94ed69b838..c74ba468ec 100644 --- a/python/tests/integration/sdk/remote_enabled_test.py +++ b/python/tests/integration/sdk/remote_enabled_test.py @@ -39,7 +39,7 @@ def setUp(self) -> None: self.bck_name = random_string() self.client = Client(CLUSTER_ENDPOINT) self.buckets = [] - self.obj_prefix = f"{self._testMethodName}-{random_string(6)}-" + self.obj_prefix = f"{self._testMethodName}-{random_string(6)}" if REMOTE_SET: self.cloud_objects = [] diff --git a/python/tests/integration/sdk/test_bucket_ops.py b/python/tests/integration/sdk/test_bucket_ops.py index 386765ef18..6a6a91a5f3 100644 --- a/python/tests/integration/sdk/test_bucket_ops.py +++ b/python/tests/integration/sdk/test_bucket_ops.py @@ -175,7 +175,7 @@ def test_get_latest_flag(self): # cold GET must result in Error with self.assertRaises(AISError): - self.bucket.object(obj_name).get(latest=True) + self.bucket.object(obj_name).get(latest=True).read_all() @unittest.skipIf( not REMOTE_SET, @@ -221,7 +221,7 @@ def _verify_obj_res(self, expected_res_dict, expect_err=False): if expect_err: for obj_name in expected_res_dict: with self.assertRaises(AISError): - self.bucket.object(self.obj_prefix + obj_name).get() + self.bucket.object(self.obj_prefix + obj_name).get().read_all() else: for obj_name, expected_data in expected_res_dict.items(): res = self.bucket.object(self.obj_prefix + obj_name).get() @@ -279,9 +279,9 @@ def test_put_files_filtered(self): ) self.bucket.object(self.obj_prefix + included_filename).get() with self.assertRaises(AISError): - self.bucket.object(excluded_by_pattern).get() + self.bucket.object(excluded_by_pattern).get().read_all() with self.assertRaises(AISError): - self.bucket.object(excluded_by_prefix).get() + self.bucket.object(excluded_by_prefix).get().read_all() def test_put_files_dry_run(self): self._create_put_files_structure(TOP_LEVEL_FILES, LOWER_LEVEL_FILES) diff --git a/python/tests/integration/sdk/test_object_group_ops.py b/python/tests/integration/sdk/test_object_group_ops.py index 9155c65f38..963ebbc9a2 100644 --- a/python/tests/integration/sdk/test_object_group_ops.py +++ b/python/tests/integration/sdk/test_object_group_ops.py @@ -180,7 +180,7 @@ def test_copy_objects_latest_flag(self): ) self.client.job(job_id=copy_job).wait_for_idle(timeout=TEST_TIMEOUT) with self.assertRaises(AISError): - self.bucket.object(obj_name).get() + self.bucket.object(obj_name).get().read_all() @unittest.skipIf( not REMOTE_SET, @@ -247,7 +247,7 @@ def test_prefetch_objects_latest_flag(self): prefetch_job = self.bucket.objects(obj_names=[obj_name]).prefetch(latest=True) self.client.job(job_id=prefetch_job).wait_for_idle(timeout=TEST_TIMEOUT) with self.assertRaises(AISError): - self.bucket.object(obj_name).get() + self.bucket.object(obj_name).get().read_all() def _prefetch_and_check_with_latest(self, bucket, obj_name, expected, latest_flag): prefetch_job = bucket.objects(obj_names=[obj_name]).prefetch(latest=latest_flag) diff --git a/python/tests/integration/sdk/test_object_ops.py b/python/tests/integration/sdk/test_object_ops.py index 83342bce0b..362b5b79f2 100644 --- a/python/tests/integration/sdk/test_object_ops.py +++ b/python/tests/integration/sdk/test_object_ops.py @@ -50,20 +50,20 @@ def tearDown(self) -> None: def _test_get_obj(self, read_type, obj_name, exp_content): chunk_size = random.randrange(1, len(exp_content) + 10) - stream = self.bucket.object(obj_name).get(chunk_size=chunk_size) + reader = self.bucket.object(obj_name).get(chunk_size=chunk_size) - self.assertEqual(stream.attributes.size, len(exp_content)) - self.assertNotEqual(stream.attributes.checksum_type, "") - self.assertNotEqual(stream.attributes.checksum_value, "") - self.assertNotEqual(stream.attributes.access_time, "") + self.assertEqual(reader.attributes.size, len(exp_content)) + self.assertNotEqual(reader.attributes.checksum_type, "") + self.assertNotEqual(reader.attributes.checksum_value, "") + self.assertNotEqual(reader.attributes.access_time, "") if not REMOTE_SET: - self.assertNotEqual(stream.attributes.obj_version, "") - self.assertEqual(stream.attributes.custom_metadata, {}) + self.assertNotEqual(reader.attributes.obj_version, "") + self.assertEqual(reader.attributes.custom_metadata, {}) if read_type == OBJ_READ_TYPE_ALL: - obj = stream.read_all() + obj = reader.read_all() else: obj = b"" - for chunk in stream: + for chunk in reader: obj += chunk self.assertEqual(obj, exp_content) diff --git a/python/tests/unit/sdk/test_object.py b/python/tests/unit/sdk/test_object.py index eeb7bc2c4e..e0ebc0841f 100644 --- a/python/tests/unit/sdk/test_object.py +++ b/python/tests/unit/sdk/test_object.py @@ -7,7 +7,6 @@ from aistore.sdk.const import ( HTTP_METHOD_HEAD, DEFAULT_CHUNK_SIZE, - HTTP_METHOD_GET, HTTP_METHOD_PATCH, QPARAM_ARCHPATH, QPARAM_ARCHREGX, @@ -18,13 +17,7 @@ QPARAM_NEW_CUSTOM, HTTP_METHOD_PUT, HTTP_METHOD_DELETE, - HEADER_CONTENT_LENGTH, HEADER_OBJECT_APPEND_HANDLE, - AIS_CHECKSUM_VALUE, - AIS_CHECKSUM_TYPE, - AIS_ACCESS_TIME, - AIS_VERSION, - AIS_CUSTOM_MD, HTTP_METHOD_POST, ACT_PROMOTE, ACT_BLOB_DOWNLOAD, @@ -90,7 +83,7 @@ def test_get_default_params(self): def test_get(self): archpath_param = "archpath" chunk_size = "4mb" - num_workers = 10 + num_workers = "10" self.expected_params[QPARAM_ARCHPATH] = archpath_param self.expected_params[QPARAM_ARCHREGX] = "" self.expected_params[QPARAM_ARCHMODE] = None @@ -118,74 +111,35 @@ def test_get_archregex(self): self.get_exec_assert(archive_settings=archive_settings) def get_exec_assert(self, **kwargs): - content = b"123456789" - content_length = 9 - ais_check_val = "xyz" - ais_check_type = "md5" - ais_atime = "time string" - ais_version = "3" - custom_metadata_dict = {"key1": "val1", "key2": "val2"} - custom_metadata = ", ".join( - ["=".join(kv) for kv in custom_metadata_dict.items()] - ) - resp_headers = CaseInsensitiveDict( - { - HEADER_CONTENT_LENGTH: content_length, - AIS_CHECKSUM_VALUE: ais_check_val, - AIS_CHECKSUM_TYPE: ais_check_type, - AIS_ACCESS_TIME: ais_atime, - AIS_VERSION: ais_version, - AIS_CUSTOM_MD: custom_metadata, - } - ) - mock_response = Mock(Response) - mock_response.headers = resp_headers - mock_response.iter_content.return_value = content - mock_response.raw = content - expected_obj = ObjectReader( - response_headers=resp_headers, - stream=mock_response, - ) - self.mock_client.request.return_value = mock_response - - res = self.object.get(**kwargs) - blob_download_settings = kwargs.get( - "blob_download_settings", BlobDownloadSettings() - ) - chunk_size = blob_download_settings.chunk_size - num_workers = blob_download_settings.num_workers - headers = {} - if chunk_size or num_workers: - headers[HEADER_OBJECT_BLOB_DOWNLOAD] = "true" - if chunk_size: - headers[HEADER_OBJECT_BLOB_CHUNK_SIZE] = chunk_size - if num_workers: - headers[HEADER_OBJECT_BLOB_WORKERS] = num_workers - - self.assertEqual(expected_obj.raw(), res.raw()) - self.assertEqual(content_length, res.attributes.size) - self.assertEqual(ais_check_type, res.attributes.checksum_type) - self.assertEqual(ais_check_val, res.attributes.checksum_value) - self.assertEqual(ais_atime, res.attributes.access_time) - self.assertEqual(ais_version, res.attributes.obj_version) - self.assertEqual(custom_metadata_dict, res.attributes.custom_metadata) - self.mock_client.request.assert_called_with( - HTTP_METHOD_GET, - path=REQUEST_PATH, - params=self.expected_params, - stream=True, - headers=headers, - ) - - # Use the object reader iterator to call the stream with the chunk size - for _ in res: - continue - mock_response.iter_content.assert_called_with( - chunk_size=kwargs.get("chunk_size", DEFAULT_CHUNK_SIZE) - ) - - if "writer" in kwargs: - self.mock_writer.writelines.assert_called_with(res) + with patch( + "aistore.sdk.object.ObjectReader", return_value=Mock(spec=ObjectReader) + ) as mock_obj_reader: + res = self.object.get(**kwargs) + + blob_download_settings = kwargs.get( + "blob_download_settings", BlobDownloadSettings() + ) + blob_chunk_size = blob_download_settings.chunk_size + blob_workers = blob_download_settings.num_workers + expected_headers = kwargs.get("expected_headers", {}) + if blob_chunk_size or blob_workers: + expected_headers[HEADER_OBJECT_BLOB_DOWNLOAD] = "true" + if blob_chunk_size: + expected_headers[HEADER_OBJECT_BLOB_CHUNK_SIZE] = blob_chunk_size + if blob_workers: + expected_headers[HEADER_OBJECT_BLOB_WORKERS] = blob_workers + expected_chunk_size = kwargs.get("chunk_size", DEFAULT_CHUNK_SIZE) + + self.assertIsInstance(res, ObjectReader) + mock_obj_reader.assert_called_with( + client=self.mock_client, + path=REQUEST_PATH, + params=self.expected_params, + headers=expected_headers, + chunk_size=expected_chunk_size, + ) + if "writer" in kwargs: + self.mock_writer.writelines.assert_called_with(res) def test_get_url(self): expected_res = "full url" diff --git a/python/tests/unit/sdk/test_object_attributes.py b/python/tests/unit/sdk/test_object_attributes.py new file mode 100644 index 0000000000..45eddc96f9 --- /dev/null +++ b/python/tests/unit/sdk/test_object_attributes.py @@ -0,0 +1,49 @@ +import unittest + +from requests.structures import CaseInsensitiveDict +from aistore.sdk.object_attributes import ObjectAttributes + + +class TestObjectAttributes(unittest.TestCase): + def setUp(self): + self.md_dict = {"key1": "value1", "key2": "value2"} + self.md_string = "key1=value1,key2=value2" + self.response_headers = CaseInsensitiveDict( + { + "Content-Length": "1024", + "Ais-Checksum-Type": "md5", + "Ais-Checksum-Value": "abcdef1234567890", + "ais-atime": "2024-08-13T10:30:00Z", + "Ais-Version": "1.0", + "Ais-Custom-Md": self.md_string + ",invalid entry", + } + ) + self.attributes = ObjectAttributes(self.response_headers) + + def test_size(self): + self.assertEqual(1024, self.attributes.size) + + def test_checksum_type(self): + self.assertEqual("md5", self.attributes.checksum_type) + + def test_checksum_value(self): + self.assertEqual("abcdef1234567890", self.attributes.checksum_value) + + def test_access_time(self): + self.assertEqual("2024-08-13T10:30:00Z", self.attributes.access_time) + + def test_obj_version(self): + self.assertEqual("1.0", self.attributes.obj_version) + + def test_custom_metadata(self): + self.assertDictEqual(self.md_dict, self.attributes.custom_metadata) + + def test_missing_headers(self): + headers = CaseInsensitiveDict() + attributes = ObjectAttributes(headers) + self.assertEqual(0, attributes.size) + self.assertEqual("", attributes.checksum_type) + self.assertEqual("", attributes.checksum_value) + self.assertEqual("", attributes.access_time) + self.assertEqual("", attributes.obj_version) + self.assertDictEqual({}, attributes.custom_metadata) diff --git a/python/tests/unit/sdk/test_object_reader.py b/python/tests/unit/sdk/test_object_reader.py new file mode 100644 index 0000000000..f77197923b --- /dev/null +++ b/python/tests/unit/sdk/test_object_reader.py @@ -0,0 +1,87 @@ +import unittest +from unittest.mock import MagicMock, patch, Mock +import requests + +from aistore.sdk.object_reader import ObjectReader +from aistore.sdk.request_client import RequestClient +from aistore.sdk.const import HTTP_METHOD_GET, HTTP_METHOD_HEAD +from aistore.sdk.object_attributes import ObjectAttributes + + +class TestObjectReader(unittest.TestCase): + def setUp(self): + self.client = MagicMock(spec=RequestClient) + self.path = "/test/path" + self.params = ["param1", "param2"] + self.headers = {"header1": "req1", "header2": "req2"} + self.response_headers = {"attr1": "resp1", "attr2": "resp2"} + self.chunk_size = 1024 + self.object_reader = ObjectReader( + self.client, self.path, self.params, self.headers, self.chunk_size + ) + + @patch("aistore.sdk.object_reader.ObjectAttributes", autospec=True) + def test_attributes_head(self, mock_attr): + mock_response = Mock(spec=requests.Response, headers=self.response_headers) + self.client.request.return_value = mock_response + + res = self.object_reader.attributes + + self.assertEqual(mock_attr.return_value, res) + self.client.request.assert_called_once_with( + HTTP_METHOD_HEAD, path=self.path, params=self.params + ) + mock_attr.assert_called_with(self.response_headers) + + @patch("aistore.sdk.object_reader.ObjectAttributes", autospec=True) + def test_read_all(self, mock_attr): + chunk1 = b"chunk1" + chunk2 = b"chunk2" + mock_response = Mock( + spec=requests.Response, + content=chunk1 + chunk2, + headers=self.response_headers, + ) + self.client.request.return_value = mock_response + + content = self.object_reader.read_all() + + self.assertEqual(chunk1 + chunk2, content) + self.assert_make_request(mock_attr, stream=False) + + @patch("aistore.sdk.object_reader.ObjectAttributes", autospec=True) + def test_raw(self, mock_attr): + mock_response = Mock( + spec=requests.Response, raw=b"bytestream", headers=self.response_headers + ) + self.client.request.return_value = mock_response + + raw_stream = self.object_reader.raw() + + self.assertEqual(mock_response.raw, raw_stream) + self.assert_make_request(mock_attr, stream=True) + + @patch("aistore.sdk.object_reader.ObjectAttributes", autospec=True) + def test_iter(self, mock_attr): + expected_chunks = [b"chunk1", b"chunk2"] + mock_response = Mock(spec=requests.Response, headers=self.response_headers) + mock_response.iter_content.return_value = expected_chunks + self.client.request.return_value = mock_response + + chunks = list(self.object_reader) + + mock_response.iter_content.assert_called_once_with(chunk_size=self.chunk_size) + mock_response.close.assert_called_once() + self.assertEqual(expected_chunks, chunks) + self.assert_make_request(mock_attr, stream=True) + + def assert_make_request(self, mock_attr, stream): + self.client.request.assert_called_once_with( + HTTP_METHOD_GET, + path=self.path, + params=self.params, + stream=stream, + headers=self.headers, + ) + self.assertIsInstance(self.object_reader.attributes, ObjectAttributes) + mock_attr.assert_called_with(self.response_headers)