diff --git a/app/api.py b/app/api.py index 7ea2bdd32..46170ff86 100644 --- a/app/api.py +++ b/app/api.py @@ -10,6 +10,7 @@ from app.core.groups import endpoints_groups from app.core.notification import endpoints_notification from app.core.payment import endpoints_payment +from app.core.schools import endpoints_schools from app.core.users import endpoints_users from app.modules.module_list import module_list @@ -24,6 +25,7 @@ api_router.include_router(endpoints_notification.router) api_router.include_router(endpoints_payment.router) api_router.include_router(endpoints_users.router) +api_router.include_router(endpoints_schools.router) for module in module_list: api_router.include_router(module.router) diff --git a/app/app.py b/app/app.py index fdafb3461..8d0e2f2ff 100644 --- a/app/app.py +++ b/app/app.py @@ -27,6 +27,7 @@ from app.core.google_api.google_api import GoogleAPI from app.core.groups.groups_type import GroupType from app.core.log import LogConfig +from app.core.schools.schools_type import SchoolType from app.dependencies import ( get_db, get_redis_client, @@ -189,6 +190,32 @@ def initialize_groups( ) +def initialize_schools( + sync_engine: Engine, + hyperion_error_logger: logging.Logger, +) -> None: + """Add the necessary shools""" + + hyperion_error_logger.info("Startup: Adding new groups to the database") + with Session(sync_engine) as db: + for school in SchoolType: + exists = initialization.get_school_by_id_sync(school_id=school.value, db=db) + # We don't want to recreate the groups if they already exist + if not exists: + db_school = models_core.CoreSchool( + id=school.value, + name=school.name, + email_regex="null", + ) + + try: + initialization.create_school_sync(school=db_school, db=db) + except IntegrityError as error: + hyperion_error_logger.fatal( + f"Startup: Could not add school {db_school.name}<{db_school.id}> in the database: {error}", + ) + + def initialize_module_visibility( sync_engine: Engine, hyperion_error_logger: logging.Logger, @@ -301,6 +328,10 @@ def init_db( sync_engine=sync_engine, hyperion_error_logger=hyperion_error_logger, ) + initialize_schools( + sync_engine=sync_engine, + hyperion_error_logger=hyperion_error_logger, + ) initialize_module_visibility( sync_engine=sync_engine, hyperion_error_logger=hyperion_error_logger, diff --git a/app/core/auth/endpoints_auth.py b/app/core/auth/endpoints_auth.py index 666df75ce..abcc76937 100644 --- a/app/core/auth/endpoints_auth.py +++ b/app/core/auth/endpoints_auth.py @@ -41,7 +41,7 @@ from app.types.exceptions import AuthHTTPException from app.types.scopes_type import ScopeType from app.utils.auth.providers import BaseAuthClient -from app.utils.tools import is_user_member_of_an_allowed_group +from app.utils.tools import is_user_member_of_any_group router = APIRouter(tags=["Auth"]) @@ -308,7 +308,7 @@ async def authorize_validation( # The auth_client may restrict the usage of the client to specific Hyperion groups. # For example, only ECLAIR members may be allowed to access the wiki if auth_client.allowed_groups is not None: - if not is_user_member_of_an_allowed_group( + if not is_user_member_of_any_group( user=user, allowed_groups=auth_client.allowed_groups, ): diff --git a/app/core/groups/groups_type.py b/app/core/groups/groups_type.py index 8ec65a5a1..60faf5bf0 100644 --- a/app/core/groups/groups_type.py +++ b/app/core/groups/groups_type.py @@ -42,6 +42,7 @@ class AccountType(str, Enum): staff = "staff" association = "association" external = "external" + other_school_student = "other_school_student" demo = "demo" def __str__(self): @@ -57,11 +58,12 @@ def get_ecl_account_types() -> list[AccountType]: ] -def get_account_types_except_external() -> list[AccountType]: +def get_account_types_except_externals() -> list[AccountType]: return [ AccountType.student, AccountType.former_student, AccountType.staff, AccountType.association, AccountType.demo, + AccountType.other_school_student, ] diff --git a/app/core/models_core.py b/app/core/models_core.py index 2af794ae4..3de3827f3 100644 --- a/app/core/models_core.py +++ b/app/core/models_core.py @@ -1,6 +1,7 @@ """Common model files for all core in order to avoid circular import due to bidirectional relationship""" from datetime import date, datetime +from uuid import UUID from sqlalchemy import ForeignKey, String from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -29,6 +30,7 @@ class CoreUser(Base): index=True, ) # Use UUID later email: Mapped[str] = mapped_column(unique=True, index=True) + school_id: Mapped[UUID] = mapped_column(ForeignKey("core_school.id")) password_hash: Mapped[str] # Depending on the account type, the user may have different rights and access to different features # External users may exist for: @@ -54,6 +56,11 @@ class CoreUser(Base): lazy="selectin", default_factory=list, ) + school: Mapped["CoreSchool"] = relationship( + "CoreSchool", + lazy="selectin", + init=False, + ) class CoreUserUnconfirmed(Base): @@ -118,6 +125,14 @@ class CoreGroup(Base): ) +class CoreSchool(Base): + __tablename__ = "core_school" + + id: Mapped[PrimaryKey] + name: Mapped[str] = mapped_column(unique=True) + email_regex: Mapped[str] + + class CoreAssociationMembership(Base): __tablename__ = "core_association_membership" diff --git a/app/core/schemas_core.py b/app/core/schemas_core.py index 2fa48226b..a45768a86 100644 --- a/app/core/schemas_core.py +++ b/app/core/schemas_core.py @@ -1,6 +1,7 @@ """Common schemas file for endpoint /users et /groups because it would cause circular import""" from datetime import date, datetime +from uuid import UUID from pydantic import BaseModel, ConfigDict, Field from pydantic.functional_validators import field_validator @@ -19,6 +20,44 @@ class CoreInformation(BaseModel): minimal_titan_version_code: int +class CoreGroupBase(BaseModel): + """Base schema for group's model""" + + name: str + description: str | None = None + + _normalize_name = field_validator("name")(validators.trailing_spaces_remover) + + +class CoreGroupSimple(CoreGroupBase): + """Simplified schema for group's model, used when getting all groups""" + + id: str + model_config = ConfigDict(from_attributes=True) + + +class CoreSchoolBase(BaseModel): + """Schema for school's model""" + + name: str + email_regex: str + + _normalize_name = field_validator("name")(validators.trailing_spaces_remover) + + +class CoreSchool(CoreSchoolBase): + id: UUID + + +class CoreSchoolUpdate(BaseModel): + """Schema for school update""" + + name: str | None = None + email_regex: str | None = None + + _normalize_name = field_validator("name")(validators.trailing_spaces_remover) + + class CoreUserBase(BaseModel): """Base schema for user's model""" @@ -35,27 +74,13 @@ class CoreUserBase(BaseModel): ) -class CoreGroupBase(BaseModel): - """Base schema for group's model""" - - name: str - description: str | None = None - - _normalize_name = field_validator("name")(validators.trailing_spaces_remover) - - class CoreUserSimple(CoreUserBase): """Simplified schema for user's model, used when getting all users""" id: str account_type: AccountType - model_config = ConfigDict(from_attributes=True) - - -class CoreGroupSimple(CoreGroupBase): - """Simplified schema for group's model, used when getting all groups""" + school_id: UUID - id: str model_config = ConfigDict(from_attributes=True) @@ -63,13 +88,13 @@ class CoreUser(CoreUserSimple): """Schema for user's model similar to core_user table in database""" email: str - account_type: AccountType birthday: date | None = None promo: int | None = None floor: FloorsType | None = None phone: str | None = None created_on: datetime | None = None groups: list[CoreGroupSimple] = [] + school: CoreSchool | None = None class CoreUserUpdate(BaseModel): @@ -97,6 +122,8 @@ class CoreUserFusionRequest(BaseModel): class CoreUserUpdateAdmin(BaseModel): + email: str | None = None + school_id: UUID | None = None account_type: AccountType | None = None name: str | None = None firstname: str | None = None @@ -164,7 +191,7 @@ class CoreUserActivateRequest(CoreUserBase): floor: FloorsType | None = None promo: int | None = Field( default=None, - description="Promotion of the student, an integer like 21", + description="Promotion of the student, an integer like 2021", ) # Password validator @@ -189,13 +216,6 @@ class CoreGroupCreate(CoreGroupBase): """Model for group creation schema""" -class CoreGroupInDB(CoreGroupBase): - """Schema for user activation""" - - id: str - model_config = ConfigDict(from_attributes=True) - - class CoreGroupUpdate(BaseModel): """Schema for group update""" diff --git a/app/core/schools/__init__.py b/app/core/schools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app/core/schools/cruds_schools.py b/app/core/schools/cruds_schools.py new file mode 100644 index 000000000..54639e8ca --- /dev/null +++ b/app/core/schools/cruds_schools.py @@ -0,0 +1,86 @@ +"""File defining the functions called by the endpoints, making queries to the table using the models""" + +from collections.abc import Sequence +from uuid import UUID + +from sqlalchemy import delete, select, update +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core import models_core, schemas_core + + +async def get_schools(db: AsyncSession) -> Sequence[models_core.CoreSchool]: + """Return all schools from database""" + + result = await db.execute(select(models_core.CoreSchool)) + return result.scalars().all() + + +async def get_school_by_id( + db: AsyncSession, + school_id: UUID, +) -> schemas_core.CoreSchool | None: + """Return school with id from database""" + result = ( + ( + await db.execute( + select(models_core.CoreSchool).where( + models_core.CoreSchool.id == school_id, + ), + ) + ) + .scalars() + .first() + ) + return ( + schemas_core.CoreSchool( + name=result.name, + email_regex=result.email_regex, + id=result.id, + ) + if result + else None + ) + + +async def get_school_by_name( + db: AsyncSession, + school_name: str, +) -> models_core.CoreSchool | None: + """Return school with name from database""" + result = await db.execute( + select(models_core.CoreSchool).where( + models_core.CoreSchool.name == school_name, + ), + ) + return result.scalars().first() + + +async def create_school( + school: models_core.CoreSchool, + db: AsyncSession, +) -> None: + """Create a new school in database and return it""" + + db.add(school) + + +async def delete_school(db: AsyncSession, school_id: UUID): + """Delete a school from database by id""" + + await db.execute( + delete(models_core.CoreSchool).where(models_core.CoreSchool.id == school_id), + ) + + +async def update_school( + db: AsyncSession, + school_id: UUID, + school_update: schemas_core.CoreSchoolUpdate, +): + await db.execute( + update(models_core.CoreSchool) + .where(models_core.CoreSchool.id == school_id) + .values(**school_update.model_dump(exclude_none=True)), + ) + await db.commit() diff --git a/app/core/schools/endpoints_schools.py b/app/core/schools/endpoints_schools.py new file mode 100644 index 000000000..1654a0d9e --- /dev/null +++ b/app/core/schools/endpoints_schools.py @@ -0,0 +1,220 @@ +""" +File defining the API itself, using fastAPI and schemas, and calling the cruds functions + +School management is part of the core of Hyperion. These endpoints allow managing schools. +""" + +import re +import uuid + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core import models_core, schemas_core +from app.core.groups.groups_type import AccountType, GroupType +from app.core.schools import cruds_schools +from app.core.schools.schools_type import SchoolType +from app.core.users import cruds_users +from app.dependencies import ( + get_db, + is_user_in, +) + +router = APIRouter(tags=["Schools"]) + + +@router.get( + "/schools/", + response_model=list[schemas_core.CoreSchool], + status_code=200, +) +async def read_schools( + db: AsyncSession = Depends(get_db), +): + """ + Return all schools from database as a list of dictionaries + """ + + schools = await cruds_schools.get_schools(db) + return schools + + +@router.get( + "/schools/{school_id}", + response_model=schemas_core.CoreSchool, + status_code=200, +) +async def read_school( + school_id: uuid.UUID, + db: AsyncSession = Depends(get_db), +): + """ + Return school with id from database as a dictionary. + + **This endpoint is only usable by administrators** + """ + + db_school = await cruds_schools.get_school_by_id(db=db, school_id=school_id) + if db_school is None: + raise HTTPException(status_code=404, detail="School not found") + return db_school + + +@router.post( + "/schools/", + response_model=schemas_core.CoreSchool, + status_code=201, +) +async def create_school( + school: schemas_core.CoreSchoolBase, + db: AsyncSession = Depends(get_db), + user: schemas_core.CoreUser = Depends(is_user_in(GroupType.admin)), +): + """ + Create a new school and add users to it based on the email regex. + + **This endpoint is only usable by administrators** + """ + if ( # We can't have two schools with the same name + await cruds_schools.get_school_by_name(school_name=school.name, db=db) + is not None + ): + raise HTTPException( + status_code=400, + detail=f"A school with the name {school.name} already exist", + ) + + try: + db_school = models_core.CoreSchool( + id=uuid.uuid4(), + name=school.name, + email_regex=school.email_regex, + ) + await cruds_schools.create_school(school=db_school, db=db) + users = await cruds_users.get_users( + db=db, + schools_ids=[SchoolType.no_school.value], + ) + for db_user in users: + if re.match(db_school.email_regex, db_user.email): + await cruds_users.update_user( + db, + db_user.id, + schemas_core.CoreUserUpdateAdmin( + school_id=db_school.id, + account_type=AccountType.other_school_student, + ), + ) + await db.commit() + except IntegrityError: + await db.rollback() + raise + else: + return db_school + + +@router.patch( + "/schools/{school_id}", + status_code=204, +) +async def update_school( + school_id: uuid.UUID, + school_update: schemas_core.CoreSchoolUpdate, + db: AsyncSession = Depends(get_db), + user: schemas_core.CoreUser = Depends(is_user_in(GroupType.admin)), +): + """ + Update the name or the description of a school. + + **This endpoint is only usable by administrators** + """ + school = await cruds_schools.get_school_by_id(db=db, school_id=school_id) + + if not school: + raise HTTPException(status_code=404, detail="School not found") + # If the request ask to update the school name, we need to check it is available + if school_update.name and school_update.name != school.name: + if ( + await cruds_schools.get_school_by_name( + school_name=school_update.name, + db=db, + ) + is not None + ): + raise HTTPException( + status_code=400, + detail=f"A school with the name {school.name} already exist", + ) + await cruds_schools.update_school( + db=db, + school_id=school_id, + school_update=school_update, + ) + if ( + school_update.email_regex is not None + and school_update.email_regex != school.email_regex + ): + await cruds_users.remove_users_from_school(db, school_id=school_id) + try: + await db.commit() + except IntegrityError: + await db.rollback() + raise + users = await cruds_users.get_users( + db, + schools_ids=[SchoolType.no_school.value], + ) + for db_user in users: + if re.match(school_update.email_regex, db_user.email): + await cruds_users.update_user( + db, + db_user.id, + schemas_core.CoreUserUpdateAdmin( + school_id=school.id, + account_type=AccountType.other_school_student, + ), + ) + try: + await db.commit() + except IntegrityError: + await db.rollback() + raise + + +@router.delete( + "/schools/{school_id}", + status_code=204, +) +async def delete_school( + school_id: uuid.UUID, + db: AsyncSession = Depends(get_db), + user: schemas_core.CoreUser = Depends(is_user_in(GroupType.admin)), +): + """ + Delete school from database. + This will remove the school from all users but won't delete any user. + + `SchoolTypes` schools can not be deleted. + + **This endpoint is only usable by administrators** + """ + + if school_id in (SchoolType.list()): + raise HTTPException( + status_code=400, + detail="SchoolTypes schools can not be deleted", + ) + + school = await cruds_schools.get_school_by_id(db=db, school_id=school_id) + if school is None: + raise HTTPException(status_code=404, detail="School not found") + + await cruds_users.remove_users_from_school(db=db, school_id=school_id) + + await cruds_schools.delete_school(db=db, school_id=school_id) + try: + await db.commit() + except IntegrityError: + await db.rollback() + raise diff --git a/app/core/schools/schools_type.py b/app/core/schools/schools_type.py new file mode 100644 index 000000000..ed8ccb468 --- /dev/null +++ b/app/core/schools/schools_type.py @@ -0,0 +1,24 @@ +from enum import Enum +from uuid import UUID + + +class SchoolType(Enum): + """ + In Hyperion, each user must have a school. Belonging to a school gives access to a set of specific endpoints. + + This class defines the basic schools available in Hyperion. + Other schools can be added by the administrator using the API. + """ + + # Account types + no_school = UUID("dce19aa2-8863-4c93-861e-fb7be8f610ed") + centrale_lyon = UUID("d9772da7-1142-4002-8b86-b694b431dfed") + + # Auth related groups + + def __str__(self): + return f"{self.name}<{self.value}>" + + @staticmethod + def list(): + return [school.value for school in SchoolType] diff --git a/app/core/users/cruds_users.py b/app/core/users/cruds_users.py index 7b3c4c560..d6923608b 100644 --- a/app/core/users/cruds_users.py +++ b/app/core/users/cruds_users.py @@ -1,6 +1,7 @@ """File defining the functions called by the endpoints, making queries to the table using the models""" from collections.abc import Sequence +from uuid import UUID from sqlalchemy import ForeignKey, and_, delete, not_, or_, select, update from sqlalchemy.exc import IntegrityError @@ -10,6 +11,7 @@ from app.core import models_core, schemas_core from app.core.groups.groups_type import AccountType +from app.core.schools.schools_type import SchoolType async def count_users(db: AsyncSession) -> int: @@ -25,53 +27,72 @@ async def get_users( excluded_account_types: list[AccountType] | None = None, included_groups: list[str] | None = None, excluded_groups: list[str] | None = None, + schools_ids: list[UUID] | None = None, ) -> Sequence[models_core.CoreUser]: """ Return all users from database. Parameters `excluded_account_types` and `excluded_groups` can be used to filter results. """ - included_account_types = included_account_types or list(AccountType) + included_account_types = included_account_types or None excluded_account_types = excluded_account_types or [] included_groups = included_groups or [] excluded_groups = excluded_groups or [] + schools_ids = schools_ids or None + + # We want, for each group that should be included check if + # - at least one of the user's groups match the expected group + included_group_condition = [ + models_core.CoreUser.groups.any( + models_core.CoreGroup.id == group_id, + ) + for group_id in included_groups + ] + included_account_type_condition = ( + or_( + False, + *[ + models_core.CoreUser.account_type == account_type + for account_type in included_account_types + ], + ) + if included_account_types + else and_(True) + ) + # We want, for each group that should not be included + # check that the following condition is false : + # - at least one of the user's groups match the expected group + excluded_group_condition = [ + not_( + models_core.CoreUser.groups.any( + models_core.CoreGroup.id == group_id, + ), + ) + for group_id in excluded_groups + ] + excluded_account_type_condition = [ + not_( + models_core.CoreUser.account_type == account_type, + ) + for account_type in excluded_account_types + ] + school_condition = ( + or_( + *[models_core.CoreUser.school_id == school_id for school_id in schools_ids], + ) + if schools_ids + else and_(True) + ) result = await db.execute( select(models_core.CoreUser).where( and_( True, - # We want, for each group that should be included check if - # - at least one of the user's groups match the expected group - *[ - models_core.CoreUser.groups.any( - models_core.CoreGroup.id == group_id, - ) - for group_id in included_groups - ], - or_( - False, - *[ - models_core.CoreUser.account_type == account_type - for account_type in included_account_types - ], - ), - *[ - not_( - models_core.CoreUser.account_type == account_type, - ) - for account_type in excluded_account_types - ], - # We want, for each group that should not be included - # check that the following condition is false : - # - at least one of the user's groups match the expected group - *[ - not_( - models_core.CoreUser.groups.any( - models_core.CoreGroup.id == group_id, - ), - ) - for group_id in excluded_groups - ], + *included_group_condition, + included_account_type_condition, + *excluded_account_type_condition, + *excluded_group_condition, + school_condition, ), ), ) @@ -120,7 +141,6 @@ async def update_user( .where(models_core.CoreUser.id == user_id) .values(**user_update.model_dump(exclude_none=True)), ) - await db.commit() async def create_unconfirmed_user( @@ -277,6 +297,22 @@ async def update_user_password_by_id( await db.commit() +async def remove_users_from_school( + db: AsyncSession, + school_id: UUID, +): + await db.execute( + update(models_core.CoreUser) + .where( + models_core.CoreUser.school_id == school_id, + ) + .values( + school_id=SchoolType.no_school.value, + account_type=AccountType.external, + ), + ) + + async def fusion_users( db: AsyncSession, user_kept_id: str, diff --git a/app/core/users/endpoints_users.py b/app/core/users/endpoints_users.py index 9c01c54e2..0a59dacb8 100644 --- a/app/core/users/endpoints_users.py +++ b/app/core/users/endpoints_users.py @@ -1,5 +1,4 @@ import logging -import re import string import uuid from datetime import UTC, datetime, timedelta @@ -18,13 +17,16 @@ ) from fastapi.responses import FileResponse from fastapi.templating import Jinja2Templates +from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from app.core import models_core, schemas_core, security, standard_responses from app.core.config import Settings from app.core.groups import cruds_groups from app.core.groups.groups_type import AccountType, GroupType +from app.core.schools.schools_type import SchoolType from app.core.users import cruds_users +from app.core.users.tools_users import get_account_type_and_school_id_from_email from app.dependencies import ( get_db, get_request_id, @@ -50,13 +52,9 @@ templates = Jinja2Templates(directory="assets/templates") -ECL_STAFF_REGEX = r"^[\w\-.]*@(enise\.)?ec-lyon\.fr$" -ECL_STUDENT_REGEX = r"^[\w\-.]*@((etu(-enise)?)|(ecl\d{2}))\.ec-lyon\.fr$" -ECL_FORMER_STUDENT_REGEX = r"^[\w\-.]*@centraliens-lyon\.net$" - @router.get( - "/users/", + "/users", response_model=list[schemas_core.CoreUserSimple], status_code=200, ) @@ -361,46 +359,18 @@ async def activate_user( detail=f"The account with the email {unconfirmed_user.email} is already confirmed", ) - # Check the account type - - # For staff and student - # ^[\w\-.]*@((etu(-enise)?|enise)\.)?ec-lyon\.fr$ - # For staff - # ^[\w\-.]*@(enise\.)?ec-lyon\.fr$ - # For student - # ^[\w\-.]*@etu(-enise)?\.ec-lyon\.fr$ - - # For former students - # ^[\w\-.]*@centraliens-lyon\.net$ - - # All accepted emails - # ^[\w\-.]*@(((etu(-enise)?|enise)\.)?ec-lyon\.fr|centraliens-lyon\.net)$ - - # By default we mark the user as external - # but if it has an ECL email address, we will mark it as member - account_type = AccountType.external - if re.match(ECL_STAFF_REGEX, unconfirmed_user.email): - # Its a staff email address - account_type = AccountType.staff - elif re.match( - ECL_STUDENT_REGEX, - unconfirmed_user.email, - ): - # Its a student email address - account_type = AccountType.student - elif re.match( - ECL_FORMER_STUDENT_REGEX, - unconfirmed_user.email, - ): - # Its a former student email address - account_type = AccountType.former_student - + # Get the account type and school_id from the email + account_type, school_id = await get_account_type_and_school_id_from_email( + email=unconfirmed_user.email, + db=db, + ) # A password should have been provided password_hash = security.get_password_hash(user.password) confirmed_user = models_core.CoreUser( id=unconfirmed_user.id, email=unconfirmed_user.email, + school_id=school_id, account_type=account_type, password_hash=password_hash, name=user.name, @@ -605,16 +575,9 @@ async def migrate_mail( settings: Settings = Depends(get_settings), ): """ - Due to a change in the email format, all student users need to migrate their email address. This endpoint will send a confirmation code to the user's new email address. He will need to use this code to confirm the change with `/users/confirm-mail-migration` endpoint. """ - if not re.match(ECL_STUDENT_REGEX, mail_migration.new_email): - raise HTTPException( - status_code=400, - detail="The new email address must match the new ECL format for student users", - ) - existing_user = await cruds_users.get_user_by_email( db=db, email=mail_migration.new_email, @@ -635,6 +598,17 @@ async def migrate_mail( ) return + # We need to make sur the user will keep the same school if he is not a no_school user + _, new_school_id = await get_account_type_and_school_id_from_email( + email=mail_migration.new_email, + db=db, + ) + if user.school_id is not SchoolType.no_school and user.school_id != new_school_id: + raise HTTPException( + status_code=400, + detail="New email address is not compatible with the current school", + ) + await create_and_send_email_migration( user_id=user.id, new_email=mail_migration.new_email, @@ -654,8 +628,8 @@ async def migrate_mail_confirm( db: AsyncSession = Depends(get_db), ): """ - Due to a change in the email format, all student users need to migrate their email address. This endpoint will updates the user new email address. + The user will need to use the confirmation code sent by the `/users/migrate-mail` endpoint. """ migration_object = await cruds_users.get_email_migration_code_by_token( @@ -692,17 +666,31 @@ async def migrate_mail_confirm( detail="User not found", ) + account, new_school_id = await get_account_type_and_school_id_from_email( + email=migration_object.new_email, + db=db, + ) try: await cruds_users.update_user( - db, - migration_object.user_id, + db=db, + user_id=migration_object.user_id, user_update=schemas_core.CoreUserUpdateAdmin( email=migration_object.new_email, - account_type=AccountType.student, + account_type=account, + school_id=new_school_id, ), ) + await db.commit() + except Exception as error: + await db.rollback() raise HTTPException(status_code=400, detail=str(error)) + except IntegrityError: + await db.rollback() + raise HTTPException( + status_code=400, + detail="Email migration failed due to database integrity error", + ) await cruds_users.delete_email_migration_code_by_token( confirmation_token=token, @@ -824,6 +812,14 @@ async def update_current_user( """ await cruds_users.update_user(db=db, user_id=user.id, user_update=user_update) + try: + await db.commit() + except IntegrityError: + await db.rollback() + raise HTTPException( + status_code=400, + detail="Update failed due to database integrity error", + ) @router.post( @@ -894,6 +890,14 @@ async def update_user( raise HTTPException(status_code=404, detail="User not found") await cruds_users.update_user(db=db, user_id=user_id, user_update=user_update) + try: + await db.commit() + except IntegrityError: + await db.rollback() + raise HTTPException( + status_code=400, + detail="Update failed due to database integrity error", + ) @router.post( diff --git a/app/core/users/tools_users.py b/app/core/users/tools_users.py new file mode 100644 index 000000000..2caaa606d --- /dev/null +++ b/app/core/users/tools_users.py @@ -0,0 +1,35 @@ +import re +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.groups.groups_type import AccountType +from app.core.schools import cruds_schools +from app.core.schools.schools_type import SchoolType + +ECL_STAFF_REGEX = r"^[\w\-.]*@(enise\.)?ec-lyon\.fr$" +ECL_STUDENT_REGEX = r"^[\w\-.]*@((etu(-enise)?)|(ecl\d{2}))\.ec-lyon\.fr$" +ECL_FORMER_STUDENT_REGEX = r"^[\w\-.]*@centraliens-lyon\.net$" + + +async def get_account_type_and_school_id_from_email( + email: str, + db: AsyncSession, +) -> tuple[AccountType, UUID]: + """Return the account type from the email""" + if re.match(ECL_STAFF_REGEX, email): + return AccountType.staff, SchoolType.centrale_lyon.value + if re.match(ECL_STUDENT_REGEX, email): + return AccountType.student, SchoolType.centrale_lyon.value + if re.match(ECL_FORMER_STUDENT_REGEX, email): + return AccountType.former_student, SchoolType.centrale_lyon.value + schools = await cruds_schools.get_schools(db) + + schools = [school for school in schools if school.id not in SchoolType.list()] + school = next( + (school for school in schools if re.match(school.email_regex, email)), + None, + ) + if school: + return AccountType.other_school_student, school.id + return AccountType.external, SchoolType.no_school.value diff --git a/app/dependencies.py b/app/dependencies.py index 1531408f9..8c9bcfd9e 100644 --- a/app/dependencies.py +++ b/app/dependencies.py @@ -33,7 +33,10 @@ async def get_users(db: AsyncSession = Depends(get_db)): from app.utils.auth import auth_utils from app.utils.communication.notifications import NotificationManager, NotificationTool from app.utils.redis import connect -from app.utils.tools import is_user_external, is_user_member_of_an_allowed_group +from app.utils.tools import ( + is_user_external, + is_user_member_of_any_group, +) # We could maybe use hyperion.security hyperion_access_logger = logging.getLogger("hyperion.access") @@ -332,12 +335,12 @@ def is_user( status_code=403, detail="Unauthorized, user is an external user", ) - if is_user_member_of_an_allowed_group(user, excluded_groups): + if is_user_member_of_any_group(user, excluded_groups): raise HTTPException( status_code=403, detail=f"Unauthorized, user is a member of any of the groups {excluded_groups}", ) - if included_groups is not None and not is_user_member_of_an_allowed_group( + if included_groups is not None and not is_user_member_of_any_group( user, included_groups, ): diff --git a/app/modules/advert/endpoints_advert.py b/app/modules/advert/endpoints_advert.py index a52958cc4..1c5740b4a 100644 --- a/app/modules/advert/endpoints_advert.py +++ b/app/modules/advert/endpoints_advert.py @@ -24,7 +24,7 @@ from app.utils.tools import ( get_file_from_data, is_group_id_valid, - is_user_member_of_an_allowed_group, + is_user_member_of_any_group, save_file_as_data, ) @@ -246,7 +246,7 @@ async def create_advert( detail="Invalid advertiser_id", ) - if not is_user_member_of_an_allowed_group(user, [advertiser.group_manager_id]): + if not is_user_member_of_any_group(user, [advertiser.group_manager_id]): raise HTTPException( status_code=403, detail=f"Unauthorized to manage {advertiser.name} adverts", @@ -301,7 +301,7 @@ async def update_advert( detail="Invalid advert_id", ) - if not is_user_member_of_an_allowed_group( + if not is_user_member_of_any_group( user, [advert.advertiser.group_manager_id], ): @@ -338,7 +338,7 @@ async def delete_advert( detail="Invalid advert_id", ) - if not is_user_member_of_an_allowed_group( + if not is_user_member_of_any_group( user, [GroupType.admin, advert.advertiser.group_manager_id], ): diff --git a/app/modules/amap/endpoints_amap.py b/app/modules/amap/endpoints_amap.py index c797eec55..b1238a174 100644 --- a/app/modules/amap/endpoints_amap.py +++ b/app/modules/amap/endpoints_amap.py @@ -25,7 +25,7 @@ from app.types.module import Module from app.utils.communication.notifications import NotificationTool from app.utils.redis import locker_get, locker_set -from app.utils.tools import is_user_member_of_an_allowed_group +from app.utils.tools import is_user_member_of_any_group module = Module( root="amap", @@ -415,8 +415,7 @@ async def add_order_to_delievery( raise HTTPException(status_code=400, detail="Invalid request") if not ( - user.id == order.user_id - or is_user_member_of_an_allowed_group(user, [GroupType.amap]) + user.id == order.user_id or is_user_member_of_any_group(user, [GroupType.amap]) ): raise HTTPException( status_code=403, @@ -537,7 +536,7 @@ async def edit_order_from_delivery( if not ( user.id == previous_order.user_id - or is_user_member_of_an_allowed_group(user, [GroupType.amap]) + or is_user_member_of_any_group(user, [GroupType.amap]) ): raise HTTPException( status_code=403, @@ -637,7 +636,7 @@ async def remove_order( **A member of the group AMAP can delete orders of other users** """ - is_user_admin = is_user_member_of_an_allowed_group(user, [GroupType.amap]) + is_user_admin = is_user_member_of_any_group(user, [GroupType.amap]) order = await cruds_amap.get_order_by_id(db=db, order_id=order_id) if not order: raise HTTPException(status_code=404, detail="No order found") @@ -830,9 +829,7 @@ async def get_cash_by_id( if user_db is None: raise HTTPException(status_code=404, detail="User not found") - if not ( - user_id == user.id or is_user_member_of_an_allowed_group(user, [GroupType.amap]) - ): + if not (user_id == user.id or is_user_member_of_any_group(user, [GroupType.amap])): raise HTTPException( status_code=403, detail="Users that are not member of the group AMAP can only access the endpoint for their own user_id.", @@ -967,9 +964,7 @@ async def get_orders_of_user( if not user_requested: raise HTTPException(status_code=404, detail="User not found") - if not ( - user_id == user.id or is_user_member_of_an_allowed_group(user, [GroupType.amap]) - ): + if not (user_id == user.id or is_user_member_of_any_group(user, [GroupType.amap])): raise HTTPException( status_code=403, detail="Users that are not member of the group AMAP can only access the endpoint for their own user_id.", diff --git a/app/modules/booking/endpoints_booking.py b/app/modules/booking/endpoints_booking.py index f5d06bd19..e80cd7529 100644 --- a/app/modules/booking/endpoints_booking.py +++ b/app/modules/booking/endpoints_booking.py @@ -20,7 +20,7 @@ from app.modules.booking.types_booking import Decision from app.types.module import Module from app.utils.communication.notifications import NotificationTool -from app.utils.tools import is_group_id_valid, is_user_member_of_an_allowed_group +from app.utils.tools import is_group_id_valid, is_user_member_of_any_group module = Module( root="booking", @@ -312,7 +312,7 @@ async def edit_booking( if not ( (user.id == booking.applicant_id and booking.decision == Decision.pending) - or is_user_member_of_an_allowed_group(user, [booking.room.manager.group_id]) + or is_user_member_of_any_group(user, [booking.room.manager.group_id]) ): raise HTTPException( status_code=403, @@ -350,7 +350,7 @@ async def confirm_booking( booking_id=booking_id, ) - if is_user_member_of_an_allowed_group(user, [booking.room.manager.group_id]): + if is_user_member_of_any_group(user, [booking.room.manager.group_id]): await cruds_booking.confirm_booking( booking_id=booking_id, decision=decision, diff --git a/app/modules/calendar/endpoints_calendar.py b/app/modules/calendar/endpoints_calendar.py index e980993c1..233607c75 100644 --- a/app/modules/calendar/endpoints_calendar.py +++ b/app/modules/calendar/endpoints_calendar.py @@ -11,7 +11,7 @@ from app.modules.calendar import cruds_calendar, models_calendar, schemas_calendar from app.modules.calendar.types_calendar import Decision from app.types.module import Module -from app.utils.tools import is_user_member_of_an_allowed_group +from app.utils.tools import is_user_member_of_any_group module = Module( root="event", @@ -69,7 +69,7 @@ async def get_applicant_bookings( **Usable by the user or admins** """ - if user.id == applicant_id or is_user_member_of_an_allowed_group( + if user.id == applicant_id or is_user_member_of_any_group( user, [GroupType.BDE], ): @@ -163,7 +163,7 @@ async def edit_bookings_id( if event is not None and not ( (user.id == event.applicant_id and event.decision == Decision.pending) - or is_user_member_of_an_allowed_group(user, [GroupType.BDE]) + or is_user_member_of_any_group(user, [GroupType.BDE]) ): raise HTTPException( status_code=403, @@ -212,7 +212,7 @@ async def delete_bookings_id( if event is not None and ( (user.id == event.applicant_id and event.decision == Decision.pending) - or is_user_member_of_an_allowed_group(user, [GroupType.BDE]) + or is_user_member_of_any_group(user, [GroupType.BDE]) ): await cruds_calendar.delete_event(event_id=event_id, db=db) diff --git a/app/modules/campaign/endpoints_campaign.py b/app/modules/campaign/endpoints_campaign.py index 3ac902a1c..5aeb89715 100644 --- a/app/modules/campaign/endpoints_campaign.py +++ b/app/modules/campaign/endpoints_campaign.py @@ -24,7 +24,7 @@ from app.types.module import Module from app.utils.tools import ( get_file_from_data, - is_user_member_of_an_allowed_group, + is_user_member_of_any_group, save_file_as_data, ) @@ -54,7 +54,7 @@ async def get_sections( voters = await cruds_campaign.get_voters(db) voters_groups = [voter.group_id for voter in voters] voters_groups.append(GroupType.CAA) - if not is_user_member_of_an_allowed_group(user, voters_groups): + if not is_user_member_of_any_group(user, voters_groups): raise HTTPException( status_code=403, detail="Access forbidden : you are not a poll member", @@ -142,7 +142,7 @@ async def get_lists( voters = await cruds_campaign.get_voters(db) voters_groups = [voter.group_id for voter in voters] voters_groups.append(GroupType.CAA) - if not is_user_member_of_an_allowed_group(user, voters_groups): + if not is_user_member_of_any_group(user, voters_groups): raise HTTPException( status_code=403, detail="Access forbidden : you are not a poll member", @@ -607,7 +607,7 @@ async def vote( """ voters = await cruds_campaign.get_voters(db) voters_groups = [voter.group_id for voter in voters] - if not is_user_member_of_an_allowed_group(user, voters_groups): + if not is_user_member_of_any_group(user, voters_groups): raise HTTPException( status_code=403, detail="Access forbidden : you are not a poll member", @@ -674,7 +674,7 @@ async def get_sections_already_voted( """ voters = await cruds_campaign.get_voters(db) voters_groups = [voter.group_id for voter in voters] - if not is_user_member_of_an_allowed_group(user, voters_groups): + if not is_user_member_of_any_group(user, voters_groups): raise HTTPException( status_code=403, detail="Access forbidden : you are not a poll member", @@ -709,7 +709,7 @@ async def get_results( voters = await cruds_campaign.get_voters(db) voters_groups = [voter.group_id for voter in voters] voters_groups.append(GroupType.CAA) - if not is_user_member_of_an_allowed_group(user, voters_groups): + if not is_user_member_of_any_group(user, voters_groups): raise HTTPException( status_code=403, detail="Access forbidden : you are not a poll member", @@ -719,7 +719,7 @@ async def get_results( if ( status == StatusType.counting - and is_user_member_of_an_allowed_group(user, [GroupType.CAA]) + and is_user_member_of_any_group(user, [GroupType.CAA]) ) or status == StatusType.published: votes = await cruds_campaign.get_votes(db=db) @@ -764,7 +764,7 @@ async def get_status_vote( voters = await cruds_campaign.get_voters(db) voters_groups = [voter.group_id for voter in voters] voters_groups.append(GroupType.CAA) - if not is_user_member_of_an_allowed_group(user, voters_groups): + if not is_user_member_of_any_group(user, voters_groups): raise HTTPException( status_code=403, detail="Access forbidden : you are not a poll member", @@ -864,7 +864,7 @@ async def read_campaigns_logo( voters = await cruds_campaign.get_voters(db) voters_groups = [voter.group_id for voter in voters] voters_groups.append(GroupType.CAA) - if not is_user_member_of_an_allowed_group(user, voters_groups): + if not is_user_member_of_any_group(user, voters_groups): raise HTTPException( status_code=403, detail="Access forbidden : you are not a poll member", diff --git a/app/modules/cdr/endpoints_cdr.py b/app/modules/cdr/endpoints_cdr.py index 5be87ee05..17f13392c 100644 --- a/app/modules/cdr/endpoints_cdr.py +++ b/app/modules/cdr/endpoints_cdr.py @@ -12,6 +12,7 @@ WebSocket, ) from fastapi.responses import FileResponse +from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from app.core import models_core, schemas_core @@ -54,7 +55,7 @@ from app.utils.tools import ( create_and_send_email_migration, get_core_data, - is_user_member_of_an_allowed_group, + is_user_member_of_any_group, set_core_data, ) @@ -83,7 +84,7 @@ async def get_cdr_users( **User must be part of a seller group to use this endpoint** """ if not ( - is_user_member_of_an_allowed_group(user, [GroupType.admin_cdr]) + is_user_member_of_any_group(user, [GroupType.admin_cdr]) or await cruds_cdr.get_sellers_by_group_ids( db=db, group_ids=[g.id for g in user.groups], @@ -117,7 +118,7 @@ async def get_cdr_users_pending_validation( **User must be part of a seller group to use this endpoint** """ if not ( - is_user_member_of_an_allowed_group(user, [GroupType.admin_cdr]) + is_user_member_of_any_group(user, [GroupType.admin_cdr]) or await cruds_cdr.get_sellers_by_group_ids( db=db, group_ids=[g.id for g in user.groups], @@ -142,6 +143,7 @@ async def get_cdr_users_pending_validation( return [ schemas_cdr.CdrUser( account_type=user.account_type, + school_id=user.school_id, curriculum=curriculum_memberships_mapping.get(user.id, None), promo=user.promo, email=user.email, @@ -174,7 +176,7 @@ async def get_cdr_user( """ if user.id != user_id: if not ( - is_user_member_of_an_allowed_group(user, [GroupType.admin_cdr]) + is_user_member_of_any_group(user, [GroupType.admin_cdr]) or await cruds_cdr.get_sellers_by_group_ids( db=db, group_ids=[g.id for g in user.groups], @@ -195,6 +197,7 @@ async def get_cdr_user( return schemas_cdr.CdrUser( account_type=user_db.account_type, + school_id=user_db.school_id, name=user_db.name, firstname=user_db.firstname, nickname=user_db.nickname, @@ -280,9 +283,13 @@ async def update_cdr_user( floor=user_update.floor, ), ) + await db.commit() except Exception: await db.rollback() raise + except IntegrityError: + await db.rollback() + raise user_db = await get_user_by_id(db, user_id) if not user_db: @@ -302,6 +309,7 @@ async def update_cdr_user( curriculum=schemas_cdr.CurriculumComplete( **curriculum.__dict__, ), + school_id=user_db.school_id, account_type=user_db.account_type, name=user_db.name, firstname=user_db.firstname, @@ -354,7 +362,7 @@ async def get_sellers_by_user_id( **User must be authenticated to use this endpoint** """ - if is_user_member_of_an_allowed_group( + if is_user_member_of_any_group( user=user, allowed_groups=[GroupType.admin_cdr], ): @@ -554,7 +562,7 @@ async def get_all_products( db, [x.id for x in user.groups], ) - if not (sellers or is_user_member_of_an_allowed_group(user, [GroupType.admin_cdr])): + if not (sellers or is_user_member_of_any_group(user, [GroupType.admin_cdr])): raise HTTPException( status_code=403, detail="You must be a seller to get all documents.", @@ -1167,7 +1175,7 @@ async def get_all_sellers_documents( db, [x.id for x in user.groups], ) - if not (sellers or is_user_member_of_an_allowed_group(user, [GroupType.admin_cdr])): + if not (sellers or is_user_member_of_any_group(user, [GroupType.admin_cdr])): raise HTTPException( status_code=403, detail="You must be a seller to get all documents.", @@ -1267,8 +1275,7 @@ async def get_purchases_by_user_id( **User must get his own purchases or be CDR Admin to use this endpoint** """ if not ( - user_id == user.id - or is_user_member_of_an_allowed_group(user, [GroupType.admin_cdr]) + user_id == user.id or is_user_member_of_any_group(user, [GroupType.admin_cdr]) ): raise HTTPException( status_code=403, @@ -1842,8 +1849,7 @@ async def get_signatures_by_user_id( **User must get his own signatures or be CDR Admin to use this endpoint** """ if not ( - user_id == user.id - or is_user_member_of_an_allowed_group(user, [GroupType.admin_cdr]) + user_id == user.id or is_user_member_of_any_group(user, [GroupType.admin_cdr]) ): raise HTTPException( status_code=403, @@ -1919,7 +1925,7 @@ async def create_signature( user_id == user.id and signature.signature_type == DocumentSignatureType.numeric ) - or is_user_member_of_an_allowed_group(user=user, allowed_groups=sellers_groups) + or is_user_member_of_any_group(user=user, allowed_groups=sellers_groups) ): raise HTTPException( status_code=403, @@ -2089,8 +2095,7 @@ async def create_curriculum_membership( **User must add a curriculum to themself or be CDR Admin to use this endpoint** """ if not ( - user_id == user.id - or is_user_member_of_an_allowed_group(user, [GroupType.admin_cdr]) + user_id == user.id or is_user_member_of_any_group(user, [GroupType.admin_cdr]) ): raise HTTPException( status_code=403, @@ -2151,6 +2156,7 @@ async def create_curriculum_membership( message=schemas_cdr.NewUserWSMessageModel( data=schemas_cdr.CdrUser( account_type=db_user.account_type, + school_id=db_user.school_id, curriculum=schemas_cdr.CurriculumComplete( id=wanted_curriculum.id, name=wanted_curriculum.name, @@ -2191,8 +2197,7 @@ async def update_curriculum_membership( **User must add a curriculum to themself or be CDR Admin to use this endpoint** """ if not ( - user_id == user.id - or is_user_member_of_an_allowed_group(user, [GroupType.admin_cdr]) + user_id == user.id or is_user_member_of_any_group(user, [GroupType.admin_cdr]) ): raise HTTPException( status_code=403, @@ -2231,6 +2236,7 @@ async def update_curriculum_membership( message=schemas_cdr.UpdateUserWSMessageModel( data=schemas_cdr.CdrUser( account_type=db_user.account_type, + school_id=db_user.school_id, curriculum=schemas_cdr.CurriculumComplete( id=curriculum.id, name=curriculum.name, @@ -2280,8 +2286,7 @@ async def delete_curriculum_membership( detail="Invalid curriculum_id", ) if not ( - user_id == user.id - or is_user_member_of_an_allowed_group(user, [GroupType.admin_cdr]) + user_id == user.id or is_user_member_of_any_group(user, [GroupType.admin_cdr]) ): raise HTTPException( status_code=403, @@ -2311,6 +2316,7 @@ async def delete_curriculum_membership( message=schemas_cdr.UpdateUserWSMessageModel( data=schemas_cdr.CdrUser( account_type=db_user.account_type, + school_id=db_user.school_id, curriculum=None, promo=db_user.promo, email=db_user.email, @@ -2347,8 +2353,7 @@ async def get_payments_by_user_id( **User must get his own payments or be CDR Admin to use this endpoint** """ if not ( - user_id == user.id - or is_user_member_of_an_allowed_group(user, [GroupType.admin_cdr]) + user_id == user.id or is_user_member_of_any_group(user, [GroupType.admin_cdr]) ): raise HTTPException( status_code=403, @@ -2484,6 +2489,7 @@ async def get_payment_url( ) user_schema = schemas_core.CoreUser( account_type=user.account_type, + school_id=user.school_id, email=user.email, birthday=user.birthday, promo=user.promo, @@ -2535,8 +2541,7 @@ async def get_memberships_by_user_id( user: models_core.CoreUser = Depends(is_user()), ): if not ( - user_id == user.id - or is_user_member_of_an_allowed_group(user, [GroupType.admin_cdr]) + user_id == user.id or is_user_member_of_any_group(user, [GroupType.admin_cdr]) ): raise HTTPException( status_code=403, @@ -2709,8 +2714,7 @@ async def get_tickets_of_user( user: models_core.CoreUser = Depends(is_user()), ): if not ( - is_user_member_of_an_allowed_group(user, [GroupType.admin_cdr]) - or user_id == user.id + is_user_member_of_any_group(user, [GroupType.admin_cdr]) or user_id == user.id ): raise HTTPException( status_code=403, diff --git a/app/modules/cdr/utils_cdr.py b/app/modules/cdr/utils_cdr.py index b80c45640..cbd658fce 100644 --- a/app/modules/cdr/utils_cdr.py +++ b/app/modules/cdr/utils_cdr.py @@ -20,7 +20,7 @@ PaymentType, ) from app.utils.tools import ( - is_user_member_of_an_allowed_group, + is_user_member_of_any_group, ) hyperion_error_logger = logging.getLogger("hyperion.error") @@ -81,7 +81,7 @@ async def is_user_in_a_seller_group( detail="Seller not found.", ) - if is_user_member_of_an_allowed_group( + if is_user_member_of_any_group( user=user, allowed_groups=[str(seller.group_id), GroupType.admin_cdr], ): diff --git a/app/modules/loan/endpoints_loan.py b/app/modules/loan/endpoints_loan.py index 9d327590e..e93e7f269 100644 --- a/app/modules/loan/endpoints_loan.py +++ b/app/modules/loan/endpoints_loan.py @@ -23,7 +23,7 @@ from app.utils.tools import ( is_group_id_valid, is_user_id_valid, - is_user_member_of_an_allowed_group, + is_user_member_of_any_group, ) if TYPE_CHECKING: @@ -186,7 +186,7 @@ async def get_loans_by_loaner( ) # The user should be a member of the loaner's manager group - if not is_user_member_of_an_allowed_group(user, [loaner.group_manager_id]): + if not is_user_member_of_any_group(user, [loaner.group_manager_id]): raise HTTPException( status_code=403, detail=f"Unauthorized to manage {loaner_id} loaner", @@ -244,7 +244,7 @@ async def get_items_by_loaner( detail="Invalid loaner_id", ) # The user should be a member of the loaner's manager group - if not is_user_member_of_an_allowed_group(user, [loaner.group_manager_id]): + if not is_user_member_of_any_group(user, [loaner.group_manager_id]): raise HTTPException( status_code=403, detail=f"Unauthorized to manage {loaner_id} loaner", @@ -288,7 +288,7 @@ async def create_items_for_loaner( detail="Invalid loaner_id", ) # The user should be a member of the loaner's manager group - if not is_user_member_of_an_allowed_group(user, [loaner.group_manager_id]): + if not is_user_member_of_any_group(user, [loaner.group_manager_id]): raise HTTPException( status_code=403, detail=f"Unauthorized to manage {loaner_id} loaner", @@ -366,7 +366,7 @@ async def update_items_for_loaner( detail=f"Item {item_id} does not belong to {loaner_id} loaner", ) # The user should be a member of the loaner's manager group - if not is_user_member_of_an_allowed_group(user, [loaner.group_manager_id]): + if not is_user_member_of_any_group(user, [loaner.group_manager_id]): raise HTTPException( status_code=403, detail=f"Unauthorized to manage {loaner_id} loaner", @@ -407,7 +407,7 @@ async def delete_loaner_item( detail=f"Item {item_id} does not belong to {loaner_id} loaner", ) # The user should be a member of the loaner's manager group - if not is_user_member_of_an_allowed_group(user, [item.loaner.group_manager_id]): + if not is_user_member_of_any_group(user, [item.loaner.group_manager_id]): raise HTTPException( status_code=403, detail=f"Unauthorized to manage {loaner_id} loaner", @@ -482,7 +482,7 @@ async def get_current_user_loaners( user_loaners: list[models_loan.Loaner] = [ loaner for loaner in existing_loaners - if is_user_member_of_an_allowed_group( + if is_user_member_of_any_group( allowed_groups=[loaner.group_manager_id], user=user, ) @@ -520,7 +520,7 @@ async def create_loan( detail="Invalid loaner_id", ) # The user should be a member of the loaner's manager group - if not is_user_member_of_an_allowed_group(user, [loaner.group_manager_id]): + if not is_user_member_of_any_group(user, [loaner.group_manager_id]): raise HTTPException( status_code=403, detail=f"Unauthorized to manage {loan_creation.loaner_id} loaner", @@ -676,7 +676,7 @@ async def update_loan( ) # The user should be a member of the loaner's manager group - if not is_user_member_of_an_allowed_group(user, [loan.loaner.group_manager_id]): + if not is_user_member_of_any_group(user, [loan.loaner.group_manager_id]): raise HTTPException( status_code=403, detail=f"Unauthorized to manage {loan.loaner_id} loaner", @@ -796,7 +796,7 @@ async def delete_loan( ) # The user should be a member of the loaner's manager group - if not is_user_member_of_an_allowed_group(user, [loan.loaner.group_manager_id]): + if not is_user_member_of_any_group(user, [loan.loaner.group_manager_id]): raise HTTPException( status_code=403, detail=f"Unauthorized to manage {loan.loaner_id} loaner", @@ -848,7 +848,7 @@ async def return_loan( ) # The user should be a member of the loaner's manager group - if not is_user_member_of_an_allowed_group(user, [loan.loaner.group_manager_id]): + if not is_user_member_of_any_group(user, [loan.loaner.group_manager_id]): raise HTTPException( status_code=403, detail=f"Unauthorized to manage {loan.loaner_id} loaner", @@ -909,7 +909,7 @@ async def extend_loan( ) end = loan.end # The user should be a member of the loaner's manager group - if not is_user_member_of_an_allowed_group(user, [loan.loaner.group_manager_id]): + if not is_user_member_of_any_group(user, [loan.loaner.group_manager_id]): raise HTTPException( status_code=403, detail=f"Unauthorized to manage {loan.loaner_id} loaner", diff --git a/app/modules/phonebook/endpoints_phonebook.py b/app/modules/phonebook/endpoints_phonebook.py index 4ef86d34b..b23a20044 100644 --- a/app/modules/phonebook/endpoints_phonebook.py +++ b/app/modules/phonebook/endpoints_phonebook.py @@ -21,7 +21,7 @@ from app.types.module import Module from app.utils.tools import ( get_file_from_data, - is_user_member_of_an_allowed_group, + is_user_member_of_any_group, save_file_as_data, ) @@ -115,7 +115,7 @@ async def create_association( **This endpoint is only usable by CAA, BDE** """ - if not is_user_member_of_an_allowed_group( + if not is_user_member_of_any_group( user=user, allowed_groups=[GroupType.CAA, GroupType.BDE], ): @@ -164,7 +164,7 @@ async def update_association( **This endpoint is only usable by CAA, BDE and association's president** """ if not ( - is_user_member_of_an_allowed_group( + is_user_member_of_any_group( user=user, allowed_groups=[GroupType.CAA, GroupType.BDE], ) @@ -225,7 +225,7 @@ async def deactivate_association( **This endpoint is only usable by CAA and BDE** """ - if not is_user_member_of_an_allowed_group( + if not is_user_member_of_any_group( user=user, allowed_groups=[GroupType.CAA, GroupType.BDE], ): @@ -253,7 +253,7 @@ async def delete_association( **This endpoint is only usable by CAA and BDE** """ - if not is_user_member_of_an_allowed_group( + if not is_user_member_of_any_group( user=user, allowed_groups=[GroupType.CAA, GroupType.BDE], ): @@ -405,7 +405,7 @@ async def create_membership( ) if not ( - is_user_member_of_an_allowed_group( + is_user_member_of_any_group( user=user, allowed_groups=[GroupType.CAA, GroupType.BDE], ) @@ -423,7 +423,7 @@ async def create_membership( if membership.role_tags is not None: if RoleTags.president.value in membership.role_tags.split( ";", - ) and not is_user_member_of_an_allowed_group( + ) and not is_user_member_of_any_group( user=user, allowed_groups=[GroupType.CAA, GroupType.BDE], ): @@ -510,7 +510,7 @@ async def update_membership( ) if not ( - is_user_member_of_an_allowed_group( + is_user_member_of_any_group( user=user, allowed_groups=[GroupType.CAA, GroupType.BDE], ) @@ -528,7 +528,7 @@ async def update_membership( if updated_membership.role_tags is not None: if RoleTags.president.value in updated_membership.role_tags.split( ";", - ) and not is_user_member_of_an_allowed_group( + ) and not is_user_member_of_any_group( user=user, allowed_groups=[GroupType.CAA, GroupType.BDE], ): @@ -573,7 +573,7 @@ async def delete_membership( ) if not ( - is_user_member_of_an_allowed_group( + is_user_member_of_any_group( user=user, allowed_groups=[GroupType.CAA, GroupType.BDE], ) @@ -618,7 +618,7 @@ async def create_association_logo( **The user must be a member of the group CAA or BDE to use this endpoint** """ - if not is_user_member_of_an_allowed_group( + if not is_user_member_of_any_group( user=user, allowed_groups=[GroupType.CAA, GroupType.BDE], ) and not await cruds_phonebook.is_user_president( diff --git a/app/modules/raffle/endpoints_raffle.py b/app/modules/raffle/endpoints_raffle.py index 3091b1a0e..ba8a268e2 100644 --- a/app/modules/raffle/endpoints_raffle.py +++ b/app/modules/raffle/endpoints_raffle.py @@ -26,7 +26,7 @@ from app.utils.tools import ( get_display_name, get_file_from_data, - is_user_member_of_an_allowed_group, + is_user_member_of_any_group, save_file_as_data, ) @@ -101,7 +101,7 @@ async def edit_raffle( if not raffle: raise HTTPException(status_code=404, detail="Raffle not found") - if not is_user_member_of_an_allowed_group(user, [raffle.group_id]): + if not is_user_member_of_any_group(user, [raffle.group_id]): raise HTTPException( status_code=403, detail=f"{user.id} user is unauthorized to manage the raffle {raffle_id}", @@ -138,7 +138,7 @@ async def delete_raffle( if not raffle: raise HTTPException(status_code=404, detail="Raffle not found") - if not is_user_member_of_an_allowed_group(user, [raffle.group_id]): + if not is_user_member_of_any_group(user, [raffle.group_id]): raise HTTPException( status_code=403, detail=f"{user.id} user is unauthorized to manage the raffle {raffle_id}", @@ -219,7 +219,7 @@ async def create_current_raffle_logo( if not raffle: raise HTTPException(status_code=404, detail="Raffle not found") - if not is_user_member_of_an_allowed_group(user, [raffle.group_id]): + if not is_user_member_of_any_group(user, [raffle.group_id]): raise HTTPException( status_code=403, detail=f"{user.id} user is unauthorized to manage the raffle {raffle_id}", @@ -306,7 +306,7 @@ async def create_packticket( if not raffle: raise HTTPException(status_code=404, detail="Raffle not found") - if not is_user_member_of_an_allowed_group(user, [raffle.group_id]): + if not is_user_member_of_any_group(user, [raffle.group_id]): raise HTTPException( status_code=403, detail=f"{user.id} user is unauthorized to manage the raffle {packticket.raffle_id}", @@ -349,7 +349,7 @@ async def edit_packticket( if not raffle: raise HTTPException(status_code=404, detail="Raffle not found") - if not is_user_member_of_an_allowed_group(user, [raffle.group_id]): + if not is_user_member_of_any_group(user, [raffle.group_id]): raise HTTPException( status_code=403, detail=f"{user.id} user is unauthorized to manage the raffle {packticket.raffle_id}", @@ -394,7 +394,7 @@ async def delete_packticket( if not raffle: raise HTTPException(status_code=404, detail="Raffle not found") - if not is_user_member_of_an_allowed_group(user, [raffle.group_id]): + if not is_user_member_of_any_group(user, [raffle.group_id]): raise HTTPException( status_code=403, detail=f"{user.id} user is unauthorized to manage the raffle {packticket.raffle_id}", @@ -540,10 +540,7 @@ async def get_tickets_by_userid( if user_db is None: raise HTTPException(status_code=404, detail="User not found") - if not ( - user_id == user.id - or is_user_member_of_an_allowed_group(user, [GroupType.admin]) - ): + if not (user_id == user.id or is_user_member_of_any_group(user, [GroupType.admin])): raise HTTPException( status_code=403, detail="Users that are not member of the group admin can only access the endpoint for their own user_id.", @@ -580,7 +577,7 @@ async def get_tickets_by_raffleid( if not raffle: raise HTTPException(status_code=404, detail="Raffle not found") - if not is_user_member_of_an_allowed_group(user, [raffle.group_id]): + if not is_user_member_of_any_group(user, [raffle.group_id]): raise HTTPException( status_code=403, detail=f"{user.id} user is unauthorized to manage the raffle {raffle_id}", @@ -624,7 +621,7 @@ async def create_prize( if not raffle: raise HTTPException(status_code=404, detail="Raffle not found") - if not is_user_member_of_an_allowed_group(user, [raffle.group_id]): + if not is_user_member_of_any_group(user, [raffle.group_id]): raise HTTPException( status_code=403, detail=f"{user.id} user is unauthorized to manage the raffle {raffle.id}", @@ -667,7 +664,7 @@ async def edit_prize( if not raffle: raise HTTPException(status_code=404, detail="Raffle not found") - if not is_user_member_of_an_allowed_group(user, [raffle.group_id]): + if not is_user_member_of_any_group(user, [raffle.group_id]): raise HTTPException( status_code=403, detail=f"{user.id} user is unauthorized to manage the raffle {raffle.id}", @@ -706,7 +703,7 @@ async def delete_prize( if not raffle: raise HTTPException(status_code=404, detail="Raffle not found") - if not is_user_member_of_an_allowed_group(user, [raffle.group_id]): + if not is_user_member_of_any_group(user, [raffle.group_id]): raise HTTPException( status_code=403, detail=f"{user.id} user is unauthorized to manage the raffle {raffle.id}", @@ -760,7 +757,7 @@ async def create_prize_picture( if not raffle: raise HTTPException(status_code=404, detail="Raffle not found") - if not is_user_member_of_an_allowed_group(user, [raffle.group_id]): + if not is_user_member_of_any_group(user, [raffle.group_id]): raise HTTPException( status_code=403, detail=f"{user.id} user is unauthorized to manage the raffle {raffle.id}", @@ -849,7 +846,7 @@ async def get_cash_by_id( if user_db is None: raise HTTPException(status_code=404, detail="User not found") - if user_id == user.id or is_user_member_of_an_allowed_group( + if user_id == user.id or is_user_member_of_any_group( user, [GroupType.admin], ): @@ -981,7 +978,7 @@ async def draw_winner( if not raffle: raise HTTPException(status_code=404, detail="Raffle not found") - if not is_user_member_of_an_allowed_group(user, [raffle.group_id]): + if not is_user_member_of_any_group(user, [raffle.group_id]): raise HTTPException( status_code=403, detail=f"{user.id} user is unauthorized to manage the raffle {raffle.id}", @@ -1032,7 +1029,7 @@ async def open_raffle( detail=f"You can't mark a raffle as open if it is not in creation mode. The current mode is {raffle.status}.", ) - if not is_user_member_of_an_allowed_group(user, [raffle.group_id]): + if not is_user_member_of_any_group(user, [raffle.group_id]): raise HTTPException( status_code=403, detail=f"{user.id} user is unauthorized to manage the raffle {raffle_id}", @@ -1070,7 +1067,7 @@ async def lock_raffle( detail=f"You can't mark a raffle as locked if it is not in open mode. The current mode is {raffle.status}.", ) - if not is_user_member_of_an_allowed_group(user, [raffle.group_id]): + if not is_user_member_of_any_group(user, [raffle.group_id]): raise HTTPException( status_code=403, detail=f"{user.id} user is unauthorized to manage the raffle {raffle_id}", diff --git a/app/modules/raid/endpoints_raid.py b/app/modules/raid/endpoints_raid.py index 776b6bcb9..8dc104412 100644 --- a/app/modules/raid/endpoints_raid.py +++ b/app/modules/raid/endpoints_raid.py @@ -36,7 +36,7 @@ get_core_data, get_file_from_data, get_random_string, - is_user_member_of_an_allowed_group, + is_user_member_of_any_group, save_file_as_data, set_core_data, ) @@ -65,7 +65,7 @@ async def get_participant_by_id( """ Get a participant by id """ - if participant_id != user.id and not is_user_member_of_an_allowed_group( + if participant_id != user.id and not is_user_member_of_any_group( user, [GroupType.raid_admin], ): @@ -520,7 +520,7 @@ async def read_document( user.id, participant.id, db, - ) and not is_user_member_of_an_allowed_group(user, [GroupType.raid_admin]): + ) and not is_user_member_of_any_group(user, [GroupType.raid_admin]): raise HTTPException( status_code=403, detail="The owner of this document is not a member of your team.", @@ -1087,13 +1087,15 @@ async def get_payment_url( if not participant.payment: checkout_name += " + " checkout_name += "T Shirt taille" + participant.t_shirt_size.value + user_dict = user.__dict__ + user_dict.pop("school", None) checkout = await payment_tool.init_checkout( module=module.root, helloasso_slug="AEECL", checkout_amount=price, checkout_name=checkout_name, redirection_uri=settings.RAID_PAYMENT_REDIRECTION_URL or "", - payer_user=schemas_core.CoreUser(**user.__dict__), + payer_user=schemas_core.CoreUser(**user_dict), db=db, ) hyperion_error_logger.info(f"RAID: Logging Checkout id {checkout.id}") diff --git a/app/utils/auth/providers.py b/app/utils/auth/providers.py index ffbd896b1..c7f544107 100644 --- a/app/utils/auth/providers.py +++ b/app/utils/auth/providers.py @@ -7,12 +7,12 @@ from app.core.groups.groups_type import ( AccountType, GroupType, - get_account_types_except_external, + get_account_types_except_externals, get_ecl_account_types, ) from app.types.floors_type import FloorsType from app.types.scopes_type import ScopeType -from app.utils.tools import get_display_name, is_user_member_of_an_allowed_group +from app.utils.tools import get_display_name, is_user_member_of_any_group class BaseAuthClient: @@ -39,7 +39,7 @@ class BaseAuthClient: # Restrict the authentication to this client to specific Hyperion account types. # When set to `None`, users from any account type can use the auth client allowed_account_types: list[AccountType] | None = ( - get_account_types_except_external() + get_account_types_except_externals() ) # redirect_uri should alway match the one provided by the client redirect_uri: list[str] @@ -152,7 +152,7 @@ def get_userinfo(cls, user: models_core.CoreUser): "groups": [group.name for group in user.groups] + [user.account_type.value], "email": user.email, "picture": f"https://hyperion.myecl.fr/users/{user.id}/profile-picture", - "is_admin": is_user_member_of_an_allowed_group(user, [GroupType.admin]), + "is_admin": is_user_member_of_any_group(user, [GroupType.admin]), } @@ -398,5 +398,5 @@ def get_userinfo(cls, user: models_core.CoreUser): "firstname": user.firstname, "lastname": user.name, "email": user.email, - "is_admin": is_user_member_of_an_allowed_group(user, [GroupType.admin]), + "is_admin": is_user_member_of_any_group(user, [GroupType.admin]), } diff --git a/app/utils/initialization.py b/app/utils/initialization.py index 2aaca3a4d..ebccd9459 100644 --- a/app/utils/initialization.py +++ b/app/utils/initialization.py @@ -4,7 +4,7 @@ from sqlalchemy import Connection, MetaData, delete, select from sqlalchemy.engine import Engine, create_engine from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Session, selectinload +from sqlalchemy.orm import Session from app.core import models_core from app.core.config import Settings @@ -87,10 +87,8 @@ def get_group_by_id_sync(group_id: str, db: Session) -> models_core.CoreGroup | Return group with id from database """ result = db.execute( - select(models_core.CoreGroup) - .where(models_core.CoreGroup.id == group_id) - .options( - selectinload(models_core.CoreGroup.members), + select(models_core.CoreGroup).where( + models_core.CoreGroup.id == group_id, ), # needed to load the members from the relationship ) return result.scalars().first() @@ -140,6 +138,33 @@ def set_core_data_crud_sync( return core_data +def get_school_by_id_sync(school_id: str, db: Session) -> models_core.CoreSchool | None: + """ + Return group with id from database + """ + result = db.execute( + select(models_core.CoreSchool).where(models_core.CoreSchool.id == school_id), + ) + return result.scalars().first() + + +def create_school_sync( + school: models_core.CoreSchool, + db: Session, +) -> models_core.CoreSchool: + """ + Create a new group in database and return it + """ + db.add(school) + try: + db.commit() + except IntegrityError: + db.rollback() + raise + else: + return school + + def delete_core_data_crud_sync(schema: str, db: Session) -> None: """ Delete core data with schema from database diff --git a/app/utils/tools.py b/app/utils/tools.py index 7b0077ad4..913e307e6 100644 --- a/app/utils/tools.py +++ b/app/utils/tools.py @@ -91,7 +91,7 @@ def unaccent(s: str) -> str: return [user for user, _ in reversed(scored)] -def is_user_member_of_an_allowed_group( +def is_user_member_of_any_group( user: models_core.CoreUser, allowed_groups: list[str] | list[GroupType], ) -> bool: diff --git a/migrations/versions/27-schools.py b/migrations/versions/27-schools.py new file mode 100644 index 000000000..c2fbcaaf4 --- /dev/null +++ b/migrations/versions/27-schools.py @@ -0,0 +1,263 @@ +"""schools + +Create Date: 2024-10-26 19:04:51.089828 +""" + +import enum +from collections.abc import Sequence +from typing import TYPE_CHECKING +from uuid import UUID + +if TYPE_CHECKING: + from pytest_alembic import MigrationContext + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "a1e6e8b52103" +down_revision: str | None = "53c163acf327" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + +centrale_regex = r"^[\w\-.]*@(etu(-enise)?\.)?ec-lyon\.fr$" + + +class SchoolType(enum.Enum): + no_school = UUID("dce19aa2-8863-4c93-861e-fb7be8f610ed") + centrale_lyon = UUID("d9772da7-1142-4002-8b86-b694b431dfed") + + +class AccountType(enum.Enum): + student = "student" + former_student = "former_student" + staff = "staff" + association = "association" + external = "external" + demo = "demo" + + +class AccountType2(enum.Enum): + student = "student" + former_student = "former_student" + staff = "staff" + association = "association" + external = "external" + other_school_student = "other_school_student" + demo = "demo" + + +DEMO_ID = "9bccbd61-2af3-4bd6-adb7-2a5e48756f66" +ECLAIR_ID = "e68d744f-472f-49e5-896f-662d83be7b9a" + +ECL_STAFF_REGEX = r"^[\w\-.]*@(enise\.)?ec-lyon\.fr$" +ECL_STUDENT_REGEX = r"^[\w\-.]*@((etu(-enise)?)|(ecl\d{2}))\.ec-lyon\.fr$" +ECL_FORMER_STUDENT_REGEX = r"^[\w\-.]*@centraliens-lyon\.net$" + +school_table = sa.Table( + "core_school", + sa.MetaData(), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("email_regex", sa.String(), nullable=False), +) +user_table = sa.Table( + "core_user", + sa.MetaData(), + sa.Column("id", sa.String(), nullable=False), + sa.Column("email", sa.String(), nullable=False), + sa.Column("account_type", sa.Enum(AccountType, name="accounttype"), nullable=False), + sa.Column("school_id", sa.Uuid(), nullable=False), +) + +visibility_table = sa.Table( + "module_account_type_visibility", + sa.MetaData(), + sa.Column("root", sa.String(), nullable=False), + sa.Column( + "allowed_account_type", + sa.Enum(AccountType, name="accounttype"), + nullable=False, + ), +) + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "core_school", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("name", sa.String(), nullable=False, unique=True), + sa.Column("email_regex", sa.String(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + + conn = op.get_bind() + + for school in SchoolType: + conn.execute( + sa.insert(school_table).values( + id=school.value, + name=school.name, + email_regex=".*", + ), + ) + + users = conn.execute( + sa.select(user_table.c.id, user_table.c.account_type), + ).fetchall() + + visibilities = conn.execute( + sa.select(visibility_table.c.root, visibility_table.c.allowed_account_type), + ).fetchall() + + with op.batch_alter_table("core_user") as batch_op: + batch_op.add_column( + sa.Column( + "school_id", + sa.Uuid(), + nullable=False, + server_default=str(SchoolType.no_school.value), + ), + ) + batch_op.create_foreign_key( + "core_user_school_id", + "core_school", + ["school_id"], + ["id"], + ) + batch_op.drop_column("account_type") + + op.drop_table("module_account_type_visibility") + + sa.Enum(AccountType, name="accounttype").drop( + conn, + ) + + op.create_table( + "module_account_type_visibility", + sa.Column("root", sa.String(), nullable=False), + sa.Column( + "allowed_account_type", + sa.Enum( + AccountType2, + name="accounttype", + ), + nullable=False, + ), + sa.PrimaryKeyConstraint("root", "allowed_account_type"), + ) + + with op.batch_alter_table("core_user") as batch_op: + batch_op.add_column( + sa.Column( + "account_type", + sa.Enum(AccountType2, name="accounttype"), + nullable=False, + server_default="external", + ), + ) + + for user in users: + conn.execute( + sa.update(user_table) + .where(user_table.c.id == user.id) + .values( + account_type=user.account_type, + school_id=SchoolType.centrale_lyon.value + if user.account_type != AccountType.external + else SchoolType.no_school.value, + ), + ) + + for visibility in visibilities: + conn.execute( + sa.insert(visibility_table).values( + root=visibility.root, + allowed_account_type=visibility.allowed_account_type, + ), + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + + conn = op.get_bind() + + users = conn.execute( + sa.select(user_table.c.id, user_table.c.account_type), + ).fetchall() + + visibilities = conn.execute( + sa.select(visibility_table.c.root, visibility_table.c.allowed_account_type), + ).fetchall() + + with op.batch_alter_table("core_user") as batch_op: + batch_op.drop_constraint("core_user_school_id", type_="foreignkey") + batch_op.drop_column("school_id") + batch_op.drop_column("account_type") + + op.drop_table("module_account_type_visibility") + + sa.Enum(AccountType2, name="accounttype").drop( + op.get_bind(), + ) + + op.create_table( + "module_account_type_visibility", + sa.Column("root", sa.String(), nullable=False), + sa.Column( + "allowed_account_type", + sa.Enum(AccountType, name="accounttype"), + nullable=False, + ), + sa.PrimaryKeyConstraint("root", "allowed_account_type"), + ) + + with op.batch_alter_table("core_user") as batch_op: + batch_op.add_column( + sa.Column( + "account_type", + sa.Enum(AccountType, name="accounttype"), + nullable=False, + server_default="external", + ), + ) + + op.drop_table("core_school") + + for user in users: + conn.execute( + sa.update(user_table) + .where(user_table.c.id == user.id) + .values( + account_type=user.account_type + if user.account_type != AccountType2.other_school_student + else AccountType.external, + ), + ) + + for visibility in visibilities: + conn.execute( + sa.insert(visibility_table).values( + root=visibility.root, + allowed_account_type=visibility.allowed_account_type, + ), + ) + # ### end Alembic commands ### + + +def pre_test_upgrade( + alembic_runner: "MigrationContext", + alembic_connection: sa.Connection, +) -> None: + pass + + +def test_upgrade( + alembic_runner: "MigrationContext", + alembic_connection: sa.Connection, +) -> None: + pass diff --git a/tests/commons.py b/tests/commons.py index fb5dee6db..cde930ffa 100644 --- a/tests/commons.py +++ b/tests/commons.py @@ -17,6 +17,7 @@ from app.core.groups.groups_type import AccountType, GroupType from app.core.payment import cruds_payment, models_payment, schemas_payment from app.core.payment.payment_tool import PaymentTool +from app.core.schools.schools_type import SchoolType from app.core.users import cruds_users from app.dependencies import get_settings from app.types.exceptions import RedisConnectionError @@ -119,6 +120,7 @@ def override_get_scheduler( async def create_user_with_groups( groups: list[GroupType], account_type: AccountType = AccountType.student, + school_id: SchoolType | uuid.UUID = SchoolType.centrale_lyon, user_id: str | None = None, email: str | None = None, password: str | None = None, @@ -135,10 +137,12 @@ async def create_user_with_groups( user_id = user_id or str(uuid.uuid4()) password_hash = security.get_password_hash(password or get_random_string()) + school_id = school_id.value if isinstance(school_id, SchoolType) else school_id user = models_core.CoreUser( id=user_id, email=email or (get_random_string() + "@etu.ec-lyon.fr"), + school_id=school_id, password_hash=password_hash, name=name or get_random_string(), firstname=firstname or get_random_string(), diff --git a/tests/test_payment.py b/tests/test_payment.py index 8f5a09055..37cbb61fe 100644 --- a/tests/test_payment.py +++ b/tests/test_payment.py @@ -76,7 +76,15 @@ async def init_objects() -> None: user = await create_user_with_groups( groups=[], ) - user_schema = schemas_core.CoreUser(**user.__dict__) + school = schemas_core.CoreSchool( + id=user.school.id, + name=user.school.name, + email_regex=user.school.email_regex, + ) + user_dict = user.__dict__ + user_dict.pop("school") + + user_schema = schemas_core.CoreUser(**user_dict, school=school) # Test endpoints # diff --git a/tests/test_schools.py b/tests/test_schools.py new file mode 100644 index 000000000..ab6f3c29e --- /dev/null +++ b/tests/test_schools.py @@ -0,0 +1,251 @@ +from uuid import UUID + +import pytest_asyncio +from fastapi.testclient import TestClient +from pytest_mock import MockerFixture + +from app.core import models_core +from app.core.groups.groups_type import AccountType, GroupType +from app.core.schools.schools_type import SchoolType +from tests.commons import ( + add_object_to_db, + create_api_access_token, + create_user_with_groups, +) + +admin_user: models_core.CoreUser +ens_user: models_core.CoreUser +fake_ens_user: models_core.CoreUser +new_school_user: models_core.CoreUser + +UNIQUE_TOKEN = "my_unique_token" + +id_test_ens = UUID("4d133de7-24c4-4dbc-be73-4705a2ddd315") + + +@pytest_asyncio.fixture(scope="module", autouse=True) +async def init_objects() -> None: + global admin_user, ens_user, fake_ens_user, new_school_user + + ens = models_core.CoreSchool( + id=id_test_ens, + name="ENS", + email_regex=r"^.*@.*ens.fr$", + ) + await add_object_to_db(ens) + + admin_user = await create_user_with_groups([GroupType.admin]) + + ens_user = await create_user_with_groups( + [], + school_id=id_test_ens, + email="test@ens.fr", + account_type=AccountType.other_school_student, + ) + fake_ens_user = await create_user_with_groups( + [], + school_id=id_test_ens, + email="test@fakeens.fr", + account_type=AccountType.other_school_student, + ) + + new_school_user = await create_user_with_groups( + [], + school_id=SchoolType.no_school, + email="test@school.fr", + account_type=AccountType.external, + ) + + +def test_read_schools(client: TestClient) -> None: + token = create_api_access_token(admin_user) + + response = client.get( + "/schools/", + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 200 + + +def test_read_school(client: TestClient) -> None: + token = create_api_access_token(admin_user) + + response = client.get( + f"/schools/{id_test_ens}", + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["name"] == "ENS" + + +def test_create_school_with_used_name(client: TestClient) -> None: + token = create_api_access_token(admin_user) + + response = client.post( + "/schools/", + json={ + "name": "ENS", + "email_regex": r"^.*@ens.fr$", + }, + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 400 + + +def test_create_school(client: TestClient) -> None: + token = create_api_access_token(admin_user) + + school = client.post( + "/schools/", + json={ + "name": "school", + "email_regex": r"^.*@school\.fr$", + }, + headers={"Authorization": f"Bearer {token}"}, + ) + assert school.status_code == 201 + + response = client.get( + f"/users/{new_school_user.id}", + headers={"Authorization": f"Bearer {token}"}, + ) + data = response.json() + assert data["school_id"] == school.json()["id"] + + +def test_update_school_with_used_name(client: TestClient) -> None: + token = create_api_access_token(admin_user) + + response = client.patch( + f"/schools/{id_test_ens}", + json={"name": "centrale_lyon"}, + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 400 + + +def test_update_school(client: TestClient) -> None: + token = create_api_access_token(admin_user) + + response = client.patch( + f"/schools/{id_test_ens}", + json={"name": "school ENS", "email_regex": r"^.*@ens.fr$"}, + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 204 + + response = client.get( + f"/schools/{id_test_ens}", + headers={"Authorization": f"Bearer {token}"}, + ) + data = response.json() + assert data["name"] == "school ENS" + + response = client.get( + f"/users/{ens_user.id}", + headers={"Authorization": f"Bearer {token}"}, + ) + data = response.json() + assert data["school_id"] == str(id_test_ens) + + response = client.get( + f"/users/{fake_ens_user.id}", + headers={"Authorization": f"Bearer {token}"}, + ) + data = response.json() + assert data["school_id"] == str(SchoolType.no_school.value) + + +def test_create_user_corresponding_to_school( + mocker: MockerFixture, + client: TestClient, +) -> None: + token = create_api_access_token(admin_user) + + response = client.post( + "/schools/", + json={ + "name": "ENS Lyon", + "email_regex": r"^[\w\-.]*@ens-lyon\.fr$", + }, + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 201 + school_id = response.json()["id"] + + mocker.patch( + "app.core.users.endpoints_users.security.generate_token", + return_value=UNIQUE_TOKEN, + ) + + response = client.post( + "/users/create", + json={ + "email": "new_user@ens-lyon.fr", + }, + ) + assert response.status_code == 201 + + response = client.post( + "/users/activate", + json={ + "activation_token": UNIQUE_TOKEN, + "password": "password", + "firstname": "new_user_firstname", + "name": "new_user_name", + }, + ) + + assert response.status_code == 201 + + users = client.get( + "/users", + headers={"Authorization": f"Bearer {token}"}, + ) + user = next( + user + for user in users.json() + if user["firstname"] == "new_user_firstname" and user["name"] == "new_user_name" + ) + + user_detail = client.get( + f"/users/{user['id']}", + headers={"Authorization": f"Bearer {token}"}, + ) + + assert user_detail.json()["school_id"] == school_id + + +def test_delete_base_school(client: TestClient) -> None: + token = create_api_access_token(admin_user) + + response = client.delete( + f"/schools/{SchoolType.centrale_lyon.value}", + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 400 + + +def test_delete_school(client: TestClient) -> None: + token = create_api_access_token(admin_user) + + response = client.delete( + f"/schools/{id_test_ens}", + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 204 + + response = client.get( + f"/schools/{id_test_ens}", + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 404 + assert response.json() == {"detail": "School not found"} + + response = client.get( + f"/users/{ens_user.id}", + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 200 + assert response.json()["school_id"] == str(SchoolType.no_school.value) diff --git a/tests/test_user_fusion.py b/tests/test_user_fusion.py index a34df7da6..5b4ae4965 100644 --- a/tests/test_user_fusion.py +++ b/tests/test_user_fusion.py @@ -57,7 +57,7 @@ async def init_objects() -> None: user_id=student_user_to_keep.id, membership=AvailableAssociationMembership.aeecl, start_date=datetime.now(tz=UTC).date() - timedelta(days=565), - end_date=datetime.now(tz=UTC).date() + timedelta(days=165), + end_date=datetime.now(tz=UTC).date() + timedelta(days=465), ) await add_object_to_db(core_association_membership_user_kept) diff --git a/tests/test_users.py b/tests/test_users.py index a7076df5f..8dc4a7ed0 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -2,11 +2,14 @@ import pytest import pytest_asyncio +from fastapi import HTTPException from fastapi.testclient import TestClient from pytest_mock import MockerFixture from app.core import models_core from app.core.groups.groups_type import AccountType, GroupType +from app.core.schools.schools_type import SchoolType +from app.dependencies import is_user from tests.commons import ( create_api_access_token, create_user_with_groups, @@ -107,6 +110,16 @@ def test_get_account_types(client: TestClient) -> None: assert data == list(AccountType) +def test_restrict_access_on_group(client: TestClient) -> None: + with pytest.raises( + HTTPException, + match="Unauthorized, user is a member of any of the groups ", + ): + is_user( + excluded_groups=[GroupType.amap], + )(user_with_group) + + def test_read_current_user(client: TestClient) -> None: token = create_api_access_token(student_user) response = client.get( @@ -158,19 +171,35 @@ def test_create_user_by_user_with_email( @pytest.mark.parametrize( - ("email", "expected_code", "expected_account_type"), + ("email", "expected_code", "expected_account_type", "expected_school_id"), [ - ("fab1@etu.ec-lyon.fr", 201, AccountType.student), - ("fab2@ec-lyon.fr", 201, AccountType.staff), - ("fab3@centraliens-lyon.net", 201, AccountType.former_student), - ("fab4@test.fr", 201, AccountType.external), - ("fab5@ecl22.ec-lyon.fr", 201, AccountType.student), + ( + "fab1@etu.ec-lyon.fr", + 201, + AccountType.student, + SchoolType.centrale_lyon, + ), + ("fab2@ec-lyon.fr", 201, AccountType.staff, SchoolType.centrale_lyon), + ( + "fab3@centraliens-lyon.net", + 201, + AccountType.former_student, + SchoolType.centrale_lyon, + ), + ("fab4@test.fr", 201, AccountType.external, SchoolType.no_school), + ( + "fab5@ecl22.ec-lyon.fr", + 201, + AccountType.student, + SchoolType.centrale_lyon, + ), ], ) def test_create_and_activate_user( email: str, expected_code: int, expected_account_type: AccountType, + expected_school_id: SchoolType, mocker: MockerFixture, client: TestClient, ) -> None: @@ -211,6 +240,13 @@ def test_create_and_activate_user( assert user is not None assert user["account_type"] == expected_account_type.value + user_detail = client.get( + f"/users/{user['id']}", + headers={"Authorization": f"Bearer {token_admin_user}"}, + ) + + assert user_detail.json()["school_id"] == str(expected_school_id.value) + @pytest.mark.parametrize( ("email", "expected_error"), @@ -326,6 +362,18 @@ def test_update_user(client: TestClient) -> None: assert response.status_code == 204 +async def test_migrate_mail_with_school_change(client: TestClient) -> None: + token = create_api_access_token(student_user_with_old_email) + + # Start the migration process + response = client.post( + "/users/migrate-mail", + json={"new_email": "fabristpp.eclair@gmail.com"}, + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 400 + + async def test_migrate_mail(mocker: MockerFixture, client: TestClient) -> None: # NOTE: we don't want to mock app.core.security.generate_token but # app.core.users.endpoints_users.security.generate_token which is the imported version of the function