From 90b8561f3dba9e7f4ecf36fcbfef3c7a85ae9263 Mon Sep 17 00:00:00 2001 From: F-G Fernandez Date: Fri, 6 Nov 2020 01:04:50 +0100 Subject: [PATCH 1/5] refactor: Renamed update routes of users --- src/app/api/routes/users.py | 4 ++-- src/tests/test_users.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/app/api/routes/users.py b/src/app/api/routes/users.py index 419abfd9..875dfa98 100644 --- a/src/app/api/routes/users.py +++ b/src/app/api/routes/users.py @@ -16,8 +16,8 @@ async def get_my_user(me: UserRead = Security(get_current_user, scopes=["me"])): return me -@router.put("/update-me", response_model=UserRead) -async def update_me(payload: UserInfo, me: UserRead = Security(get_current_user, scopes=["me"])): +@router.put("/update-info", response_model=UserRead) +async def update_my_info(payload: UserInfo, me: UserRead = Security(get_current_user, scopes=["me"])): return await routing.update_entry(users, payload, me.id) diff --git a/src/tests/test_users.py b/src/tests/test_users.py index fb7a0915..5f31892e 100644 --- a/src/tests/test_users.py +++ b/src/tests/test_users.py @@ -103,7 +103,7 @@ async def mock_get_all(table, query_filter=None): assert [{k: v for k, v in r.items() if k != 'created_at'} for r in response.json()] == test_data -def test_update_user(test_app, monkeypatch): +def test_update_user_info(test_app, monkeypatch): async def mock_get(entry_id, table): return True @@ -121,7 +121,7 @@ async def mock_put(entry_id, payload, table): # test update connected_user (dont alter username for other tests) test_update_data = {"username": "connected_user"} - response = test_app.put("/users/update-me", data=json.dumps(test_update_data)) + response = test_app.put("/users/update-info", data=json.dumps(test_update_data)) assert response.status_code == 200 assert {k: v for k, v in response.json().items() if k != "created_at"} == {**test_update_data, "id": 99} @@ -136,13 +136,13 @@ async def mock_put(entry_id, payload, table): [0, {"username": "foo"}, 422], ], ) -def test_update_user_invalid(test_app, monkeypatch, user_id, payload, status_code): +def test_update_user_info_invalid(test_app, monkeypatch, user_id, payload, status_code): async def mock_get(entry_id, table): return None monkeypatch.setattr(crud, "get", mock_get) - response = test_app.put(f"/users/{user_id}/", data=json.dumps(payload),) + response = test_app.put(f"/users/{user_id}/", data=json.dumps(payload)) assert response.status_code == status_code, print(payload) From b21c42f433795894ae7587faf22c626c5d53f0a8 Mon Sep 17 00:00:00 2001 From: F-G Fernandez Date: Fri, 6 Nov 2020 01:05:13 +0100 Subject: [PATCH 2/5] feat: Added update password function --- src/app/api/routing.py | 13 ++++++++++++- src/app/api/schemas.py | 9 +++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/app/api/routing.py b/src/app/api/routing.py index 0124582d..3b3eb07e 100644 --- a/src/app/api/routing.py +++ b/src/app/api/routing.py @@ -4,7 +4,7 @@ from fastapi import HTTPException, Path from app.api import crud, security -from app.api.schemas import UserAuth, UserCreation +from app.api.schemas import UserAuth, UserCreation, UserCred, UserCredHash async def create_entry(table: Table, payload: BaseModel): @@ -54,3 +54,14 @@ async def create_user(user_table: Table, payload: UserAuth): pwd = await security.hash_password(payload.password) payload = UserCreation(username=payload.username, hashed_password=pwd, scopes=payload.scopes) return await create_entry(user_table, payload) + + +async def update_user_pwd(user_table: Table, payload: UserCred, entry_id: int = Path(..., gt=0)): + entry = await get_entry(user_table, entry_id) + # Hash the password + pwd = await security.hash_password(payload.password) + # Update the password + payload = UserCredHash(hashed_password=pwd) + await crud.put(entry_id, payload, user_table) + # Return non-sensitive information + return {"username": entry["username"]} diff --git a/src/app/api/schemas.py b/src/app/api/schemas.py index 27614a2a..230282b5 100644 --- a/src/app/api/schemas.py +++ b/src/app/api/schemas.py @@ -10,6 +10,15 @@ class UserInfo(BaseModel): username: str = Field(..., min_length=3, max_length=50) +# Sensitive information about the user +class UserCred(BaseModel): + password: str + + +class UserCredHash(BaseModel): + hashed_password: str + + # Visible info class UserRead(UserInfo): id: int = Field(..., gt=0) From 2afdef2f675dbdca52c8dc2990eb5cab2f9cb8e7 Mon Sep 17 00:00:00 2001 From: F-G Fernandez Date: Fri, 6 Nov 2020 01:05:30 +0100 Subject: [PATCH 3/5] feat: Added self and admin password update routes --- src/app/api/routes/users.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/app/api/routes/users.py b/src/app/api/routes/users.py index 875dfa98..0a83bb12 100644 --- a/src/app/api/routes/users.py +++ b/src/app/api/routes/users.py @@ -4,7 +4,7 @@ from app.api import routing, security from app.db import users -from app.api.schemas import UserInfo, UserRead, UserAuth, UserCreation +from app.api.schemas import UserInfo, UserCred, UserRead, UserAuth, UserCreation from app.api.deps import get_current_user @@ -21,6 +21,11 @@ async def update_my_info(payload: UserInfo, me: UserRead = Security(get_current_ return await routing.update_entry(users, payload, me.id) +@router.put("/update-pwd", response_model=UserInfo) +async def update_my_password(payload: UserCred, me: UserRead = Security(get_current_user, scopes=["me"])): + return await routing.update_user_pwd(users, payload, me.id) + + @router.post("/", response_model=UserRead, status_code=201) async def create_user(payload: UserAuth, _=Security(get_current_user, scopes=["admin"])): return await routing.create_user(users, payload) @@ -45,6 +50,15 @@ async def update_user( return await routing.update_entry(users, payload, user_id) +@router.put("/{user_id}/pwd", response_model=UserInfo) +async def reset_password( + payload: UserCred, + user_id: int = Path(..., gt=0), + _=Security(get_current_user, scopes=["admin"]) +): + return await routing.update_user_pwd(users, payload, user_id) + + @router.delete("/{user_id}/", response_model=UserRead) async def delete_user(user_id: int = Path(..., gt=0), _=Security(get_current_user, scopes=["admin"])): return await routing.delete_entry(users, user_id) From b41127b3e50eb2f911ce3130dacb459af99563a5 Mon Sep 17 00:00:00 2001 From: F-G Fernandez Date: Fri, 6 Nov 2020 01:05:36 +0100 Subject: [PATCH 4/5] test: Added unittests --- src/tests/test_users.py | 65 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/src/tests/test_users.py b/src/tests/test_users.py index 5f31892e..db1192c2 100644 --- a/src/tests/test_users.py +++ b/src/tests/test_users.py @@ -146,6 +146,71 @@ async def mock_get(entry_id, table): assert response.status_code == status_code, print(payload) +def test_update_user_pwd(test_app, monkeypatch): + + test_data = [ + {"id": 1, "username": "someone", "hashed_password": "first_hashed", "scopes": "me"}, + {"id": 99, "username": "connected_user", "hashed_password": "first_hashed", "scopes": "me"} + ] + + async def mock_get(entry_id, table): + for entry in test_data: + if entry['id'] == entry_id: + return entry + return None + + monkeypatch.setattr(crud, "get", mock_get) + + async def mock_put(entry_id, payload, table): + return entry_id + + monkeypatch.setattr(crud, "put", mock_put) + + test_update_data = {"password": "my_password"} + test_response = {"username": "someone"} + response = test_app.put("/users/1/pwd", data=json.dumps(test_update_data)) + assert response.status_code == 200 + assert response.json() == test_response + + # test update connected_user (dont alter username for other tests) + test_update_data = {"password": "my_password"} + test_response = {"username": "connected_user"} + response = test_app.put("/users/update-pwd", data=json.dumps(test_update_data)) + assert response.status_code == 200 + assert response.json() == test_response + + +@pytest.mark.parametrize( + "user_id, payload, status_code", + [ + [1, {}, 422], + [1, {"description": "bar"}, 422], + [999, {"password": "my_pwd"}, 404], + [0, {"password": "my_pwd"}, 422], + ], +) +def test_update_user_pwd_invalid(test_app, monkeypatch, user_id, payload, status_code): + test_data = [ + {"id": 1, "username": "someone", "hashed_password": "first_hashed", "scopes": "me"}, + {"id": 99, "username": "connected_user", "hashed_password": "first_hashed", "scopes": "me"} + ] + + async def mock_get(entry_id, table): + for entry in test_data: + if entry['id'] == entry_id: + return entry + return None + monkeypatch.setattr(crud, "get", mock_get) + + async def mock_put(entry_id, payload, table): + return entry_id + + monkeypatch.setattr(crud, "put", mock_put) + + response = test_app.put(f"/users/{user_id}/pwd", data=json.dumps(payload)) + assert response.status_code == status_code, print(payload) + + def test_remove_user(test_app, monkeypatch): test_data = {"username": "someone", "id": 1} From 7c84e7c16729f0b5d95f15cea5bc7c214e1c27e4 Mon Sep 17 00:00:00 2001 From: frgfm Date: Fri, 6 Nov 2020 11:40:23 +0100 Subject: [PATCH 5/5] refactor: Refactored schemas Added a template class for automatic creation date resolution --- src/app/api/schemas.py | 72 +++++++++++------------------------------- 1 file changed, 19 insertions(+), 53 deletions(-) diff --git a/src/app/api/schemas.py b/src/app/api/schemas.py index 230282b5..116eebde 100644 --- a/src/app/api/schemas.py +++ b/src/app/api/schemas.py @@ -5,6 +5,16 @@ from app.db import SiteType, EventType, MediaType, AlertType +# Template class +class _CreatedAt(BaseModel): + created_at: datetime = None + + @staticmethod + @validator('created_at', pre=True, always=True) + def default_ts_created(v): + return v or datetime.utcnow() + + # Abstract information about a user class UserInfo(BaseModel): username: str = Field(..., min_length=3, max_length=50) @@ -20,25 +30,17 @@ class UserCredHash(BaseModel): # Visible info -class UserRead(UserInfo): +class UserRead(UserInfo, _CreatedAt): id: int = Field(..., gt=0) - created_at: datetime = None - - @staticmethod - @validator('created_at', pre=True, always=True) - def default_ts_created(v): - return v or datetime.utcnow() # Authentication request -class UserAuth(UserInfo): - password: str +class UserAuth(UserInfo, UserCred): scopes: Optional[str] = "me" # Creation payload -class UserCreation(UserInfo): - hashed_password: str +class UserCreation(UserInfo, UserCredHash): scopes: str @@ -59,14 +61,8 @@ class SiteIn(BaseModel): type: SiteType = SiteType.tower -class SiteOut(SiteIn): +class SiteOut(SiteIn, _CreatedAt): id: int = Field(..., gt=0) - created_at: datetime = None - - @staticmethod - @validator('created_at', pre=True, always=True) - def default_ts_created(v): - return v or datetime.utcnow() class EventIn(BaseModel): @@ -77,14 +73,8 @@ class EventIn(BaseModel): start_ts: datetime = None -class EventOut(EventIn): +class EventOut(EventIn, _CreatedAt): id: int = Field(..., gt=0) - created_at: datetime = None - - @staticmethod - @validator('created_at', pre=True, always=True) - def default_ts_created(v): - return v or datetime.utcnow() class DeviceIn(BaseModel): @@ -99,14 +89,8 @@ class DeviceIn(BaseModel): last_ping: datetime = None -class DeviceOut(DeviceIn): +class DeviceOut(DeviceIn, _CreatedAt): id: int = Field(..., gt=0) - created_at: datetime = None - - @staticmethod - @validator('created_at', pre=True, always=True) - def default_ts_created(v): - return v or datetime.utcnow() class MediaIn(BaseModel): @@ -114,14 +98,8 @@ class MediaIn(BaseModel): type: MediaType = MediaType.image -class MediaOut(MediaIn): +class MediaOut(MediaIn, _CreatedAt): id: int = Field(..., gt=0) - created_at: datetime = None - - @staticmethod - @validator('created_at', pre=True, always=True) - def default_ts_created(v): - return v or datetime.utcnow() class InstallationIn(BaseModel): @@ -136,14 +114,8 @@ class InstallationIn(BaseModel): end_ts: datetime = None -class InstallationOut(InstallationIn): +class InstallationOut(InstallationIn, _CreatedAt): id: int = Field(..., gt=0) - created_at: datetime = None - - @staticmethod - @validator('created_at', pre=True, always=True) - def default_ts_created(v): - return v or datetime.utcnow() class AlertIn(BaseModel): @@ -156,11 +128,5 @@ class AlertIn(BaseModel): is_acknowledged: bool = False -class AlertOut(AlertIn): +class AlertOut(AlertIn, _CreatedAt): id: int = Field(..., gt=0) - created_at: datetime = None - - @staticmethod - @validator('created_at', pre=True, always=True) - def default_ts_created(v): - return v or datetime.utcnow()