Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add field "manual bboxes" in the Detection table #374

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 42 additions & 3 deletions src/app/api/api_v1/endpoints/detections.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
from app.crud import CameraCRUD, DetectionCRUD, WebhookCRUD
from app.models import Camera, Detection, Role, UserRole
from app.schemas.detections import (
BOXES_PATTERN,
BOXES_PATTERN_WITH_CONFIDENCE,
COMPILED_BOXES_PATTERN,
COMPILED_BOXES_PATTERN_WITH_CONFIDENCE,
DetectionCreate,
DetectionLabel,
DetectionManualBboxes,
DetectionUrl,
DetectionWithUrl,
)
Expand All @@ -45,7 +47,7 @@
bboxes: str = Form(
...,
description="string representation of list of detection localizations, each represented as a tuple of relative coords (max 3 decimals) in order: xmin, ymin, xmax, ymax, conf",
pattern=BOXES_PATTERN,
pattern=BOXES_PATTERN_WITH_CONFIDENCE,
min_length=2,
max_length=settings.MAX_BBOX_STR_LENGTH,
),
Expand All @@ -58,7 +60,7 @@
telemetry_client.capture(f"camera|{token_payload.sub}", event="detections-create")

# Throw an error if the format is invalid and can't be captured by the regex
if any(box[0] >= box[2] or box[1] >= box[3] for box in COMPILED_BOXES_PATTERN.findall(bboxes)):
if any(box[0] >= box[2] or box[1] >= box[3] for box in COMPILED_BOXES_PATTERN_WITH_CONFIDENCE.findall(bboxes)):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="xmin & ymin are expected to be respectively smaller than xmax & ymax",
Expand Down Expand Up @@ -188,6 +190,43 @@
return await detections.update(detection_id, payload)


@router.patch("/{detection_id}/manualbboxes", status_code=status.HTTP_200_OK, summary="Update the manual_bboxes field")
async def manual_bboxes(
payload: DetectionManualBboxes,
detection_id: int = Path(..., gt=0),
cameras: CameraCRUD = Depends(get_camera_crud),
detections: DetectionCRUD = Depends(get_detection_crud),
token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT]),
) -> Detection:
telemetry_client.capture(
token_payload.sub, event="detections-manual-bboxes", properties={"detection_id": detection_id}
)
detection = cast(Detection, await detections.get(detection_id, strict=True))

matched_boxes = COMPILED_BOXES_PATTERN.findall(payload.manual_bboxes)

# Throw an error if no boxes match the expected pattern
if not matched_boxes:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="No valid bounding boxes found in the input"
)

# Throw an error if the format is invalid and can't be captured by the regex
if any(box[0] >= box[2] or box[1] >= box[3] for box in matched_boxes):
raise HTTPException(

Check warning on line 216 in src/app/api/api_v1/endpoints/detections.py

View check run for this annotation

Codecov / codecov/patch

src/app/api/api_v1/endpoints/detections.py#L216

Added line #L216 was not covered by tests
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="xmin & ymin are expected to be respectively smaller than xmax & ymax",
)

if UserRole.ADMIN in token_payload.scopes:
return await detections.update(detection_id, payload)

camera = cast(Camera, await cameras.get(detection.camera_id, strict=True))
if token_payload.organization_id != camera.organization_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access forbidden.")

Check warning on line 226 in src/app/api/api_v1/endpoints/detections.py

View check run for this annotation

Codecov / codecov/patch

src/app/api/api_v1/endpoints/detections.py#L226

Added line #L226 was not covered by tests
return await detections.update(detection_id, payload)


