Skip to content

Commit

Permalink
tests: avoid using 'real' client in bucket unit tests
Browse files Browse the repository at this point in the history
Also, adjust 'Bucket.get_upload_policy' to use credentials from the client,
rather than its connection.

Remove last uses of '_Client' and '_Connection' shims:  use a mocked
client instead.

Toward #416.
  • Loading branch information
tseaver committed Jun 8, 2021
1 parent 6bd8a20 commit 7a93104
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 130 deletions.
2 changes: 1 addition & 1 deletion google/cloud/storage/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -3181,7 +3181,7 @@ def generate_upload_policy(self, conditions, expiration=None, client=None):
to attach the signature.
"""
client = self._require_client(client)
credentials = client._base_connection.credentials
credentials = client._credentials
_signing.ensure_signed_credentials(credentials)

if expiration is None:
Expand Down
150 changes: 21 additions & 129 deletions tests/unit/test_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,15 +442,14 @@ def _get_default_timeout():
return _DEFAULT_TIMEOUT

@staticmethod
def _make_client(*args, **kw):
def _make_client(**kw):
from google.cloud.storage.client import Client

return Client(*args, **kw)
return mock.create_autospec(Client, instance=True, **kw)

def _make_one(self, client=None, name=None, properties=None, user_project=None):
if client is None:
connection = _Connection()
client = _Client(connection)
client = self._make_client()
if user_project is None:
bucket = self._get_target_class()(client, name=name)
else:
Expand Down Expand Up @@ -482,8 +481,7 @@ def test_ctor(self):
def test_ctor_w_user_project(self):
NAME = "name"
USER_PROJECT = "user-project-123"
connection = _Connection()
client = _Client(connection)
client = self._make_client()
bucket = self._make_one(client, name=NAME, user_project=USER_PROJECT)
self.assertEqual(bucket.name, NAME)
self.assertEqual(bucket._properties, {})
Expand Down Expand Up @@ -575,7 +573,7 @@ def test_notification_defaults(self):
PROJECT = "PROJECT"
BUCKET_NAME = "BUCKET_NAME"
TOPIC_NAME = "TOPIC_NAME"
client = _Client(_Connection(), project=PROJECT)
client = self._make_client(project=PROJECT)
bucket = self._make_one(client, name=BUCKET_NAME)

notification = bucket.notification(TOPIC_NAME)
Expand Down Expand Up @@ -603,7 +601,7 @@ def test_notification_explicit(self):
CUSTOM_ATTRIBUTES = {"attr1": "value1", "attr2": "value2"}
EVENT_TYPES = [OBJECT_FINALIZE_EVENT_TYPE, OBJECT_DELETE_EVENT_TYPE]
BLOB_NAME_PREFIX = "blob-name-prefix/"
client = _Client(_Connection(), project=PROJECT)
client = self._make_client(project=PROJECT)
bucket = self._make_one(client, name=BUCKET_NAME)

notification = bucket.notification(
Expand Down Expand Up @@ -1609,8 +1607,7 @@ def test_reload_w_metageneration_match(self):
)

def test_reload_w_generation_match(self):
connection = _Connection()
client = _Client(connection)
client = self._make_client()
bucket = self._make_one(client=client, name="name")

with self.assertRaises(TypeError):
Expand Down Expand Up @@ -3351,96 +3348,14 @@ def test_make_private_recursive_too_many(self):

client.list_blobs.assert_called_once()

def test_page_empty_response(self):
from google.api_core import page_iterator

connection = _Connection()
client = self._make_client()
client._base_connection = connection
name = "name"
bucket = self._make_one(client=client, name=name)
iterator = bucket.list_blobs()
page = page_iterator.Page(iterator, (), None)
iterator._page = page
blobs = list(page)
self.assertEqual(blobs, [])
self.assertEqual(iterator.prefixes, set())

def test_page_non_empty_response(self):
import six
from google.cloud.storage.blob import Blob

blob_name = "blob-name"
response = {"items": [{"name": blob_name}], "prefixes": ["foo"]}
connection = _Connection()
client = self._make_client()
client._base_connection = connection
name = "name"
bucket = self._make_one(client=client, name=name)

def fake_response():
return response

iterator = bucket.list_blobs()
iterator._get_next_page_response = fake_response

page = six.next(iterator.pages)
self.assertEqual(page.prefixes, ("foo",))
self.assertEqual(page.num_items, 1)
blob = six.next(page)
self.assertEqual(page.remaining, 0)
self.assertIsInstance(blob, Blob)
self.assertEqual(blob.name, blob_name)
self.assertEqual(iterator.prefixes, set(["foo"]))

def test_cumulative_prefixes(self):
import six
from google.cloud.storage.blob import Blob

BLOB_NAME = "blob-name1"
response1 = {
"items": [{"name": BLOB_NAME}],
"prefixes": ["foo"],
"nextPageToken": "s39rmf9",
}
response2 = {"items": [], "prefixes": ["bar"]}
client = self._make_client()
name = "name"
bucket = self._make_one(client=client, name=name)
responses = [response1, response2]

def fake_response():
return responses.pop(0)

iterator = bucket.list_blobs()
iterator._get_next_page_response = fake_response

# Parse first response.
pages_iter = iterator.pages
page1 = six.next(pages_iter)
self.assertEqual(page1.prefixes, ("foo",))
self.assertEqual(page1.num_items, 1)
blob = six.next(page1)
self.assertEqual(page1.remaining, 0)
self.assertIsInstance(blob, Blob)
self.assertEqual(blob.name, BLOB_NAME)
self.assertEqual(iterator.prefixes, set(["foo"]))
# Parse second response.
page2 = six.next(pages_iter)
self.assertEqual(page2.prefixes, ("bar",))
self.assertEqual(page2.num_items, 0)
self.assertEqual(iterator.prefixes, set(["foo", "bar"]))

def _test_generate_upload_policy_helper(self, **kwargs):
def _generate_upload_policy_helper(self, **kwargs):
import base64
import json

credentials = _create_signing_credentials()
credentials.signer_email = mock.sentinel.signer_email
credentials.sign_bytes.return_value = b"DEADBEEF"
connection = _Connection()
connection.credentials = credentials
client = _Client(connection)
client = self._make_client(_credentials=credentials)
name = "name"
bucket = self._make_one(client=client, name=name)

Expand Down Expand Up @@ -3477,7 +3392,7 @@ def _test_generate_upload_policy_helper(self, **kwargs):
def test_generate_upload_policy(self, now):
from google.cloud._helpers import _datetime_to_rfc3339

_, policy = self._test_generate_upload_policy_helper()
_, policy = self._generate_upload_policy_helper()

self.assertEqual(
policy["expiration"],
Expand All @@ -3489,15 +3404,13 @@ def test_generate_upload_policy_args(self):

expiration = datetime.datetime(1990, 5, 29)

_, policy = self._test_generate_upload_policy_helper(expiration=expiration)
_, policy = self._generate_upload_policy_helper(expiration=expiration)

self.assertEqual(policy["expiration"], _datetime_to_rfc3339(expiration))

def test_generate_upload_policy_bad_credentials(self):
credentials = object()
connection = _Connection()
connection.credentials = credentials
client = _Client(connection)
client = self._make_client(_credentials=credentials)
name = "name"
bucket = self._make_one(client=client, name=name)

Expand Down Expand Up @@ -3628,8 +3541,7 @@ def test_lock_retention_policy_w_user_project(self):

def test_generate_signed_url_w_invalid_version(self):
expiration = "2014-10-16T20:34:37.000Z"
connection = _Connection()
client = _Client(connection)
client = self._make_client()
bucket = self._make_one(name="bucket_name", client=client)
with self.assertRaises(ValueError):
bucket.generate_signed_url(expiration, version="nonesuch")
Expand Down Expand Up @@ -3665,8 +3577,7 @@ def _generate_signed_url_helper(
if expiration is None:
expiration = datetime.datetime.utcnow().replace(tzinfo=UTC) + delta

connection = _Connection()
client = _Client(connection)
client = self._make_client(_credentials=credentials)
bucket = self._make_one(name=bucket_name, client=client)

if version is None:
Expand Down Expand Up @@ -3726,32 +3637,33 @@ def _generate_signed_url_helper(
def test_get_bucket_from_string_w_valid_uri(self):
from google.cloud.storage.bucket import Bucket

connection = _Connection()
client = _Client(connection)
client = self._make_client()
BUCKET_NAME = "BUCKET_NAME"
uri = "gs://" + BUCKET_NAME

bucket = Bucket.from_string(uri, client)

self.assertIsInstance(bucket, Bucket)
self.assertIs(bucket.client, client)
self.assertEqual(bucket.name, BUCKET_NAME)

def test_get_bucket_from_string_w_invalid_uri(self):
from google.cloud.storage.bucket import Bucket

connection = _Connection()
client = _Client(connection)
client = self._make_client()

with pytest.raises(ValueError, match="URI scheme must be gs"):
Bucket.from_string("http://bucket_name", client)

def test_get_bucket_from_string_w_domain_name_bucket(self):
from google.cloud.storage.bucket import Bucket

connection = _Connection()
client = _Client(connection)
client = self._make_client()
BUCKET_NAME = "buckets.example.com"
uri = "gs://" + BUCKET_NAME

bucket = Bucket.from_string(uri, client)

self.assertIsInstance(bucket, Bucket)
self.assertIs(bucket.client, client)
self.assertEqual(bucket.name, BUCKET_NAME)
Expand Down Expand Up @@ -3886,23 +3798,3 @@ def test_it(self):
self.assertEqual(notification._topic_name, topic)
self.assertEqual(notification._topic_project, project)
self.assertEqual(notification._properties, item)


class _Connection(object):
credentials = None

def __init__(self):
pass

def api_request(self, **kw): # pragma: NO COVER
pass


class _Client(object):
def __init__(self, connection, project=None):
self._base_connection = connection
self.project = project

@property
def _credentials(self):
return self._base_connection.credentials

0 comments on commit 7a93104

Please sign in to comment.