Skip to content

Commit

Permalink
Fix: user update on school update
Browse files Browse the repository at this point in the history
  • Loading branch information
Rotheem committed Jan 7, 2025
1 parent 010097a commit 3f2a87b
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 17 deletions.
24 changes: 20 additions & 4 deletions app/core/schools/cruds_schools.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,28 @@ async def get_schools(db: AsyncSession) -> Sequence[models_core.CoreSchool]:
async def get_school_by_id(
db: AsyncSession,
school_id: UUID,
) -> models_core.CoreSchool | None:
) -> schemas_core.CoreSchool | None:
"""Return school with id from database"""
result = await db.execute(
select(models_core.CoreSchool).where(models_core.CoreSchool.id == school_id),
result = (
(
await db.execute(
select(models_core.CoreSchool).where(
models_core.CoreSchool.id == school_id,
),
)
)
.scalars()
.first()
)
return (
schemas_core.CoreSchool(
name=result.name,
email_regex=result.email_regex,
id=result.id,
)
if result
else None
)
return result.scalars().first()


async def get_school_by_name(
Expand Down
20 changes: 10 additions & 10 deletions app/core/schools/endpoints_schools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
School management is part of the core of Hyperion. These endpoints allow managing schools.
"""

import logging
import re
import uuid

Expand All @@ -24,8 +23,6 @@

router = APIRouter(tags=["Schools"])

hyperion_error_logger = logging.getLogger("hyperion.error")


@router.get(
"/schools/",
Expand Down Expand Up @@ -136,7 +133,6 @@ async def update_school(

if not school:
raise HTTPException(status_code=404, detail="School not found")

# If the request ask to update the school name, we need to check it is available
if school_update.name and school_update.name != school.name:
if (
Expand All @@ -155,12 +151,16 @@ async def update_school(
school_id=school_id,
school_update=school_update,
)

if (
school_update.email_regex is not None
and school_update.email_regex != school.email_regex
):
await cruds_users.remove_users_from_school(db, school_id=school_id)
try:
await db.commit()
except IntegrityError:
await db.rollback()
raise
users = await cruds_users.get_users(
db,
schools_ids=[SchoolType.no_school.value],
Expand All @@ -175,11 +175,11 @@ async def update_school(
account_type=AccountType.other_school_student,
),
)
try:
await db.commit()
except IntegrityError:
await db.rollback()
raise
try:
await db.commit()
except IntegrityError:
await db.rollback()
raise


@router.delete(
Expand Down
27 changes: 24 additions & 3 deletions tests/test_schools.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

admin_user: models_core.CoreUser
ens_user: models_core.CoreUser
fake_ens_user: models_core.CoreUser
new_school_user: models_core.CoreUser

UNIQUE_TOKEN = "my_unique_token"
Expand All @@ -24,12 +25,12 @@

@pytest_asyncio.fixture(scope="module", autouse=True)
async def init_objects() -> None:
global admin_user, ens_user, new_school_user
global admin_user, ens_user, fake_ens_user, new_school_user

ens = models_core.CoreSchool(
id=id_test_ens,
name="ENS",
email_regex=r"^.*@ens.fr$",
email_regex=r"^.*@.*ens.fr$",
)
await add_object_to_db(ens)

Expand All @@ -41,6 +42,12 @@ async def init_objects() -> None:
email="[email protected]",
account_type=AccountType.other_school_student,
)
fake_ens_user = await create_user_with_groups(
[],
school_id=id_test_ens,
email="[email protected]",
account_type=AccountType.other_school_student,
)

new_school_user = await create_user_with_groups(
[],
Expand Down Expand Up @@ -123,7 +130,7 @@ def test_update_school(client: TestClient) -> None:

response = client.patch(
f"/schools/{id_test_ens}",
json={"name": "school ENS", "email_regex": r"^.*@.*ens.fr$"},
json={"name": "school ENS", "email_regex": r"^.*@ens.fr$"},
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 204
Expand All @@ -135,6 +142,20 @@ def test_update_school(client: TestClient) -> None:
data = response.json()
assert data["name"] == "school ENS"

response = client.get(
f"/users/{ens_user.id}",
headers={"Authorization": f"Bearer {token}"},
)
data = response.json()
assert data["school_id"] == str(id_test_ens)

response = client.get(
f"/users/{fake_ens_user.id}",
headers={"Authorization": f"Bearer {token}"},
)
data = response.json()
assert data["school_id"] == str(SchoolType.no_school.value)


def test_create_user_corresponding_to_school(
mocker: MockerFixture,
Expand Down

0 comments on commit 3f2a87b

Please sign in to comment.