Skip to content

Commit

Permalink
added support for using S3 single part uploads
Browse files Browse the repository at this point in the history
  • Loading branch information
adrpar committed Jan 20, 2020
1 parent a621aeb commit 34feb3f
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 21 deletions.
164 changes: 152 additions & 12 deletions smart_open/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def open(
session=None,
resource_kwargs=None,
multipart_upload_kwargs=None,
multipart_upload=True,
):
"""Open an S3 object for reading or writing.
Expand All @@ -100,6 +101,12 @@ def open(
multipart_upload_kwargs: dict, optional
Additional parameters to pass to boto3's initiate_multipart_upload function.
For writing only.
multipart_upload: bool, optional
Default: `True`
If set to `True`, will use multipart upload for writing to S3. If set
to `false`, S3 upload will use the S3 Single-Part Upload API, which
is more ideal for small file sizes.
For writing only.
version_id: str, optional
Version of the object, used when reading object.
If None, will fetch the most recent version.
Expand Down Expand Up @@ -127,14 +134,23 @@ def open(
resource_kwargs=resource_kwargs,
)
elif mode == WRITE_BINARY:
fileobj = BufferedOutputBase(
bucket_id,
key_id,
min_part_size=min_part_size,
session=session,
multipart_upload_kwargs=multipart_upload_kwargs,
resource_kwargs=resource_kwargs,
)
if multipart_upload:
fileobj = BufferedMultiPartOutputBase(
bucket_id,
key_id,
min_part_size=min_part_size,
session=session,
multipart_upload_kwargs=multipart_upload_kwargs,
resource_kwargs=resource_kwargs,
)
else:
fileobj = BufferedSinglePartOutputBase(
bucket_id,
key_id,
session=session,
multipart_upload_kwargs=multipart_upload_kwargs,
resource_kwargs=resource_kwargs,
)
else:
assert False, 'unexpected mode: %r' % mode
return fileobj
Expand Down Expand Up @@ -479,8 +495,8 @@ def __repr__(self):
)


class BufferedOutputBase(io.BufferedIOBase):
"""Writes bytes to S3.
class BufferedMultiPartOutputBase(io.BufferedIOBase):
"""Writes bytes to S3 using the multi part API.
Implements the io.BufferedIOBase interface of the standard library."""

Expand Down Expand Up @@ -637,11 +653,11 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

def __str__(self):
return "smart_open.s3.BufferedOutputBase(%r, %r)" % (self._object.bucket_name, self._object.key)
return "smart_open.s3.BufferedMultiPartOutputBase(%r, %r)" % (self._object.bucket_name, self._object.key)

def __repr__(self):
return (
"smart_open.s3.BufferedOutputBase("
"smart_open.s3.BufferedMultiPartOutputBase("
"bucket=%r, "
"key=%r, "
"min_part_size=%r, "
Expand All @@ -658,6 +674,130 @@ def __repr__(self):
)


class BufferedSinglePartOutputBase(io.BufferedIOBase):
"""Writes bytes to S3 using the single part API.
Implements the io.BufferedIOBase interface of the standard library."""

def __init__(
self,
bucket,
key,
session=None,
resource_kwargs=None,
multipart_upload_kwargs=None,
):

self._session = session
self._resource_kwargs = resource_kwargs

if session is None:
session = boto3.Session()
if resource_kwargs is None:
resource_kwargs = {}
if multipart_upload_kwargs is None:
multipart_upload_kwargs = {}

self._multipart_upload_kwargs = multipart_upload_kwargs

s3 = session.resource('s3', **resource_kwargs)
try:
self._object = s3.Object(bucket, key)
s3.meta.client.head_bucket(Bucket=bucket)
except botocore.client.ClientError:
raise ValueError('the bucket %r does not exist, or is forbidden for access' % bucket)

self._buf = io.BytesIO()
self._total_bytes = 0
self._closed = False

#
# This member is part of the io.BufferedIOBase interface.
#
self.raw = None

def flush(self):
pass

#
# Override some methods from io.IOBase.
#
def close(self):
self._buf.seek(0)

try:
self._object.put(Body=self._buf, **self._multipart_upload_kwargs)
except botocore.client.ClientError:
raise ValueError('the bucket %r does not exist, or is forbidden for access' % self._object.bucket_name)

logger.debug("direct upload finished")
self._closed = True

@property
def closed(self):
return self._closed

def writable(self):
"""Return True if the stream supports writing."""
return True

def tell(self):
"""Return the current stream position."""
return self._total_bytes

#
# io.BufferedIOBase methods.
#
def detach(self):
raise io.UnsupportedOperation("detach() not supported")

