Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPP-3815: Reset the database at start of each email #4718

Merged
merged 1 commit into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions emails/management/commands/process_emails_from_sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
import time
from datetime import UTC, datetime
from multiprocessing import Pool
from typing import Any
from typing import Any, cast
from urllib.parse import urlsplit

from django import setup
from django.core.management.base import CommandError
from django.db import connection
from django.http import HttpResponse

import boto3
Expand Down Expand Up @@ -435,7 +436,7 @@ def error_callback(exc_info: BaseException) -> None:
pool_start_time = time.monotonic()
with Pool(1, initializer=setup) as pool:
future = pool.apply_async(
_sns_inbound_logic,
run_sns_inbound_logic,
[topic_arn, message_type, verified_json_body],
callback=success_callback,
error_callback=error_callback,
Expand Down Expand Up @@ -484,3 +485,15 @@ def pluralize(self, value: int, singular: str, plural: str | None = None) -> str
return f"{value} {singular}"
else:
return f"{value} {plural or (singular + 's')}"


def run_sns_inbound_logic(
topic_arn: str, message_type: str, json_body: str
) -> HttpResponse:
# Reset any exiting connection, verify it is usable
with connection.cursor() as cursor:
cursor.db.queries_log.clear()
if not cursor.db.is_usable():
cursor.db.close()

return cast(HttpResponse, _sns_inbound_logic(topic_arn, message_type, json_body))
40 changes: 40 additions & 0 deletions emails/tests/mgmt_process_emails_from_sqs_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,22 @@ def mock_verify_from_sns() -> Iterator[Mock]:
yield mock_verify_from_sns


@pytest.fixture(autouse=True)
def mock_django_db_connection() -> Iterator[Mock]:
"""Mock django.db.connection, used to manually recycle the database connection."""

mock_db = Mock(spec_set=["is_usable", "close", "queries_log"])
mock_db.is_usable.return_value = True

mock_cursor = Mock()
mock_cursor.__enter__ = Mock(return_value=mock_cursor)
mock_cursor.__exit__ = Mock(return_value=False)
mock_cursor.db = mock_db

with patch(f"{MOCK_BASE}.connection.cursor", return_value=mock_cursor):
yield mock_db


@pytest.fixture(autouse=True)
def mock_sns_inbound_logic() -> Iterator[Mock]:
"""Mock _sns_inbound_logic(topic_arn, message_type, json_body) to do nothing"""
Expand Down Expand Up @@ -295,6 +311,7 @@ def test_one_message(
mock_verify_from_sns: Mock,
mock_sns_inbound_logic: Mock,
mock_sqs_client: Mock,
mock_django_db_connection: Mock,
caplog: LogCaptureFixture,
) -> None:
"""The command will process an available message."""
Expand All @@ -321,6 +338,9 @@ def test_one_message(
"Notification",
TEST_SNS_MESSAGE,
)
mock_django_db_connection.queries_log.clear.assert_called_once()
mock_django_db_connection.is_usable.assert_called_once()
mock_django_db_connection.close.assert_not_called()


def test_keyboard_interrupt(
Expand Down Expand Up @@ -476,6 +496,26 @@ def test_ses_timeout(
assert mock_process_pool_future._timeouts == [1.0] * 120


def test_db_is_unusable_is_closed(
mock_sqs_client: Mock, mock_django_db_connection: Mock, caplog: LogCaptureFixture
) -> None:
"""If the database connection is unusable, it is closed so it will be refreshed."""
mock_django_db_connection.is_usable.return_value = False
msg = fake_sqs_message(json.dumps(TEST_SNS_MESSAGE))
mock_sqs_client.return_value = fake_queue([msg], [])
call_command(COMMAND_NAME)
summary = summary_from_exit_log(caplog)
assert summary["total_messages"] == 1
msg.delete.assert_called()
rec2 = caplog.records[1]
assert rec2.msg == "Message processed"
rec2_extra = log_extra(rec2)
assert rec2_extra["success"] is True
mock_django_db_connection.queries_log.clear.assert_called_once()
mock_django_db_connection.is_usable.assert_called_once()
mock_django_db_connection.close.assert_called_once()


def test_verify_from_sns_raises_openssl_error(
mock_verify_from_sns: Mock, mock_sqs_client: Mock, caplog: LogCaptureFixture
) -> None:
Expand Down