Skip to content

Commit

Permalink
proper error handling in authentication middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasScholl committed Feb 25, 2022
1 parent db8521f commit 7a6893a
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 17 deletions.
25 changes: 17 additions & 8 deletions fastapi_auth_middleware/middleware.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Tuple
from typing import Tuple, Callable, List

from fastapi import FastAPI
from starlette.authentication import AuthenticationBackend, AuthCredentials, AuthenticationError, BaseUser
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.requests import HTTPConnection
from starlette.requests import HTTPConnection, Request
from starlette.responses import JSONResponse


class FastAPIUser(BaseUser):
Expand Down Expand Up @@ -44,7 +45,7 @@ def identity(self) -> str:
class FastAPIAuthBackend(AuthenticationBackend):
""" Auth Backend for FastAPI """

def __init__(self, verify_authorization_header: callable):
def __init__(self, verify_authorization_header: Callable[[str], Tuple[List[str], BaseUser]]):
""" Auth Backend constructor. Part of an AuthenticationMiddleware as backend.
Args:
Expand All @@ -64,19 +65,27 @@ async def authenticate(self, conn: HTTPConnection) -> Tuple[AuthCredentials, Bas
if "Authorization" not in conn.headers:
raise AuthenticationError("Authorization header missing")

authorization_header: str = conn.headers["Authorization"]
scopes, user = self.verify_authorization_header(authorization_header)
try:
authorization_header: str = conn.headers["Authorization"]
scopes, user = self.verify_authorization_header(authorization_header)
except Exception as exception:
raise AuthenticationError(exception) from None

return AuthCredentials(scopes=scopes), user


def AuthMiddleware(app: FastAPI, verify_authorization_header: callable):
def AuthMiddleware(
app: FastAPI,
verify_authorization_header: Callable[[str], Tuple[List[str], BaseUser]],
auth_error_handler: Callable[[Request, AuthenticationError], JSONResponse] = None
):
""" Factory method, returning an AuthenticationMiddleware
Intentionally not named with lower snake case convention as this is a factory method returning a class. Should feel like a class.
Args:
app (FastAPI): The FastAPI instance the middleware should be applied to. The `add_middleware` function of FastAPI adds the app as first argument by default.
verify_authorization_header (callable): A function handle that returns a list of scopes and a BaseUser
verify_authorization_header (Callable[[str], Tuple[List[str], BaseUser]]): A function handle that returns a list of scopes and a BaseUser
auth_error_handler (Callable[[Request, Exception], JSONResponse]): Optional error handler for creating responses when an exception was raised in verify_authorization_header
Examples:
```python
Expand All @@ -89,4 +98,4 @@ def verify_authorization_header(auth_header: str) -> Tuple[List[str], FastAPIUse
app.add_middleware(AuthMiddleware, verify_authorization_header=verify_authorization_header)
```
"""
return AuthenticationMiddleware(app, backend=FastAPIAuthBackend(verify_authorization_header))
return AuthenticationMiddleware(app, backend=FastAPIAuthBackend(verify_authorization_header), on_error=auth_error_handler)
43 changes: 34 additions & 9 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Callable

from _pytest.fixtures import fixture
from fastapi import FastAPI
from starlette.authentication import requires
from starlette.authentication import requires, AuthenticationError
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.testclient import TestClient

from fastapi_auth_middleware import AuthMiddleware, FastAPIUser
Expand All @@ -20,10 +23,14 @@ def verify_authorization_header_basic_admin_scope(auth_header: str):
return scopes, user


def raise_exception_in_verify_authorization_header(_):
raise Exception('some auth error occured')


# Sample app with simple routes, takes a verify_authorization_header callable that is applied to the middleware
def fastapi_app(verify_authorization_header: callable):
def fastapi_app(verify_authorization_header: Callable, auth_error_handler: Callable = None):
app = FastAPI()
app.add_middleware(AuthMiddleware, verify_authorization_header=verify_authorization_header)
app.add_middleware(AuthMiddleware, verify_authorization_header=verify_authorization_header, auth_error_handler=auth_error_handler)

@app.get("/")
def home():
Expand All @@ -49,26 +56,44 @@ class TestBasicBehaviour:
"""

@fixture
def client(self):
def client(self) -> TestClient:
app = fastapi_app(verify_authorization_header_basic)
return TestClient(app)

@fixture
def client_with_scopes(self):
def client_with_scopes(self) -> TestClient:
app = fastapi_app(verify_authorization_header_basic_admin_scope)
return TestClient(app)

def test_home_fail_no_header(self, client):
def test_home_fail_no_header(self, client: TestClient):
assert client.get("/").status_code == 400

def test_home_succeed(self, client):
def test_home_succeed(self, client: TestClient):
assert client.get("/", headers={"Authorization": "ey.."}).status_code == 200

def test_user_attributes(self, client):
def test_user_attributes(self, client: TestClient):
request = client.get("/user", headers={"Authorization": "ey.."})
assert request.status_code == 200
assert request.content == b'"True Code Specialist 1"' # b'"{user.is_authenticated} {user.display_name} {user.identity}"'

def test_scopes(self, client, client_with_scopes):
def test_scopes(self, client: TestClient, client_with_scopes: TestClient):
assert client.get("/admin-scope", headers={"Authorization": "ey.."}).status_code == 403 # Does not contain the requested scope
assert client_with_scopes.get("/admin-scope", headers={"Authorization": "ey.."}).status_code == 200 # Contains the requested scope

def test_fail_auth_error(self):
app = fastapi_app(verify_authorization_header=raise_exception_in_verify_authorization_header)
client_with_auth_error = TestClient(app=app)

response = client_with_auth_error.get('/', headers={"Authorization": "ey.."})
assert response.status_code == 400

def test_fail_auth_error_with_custom_handler(self):
def handle_auth_error(request: Request, exception: AuthenticationError):
assert isinstance(exception, AuthenticationError)
return JSONResponse(content={'message': str(exception)}, status_code=401)

app = fastapi_app(verify_authorization_header=raise_exception_in_verify_authorization_header, auth_error_handler=handle_auth_error)
client_with_auth_error = TestClient(app=app)

response = client_with_auth_error.get('/', headers={"Authorization": "ey.."})
assert response.status_code == 401

0 comments on commit 7a6893a

Please sign in to comment.