def write(self, b):
"""Write the given buffer (bytes, bytearray, memoryview or any buffer
interface implementation) into the buffer. Content of the buffer will be
written to S3 on close as a single-part upload.
For more information about buffers, see https://docs.python.org/3/c-api/buffer.html"""

length = self._buf.write(b)
self._total_bytes += length
return length

def terminate(self):
"""Nothing to cancel in single-part uploads."""
return

#
# Internal methods.
#
def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
self.terminate()
else:
self.close()

def __str__(self):
return "smart_open.s3.BufferedSinglePartOutputBase(%r, %r)" % (self._object.bucket_name, self._object.key)

def __repr__(self):
return (
"smart_open.s3.BufferedSinglePartOutputBase("
"bucket=%r, "
"key=%r, "
"session=%r, "
"resource_kwargs=%r, "
"multipart_upload_kwargs=%r)"
) % (
self._object.bucket_name,
self._object.key,
self._session,
self._resource_kwargs,
self._multipart_upload_kwargs,
)


def _accept_all(key):
return True

Expand Down
111 changes: 103 additions & 8 deletions smart_open/tests/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def test_to_boto3(self):


@maybe_mock_s3
class BufferedOutputBaseTest(unittest.TestCase):
class BufferedMultiPartOutputBaseTest(unittest.TestCase):
"""
Test writing into s3 files.
Expand All @@ -318,7 +318,7 @@ def test_write_01(self):
test_string = u"žluťoučký koníček".encode('utf8')

# write into key
with smart_open.s3.BufferedOutputBase(BUCKET_NAME, WRITE_KEY_NAME) as fout:
with smart_open.s3.BufferedMultiPartOutputBase(BUCKET_NAME, WRITE_KEY_NAME) as fout:
fout.write(test_string)

# read key and test content
Expand All @@ -329,7 +329,7 @@ def test_write_01(self):
def test_write_01a(self):
"""Does s3 write fail on incorrect input?"""
try:
with smart_open.s3.BufferedOutputBase(BUCKET_NAME, WRITE_KEY_NAME) as fin:
with smart_open.s3.BufferedMultiPartOutputBase(BUCKET_NAME, WRITE_KEY_NAME) as fin:
fin.write(None)
except TypeError:
pass
Expand All @@ -338,7 +338,7 @@ def test_write_01a(self):

def test_write_02(self):
"""Does s3 write unicode-utf8 conversion work?"""
smart_open_write = smart_open.s3.BufferedOutputBase(BUCKET_NAME, WRITE_KEY_NAME)
smart_open_write = smart_open.s3.BufferedMultiPartOutputBase(BUCKET_NAME, WRITE_KEY_NAME)
smart_open_write.tell()
logger.info("smart_open_write: %r", smart_open_write)
with smart_open_write as fout:
Expand All @@ -348,7 +348,7 @@ def test_write_02(self):
def test_write_03(self):
"""Does s3 multipart chunking work correctly?"""
# write
smart_open_write = smart_open.s3.BufferedOutputBase(
smart_open_write = smart_open.s3.BufferedMultiPartOutputBase(
BUCKET_NAME, WRITE_KEY_NAME, min_part_size=10
)
with smart_open_write as fout:
Expand All @@ -369,7 +369,7 @@ def test_write_03(self):

def test_write_04(self):
"""Does writing no data cause key with an empty value to be created?"""
smart_open_write = smart_open.s3.BufferedOutputBase(BUCKET_NAME, WRITE_KEY_NAME)
smart_open_write = smart_open.s3.BufferedMultiPartOutputBase(BUCKET_NAME, WRITE_KEY_NAME)
with smart_open_write as fout: # noqa
pass

Expand All @@ -380,7 +380,7 @@ def test_write_04(self):

def test_gzip(self):
expected = u'а не спеть ли мне песню... о любви'.encode('utf-8')
with smart_open.s3.BufferedOutputBase(BUCKET_NAME, WRITE_KEY_NAME) as fout:
with smart_open.s3.BufferedMultiPartOutputBase(BUCKET_NAME, WRITE_KEY_NAME) as fout:
with gzip.GzipFile(fileobj=fout, mode='w') as zipfile:
zipfile.write(expected)

Expand All @@ -397,7 +397,7 @@ def test_buffered_writer_wrapper_works(self):
"""
expected = u'не думай о секундах свысока'

with smart_open.s3.BufferedOutputBase(BUCKET_NAME, WRITE_KEY_NAME) as fout:
with smart_open.s3.BufferedMultiPartOutputBase(BUCKET_NAME, WRITE_KEY_NAME) as fout:
with io.BufferedWriter(fout) as sub_out:
sub_out.write(expected.encode('utf-8'))

Expand Down Expand Up @@ -450,6 +450,101 @@ def test_to_boto3(self):
self.assertEqual(contents, boto3_body)


@maybe_mock_s3
class BufferedSinglePartOutputBaseTest(unittest.TestCase):
"""
Test writing into s3 files using single part upload.
"""
def setUp(self):
ignore_resource_warnings()

def tearDown(self):
cleanup_bucket()

def test_write_01(self):
"""Does writing into s3 work correctly?"""
test_string = u"žluťoučký koníček".encode('utf8')

# write into key
with smart_open.s3.BufferedSinglePartOutputBase(BUCKET_NAME, WRITE_KEY_NAME) as fout:
fout.write(test_string)

# read key and test content
output = list(smart_open.smart_open("s3://{}/{}".format(BUCKET_NAME, WRITE_KEY_NAME), "rb"))

self.assertEqual(output, [test_string])

def test_write_01a(self):
"""Does s3 write fail on incorrect input?"""
try:
with smart_open.s3.BufferedSinglePartOutputBase(BUCKET_NAME, WRITE_KEY_NAME) as fin:
fin.write(None)
except TypeError:
pass
else:
self.fail()

def test_write_02(self):
"""Does s3 write unicode-utf8 conversion work?"""
test_string = u"testžížáč".encode("utf-8")

smart_open_write = smart_open.s3.BufferedSinglePartOutputBase(BUCKET_NAME, WRITE_KEY_NAME)
smart_open_write.tell()
logger.info("smart_open_write: %r", smart_open_write)
with smart_open_write as fout:
fout.write(test_string)
self.assertEqual(fout.tell(), 14)

def test_write_04(self):
"""Does writing no data cause key with an empty value to be created?"""
smart_open_write = smart_open.s3.BufferedSinglePartOutputBase(BUCKET_NAME, WRITE_KEY_NAME)
with smart_open_write as fout: # noqa
pass

# read back the same key and check its content
output = list(smart_open.smart_open("s3://{}/{}".format(BUCKET_NAME, WRITE_KEY_NAME)))

self.assertEqual(output, [])

def test_buffered_writer_wrapper_works(self):
"""
Ensure that we can wrap a smart_open s3 stream in a BufferedWriter, which
passes a memoryview object to the underlying stream in python >= 2.7
"""
expected = u'не думай о секундах свысока'

with smart_open.s3.BufferedSinglePartOutputBase(BUCKET_NAME, WRITE_KEY_NAME) as fout:
with io.BufferedWriter(fout) as sub_out:
sub_out.write(expected.encode('utf-8'))

with smart_open.smart_open("s3://{}/{}".format(BUCKET_NAME, WRITE_KEY_NAME)) as fin:
with io.TextIOWrapper(fin, encoding='utf-8') as text:
actual = text.read()

self.assertEqual(expected, actual)

def test_nonexisting_bucket(self):
expected = u"выйду ночью в поле с конём".encode('utf-8')
with self.assertRaises(ValueError):
with smart_open.s3.open('thisbucketdoesntexist', 'mykey', 'wb', multipart_upload=False) as fout:
fout.write(expected)

def test_double_close(self):
text = u'там за туманами, вечными, пьяными'.encode('utf-8')
fout = smart_open.s3.open(BUCKET_NAME, 'key', 'wb', multipart_upload=False)
fout.write(text)
fout.close()
fout.close()

def test_flush_close(self):
text = u'там за туманами, вечными, пьяными'.encode('utf-8')
fout = smart_open.s3.open(BUCKET_NAME, 'key', 'wb', multipart_upload=False)
fout.write(text)
fout.flush()
fout.close()


class ClampTest(unittest.TestCase):
def test(self):
self.assertEqual(smart_open.s3.clamp(5, 0, 10), 5)
Expand Down
2 changes: 1 addition & 1 deletion smart_open/tests/test_smart_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -1501,7 +1501,7 @@ def test_respects_endpoint_url_read(self, mock_open):
expected = {'endpoint_url': 'https://play.min.io:9000'}
self.assertEqual(mock_open.call_args[1]['resource_kwargs'], expected)

@mock.patch('smart_open.s3.BufferedOutputBase')
@mock.patch('smart_open.s3.BufferedMultiPartOutputBase')
def test_respects_endpoint_url_write(self, mock_open):
url = 's3://key_id:[email protected]:9000@smart-open-test/README.rst'
smart_open.open(url, 'wb')
Expand Down

0 comments on commit 34feb3f

Please sign in to comment.