diff --git a/controllers/v2/shift/api.py b/controllers/v2/shift/api.py index fe00f716..578a382a 100644 --- a/controllers/v2/shift/api.py +++ b/controllers/v2/shift/api.py @@ -1,8 +1,10 @@ from flask_restful import reqparse, Resource, marshal_with, inputs, marshal + +from exception.client_exception import ConflictError from .response_models import shift -from domain import UserType +from domain import UserType, ShiftVolunteerStatus from repository.shift_repository import ShiftRepository -from services.jwk import requires_auth, is_user_or_has_role +from services.jwk import requires_auth, is_user_or_has_role, requires_admin, has_role from controllers.v2.v2_blueprint import v2_api import logging @@ -11,7 +13,8 @@ parser.add_argument('title', type=str) parser.add_argument('start', type=inputs.datetime_from_iso8601, required=True, help="Start time cannot be blank!") parser.add_argument('end', type=inputs.datetime_from_iso8601, required=True, help="End time cannot be blank!") -parser.add_argument('roles', type=list, location='json', required=True, help="Roles cannot be blank!") +parser.add_argument('vehicle_type', type=int, required=True, help="Vehicle type cannot be blank!") + parser_modify_status = reqparse.RequestParser() parser_modify_status.add_argument('status', type=str, location='json', required=True, help="Status cannot be blank!") @@ -22,29 +25,50 @@ class VolunteerShiftV2(Resource): def __init__(self, shift_repository: ShiftRepository = ShiftRepository()): self.shift_repository = shift_repository + @requires_auth + @has_role(UserType.ROOT_ADMIN) + def post(self, user_id): + try: + args = parser.parse_args() + title = args['title'] + start = args['start'] + end = args['end'] + vehicle_type = args['vehicle_type'] + new_shift_id = self.shift_repository.post_shift_request(user_id, title, start, end, vehicle_type) + if new_shift_id: + return {"shift_id": new_shift_id}, 200 + else: + return {"message": "Failed to create shift."}, 400 + except Exception as e: + logging.error(f"Error creating new shift request: {e}") + return {"message": "Internal server error"}, 500 + + @requires_auth @is_user_or_has_role(None, UserType.ROOT_ADMIN) def get(self, user_id): try: shifts = self.shift_repository.get_shift(user_id) - if shifts: - return marshal(shifts, shift), 200 - else: - return {"message": "No shift record found."}, 400 + return marshal(shifts, shift), 200 except Exception as e: logging.error(f"Error retrieving shifts for user {user_id}: {e}") return {"message": "Internal server error"}, 500 - + @requires_auth + @is_user_or_has_role(None, UserType.ROOT_ADMIN) def put(self, user_id, shift_id): args = parser_modify_status.parse_args() status = args["status"] + status_enum = ShiftVolunteerStatus[status.upper()] try: - success = self.shift_repository.update_shift_status(user_id, shift_id, status) + success = self.shift_repository.update_shift_status(user_id, shift_id, status_enum) if success: return {"message": "Status updated successfully"}, 200 else: return {"message": "No user or shift record is found, status not updated."}, 400 + except ConflictError as e: # Handle conflict error + logging.error(f"Conflict when updating shift for user {user_id}: {e}") + return {"message": "Shift time conflict detected. Cannot confirm shift."}, 409 except Exception as e: logging.error(f"Error updating shifts for user {user_id}: {e}") return {"message": "Internal server error"}, 500 diff --git a/controllers/v2/shift/response_models.py b/controllers/v2/shift/response_models.py index 47d202aa..87fefecb 100644 --- a/controllers/v2/shift/response_models.py +++ b/controllers/v2/shift/response_models.py @@ -2,8 +2,8 @@ shift = { 'shiftId': fields.Integer, - 'status': fields.String, 'title': fields.String, 'start': fields.DateTime(dt_format='iso8601'), - 'end': fields.DateTime(dt_format='iso8601') + 'end': fields.DateTime(dt_format='iso8601'), + 'status': fields.String } diff --git a/controllers/v2/unavailability/api.py b/controllers/v2/unavailability/api.py index 0145cfa5..e07dd027 100644 --- a/controllers/v2/unavailability/api.py +++ b/controllers/v2/unavailability/api.py @@ -67,12 +67,11 @@ def __init__(self, event_repository: EventRepository = EventRepository()): @requires_auth @is_user_or_has_role(None, UserType.ROOT_ADMIN) def get(self, user_id): - volunteer_unavailability_record = self.event_repository.get_event(user_id) - if volunteer_unavailability_record is not None and volunteer_unavailability_record != []: - return volunteer_unavailability_record - elif volunteer_unavailability_record == []: - return {"message": "No unavailability record found."}, 400 - else: + try: + volunteer_unavailability_record = self.event_repository.get_event(user_id) + return volunteer_unavailability_record, 200 + except Exception as e: + logging.error(f"Error retrieving shifts for user {user_id}: {e}") return {"message": "Internal server error"}, 500 @requires_auth diff --git a/domain/entity/__init__.py b/domain/entity/__init__.py index 36b5620e..dd68c272 100644 --- a/domain/entity/__init__.py +++ b/domain/entity/__init__.py @@ -15,5 +15,4 @@ from .shift_request import ShiftRequest from .shift_request_volunteer import ShiftRequestVolunteer from .shift_position import ShiftPosition -from .fcm_tokens import FCMToken - +from .fcm_tokens import FCMToken \ No newline at end of file diff --git a/domain/entity/shift_position.py b/domain/entity/shift_position.py index ee41a167..c4fff299 100644 --- a/domain/entity/shift_position.py +++ b/domain/entity/shift_position.py @@ -1,21 +1,16 @@ -from datetime import datetime - -from sqlalchemy import Column, String, DateTime, ForeignKey, Integer, Enum +from sqlalchemy import Column, String, ForeignKey, Integer from sqlalchemy.orm import relationship - from domain.base import Base class ShiftPosition(Base): __tablename__ = 'shift_position' - id = Column(Integer, primary_key=True, autoincrement=True) - shift_id = Column(Integer, ForeignKey('shift_request.id'), nullable=False) - role_code = Column(String(256), ForeignKey('role.code'), nullable=False) + shift_id = Column(Integer, ForeignKey('shift_request.id'), name='shift_id', nullable=False) + role_code = Column(String(256), ForeignKey('role.code'), name='role_code', nullable=False) # Many-to-one relationship with Role role = relationship("Role") - # One-to-one relationship with ShiftRequestVolunteer using backref volunteer = relationship("ShiftRequestVolunteer", uselist=False, backref="shift_position", primaryjoin="ShiftPosition.id == ShiftRequestVolunteer.position_id") diff --git a/domain/entity/shift_request.py b/domain/entity/shift_request.py index cbfd37e1..79d64233 100644 --- a/domain/entity/shift_request.py +++ b/domain/entity/shift_request.py @@ -15,13 +15,13 @@ class ShiftRequest(Base): title = Column(String(29), name='title', nullable=False) startTime = Column(DateTime, name='from', nullable=False) endTime = Column(DateTime, name='to', nullable=False) - status = Column(Enum(ShiftStatus), name='status', default=ShiftStatus.WAITING, nullable=False) + status = Column(Enum(ShiftStatus), name='status', default=ShiftStatus.SUBMITTED, nullable=False) update_date_time = Column(DateTime, name='last_update_datetime', default=datetime.now(), nullable=False) insert_date_time = Column(DateTime, name='created_datetime', default=datetime.now(), nullable=False) - Column() - user = relationship("User") + + user = relationship("User") # One-to-many relationship: A shift can have multiple positions positions = relationship("ShiftPosition", backref="shift_request") - - + # One-to-many relationship: A shift can be assigned to many volunteers + volunteers = relationship("ShiftRequestVolunteer", backref="shift_request") diff --git a/domain/entity/shift_request_volunteer.py b/domain/entity/shift_request_volunteer.py index 93f2de9b..1ab6337a 100644 --- a/domain/entity/shift_request_volunteer.py +++ b/domain/entity/shift_request_volunteer.py @@ -18,5 +18,4 @@ class ShiftRequestVolunteer(Base): update_date_time = Column(DateTime, name='last_update_datetime', default=datetime.now(), nullable=False) insert_date_time = Column(DateTime, name='created_datetime', default=datetime.now(), nullable=False) - shift_request = relationship("ShiftRequest") user = relationship("User") \ No newline at end of file diff --git a/domain/entity/unavailability_time.py b/domain/entity/unavailability_time.py index 67759b12..83be3477 100644 --- a/domain/entity/unavailability_time.py +++ b/domain/entity/unavailability_time.py @@ -12,4 +12,5 @@ class UnavailabilityTime(Base): start = Column(DateTime, nullable=False, default=datetime.now()) end = Column(DateTime, nullable=False, default=datetime.now()) status = Column(Boolean, nullable=False, default=1) + is_shift = Column(Boolean, nullable=False, default=False) UniqueConstraint(eventId, userId, name='event') diff --git a/domain/type/shift_record.py b/domain/type/shift_record.py index ecbaf730..aa2507ed 100644 --- a/domain/type/shift_record.py +++ b/domain/type/shift_record.py @@ -7,7 +7,7 @@ @dataclass class ShiftRecord: shiftId: int - status: ShiftVolunteerStatus title: str start: datetime - end: datetime \ No newline at end of file + end: datetime + status: ShiftVolunteerStatus \ No newline at end of file diff --git a/domain/type/shift_status.py b/domain/type/shift_status.py index 314deff6..8ff34e01 100644 --- a/domain/type/shift_status.py +++ b/domain/type/shift_status.py @@ -2,7 +2,7 @@ class ShiftStatus(Enum): - WAITING = "waiting" - UNSUBMITTED = "un-submitted" - INPROGRESS = "in-progress" + PENDING = "pending" + SUBMITTED = "submitted" + CONFIRMED = "confirmed" COMPLETED = "completed" \ No newline at end of file diff --git a/domain/type/shift_volunteer_status.py b/domain/type/shift_volunteer_status.py index 25c73532..54f0cd8d 100644 --- a/domain/type/shift_volunteer_status.py +++ b/domain/type/shift_volunteer_status.py @@ -3,5 +3,5 @@ class ShiftVolunteerStatus(Enum): PENDING = "pending" - CONFIRMED = "confirmed" + ACCEPTED = "accepted" REJECTED = "rejected" \ No newline at end of file diff --git a/exception/client_exception.py b/exception/client_exception.py index e252c5af..ee9a24fc 100644 --- a/exception/client_exception.py +++ b/exception/client_exception.py @@ -18,3 +18,12 @@ def __init__(self, *args): def __str__(self): # Optionally customize the string representation for this specific error return f"InvalidArgumentError: unexpected values in the payload" + +class ConflictError(FireAppException): + def __init__(self, *args): + # Call the superclass constructor with a default message and any additional arguments + super().__init__(f"Shift time conflict detected", *args) + + def __str__(self): + # Optionally customize the string representation for this specific error + return f"ConflictError: This shift is conflict with other confirmed shift." \ No newline at end of file diff --git a/repository/shift_repository.py b/repository/shift_repository.py index ab28d227..2eec1c10 100644 --- a/repository/shift_repository.py +++ b/repository/shift_repository.py @@ -1,13 +1,10 @@ import logging from typing import List -from dataclasses import asdict -from flask import jsonify +from datetime import datetime -from datetime import datetime, timezone - -from exception import EventNotFoundError, InvalidArgumentError -from domain import session_scope, ShiftRequestVolunteer, ShiftRequest, ShiftRecord +from exception.client_exception import ConflictError +from domain import session_scope, ShiftRequestVolunteer, ShiftRequest, ShiftPosition, Role, ShiftRecord, ShiftStatus, UnavailabilityTime, ShiftVolunteerStatus class ShiftRepository: @@ -15,6 +12,103 @@ class ShiftRepository: def __init__(self): pass + def post_shift_request(self, user_id, title, start_time, end_time, vehicle_type): + """ + Creates a new shift request and associated shift positions based on the vehicle type. + + Parameters: + ---------- + user_id : int + The ID of the user creating the shift request. + title : str + The title of the shift request. + start_time : datetime + The start time of the shift. + end_time : datetime + The end time of the shift. + vehicle_type : int + The type of vehicle associated with the shift (determines roles). + + Returns: + ------- + int or None + The ID of the newly created shift request if successful, otherwise None. + """ + now = datetime.now() # Get the current timestamp + with session_scope() as session: + try: + # Create a new ShiftRequest object and populate its fields + shift_request = ShiftRequest( + user_id=user_id, + title=title, + startTime=start_time, + endTime=end_time, + status=ShiftStatus.PENDING, # need to be changed to submitted after linda pr approved + update_date_time=now, # Update timestamp + insert_date_time=now # Insert timestamp + ) + + # Add the ShiftRequest object to the session and commit + session.add(shift_request) + session.commit() + # Add roles to the position table + positions_created = self.create_positions(session, shift_request.id, vehicle_type) + if not positions_created: + session.rollback() + return None + session.commit() + return shift_request.id # Optionally return the created ShiftRequest object id + + except Exception as e: + logging.error(f"Error creating new shift request: {e}") + session.rollback() + return None + + + + + def create_positions(self, session, shiftId, vehicleType): + """ + Creates shift positions based on the vehicle type for a given shift request. + + Parameters: + ---------- + session : Session + The active database session. + shiftId : int + The ID of the shift request to create positions for. + vehicleType : int + The type of vehicle, which determines the roles to be assigned. + + Returns: + ------- + bool + True if positions are successfully created, otherwise False. + """ + try: + if vehicleType == 1: # Heavy Tanker + roleCodes = ['crewLeader', 'driver', 'advanced', 'advanced', 'basic', 'basic'] + elif vehicleType == 2: # Medium Tanker + roleCodes = ['crewLeader', 'driver', 'advanced', 'basic'] + elif vehicleType == 3: # Light Unit + roleCodes = ['driver', 'basic'] + else: + logging.error(f"Invalid vehicle type: {vehicleType}") + return False + + # create record in ShiftPosition + for roleCode in roleCodes: + shift_position = ShiftPosition( + shift_id=shiftId, + role_code=roleCode + ) + session.add(shift_position) + return True + except Exception as e: + logging.error(f"Error creating positions: {e}") + return False + + def get_shift(self, userId) -> List[ShiftRecord]: """ Retrieves all shift events for a given user that have not ended yet. @@ -27,8 +121,8 @@ def get_shift(self, userId) -> List[ShiftRecord]: try: # only show the shift that is end in the future shifts = session.query(ShiftRequestVolunteer).join(ShiftRequest).filter( - ShiftRequestVolunteer.user_id == userId, - ShiftRequest.endTime > now).all() + ShiftRequestVolunteer.user_id == userId, + ShiftRequest.endTime > now).all() # check if there's some results if not shifts: logging.info(f"No active shifts found for user {userId}") @@ -38,17 +132,17 @@ def get_shift(self, userId) -> List[ShiftRecord]: for shift in shifts: shift_record = ShiftRecord( shiftId=shift.request_id, - status=shift.status.value, title=shift.shift_request.title, start=shift.shift_request.startTime, - end=shift.shift_request.endTime) + end=shift.shift_request.endTime, + status=shift.status.value) shift_records.append(shift_record) return shift_records except Exception as e: logging.error(f"Error retrieving shifts for user {userId}: {e}") - return [] + raise - def update_shift_status(self, user_id, shift_id, new_status): + def update_shift_status(self, user_id, shift_id, new_status: ShiftVolunteerStatus): """ Updates the status of a volunteer's shift request in the database. @@ -74,18 +168,74 @@ def update_shift_status(self, user_id, shift_id, new_status): # If record exists, update the status if shift_request_volunteer: + # check for conflict + is_conflict = self.check_conflict_shifts(session, user_id, shift_id) + if is_conflict and new_status == ShiftVolunteerStatus.ACCEPTED: + # Raise the ConflictError if there's a conflict + raise ConflictError(f"Shift {shift_id} conflicts with other confirmed shifts.") + # update status shift_request_volunteer.status = new_status shift_request_volunteer.last_update_datetime = datetime.now() + # If the new status is CONFIRMED, add an unavailability time record + if new_status == ShiftVolunteerStatus.ACCEPTED: + # Fetch start and end times from the ShiftRequest table + shift_request = session.query(ShiftRequest).filter_by(id=shift_id).first() + if shift_request: + unavailability_record = UnavailabilityTime( + userId=user_id, + title=f"shift {shift_id}", + periodicity=1, + start=shift_request.startTime, + end=shift_request.endTime, + is_shift=True + ) + session.add(unavailability_record) session.commit() return True else: logging.info(f"No shift request volunteer with user id {user_id} and shift {shift_id} not found") return False + except ConflictError as e: + logging.error(f"ConflictError: {e}") + raise except Exception as e: session.rollback() logging.error(f"Error updating shift request for user {user_id} and shift_id {shift_id}: {e}") return False + def check_conflict_shifts(self, session, userId, shiftId): + """ + Check if a given user has any conflicting confirmed shifts with the current shift request. + + :param session: Database session. + :param userId: the user id of the current shift request to check for conflicts. + :param shiftId: the ID of the current shift request to check for conflicts. + :return: True if there is a conflict, False if no conflicts are found. + """ + try: + # Query all confirmed shifts for the user + confirmed_shifts = session.query(ShiftRequestVolunteer).join(ShiftRequest).filter( + ShiftRequestVolunteer.user_id == userId, + ShiftRequestVolunteer.status == ShiftVolunteerStatus.ACCEPTED + ).all() + # The current shift information with start time and end time + current_shift_information = session.query(ShiftRequestVolunteer).join(ShiftRequest).filter( + ShiftRequestVolunteer.user_id == userId, + ShiftRequestVolunteer.request_id == shiftId + ).first() + # Iterate over all confirmed shifts and check for time conflicts + for shift in confirmed_shifts: + if (shift.shift_request.startTime < current_shift_information.shift_request.endTime and + current_shift_information.shift_request.startTime < shift.shift_request.endTime): + # A conflict is found if the time ranges overlap + return True + # If no conflicts are found, return False + return False + except Exception as e: + # Log the error and return False in case of an exception + logging.error(f"Error checking shift conflicts for user {userId} and request {shiftId}: {e}") + return False + def save_shift_volunteers(self, volunteers: List[ShiftRequestVolunteer]) -> None: """ Saves a list of ShiftRequestVolunteer objects to the database. @@ -98,4 +248,4 @@ def save_shift_volunteers(self, volunteers: List[ShiftRequestVolunteer]) -> None except Exception as e: session.rollback() logging.error(f"Error saving shift volunteers: {e}") - raise + raise \ No newline at end of file diff --git a/repository/volunteer_unavailability_v2.py b/repository/volunteer_unavailability_v2.py index 50f141a3..b8bb2c4d 100644 --- a/repository/volunteer_unavailability_v2.py +++ b/repository/volunteer_unavailability_v2.py @@ -45,7 +45,7 @@ def get_event(self, userId): try: # only show the unavailability time that is end in the future events = session.query(UnavailabilityTime).filter( - UnavailabilityTime.userId == userId, UnavailabilityTime.status == 1, + UnavailabilityTime.userId == userId, UnavailabilityTime.status == 1, UnavailabilityTime.is_shift == False, UnavailabilityTime.end > now).all() if events: event_records = [] @@ -64,8 +64,8 @@ def get_event(self, userId): else: return [] except Exception as e: - logging.error(e) - return None + logging.error(f"Database error occurred: {e}") + raise # copy from repository.unavailability_repository.py def create_event(self, userId, title, startTime, endTime, periodicity): diff --git a/tests/functional/test_unavailability_get.py b/tests/functional/test_unavailability_get.py index ffbdbd1a..27819592 100644 --- a/tests/functional/test_unavailability_get.py +++ b/tests/functional/test_unavailability_get.py @@ -18,8 +18,8 @@ def test_get_volunteer_unavailability_no_records(test_client, create_user): user_id = create_user # Assuming this user has no unavailability records test_client.post(f"/v2/volunteers/{user_id}/unavailability") response = test_client.get(f"/v2/volunteers/{user_id}/unavailability") - assert response.status_code == 400 # Assuming the endpoint returns a 400 status for no records found - assert response.json == {"message": "No unavailability record found."} # Expected response body for no records + assert response.status_code == 200 # Assuming the endpoint returns a 200 status for no records found (empty list) + # assert response.json == {"message": "No unavailability record found."} # Expected response body for no records