From bf2cdcc18bde49220a745d3c8a90ba5802e8da1a Mon Sep 17 00:00:00 2001 From: Otrek Wilke Date: Wed, 31 May 2023 21:05:26 +0200 Subject: [PATCH] added flag, whether or not a multi tenant app and added tests to test decorator --- fastapi_microsoft_identity/auth_service.py | 58 ++++++++++++++-------- tests/test_azure_ad.py | 48 ++++++++++++++++-- 2 files changed, 82 insertions(+), 24 deletions(-) diff --git a/fastapi_microsoft_identity/auth_service.py b/fastapi_microsoft_identity/auth_service.py index d9c024e..5196531 100644 --- a/fastapi_microsoft_identity/auth_service.py +++ b/fastapi_microsoft_identity/auth_service.py @@ -9,17 +9,20 @@ client_id=None b2c_policy_name = None b2c_domain_name = None +multitenant = False def initialize( tenant_id_, client_id_, b2c_policy_name_=None, - b2c_domain_name_=None): - global tenant_id, client_id, b2c_policy_name, b2c_domain_name + b2c_domain_name_=None, **kwargs): + global tenant_id, client_id, b2c_policy_name, b2c_domain_name, multitenant tenant_id = tenant_id_ client_id = client_id_ b2c_policy_name = b2c_policy_name_ b2c_domain_name = b2c_domain_name_ + multitenant = kwargs.get('multitenant') or False + class AuthError(Exception): def __init__(self, error_msg:str, status_code:int): @@ -91,12 +94,10 @@ async def decorated(*args, **kwargs): try: token = get_token_auth_header(kwargs["request"]) url = f'https://login.microsoftonline.com/{tenant_id}/discovery/v2.0/keys' - async with httpx.AsyncClient() as client: resp: Response = await client.get(url) if resp.status_code != 200: raise AuthError("Problem with Azure AD discovery URL", status_code=404) - jwks = resp.json() unverified_header = jwt.get_unverified_header(token) rsa_key = {} @@ -109,7 +110,7 @@ async def decorated(*args, **kwargs): "n": key["n"], "e": key["e"] } - except Exception: + except Exception as e: return fastapi.Response(content="Invalid_header: Unable to parse authentication", status_code= 401) if rsa_key: try : @@ -117,10 +118,11 @@ async def decorated(*args, **kwargs): __decode_JWT(token_version, token, rsa_key) return await f(*args, **kwargs) except AuthError as auth_err: - fastapi.Response(content=auth_err.error_msg, status_code=auth_err.status_code) + return fastapi.Response(content=auth_err.error_msg, status_code=auth_err.status_code) return fastapi.Response(content="Invalid header error: Unable to find appropriate key", status_code=401) return decorated + def requires_b2c_auth(f): @wraps(f) async def decorated(*args, **kwargs): @@ -163,13 +165,21 @@ def __decode_B2C_JWT(token_version, token, rsa_key): else: _issuer = f'https://{b2c_domain_name}.b2clogin.com/{tenant_id}/v2.0'.lower() try: - payload = jwt.decode( - token, - rsa_key, - algorithms=["RS256"], - audience=client_id, - issuer=_issuer - ) + if multitenant: + payload = jwt.decode( + token, + rsa_key, + algorithms=["RS256"], + audience=client_id + ) + else: + payload = jwt.decode( + token, + rsa_key, + algorithms=["RS256"], + audience=client_id, + issuer=_issuer + ) except jwt.ExpiredSignatureError: raise AuthError("Token error: The token has expired", 401) except jwt.JWTClaimsError: @@ -185,13 +195,21 @@ def __decode_JWT(token_version, token, rsa_key): _issuer = f'https://login.microsoftonline.com/{tenant_id}/v2.0' _audience=f'{client_id}' try: - payload = jwt.decode( - token, - rsa_key, - algorithms=["RS256"], - audience=_audience, - issuer=_issuer - ) + if multitenant: + payload = jwt.decode( + token, + rsa_key, + algorithms=["RS256"], + audience=_audience + ) + else: + payload = jwt.decode( + token, + rsa_key, + algorithms=["RS256"], + audience=_audience, + issuer=_issuer + ) except jwt.ExpiredSignatureError: raise AuthError("Token error: The token has expired", 401) except jwt.JWTClaimsError: diff --git a/tests/test_azure_ad.py b/tests/test_azure_ad.py index abe34d8..b542e4d 100644 --- a/tests/test_azure_ad.py +++ b/tests/test_azure_ad.py @@ -1,23 +1,30 @@ # Sample Test passing with nose and pytest import pytest -from fastapi import Request +from fastapi import Request, Response import sys import os + container_folder = os.path.abspath(os.path.join( os.path.dirname(__file__), '..' )) sys.path.insert(0, container_folder) -from fastapi_microsoft_identity import auth_service, AuthError +from fastapi_microsoft_identity import auth_service, AuthError, requires_auth from multidict import MultiDict +client_id = os.environ.get('CLIENT_ID') or "66ba9476-0700-4178-81ea-fbeb7097c28e" +tenant_id = os.environ.get('TENANT_ID') or "de2656e6-585f-4684-8e65-3ce50a7770a8" + user_token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6Imk2bEdrM0ZaenhSY1ViMkMzbkVRN3N5SEpsWSJ9.eyJhdWQiOiI2ZTc0MTcyYi1iZTU2LTQ4NDMtOWZmNC1lNjZhMzliYjEyZTMiLCJpc3MiOiJodHRwczovL2xvZ2luLm1pY3Jvc29mdG9ubGluZS5jb20vNzJmOTg4YmYtODZmMS00MWFmLTkxYWItMmQ3Y2QwMTFkYjQ3L3YyLjAiLCJpYXQiOjE1MzcyMzEwNDgsIm5iZiI6MTUzNzIzMTA0OCwiZXhwIjoxNTM3MjM0OTQ4LCJhaW8iOiJBWFFBaS84SUFBQUF0QWFaTG8zQ2hNaWY2S09udHRSQjdlQnE0L0RjY1F6amNKR3hQWXkvQzNqRGFOR3hYZDZ3TklJVkdSZ2hOUm53SjFsT2NBbk5aY2p2a295ckZ4Q3R0djMzMTQwUmlvT0ZKNGJDQ0dWdW9DYWcxdU9UVDIyMjIyZ0h3TFBZUS91Zjc5UVgrMEtJaWpkcm1wNjlSY3R6bVE9PSIsImF6cCI6IjZlNzQxNzJiLWJlNTYtNDg0My05ZmY0LWU2NmEzOWJiMTJlMyIsImF6cGFjciI6IjAiLCJuYW1lIjoiQWJlIExpbmNvbG4iLCJvaWQiOiI2OTAyMjJiZS1mZjFhLTRkNTYtYWJkMS03ZTRmN2QzOGU0NzQiLCJwcmVmZXJyZWRfdXNlcm5hbWUiOiJhYmVsaUBtaWNyb3NvZnQuY29tIiwicmgiOiJJIiwic2NwIjoiYWNjZXNzX2FzX3VzZXIiLCJzdWIiOiJIS1pwZmFIeVdhZGVPb3VZbGl0anJJLUtmZlRtMjIyWDVyclYzeERxZktRIiwidGlkIjoiNzJmOTg4YmYtODZmMS00MWFmLTkxYWItMmQ3Y2QwMTFkYjQ3IiwidXRpIjoiZnFpQnFYTFBqMGVRYTgyUy1JWUZBQSIsInZlciI6IjIuMCJ9.pj4N-w_3Us9DrBLfpCt" application_token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6Ik1yNS1BVWliZkJpaTdOZDFqQmViYXhib1hXMCJ9.eyJhdWQiOiJkZTI2NTZlNi01ODVmLTQ2ODQtOGU2NS0zY2U1MGE3NzcwYTgiLCJpc3MiOiJodHRwczovL2xvZ2luLm1pY3Jvc29mdG9ubGluZS5jb20vNjZiYTk0NzYtMDcwMC00MTc4LTgxZWEtZmJlYjcwOTdjMjhlL3YyLjAiLCJpYXQiOjE2NDYxNjc1NzIsIm5iZiI6MTY0NjE2NzU3MiwiZXhwIjoxNjQ2MTcxNDcyLCJhaW8iOiJFMlpnWVBoN2RoSFBLNGNuOFE1TUFob01CenAwQVE9PSIsImF6cCI6ImY3NTllY2FiLWM0NWUtNDVlZS1hYWZmLWJlMzJhZmM3ZGU5YiIsImF6cGFjciI6IjEiLCJvaWQiOiIyMTJkOGM2ZS05YzdmLTQ4MWEtOGZkOC1kOTllMzVhOWNiMWMiLCJyaCI6IjAuQVZBQWRwUzZaZ0FIZUVHQjZ2dnJjSmZDanVaV0p0NWZXSVJHam1VODVRcDNjS2hfQUFBLiIsInJvbGVzIjpbImFwcC53ZWF0aGVyLnJlYWQiXSwic3ViIjoiMjEyZDhjNmUtOWM3Zi00ODFhLThmZDgtZDk5ZTM1YTljYjFjIiwidGlkIjoiNjZiYTk0NzYtMDcwMC00MTc4LTgxZWEtZmJlYjcwOTdjMjhlIiwidXRpIjoiX1lYOWhSbElvMGVwX2c3bk9KeXpBUSIsInZlciI6IjIuMCJ9.omq5Abe7rObD_-NDZ64KB3hf3pfCOCS4Sk3cz-jA_4cd49zwzq7wOI8CtXq5vhLUpbwRGCGiZqG-WYmTrTmDwNn2KcsEL8SQkKK5FCOriit8PrDVBAbidAAZsp8OgchhuNBdzp4wUUB7X3cQPk2g6XVOchqvw6MJZVFxi8r5Kqxq8AMJJlHO-ijUX5qKRcrIHkhezFjtGs-TV1dgdpGshKcWhpA635ehRFigY0Hry6vyYaPuiwufp2iMXJ1ZT6ZHqFIE_HeQNLTo39zV5CzVQ4UHH9gDMHfqSbEEO79JyZfNF_HjH40fmvj5HKA8nOEL_LG7fFy3p4BPiVAeUqeUvw" +my_user_token = os.environ.get("USER_TOKEN") auth_service.initialize( - "66ba9476-0700-4178-81ea-fbeb7097c28e", - "de2656e6-585f-4684-8e65-3ce50a7770a8") + tenant_id, + client_id, + multitenant=True + ) def test_auth_header_has_token(): headers_with_auth = MultiDict([("Authorization", f'Bearer {user_token}'), ("Content-Type", "application/json")]) @@ -80,3 +87,36 @@ def test_can_retrieve_application_token_claims(): assert claims != None, "Retrieved token claims successfully!" +@pytest.mark.asyncio +async def test_requires_auth_decorator(): + headers_with_auth = MultiDict([("Authorization", f'Bearer {my_user_token}'), ("Content-Type", "application/json")]) + request = Request + request.headers = headers_with_auth + + # a dummy function that is to be decorated + async def test_func(request: Request)-> str: + return "test" + + result = requires_auth(test_func) + response = await result(request=request) + if isinstance(response, Response): + print(f"auth resulted in {response.body}") + assert response == "test", "Requires auth decorator works!" + + +@pytest.mark.asyncio +async def test_can_find_user_but_from_different_tenant(): + other_tenant_token = os.environ.get("OTHER_TENANT_TOKEN") + headers_with_auth = MultiDict([("Authorization", f'Bearer {other_tenant_token}'), ("Content-Type", "application/json")]) + request = Request + request.headers = headers_with_auth + + # a dummy function that is to be decorated + async def test_func(request: Request)-> str: + return "test" + + result = requires_auth(test_func) + response = await result(request=request) + if isinstance(response, Response): + print(f"auth resulted in {response.body}") + assert response == "test", "Requires auth decorator works!"