Skip to content

Commit

Permalink
[Storage] Fix public bucket source check in SkyPilot Storage (#1087)
Browse files Browse the repository at this point in the history
* fix public bucket detection

* add nonexist bucket test

* add private bucket test

* lint

* arg naming

* Add check for non existent bucket in tests

* lint

* Add tcga-2-open to tests

* Fix typo
  • Loading branch information
romilbhardwaj authored Aug 24, 2022
1 parent 70aab9d commit f4e785c
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 5 deletions.
7 changes: 2 additions & 5 deletions sky/data/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,21 +854,18 @@ def _get_bucket(self) -> Tuple[StorageHandle, bool]:
# This line does not error out if the bucket is an external public
# bucket or if it is a user's bucket that is publicly
# accessible.
self.client.get_public_access_block(Bucket=self.name)
self.client.head_bucket(Bucket=self.name)
return bucket, False
except aws.client_exception() as e:
error_code = e.response['Error']['Code']
# AccessDenied error for buckets that are private and not owned by
# user.
if error_code == 'AccessDenied':
if error_code == '403':
command = f'aws s3 ls {self.name}'
with ux_utils.print_exception_no_traceback():
raise exceptions.StorageBucketGetError(
_BUCKET_FAIL_TO_CONNECT_MESSAGE.format(
name=self.name, command=command)) from e
# Try private bucket case.
if data_utils.verify_s3_bucket(self.name):
return bucket, False

if self.source is not None and self.source.startswith('s3://'):
with ux_utils.print_exception_no_traceback():
Expand Down
56 changes: 56 additions & 0 deletions tests/test_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,7 @@ def test_new_bucket_creation_and_deletion(self, tmp_local_storage_obj,
@pytest.mark.parametrize(
'tmp_public_storage_obj, store_type',
[('s3://tcga-2-open', storage_lib.StoreType.S3),
('s3://digitalcorpora', storage_lib.StoreType.S3),
('gs://gcp-public-data-sentinel-2', storage_lib.StoreType.GCS)],
indirect=['tmp_public_storage_obj'])
def test_public_bucket(self, tmp_public_storage_obj, store_type):
Expand All @@ -865,6 +866,61 @@ def test_public_bucket(self, tmp_public_storage_obj, store_type):
out = subprocess.check_output(['sky', 'storage', 'ls'])
assert tmp_public_storage_obj.name not in out.decode('utf-8')

@pytest.mark.parametrize('nonexist_bucket_url',
['s3://{random_name}', 'gs://{random_name}'])
def test_nonexistent_bucket(self, nonexist_bucket_url):
# Attempts to create fetch a stroage with a non-existent source.
# Generate a random bucket name and verify it doesn't exist:
retry_count = 0
while True:
nonexist_bucket_name = str(uuid.uuid4())
if nonexist_bucket_url.startswith('s3'):
command = [
'aws', 's3api', 'head-bucket', '--bucket',
nonexist_bucket_name
]
expected_output = '404'
elif nonexist_bucket_url.startswith('gs'):
command = [
'gsutil', 'ls',
nonexist_bucket_url.format(random_name=nonexist_bucket_name)
]
expected_output = 'BucketNotFoundException'
else:
raise ValueError('Unsupported bucket type '
f'{nonexist_bucket_url}')

# Check if bucket exists using the cli:
try:
out = subprocess.check_output(command, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e:
out = e.output
out = out.decode('utf-8')
if expected_output in out:
break
else:
retry_count += 1
if retry_count > 3:
raise RuntimeError('Unable to find a nonexistent bucket '
'to use. This is higly unlikely - '
'check if the tests are correct.')

with pytest.raises(
sky.exceptions.StorageBucketGetError,
match='Attempted to connect to a non-existent bucket'):
storage_obj = storage_lib.Storage(source=nonexist_bucket_url.format(
random_name=nonexist_bucket_name))

@pytest.mark.parametrize('private_bucket',
[f's3://imagenet', f'gs://imagenet'])
def test_private_bucket(self, private_bucket):
# Attempts to access private buckets not belonging to the user.
# These buckets are known to be private, but may need to be updated if
# they are removed by their owners.
with pytest.raises(sky.exceptions.StorageBucketGetError,
match='the bucket name is taken'):
storage_obj = storage_lib.Storage(source=private_bucket)

@staticmethod
def cli_ls_cmd(store_type, bucket_name):
if store_type == storage_lib.StoreType.S3:
Expand Down

0 comments on commit f4e785c

Please sign in to comment.