Skip to content

Commit

Permalink
feat: Added event existence check upon alert creation (#149)
Browse files Browse the repository at this point in the history
* feat: Made event_id optional

* feat: Added event check in alert creation

* fix: Fixed import

* fix: Fixed event check

* test: Added unittest

* test: Added datetime util function

* fix: Fixed ongoing alert fetching

* test: Added flexibility to ongoing alert fetching test

* test: Updated unittests

* style: Fixed lint

* refactor: Added alert relaxation duration in config

* refactor: Moved event checking to crud

* refactor: Reflected changes on client

* test: Updated unittests

* style: Fixed lint
  • Loading branch information
frgfm authored Apr 5, 2021
1 parent 67d893a commit f9e46fc
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 28 deletions.
29 changes: 24 additions & 5 deletions client/pyroclient/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
from urllib.parse import urljoin
import io
from typing import Dict, Any
from typing import Dict, Any, Optional

from .exceptions import HTTPRequestException

Expand Down Expand Up @@ -82,8 +82,14 @@ def heartbeat(self) -> Response:
"""
return requests.put(self.routes["heartbeat"], headers=self.headers)

def update_my_location(self, lat: float = None, lon: float = None,
elevation: float = None, yaw: float = None, pitch: float = None) -> Response:
def update_my_location(
self,
lat: Optional[float] = None,
lon: Optional[float] = None,
elevation: Optional[float] = None,
yaw: Optional[float] = None,
pitch: Optional[float] = None
) -> Response:
"""Updates the location of the device
Example::
Expand Down Expand Up @@ -160,7 +166,14 @@ def create_no_alert_site(self, lat: float, lon: float, name: str, country: str,
payload["group_id"] = group_id
return requests.post(self.routes["no-alert-site"], headers=self.headers, json=payload)

def send_alert(self, lat: float, lon: float, event_id: int, device_id: int, media_id: int = None) -> Response:
def send_alert(
self,
lat: float,
lon: float,
device_id: int,
event_id: Optional[int] = None,
media_id: Optional[int] = None
) -> Response:
"""Raise an alert to the API.
Example::
Expand Down Expand Up @@ -189,7 +202,13 @@ def send_alert(self, lat: float, lon: float, event_id: int, device_id: int, medi

return requests.post(self.routes["send-alert"], headers=self.headers, json=payload)

def send_alert_from_device(self, lat: float, lon: float, event_id: int, media_id: int = None) -> Response:
def send_alert_from_device(
self,
lat: float,
lon: float,
event_id: Optional[int] = None,
media_id: Optional[int] = None
) -> Response:
"""Raise an alert to the API from a device (no need to specify device ID).
Example::
Expand Down
42 changes: 40 additions & 2 deletions src/app/api/crud/alerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

from typing import List, Dict, Mapping, Any
from sqlalchemy import Table
from typing import List, Dict, Mapping, Any, Optional
from sqlalchemy import Table, and_
from datetime import datetime, timedelta

from app.db import alerts
from app.api import crud
from app.api.routes.events import create_event
from app.api.schemas import EventIn, AlertIn, AlertOut
import app.config as cfg


async def fetch_ongoing_alerts(
Expand All @@ -25,3 +30,36 @@ async def fetch_ongoing_alerts(
query = query.where(~getattr(table.c, "event_id").in_(all_closed_events))

return await crud.base.database.fetch_all(query=query)


async def resolve_previous_alert(device_id: int) -> Optional[AlertOut]:
# check whether there is an alert in the last 5 min by the same device
max_ts = datetime.utcnow() - timedelta(seconds=cfg.ALERT_RELAXATION_SECONDS)
query = (
alerts.select()
.where(
and_(
alerts.c.device_id == device_id,
alerts.c.created_at >= max_ts
)
)
.order_by(alerts.c.created_at.desc())
.limit(1)
)

entries = await crud.base.database.fetch_all(query=query)

return entries[0] if len(entries) > 0 else None


async def create_event_if_inexistant(payload: AlertIn) -> int:
# check whether there is an alert in the last 5 min by the same device
previous_alert = await resolve_previous_alert(payload.device_id)
if previous_alert is None:
# Create an event & get the ID
event = await create_event(EventIn(lat=payload.lat, lon=payload.lon, start_ts=datetime.utcnow()))
event_id = event['id']
# Get event ref
else:
event_id = previous_alert['event_id']
return event_id
13 changes: 9 additions & 4 deletions src/app/api/routes/alerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import List
from fastapi import APIRouter, Path, Security, HTTPException, status
from sqlalchemy import select
from datetime import datetime, timedelta

from app.api import crud
from app.db import alerts, events, media
Expand Down Expand Up @@ -41,8 +42,13 @@ async def create_alert(payload: AlertIn, _=Security(get_current_access, scopes=[
Below, click on "Schema" for more detailed information about arguments
or "Example Value" to get a concrete idea of arguments
"""

if payload.media_id is not None:
await check_media_existence(payload.media_id)

if payload.event_id is None:
payload.event_id = await crud.alerts.create_event_if_inexistant(payload)

return await crud.create_entry(alerts, payload)


Expand All @@ -56,9 +62,8 @@ async def create_alert_from_device(payload: AlertBase,
Below, click on "Schema" for more detailed information about arguments
or "Example Value" to get a concrete idea of arguments
"""
if payload.media_id is not None:
await check_media_existence(payload.media_id)
return await crud.create_entry(alerts, AlertIn(**payload.dict(), device_id=device.id))

return await create_alert(AlertIn(**payload.dict(), device_id=device.id))


@router.get("/{alert_id}/", response_model=AlertOut, summary="Get information about a specific alert")
Expand Down Expand Up @@ -136,7 +141,7 @@ async def fetch_ongoing_alerts(_=Security(get_current_access, scopes=[AccessType
.where(
alerts.c.event_id.in_(
select([events.c.id])
.where(events.c.end_ts.isnot(None))
.where(events.c.end_ts.is_(None))
)
)
)
Expand Down
2 changes: 1 addition & 1 deletion src/app/api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ class AlertMediaId(BaseModel):


class AlertBase(_FlatLocation, AlertMediaId):
event_id: int = Field(..., gt=0)
event_id: int = Field(None, gt=0)
is_acknowledged: bool = Field(False)
azimuth: float = Field(default=None, gt=0, lt=360)

Expand Down
2 changes: 2 additions & 0 deletions src/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
TEST_DATABASE_URL: str = os.getenv("TEST_DATABASE_URL")
LOGO_URL: str = "https://pyronear.org/img/logo_letters.png"

ALERT_RELAXATION_SECONDS: int = 5 * 60


SECRET_KEY: str = secrets.token_urlsafe(32)
if DEBUG:
Expand Down
43 changes: 27 additions & 16 deletions src/tests/routes/test_alerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from app import db
from app.api import crud
from tests.db_utils import get_entry, fill_table
from tests.utils import update_only_datetime, parse_time
from tests.utils import update_only_datetime, parse_time, ts_to_string


USER_TABLE = [
Expand Down Expand Up @@ -57,6 +57,8 @@
"azimuth": 47., "is_acknowledged": True, "created_at": "2020-10-13T09:18:45.447773"},
{"id": 3, "device_id": 2, "event_id": 2, "media_id": None, "lat": 10., "lon": 8.,
"azimuth": 123., "is_acknowledged": False, "created_at": "2020-11-03T11:18:45.447773"},
{"id": 4, "device_id": 1, "event_id": 3, "media_id": None, "lat": 0., "lon": 0.,
"azimuth": 47., "is_acknowledged": True, "created_at": ts_to_string(datetime.utcnow())},
]

USER_TABLE_FOR_DB = list(map(update_only_datetime, USER_TABLE))
Expand Down Expand Up @@ -147,7 +149,8 @@ async def test_fetch_ongoing_alerts(test_app_asyncio, init_test_db, access_idx,
assert response.json()['detail'] == status_details

if response.status_code // 100 == 2:
assert response.json() == ALERT_TABLE[:2]
event_ids = [entry['id'] for entry in EVENT_TABLE if entry['end_ts'] is None]
assert response.json() == [entry for entry in ALERT_TABLE if entry['event_id'] in event_ids]


@pytest.mark.parametrize(
Expand All @@ -174,21 +177,23 @@ async def test_fetch_unacknowledged_alerts(test_app_asyncio, init_test_db, acces


@pytest.mark.parametrize(
"access_idx, payload, status_code, status_details",
"access_idx, payload, expected_event_id, status_code, status_details",
[
[0, {"device_id": 2, "event_id": 2, "lat": 10., "lon": 8., "azimuth": 47.5},
[0, {"device_id": 2, "event_id": 2, "lat": 10., "lon": 8., "azimuth": 47.5}, None,
401, "Permission denied"],
[1, {"device_id": 2, "event_id": 2, "lat": 10., "lon": 8., "azimuth": 47.5}, 201, None],
[2, {"device_id": 2, "event_id": 2, "lat": 10., "lon": 8., "azimuth": 47.5},
[1, {"device_id": 2, "event_id": 2, "lat": 10., "lon": 8., "azimuth": 47.5}, None, 201, None],
[1, {"device_id": 2, "lat": 10., "lon": 8., "azimuth": 47.5}, 4, 201, None],
[1, {"device_id": 1, "lat": 10., "lon": 8., "azimuth": 47.5}, 3, 201, None],
[2, {"device_id": 2, "event_id": 2, "lat": 10., "lon": 8., "azimuth": 47.5}, None,
401, "Permission denied"],
[1, {"event_id": 2, "lat": 10., "lon": 8.}, 422, None],
[1, {"device_id": 2, "event_id": 2, "lat": 10., "lon": 8., "azimuth": "hello"}, 422, None],
[1, {"device_id": 2, "event_id": 2, "lat": 10., "lon": 8., "azimuth": -5.}, 422, None],
[1, {"event_id": 2, "lat": 10., "lon": 8.}, None, 422, None],
[1, {"device_id": 2, "event_id": 2, "lat": 10., "lon": 8., "azimuth": "hello"}, None, 422, None],
[1, {"device_id": 2, "event_id": 2, "lat": 10., "lon": 8., "azimuth": -5.}, None, 422, None],
],
)
@pytest.mark.asyncio
async def test_create_alert(test_app_asyncio, init_test_db, test_db,
access_idx, payload, status_code, status_details):
access_idx, payload, expected_event_id, status_code, status_details):

# Create a custom access token
auth = await pytest.get_token(ACCESS_TABLE[access_idx]['id'], ACCESS_TABLE[access_idx]['scope'].split())
Expand All @@ -203,6 +208,8 @@ async def test_create_alert(test_app_asyncio, init_test_db, test_db,
json_response = response.json()
test_response = {"id": len(ALERT_TABLE) + 1, **payload,
"media_id": None, "is_acknowledged": False}
if isinstance(expected_event_id, int):
test_response['event_id'] = expected_event_id
assert {k: v for k, v in json_response.items() if k != 'created_at'} == test_response

new_alert = await get_entry(test_db, db.alerts, json_response["id"])
Expand All @@ -211,16 +218,18 @@ async def test_create_alert(test_app_asyncio, init_test_db, test_db,


@pytest.mark.parametrize(
"access_idx, payload, status_code, status_details",
"access_idx, payload, expected_event_id, status_code, status_details",
[
[0, {"event_id": 2, "lat": 10., "lon": 8.}, 401, "Permission denied"],
[1, {"event_id": 2, "lat": 10., "lon": 8.}, 401, "Permission denied"],
[2, {"event_id": 2, "lat": 10., "lon": 8.}, 201, None],
[0, {"event_id": 2, "lat": 10., "lon": 8.}, None, 401, "Permission denied"],
[1, {"event_id": 2, "lat": 10., "lon": 8.}, None, 401, "Permission denied"],
[2, {"event_id": 2, "lat": 10., "lon": 8.}, None, 201, None],
[2, {"lat": 10., "lon": 8.}, 3, 201, None],
[3, {"lat": 10., "lon": 8.}, 4, 201, None],
],
)
@pytest.mark.asyncio
async def test_create_alert_by_device(test_app_asyncio, init_test_db, test_db,
access_idx, payload, status_code, status_details):
access_idx, payload, expected_event_id, status_code, status_details):

# Create a custom access token
auth = await pytest.get_token(ACCESS_TABLE[access_idx]['id'], ACCESS_TABLE[access_idx]['scope'].split())
Expand All @@ -242,6 +251,8 @@ async def test_create_alert_by_device(test_app_asyncio, init_test_db, test_db,
test_response = {"id": len(ALERT_TABLE) + 1,
"device_id": device_id, **payload,
"media_id": None, "is_acknowledged": False, "azimuth": None}
if isinstance(expected_event_id, int):
test_response['event_id'] = expected_event_id
assert {k: v for k, v in json_response.items() if k != 'created_at'} == test_response
new_alert = await get_entry(test_db, db.alerts, json_response["id"])
new_alert = dict(**new_alert)
Expand All @@ -257,7 +268,7 @@ async def test_create_alert_by_device(test_app_asyncio, init_test_db, test_db,
[1, {}, 1, 422, None],
[1, {"device_id": 2, "event_id": 2, "lat": 10., "lon": 8., "is_acknowledged": True}, 999,
404, "Entry not found"],
[1, {"device_id": 2, "lat": 10., "lon": 8., "is_acknowledged": True}, 1,
[1, {"device_id": 2, "lat": 10., "is_acknowledged": True}, 1,
422, None],
[1, {"device_id": 2, "event_id": 2, "lat": 10., "lon": 8., "is_acknowledged": True}, 0,
422, None],
Expand Down
4 changes: 4 additions & 0 deletions src/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ def update_only_datetime(entity_as_dict):

def parse_time(d):
return datetime.strptime(d, DATETIME_FORMAT)


def ts_to_string(ts):
return datetime.strftime(ts, DATETIME_FORMAT)

0 comments on commit f9e46fc

Please sign in to comment.