Skip to content

Commit

Permalink
code review
Browse files Browse the repository at this point in the history
  • Loading branch information
abikouo committed Oct 14, 2022
1 parent 505747d commit 5697191
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 46 deletions.
3 changes: 0 additions & 3 deletions plugins/module_utils/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,6 @@ def resource(self, service):
return boto3_conn(self, conn_type='resource', resource=service,
region=region, endpoint=endpoint_url, **aws_connect_kwargs)

def boto3_conn(self, **kwargs):
return boto3_conn(self, **kwargs)

@property
def region(self):
return get_aws_region(self, True)
Expand Down
21 changes: 10 additions & 11 deletions plugins/module_utils/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from ansible.module_utils.basic import to_text
from ansible.module_utils.six.moves.urllib.parse import urlparse

from .botocore import boto3_conn

try:
from botocore.client import Config
from botocore.exceptions import BotoCoreError, ClientError
Expand Down Expand Up @@ -88,20 +90,20 @@ def calculate_etag_content(module, content, etag, s3, bucket, obj, version=None)
return '"{0}"'.format(md5(content).hexdigest())


def validate_bucket_name(module, name):
def validate_bucket_name(name):
# See: https://docs.aws.amazon.com/AmazonS3/latest/userguide/bucketnamingrules.html
if len(name) < 3:
module.fail_json(msg='the length of an S3 bucket must be at least 3 characters')
return 'the length of an S3 bucket must be at least 3 characters'
if len(name) > 63:
module.fail_json(msg='the length of an S3 bucket cannot exceed 63 characters')
return 'the length of an S3 bucket cannot exceed 63 characters'

legal_characters = string.ascii_lowercase + ".-" + string.digits
illegal_characters = [c for c in name if c not in legal_characters]
if illegal_characters:
module.fail_json(msg='invalid character(s) found in the bucket name')
return 'invalid character(s) found in the bucket name'
if name[-1] not in string.ascii_lowercase + string.digits:
module.fail_json(msg='bucket names must begin and end with a letter or number')
return True
return 'bucket names must begin and end with a letter or number'
return None


# Spot special case of fakes3.
Expand Down Expand Up @@ -143,7 +145,7 @@ def parse_default_endpoint(url, mode, encryption_mode, dualstack, sig_4):
return result


