Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added a password update route #15

Merged
merged 5 commits into from
Nov 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions src/app/api/routes/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from app.api import routing, security
from app.db import users
from app.api.schemas import UserInfo, UserRead, UserAuth, UserCreation
from app.api.schemas import UserInfo, UserCred, UserRead, UserAuth, UserCreation
from app.api.deps import get_current_user


Expand All @@ -16,11 +16,16 @@ async def get_my_user(me: UserRead = Security(get_current_user, scopes=["me"])):
return me


@router.put("/update-me", response_model=UserRead)
async def update_me(payload: UserInfo, me: UserRead = Security(get_current_user, scopes=["me"])):
@router.put("/update-info", response_model=UserRead)
async def update_my_info(payload: UserInfo, me: UserRead = Security(get_current_user, scopes=["me"])):
return await routing.update_entry(users, payload, me.id)


@router.put("/update-pwd", response_model=UserInfo)
async def update_my_password(payload: UserCred, me: UserRead = Security(get_current_user, scopes=["me"])):
return await routing.update_user_pwd(users, payload, me.id)


@router.post("/", response_model=UserRead, status_code=201)
async def create_user(payload: UserAuth, _=Security(get_current_user, scopes=["admin"])):
return await routing.create_user(users, payload)
Expand All @@ -45,6 +50,15 @@ async def update_user(
return await routing.update_entry(users, payload, user_id)


@router.put("/{user_id}/pwd", response_model=UserInfo)
async def reset_password(
payload: UserCred,
user_id: int = Path(..., gt=0),
_=Security(get_current_user, scopes=["admin"])
):
return await routing.update_user_pwd(users, payload, user_id)


@router.delete("/{user_id}/", response_model=UserRead)
async def delete_user(user_id: int = Path(..., gt=0), _=Security(get_current_user, scopes=["admin"])):
return await routing.delete_entry(users, user_id)
13 changes: 12 additions & 1 deletion src/app/api/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fastapi import HTTPException, Path

from app.api import crud, security
from app.api.schemas import UserAuth, UserCreation
from app.api.schemas import UserAuth, UserCreation, UserCred, UserCredHash


async def create_entry(table: Table, payload: BaseModel):
Expand Down Expand Up @@ -54,3 +54,14 @@ async def create_user(user_table: Table, payload: UserAuth):
pwd = await security.hash_password(payload.password)
payload = UserCreation(username=payload.username, hashed_password=pwd, scopes=payload.scopes)
return await create_entry(user_table, payload)


async def update_user_pwd(user_table: Table, payload: UserCred, entry_id: int = Path(..., gt=0)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One should create another file to store those methods that do the interface between routing and crud. In my PR I added some lines in this file too. If this continue, the file will end up quite big.

entry = await get_entry(user_table, entry_id)
# Hash the password
pwd = await security.hash_password(payload.password)
# Update the password
payload = UserCredHash(hashed_password=pwd)
await crud.put(entry_id, payload, user_table)
# Return non-sensitive information
return {"username": entry["username"]}
81 changes: 28 additions & 53 deletions src/app/api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,42 @@
from app.db import SiteType, EventType, MediaType, AlertType


# Template class
class _CreatedAt(BaseModel):
created_at: datetime = None

@staticmethod
@validator('created_at', pre=True, always=True)
def default_ts_created(v):
return v or datetime.utcnow()


# Abstract information about a user
class UserInfo(BaseModel):
username: str = Field(..., min_length=3, max_length=50)


# Sensitive information about the user
class UserCred(BaseModel):
password: str


class UserCredHash(BaseModel):
hashed_password: str


# Visible info
class UserRead(UserInfo):
class UserRead(UserInfo, _CreatedAt):
id: int = Field(..., gt=0)
created_at: datetime = None

@staticmethod
@validator('created_at', pre=True, always=True)
def default_ts_created(v):
return v or datetime.utcnow()


# Authentication request
class UserAuth(UserInfo):
password: str
class UserAuth(UserInfo, UserCred):
scopes: Optional[str] = "me"


# Creation payload
class UserCreation(UserInfo):
hashed_password: str
class UserCreation(UserInfo, UserCredHash):
scopes: str


Expand All @@ -50,14 +61,8 @@ class SiteIn(BaseModel):
type: SiteType = SiteType.tower


class SiteOut(SiteIn):
class SiteOut(SiteIn, _CreatedAt):
id: int = Field(..., gt=0)
created_at: datetime = None

@staticmethod
@validator('created_at', pre=True, always=True)
def default_ts_created(v):
return v or datetime.utcnow()


class EventIn(BaseModel):
Expand All @@ -68,14 +73,8 @@ class EventIn(BaseModel):
start_ts: datetime = None


class EventOut(EventIn):
class EventOut(EventIn, _CreatedAt):
id: int = Field(..., gt=0)
created_at: datetime = None

@staticmethod
@validator('created_at', pre=True, always=True)
def default_ts_created(v):
return v or datetime.utcnow()


class DeviceIn(BaseModel):
Expand All @@ -90,29 +89,17 @@ class DeviceIn(BaseModel):
last_ping: datetime = None


class DeviceOut(DeviceIn):
class DeviceOut(DeviceIn, _CreatedAt):
id: int = Field(..., gt=0)
created_at: datetime = None

@staticmethod
@validator('created_at', pre=True, always=True)
def default_ts_created(v):
return v or datetime.utcnow()


class MediaIn(BaseModel):
device_id: int = Field(..., gt=0)
type: MediaType = MediaType.image


class MediaOut(MediaIn):
class MediaOut(MediaIn, _CreatedAt):
id: int = Field(..., gt=0)
created_at: datetime = None

@staticmethod
@validator('created_at', pre=True, always=True)
def default_ts_created(v):
return v or datetime.utcnow()


class InstallationIn(BaseModel):
Expand All @@ -127,14 +114,8 @@ class InstallationIn(BaseModel):
end_ts: datetime = None


class InstallationOut(InstallationIn):
class InstallationOut(InstallationIn, _CreatedAt):
id: int = Field(..., gt=0)
created_at: datetime = None

@staticmethod
@validator('created_at', pre=True, always=True)
def default_ts_created(v):
return v or datetime.utcnow()


class AlertIn(BaseModel):
Expand All @@ -147,11 +128,5 @@ class AlertIn(BaseModel):
is_acknowledged: bool = False


class AlertOut(AlertIn):
class AlertOut(AlertIn, _CreatedAt):
id: int = Field(..., gt=0)
created_at: datetime = None

@staticmethod
@validator('created_at', pre=True, always=True)
def default_ts_created(v):
return v or datetime.utcnow()
73 changes: 69 additions & 4 deletions src/tests/test_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ async def mock_get_all(table, query_filter=None):
assert [{k: v for k, v in r.items() if k != 'created_at'} for r in response.json()] == test_data


def test_update_user(test_app, monkeypatch):
def test_update_user_info(test_app, monkeypatch):
async def mock_get(entry_id, table):
return True

Expand All @@ -121,7 +121,7 @@ async def mock_put(entry_id, payload, table):

# test update connected_user (dont alter username for other tests)
test_update_data = {"username": "connected_user"}
response = test_app.put("/users/update-me", data=json.dumps(test_update_data))
response = test_app.put("/users/update-info", data=json.dumps(test_update_data))
assert response.status_code == 200
assert {k: v for k, v in response.json().items() if k != "created_at"} == {**test_update_data, "id": 99}

Expand All @@ -136,13 +136,78 @@ async def mock_put(entry_id, payload, table):
[0, {"username": "foo"}, 422],
],
)
def test_update_user_invalid(test_app, monkeypatch, user_id, payload, status_code):
def test_update_user_info_invalid(test_app, monkeypatch, user_id, payload, status_code):
async def mock_get(entry_id, table):
return None

monkeypatch.setattr(crud, "get", mock_get)

response = test_app.put(f"/users/{user_id}/", data=json.dumps(payload),)
response = test_app.put(f"/users/{user_id}/", data=json.dumps(payload))
assert response.status_code == status_code, print(payload)


def test_update_user_pwd(test_app, monkeypatch):

test_data = [
{"id": 1, "username": "someone", "hashed_password": "first_hashed", "scopes": "me"},
{"id": 99, "username": "connected_user", "hashed_password": "first_hashed", "scopes": "me"}
]

async def mock_get(entry_id, table):
for entry in test_data:
if entry['id'] == entry_id:
return entry
return None

monkeypatch.setattr(crud, "get", mock_get)

async def mock_put(entry_id, payload, table):
return entry_id

monkeypatch.setattr(crud, "put", mock_put)

test_update_data = {"password": "my_password"}
test_response = {"username": "someone"}
response = test_app.put("/users/1/pwd", data=json.dumps(test_update_data))
assert response.status_code == 200
assert response.json() == test_response

# test update connected_user (dont alter username for other tests)
test_update_data = {"password": "my_password"}
test_response = {"username": "connected_user"}
response = test_app.put("/users/update-pwd", data=json.dumps(test_update_data))
assert response.status_code == 200
assert response.json() == test_response


@pytest.mark.parametrize(
"user_id, payload, status_code",
[
[1, {}, 422],
[1, {"description": "bar"}, 422],
[999, {"password": "my_pwd"}, 404],
[0, {"password": "my_pwd"}, 422],
],
)
def test_update_user_pwd_invalid(test_app, monkeypatch, user_id, payload, status_code):
test_data = [
{"id": 1, "username": "someone", "hashed_password": "first_hashed", "scopes": "me"},
{"id": 99, "username": "connected_user", "hashed_password": "first_hashed", "scopes": "me"}
]

async def mock_get(entry_id, table):
for entry in test_data:
if entry['id'] == entry_id:
return entry
return None
monkeypatch.setattr(crud, "get", mock_get)

async def mock_put(entry_id, payload, table):
return entry_id

monkeypatch.setattr(crud, "put", mock_put)

response = test_app.put(f"/users/{user_id}/pwd", data=json.dumps(payload))
assert response.status_code == status_code, print(payload)


Expand Down