Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added flag, whether or not a multi tenant app and added tests to test… #17

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 38 additions & 20 deletions fastapi_microsoft_identity/auth_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,20 @@
client_id=None
b2c_policy_name = None
b2c_domain_name = None
multitenant = False

def initialize(
tenant_id_,
client_id_,
b2c_policy_name_=None,
b2c_domain_name_=None):
global tenant_id, client_id, b2c_policy_name, b2c_domain_name
b2c_domain_name_=None, **kwargs):
global tenant_id, client_id, b2c_policy_name, b2c_domain_name, multitenant
tenant_id = tenant_id_
client_id = client_id_
b2c_policy_name = b2c_policy_name_
b2c_domain_name = b2c_domain_name_
multitenant = kwargs.get('multitenant') or False


class AuthError(Exception):
def __init__(self, error_msg:str, status_code:int):
Expand Down Expand Up @@ -91,12 +94,10 @@ async def decorated(*args, **kwargs):
try:
token = get_token_auth_header(kwargs["request"])
url = f'https://login.microsoftonline.com/{tenant_id}/discovery/v2.0/keys'

async with httpx.AsyncClient() as client:
resp: Response = await client.get(url)
if resp.status_code != 200:
raise AuthError("Problem with Azure AD discovery URL", status_code=404)

jwks = resp.json()
unverified_header = jwt.get_unverified_header(token)
rsa_key = {}
Expand All @@ -109,18 +110,19 @@ async def decorated(*args, **kwargs):
"n": key["n"],
"e": key["e"]
}
except Exception:
except Exception as e:
return fastapi.Response(content="Invalid_header: Unable to parse authentication", status_code= 401)
if rsa_key:
try :
token_version = __get_token_version(token)
__decode_JWT(token_version, token, rsa_key)
return await f(*args, **kwargs)
except AuthError as auth_err:
fastapi.Response(content=auth_err.error_msg, status_code=auth_err.status_code)
return fastapi.Response(content=auth_err.error_msg, status_code=auth_err.status_code)
return fastapi.Response(content="Invalid header error: Unable to find appropriate key", status_code=401)
return decorated


def requires_b2c_auth(f):
@wraps(f)
async def decorated(*args, **kwargs):
Expand Down Expand Up @@ -163,13 +165,21 @@ def __decode_B2C_JWT(token_version, token, rsa_key):
else:
_issuer = f'https://{b2c_domain_name}.b2clogin.com/{tenant_id}/v2.0'.lower()
try:
payload = jwt.decode(
token,
rsa_key,
algorithms=["RS256"],
audience=client_id,
issuer=_issuer
)
if multitenant:
payload = jwt.decode(
token,
rsa_key,
algorithms=["RS256"],
audience=client_id
)
else:
payload = jwt.decode(
token,
rsa_key,
algorithms=["RS256"],
audience=client_id,
issuer=_issuer
)
except jwt.ExpiredSignatureError:
raise AuthError("Token error: The token has expired", 401)
except jwt.JWTClaimsError:
Expand All @@ -185,13 +195,21 @@ def __decode_JWT(token_version, token, rsa_key):
_issuer = f'https://login.microsoftonline.com/{tenant_id}/v2.0'
_audience=f'{client_id}'
try:
payload = jwt.decode(
token,
rsa_key,
algorithms=["RS256"],
audience=_audience,
issuer=_issuer
)
if multitenant:
payload = jwt.decode(
token,
rsa_key,
algorithms=["RS256"],
audience=_audience
)
else:
payload = jwt.decode(
token,
rsa_key,
algorithms=["RS256"],
audience=_audience,
issuer=_issuer
)
except jwt.ExpiredSignatureError:
raise AuthError("Token error: The token has expired", 401)
except jwt.JWTClaimsError:
Expand Down
48 changes: 44 additions & 4 deletions tests/test_azure_ad.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
# Sample Test passing with nose and pytest
import pytest
from fastapi import Request
from fastapi import Request, Response
import sys
import os


container_folder = os.path.abspath(os.path.join(
os.path.dirname(__file__), '..'
))
sys.path.insert(0, container_folder)

from fastapi_microsoft_identity import auth_service, AuthError
from fastapi_microsoft_identity import auth_service, AuthError, requires_auth
from multidict import MultiDict

client_id = os.environ.get('CLIENT_ID') or "66ba9476-0700-4178-81ea-fbeb7097c28e"
tenant_id = os.environ.get('TENANT_ID') or "de2656e6-585f-4684-8e65-3ce50a7770a8"