def get_s3_connection(module, aws_connect_kwargs, location, ceph, endpoint_url, sig_4=False):
def get_s3_connection(mode, encryption_mode, dualstack, aws_connect_kwargs, location, ceph, endpoint_url, sig_4=False):
params = dict(
conn_type='client',
resource='s3',
Expand All @@ -155,10 +157,7 @@ def get_s3_connection(module, aws_connect_kwargs, location, ceph, endpoint_url,
elif is_fakes3(endpoint_url):
endpoint_p = parse_fakes3_endpoint(endpoint_url)
else:
mode = module.params.get("mode")
encryption_mode = module.params.get("encryption_mode")
dualstack = module.params.get("dualstack")
endpoint_p = parse_default_endpoint(endpoint_url, mode, encryption_mode, dualstack, sig_4)

params.update(endpoint_p)
return module.boto3_conn(**params)
return boto3_conn(**params)
4 changes: 3 additions & 1 deletion plugins/modules/s3_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,7 +1129,9 @@ def main():
region, _ec2_url, aws_connect_kwargs = get_aws_connection_info(module, boto3=True)

if module.params.get('validate_bucket_name'):
validate_bucket_name(module, module.params["name"])
err = validate_bucket_name(module.params["name"])
if err:
module.fail_json(msg=err)

if region in ('us-east-1', '', None):
# default to US Standard region
Expand Down
11 changes: 7 additions & 4 deletions plugins/modules/s3_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,7 @@ def main():
max_keys = module.params.get('max_keys')
metadata = module.params.get('metadata')
mode = module.params.get('mode')
encryption_mode = module.params.get('encryption_mode')
obj = module.params.get('object')
version = module.params.get('version')
overwrite = module.params.get('overwrite')
Expand All @@ -986,7 +987,9 @@ def main():
bucket_canned_acl = ["private", "public-read", "public-read-write", "authenticated-read"]

if module.params.get('validate_bucket_name'):
validate_bucket_name(module, bucket)
err = validate_bucket_name(bucket)
if err:
module.fail_json(msg=err)

if overwrite not in ['always', 'never', 'different', 'latest']:
if module.boolean(overwrite):
Expand Down Expand Up @@ -1030,7 +1033,7 @@ def main():
if endpoint_url:
for key in ['validate_certs', 'security_token', 'profile_name']:
aws_connect_kwargs.pop(key, None)
s3 = get_s3_connection(module, aws_connect_kwargs, location, ceph, endpoint_url, sig_v4)
s3 = get_s3_connection(mode, encryption_mode, dualstack, aws_connect_kwargs, location, ceph, endpoint_url, sig_v4)

validate = not ignore_nonexistent_bucket

Expand Down Expand Up @@ -1081,7 +1084,7 @@ def main():
try:
download_s3file(module, s3, bucket, obj, dest, retries, version=version)
except Sigv4Required:
s3 = get_s3_connection(module, aws_connect_kwargs, location, ceph, endpoint_url, sig_4=True)
s3 = get_s3_connection(mode, encryption_mode, dualstack, aws_connect_kwargs, location, ceph, endpoint_url, sig_4=True)
download_s3file(module, s3, bucket, obj, dest, retries, version=version)

if mode == 'put':
Expand Down Expand Up @@ -1203,7 +1206,7 @@ def main():
try:
download_s3str(module, s3, bucket, obj, version=version)
except Sigv4Required:
s3 = get_s3_connection(module, aws_connect_kwargs, location, ceph, endpoint_url, sig_4=True)
s3 = get_s3_connection(mode, encryption_mode, dualstack, aws_connect_kwargs, location, ceph, endpoint_url, sig_4=True)
download_s3str(module, s3, bucket, obj, version=version)
elif version is not None:
module.fail_json(msg="Key %s with version id %s does not exist." % (obj, version))
Expand Down
10 changes: 9 additions & 1 deletion plugins/modules/s3_object_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,8 @@ def main():
endpoint_url = module.params.get('endpoint_url')
dualstack = module.params.get('dualstack')
ceph = module.params.get('ceph')
mode = module.params.get('mode')
encryption_mode = module.params.get('encryption_mode')

if not endpoint_url and 'S3_URL' in os.environ:
endpoint_url = os.environ['S3_URL']
Expand All @@ -720,7 +722,13 @@ def main():
location = region
for key in ['validate_certs', 'security_token', 'profile_name']:
aws_connect_kwargs.pop(key, None)
connection = get_s3_connection(module, aws_connect_kwargs, location, ceph, endpoint_url)
connection = get_s3_connection(mode,
encryption_mode,
dualstack,
aws_connect_kwargs,
location,
ceph,
endpoint_url)
else:
try:
connection = module.client('s3', retry_decorator=AWSRetry.jittered_backoff())
Expand Down
39 changes: 13 additions & 26 deletions tests/unit/module_utils/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def test_calculate_etag_failure(m_checksum_file, m_checksum_content, using_file)


@pytest.mark.parametrize(
"bucket_name,error",
"bucket_name,result",
[
("docexamplebucket1", None),
("log-delivery-march-2020", None),
Expand All @@ -224,19 +224,9 @@ def test_calculate_etag_failure(m_checksum_file, m_checksum_content, using_file)
("my", "the length of an S3 bucket must be at least 3 characters")
]
)
def test_validate_bucket_name(bucket_name, error):
def test_validate_bucket_name(bucket_name, result):

module = MagicMock()
module.fail_json.side_effect = SystemExit(1)

if error:
with pytest.raises(SystemExit):
s3.validate_bucket_name(module, bucket_name)

module.fail_json.assert_called_with(msg=error)
else:
assert s3.validate_bucket_name(module, bucket_name)
module.fail_json.assert_not_called()
assert result == s3.validate_bucket_name(bucket_name)


mod_urlparse = "ansible_collections.amazon.aws.plugins.module_utils.s3.urlparse"
Expand Down Expand Up @@ -392,7 +382,9 @@ def test_parse_default_endpoint(m_config, mode, encryption_mode, dualstack, sig_
@patch('ansible_collections.amazon.aws.plugins.module_utils.s3.parse_fakes3_endpoint')
@patch('ansible_collections.amazon.aws.plugins.module_utils.s3.is_fakes3')
@patch('ansible_collections.amazon.aws.plugins.module_utils.s3.parse_ceph_endpoint')
def test_get_s3_connection(m_parse_ceph_endpoint,
@patch('ansible_collections.amazon.aws.plugins.module_utils.s3.boto3_conn')
def test_get_s3_connection(m_boto3_conn,
m_parse_ceph_endpoint,
m_is_fakes3,
m_parse_fakes3_endpoint,
m_parse_default_endpoint,
Expand All @@ -402,14 +394,13 @@ def test_get_s3_connection(m_parse_ceph_endpoint,
url = "https://my-bucket.s3.us-west-2.amazonaws.com"
region = "us-east-1"
aws_connect_kwargs = {"aws_secret_key": "secret123!", "aws_access_key": "ABCDEFG"}
params = {"mode": "put", "encryption_mode": "aws:test", "dualstack": False}
mode = "put"
encryption_mode = "aws:test"
dualstack = False
sig_4 = False

endpoint = {"endpoint": url, "config": {"s3": True, "signature": "s123"}}

module = MagicMock()
module.params = params

m_is_fakes3.return_value = isfakes3
if ceph:
m_parse_ceph_endpoint.return_value = endpoint
Expand All @@ -422,7 +413,7 @@ def test_get_s3_connection(m_parse_ceph_endpoint,
expected.update(aws_connect_kwargs)
expected.update(endpoint)

result = s3.get_s3_connection(module, aws_connect_kwargs, region, ceph, url, sig_4)
result = s3.get_s3_connection(mode, encryption_mode, dualstack, aws_connect_kwargs, region, ceph, url, sig_4)

if ceph:
m_parse_ceph_endpoint.assert_called_with(url)
Expand All @@ -434,14 +425,10 @@ def test_get_s3_connection(m_parse_ceph_endpoint,
m_parse_default_endpoint.assert_not_called()
else:
m_parse_default_endpoint.assert_called_with(
url,
params.get("mode"),
params.get("encryption_mode"),
params.get("dualstack"),
sig_4
url, mode, encryption_mode, dualstack, sig_4
)
m_parse_ceph_endpoint.assert_not_called()
m_parse_fakes3_endpoint.assert_not_called()

assert result == module.boto3_conn.return_value
module.boto3_conn.assert_called_with(**expected)
assert result == m_boto3_conn.return_value
m_boto3_conn.assert_called_with(**expected)

0 comments on commit 5697191

Please sign in to comment.