diff --git a/nodestream/pipeline/object_storage.py b/nodestream/pipeline/object_storage.py index 89dff2c7..ce564746 100644 --- a/nodestream/pipeline/object_storage.py +++ b/nodestream/pipeline/object_storage.py @@ -7,8 +7,11 @@ from pathlib import Path from typing import Optional, TypeVar +from botocore.exceptions import ClientError + from ..pluggable import Pluggable from ..subclass_registry import SubclassRegistry +from .extractors.credential_utils import AwsClientFactory OBJECT_STORE_REGISTRY = SubclassRegistry(ignore_overrides=True) T = TypeVar("T") @@ -298,3 +301,34 @@ def put(self, key: str, path: bytes): def delete(self, key: str): self.store.delete(key) + + +class S3ObjectStore(ObjectStore): + __slots__ = ("client", "bucket_name") + + def __init__(self, bucket_name: str, **client_factory_args): + client_factory = AwsClientFactory(**client_factory_args) + self.client = client_factory.make_client("s3") + self.bucket_name = bucket_name + + def get(self, key: str) -> Optional[bytes]: + try: + response = self.client.get_object(Bucket=self.bucket_name, Key=key) + return response["Body"].read() + except ClientError as e: + status = e.response["ResponseMetadata"]["HTTPStatusCode"] + if status == 404: + return None + raise e + + def put(self, key: str, data: bytes): + self.client.put_object(Bucket=self.bucket_name, Key=key, Body=data) + + def delete(self, key: str): + try: + self.client.delete_object(Bucket=self.bucket_name, Key=key) + except ClientError as e: + status = e.response["ResponseMetadata"]["HTTPStatusCode"] + if status == 404: + return + raise e diff --git a/tests/unit/pipeline/test_object_storage.py b/tests/unit/pipeline/test_object_storage.py index 0d2cdf1a..75c0d55b 100644 --- a/tests/unit/pipeline/test_object_storage.py +++ b/tests/unit/pipeline/test_object_storage.py @@ -2,6 +2,7 @@ from pathlib import Path import pytest +from botocore.exceptions import ClientError from hamcrest import assert_that, equal_to, is_, none, not_none from nodestream.pipeline.object_storage import ( @@ -11,12 +12,14 @@ MalformedSignedObjectError, NullObjectStore, ObjectStore, + S3ObjectStore, SignedObject, StaticNamespace, ) SOME_KEY = "some_key" SOME_DATA = b"some_data" +BUCKET_NAME = "bucket_name" @pytest.fixture @@ -171,3 +174,72 @@ def test_signed_object_store_delete(signed_object_store): signed_object_store.put(SOME_KEY, SOME_DATA) signed_object_store.delete(SOME_KEY) assert_that(signed_object_store.get(SOME_KEY), is_(none())) + + +@pytest.fixture +def s3_client(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def s3_object_store(s3_client): + with pytest.MonkeyPatch.context() as mp: + mp.setattr( + "nodestream.pipeline.object_storage.AwsClientFactory.make_client", + lambda self, service: s3_client, + ) + yield S3ObjectStore(BUCKET_NAME) + + +def test_s3_object_store_get_found(s3_object_store, s3_client, mocker): + s3_client.get_object.return_value = { + "Body": mocker.MagicMock(read=lambda: SOME_DATA) + } + retrieved_data = s3_object_store.get(SOME_KEY) + assert_that(retrieved_data, equal_to(SOME_DATA)) + s3_client.get_object.assert_called_once_with(Bucket=BUCKET_NAME, Key=SOME_KEY) + + +def test_s3_object_store_get_not_found(s3_object_store, s3_client): + s3_client.get_object.side_effect = ClientError( + {"ResponseMetadata": {"HTTPStatusCode": 404}}, "get_object" + ) + retrieved_data = s3_object_store.get(SOME_KEY) + assert_that(retrieved_data, is_(none())) + s3_client.get_object.assert_called_once_with(Bucket=BUCKET_NAME, Key=SOME_KEY) + + +def test_s3_object_store_get_other_error(s3_object_store, s3_client): + s3_client.get_object.side_effect = ClientError( + {"ResponseMetadata": {"HTTPStatusCode": 400}}, "get_object" + ) + with pytest.raises(ClientError): + s3_object_store.get(SOME_KEY) + + +def test_s3_object_store_put(s3_object_store, s3_client): + s3_object_store.put(SOME_KEY, SOME_DATA) + s3_client.put_object.assert_called_once_with( + Bucket=BUCKET_NAME, Key=SOME_KEY, Body=SOME_DATA + ) + + +def test_s3_object_store_delete(s3_object_store, s3_client): + s3_object_store.delete(SOME_KEY) + s3_client.delete_object.assert_called_once_with(Bucket=BUCKET_NAME, Key=SOME_KEY) + + +def test_s3_object_store_delete_not_found(s3_object_store, s3_client): + s3_client.delete_object.side_effect = ClientError( + {"ResponseMetadata": {"HTTPStatusCode": 404}}, "delete_object" + ) + s3_object_store.delete(SOME_KEY) + s3_client.delete_object.assert_called_once_with(Bucket=BUCKET_NAME, Key=SOME_KEY) + + +def test_s3_object_store_delete_other_error(s3_object_store, s3_client): + s3_client.delete_object.side_effect = ClientError( + {"ResponseMetadata": {"HTTPStatusCode": 400}}, "delete_object" + ) + with pytest.raises(ClientError): + s3_object_store.delete(SOME_KEY)