From f4e785c84042635fb6741965be3fb7d384164477 Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Wed, 24 Aug 2022 10:55:21 -0700 Subject: [PATCH] [Storage] Fix public bucket source check in SkyPilot Storage (#1087) * fix public bucket detection * add nonexist bucket test * add private bucket test * lint * arg naming * Add check for non existent bucket in tests * lint * Add tcga-2-open to tests * Fix typo --- sky/data/storage.py | 7 ++---- tests/test_smoke.py | 56 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/sky/data/storage.py b/sky/data/storage.py index b3e7c64aaac..1dff096fe53 100644 --- a/sky/data/storage.py +++ b/sky/data/storage.py @@ -854,21 +854,18 @@ def _get_bucket(self) -> Tuple[StorageHandle, bool]: # This line does not error out if the bucket is an external public # bucket or if it is a user's bucket that is publicly # accessible. - self.client.get_public_access_block(Bucket=self.name) + self.client.head_bucket(Bucket=self.name) return bucket, False except aws.client_exception() as e: error_code = e.response['Error']['Code'] # AccessDenied error for buckets that are private and not owned by # user. - if error_code == 'AccessDenied': + if error_code == '403': command = f'aws s3 ls {self.name}' with ux_utils.print_exception_no_traceback(): raise exceptions.StorageBucketGetError( _BUCKET_FAIL_TO_CONNECT_MESSAGE.format( name=self.name, command=command)) from e - # Try private bucket case. - if data_utils.verify_s3_bucket(self.name): - return bucket, False if self.source is not None and self.source.startswith('s3://'): with ux_utils.print_exception_no_traceback(): diff --git a/tests/test_smoke.py b/tests/test_smoke.py index e22242e872a..1ecd3cf57e8 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -854,6 +854,7 @@ def test_new_bucket_creation_and_deletion(self, tmp_local_storage_obj, @pytest.mark.parametrize( 'tmp_public_storage_obj, store_type', [('s3://tcga-2-open', storage_lib.StoreType.S3), + ('s3://digitalcorpora', storage_lib.StoreType.S3), ('gs://gcp-public-data-sentinel-2', storage_lib.StoreType.GCS)], indirect=['tmp_public_storage_obj']) def test_public_bucket(self, tmp_public_storage_obj, store_type): @@ -865,6 +866,61 @@ def test_public_bucket(self, tmp_public_storage_obj, store_type): out = subprocess.check_output(['sky', 'storage', 'ls']) assert tmp_public_storage_obj.name not in out.decode('utf-8') + @pytest.mark.parametrize('nonexist_bucket_url', + ['s3://{random_name}', 'gs://{random_name}']) + def test_nonexistent_bucket(self, nonexist_bucket_url): + # Attempts to create fetch a stroage with a non-existent source. + # Generate a random bucket name and verify it doesn't exist: + retry_count = 0 + while True: + nonexist_bucket_name = str(uuid.uuid4()) + if nonexist_bucket_url.startswith('s3'): + command = [ + 'aws', 's3api', 'head-bucket', '--bucket', + nonexist_bucket_name + ] + expected_output = '404' + elif nonexist_bucket_url.startswith('gs'): + command = [ + 'gsutil', 'ls', + nonexist_bucket_url.format(random_name=nonexist_bucket_name) + ] + expected_output = 'BucketNotFoundException' + else: + raise ValueError('Unsupported bucket type ' + f'{nonexist_bucket_url}') + + # Check if bucket exists using the cli: + try: + out = subprocess.check_output(command, stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + out = e.output + out = out.decode('utf-8') + if expected_output in out: + break + else: + retry_count += 1 + if retry_count > 3: + raise RuntimeError('Unable to find a nonexistent bucket ' + 'to use. This is higly unlikely - ' + 'check if the tests are correct.') + + with pytest.raises( + sky.exceptions.StorageBucketGetError, + match='Attempted to connect to a non-existent bucket'): + storage_obj = storage_lib.Storage(source=nonexist_bucket_url.format( + random_name=nonexist_bucket_name)) + + @pytest.mark.parametrize('private_bucket', + [f's3://imagenet', f'gs://imagenet']) + def test_private_bucket(self, private_bucket): + # Attempts to access private buckets not belonging to the user. + # These buckets are known to be private, but may need to be updated if + # they are removed by their owners. + with pytest.raises(sky.exceptions.StorageBucketGetError, + match='the bucket name is taken'): + storage_obj = storage_lib.Storage(source=private_bucket) + @staticmethod def cli_ls_cmd(store_type, bucket_name): if store_type == storage_lib.StoreType.S3: