diff --git a/app/api/routers/query.py b/app/api/routers/query.py index c5a526a..d6f0e5e 100644 --- a/app/api/routers/query.py +++ b/app/api/routers/query.py @@ -1,12 +1,32 @@ """Router for query path operations.""" -from fastapi import APIRouter, Depends, Response, status +from fastapi import APIRouter, Depends, HTTPException, Response, status +from fastapi.security import OAuth2 -from .. import crud +from .. import crud, security from ..models import CombinedQueryResponse, QueryModel +from ..security import verify_token + +# from fastapi.security import open_id_connect_url + router = APIRouter(prefix="/query", tags=["query"]) +# Adapted from info in https://github.com/tiangolo/fastapi/discussions/9137#discussioncomment-5157382 +oauth2_scheme = OAuth2( + flows={ + "implicit": { + "authorizationUrl": "https://accounts.google.com/o/oauth2/auth", + } + }, + # Don't automatically error out when request is not authenticated, to support optional authentication + auto_error=False, +) +# NOTE: Can also explicitly use OpenID Connect because Google supports it - results in the same behavior as the OAuth2 scheme above. +# openid_connect_scheme = open_id_connect_url.OpenIdConnect( +# openIdConnectUrl="https://accounts.google.com/.well-known/openid-configuration" +# ) + # We use the Response parameter below to change the status code of the response while still being able to validate the returned data using the response model. # (see https://fastapi.tiangolo.com/advanced/response-change-status-code/ for more info). @@ -16,9 +36,33 @@ # example responses for different status codes in the OpenAPI docs (less relevant for now since there is only one response model). @router.get("/", response_model=CombinedQueryResponse) async def get_query( - response: Response, query: QueryModel = Depends(QueryModel) + response: Response, + query: QueryModel = Depends(QueryModel), + token: str | None = Depends(oauth2_scheme), ): """When a GET request is sent, return list of dicts corresponding to subject-level metadata aggregated by dataset.""" + # NOTE: Currently, when the request is unauthenticated (missing or malformed authorization header -> missing token), + # the default response is a 403 Forbidden error. + # This doesn't fully align with HTTP status code conventions: + # - 401 Unauthorized should be used when the client lacks authentication credentials + # - 403 Forbidden should be used when the client has been authenticated but lacks the required permissions + # If we really care about returning a 401 Unauthorized error, we can use auto_error=False + # when creating the OAuth2 object and raise a custom HTTPException. + # See also https://github.com/tiangolo/fastapi/discussions/9130 + # if not token: + # raise HTTPException( + # status_code=status.HTTP_401_UNAUTHORIZED, + # detail="Not authenticated", + # headers={"WWW-Authenticate": "Bearer"}, + # ) + if security.AUTH_ENABLED: + if token is None: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authenticated", + ) + verify_token(token) + response_dict = await crud.get( query.min_age, query.max_age, diff --git a/app/api/security.py b/app/api/security.py new file mode 100644 index 0000000..cb159a7 --- /dev/null +++ b/app/api/security.py @@ -0,0 +1,43 @@ +import os + +from fastapi import HTTPException, status +from fastapi.security.utils import get_authorization_scheme_param +from google.auth.exceptions import GoogleAuthError +from google.auth.transport import requests +from google.oauth2 import id_token + +AUTH_ENABLED = os.environ.get("NB_ENABLE_AUTH", "True").lower() == "true" +CLIENT_ID = os.environ.get("NB_QUERY_CLIENT_ID", None) + + +def check_client_id(): + """Check if the CLIENT_ID environment variable is set.""" + # By default, if CLIENT_ID is not provided to verify_oauth2_token, + # Google will simply skip verifying the audience claim of ID tokens. + # This however can be a security risk, so we mandate that CLIENT_ID is set. + if AUTH_ENABLED and CLIENT_ID is None: + raise ValueError( + "Authentication has been enabled (NB_ENABLE_AUTH) but the environment variable NB_QUERY_CLIENT_ID is not set. " + "Please set NB_QUERY_CLIENT_ID to the Google client ID for your Neurobagel query tool deployment, to verify the audience claim of ID tokens." + ) + + +def verify_token(token: str): + """Verify the Google ID token. Raise an HTTPException if the token is invalid.""" + # Adapted from https://developers.google.com/identity/gsi/web/guides/verify-google-id-token#python + try: + # Extract the token from the "Bearer" scheme + # (See https://github.com/tiangolo/fastapi/blob/master/fastapi/security/oauth2.py#L473-L485) + # TODO: Check also if scheme of token is "Bearer"? + _, param = get_authorization_scheme_param(token) + id_info = id_token.verify_oauth2_token( + param, requests.Request(), CLIENT_ID + ) + # TODO: Remove print statement or turn into logging + print("Token verified: ", id_info) + except (GoogleAuthError, ValueError) as exc: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f"Invalid token: {exc}", + headers={"WWW-Authenticate": "Bearer"}, + ) from exc diff --git a/app/main.py b/app/main.py index e88281a..6e6de54 100644 --- a/app/main.py +++ b/app/main.py @@ -11,6 +11,7 @@ from .api import utility as util from .api.routers import attributes, nodes, query +from .api.security import check_client_id logger = logging.getLogger("nb-f-API") stdout_handler = logging.StreamHandler() @@ -26,6 +27,7 @@ async def lifespan(app: FastAPI): """ Collect and store locally defined and public node details for federation upon startup and clears the index upon shutdown. """ + check_client_id() await util.create_federation_node_index() yield util.FEDERATION_NODES.clear() diff --git a/requirements.txt b/requirements.txt index 6314e26..ed881d1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,17 @@ anyio==3.6.2 attrs==23.1.0 +cachetools==5.3.3 certifi==2023.7.22 cfgv==3.3.1 +charset-normalizer==3.3.2 click==8.1.3 +colorama==0.4.6 coverage==7.0.0 distlib==0.3.6 exceptiongroup==1.0.4 fastapi==0.95.2 filelock==3.8.0 +google-auth==2.32.0 h11==0.14.0 httpcore==0.16.2 httpx==0.23.1 @@ -22,23 +26,28 @@ orjson==3.8.6 packaging==21.3 pandas==1.5.2 platformdirs==2.5.4 -pluggy==1.5.0 +pluggy==1.0.0 pre-commit==2.20.0 +pyasn1==0.6.0 +pyasn1_modules==0.4.0 pydantic==1.10.2 pyparsing==3.0.9 -pytest==8.2.1 +pytest==7.2.0 pytest-asyncio==0.23.7 python-dateutil==2.8.2 pytz==2023.3.post1 PyYAML==6.0 referencing==0.31.1 +requests==2.31.0 rfc3986==1.5.0 rpds-py==0.13.2 +rsa==4.9 six==1.16.0 sniffio==1.3.0 starlette==0.27.0 toml==0.10.2 tomli==2.0.1 typing_extensions==4.4.0 +urllib3==2.2.0 uvicorn==0.20.0 virtualenv==20.16.7 diff --git a/tests/conftest.py b/tests/conftest.py index a8b4b5f..3d67611 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,33 @@ def test_app(): yield client +@pytest.fixture +def disable_auth(monkeypatch): + """ + Disable the authentication requirement for the API to skip startup checks + (for when the tested route does not require authentication). + """ + monkeypatch.setattr("app.api.security.AUTH_ENABLED", False) + + +@pytest.fixture() +def mock_verify_token(): + """Mock a successful token verification that does not raise any exceptions.""" + + def _verify_token(token): + return None + + return _verify_token + + +@pytest.fixture() +def set_mock_verify_token(monkeypatch, mock_verify_token): + """Set the verify_token function to a mock that does not raise any exceptions.""" + monkeypatch.setattr( + "app.api.routers.query.verify_token", mock_verify_token + ) + + @pytest.fixture(scope="function") def set_valid_test_federation_nodes(monkeypatch): """Set two correctly formatted federation nodes for a test function (mocks the result of reading/parsing available public and local nodes on startup).""" diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 3f9506f..56eb42d 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -14,7 +14,9 @@ }, ], ) -def test_nodes_discovery_endpoint(test_app, monkeypatch, local_nodes): +def test_nodes_discovery_endpoint( + test_app, monkeypatch, local_nodes, disable_auth +): """Test that a federation node index is correctly created from locally set and remote node lists.""" def mock_parse_nodes_as_dict(path): @@ -59,7 +61,7 @@ def mock_httpx_get(**kwargs): def test_failed_public_nodes_fetching_raises_warning( - test_app, monkeypatch, caplog + test_app, monkeypatch, disable_auth, caplog ): """Test that when request for remote list of public nodes fails, an informative warning is raised and the federation node index only includes local nodes.""" @@ -95,7 +97,7 @@ def mock_httpx_get(**kwargs): assert warn_substr in caplog.text -def test_unset_local_nodes_raises_warning(test_app, monkeypatch): +def test_unset_local_nodes_raises_warning(test_app, monkeypatch, disable_auth): """Test that when no local nodes are set, an informative warning is raised and the federation node index only includes remote nodes.""" def mock_parse_nodes_as_dict(path): @@ -166,7 +168,9 @@ def test_missing_local_nodes_file_does_not_raise_error(tmp_path): assert util.parse_nodes_as_dict(expected_file_path) == {} -def test_no_available_nodes_raises_error(monkeypatch, test_app, caplog): +def test_no_available_nodes_raises_error( + monkeypatch, test_app, disable_auth, caplog +): """Test that when no local or remote nodes are available, an informative error is raised.""" def mock_parse_nodes_as_dict(path): diff --git a/tests/test_query.py b/tests/test_query.py index b731e91..7693592 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -6,6 +6,12 @@ from fastapi import status +@pytest.fixture() +def mock_token(): + """Create a mock token that is well-formed for testing purposes.""" + return "Bearer foo" + + @pytest.fixture() def mocked_single_matching_dataset_result(): """Valid aggregate query result for a single matching dataset.""" @@ -29,6 +35,8 @@ def test_partial_node_failure_responses_handled_gracefully( test_app, set_valid_test_federation_nodes, mocked_single_matching_dataset_result, + mock_token, + set_mock_verify_token, caplog, ): """ @@ -50,7 +58,10 @@ async def mock_httpx_get(self, **kwargs): monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get) - response = test_app.get("/query/") + response = test_app.get( + "/query/", + headers={"Authorization": mock_token}, + ) assert response.status_code == status.HTTP_207_MULTI_STATUS assert response.json() == { @@ -104,6 +115,8 @@ def test_partial_node_request_failures_handled_gracefully( test_app, set_valid_test_federation_nodes, mocked_single_matching_dataset_result, + mock_token, + set_mock_verify_token, error_to_raise, expected_node_message, caplog, @@ -123,7 +136,10 @@ async def mock_httpx_get(self, **kwargs): monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get) - response = test_app.get("/query/") + response = test_app.get( + "/query/", + headers={"Authorization": mock_token}, + ) assert response.status_code == status.HTTP_207_MULTI_STATUS @@ -153,6 +169,8 @@ def test_all_nodes_failure_handled_gracefully( monkeypatch, test_app, mock_failed_connection_httpx_get, + mock_token, + set_mock_verify_token, set_valid_test_federation_nodes, caplog, ): @@ -164,7 +182,10 @@ def test_all_nodes_failure_handled_gracefully( httpx.AsyncClient, "get", mock_failed_connection_httpx_get ) - response = test_app.get("/query/") + response = test_app.get( + "/query/", + headers={"Authorization": mock_token}, + ) # We expect 3 logs here: one warning for each failed node, and one error for the overall failure assert len(caplog.records) == 3 @@ -186,6 +207,8 @@ def test_all_nodes_success_handled_gracefully( caplog, set_valid_test_federation_nodes, mocked_single_matching_dataset_result, + mock_token, + set_mock_verify_token, ): """ Test that when queries sent to all nodes succeed, the federation API response includes an overall success status and no errors. @@ -201,7 +224,10 @@ async def mock_httpx_get(self, **kwargs): monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get) - response = test_app.get("/query/") + response = test_app.get( + "/query/", + headers={"Authorization": mock_token}, + ) assert response.status_code == status.HTTP_200_OK @@ -210,3 +236,26 @@ async def mock_httpx_get(self, **kwargs): assert response["errors"] == [] assert len(response["responses"]) == 2 assert "Requests to all nodes succeeded (2/2)" in caplog.text + + +def test_query_without_token_succeeds_when_auth_disabled( + monkeypatch, + test_app, + set_valid_test_federation_nodes, + mocked_single_matching_dataset_result, + disable_auth, +): + """ + Test that when authentication is disabled, a federated query request without a token succeeds. + """ + + async def mock_httpx_get(self, **kwargs): + return httpx.Response( + status_code=200, json=[mocked_single_matching_dataset_result] + ) + + monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get) + + response = test_app.get("/query/") + + assert response.status_code == status.HTTP_200_OK diff --git a/tests/test_security.py b/tests/test_security.py new file mode 100644 index 0000000..b657168 --- /dev/null +++ b/tests/test_security.py @@ -0,0 +1,64 @@ +import pytest +from fastapi import HTTPException + +from app.api.security import verify_token + + +def test_missing_client_id_raises_error_when_auth_enabled( + monkeypatch, test_app +): + """Test that a missing client ID raises an error on startup when authentication is enabled.""" + # We're using what should be default values of CLIENT_ID and AUTH_ENABLED here + # (if the corresponding environment variables are unset), + # but we set the values explicitly here for clarity + monkeypatch.setattr("app.api.security.CLIENT_ID", None) + monkeypatch.setattr("app.api.security.AUTH_ENABLED", True) + + with pytest.raises(ValueError) as exc_info: + with test_app: + pass + + assert "NB_QUERY_CLIENT_ID is not set" in str(exc_info.value) + + +# Ignore startup warning that is unrelated to the current test +@pytest.mark.filterwarnings( + "ignore:No local Neurobagel nodes defined or found" +) +def test_missing_client_id_ignored_when_auth_disabled(monkeypatch, test_app): + """Test that a missing client ID does not raise an error when authentication is disabled.""" + monkeypatch.setattr("app.api.security.CLIENT_ID", None) + monkeypatch.setattr("app.api.security.AUTH_ENABLED", False) + + with test_app: + pass + + +@pytest.mark.parametrize( + "invalid_token", + ["Bearer faketoken", "Bearer", "faketoken", "fakescheme faketoken"], +) +def test_invalid_token_raises_error(invalid_token): + """Test that an invalid token raises an error from the verification process.""" + with pytest.raises(HTTPException) as exc_info: + verify_token(invalid_token) + + assert exc_info.value.status_code == 401 + assert "Invalid token" in exc_info.value.detail + + +@pytest.mark.parametrize( + "invalid_auth_header", + [{}, {"Authorization": ""}, {"badheader": "badvalue"}], +) +def test_query_with_malformed_auth_header_fails( + test_app, set_mock_verify_token, invalid_auth_header +): + """Test that a request to the /query route with a missing or malformed authorization header, fails .""" + + response = test_app.get( + "/query/", + headers=invalid_auth_header, + ) + + assert response.status_code == 403