Skip to content

Commit

Permalink
Update AWS auth manager to use Fastapi instead of Flask (apache#46381)
Browse files Browse the repository at this point in the history
  • Loading branch information
vincbeck authored and insomnes committed Feb 6, 2025
1 parent 69ce43e commit 17f3799
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 288 deletions.
12 changes: 12 additions & 0 deletions airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2676,3 +2676,15 @@ dag_processor:
type: integer
example: ~
default: "30"
fastapi:
description: Configuration for the Fastapi webserver.
options:
base_url:
description: |
The base url of the Fastapi endpoint. Airflow cannot guess what domain or CNAME you are using.
If the Airflow console (the front-end) and the Fastapi apis are on a different domain, this config
should contain the Fastapi apis endpoint.
version_added: ~
type: string
example: ~
default: "http://localhost:29091"
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
from collections import defaultdict
from collections.abc import Container, Sequence
from functools import cached_property
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Any, cast

from flask import session, url_for
from fastapi import FastAPI
from flask import session

from airflow.auth.managers.base_auth_manager import BaseAuthManager
from airflow.auth.managers.models.resource_details import (
Expand All @@ -34,6 +35,7 @@
VariableDetails,
)
from airflow.cli.cli_config import CLICommand, DefaultHelpParser, GroupCommand
from airflow.configuration import conf
from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities
from airflow.providers.amazon.aws.auth_manager.avp.facade import (
Expand All @@ -43,11 +45,7 @@
from airflow.providers.amazon.aws.auth_manager.cli.definition import (
AWS_AUTH_MANAGER_COMMANDS,
)
from airflow.providers.amazon.aws.auth_manager.security_manager.aws_security_manager_override import (
AwsSecurityManagerOverride,
)
from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser
from airflow.providers.amazon.aws.auth_manager.views.auth import AwsAuthManagerAuthenticationViews
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS

if TYPE_CHECKING:
Expand All @@ -61,7 +59,6 @@
IsAuthorizedVariableRequest,
)
from airflow.auth.managers.models.resource_details import AssetDetails, ConfigurationDetails
from airflow.www.extensions.init_appbuilder import AirflowAppBuilder


class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
Expand All @@ -72,8 +69,6 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
authentication and authorization in Airflow.
"""

appbuilder: AirflowAppBuilder | None = None

def __init__(self) -> None:
if not AIRFLOW_V_3_0_PLUS:
raise AirflowOptionalProviderFeatureException(
Expand All @@ -87,12 +82,27 @@ def __init__(self) -> None:
def avp_facade(self):
return AwsAuthManagerAmazonVerifiedPermissionsFacade()

@cached_property
def fastapi_endpoint(self) -> str:
return conf.get("fastapi", "base_url")

def get_user(self) -> AwsAuthManagerUser | None:
return session["aws_user"] if self.is_logged_in() else None

def is_logged_in(self) -> bool:
return "aws_user" in session

def deserialize_user(self, token: dict[str, Any]) -> AwsAuthManagerUser:
return AwsAuthManagerUser(**token)

def serialize_user(self, user: AwsAuthManagerUser) -> dict[str, Any]:
return {
"user_id": user.get_id(),
"groups": user.get_groups(),
"username": user.username,
"email": user.email,
}

def is_authorized_configuration(
self,
*,
Expand Down Expand Up @@ -367,14 +377,10 @@ def _has_access_to_menu_item(request: IsAuthorizedRequest):
return accessible_items

def get_url_login(self, **kwargs) -> str:
return url_for("AwsAuthManagerAuthenticationViews.login")
return f"{self.fastapi_endpoint}/auth/login"

def get_url_logout(self) -> str:
return url_for("AwsAuthManagerAuthenticationViews.logout")

@cached_property
def security_manager(self) -> AwsSecurityManagerOverride:
return AwsSecurityManagerOverride(self.appbuilder)
raise NotImplementedError()

@staticmethod
def get_cli_commands() -> list[CLICommand]:
Expand All @@ -387,9 +393,20 @@ def get_cli_commands() -> list[CLICommand]:
),
]

def register_views(self) -> None:
if self.appbuilder:
self.appbuilder.add_view_no_menu(AwsAuthManagerAuthenticationViews())
def get_fastapi_app(self) -> FastAPI | None:
from airflow.providers.amazon.aws.auth_manager.router.login import login_router

app = FastAPI(
title="AWS auth manager sub application",
description=(
"This is the AWS auth manager fastapi sub application. This API is only available if the "
"auth manager used in the Airflow environment is AWS auth manager. "
"This sub application provides login routes."
),
)
app.include_router(login_router)

return app

@staticmethod
def _get_menu_item_request(resource_name: str) -> IsAuthorizedRequest:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import logging
from typing import Any

import anyio
from fastapi import HTTPException, Request
from starlette import status
from starlette.responses import RedirectResponse

from airflow.api_fastapi.app import get_auth_manager
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.configuration import conf
from airflow.providers.amazon.aws.auth_manager.constants import CONF_SAML_METADATA_URL_KEY, CONF_SECTION_NAME
from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser

try:
from onelogin.saml2.auth import OneLogin_Saml2_Auth
from onelogin.saml2.errors import OneLogin_Saml2_Error
from onelogin.saml2.idp_metadata_parser import OneLogin_Saml2_IdPMetadataParser
except ImportError:
raise ImportError(
"AWS auth manager requires the python3-saml library but it is not installed by default. "
"Please install the python3-saml library by running: "
"pip install apache-airflow-providers-amazon[python3-saml]"
)

log = logging.getLogger(__name__)
login_router = AirflowRouter(tags=["AWSAuthManagerLogin"])


@login_router.get("/login")
def login(request: Request):
"""Authenticate the user."""
saml_auth = _init_saml_auth(request)
callback_url = saml_auth.login()
return RedirectResponse(url=callback_url)


@login_router.post("/login_callback")
def login_callback(request: Request):
"""Authenticate the user."""
saml_auth = _init_saml_auth(request)
try:
saml_auth.process_response()
except OneLogin_Saml2_Error as e:
log.exception(e)
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to authenticate")
errors = saml_auth.get_errors()
is_authenticated = saml_auth.is_authenticated()
if not is_authenticated:
error_reason = saml_auth.get_last_error_reason()
log.error("Failed to authenticate")
log.error("Errors: %s", errors)
log.error("Error reason: %s", error_reason)
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, f"Failed to authenticate: {error_reason}")

attributes = saml_auth.get_attributes()
user = AwsAuthManagerUser(
user_id=attributes["id"][0],
groups=attributes["groups"],
username=saml_auth.get_nameid(),
email=attributes["email"][0] if "email" in attributes else None,
)
return RedirectResponse(url=f"/webapp?token={get_auth_manager().get_jwt_token(user)}", status_code=303)


def _init_saml_auth(request: Request) -> OneLogin_Saml2_Auth:
request_data = _prepare_request(request)
base_url = conf.get(section="fastapi", key="base_url")
settings = {
# We want to keep this flag on in case of errors.
# It provides an error reasons, if turned off, it does not
"debug": True,
"sp": {
"entityId": "aws-auth-manager-saml-client",
"assertionConsumerService": {
"url": f"{base_url}/auth/login_callback",
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST",
},
},
}
merged_settings = OneLogin_Saml2_IdPMetadataParser.merge_settings(_get_idp_data(), settings)
return OneLogin_Saml2_Auth(request_data, merged_settings)


def _prepare_request(request: Request) -> dict:
host = request.headers.get("host", request.client.host if request.client else "localhost")
data: dict[str, Any] = {
"https": "on" if request.url.scheme == "https" else "off",
"http_host": host,
"server_port": request.url.port,
"script_name": request.url.path,
"get_data": request.query_params,
"post_data": {},
}
form_data = anyio.from_thread.run(request.form)
if "SAMLResponse" in form_data:
data["post_data"]["SAMLResponse"] = form_data["SAMLResponse"]
if "RelayState" in form_data:
data["post_data"]["RelayState"] = form_data["RelayState"]
return data


def _get_idp_data() -> dict:
saml_metadata_url = conf.get_mandatory_value(CONF_SECTION_NAME, CONF_SAML_METADATA_URL_KEY)
return OneLogin_Saml2_IdPMetadataParser.parse_remote(saml_metadata_url)
Loading

0 comments on commit 17f3799

Please sign in to comment.