diff --git a/superset/tasks/schedules.py b/superset/tasks/schedules.py index 7b39362aec037..a53ce08e771b4 100644 --- a/superset/tasks/schedules.py +++ b/superset/tasks/schedules.py @@ -551,10 +551,10 @@ def schedule_alert_query( # pylint: disable=unused-argument if report_type == ScheduleType.alert: if is_test_alert and recipients: - deliver_alert(schedule, recipients) + deliver_alert(schedule.id, recipients) return - if run_alert_query(schedule): + if run_alert_query(schedule.id): # deliver_dashboard OR deliver_slice return else: @@ -567,7 +567,9 @@ class AlertState: PASS = "pass" -def deliver_alert(alert: Alert, recipients: Optional[str] = None) -> None: +def deliver_alert(alert_id: int, recipients: Optional[str] = None) -> None: + alert = db.session.query(Alert).get(alert_id) + logging.info("Triggering alert: %s", alert) img_data = None images = {} @@ -612,10 +614,12 @@ def deliver_alert(alert: Alert, recipients: Optional[str] = None) -> None: _deliver_email(recipients, deliver_as_group, subject, body, data, images) -def run_alert_query(alert: Alert) -> Optional[bool]: +def run_alert_query(alert_id: int) -> Optional[bool]: """ Execute alert.sql and return value if any rows are returned """ + alert = db.session.query(Alert).get(alert_id) + logger.info("Processing alert ID: %i", alert.id) database = alert.database if not database: @@ -650,7 +654,7 @@ def run_alert_query(alert: Alert) -> Optional[bool]: for row in df.to_records(): if any(row): state = AlertState.TRIGGER - deliver_alert(alert) + deliver_alert(alert.id) break if not state: state = AlertState.PASS diff --git a/tests/alerts_tests.py b/tests/alerts_tests.py index a7e032a0807eb..07205810edd05 100644 --- a/tests/alerts_tests.py +++ b/tests/alerts_tests.py @@ -86,21 +86,21 @@ def teardown_module(): @patch("superset.tasks.schedules.logging.Logger.error") def test_run_alert_query(mock_error, mock_deliver_alert): with app.app_context(): - run_alert_query(db.session.query(Alert).filter_by(id=1).one()) + run_alert_query(db.session.query(Alert).filter_by(id=1).one().id) alert1 = db.session.query(Alert).filter_by(id=1).one() assert mock_deliver_alert.call_count == 0 assert len(alert1.logs) == 1 assert alert1.logs[0].alert_id == 1 assert alert1.logs[0].state == "pass" - run_alert_query(db.session.query(Alert).filter_by(id=2).one()) + run_alert_query(db.session.query(Alert).filter_by(id=2).one().id) alert2 = db.session.query(Alert).filter_by(id=2).one() assert mock_deliver_alert.call_count == 1 assert len(alert2.logs) == 1 assert alert2.logs[0].alert_id == 2 assert alert2.logs[0].state == "trigger" - run_alert_query(db.session.query(Alert).filter_by(id=3).one()) + run_alert_query(db.session.query(Alert).filter_by(id=3).one().id) alert3 = db.session.query(Alert).filter_by(id=3).one() assert mock_deliver_alert.call_count == 1 assert mock_error.call_count == 2 @@ -108,11 +108,11 @@ def test_run_alert_query(mock_error, mock_deliver_alert): assert alert3.logs[0].alert_id == 3 assert alert3.logs[0].state == "error" - run_alert_query(db.session.query(Alert).filter_by(id=4).one()) + run_alert_query(db.session.query(Alert).filter_by(id=4).one().id) assert mock_deliver_alert.call_count == 1 assert mock_error.call_count == 3 - run_alert_query(db.session.query(Alert).filter_by(id=5).one()) + run_alert_query(db.session.query(Alert).filter_by(id=5).one().id) assert mock_deliver_alert.call_count == 1 assert mock_error.call_count == 4