diff --git a/src/app/api/routes/users.py b/src/app/api/routes/users.py index 419abfd9..0a83bb12 100644 --- a/src/app/api/routes/users.py +++ b/src/app/api/routes/users.py @@ -4,7 +4,7 @@ from app.api import routing, security from app.db import users -from app.api.schemas import UserInfo, UserRead, UserAuth, UserCreation +from app.api.schemas import UserInfo, UserCred, UserRead, UserAuth, UserCreation from app.api.deps import get_current_user @@ -16,11 +16,16 @@ async def get_my_user(me: UserRead = Security(get_current_user, scopes=["me"])): return me -@router.put("/update-me", response_model=UserRead) -async def update_me(payload: UserInfo, me: UserRead = Security(get_current_user, scopes=["me"])): +@router.put("/update-info", response_model=UserRead) +async def update_my_info(payload: UserInfo, me: UserRead = Security(get_current_user, scopes=["me"])): return await routing.update_entry(users, payload, me.id) +@router.put("/update-pwd", response_model=UserInfo) +async def update_my_password(payload: UserCred, me: UserRead = Security(get_current_user, scopes=["me"])): + return await routing.update_user_pwd(users, payload, me.id) + + @router.post("/", response_model=UserRead, status_code=201) async def create_user(payload: UserAuth, _=Security(get_current_user, scopes=["admin"])): return await routing.create_user(users, payload) @@ -45,6 +50,15 @@ async def update_user( return await routing.update_entry(users, payload, user_id) +@router.put("/{user_id}/pwd", response_model=UserInfo) +async def reset_password( + payload: UserCred, + user_id: int = Path(..., gt=0), + _=Security(get_current_user, scopes=["admin"]) +): + return await routing.update_user_pwd(users, payload, user_id) + + @router.delete("/{user_id}/", response_model=UserRead) async def delete_user(user_id: int = Path(..., gt=0), _=Security(get_current_user, scopes=["admin"])): return await routing.delete_entry(users, user_id) diff --git a/src/app/api/routing.py b/src/app/api/routing.py index 0124582d..3b3eb07e 100644 --- a/src/app/api/routing.py +++ b/src/app/api/routing.py @@ -4,7 +4,7 @@ from fastapi import HTTPException, Path from app.api import crud, security -from app.api.schemas import UserAuth, UserCreation +from app.api.schemas import UserAuth, UserCreation, UserCred, UserCredHash async def create_entry(table: Table, payload: BaseModel): @@ -54,3 +54,14 @@ async def create_user(user_table: Table, payload: UserAuth): pwd = await security.hash_password(payload.password) payload = UserCreation(username=payload.username, hashed_password=pwd, scopes=payload.scopes) return await create_entry(user_table, payload) + + +async def update_user_pwd(user_table: Table, payload: UserCred, entry_id: int = Path(..., gt=0)): + entry = await get_entry(user_table, entry_id) + # Hash the password + pwd = await security.hash_password(payload.password) + # Update the password + payload = UserCredHash(hashed_password=pwd) + await crud.put(entry_id, payload, user_table) + # Return non-sensitive information + return {"username": entry["username"]} diff --git a/src/app/api/schemas.py b/src/app/api/schemas.py index 27614a2a..116eebde 100644 --- a/src/app/api/schemas.py +++ b/src/app/api/schemas.py @@ -5,31 +5,42 @@ from app.db import SiteType, EventType, MediaType, AlertType +# Template class +class _CreatedAt(BaseModel): + created_at: datetime = None + + @staticmethod + @validator('created_at', pre=True, always=True) + def default_ts_created(v): + return v or datetime.utcnow() + + # Abstract information about a user class UserInfo(BaseModel): username: str = Field(..., min_length=3, max_length=50) +# Sensitive information about the user +class UserCred(BaseModel): + password: str + + +class UserCredHash(BaseModel): + hashed_password: str + + # Visible info -class UserRead(UserInfo): +class UserRead(UserInfo, _CreatedAt): id: int = Field(..., gt=0) - created_at: datetime = None - - @staticmethod - @validator('created_at', pre=True, always=True) - def default_ts_created(v): - return v or datetime.utcnow() # Authentication request -class UserAuth(UserInfo): - password: str +class UserAuth(UserInfo, UserCred): scopes: Optional[str] = "me" # Creation payload -class UserCreation(UserInfo): - hashed_password: str +class UserCreation(UserInfo, UserCredHash): scopes: str @@ -50,14 +61,8 @@ class SiteIn(BaseModel): type: SiteType = SiteType.tower -class SiteOut(SiteIn): +class SiteOut(SiteIn, _CreatedAt): id: int = Field(..., gt=0) - created_at: datetime = None - - @staticmethod - @validator('created_at', pre=True, always=True) - def default_ts_created(v): - return v or datetime.utcnow() class EventIn(BaseModel): @@ -68,14 +73,8 @@ class EventIn(BaseModel): start_ts: datetime = None -class EventOut(EventIn): +class EventOut(EventIn, _CreatedAt): id: int = Field(..., gt=0) - created_at: datetime = None - - @staticmethod - @validator('created_at', pre=True, always=True) - def default_ts_created(v): - return v or datetime.utcnow() class DeviceIn(BaseModel): @@ -90,14 +89,8 @@ class DeviceIn(BaseModel): last_ping: datetime = None -class DeviceOut(DeviceIn): +class DeviceOut(DeviceIn, _CreatedAt): id: int = Field(..., gt=0) - created_at: datetime = None - - @staticmethod - @validator('created_at', pre=True, always=True) - def default_ts_created(v): - return v or datetime.utcnow() class MediaIn(BaseModel): @@ -105,14 +98,8 @@ class MediaIn(BaseModel): type: MediaType = MediaType.image -class MediaOut(MediaIn): +class MediaOut(MediaIn, _CreatedAt): id: int = Field(..., gt=0) - created_at: datetime = None - - @staticmethod - @validator('created_at', pre=True, always=True) - def default_ts_created(v): - return v or datetime.utcnow() class InstallationIn(BaseModel): @@ -127,14 +114,8 @@ class InstallationIn(BaseModel): end_ts: datetime = None -class InstallationOut(InstallationIn): +class InstallationOut(InstallationIn, _CreatedAt): id: int = Field(..., gt=0) - created_at: datetime = None - - @staticmethod - @validator('created_at', pre=True, always=True) - def default_ts_created(v): - return v or datetime.utcnow() class AlertIn(BaseModel): @@ -147,11 +128,5 @@ class AlertIn(BaseModel): is_acknowledged: bool = False -class AlertOut(AlertIn): +class AlertOut(AlertIn, _CreatedAt): id: int = Field(..., gt=0) - created_at: datetime = None - - @staticmethod - @validator('created_at', pre=True, always=True) - def default_ts_created(v): - return v or datetime.utcnow() diff --git a/src/tests/test_users.py b/src/tests/test_users.py index fb7a0915..db1192c2 100644 --- a/src/tests/test_users.py +++ b/src/tests/test_users.py @@ -103,7 +103,7 @@ async def mock_get_all(table, query_filter=None): assert [{k: v for k, v in r.items() if k != 'created_at'} for r in response.json()] == test_data -def test_update_user(test_app, monkeypatch): +def test_update_user_info(test_app, monkeypatch): async def mock_get(entry_id, table): return True @@ -121,7 +121,7 @@ async def mock_put(entry_id, payload, table): # test update connected_user (dont alter username for other tests) test_update_data = {"username": "connected_user"} - response = test_app.put("/users/update-me", data=json.dumps(test_update_data)) + response = test_app.put("/users/update-info", data=json.dumps(test_update_data)) assert response.status_code == 200 assert {k: v for k, v in response.json().items() if k != "created_at"} == {**test_update_data, "id": 99} @@ -136,13 +136,78 @@ async def mock_put(entry_id, payload, table): [0, {"username": "foo"}, 422], ], ) -def test_update_user_invalid(test_app, monkeypatch, user_id, payload, status_code): +def test_update_user_info_invalid(test_app, monkeypatch, user_id, payload, status_code): async def mock_get(entry_id, table): return None monkeypatch.setattr(crud, "get", mock_get) - response = test_app.put(f"/users/{user_id}/", data=json.dumps(payload),) + response = test_app.put(f"/users/{user_id}/", data=json.dumps(payload)) + assert response.status_code == status_code, print(payload) + + +def test_update_user_pwd(test_app, monkeypatch): + + test_data = [ + {"id": 1, "username": "someone", "hashed_password": "first_hashed", "scopes": "me"}, + {"id": 99, "username": "connected_user", "hashed_password": "first_hashed", "scopes": "me"} + ] + + async def mock_get(entry_id, table): + for entry in test_data: + if entry['id'] == entry_id: + return entry + return None + + monkeypatch.setattr(crud, "get", mock_get) + + async def mock_put(entry_id, payload, table): + return entry_id + + monkeypatch.setattr(crud, "put", mock_put) + + test_update_data = {"password": "my_password"} + test_response = {"username": "someone"} + response = test_app.put("/users/1/pwd", data=json.dumps(test_update_data)) + assert response.status_code == 200 + assert response.json() == test_response + + # test update connected_user (dont alter username for other tests) + test_update_data = {"password": "my_password"} + test_response = {"username": "connected_user"} + response = test_app.put("/users/update-pwd", data=json.dumps(test_update_data)) + assert response.status_code == 200 + assert response.json() == test_response + + +@pytest.mark.parametrize( + "user_id, payload, status_code", + [ + [1, {}, 422], + [1, {"description": "bar"}, 422], + [999, {"password": "my_pwd"}, 404], + [0, {"password": "my_pwd"}, 422], + ], +) +def test_update_user_pwd_invalid(test_app, monkeypatch, user_id, payload, status_code): + test_data = [ + {"id": 1, "username": "someone", "hashed_password": "first_hashed", "scopes": "me"}, + {"id": 99, "username": "connected_user", "hashed_password": "first_hashed", "scopes": "me"} + ] + + async def mock_get(entry_id, table): + for entry in test_data: + if entry['id'] == entry_id: + return entry + return None + monkeypatch.setattr(crud, "get", mock_get) + + async def mock_put(entry_id, payload, table): + return entry_id + + monkeypatch.setattr(crud, "put", mock_put) + + response = test_app.put(f"/users/{user_id}/pwd", data=json.dumps(payload)) assert response.status_code == status_code, print(payload)