diff --git a/app/auth/security.py b/app/auth/security.py index 721bc12a..9e108617 100644 --- a/app/auth/security.py +++ b/app/auth/security.py @@ -4,9 +4,9 @@ from passlib.hash import argon2 import jwt from jwt.exceptions import InvalidTokenError -from fastapi.security import OAuth2PasswordBearer -from fastapi import HTTPException, Depends, status -from app.models.users import User, TokenData, Token +from fastapi.security import OAuth2PasswordBearer, SecurityScopes +from fastapi import HTTPException, Depends, status, Security +from app.models.users import User, TokenData from app.config import Config from app.db import get_session from sqlmodel import Session @@ -16,7 +16,6 @@ ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 30 SECRET_KEY = Config.SECRET_KEY - oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") @@ -61,12 +60,15 @@ def authenticate_user(session, username: str, password: str) -> str | User | boo return user -def create_access_token(data: dict, expires_delta: timedelta | None = None) -> Token: +def create_access_token( + data: dict, scopes: list | None = None, expires_delta: timedelta | None = None +) -> str: """ Creates the JWT access token with an expiry time. Args: data: A dictionary containing the username. + scopes: A list of scopes assigned to the user. expires_delta: A timedelta of the expiry time of the token. Returns: @@ -77,6 +79,8 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None) -> T the expiry field to ensure it can still be read as a standalone object. """ to_encode = data.copy() + scopes = scopes or [] + to_encode.update({"scopes": scopes}) if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: @@ -88,7 +92,12 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None) -> T return encoded_jwt +def token_decode(token: str) -> dict: + return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + + async def get_current_user( + security_scopes: SecurityScopes, token: Annotated[str, Depends(oauth2_scheme)], session: Annotated[Session, Depends(get_session)], ): @@ -96,7 +105,9 @@ async def get_current_user( Checks the current user token to return a user. Args: + security_scopes: Security scopes user should have access to. token: Uses the oauth2 scheme to get the current JWT. + session: Uses the session object to get the current user. Returns: user: Returns the current user object by verifying against the JWT. @@ -105,11 +116,19 @@ async def get_current_user( HTTP_Exception: If authentication fails, a HTTP 401 Unauthorised error is raised with a message indicating that the credentials could not be validated. """ + credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) + + scopes_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not enough permissions", + headers={"WWW-Authenticate": f'Bearer scope="{security_scopes.scope_str}"'}, + ) + try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) username: str = payload.get("sub") @@ -120,15 +139,24 @@ async def get_current_user( user = session.get(User, token_data.username) if user is None: raise credentials_exception + + if security_scopes.scopes and not user.scopes: + raise scopes_exception + + for scope in security_scopes.scopes: + if scope not in user.scopes: + raise scopes_exception + return user async def get_current_active_user( - current_user: Annotated[User, Depends(get_current_user)], + current_user: Annotated[User, Security(get_current_user, scopes=[])], ): if current_user.disabled: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="User Disabled", ) + return current_user diff --git a/app/db/migrations/versions/d4223a91f1de_user_scopes.py b/app/db/migrations/versions/d4223a91f1de_user_scopes.py new file mode 100644 index 00000000..2f5f53a9 --- /dev/null +++ b/app/db/migrations/versions/d4223a91f1de_user_scopes.py @@ -0,0 +1,31 @@ +"""user scopes + +Revision ID: d4223a91f1de +Revises: c4b9d0057513 +Create Date: 2024-11-21 11:33:22.255372 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "d4223a91f1de" +down_revision: Union[str, None] = "c4b9d0057513" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("users", sa.Column("scopes", sa.JSON(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("users", "scopes") + # ### end Alembic commands ### diff --git a/app/models/users.py b/app/models/users.py index e1e5e1cd..ef968130 100644 --- a/app/models/users.py +++ b/app/models/users.py @@ -1,4 +1,22 @@ -from sqlmodel import Field, SQLModel +from sqlmodel import Field, SQLModel, JSON +from enum import Enum +from typing import List + + +class UserScopes(str, Enum): + READ = "read" + CREATE = "create" + UPDATE = "update" + DELETE = "delete" + + @classmethod + def as_list(cls): + # Iterate over the values only + return [member.value for member in cls] + + @classmethod + def as_dict(cls) -> dict: + return {member.name: member.value for member in cls} class Token(SQLModel): @@ -27,3 +45,4 @@ class User(SQLModel, table=True): email: str | None = None full_name: str | None = None disabled: bool = Field(default=False) + scopes: List[UserScopes] = Field(sa_type=JSON, default=[], nullable=True) diff --git a/app/routers/case_information.py b/app/routers/case_information.py index b95f03f5..b1786a16 100644 --- a/app/routers/case_information.py +++ b/app/routers/case_information.py @@ -1,18 +1,18 @@ +import structlog from typing import Sequence - -from fastapi import APIRouter, HTTPException, Depends - +from uuid import UUID +from fastapi import APIRouter, HTTPException, Security, Depends +from sqlmodel import Session, select from app.models.cases import ( CaseRequest, Case, CaseResponse, CaseUpdateRequest, ) -from sqlmodel import Session, select from app.db import get_session from app.auth.security import get_current_active_user -from uuid import UUID -import structlog +from app.models.users import UserScopes + logger = structlog.getLogger(__name__) @@ -25,7 +25,12 @@ ) -@router.get("/{case_id}", tags=["cases"], response_model=CaseResponse) +@router.get( + "/{case_id}", + tags=["cases"], + response_model=CaseResponse, + dependencies=[Security(get_current_active_user, scopes=[UserScopes.READ])], +) async def read_case(case_id: UUID, session: Session = Depends(get_session)) -> Case: case: Case | None = session.get(Case, case_id) if not case: @@ -33,13 +38,23 @@ async def read_case(case_id: UUID, session: Session = Depends(get_session)) -> C return case -@router.get("/", tags=["cases"]) +@router.get( + "/", + tags=["cases"], + dependencies=[Security(get_current_active_user, scopes=[UserScopes.READ])], +) async def read_all_cases(session: Session = Depends(get_session)) -> Sequence[Case]: cases = session.exec(select(Case)).all() return cases -@router.post("/", tags=["cases"], response_model=CaseResponse, status_code=201) +@router.post( + "/", + tags=["cases"], + response_model=CaseResponse, + status_code=201, + dependencies=[Security(get_current_active_user, scopes=[UserScopes.CREATE])], +) def create_case( request: CaseRequest, session: Session = Depends(get_session), @@ -50,7 +65,12 @@ def create_case( return case -@router.put("/{case_id}", tags=["cases"], response_model=CaseResponse) +@router.put( + "/{case_id}", + tags=["cases"], + response_model=CaseResponse, + dependencies=[Security(get_current_active_user, scopes=[UserScopes.UPDATE])], +) def update_case( case_id: UUID, request: CaseUpdateRequest, session: Session = Depends(get_session) ): diff --git a/app/routers/security.py b/app/routers/security.py index 97ac13d8..a099906f 100644 --- a/app/routers/security.py +++ b/app/routers/security.py @@ -48,6 +48,8 @@ async def login_for_access_token( ) access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) access_token = create_access_token( - data={"sub": user.username}, expires_delta=access_token_expires + data={"sub": user.username}, + expires_delta=access_token_expires, + scopes=user.scopes, ) return Token(access_token=str(access_token), token_type="bearer") diff --git a/bin/add_users.py b/bin/add_users.py index f9e1045a..b99fc0c6 100644 --- a/bin/add_users.py +++ b/bin/add_users.py @@ -5,7 +5,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from app.db import get_session -from app.models.users import User +from app.models.users import User, UserScopes from app.auth.security import get_password_hash import logging @@ -28,6 +28,7 @@ def add_users(users_list_dict: list[dict]): username = user_info.get("username") password = user_info.get("password") disabled = user_info.get("disabled") + scopes = user_info.get("scopes", []) if not username or not password: logging.warning( @@ -43,7 +44,10 @@ def add_users(users_list_dict: list[dict]): password = get_password_hash(password) new_user = User( - username=username, hashed_password=password, disabled=disabled + username=username, + hashed_password=password, + disabled=disabled, + scopes=scopes, ) session.add(new_user) @@ -51,7 +55,12 @@ def add_users(users_list_dict: list[dict]): users_to_add = [ - {"username": "cla_admin", "password": "cla_admin", "disabled": False}, + { + "username": "cla_admin", + "password": "cla_admin", + "disabled": False, + "scopes": UserScopes.as_list(), + }, {"username": "janedoe", "password": "password", "disabled": True}, ] diff --git a/docs/source/documentation/case.html.md.erb b/docs/source/documentation/case.html.md.erb index 38c01315..94e9f69b 100644 --- a/docs/source/documentation/case.html.md.erb +++ b/docs/source/documentation/case.html.md.erb @@ -9,6 +9,8 @@ title: Case model ``` POST /cases ``` +### Scope +create A case can be created using the following request schema @@ -118,18 +120,24 @@ You will receive the following response schema: ``` GET /cases/ ``` +### Scope +read ### Gets all case information for a given case id ``` GET /cases/{case_id} ``` +### Scope +read ### Modify a case ``` -PATCH /cases/{case_id} +PUT /cases/{case_id} ``` +### Scope +update A case can be modified by providing a new case with the following schema. diff --git a/docs/source/documentation/scopes.erb b/docs/source/documentation/scopes.erb new file mode 100644 index 00000000..a1fe17ec --- /dev/null +++ b/docs/source/documentation/scopes.erb @@ -0,0 +1,24 @@ +--- +title: Scopes +--- + +## Scopes +There are four scopes which each correspond to a http method: +- create +- read +- update +- delete + +Scopes are only assignable by the CLA team and you need to request which scopes you api requires as part of your account creation process + +### Adding/updating user scopes +The following overwrites the current user scopes with the ones given +`python manage.py user-scopes-add --scope=create --scope=read --scope=update` + +### Listing user scopes +To get a list of current scopes assign to a user do +`python manage.py user-scopes-list ` + +### List routes with scopes +To get a list of all the routes which includes their scopes do +`python manage.py routes-list` \ No newline at end of file diff --git a/manage.py b/manage.py new file mode 100755 index 00000000..b16150d9 --- /dev/null +++ b/manage.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python + +import json +import typer +from typing import List +from typing_extensions import Annotated +from sqlmodel import Session +from sqlmodel.sql.expression import select +from fastapi import Depends +from fastapi.params import Security +from app.models.users import User, UserScopes +from app.db import get_session +from app.main import create_app + +app = typer.Typer() + +session: Session = next(get_session()) + + +@app.command() +def user_scopes_add( + username: str, scope: Annotated[List[UserScopes], typer.Option()] +) -> None: + statement = select(User).where(User.username == username) + user: User = session.exec(statement).first() + print( + f"Replacing user {user.username} current scopes {user.scopes} with new scopes {scope}..." + ) + user.scopes = scope + session.add(user) + session.commit() + print("Done") + + +@app.command() +def user_scopes_list(username: str) -> None: + statement = select(User).where(User.username == username) + user: User = session.exec(statement).first() + if not user.scopes: + print(f"{user.username} has no scopes") + return + print(f"{user.username} has scopes {user.scopes}") + + +@app.command() +def routes_list(): + fastapi_app = create_app() + routes = {} + for route in fastapi_app.routes: + dependencies = getattr(route, "dependencies", []) + routes[route.path] = {"scopes": get_scopes_from_dependencies(dependencies)} + + output = json.dumps(routes, indent=4) + print(output) + + +def get_scopes_from_dependencies(dependencies: List[Depends]): + scopes = [] + for dependency in dependencies: + if isinstance(dependency, Security): + items = getattr(dependency, "scopes", []) + for item in items: + scopes.append(item) + return scopes + + +if __name__ == "__main__": + app() diff --git a/requirements/source/requirements-base.in b/requirements/source/requirements-base.in index 914cd4ef..da676e0f 100644 --- a/requirements/source/requirements-base.in +++ b/requirements/source/requirements-base.in @@ -11,3 +11,4 @@ pyjwt passlib argon2_cffi structlog +typer \ No newline at end of file diff --git a/tests/auth/test_auth.py b/tests/auth/test_auth.py index c32639ff..2f7ca988 100644 --- a/tests/auth/test_auth.py +++ b/tests/auth/test_auth.py @@ -1,16 +1,20 @@ +from typing import List + from fastapi.testclient import TestClient +from sqlmodel import Session from app.auth.security import ( create_access_token, verify_password, get_password_hash, authenticate_user, + token_decode, ACCESS_TOKEN_EXPIRE_MINUTES, ) from freezegun import freeze_time import pytest from jwt import ExpiredSignatureError from datetime import timedelta, datetime -from app.models.users import User +from app.models.users import User, UserScopes def test_auth_fail_case(client: TestClient): @@ -82,10 +86,12 @@ def test_password_hashing(): def test_create_token(): - jwt = create_access_token( - data={"sub": "cla_admin"}, expires_delta=timedelta(minutes=30) + token = create_access_token( + data={"sub": "cla_admin"}, expires_delta=timedelta(minutes=30), scopes=[] ) - assert len(jwt) == 129 + expected_keys = ["sub", "scopes", "exp"] + token_data = token_decode(token) + assert list(token_data.keys()) == expected_keys def test_token_with_no_expire(): @@ -101,9 +107,60 @@ def test_token_with_no_expire(): def test_token_defined_expiry(): with freeze_time("2024-08-23 10:00:00"): token = create_access_token( - data={"sub": "cla_admin"}, expires_delta=timedelta(minutes=1) + data={"sub": "cla_admin"}, + expires_delta=timedelta(minutes=1), ) assert token is not None with freeze_time("2024-08-23 10:05:00"): assert pytest.raises(ExpiredSignatureError) + + +def test_scopes_missing_scopes(client: TestClient, session: Session): + # Create the test user with no given scopes + # They should not be able to access the GET /cases resource as that requires the UserScopes.READ scope + assert_user_scope(session, client, [], "/cases", 401) + + +def test_scopes_incorrect_scope(client: TestClient, session: Session): + # Create the test user with a UserScopes.CREATE scope + # They should not be able to access the GET /cases resource as that requires the UserScopes.READ scope + assert_user_scope(session, client, [UserScopes.CREATE], "/cases", 401) + + +def test_scopes_correct_scope(client: TestClient, session: Session): + # Create the test user with a UserScopes.READ scope + # They should be able to access the GET /cases resource as that requires the UserScopes.READ scope + assert_user_scope(session, client, [UserScopes.READ], "/cases", 200) + + +def assert_user_scope( + session: Session, + client: TestClient, + scopes: List[UserScopes], + resource: str, + expected_status_code, +): + # Create the test user with given scopes + username = "test_assert_user_scope" + password = "" + user = User( + username=username, hashed_password=get_password_hash(password), scopes=scopes + ) + session.add(user) + session.commit() + + # Obtain an access token for the test user + response = client.post( + "/token", + data={"username": username, "password": password}, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + token = response.json()["access_token"] + token_data = token_decode(token) + assert token_data["scopes"] == scopes + + # Attempt to access a resource with the test user + client.headers["Authorization"] = f"Bearer {token}" + response = client.get(resource) + assert response.status_code == expected_status_code diff --git a/tests/conftest.py b/tests/conftest.py index a075e060..c4a1868e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,7 @@ from sqlalchemy.orm import sessionmaker from app.auth.security import get_password_hash -from app.models.users import User +from app.models.users import User, UserScopes SECRET_KEY = "TEST_KEY" @@ -33,7 +33,15 @@ def session_fixture(): password = get_password_hash(password) new_user = User( - username=username, hashed_password=password, disabled=disabled + username=username, + hashed_password=password, + disabled=disabled, + scopes=[ + UserScopes.CREATE, + UserScopes.READ, + UserScopes.UPDATE, + UserScopes.DELETE, + ], ) db_session.add(new_user)