diff --git a/src/app/api/api_v1/endpoints/detections.py b/src/app/api/api_v1/endpoints/detections.py index bee700ac..f89642b8 100644 --- a/src/app/api/api_v1/endpoints/detections.py +++ b/src/app/api/api_v1/endpoints/detections.py @@ -25,10 +25,12 @@ from app.crud import CameraCRUD, DetectionCRUD, WebhookCRUD from app.models import Camera, Detection, Role, UserRole from app.schemas.detections import ( - BOXES_PATTERN, + BOXES_PATTERN_WITH_CONFIDENCE, COMPILED_BOXES_PATTERN, + COMPILED_BOXES_PATTERN_WITH_CONFIDENCE, DetectionCreate, DetectionLabel, + DetectionManualBboxes, DetectionUrl, DetectionWithUrl, ) @@ -45,7 +47,7 @@ async def create_detection( bboxes: str = Form( ..., description="string representation of list of detection localizations, each represented as a tuple of relative coords (max 3 decimals) in order: xmin, ymin, xmax, ymax, conf", - pattern=BOXES_PATTERN, + pattern=BOXES_PATTERN_WITH_CONFIDENCE, min_length=2, max_length=settings.MAX_BBOX_STR_LENGTH, ), @@ -58,7 +60,7 @@ async def create_detection( telemetry_client.capture(f"camera|{token_payload.sub}", event="detections-create") # Throw an error if the format is invalid and can't be captured by the regex - if any(box[0] >= box[2] or box[1] >= box[3] for box in COMPILED_BOXES_PATTERN.findall(bboxes)): + if any(box[0] >= box[2] or box[1] >= box[3] for box in COMPILED_BOXES_PATTERN_WITH_CONFIDENCE.findall(bboxes)): raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="xmin & ymin are expected to be respectively smaller than xmax & ymax", @@ -188,6 +190,43 @@ async def label_detection( return await detections.update(detection_id, payload) +@router.patch("/{detection_id}/manualbboxes", status_code=status.HTTP_200_OK, summary="Update the manual_bboxes field") +async def manual_bboxes( + payload: DetectionManualBboxes, + detection_id: int = Path(..., gt=0), + cameras: CameraCRUD = Depends(get_camera_crud), + detections: DetectionCRUD = Depends(get_detection_crud), + token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT]), +) -> Detection: + telemetry_client.capture( + token_payload.sub, event="detections-manual-bboxes", properties={"detection_id": detection_id} + ) + detection = cast(Detection, await detections.get(detection_id, strict=True)) + + matched_boxes = COMPILED_BOXES_PATTERN.findall(payload.manual_bboxes) + + # Throw an error if no boxes match the expected pattern + if not matched_boxes: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="No valid bounding boxes found in the input" + ) + + # Throw an error if the format is invalid and can't be captured by the regex + if any(box[0] >= box[2] or box[1] >= box[3] for box in matched_boxes): + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="xmin & ymin are expected to be respectively smaller than xmax & ymax", + ) + + if UserRole.ADMIN in token_payload.scopes: + return await detections.update(detection_id, payload) + + camera = cast(Camera, await cameras.get(detection.camera_id, strict=True)) + if token_payload.organization_id != camera.organization_id: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access forbidden.") + return await detections.update(detection_id, payload) + + @router.delete("/{detection_id}", status_code=status.HTTP_200_OK, summary="Delete a detection") async def delete_detection( detection_id: int = Path(..., gt=0), diff --git a/src/app/models.py b/src/app/models.py index 6dee673a..e3f7e577 100644 --- a/src/app/models.py +++ b/src/app/models.py @@ -61,6 +61,7 @@ class Detection(SQLModel, table=True): bucket_key: str bboxes: str = Field(..., min_length=2, max_length=settings.MAX_BBOX_STR_LENGTH, nullable=False) is_wildfire: Union[bool, None] = None + manual_bboxes: str = Field(None, min_length=2, max_length=settings.MAX_BBOX_STR_LENGTH, nullable=True) created_at: datetime = Field(default_factory=datetime.utcnow, nullable=False) updated_at: datetime = Field(default_factory=datetime.utcnow, nullable=False) diff --git a/src/app/schemas/detections.py b/src/app/schemas/detections.py index eff9d052..fd7b4452 100644 --- a/src/app/schemas/detections.py +++ b/src/app/schemas/detections.py @@ -10,13 +10,23 @@ from app.core.config import settings from app.models import Detection -__all__ = ["Azimuth", "DetectionCreate", "DetectionLabel", "DetectionUrl"] +__all__ = ["Azimuth", "DetectionCreate", "DetectionLabel", "DetectionManualBboxes", "DetectionUrl"] class DetectionLabel(BaseModel): is_wildfire: bool +class DetectionManualBboxes(BaseModel): + manual_bboxes: str = Field( + ..., + min_length=2, + max_length=settings.MAX_BBOX_STR_LENGTH, + description="string representation of list of tuples where each tuple is a relative coordinate in order xmin, ymin, xmax, ymax", + json_schema_extra={"examples": ["[(0.1, 0.1, 0.9, 0.9)]"]}, + ) + + class Azimuth(BaseModel): azimuth: float = Field( ..., @@ -29,7 +39,10 @@ class Azimuth(BaseModel): # Regex for a float between 0 and 1, with a maximum of 3 decimals FLOAT_PATTERN = r"(0?\.[0-9]{1,3}|0|1)" -BOX_PATTERN = rf"\({FLOAT_PATTERN},{FLOAT_PATTERN},{FLOAT_PATTERN},{FLOAT_PATTERN},{FLOAT_PATTERN}\)" +BOX_PATTERN_WITH_CONFIDENCE = rf"\({FLOAT_PATTERN},{FLOAT_PATTERN},{FLOAT_PATTERN},{FLOAT_PATTERN},{FLOAT_PATTERN}\)" +BOXES_PATTERN_WITH_CONFIDENCE = rf"^\[{BOX_PATTERN_WITH_CONFIDENCE}(,{BOX_PATTERN_WITH_CONFIDENCE})*\]$" +COMPILED_BOXES_PATTERN_WITH_CONFIDENCE = re.compile(BOXES_PATTERN_WITH_CONFIDENCE) +BOX_PATTERN = rf"\({FLOAT_PATTERN},{FLOAT_PATTERN},{FLOAT_PATTERN},{FLOAT_PATTERN}\)" BOXES_PATTERN = rf"^\[{BOX_PATTERN}(,{BOX_PATTERN})*\]$" COMPILED_BOXES_PATTERN = re.compile(BOXES_PATTERN) diff --git a/src/migrations/versions/2024_05_30_1200-f84a0ed81bdc_init.py b/src/migrations/versions/2024_05_30_1200-f84a0ed81bdc_init.py index 91c5fb63..afd3e0ef 100755 --- a/src/migrations/versions/2024_05_30_1200-f84a0ed81bdc_init.py +++ b/src/migrations/versions/2024_05_30_1200-f84a0ed81bdc_init.py @@ -52,6 +52,8 @@ def upgrade() -> None: sa.Column("azimuth", sa.Float(), nullable=False), sa.Column("bucket_key", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column("is_wildfire", sa.Boolean(), nullable=True), + sa.Column("bboxes", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("manual_bboxes", sqlmodel.sql.sqltypes.AutoString(), nullable=True), sa.Column("created_at", sa.DateTime(), nullable=False), sa.Column("updated_at", sa.DateTime(), nullable=False), sa.ForeignKeyConstraint( diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 0f885f43..51f177c0 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -96,6 +96,7 @@ "azimuth": 43.7, "bucket_key": "my_file", "is_wildfire": True, + "manual_bboxes": None, "bboxes": "[(.1,.1,.7,.8,.9)]", "created_at": datetime.strptime("2023-11-07T15:08:19.226673", dt_format), "updated_at": datetime.strptime("2023-11-07T15:08:19.226673", dt_format), @@ -106,6 +107,7 @@ "azimuth": 43.7, "bucket_key": "my_file", "is_wildfire": False, + "manual_bboxes": None, "bboxes": "[(.1,.1,.7,.8,.9)]", "created_at": datetime.strptime("2023-11-07T15:08:19.226673", dt_format), "updated_at": datetime.strptime("2023-11-07T15:08:19.226673", dt_format), @@ -116,6 +118,7 @@ "azimuth": 43.7, "bucket_key": "my_file", "is_wildfire": None, + "manual_bboxes": None, "bboxes": "[(.1,.1,.7,.8,.9)]", "created_at": datetime.strptime("2023-11-07T15:08:19.226673", dt_format), "updated_at": datetime.strptime("2023-11-07T15:08:19.226673", dt_format), diff --git a/src/tests/endpoints/test_detections.py b/src/tests/endpoints/test_detections.py index dc070d32..6712e995 100644 --- a/src/tests/endpoints/test_detections.py +++ b/src/tests/endpoints/test_detections.py @@ -17,7 +17,7 @@ (None, 0, {"azimuth": 45.6, "bboxes": []}, 422, None), (None, 1, {"azimuth": 45.6, "bboxes": (0.6, 0.6, 0.6, 0.6, 0.6)}, 422, None), (None, 1, {"azimuth": 45.6, "bboxes": "[(0.6, 0.6, 0.6, 0.6, 0.6)]"}, 422, None), - (None, 1, {"azimuth": 45.6, "bboxes": "[(0.6,0.6,0.7,0.7,0.6)]"}, 201, None), + (None, 1, {"azimuth": 45.6, "bboxes": "[(0.6,0.6,0.7,0.7,0.6)]", "manual_bboxes": None}, 201, None), ], ) @pytest.mark.asyncio @@ -181,7 +181,6 @@ async def test_fetch_unlabeled_detections( (0, 0, {"is_wildfire": True}, 422, None, None), (0, 1, {"label": True}, 422, None, None), (0, 1, {"is_wildfire": "hello"}, 422, None, None), - # (0, 1, {"is_wildfire": "True"}, 422, None, None), # odd, this works (0, 1, {"is_wildfire": True}, 200, None, 0), (0, 2, {"is_wildfire": True}, 200, None, 1), (1, 1, {"is_wildfire": True}, 200, None, 0), @@ -219,6 +218,50 @@ async def test_label_detection( } +@pytest.mark.parametrize( + ("user_idx", "detection_id", "payload", "status_code", "status_detail", "expected_idx"), + [ + (None, 1, {"manual_bboxes": True}, 401, "Not authenticated", None), + (0, 0, {"manual_bboxes": True}, 422, None, None), + (0, 1, {"label": True}, 422, None, None), + (0, 1, {"manual_bboxes": "hello"}, 422, None, None), + (0, 1, {"manual_bboxes": "[(.1,.1,.7,.8)]"}, 200, None, 0), + (0, 2, {"manual_bboxes": "[(.1,.1,.7,.8)]"}, 200, None, 1), + (1, 1, {"manual_bboxes": "[(.1,.1,.7,.8)]"}, 200, None, 0), + (1, 2, {"manual_bboxes": "[(.1,.1,.7,.8)]"}, 200, None, 1), + (2, 1, {"manual_bboxes": "[(.1,.1,.7,.8)]"}, 403, None, 0), + ], +) +@pytest.mark.asyncio +async def test_manual_bboxes( + async_client: AsyncClient, + detection_session: AsyncSession, + user_idx: Union[int, None], + detection_id: int, + payload: Dict[str, Any], + status_code: int, + status_detail: Union[str, None], + expected_idx: Union[int, None], +): + auth = None + if isinstance(user_idx, int): + auth = pytest.get_token( + pytest.user_table[user_idx]["id"], + pytest.user_table[user_idx]["role"].split(), + pytest.user_table[user_idx]["organization_id"], + ) + + response = await async_client.patch(f"/detections/{detection_id}/manualbboxes", json=payload, headers=auth) + assert response.status_code == status_code, print(response.__dict__) + if isinstance(status_detail, str): + assert response.json()["detail"] == status_detail + if response.status_code // 100 == 2: + assert response.json() == { + **{k: v for k, v in pytest.detection_table[expected_idx].items()}, + **payload, + } + + @pytest.mark.parametrize( ("user_idx", "detection_id", "status_code", "status_detail"), [