Skip to content

Commit

Permalink
Add support for external IdP OIDC token retrieval
Browse files Browse the repository at this point in the history
using OAuth2.0 Crient Credentials Grant for
Google Cloud Operators.

This feature enables OIDC token retrieval from
any generic Identity Provider (IdP) that uses the OAuth 2.0
Credentials Grant Flow. Additionally, it lays the groundwork
for integrating other custom OIDC token retrieval methods.

related: apache#35899

Co-authored-by: Gonçalo Azevedo <[email protected]>
  • Loading branch information
dybolo and vugonz committed Jun 9, 2024
1 parent fc4fbb3 commit ee193fe
Show file tree
Hide file tree
Showing 5 changed files with 412 additions and 0 deletions.
48 changes: 48 additions & 0 deletions airflow/providers/google/cloud/utils/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud._internal_client.secret_manager_client import _SecretManagerClient
from airflow.providers.google.cloud.utils.external_token_supplier import (
ClientCredentialsGrantFlowTokenSupplier,
)
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.process_utils import patch_environ

Expand Down Expand Up @@ -210,6 +213,10 @@ def __init__(
target_principal: str | None = None,
delegates: Sequence[str] | None = None,
is_anonymous: bool | None = None,
idp_issuer_url: str | None = None,
client_id: str | None = None,
client_secret: str | None = None,
idp_extra_params_dict: dict[str, str] | None = None,
) -> None:
super().__init__()
key_options = [key_path, keyfile_dict, credential_config_file, key_secret_name, is_anonymous]
Expand All @@ -229,6 +236,10 @@ def __init__(
self.target_principal = target_principal
self.delegates = delegates
self.is_anonymous = is_anonymous
self.idp_issuer_url = idp_issuer_url
self.client_id = client_id
self.client_secret = client_secret
self.idp_extra_params_dict = idp_extra_params_dict

def get_credentials_and_project(self) -> tuple[Credentials, str]:
"""
Expand All @@ -248,6 +259,10 @@ def get_credentials_and_project(self) -> tuple[Credentials, str]:
credentials, project_id = self._get_credentials_using_key_secret_name()
elif self.keyfile_dict:
credentials, project_id = self._get_credentials_using_keyfile_dict()
elif self.idp_issuer_url:
credentials, project_id = (
self._get_credentials_using_credential_config_file_and_token_supplier()
)
elif self.credential_config_file:
credentials, project_id = self._get_credentials_using_credential_config_file()
else:
Expand Down Expand Up @@ -357,6 +372,24 @@ def _get_credentials_using_credential_config_file(self) -> tuple[Credentials, st

return credentials, project_id

def _get_credentials_using_credential_config_file_and_token_supplier(self):
self._log_info(
"Getting connection using credential configuration file and external Identity Provider."
)

if not self.credential_config_file:
raise AirflowException(
"Credential configuration is needed to use authentication by External Identity Provider."
)

info = _get_info_from_credential_configuration_file(self.credential_config_file)
info["subject_token_supplier"] = ClientCredentialsGrantFlowTokenSupplier(
oidc_issuer_url=self.idp_issuer_url, client_id=self.client_id, client_secret=self.client_secret
)

credentials, project_id = google.auth.load_credentials_from_dict(info=info, scopes=self.scopes)
return credentials, project_id

def _get_credentials_using_adc(self) -> tuple[Credentials, str]:
self._log_info(
"Getting connection using `google.auth.default()` since no explicit credentials are provided."
Expand Down Expand Up @@ -426,3 +459,18 @@ def _get_project_id_from_service_account_email(service_account_email: str) -> st
raise AirflowException(
f"Could not extract project_id from service account's email: {service_account_email}."
)


def _get_info_from_credential_configuration_file(credential_configuration_file):
if isinstance(credential_configuration_file, str) and os.path.exists(credential_configuration_file):
with open(credential_configuration_file) as file_obj:
try:
info = json.load(file_obj)
except ValueError:
raise AirflowException("Credentials Configuration File is not a valid json file.")
else:
try:
info = json.loads(credential_configuration_file)
except json.decoder.JSONDecodeError:
raise AirflowException("Invalid JSON.")
return info
146 changes: 146 additions & 0 deletions airflow/providers/google/cloud/utils/external_token_supplier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import time
from functools import wraps
from typing import TYPE_CHECKING, Any

import requests
from google.auth.exceptions import RefreshError
from google.auth.identity_pool import SubjectTokenSupplier

if TYPE_CHECKING:
from google.auth.external_account import SupplierContext
from google.auth.transport import Request

from airflow.utils.log.logging_mixin import LoggingMixin


def cache_token_decorator(get_subject_token_method):
"""Cache calls to ``SubjectTokenSupplier`` instances' ``get_token_supplier`` methods.
Different instances of a same SubjectTokenSupplier class with the same credentials and oidc issuer url
share access tokens.
:param get_subject_token_method: A method that returns both a token and an integer specifying
the time in seconds until the token expires
See also:
https://googleapis.dev/python/google-auth/latest/reference/google.auth.identity_pool.html#google.auth.identity_pool.SubjectTokenSupplier.get_subject_token
"""
cache = {}

@wraps(get_subject_token_method)
def wrapper(supplier_instance: SubjectTokenSupplier, *args, **kwargs) -> str:
"""Obeys the interface set by ``SubjectTokenSupplier`` for ``get_subject_token`` methods.
:param supplier_instance: the SubjectTokenSupplier instance whose get_subject_token method is being decorated
:return: The token string
"""
nonlocal cache

cache_key = (
supplier_instance.oidc_issuer_url
+ supplier_instance.client_id
+ supplier_instance.client_secret
+ ",".join(sorted(supplier_instance.extra_params_kwargs))
)
token: dict[str, str | float] = {}

if cache_key not in cache or cache[cache_key]["expiration_time"] < time.monotonic():
supplier_instance.log.info("OIDC token missing or expired")
try:
access_token, expires_in = get_subject_token_method(supplier_instance, *args, **kwargs)
if not isinstance(expires_in, int) or not isinstance(access_token, str):
raise RefreshError # assume error if strange values are provided

except RefreshError:
supplier_instance.log.error("Failed retrieving new OIDC Token from IdP")
raise

expiration_time = time.monotonic() + float(expires_in)
token["access_token"] = access_token
token["expiration_time"] = expiration_time
cache[cache_key] = token

supplier_instance.log.info("New OIDC token retrieved, expires in %s seconds.", expires_in)

return cache[cache_key]["access_token"]

return wrapper


class ClientCredentialsGrantFlowTokenSupplier(LoggingMixin, SubjectTokenSupplier):
"""
Class that retrieves an OIDC token from an external IdP using OAuth2.0 Client Credentials Grant flow.
This class implements the ``SubjectTokenSupplier`` interface class used by ``google.auth.identity_pool.Credentials``
:params oidc_issuer_url: URL of the IdP that performs OAuth2.0 Client Credentials Grant flow and returns an OIDC token.
:params client_id: Client ID of the application requesting the token
:params client_secret: Client secret of the application requesting the token
:params extra_params_kwargs: Extra parameters to be passed in the payload of the POST request to the `oidc_issuer_url`
See also:
https://googleapis.dev/python/google-auth/latest/reference/google.auth.identity_pool.html#google.auth.identity_pool.SubjectTokenSupplier
"""

def __init__(
self,
oidc_issuer_url: str,
client_id: str,
client_secret: str,
**extra_params_kwargs: Any,
) -> None:
super().__init__()
self.oidc_issuer_url = oidc_issuer_url
self.client_id = client_id
self.client_secret = client_secret
self.extra_params_kwargs = extra_params_kwargs

@cache_token_decorator
def get_subject_token(self, context: SupplierContext, request: Request) -> tuple[str, int]:
"""Perform Client Credentials Grant flow with IdP and retrieves an OIDC token and expiration time."""
self.log.info("Requesting new OIDC token from external IdP.")
try:
response = requests.post(
self.oidc_issuer_url,
data={
"grant_type": "client_credentials",
"client_id": self.client_id,
"client_secret": self.client_secret,
**self.extra_params_kwargs,
},
)
response.raise_for_status()
except requests.HTTPError as e:
raise RefreshError(str(e))
except requests.ConnectionError as e:
raise RefreshError(str(e))

try:
response_dict = response.json()
except requests.JSONDecodeError:
raise RefreshError(f"Didn't get a json response from {self.oidc_issuer_url}")

# These fields are required
if {"access_token", "expires_in"} - set(response_dict.keys()):
# TODO more information about the error can be provided in the exception by inspecting the response
raise RefreshError(f"No access token returned from {self.oidc_issuer_url}")

return response_dict["access_token"], response_dict["expires_in"]
30 changes: 30 additions & 0 deletions airflow/providers/google/common/hooks/base_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,20 @@ def get_connection_form_widgets(cls) -> dict[str, Any]:
"impersonation_chain": StringField(
lazy_gettext("Impersonation Chain"), widget=BS3TextFieldWidget()
),
"idp_issuer_url": StringField(
lazy_gettext("IdP Token Issue URL (Client Credentials Grant Flow)"),
widget=BS3TextFieldWidget(),
),
"client_id": StringField(
lazy_gettext("Client ID (Client Credentials Grant Flow)"), widget=BS3TextFieldWidget()
),
"client_secret": StringField(
lazy_gettext("Client Secret (Client Credentials Grant Flow)"),
widget=BS3PasswordFieldWidget(),
),
"idp_extra_parameters": StringField(
lazy_gettext("IdP Extra Request Parameters"), widget=BS3TextFieldWidget()
),
"is_anonymous": BooleanField(
lazy_gettext("Anonymous credentials (ignores all other settings)"), default=False
),
Expand Down Expand Up @@ -305,6 +319,18 @@ def get_credentials_and_project_id(self) -> tuple[Credentials, str | None]:
target_principal, delegates = _get_target_principal_and_delegates(self.impersonation_chain)
is_anonymous = self._get_field("is_anonymous")

idp_issuer_url: str | None = self._get_field("idp_issuer_url", None)
client_id: str | None = self._get_field("client_id", None)
client_secret: str | None = self._get_field("client_secret", None)
idp_extra_params: str | None = self._get_field("idp_extra_params", None)

idp_extra_params_dict: dict[str, str] | None = None
if idp_extra_params:
try:
idp_extra_params_dict = json.loads(idp_extra_params)
except json.decoder.JSONDecodeError:
raise AirflowException("Invalid JSON.")

credentials, project_id = get_credentials_and_project_id(
key_path=key_path,
keyfile_dict=keyfile_dict_json,
Expand All @@ -316,6 +342,10 @@ def get_credentials_and_project_id(self) -> tuple[Credentials, str | None]:
target_principal=target_principal,
delegates=delegates,
is_anonymous=is_anonymous,
idp_issuer_url=idp_issuer_url,
client_id=client_id,
client_secret=client_secret,
idp_extra_params_dict=idp_extra_params_dict,
)

overridden_project_id = self._get_field("project")
Expand Down
53 changes: 53 additions & 0 deletions tests/providers/google/cloud/utils/test_credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,21 @@
ACCOUNT_3_ANOTHER_PROJECT = "account_3@another_project_id.iam.gserviceaccount.com"
ANOTHER_PROJECT_ID = "another_project_id"
CRED_PROVIDER_LOGGER_NAME = "airflow.providers.google.cloud.utils.credentials_provider._CredentialProvider"
IDP_LINK = "http://example.com/idp"
CLIENT_ID = "your-client-id"
CLIENT_SECRET = "your-client-secret"
TEST_AUDIENCE = "test-audience"
TOKEN_TYPE = "urn:ietf:params:oauth:token-type:jwt"
ACCOUNT_IMPERSONATION = "http://example.com/impersonate"
CREDENTIAL_CONFIG_FILE = (
'{"audience": "'
+ TEST_AUDIENCE
+ '", "subject_token_type": "'
+ TOKEN_TYPE
+ '", "service_account_impersonation_url": "'
+ ACCOUNT_IMPERSONATION
+ '"}'
)


@pytest.fixture
Expand Down Expand Up @@ -411,6 +426,44 @@ def test_disable_logging(self, mock_default, mock_info, mock_file, assert_no_log
disable_logging=True,
)

@mock.patch("google.auth.load_credentials_from_dict", return_value=("CREDENTIALS", "PROJECT_ID"))
def test_get_credentials_using_identity_provider(self, mock_load_credentials_from_file, caplog):
with caplog.at_level(level=logging.DEBUG, logger=CRED_PROVIDER_LOGGER_NAME):
caplog.clear()
result = get_credentials_and_project_id(
credential_config_file=CREDENTIAL_CONFIG_FILE,
idp_issuer_url=IDP_LINK,
client_id=CLIENT_ID,
client_secret=CLIENT_SECRET,
)
mock_load_credentials_from_file.assert_called_once_with(
info={
"audience": TEST_AUDIENCE,
"subject_token_type": TOKEN_TYPE,
"service_account_impersonation_url": ACCOUNT_IMPERSONATION,
"subject_token_supplier": ANY,
},
scopes=ANY,
)
assert result == ("CREDENTIALS", "PROJECT_ID")
assert (
"Getting connection using credential configuration file and external Identity Provider."
in caplog.messages
)

def test_get_credentials_using_idp_no_credential_config(self):
with pytest.raises(
AirflowException,
match=re.escape(
"Credential configuration is needed to use authentication by External Identity Provider."
),
):
get_credentials_and_project_id(
idp_issuer_url=IDP_LINK,
client_id=CLIENT_ID,
client_secret=CLIENT_SECRET,
)


class TestGetScopes:
def test_get_scopes_with_default(self):
Expand Down
Loading

0 comments on commit ee193fe

Please sign in to comment.