@router.delete("/{detection_id}", status_code=status.HTTP_200_OK, summary="Delete a detection")
async def delete_detection(
detection_id: int = Path(..., gt=0),
Expand Down
1 change: 1 addition & 0 deletions src/app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class Detection(SQLModel, table=True):
bucket_key: str
bboxes: str = Field(..., min_length=2, max_length=settings.MAX_BBOX_STR_LENGTH, nullable=False)
is_wildfire: Union[bool, None] = None
manual_bboxes: str = Field(None, min_length=2, max_length=settings.MAX_BBOX_STR_LENGTH, nullable=True)
created_at: datetime = Field(default_factory=datetime.utcnow, nullable=False)
updated_at: datetime = Field(default_factory=datetime.utcnow, nullable=False)

Expand Down
17 changes: 15 additions & 2 deletions src/app/schemas/detections.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,23 @@
from app.core.config import settings
from app.models import Detection

__all__ = ["Azimuth", "DetectionCreate", "DetectionLabel", "DetectionUrl"]
__all__ = ["Azimuth", "DetectionCreate", "DetectionLabel", "DetectionManualBboxes", "DetectionUrl"]


class DetectionLabel(BaseModel):
is_wildfire: bool


class DetectionManualBboxes(BaseModel):
manual_bboxes: str = Field(
...,
min_length=2,
max_length=settings.MAX_BBOX_STR_LENGTH,
description="string representation of list of tuples where each tuple is a relative coordinate in order xmin, ymin, xmax, ymax",
json_schema_extra={"examples": ["[(0.1, 0.1, 0.9, 0.9)]"]},
)


class Azimuth(BaseModel):
azimuth: float = Field(
...,
Expand All @@ -29,7 +39,10 @@ class Azimuth(BaseModel):

# Regex for a float between 0 and 1, with a maximum of 3 decimals
FLOAT_PATTERN = r"(0?\.[0-9]{1,3}|0|1)"
BOX_PATTERN = rf"\({FLOAT_PATTERN},{FLOAT_PATTERN},{FLOAT_PATTERN},{FLOAT_PATTERN},{FLOAT_PATTERN}\)"
BOX_PATTERN_WITH_CONFIDENCE = rf"\({FLOAT_PATTERN},{FLOAT_PATTERN},{FLOAT_PATTERN},{FLOAT_PATTERN},{FLOAT_PATTERN}\)"
BOXES_PATTERN_WITH_CONFIDENCE = rf"^\[{BOX_PATTERN_WITH_CONFIDENCE}(,{BOX_PATTERN_WITH_CONFIDENCE})*\]$"
COMPILED_BOXES_PATTERN_WITH_CONFIDENCE = re.compile(BOXES_PATTERN_WITH_CONFIDENCE)
BOX_PATTERN = rf"\({FLOAT_PATTERN},{FLOAT_PATTERN},{FLOAT_PATTERN},{FLOAT_PATTERN}\)"
BOXES_PATTERN = rf"^\[{BOX_PATTERN}(,{BOX_PATTERN})*\]$"
COMPILED_BOXES_PATTERN = re.compile(BOXES_PATTERN)

Expand Down
2 changes: 2 additions & 0 deletions src/migrations/versions/2024_05_30_1200-f84a0ed81bdc_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def upgrade() -> None:
sa.Column("azimuth", sa.Float(), nullable=False),
sa.Column("bucket_key", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("is_wildfire", sa.Boolean(), nullable=True),
sa.Column("bboxes", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("manual_bboxes", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("updated_at", sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(
Expand Down
3 changes: 3 additions & 0 deletions src/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
"azimuth": 43.7,
"bucket_key": "my_file",
"is_wildfire": True,
"manual_bboxes": None,
"bboxes": "[(.1,.1,.7,.8,.9)]",
"created_at": datetime.strptime("2023-11-07T15:08:19.226673", dt_format),
"updated_at": datetime.strptime("2023-11-07T15:08:19.226673", dt_format),
Expand All @@ -106,6 +107,7 @@
"azimuth": 43.7,
"bucket_key": "my_file",
"is_wildfire": False,
"manual_bboxes": None,
"bboxes": "[(.1,.1,.7,.8,.9)]",
"created_at": datetime.strptime("2023-11-07T15:08:19.226673", dt_format),
"updated_at": datetime.strptime("2023-11-07T15:08:19.226673", dt_format),
Expand All @@ -116,6 +118,7 @@
"azimuth": 43.7,
"bucket_key": "my_file",
"is_wildfire": None,
"manual_bboxes": None,
"bboxes": "[(.1,.1,.7,.8,.9)]",
"created_at": datetime.strptime("2023-11-07T15:08:19.226673", dt_format),
"updated_at": datetime.strptime("2023-11-07T15:08:19.226673", dt_format),
Expand Down
47 changes: 45 additions & 2 deletions src/tests/endpoints/test_detections.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
(None, 0, {"azimuth": 45.6, "bboxes": []}, 422, None),
(None, 1, {"azimuth": 45.6, "bboxes": (0.6, 0.6, 0.6, 0.6, 0.6)}, 422, None),
(None, 1, {"azimuth": 45.6, "bboxes": "[(0.6, 0.6, 0.6, 0.6, 0.6)]"}, 422, None),
(None, 1, {"azimuth": 45.6, "bboxes": "[(0.6,0.6,0.7,0.7,0.6)]"}, 201, None),
(None, 1, {"azimuth": 45.6, "bboxes": "[(0.6,0.6,0.7,0.7,0.6)]", "manual_bboxes": None}, 201, None),
],
)
@pytest.mark.asyncio
Expand Down Expand Up @@ -181,7 +181,6 @@ async def test_fetch_unlabeled_detections(
(0, 0, {"is_wildfire": True}, 422, None, None),
(0, 1, {"label": True}, 422, None, None),
(0, 1, {"is_wildfire": "hello"}, 422, None, None),
# (0, 1, {"is_wildfire": "True"}, 422, None, None), # odd, this works
(0, 1, {"is_wildfire": True}, 200, None, 0),
(0, 2, {"is_wildfire": True}, 200, None, 1),
(1, 1, {"is_wildfire": True}, 200, None, 0),
Expand Down Expand Up @@ -219,6 +218,50 @@ async def test_label_detection(
}


@pytest.mark.parametrize(
("user_idx", "detection_id", "payload", "status_code", "status_detail", "expected_idx"),
[
(None, 1, {"manual_bboxes": True}, 401, "Not authenticated", None),
(0, 0, {"manual_bboxes": True}, 422, None, None),
(0, 1, {"label": True}, 422, None, None),
(0, 1, {"manual_bboxes": "hello"}, 422, None, None),
(0, 1, {"manual_bboxes": "[(.1,.1,.7,.8)]"}, 200, None, 0),
(0, 2, {"manual_bboxes": "[(.1,.1,.7,.8)]"}, 200, None, 1),
(1, 1, {"manual_bboxes": "[(.1,.1,.7,.8)]"}, 200, None, 0),
(1, 2, {"manual_bboxes": "[(.1,.1,.7,.8)]"}, 200, None, 1),
(2, 1, {"manual_bboxes": "[(.1,.1,.7,.8)]"}, 403, None, 0),
],
)
@pytest.mark.asyncio
async def test_manual_bboxes(
async_client: AsyncClient,
detection_session: AsyncSession,
user_idx: Union[int, None],
detection_id: int,
payload: Dict[str, Any],
status_code: int,
status_detail: Union[str, None],
expected_idx: Union[int, None],
):
auth = None
if isinstance(user_idx, int):
auth = pytest.get_token(
pytest.user_table[user_idx]["id"],
pytest.user_table[user_idx]["role"].split(),
pytest.user_table[user_idx]["organization_id"],
)

response = await async_client.patch(f"/detections/{detection_id}/manualbboxes", json=payload, headers=auth)
assert response.status_code == status_code, print(response.__dict__)
if isinstance(status_detail, str):
assert response.json()["detail"] == status_detail
if response.status_code // 100 == 2:
assert response.json() == {
**{k: v for k, v in pytest.detection_table[expected_idx].items()},
**payload,
}


@pytest.mark.parametrize(
("user_idx", "detection_id", "status_code", "status_detail"),
[
Expand Down
Loading