Skip to content

Commit

Permalink
[AIRFLOW-5335] Update GCSHook methods so they need min IAM perms (#5939)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaxil authored Aug 29, 2019
1 parent 3c78919 commit b1d3d55
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 50 deletions.
36 changes: 18 additions & 18 deletions airflow/contrib/hooks/gcs_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def copy(self, source_bucket, source_object, destination_bucket=None,
raise ValueError('source_bucket and source_object cannot be empty.')

client = self.get_conn()
source_bucket = client.get_bucket(source_bucket)
source_bucket = client.bucket(source_bucket)
source_object = source_bucket.blob(source_object)
destination_bucket = client.get_bucket(destination_bucket)
destination_bucket = client.bucket(destination_bucket)
destination_object = source_bucket.copy_blob(
blob=source_object,
destination_bucket=destination_bucket,
Expand Down Expand Up @@ -133,9 +133,9 @@ def rewrite(self, source_bucket, source_object, destination_bucket,
raise ValueError('source_bucket and source_object cannot be empty.')

client = self.get_conn()
source_bucket = client.get_bucket(source_bucket)
source_bucket = client.bucket(source_bucket)
source_object = source_bucket.blob(blob_name=source_object)
destination_bucket = client.get_bucket(destination_bucket)
destination_bucket = client.bucket(destination_bucket)

token, bytes_rewritten, total_bytes = destination_bucket.blob(
blob_name=destination_object).rewrite(
Expand Down Expand Up @@ -169,7 +169,7 @@ def download(self, bucket_name, object_name, filename=None):
:type filename: str
"""
client = self.get_conn()
bucket = client.get_bucket(bucket_name)
bucket = client.bucket(bucket_name)
blob = bucket.blob(blob_name=object_name)

if filename:
Expand Down Expand Up @@ -204,7 +204,7 @@ def upload(self, bucket_name, object_name, filename,
filename = filename_gz

client = self.get_conn()
bucket = client.get_bucket(bucket_name)
bucket = client.bucket(bucket_name)
blob = bucket.blob(blob_name=object_name)
blob.upload_from_filename(filename=filename,
content_type=mime_type)
Expand All @@ -224,7 +224,7 @@ def exists(self, bucket_name, object_name):
:type object_name: str
"""
client = self.get_conn()
bucket = client.get_bucket(bucket_name)
bucket = client.bucket(bucket_name)
blob = bucket.blob(blob_name=object_name)
return blob.exists()

Expand All @@ -241,9 +241,12 @@ def is_updated_after(self, bucket_name, object_name, ts):
:type ts: datetime.datetime
"""
client = self.get_conn()
bucket = storage.Bucket(client=client, name=bucket_name)
bucket = client.bucket(bucket_name)
blob = bucket.get_blob(blob_name=object_name)
blob.reload()

if blob is None:
raise ValueError("Object ({}) not found in Bucket ({})".format(
object_name, bucket_name))

blob_update_time = blob.updated

Expand All @@ -270,7 +273,7 @@ def delete(self, bucket_name, object_name):
:type object_name: str
"""
client = self.get_conn()
bucket = client.get_bucket(bucket_name)
bucket = client.bucket(bucket_name)
blob = bucket.blob(blob_name=object_name)
blob.delete()

Expand All @@ -294,7 +297,7 @@ def list(self, bucket_name, versions=None, max_results=None, prefix=None, delimi
:return: a stream of object names matching the filtering criteria
"""
client = self.get_conn()
bucket = client.get_bucket(bucket_name)
bucket = client.bucket(bucket_name)

ids = []
page_token = None
Expand Down Expand Up @@ -338,9 +341,8 @@ def get_size(self, bucket_name, object_name):
object_name,
bucket_name)
client = self.get_conn()
bucket = client.get_bucket(bucket_name)
bucket = client.bucket(bucket_name)
blob = bucket.get_blob(blob_name=object_name)
blob.reload()
blob_size = blob.size
self.log.info('The file size of %s is %s bytes.', object_name, blob_size)
return blob_size
Expand All @@ -358,9 +360,8 @@ def get_crc32c(self, bucket_name, object_name):
self.log.info('Retrieving the crc32c checksum of '
'object_name: %s in bucket_name: %s', object_name, bucket_name)
client = self.get_conn()
bucket = client.get_bucket(bucket_name)
bucket = client.bucket(bucket_name)
blob = bucket.get_blob(blob_name=object_name)
blob.reload()
blob_crc32c = blob.crc32c
self.log.info('The crc32c checksum of %s is %s', object_name, blob_crc32c)
return blob_crc32c
Expand All @@ -378,9 +379,8 @@ def get_md5hash(self, bucket_name, object_name):
self.log.info('Retrieving the MD5 hash of '
'object: %s in bucket: %s', object_name, bucket_name)
client = self.get_conn()
bucket = client.get_bucket(bucket_name)
bucket = client.bucket(bucket_name)
blob = bucket.get_blob(blob_name=object_name)
blob.reload()
blob_md5hash = blob.md5_hash
self.log.info('The md5Hash of %s is %s', object_name, blob_md5hash)
return blob_md5hash
Expand Down Expand Up @@ -550,7 +550,7 @@ def compose(self, bucket_name, source_objects, destination_object):
self.log.info("Composing %s to %s in the bucket %s",
source_objects, destination_object, bucket_name)
client = self.get_conn()
bucket = client.get_bucket(bucket_name)
bucket = client.bucket(bucket_name)
destination_blob = bucket.blob(destination_object)
destination_blob.compose(
sources=[
Expand Down
88 changes: 56 additions & 32 deletions tests/contrib/hooks/test_gcs_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import os
import tempfile
import unittest
from datetime import datetime

import dateutil
from google.cloud import storage
from google.cloud import exceptions

Expand Down Expand Up @@ -69,14 +71,25 @@ def setUp(self):
self.gcs_hook = gcs_hook.GoogleCloudStorageHook(
google_cloud_storage_conn_id='test')

def test_storage_client_creation(self):
with mock.patch('google.cloud.storage.Client') as mock_client:
gcs_hook_1 = gcs_hook.GoogleCloudStorageHook()
gcs_hook_1.get_conn()

# test that Storage Client is called with required arguments
mock_client.assert_called_once_with(
client_info=mock.ANY,
credentials=mock.ANY,
project=mock.ANY)

@mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn'))
def test_exists(self, mock_service):
test_bucket = 'test_bucket'
test_object = 'test_object'

# Given
get_bucket_mock = mock_service.return_value.get_bucket
blob_object = get_bucket_mock.return_value.blob
bucket_mock = mock_service.return_value.bucket
blob_object = bucket_mock.return_value.blob
exists_method = blob_object.return_value.exists
exists_method.return_value = True

Expand All @@ -85,7 +98,7 @@ def test_exists(self, mock_service):

# Then
self.assertTrue(response)
get_bucket_mock.assert_called_once_with(test_bucket)
bucket_mock.assert_called_once_with(test_bucket)
blob_object.assert_called_once_with(blob_name=test_object)
exists_method.assert_called_once_with()

Expand All @@ -95,8 +108,8 @@ def test_exists_nonexisting_object(self, mock_service):
test_object = 'test_object'

# Given
get_bucket_mock = mock_service.return_value.get_bucket
blob_object = get_bucket_mock.return_value.blob
bucket_mock = mock_service.return_value.bucket
blob_object = bucket_mock.return_value.blob
exists_method = blob_object.return_value.exists
exists_method.return_value = False

Expand All @@ -106,6 +119,24 @@ def test_exists_nonexisting_object(self, mock_service):
# Then
self.assertFalse(response)

@mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn'))
def test_is_updated_after(self, mock_service):
test_bucket = 'test_bucket'
test_object = 'test_object'

# Given
mock_service.return_value.bucket.return_value.get_blob\
.return_value.updated = datetime(2019, 8, 28, 14, 7, 20, 700000, dateutil.tz.tzutc())

# When
response = self.gcs_hook.is_updated_after(
bucket_name=test_bucket, object_name=test_object,
ts=datetime(2018, 1, 1, 1, 1, 1)
)

# Then
self.assertTrue(response)

@mock.patch('google.cloud.storage.Bucket')
@mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn'))
def test_copy(self, mock_service, mock_bucket):
Expand All @@ -121,9 +152,9 @@ def test_copy(self, mock_service, mock_bucket):
name=destination_object)

# Given
get_bucket_mock = mock_service.return_value.get_bucket
get_bucket_mock.return_value = mock_bucket
copy_method = get_bucket_mock.return_value.copy_blob
bucket_mock = mock_service.return_value.bucket
bucket_mock.return_value = mock_bucket
copy_method = bucket_mock.return_value.copy_blob
copy_method.return_value = destination_blob

# When
Expand Down Expand Up @@ -206,9 +237,9 @@ def test_rewrite(self, mock_service, mock_bucket):
source_blob = mock_bucket.blob(source_object)

# Given
get_bucket_mock = mock_service.return_value.get_bucket
get_bucket_mock.return_value = mock_bucket
get_blob_method = get_bucket_mock.return_value.blob
bucket_mock = mock_service.return_value.bucket
bucket_mock.return_value = mock_bucket
get_blob_method = bucket_mock.return_value.blob
rewrite_method = get_blob_method.return_value.rewrite
rewrite_method.side_effect = [(None, mock.ANY, mock.ANY), (mock.ANY, mock.ANY, mock.ANY)]

Expand Down Expand Up @@ -280,8 +311,8 @@ def test_delete_nonexisting_object(self, mock_service):
test_bucket = 'test_bucket'
test_object = 'test_object'

get_bucket_method = mock_service.return_value.get_bucket
blob = get_bucket_method.return_value.blob
bucket_method = mock_service.return_value.bucket
blob = bucket_method.return_value.blob
delete_method = blob.return_value.delete
delete_method.side_effect = exceptions.NotFound(message="Not Found")

Expand All @@ -294,52 +325,45 @@ def test_object_get_size(self, mock_service):
test_object = 'test_object'
returned_file_size = 1200

get_bucket_method = mock_service.return_value.get_bucket
get_blob_method = get_bucket_method.return_value.get_blob
bucket_method = mock_service.return_value.bucket
get_blob_method = bucket_method.return_value.get_blob
get_blob_method.return_value.size = returned_file_size

response = self.gcs_hook.get_size(bucket_name=test_bucket,
object_name=test_object)

self.assertEqual(response, returned_file_size)
get_blob_method.return_value.reload.assert_called_once_with()

@mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn'))
def test_object_get_crc32c(self, mock_service):
test_bucket = 'test_bucket'
test_object = 'test_object'
returned_file_crc32c = "xgdNfQ=="

get_bucket_method = mock_service.return_value.get_bucket
get_blob_method = get_bucket_method.return_value.get_blob
bucket_method = mock_service.return_value.bucket
get_blob_method = bucket_method.return_value.get_blob
get_blob_method.return_value.crc32c = returned_file_crc32c

response = self.gcs_hook.get_crc32c(bucket_name=test_bucket,
object_name=test_object)

self.assertEqual(response, returned_file_crc32c)

# Check that reload method is called
get_blob_method.return_value.reload.assert_called_once_with()

@mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn'))
def test_object_get_md5hash(self, mock_service):
test_bucket = 'test_bucket'
test_object = 'test_object'
returned_file_md5hash = "leYUJBUWrRtks1UeUFONJQ=="

get_bucket_method = mock_service.return_value.get_bucket
get_blob_method = get_bucket_method.return_value.get_blob
bucket_method = mock_service.return_value.bucket
get_blob_method = bucket_method.return_value.get_blob
get_blob_method.return_value.md5_hash = returned_file_md5hash

response = self.gcs_hook.get_md5hash(bucket_name=test_bucket,
object_name=test_object)

self.assertEqual(response, returned_file_md5hash)

# Check that reload method is called
get_blob_method.return_value.reload.assert_called_once_with()

@mock.patch('google.cloud.storage.Bucket')
@mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn'))
def test_create_bucket(self, mock_service, mock_bucket):
Expand Down Expand Up @@ -416,9 +440,9 @@ def test_compose(self, mock_service, mock_blob):
test_source_objects = ['test_object_1', 'test_object_2', 'test_object_3']
test_destination_object = 'test_object_composed'

mock_service.return_value.get_bucket.return_value\
mock_service.return_value.bucket.return_value\
.blob.return_value = mock_blob(blob_name=mock.ANY)
method = mock_service.return_value.get_bucket.return_value.blob\
method = mock_service.return_value.bucket.return_value.blob\
.return_value.compose

self.gcs_hook.compose(
Expand Down Expand Up @@ -492,7 +516,7 @@ def test_download_as_string(self, mock_service):
test_object = 'test_object'
test_object_bytes = io.BytesIO(b"input")

download_method = mock_service.return_value.get_bucket.return_value \
download_method = mock_service.return_value.bucket.return_value \
.blob.return_value.download_as_string
download_method.return_value = test_object_bytes

Expand All @@ -510,11 +534,11 @@ def test_download_to_file(self, mock_service):
test_object_bytes = io.BytesIO(b"input")
test_file = 'test_file'

download_filename_method = mock_service.return_value.get_bucket.return_value \
download_filename_method = mock_service.return_value.bucket.return_value \
.blob.return_value.download_to_filename
download_filename_method.return_value = None

download_as_a_string_method = mock_service.return_value.get_bucket.return_value \
download_as_a_string_method = mock_service.return_value.bucket.return_value \
.blob.return_value.download_as_string
download_as_a_string_method.return_value = test_object_bytes

Expand Down Expand Up @@ -546,7 +570,7 @@ def test_upload(self, mock_service):
test_bucket = 'test_bucket'
test_object = 'test_object'

upload_method = mock_service.return_value.get_bucket.return_value\
upload_method = mock_service.return_value.bucket.return_value\
.blob.return_value.upload_from_filename
upload_method.return_value = None

Expand Down

0 comments on commit b1d3d55

Please sign in to comment.