Skip to content

Commit

Permalink
Reset the database at start of each email
Browse files Browse the repository at this point in the history
  • Loading branch information
jwhitlock committed May 20, 2024
1 parent 20ed779 commit 623cfba
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 2 deletions.
17 changes: 15 additions & 2 deletions emails/management/commands/process_emails_from_sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
import time
from datetime import UTC, datetime
from multiprocessing import Pool
from typing import Any
from typing import Any, cast
from urllib.parse import urlsplit

from django import setup
from django.core.management.base import CommandError
from django.db import connection
from django.http import HttpResponse

import boto3
Expand Down Expand Up @@ -435,7 +436,7 @@ def error_callback(exc_info: BaseException) -> None:
pool_start_time = time.monotonic()
with Pool(1, initializer=setup) as pool:
future = pool.apply_async(
_sns_inbound_logic,
run_sns_inbound_logic,
[topic_arn, message_type, verified_json_body],
callback=success_callback,
error_callback=error_callback,
Expand Down Expand Up @@ -484,3 +485,15 @@ def pluralize(self, value: int, singular: str, plural: str | None = None) -> str
return f"{value} {singular}"
else:
return f"{value} {plural or (singular + 's')}"


def run_sns_inbound_logic(
topic_arn: str, message_type: str, json_body: str
) -> HttpResponse:
# Reset any exiting connection, verify it is usable
with connection.cursor() as cursor:
cursor.db.queries_log.clear()
if not cursor.db.is_usable():
cursor.db.close()

return cast(HttpResponse, _sns_inbound_logic(topic_arn, message_type, json_body))
40 changes: 40 additions & 0 deletions emails/tests/mgmt_process_emails_from_sqs_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,22 @@ def mock_verify_from_sns() -> Iterator[Mock]:
yield mock_verify_from_sns


@pytest.fixture(autouse=True)
def mock_django_db_connection() -> Iterator[Mock]:
"""Mock django.db.connection, used to manually recycle the database connection."""

mock_db = Mock(spec_set=["is_usable", "close", "queries_log"])
mock_db.is_usable.return_value = True

mock_cursor = Mock()
mock_cursor.__enter__ = Mock(return_value=mock_cursor)
mock_cursor.__exit__ = Mock(return_value=False)
mock_cursor.db = mock_db

with patch(f"{MOCK_BASE}.connection.cursor", return_value=mock_cursor):
yield mock_db


@pytest.fixture(autouse=True)
def mock_sns_inbound_logic() -> Iterator[Mock]:
"""Mock _sns_inbound_logic(topic_arn, message_type, json_body) to do nothing"""
Expand Down Expand Up @@ -295,6 +311,7 @@ def test_one_message(
mock_verify_from_sns: Mock,
mock_sns_inbound_logic: Mock,
mock_sqs_client: Mock,
mock_django_db_connection: Mock,
caplog: LogCaptureFixture,
) -> None:
"""The command will process an available message."""
Expand All @@ -321,6 +338,9 @@ def test_one_message(
"Notification",
TEST_SNS_MESSAGE,
)
mock_django_db_connection.queries_log.clear.assert_called_once()
mock_django_db_connection.is_usable.assert_called_once()
mock_django_db_connection.close.assert_not_called()


def test_keyboard_interrupt(
Expand Down Expand Up @@ -476,6 +496,26 @@ def test_ses_timeout(
assert mock_process_pool_future._timeouts == [1.0] * 120


def test_db_is_unusable_is_closed(
mock_sqs_client: Mock, mock_django_db_connection: Mock, caplog: LogCaptureFixture
) -> None:
"""If the database connection is unusable, it is closed so it will be refreshed."""
mock_django_db_connection.is_usable.return_value = False
msg = fake_sqs_message(json.dumps(TEST_SNS_MESSAGE))
mock_sqs_client.return_value = fake_queue([msg], [])
call_command(COMMAND_NAME)
summary = summary_from_exit_log(caplog)
assert summary["total_messages"] == 1
msg.delete.assert_called()
rec2 = caplog.records[1]
assert rec2.msg == "Message processed"
rec2_extra = log_extra(rec2)
assert rec2_extra["success"] is True
mock_django_db_connection.queries_log.clear.assert_called_once()
mock_django_db_connection.is_usable.assert_called_once()
mock_django_db_connection.close.assert_called_once()


def test_verify_from_sns_raises_openssl_error(
mock_verify_from_sns: Mock, mock_sqs_client: Mock, caplog: LogCaptureFixture
) -> None:
Expand Down

0 comments on commit 623cfba

Please sign in to comment.