From 623cfba20add8d70030da888b2454f000f58725f Mon Sep 17 00:00:00 2001 From: John Whitlock Date: Mon, 20 May 2024 18:02:22 -0500 Subject: [PATCH] Reset the database at start of each email --- .../commands/process_emails_from_sqs.py | 17 +++++++- .../mgmt_process_emails_from_sqs_tests.py | 40 +++++++++++++++++++ 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/emails/management/commands/process_emails_from_sqs.py b/emails/management/commands/process_emails_from_sqs.py index 4a15c47de1..65c9ee835d 100644 --- a/emails/management/commands/process_emails_from_sqs.py +++ b/emails/management/commands/process_emails_from_sqs.py @@ -16,11 +16,12 @@ import time from datetime import UTC, datetime from multiprocessing import Pool -from typing import Any +from typing import Any, cast from urllib.parse import urlsplit from django import setup from django.core.management.base import CommandError +from django.db import connection from django.http import HttpResponse import boto3 @@ -435,7 +436,7 @@ def error_callback(exc_info: BaseException) -> None: pool_start_time = time.monotonic() with Pool(1, initializer=setup) as pool: future = pool.apply_async( - _sns_inbound_logic, + run_sns_inbound_logic, [topic_arn, message_type, verified_json_body], callback=success_callback, error_callback=error_callback, @@ -484,3 +485,15 @@ def pluralize(self, value: int, singular: str, plural: str | None = None) -> str return f"{value} {singular}" else: return f"{value} {plural or (singular + 's')}" + + +def run_sns_inbound_logic( + topic_arn: str, message_type: str, json_body: str +) -> HttpResponse: + # Reset any exiting connection, verify it is usable + with connection.cursor() as cursor: + cursor.db.queries_log.clear() + if not cursor.db.is_usable(): + cursor.db.close() + + return cast(HttpResponse, _sns_inbound_logic(topic_arn, message_type, json_body)) diff --git a/emails/tests/mgmt_process_emails_from_sqs_tests.py b/emails/tests/mgmt_process_emails_from_sqs_tests.py index 4e1b046999..0d12351e88 100644 --- a/emails/tests/mgmt_process_emails_from_sqs_tests.py +++ b/emails/tests/mgmt_process_emails_from_sqs_tests.py @@ -76,6 +76,22 @@ def mock_verify_from_sns() -> Iterator[Mock]: yield mock_verify_from_sns +@pytest.fixture(autouse=True) +def mock_django_db_connection() -> Iterator[Mock]: + """Mock django.db.connection, used to manually recycle the database connection.""" + + mock_db = Mock(spec_set=["is_usable", "close", "queries_log"]) + mock_db.is_usable.return_value = True + + mock_cursor = Mock() + mock_cursor.__enter__ = Mock(return_value=mock_cursor) + mock_cursor.__exit__ = Mock(return_value=False) + mock_cursor.db = mock_db + + with patch(f"{MOCK_BASE}.connection.cursor", return_value=mock_cursor): + yield mock_db + + @pytest.fixture(autouse=True) def mock_sns_inbound_logic() -> Iterator[Mock]: """Mock _sns_inbound_logic(topic_arn, message_type, json_body) to do nothing""" @@ -295,6 +311,7 @@ def test_one_message( mock_verify_from_sns: Mock, mock_sns_inbound_logic: Mock, mock_sqs_client: Mock, + mock_django_db_connection: Mock, caplog: LogCaptureFixture, ) -> None: """The command will process an available message.""" @@ -321,6 +338,9 @@ def test_one_message( "Notification", TEST_SNS_MESSAGE, ) + mock_django_db_connection.queries_log.clear.assert_called_once() + mock_django_db_connection.is_usable.assert_called_once() + mock_django_db_connection.close.assert_not_called() def test_keyboard_interrupt( @@ -476,6 +496,26 @@ def test_ses_timeout( assert mock_process_pool_future._timeouts == [1.0] * 120 +def test_db_is_unusable_is_closed( + mock_sqs_client: Mock, mock_django_db_connection: Mock, caplog: LogCaptureFixture +) -> None: + """If the database connection is unusable, it is closed so it will be refreshed.""" + mock_django_db_connection.is_usable.return_value = False + msg = fake_sqs_message(json.dumps(TEST_SNS_MESSAGE)) + mock_sqs_client.return_value = fake_queue([msg], []) + call_command(COMMAND_NAME) + summary = summary_from_exit_log(caplog) + assert summary["total_messages"] == 1 + msg.delete.assert_called() + rec2 = caplog.records[1] + assert rec2.msg == "Message processed" + rec2_extra = log_extra(rec2) + assert rec2_extra["success"] is True + mock_django_db_connection.queries_log.clear.assert_called_once() + mock_django_db_connection.is_usable.assert_called_once() + mock_django_db_connection.close.assert_called_once() + + def test_verify_from_sns_raises_openssl_error( mock_verify_from_sns: Mock, mock_sqs_client: Mock, caplog: LogCaptureFixture ) -> None: