Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
closes #1
  • Loading branch information
alukach committed Aug 22, 2024
1 parent 2c65bb0 commit 4c1cdd7
Show file tree
Hide file tree
Showing 6 changed files with 259 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Lint
on:
push:
paths:
- "eoapi/**"
- "**/*.py"

jobs:
pre-commit:
Expand Down
25 changes: 25 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
name: Test

on:
push:
paths:
- "**/*.py"
- "pyproject.toml"

jobs:
pytest:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
cache: "pip"

- name: Install dependencies
run: pip install -e ".[testing]"

- name: Run tests
run: pytest
Empty file added eoapi/__init__.py
Empty file.
19 changes: 11 additions & 8 deletions eoapi/auth_utils/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,28 +69,31 @@ def create_auth_token_dependency(
"""

def auth_token(
token_str: Annotated[str, Security(auth_scheme)],
auth_header: Annotated[str, Security(auth_scheme)],
required_scopes: security.SecurityScopes,
):
token_parts = token_str.split(" ")
# Extract token from header
token_parts = auth_header.split(" ")
if len(token_parts) != 2 or token_parts[0].lower() != "bearer":
logger.error(f"Invalid token: {auth_header}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authorization header",
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
else:
[_, token] = token_parts
[_, token] = token_parts

# Parse & validate token
try:
key = jwks_client.get_signing_key_from_jwt(token).key
payload = jwt.decode(
token,
jwks_client.get_signing_key_from_jwt(token).key,
key,
algorithms=["RS256"],
# NOTE: Audience validation MUST match audience claim if set in token (https://pyjwt.readthedocs.io/en/stable/changelog.html?highlight=audience#id40)
audience=allowed_jwt_audiences,
)
except jwt.exceptions.InvalidTokenError as e:
except (jwt.exceptions.InvalidTokenError, jwt.exceptions.DecodeError) as e:
logger.exception(f"InvalidTokenError: {e=}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
Expand Down Expand Up @@ -124,7 +127,7 @@ def apply_auth_dependencies(
"""
# Ignore paths without dependants, e.g. /api, /api.html, /docs/oauth2-redirect
if not hasattr(api_route, "dependant"):
logger.warn(
logger.warning(
f"Route {api_route} has no dependant, not apply auth dependency"
)
return
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ authors = [
{name = "Anthony Lukach", email = "[email protected]"},
]
dependencies = [
"cryptography>=43.0.0",
"fastapi-slim>=0.111.0",
"pydantic-settings>=2.2.1",
"pyjwt>=2.9.0",
"cryptography>=43.0.0",
]
description = "Authentication & authorization helpers for eoAPI"
dynamic = ["version"]
Expand All @@ -32,6 +32,8 @@ lint = [
"pre-commit",
]
testing = [
"pytest>=6.0",
"coverage",
"httpx>=0.27.0",
"jwcrypto>=1.5.6",
"pytest>=6.0",
]
218 changes: 218 additions & 0 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
import json
from typing import Any, Dict
from unittest.mock import MagicMock, patch

import jwt
import pytest
from cryptography.hazmat.primitives.asymmetric import rsa
from fastapi import FastAPI, HTTPException, Security, status, testclient
from jwcrypto.jwt import JWK, JWT

from eoapi.auth_utils import OpenIdConnectAuth


@pytest.fixture
def test_key() -> "JWK":
return JWK.generate(
kty="RSA", size=2048, kid="test", use="sig", e="AQAB", alg="RS256"
)


@pytest.fixture
def public_key(test_key: "JWK") -> Dict[str, Any]:
return test_key.export_public(as_dict=True)


@pytest.fixture
def private_key(test_key: "JWK") -> Dict[str, Any]:
return test_key.export_private(as_dict=True)


@pytest.fixture(autouse=True)
def mock_jwks(public_key: "rsa.RSAPrivateKey"):
mock_oidc_config = {"jwks_uri": "https://example.com/jwks"}

mock_jwks = {"keys": [public_key]}

with (
patch("urllib.request.urlopen") as mock_urlopen,
patch("jwt.PyJWKClient.fetch_data") as mock_fetch_data,
):
mock_oidc_config_response = MagicMock()
mock_oidc_config_response.read.return_value = json.dumps(
mock_oidc_config
).encode()
mock_oidc_config_response.status = 200

mock_urlopen.return_value.__enter__.return_value = mock_oidc_config_response
mock_fetch_data.return_value = mock_jwks
yield mock_urlopen


@pytest.fixture
def token_builder(test_key: "JWK"):
def build_token(payload: Dict[str, Any], key=None) -> str:
jwt_token = JWT(
header={k: test_key.get(k) for k in ["alg", "kid"]},
claims=payload,
)
jwt_token.make_signed_token(key or test_key)
return jwt_token.serialize()

return build_token


@pytest.fixture
def test_app():
app = FastAPI()

@app.get("/test-route")
def test():
return {"message": "Hello World"}

return app


@pytest.fixture
def test_client(test_app):
return testclient.TestClient(test_app)


def test_oidc_auth_initialization(mock_jwks: MagicMock):
"""
Auth object is initialized with the correct dependencies.
"""
openid_configuration_url = "https://example.com/.well-known/openid-configuration"
auth = OpenIdConnectAuth(openid_configuration_url=openid_configuration_url)
assert auth.jwks_client is not None
assert auth.auth_scheme is not None
assert auth.valid_token_dependency is not None
mock_jwks.assert_called_once_with(openid_configuration_url)


def test_auth_token_valid(token_builder):
"""
Auth token dependency returns the token payload when the token is valid.
"""
token = token_builder({"scope": "test_scope"})

auth = OpenIdConnectAuth(
openid_configuration_url="https://example.com/.well-known/openid-configuration"
)

token_payload = auth.valid_token_dependency(
auth_header=f"Bearer {token}", required_scopes=Security([])
)
assert token_payload["scope"] == "test_scope"


def test_auth_token_invalid_audience(token_builder):
"""
Auth token dependency throws 401 when the token audience is invalid.
"""
token = token_builder({"scope": "test_scope", "aud": "test_audience"})

auth = OpenIdConnectAuth(
openid_configuration_url="https://example.com/.well-known/openid-configuration"
)

with pytest.raises(HTTPException) as exc_info:
auth.valid_token_dependency(
auth_header=f"Bearer {token}", required_scopes=Security([])
)

assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert exc_info.value.detail == "Could not validate credentials"
assert isinstance(exc_info.value.__cause__, jwt.exceptions.InvalidAudienceError)


def test_auth_token_invalid_signature(token_builder):
"""
Auth token dependency throws 401 when the token signature is invalid.
"""
other_key = JWK.generate(
kty="RSA", size=2048, kid="test", use="sig", e="AQAB", alg="RS256"
)
token = token_builder({"scope": "test_scope", "aud": "test_audience"}, other_key)

auth = OpenIdConnectAuth(
openid_configuration_url="https://example.com/.well-known/openid-configuration"
)

with pytest.raises(HTTPException) as exc_info:
auth.valid_token_dependency(
auth_header=f"Bearer {token}", required_scopes=Security([])
)

assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert exc_info.value.detail == "Could not validate credentials"
assert isinstance(exc_info.value.__cause__, jwt.exceptions.InvalidSignatureError)


@pytest.mark.parametrize(
"token",
[
"foo",
"Bearer foo",
"Bearer foo.bar.xyz",
"Basic foo",
],
)
def test_auth_token_invalid_token(token):
"""
Auth token dependency throws 401 when the token is invalid.
"""
auth = OpenIdConnectAuth(
openid_configuration_url="https://example.com/.well-known/openid-configuration"
)

with pytest.raises(HTTPException) as exc_info:
auth.valid_token_dependency(auth_header=token, required_scopes=Security([]))

assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert exc_info.value.detail == "Could not validate credentials"


def test_apply_auth_dependencies(test_app, test_client):
auth = OpenIdConnectAuth(
openid_configuration_url="https://example.com/.well-known/openid-configuration"
)

for route in test_app.routes:
auth.apply_auth_dependencies(
api_route=route, required_token_scopes=["test_scope"]
)

resp = test_client.get("/test-route")
assert resp.json() == {"detail": "Not authenticated"}
assert resp.status_code == status.HTTP_403_FORBIDDEN


@pytest.mark.parametrize(
"required_sent_response",
[
("a", "b", status.HTTP_401_UNAUTHORIZED),
("a b c", "a b", status.HTTP_401_UNAUTHORIZED),
("a", "a", status.HTTP_200_OK),
(None, None, status.HTTP_200_OK),
(None, "a", status.HTTP_200_OK),
("a b c", "d c b a", status.HTTP_200_OK),
],
)
def test_reject_wrong_scope(
test_app, test_client, token_builder, required_sent_response
):
auth = OpenIdConnectAuth(
openid_configuration_url="https://example.com/.well-known/openid-configuration"
)

scope_required, scope_sent, expected_status = required_sent_response
for route in test_app.routes:
auth.apply_auth_dependencies(
api_route=route,
required_token_scopes=scope_required.split(" ") if scope_required else None,
)

token = token_builder({"scope": scope_sent})
resp = test_client.get("/test-route", headers={"Authorization": f"Bearer {token}"})
assert resp.status_code == expected_status

0 comments on commit 4c1cdd7

Please sign in to comment.