Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: making iam endpoint universe-aware #1604

Merged
merged 11 commits into from
Oct 19, 2024
18 changes: 6 additions & 12 deletions google/auth/impersonated_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,6 @@


_REFRESH_ERROR = "Unable to acquire impersonated credentials"
_UNIVERSE_DOMAIN_MATCH_SOURCE_ERROR = (
"The universe_domain "
"is not supported for impersonated credentials. The "
"credential uses the value from source_credentials."
)

_DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds

Expand Down Expand Up @@ -180,7 +175,6 @@ def __init__(
lifetime=_DEFAULT_TOKEN_LIFETIME_SECS,
quota_project_id=None,
iam_endpoint_override=None,
universe_domain=None,
):
"""
Args:
Expand Down Expand Up @@ -227,11 +221,14 @@ def __init__(
and self._source_credentials._always_use_jwt_access
):
self._source_credentials._create_self_signed_jwt(None)
<<<<<<< HEAD
sai-sunder-s marked this conversation as resolved.
Show resolved Hide resolved
if (
TimurSadykov marked this conversation as resolved.
Show resolved Hide resolved
universe_domain is not None
and universe_domain != self._source_credentials.universe_domain
):
raise exceptions.InvalidOperation(_UNIVERSE_DOMAIN_MATCH_SOURCE_ERROR)
=======
>>>>>>> 273a733 (fix: test updates)
self._universe_domain = source_credentials.universe_domain
self._target_principal = target_principal
self._target_scopes = target_scopes
Expand Down Expand Up @@ -289,15 +286,12 @@ def _update_token(self, request):
iam_endpoint_override=self._iam_endpoint_override,
)

def get_iam_sign_endpoint(self):
return iam._IAM_SIGN_ENDPOINT.format(
self.universe_domain, self._target_principal
)

def sign_bytes(self, message):
from google.auth.transport.requests import AuthorizedSession

iam_sign_endpoint = self.get_iam_sign_endpoint()
iam_sign_endpoint = iam._IAM_SIGN_ENDPOINT.format(
self.universe_domain, self._target_principal
)

body = {
"payload": base64.b64encode(message).decode("utf-8"),
Expand Down
20 changes: 20 additions & 0 deletions tests/compute_engine/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,16 @@ def test_with_target_audience_integration(self):
},
)

# mock information about universe_domain
responses.add(
responses.GET,
"http://metadata.google.internal/computeMetadata/v1/universe/"
"universe_domain",
status=200,
content_type="application/json",
json={},
)

# mock token for credentials
responses.add(
responses.GET,
Expand Down Expand Up @@ -659,6 +669,16 @@ def test_with_quota_project_integration(self):
},
)

# stubby response about universe_domain
responses.add(
responses.GET,
"http://metadata.google.internal/computeMetadata/v1/universe/"
"universe_domain",
status=200,
content_type="application/json",
json={},
)

# mock sign blob endpoint
signature = base64.b64encode(b"some-signature").decode("utf-8")
responses.add(
Expand Down
53 changes: 10 additions & 43 deletions tests/test_impersonated_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def make_credentials(
lifetime=LIFETIME,
target_principal=TARGET_PRINCIPAL,
iam_endpoint_override=None,
universe_domain=None,
):

return Credentials(
Expand All @@ -134,7 +133,6 @@ def make_credentials(
delegates=self.DELEGATES,
lifetime=lifetime,
iam_endpoint_override=iam_endpoint_override,
universe_domain=universe_domain,
)

def test_get_cred_info(self):
Expand All @@ -148,30 +146,13 @@ def test_get_cred_info(self):
"principal": "[email protected]",
}

def test_explicit_universe_domain_matching_source(self):
source_credentials = service_account.Credentials(
SIGNER, "[email protected]", TOKEN_URI, universe_domain="foo.bar"
)
credentials = self.make_credentials(
universe_domain="foo.bar", source_credentials=source_credentials
)
assert credentials.universe_domain == "foo.bar"

def test_universe_domain_from_source(self):
def test_universe_domain_matching_source(self):
source_credentials = service_account.Credentials(
SIGNER, "[email protected]", TOKEN_URI, universe_domain="foo.bar"
)
credentials = self.make_credentials(source_credentials=source_credentials)
assert credentials.universe_domain == "foo.bar"

def test_explicit_universe_domain_not_matching_source(self):
with pytest.raises(exceptions.InvalidOperation) as excinfo:
self.make_credentials(universe_domain="foo.bar")

assert excinfo.match(
impersonated_credentials._UNIVERSE_DOMAIN_MATCH_SOURCE_ERROR
)

def test__make_copy_get_cred_info(self):
credentials = self.make_credentials()
credentials._cred_file_path = "/path/to/file"
Expand Down Expand Up @@ -417,28 +398,6 @@ def test_signer_email(self):
credentials = self.make_credentials(target_principal=self.TARGET_PRINCIPAL)
assert credentials.signer_email == self.TARGET_PRINCIPAL

def test_sign_endpoint(self):
source_credentials = service_account.Credentials(
SIGNER, "[email protected]", TOKEN_URI, universe_domain="foo.bar"
)
credentials = self.make_credentials(source_credentials=source_credentials)
assert (
credentials.get_iam_sign_endpoint()
== "https://iamcredentials.foo.bar/v1/projects/-/serviceAccounts/[email protected]:signBlob"
)

def test_sign_endpoint_explicit_universe_domain(self):
source_credentials = service_account.Credentials(
SIGNER, "[email protected]", TOKEN_URI, universe_domain="foo.bar"
)
credentials = self.make_credentials(
universe_domain="foo.bar", source_credentials=source_credentials
)
assert (
credentials.get_iam_sign_endpoint()
== "https://iamcredentials.foo.bar/v1/projects/-/serviceAccounts/[email protected]:signBlob"
)

def test_service_account_email(self):
credentials = self.make_credentials(target_principal=self.TARGET_PRINCIPAL)
assert credentials.service_account_email == self.TARGET_PRINCIPAL
Expand All @@ -460,11 +419,19 @@ def test_sign_bytes(self, mock_donor_credentials, mock_authorizedsession_sign):
request.return_value = response

credentials.refresh(request)

assert credentials.valid
assert not credentials.expired

signature = credentials.sign_bytes(b"signed bytes")
mock_authorizedsession_sign.assert_called_with(
TimurSadykov marked this conversation as resolved.
Show resolved Hide resolved
mock.ANY,
"POST",
"https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/[email protected]:signBlob",
None,
json={"payload": "c2lnbmVkIGJ5dGVz", "delegates": []},
headers={"Content-Type": "application/json"},
)

assert signature == b"signature"

def test_sign_bytes_failure(self):
Expand Down