Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve testing AWS Connection response #26953

Merged
merged 1 commit into from
Oct 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,21 +635,21 @@ def test_connection(self):
.. seealso::
https://docs.aws.amazon.com/STS/latest/APIReference/API_GetCallerIdentity.html
"""
orig_client_type, self.client_type = self.client_type, 'sts'
try:
res = self.get_client_type().get_caller_identity()
metadata = res.pop("ResponseMetadata", {})
if metadata.get("HTTPStatusCode") == 200:
return True, json.dumps(res)
else:
session = self.get_session()
conn_info = session.client("sts").get_caller_identity()
metadata = conn_info.pop("ResponseMetadata", {})
if metadata.get("HTTPStatusCode") != 200:
try:
return False, json.dumps(metadata)
except TypeError:
return False, str(metadata)
conn_info["credentials_method"] = session.get_credentials().method
conn_info["region_name"] = session.region_name
return True, ", ".join(f"{k}={v!r}" for k, v in conn_info.items())

except Exception as e:
return False, str(e)
finally:
self.client_type = orig_client_type
return False, str(f"{type(e).__name__!r} error occurred while testing connection: {e}")


class AwsBaseHook(AwsGenericHook[Union[boto3.client, boto3.resource]]):
Expand Down
29 changes: 29 additions & 0 deletions tests/providers/amazon/aws/hooks/test_base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,35 @@ def test_hook_connection_test(self):
assert result
assert hook.client_type == "s3" # Same client_type which defined during initialisation

@mock.patch("boto3.session.Session")
def test_hook_connection_test_failed(self, mock_boto3_session):
"""Test ``test_connection`` failure."""
hook = AwsBaseHook(client_type="ec2")

# Tests that STS API return non 200 code. Under normal circumstances this is hardly possible.
response_metadata = {"HTTPStatusCode": 500, "reason": "Test Failure"}
mock_sts_client = mock.MagicMock()
mock_sts_client.return_value.get_caller_identity.return_value = {
"ResponseMetadata": response_metadata
}
mock_boto3_session.return_value.client = mock_sts_client
result, message = hook.test_connection()
assert not result
assert message == json.dumps(response_metadata)
mock_sts_client.assert_called_once_with("sts")

def mock_error():
raise ConnectionError("Test Error")

# Something bad happen during boto3.session.Session creation (e.g. wrong credentials or conn error)
mock_boto3_session.reset_mock()
mock_boto3_session.side_effect = mock_error
result, message = hook.test_connection()
assert not result
assert message == "'ConnectionError' error occurred while testing connection: Test Error"

assert hook.client_type == "ec2"

@mock.patch.dict(os.environ, {f"AIRFLOW_CONN_{MOCK_AWS_CONN_ID.upper()}": "aws://"})
def test_conn_config_conn_id_exists(self):
"""Test retrieve connection config if aws_conn_id exists."""
Expand Down