Skip to content

Commit

Permalink
feat(providers/microsoft): add DefaultAzureCredential support to Azur…
Browse files Browse the repository at this point in the history
…eContainerInstanceHook (apache#33467)

* feat(provider/microsoft): add DefaultAzureCredential compatibility to azure-python-sdk through AzureIdentityCredentialAdapter wrapper
https://stackoverflow.com/questions/63384092/exception-attributeerror-defaultazurecredential-object-has-no-attribute-sig

* feat(providers/microsoft): add DefaultAzureCredential support to AzureContainerInstanceHook

* fix(providers/microsfot): replace AzureIdentityCredentialAdapter with DefaultAzureCredential due to backward compatibility
  • Loading branch information
Lee-W authored Aug 24, 2023
1 parent 626d3da commit faa50cb
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 4 deletions.
16 changes: 12 additions & 4 deletions airflow/providers/microsoft/azure/hooks/base_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import AzureIdentityCredentialAdapter


class AzureBaseHook(BaseHook):
Expand Down Expand Up @@ -124,10 +125,17 @@ def get_conn(self) -> Any:
self.log.info("Getting connection using a JSON config.")
return get_client_from_json_dict(client_class=self.sdk_client, config_dict=key_json)

self.log.info("Getting connection using specific credentials and subscription_id.")
return self.sdk_client(
credentials=ServicePrincipalCredentials(
credentials: ServicePrincipalCredentials | AzureIdentityCredentialAdapter
if all([conn.login, conn.password, tenant]):
self.log.info("Getting connection using specific credentials and subscription_id.")
credentials = ServicePrincipalCredentials(
client_id=conn.login, secret=conn.password, tenant=tenant
),
)
else:
self.log.info("Using DefaultAzureCredential as credential")
credentials = AzureIdentityCredentialAdapter()

return self.sdk_client(
credentials=credentials,
subscription_id=subscription_id,
)
50 changes: 50 additions & 0 deletions airflow/providers/microsoft/azure/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@

import warnings

from azure.core.pipeline import PipelineContext, PipelineRequest
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
from azure.core.pipeline.transport import HttpRequest
from azure.identity import DefaultAzureCredential
from msrest.authentication import BasicTokenAuthentication


def get_field(*, conn_id: str, conn_type: str, extras: dict, field_name: str):
"""Get field from extra, first checking short name, then for backcompat we check for prefixed name."""
Expand All @@ -43,3 +49,47 @@ def get_field(*, conn_id: str, conn_type: str, extras: dict, field_name: str):
if ret == "":
return None
return ret


class AzureIdentityCredentialAdapter(BasicTokenAuthentication):
"""Adapt azure-identity credentials for backward compatibility.
Adapt credentials from azure-identity to be compatible with SD
that needs msrestazure or azure.common.credentials
Check https://stackoverflow.com/questions/63384092/exception-attributeerror-defaultazurecredential-object-has-no-attribute-sig
"""

def __init__(self, credential=None, resource_id="https://management.azure.com/.default", **kwargs):
"""Adapt azure-identity credentials for backward compatibility.
:param credential: Any azure-identity credential (DefaultAzureCredential by default)
:param str resource_id: The scope to use to get the token (default ARM)
"""
super().__init__(None)
if credential is None:
credential = DefaultAzureCredential()
self._policy = BearerTokenCredentialPolicy(credential, resource_id, **kwargs)

def _make_request(self):
return PipelineRequest(
HttpRequest("AzureIdentityCredentialAdapter", "https://fakeurl"), PipelineContext(None)
)

def set_token(self):
"""Ask the azure-core BearerTokenCredentialPolicy policy to get a token.
Using the policy gives us for free the caching system of azure-core.
We could make this code simpler by using private method, but by definition
I can't assure they will be there forever, so mocking a fake call to the policy
to extract the token, using 100% public API.
"""
request = self._make_request()
self._policy.on_request(request)
# Read Authorization, and get the second part after Bearer
token = request.http_request.headers["Authorization"].split(" ", 1)[1]
self.token = {"access_token": token}

def signed_session(self, azure_session=None):
self.set_token()
return super().signed_session(azure_session)

0 comments on commit faa50cb

Please sign in to comment.