From 34feb3f7eb45ae0417011a15b4ad59c70a8deb43 Mon Sep 17 00:00:00 2001 From: Adrian Partl Date: Sun, 8 Dec 2019 21:24:29 +0100 Subject: [PATCH] added support for using S3 single part uploads --- smart_open/s3.py | 164 ++++++++++++++++++++++++++-- smart_open/tests/test_s3.py | 111 +++++++++++++++++-- smart_open/tests/test_smart_open.py | 2 +- 3 files changed, 256 insertions(+), 21 deletions(-) diff --git a/smart_open/s3.py b/smart_open/s3.py index 80948ad2..249513c7 100644 --- a/smart_open/s3.py +++ b/smart_open/s3.py @@ -78,6 +78,7 @@ def open( session=None, resource_kwargs=None, multipart_upload_kwargs=None, + multipart_upload=True, ): """Open an S3 object for reading or writing. @@ -100,6 +101,12 @@ def open( multipart_upload_kwargs: dict, optional Additional parameters to pass to boto3's initiate_multipart_upload function. For writing only. + multipart_upload: bool, optional + Default: `True` + If set to `True`, will use multipart upload for writing to S3. If set + to `false`, S3 upload will use the S3 Single-Part Upload API, which + is more ideal for small file sizes. + For writing only. version_id: str, optional Version of the object, used when reading object. If None, will fetch the most recent version. @@ -127,14 +134,23 @@ def open( resource_kwargs=resource_kwargs, ) elif mode == WRITE_BINARY: - fileobj = BufferedOutputBase( - bucket_id, - key_id, - min_part_size=min_part_size, - session=session, - multipart_upload_kwargs=multipart_upload_kwargs, - resource_kwargs=resource_kwargs, - ) + if multipart_upload: + fileobj = BufferedMultiPartOutputBase( + bucket_id, + key_id, + min_part_size=min_part_size, + session=session, + multipart_upload_kwargs=multipart_upload_kwargs, + resource_kwargs=resource_kwargs, + ) + else: + fileobj = BufferedSinglePartOutputBase( + bucket_id, + key_id, + session=session, + multipart_upload_kwargs=multipart_upload_kwargs, + resource_kwargs=resource_kwargs, + ) else: assert False, 'unexpected mode: %r' % mode return fileobj @@ -479,8 +495,8 @@ def __repr__(self): ) -class BufferedOutputBase(io.BufferedIOBase): - """Writes bytes to S3. +class BufferedMultiPartOutputBase(io.BufferedIOBase): + """Writes bytes to S3 using the multi part API. Implements the io.BufferedIOBase interface of the standard library.""" @@ -637,11 +653,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.close() def __str__(self): - return "smart_open.s3.BufferedOutputBase(%r, %r)" % (self._object.bucket_name, self._object.key) + return "smart_open.s3.BufferedMultiPartOutputBase(%r, %r)" % (self._object.bucket_name, self._object.key) def __repr__(self): return ( - "smart_open.s3.BufferedOutputBase(" + "smart_open.s3.BufferedMultiPartOutputBase(" "bucket=%r, " "key=%r, " "min_part_size=%r, " @@ -658,6 +674,130 @@ def __repr__(self): ) +class BufferedSinglePartOutputBase(io.BufferedIOBase): + """Writes bytes to S3 using the single part API. + + Implements the io.BufferedIOBase interface of the standard library.""" + + def __init__( + self, + bucket, + key, + session=None, + resource_kwargs=None, + multipart_upload_kwargs=None, + ): + + self._session = session + self._resource_kwargs = resource_kwargs + + if session is None: + session = boto3.Session() + if resource_kwargs is None: + resource_kwargs = {} + if multipart_upload_kwargs is None: + multipart_upload_kwargs = {} + + self._multipart_upload_kwargs = multipart_upload_kwargs + + s3 = session.resource('s3', **resource_kwargs) + try: + self._object = s3.Object(bucket, key) + s3.meta.client.head_bucket(Bucket=bucket) + except botocore.client.ClientError: + raise ValueError('the bucket %r does not exist, or is forbidden for access' % bucket) + + self._buf = io.BytesIO() + self._total_bytes = 0 + self._closed = False + + # + # This member is part of the io.BufferedIOBase interface. + # + self.raw = None + + def flush(self): + pass + + # + # Override some methods from io.IOBase. + # + def close(self): + self._buf.seek(0) + + try: + self._object.put(Body=self._buf, **self._multipart_upload_kwargs) + except botocore.client.ClientError: + raise ValueError('the bucket %r does not exist, or is forbidden for access' % self._object.bucket_name) + + logger.debug("direct upload finished") + self._closed = True + + @property + def closed(self): + return self._closed + + def writable(self): + """Return True if the stream supports writing.""" + return True + + def tell(self): + """Return the current stream position.""" + return self._total_bytes + + # + # io.BufferedIOBase methods. + # + def detach(self): + raise io.UnsupportedOperation("detach() not supported") + + def write(self, b): + """Write the given buffer (bytes, bytearray, memoryview or any buffer + interface implementation) into the buffer. Content of the buffer will be + written to S3 on close as a single-part upload. + + For more information about buffers, see https://docs.python.org/3/c-api/buffer.html""" + + length = self._buf.write(b) + self._total_bytes += length + return length + + def terminate(self): + """Nothing to cancel in single-part uploads.""" + return + + # + # Internal methods. + # + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is not None: + self.terminate() + else: + self.close() + + def __str__(self): + return "smart_open.s3.BufferedSinglePartOutputBase(%r, %r)" % (self._object.bucket_name, self._object.key) + + def __repr__(self): + return ( + "smart_open.s3.BufferedSinglePartOutputBase(" + "bucket=%r, " + "key=%r, " + "session=%r, " + "resource_kwargs=%r, " + "multipart_upload_kwargs=%r)" + ) % ( + self._object.bucket_name, + self._object.key, + self._session, + self._resource_kwargs, + self._multipart_upload_kwargs, + ) + + def _accept_all(key): return True diff --git a/smart_open/tests/test_s3.py b/smart_open/tests/test_s3.py index 840b95a6..e39bb8d3 100644 --- a/smart_open/tests/test_s3.py +++ b/smart_open/tests/test_s3.py @@ -302,7 +302,7 @@ def test_to_boto3(self): @maybe_mock_s3 -class BufferedOutputBaseTest(unittest.TestCase): +class BufferedMultiPartOutputBaseTest(unittest.TestCase): """ Test writing into s3 files. @@ -318,7 +318,7 @@ def test_write_01(self): test_string = u"žluťoučký koníček".encode('utf8') # write into key - with smart_open.s3.BufferedOutputBase(BUCKET_NAME, WRITE_KEY_NAME) as fout: + with smart_open.s3.BufferedMultiPartOutputBase(BUCKET_NAME, WRITE_KEY_NAME) as fout: fout.write(test_string) # read key and test content @@ -329,7 +329,7 @@ def test_write_01(self): def test_write_01a(self): """Does s3 write fail on incorrect input?""" try: - with smart_open.s3.BufferedOutputBase(BUCKET_NAME, WRITE_KEY_NAME) as fin: + with smart_open.s3.BufferedMultiPartOutputBase(BUCKET_NAME, WRITE_KEY_NAME) as fin: fin.write(None) except TypeError: pass @@ -338,7 +338,7 @@ def test_write_01a(self): def test_write_02(self): """Does s3 write unicode-utf8 conversion work?""" - smart_open_write = smart_open.s3.BufferedOutputBase(BUCKET_NAME, WRITE_KEY_NAME) + smart_open_write = smart_open.s3.BufferedMultiPartOutputBase(BUCKET_NAME, WRITE_KEY_NAME) smart_open_write.tell() logger.info("smart_open_write: %r", smart_open_write) with smart_open_write as fout: @@ -348,7 +348,7 @@ def test_write_02(self): def test_write_03(self): """Does s3 multipart chunking work correctly?""" # write - smart_open_write = smart_open.s3.BufferedOutputBase( + smart_open_write = smart_open.s3.BufferedMultiPartOutputBase( BUCKET_NAME, WRITE_KEY_NAME, min_part_size=10 ) with smart_open_write as fout: @@ -369,7 +369,7 @@ def test_write_03(self): def test_write_04(self): """Does writing no data cause key with an empty value to be created?""" - smart_open_write = smart_open.s3.BufferedOutputBase(BUCKET_NAME, WRITE_KEY_NAME) + smart_open_write = smart_open.s3.BufferedMultiPartOutputBase(BUCKET_NAME, WRITE_KEY_NAME) with smart_open_write as fout: # noqa pass @@ -380,7 +380,7 @@ def test_write_04(self): def test_gzip(self): expected = u'а не спеть ли мне песню... о любви'.encode('utf-8') - with smart_open.s3.BufferedOutputBase(BUCKET_NAME, WRITE_KEY_NAME) as fout: + with smart_open.s3.BufferedMultiPartOutputBase(BUCKET_NAME, WRITE_KEY_NAME) as fout: with gzip.GzipFile(fileobj=fout, mode='w') as zipfile: zipfile.write(expected) @@ -397,7 +397,7 @@ def test_buffered_writer_wrapper_works(self): """ expected = u'не думай о секундах свысока' - with smart_open.s3.BufferedOutputBase(BUCKET_NAME, WRITE_KEY_NAME) as fout: + with smart_open.s3.BufferedMultiPartOutputBase(BUCKET_NAME, WRITE_KEY_NAME) as fout: with io.BufferedWriter(fout) as sub_out: sub_out.write(expected.encode('utf-8')) @@ -450,6 +450,101 @@ def test_to_boto3(self): self.assertEqual(contents, boto3_body) +@maybe_mock_s3 +class BufferedSinglePartOutputBaseTest(unittest.TestCase): + """ + Test writing into s3 files using single part upload. + + """ + def setUp(self): + ignore_resource_warnings() + + def tearDown(self): + cleanup_bucket() + + def test_write_01(self): + """Does writing into s3 work correctly?""" + test_string = u"žluťoučký koníček".encode('utf8') + + # write into key + with smart_open.s3.BufferedSinglePartOutputBase(BUCKET_NAME, WRITE_KEY_NAME) as fout: + fout.write(test_string) + + # read key and test content + output = list(smart_open.smart_open("s3://{}/{}".format(BUCKET_NAME, WRITE_KEY_NAME), "rb")) + + self.assertEqual(output, [test_string]) + + def test_write_01a(self): + """Does s3 write fail on incorrect input?""" + try: + with smart_open.s3.BufferedSinglePartOutputBase(BUCKET_NAME, WRITE_KEY_NAME) as fin: + fin.write(None) + except TypeError: + pass + else: + self.fail() + + def test_write_02(self): + """Does s3 write unicode-utf8 conversion work?""" + test_string = u"testžížáč".encode("utf-8") + + smart_open_write = smart_open.s3.BufferedSinglePartOutputBase(BUCKET_NAME, WRITE_KEY_NAME) + smart_open_write.tell() + logger.info("smart_open_write: %r", smart_open_write) + with smart_open_write as fout: + fout.write(test_string) + self.assertEqual(fout.tell(), 14) + + def test_write_04(self): + """Does writing no data cause key with an empty value to be created?""" + smart_open_write = smart_open.s3.BufferedSinglePartOutputBase(BUCKET_NAME, WRITE_KEY_NAME) + with smart_open_write as fout: # noqa + pass + + # read back the same key and check its content + output = list(smart_open.smart_open("s3://{}/{}".format(BUCKET_NAME, WRITE_KEY_NAME))) + + self.assertEqual(output, []) + + def test_buffered_writer_wrapper_works(self): + """ + Ensure that we can wrap a smart_open s3 stream in a BufferedWriter, which + passes a memoryview object to the underlying stream in python >= 2.7 + """ + expected = u'не думай о секундах свысока' + + with smart_open.s3.BufferedSinglePartOutputBase(BUCKET_NAME, WRITE_KEY_NAME) as fout: + with io.BufferedWriter(fout) as sub_out: + sub_out.write(expected.encode('utf-8')) + + with smart_open.smart_open("s3://{}/{}".format(BUCKET_NAME, WRITE_KEY_NAME)) as fin: + with io.TextIOWrapper(fin, encoding='utf-8') as text: + actual = text.read() + + self.assertEqual(expected, actual) + + def test_nonexisting_bucket(self): + expected = u"выйду ночью в поле с конём".encode('utf-8') + with self.assertRaises(ValueError): + with smart_open.s3.open('thisbucketdoesntexist', 'mykey', 'wb', multipart_upload=False) as fout: + fout.write(expected) + + def test_double_close(self): + text = u'там за туманами, вечными, пьяными'.encode('utf-8') + fout = smart_open.s3.open(BUCKET_NAME, 'key', 'wb', multipart_upload=False) + fout.write(text) + fout.close() + fout.close() + + def test_flush_close(self): + text = u'там за туманами, вечными, пьяными'.encode('utf-8') + fout = smart_open.s3.open(BUCKET_NAME, 'key', 'wb', multipart_upload=False) + fout.write(text) + fout.flush() + fout.close() + + class ClampTest(unittest.TestCase): def test(self): self.assertEqual(smart_open.s3.clamp(5, 0, 10), 5) diff --git a/smart_open/tests/test_smart_open.py b/smart_open/tests/test_smart_open.py index 2ea0a359..4d975584 100644 --- a/smart_open/tests/test_smart_open.py +++ b/smart_open/tests/test_smart_open.py @@ -1501,7 +1501,7 @@ def test_respects_endpoint_url_read(self, mock_open): expected = {'endpoint_url': 'https://play.min.io:9000'} self.assertEqual(mock_open.call_args[1]['resource_kwargs'], expected) - @mock.patch('smart_open.s3.BufferedOutputBase') + @mock.patch('smart_open.s3.BufferedMultiPartOutputBase') def test_respects_endpoint_url_write(self, mock_open): url = 's3://key_id:secret_key@play.min.io:9000@smart-open-test/README.rst' smart_open.open(url, 'wb')