user_token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6Imk2bEdrM0ZaenhSY1ViMkMzbkVRN3N5SEpsWSJ9.eyJhdWQiOiI2ZTc0MTcyYi1iZTU2LTQ4NDMtOWZmNC1lNjZhMzliYjEyZTMiLCJpc3MiOiJodHRwczovL2xvZ2luLm1pY3Jvc29mdG9ubGluZS5jb20vNzJmOTg4YmYtODZmMS00MWFmLTkxYWItMmQ3Y2QwMTFkYjQ3L3YyLjAiLCJpYXQiOjE1MzcyMzEwNDgsIm5iZiI6MTUzNzIzMTA0OCwiZXhwIjoxNTM3MjM0OTQ4LCJhaW8iOiJBWFFBaS84SUFBQUF0QWFaTG8zQ2hNaWY2S09udHRSQjdlQnE0L0RjY1F6amNKR3hQWXkvQzNqRGFOR3hYZDZ3TklJVkdSZ2hOUm53SjFsT2NBbk5aY2p2a295ckZ4Q3R0djMzMTQwUmlvT0ZKNGJDQ0dWdW9DYWcxdU9UVDIyMjIyZ0h3TFBZUS91Zjc5UVgrMEtJaWpkcm1wNjlSY3R6bVE9PSIsImF6cCI6IjZlNzQxNzJiLWJlNTYtNDg0My05ZmY0LWU2NmEzOWJiMTJlMyIsImF6cGFjciI6IjAiLCJuYW1lIjoiQWJlIExpbmNvbG4iLCJvaWQiOiI2OTAyMjJiZS1mZjFhLTRkNTYtYWJkMS03ZTRmN2QzOGU0NzQiLCJwcmVmZXJyZWRfdXNlcm5hbWUiOiJhYmVsaUBtaWNyb3NvZnQuY29tIiwicmgiOiJJIiwic2NwIjoiYWNjZXNzX2FzX3VzZXIiLCJzdWIiOiJIS1pwZmFIeVdhZGVPb3VZbGl0anJJLUtmZlRtMjIyWDVyclYzeERxZktRIiwidGlkIjoiNzJmOTg4YmYtODZmMS00MWFmLTkxYWItMmQ3Y2QwMTFkYjQ3IiwidXRpIjoiZnFpQnFYTFBqMGVRYTgyUy1JWUZBQSIsInZlciI6IjIuMCJ9.pj4N-w_3Us9DrBLfpCt"
application_token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6Ik1yNS1BVWliZkJpaTdOZDFqQmViYXhib1hXMCJ9.eyJhdWQiOiJkZTI2NTZlNi01ODVmLTQ2ODQtOGU2NS0zY2U1MGE3NzcwYTgiLCJpc3MiOiJodHRwczovL2xvZ2luLm1pY3Jvc29mdG9ubGluZS5jb20vNjZiYTk0NzYtMDcwMC00MTc4LTgxZWEtZmJlYjcwOTdjMjhlL3YyLjAiLCJpYXQiOjE2NDYxNjc1NzIsIm5iZiI6MTY0NjE2NzU3MiwiZXhwIjoxNjQ2MTcxNDcyLCJhaW8iOiJFMlpnWVBoN2RoSFBLNGNuOFE1TUFob01CenAwQVE9PSIsImF6cCI6ImY3NTllY2FiLWM0NWUtNDVlZS1hYWZmLWJlMzJhZmM3ZGU5YiIsImF6cGFjciI6IjEiLCJvaWQiOiIyMTJkOGM2ZS05YzdmLTQ4MWEtOGZkOC1kOTllMzVhOWNiMWMiLCJyaCI6IjAuQVZBQWRwUzZaZ0FIZUVHQjZ2dnJjSmZDanVaV0p0NWZXSVJHam1VODVRcDNjS2hfQUFBLiIsInJvbGVzIjpbImFwcC53ZWF0aGVyLnJlYWQiXSwic3ViIjoiMjEyZDhjNmUtOWM3Zi00ODFhLThmZDgtZDk5ZTM1YTljYjFjIiwidGlkIjoiNjZiYTk0NzYtMDcwMC00MTc4LTgxZWEtZmJlYjcwOTdjMjhlIiwidXRpIjoiX1lYOWhSbElvMGVwX2c3bk9KeXpBUSIsInZlciI6IjIuMCJ9.omq5Abe7rObD_-NDZ64KB3hf3pfCOCS4Sk3cz-jA_4cd49zwzq7wOI8CtXq5vhLUpbwRGCGiZqG-WYmTrTmDwNn2KcsEL8SQkKK5FCOriit8PrDVBAbidAAZsp8OgchhuNBdzp4wUUB7X3cQPk2g6XVOchqvw6MJZVFxi8r5Kqxq8AMJJlHO-ijUX5qKRcrIHkhezFjtGs-TV1dgdpGshKcWhpA635ehRFigY0Hry6vyYaPuiwufp2iMXJ1ZT6ZHqFIE_HeQNLTo39zV5CzVQ4UHH9gDMHfqSbEEO79JyZfNF_HjH40fmvj5HKA8nOEL_LG7fFy3p4BPiVAeUqeUvw"
my_user_token = os.environ.get("USER_TOKEN")

auth_service.initialize(
"66ba9476-0700-4178-81ea-fbeb7097c28e",
"de2656e6-585f-4684-8e65-3ce50a7770a8")
tenant_id,
client_id,
multitenant=True
)

def test_auth_header_has_token():
headers_with_auth = MultiDict([("Authorization", f'Bearer {user_token}'), ("Content-Type", "application/json")])
Expand Down Expand Up @@ -80,3 +87,36 @@ def test_can_retrieve_application_token_claims():
assert claims != None, "Retrieved token claims successfully!"


@pytest.mark.asyncio
async def test_requires_auth_decorator():
headers_with_auth = MultiDict([("Authorization", f'Bearer {my_user_token}'), ("Content-Type", "application/json")])
request = Request
request.headers = headers_with_auth

# a dummy function that is to be decorated
async def test_func(request: Request)-> str:
return "test"

result = requires_auth(test_func)
response = await result(request=request)
if isinstance(response, Response):
print(f"auth resulted in {response.body}")
assert response == "test", "Requires auth decorator works!"


@pytest.mark.asyncio
async def test_can_find_user_but_from_different_tenant():
other_tenant_token = os.environ.get("OTHER_TENANT_TOKEN")
headers_with_auth = MultiDict([("Authorization", f'Bearer {other_tenant_token}'), ("Content-Type", "application/json")])
request = Request
request.headers = headers_with_auth

# a dummy function that is to be decorated
async def test_func(request: Request)-> str:
return "test"

result = requires_auth(test_func)
response = await result(request=request)
if isinstance(response, Response):
print(f"auth resulted in {response.body}")
assert response == "test", "Requires auth decorator works!"