Skip to content

Commit

Permalink
sdk/python: Refactor object get to move request logic to object reader
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Wilson <[email protected]>
  • Loading branch information
aaronnw committed Aug 14, 2024
1 parent be347a6 commit 9562ed4
Show file tree
Hide file tree
Showing 9 changed files with 242 additions and 120 deletions.
15 changes: 5 additions & 10 deletions python/aistore/sdk/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from aistore.sdk.const import (
DEFAULT_CHUNK_SIZE,
HTTP_METHOD_DELETE,
HTTP_METHOD_GET,
HTTP_METHOD_HEAD,
HTTP_METHOD_PUT,
QPARAM_ARCHPATH,
Expand Down Expand Up @@ -115,7 +114,7 @@ def get(
byte_range: str = None,
) -> ObjectReader:
"""
Reads an object
Creates and returns an ObjectReader with access to object contents and optionally writes to a provided writer.
Args:
archive_settings (ArchiveSettings, optional): Settings for archive extraction
Expand All @@ -129,7 +128,8 @@ def get(
both the start and end of the range (e.g. "bytes=0-499" to request the first 500 bytes)
Returns:
The stream of bytes to read an object or a file inside an archive.
An ObjectReader which can be iterated over to stream chunks of object content or used to read all content
directly.
Raises:
requests.RequestException: "There was an ambiguous exception that occurred while handling..."
Expand Down Expand Up @@ -163,16 +163,11 @@ def get(
# https://www.rfc-editor.org/rfc/rfc7233#section-2.1
headers = {HEADER_RANGE: byte_range}

resp = self._client.request(
HTTP_METHOD_GET,
obj_reader = ObjectReader(
client=self._client,
path=self._object_path,
params=params,
stream=True,
headers=headers,
)
obj_reader = ObjectReader(
stream=resp,
response_headers=resp.headers,
chunk_size=chunk_size,
)
if writer:
Expand Down
73 changes: 55 additions & 18 deletions python/aistore/sdk/object_reader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Iterator
from typing import Iterator, List, Dict, Optional
import requests
from requests.structures import CaseInsensitiveDict
from aistore.sdk.const import DEFAULT_CHUNK_SIZE
from aistore.sdk.request_client import RequestClient
from aistore.sdk.const import DEFAULT_CHUNK_SIZE, HTTP_METHOD_GET, HTTP_METHOD_HEAD
from aistore.sdk.object_attributes import ObjectAttributes


Expand All @@ -11,15 +11,52 @@ class ObjectReader:
attributes.
"""

# pylint: disable=too-many-arguments
def __init__(
self,
response_headers: CaseInsensitiveDict,
stream: requests.Response,
client: RequestClient,
path: str,
params: List[str],
headers: Optional[Dict[str, str]] = None,
chunk_size: int = DEFAULT_CHUNK_SIZE,
):
self._request_client = client
self._request_path = path
self._request_params = params
self._request_headers = headers
self._chunk_size = chunk_size
self._stream = stream
self._attributes = ObjectAttributes(response_headers)
self._attributes = None

def _head(self):
"""
Make a head request to AIS to update and return only object attributes.
Returns:
ObjectAttributes for this object
"""
resp = self._request_client.request(
HTTP_METHOD_HEAD, path=self._request_path, params=self._request_params
)
return ObjectAttributes(resp.headers)

def _make_request(self, stream):
"""
Make a request to AIS to get the object content.
Returns:
requests.Response from AIS
"""
resp = self._request_client.request(
HTTP_METHOD_GET,
path=self._request_path,
params=self._request_params,
stream=stream,
headers=self._request_headers,
)
self._attributes = ObjectAttributes(resp.headers)
return resp

@property
def attributes(self) -> ObjectAttributes:
Expand All @@ -29,21 +66,20 @@ def attributes(self) -> ObjectAttributes:
Returns:
ObjectAttributes: Parsed object attributes from the headers returned by AIS
"""
if not self._attributes:
self._attributes = self._head()
return self._attributes

def read_all(self) -> bytes:
"""
Read all byte data from the object content stream.
Read all byte data directly from the object response without using a stream.
This uses a bytes cast which makes it slightly slower and requires all object content to fit in memory at once.
This requires all object content to fit in memory at once and downloads all content before returning.
Returns:
bytes: Object content as bytes.
"""
obj_arr = bytearray()
for chunk in self:
obj_arr.extend(chunk)
return bytes(obj_arr)
return self._make_request(stream=False).content

def raw(self) -> requests.Response:
"""
Expand All @@ -52,16 +88,17 @@ def raw(self) -> requests.Response:
Returns:
requests.Response: Raw byte stream of the object content
"""
return self._stream.raw
return self._make_request(stream=True).raw

def __iter__(self) -> Iterator[bytes]:
"""
Creates a generator to read the stream content in chunks.
Make a request to get a stream from the provided object and yield chunks of the stream content.
Returns:
Iterator[bytes]: An iterator to access the next chunk of bytes
Iterator[bytes]: An iterator over each chunk of bytes in the object
"""
stream = self._make_request(stream=True)
try:
yield from self._stream.iter_content(chunk_size=self._chunk_size)
yield from stream.iter_content(chunk_size=self._chunk_size)
finally:
self._stream.close()
stream.close()
2 changes: 1 addition & 1 deletion python/tests/integration/sdk/remote_enabled_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def setUp(self) -> None:
self.bck_name = random_string()
self.client = Client(CLUSTER_ENDPOINT)
self.buckets = []
self.obj_prefix = f"{self._testMethodName}-{random_string(6)}-"
self.obj_prefix = f"{self._testMethodName}-{random_string(6)}"

if REMOTE_SET:
self.cloud_objects = []
Expand Down
8 changes: 4 additions & 4 deletions python/tests/integration/sdk/test_bucket_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_get_latest_flag(self):

# cold GET must result in Error
with self.assertRaises(AISError):
self.bucket.object(obj_name).get(latest=True)
self.bucket.object(obj_name).get(latest=True).read_all()

@unittest.skipIf(
not REMOTE_SET,
Expand Down Expand Up @@ -221,7 +221,7 @@ def _verify_obj_res(self, expected_res_dict, expect_err=False):
if expect_err:
for obj_name in expected_res_dict:
with self.assertRaises(AISError):
self.bucket.object(self.obj_prefix + obj_name).get()
self.bucket.object(self.obj_prefix + obj_name).get().read_all()
else:
for obj_name, expected_data in expected_res_dict.items():
res = self.bucket.object(self.obj_prefix + obj_name).get()
Expand Down Expand Up @@ -279,9 +279,9 @@ def test_put_files_filtered(self):
)
self.bucket.object(self.obj_prefix + included_filename).get()
with self.assertRaises(AISError):
self.bucket.object(excluded_by_pattern).get()
self.bucket.object(excluded_by_pattern).get().read_all()
with self.assertRaises(AISError):
self.bucket.object(excluded_by_prefix).get()
self.bucket.object(excluded_by_prefix).get().read_all()

def test_put_files_dry_run(self):
self._create_put_files_structure(TOP_LEVEL_FILES, LOWER_LEVEL_FILES)
Expand Down
4 changes: 2 additions & 2 deletions python/tests/integration/sdk/test_object_group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def test_copy_objects_latest_flag(self):
)
self.client.job(job_id=copy_job).wait_for_idle(timeout=TEST_TIMEOUT)
with self.assertRaises(AISError):
self.bucket.object(obj_name).get()
self.bucket.object(obj_name).get().read_all()

@unittest.skipIf(
not REMOTE_SET,
Expand Down Expand Up @@ -247,7 +247,7 @@ def test_prefetch_objects_latest_flag(self):
prefetch_job = self.bucket.objects(obj_names=[obj_name]).prefetch(latest=True)
self.client.job(job_id=prefetch_job).wait_for_idle(timeout=TEST_TIMEOUT)
with self.assertRaises(AISError):
self.bucket.object(obj_name).get()
self.bucket.object(obj_name).get().read_all()

def _prefetch_and_check_with_latest(self, bucket, obj_name, expected, latest_flag):
prefetch_job = bucket.objects(obj_names=[obj_name]).prefetch(latest=latest_flag)
Expand Down
18 changes: 9 additions & 9 deletions python/tests/integration/sdk/test_object_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,20 @@ def tearDown(self) -> None:

def _test_get_obj(self, read_type, obj_name, exp_content):
chunk_size = random.randrange(1, len(exp_content) + 10)
stream = self.bucket.object(obj_name).get(chunk_size=chunk_size)
reader = self.bucket.object(obj_name).get(chunk_size=chunk_size)

self.assertEqual(stream.attributes.size, len(exp_content))
self.assertNotEqual(stream.attributes.checksum_type, "")
self.assertNotEqual(stream.attributes.checksum_value, "")
self.assertNotEqual(stream.attributes.access_time, "")
self.assertEqual(reader.attributes.size, len(exp_content))
self.assertNotEqual(reader.attributes.checksum_type, "")
self.assertNotEqual(reader.attributes.checksum_value, "")
self.assertNotEqual(reader.attributes.access_time, "")
if not REMOTE_SET:
self.assertNotEqual(stream.attributes.obj_version, "")
self.assertEqual(stream.attributes.custom_metadata, {})
self.assertNotEqual(reader.attributes.obj_version, "")
self.assertEqual(reader.attributes.custom_metadata, {})
if read_type == OBJ_READ_TYPE_ALL:
obj = stream.read_all()
obj = reader.read_all()
else:
obj = b""
for chunk in stream:
for chunk in reader:
obj += chunk
self.assertEqual(obj, exp_content)

Expand Down
106 changes: 30 additions & 76 deletions python/tests/unit/sdk/test_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from aistore.sdk.const import (
HTTP_METHOD_HEAD,
DEFAULT_CHUNK_SIZE,
HTTP_METHOD_GET,
HTTP_METHOD_PATCH,
QPARAM_ARCHPATH,
QPARAM_ARCHREGX,
Expand All @@ -18,13 +17,7 @@
QPARAM_NEW_CUSTOM,
HTTP_METHOD_PUT,
HTTP_METHOD_DELETE,
HEADER_CONTENT_LENGTH,
HEADER_OBJECT_APPEND_HANDLE,
AIS_CHECKSUM_VALUE,
AIS_CHECKSUM_TYPE,
AIS_ACCESS_TIME,
AIS_VERSION,
AIS_CUSTOM_MD,
HTTP_METHOD_POST,
ACT_PROMOTE,
ACT_BLOB_DOWNLOAD,
Expand Down Expand Up @@ -90,7 +83,7 @@ def test_get_default_params(self):
def test_get(self):
archpath_param = "archpath"
chunk_size = "4mb"
num_workers = 10
num_workers = "10"
self.expected_params[QPARAM_ARCHPATH] = archpath_param
self.expected_params[QPARAM_ARCHREGX] = ""
self.expected_params[QPARAM_ARCHMODE] = None
Expand Down Expand Up @@ -118,74 +111,35 @@ def test_get_archregex(self):
self.get_exec_assert(archive_settings=archive_settings)

def get_exec_assert(self, **kwargs):
content = b"123456789"
content_length = 9
ais_check_val = "xyz"
ais_check_type = "md5"
ais_atime = "time string"
ais_version = "3"
custom_metadata_dict = {"key1": "val1", "key2": "val2"}
custom_metadata = ", ".join(
["=".join(kv) for kv in custom_metadata_dict.items()]
)
resp_headers = CaseInsensitiveDict(
{
HEADER_CONTENT_LENGTH: content_length,
AIS_CHECKSUM_VALUE: ais_check_val,
AIS_CHECKSUM_TYPE: ais_check_type,
AIS_ACCESS_TIME: ais_atime,
AIS_VERSION: ais_version,
AIS_CUSTOM_MD: custom_metadata,
}
)
mock_response = Mock(Response)
mock_response.headers = resp_headers
mock_response.iter_content.return_value = content
mock_response.raw = content
expected_obj = ObjectReader(
response_headers=resp_headers,
stream=mock_response,
)
self.mock_client.request.return_value = mock_response

res = self.object.get(**kwargs)
blob_download_settings = kwargs.get(
"blob_download_settings", BlobDownloadSettings()
)
chunk_size = blob_download_settings.chunk_size
num_workers = blob_download_settings.num_workers
headers = {}
if chunk_size or num_workers:
headers[HEADER_OBJECT_BLOB_DOWNLOAD] = "true"
if chunk_size:
headers[HEADER_OBJECT_BLOB_CHUNK_SIZE] = chunk_size
if num_workers:
headers[HEADER_OBJECT_BLOB_WORKERS] = num_workers

self.assertEqual(expected_obj.raw(), res.raw())
self.assertEqual(content_length, res.attributes.size)
self.assertEqual(ais_check_type, res.attributes.checksum_type)
self.assertEqual(ais_check_val, res.attributes.checksum_value)
self.assertEqual(ais_atime, res.attributes.access_time)
self.assertEqual(ais_version, res.attributes.obj_version)
self.assertEqual(custom_metadata_dict, res.attributes.custom_metadata)
self.mock_client.request.assert_called_with(
HTTP_METHOD_GET,
path=REQUEST_PATH,
params=self.expected_params,
stream=True,
headers=headers,
)

# Use the object reader iterator to call the stream with the chunk size
for _ in res:
continue
mock_response.iter_content.assert_called_with(
chunk_size=kwargs.get("chunk_size", DEFAULT_CHUNK_SIZE)
)

if "writer" in kwargs:
self.mock_writer.writelines.assert_called_with(res)
with patch(
"aistore.sdk.object.ObjectReader", return_value=Mock(spec=ObjectReader)
) as mock_obj_reader:
res = self.object.get(**kwargs)

blob_download_settings = kwargs.get(
"blob_download_settings", BlobDownloadSettings()
)
blob_chunk_size = blob_download_settings.chunk_size
blob_workers = blob_download_settings.num_workers
expected_headers = kwargs.get("expected_headers", {})
if blob_chunk_size or blob_workers:
expected_headers[HEADER_OBJECT_BLOB_DOWNLOAD] = "true"
if blob_chunk_size:
expected_headers[HEADER_OBJECT_BLOB_CHUNK_SIZE] = blob_chunk_size
if blob_workers:
expected_headers[HEADER_OBJECT_BLOB_WORKERS] = blob_workers
expected_chunk_size = kwargs.get("chunk_size", DEFAULT_CHUNK_SIZE)

self.assertIsInstance(res, ObjectReader)
mock_obj_reader.assert_called_with(
client=self.mock_client,
path=REQUEST_PATH,
params=self.expected_params,
headers=expected_headers,
chunk_size=expected_chunk_size,
)
if "writer" in kwargs:
self.mock_writer.writelines.assert_called_with(res)

def test_get_url(self):
expected_res = "full url"
Expand Down
Loading

0 comments on commit 9562ed4

Please sign in to comment.