Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: DbApiHook.insert_rows unnecessarily restarting connections #40615

Merged
merged 13 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import warnings
from contextlib import closing, contextmanager
from datetime import datetime
from functools import cached_property
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -181,19 +182,21 @@ def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwa
"replace_statement_format", "REPLACE INTO {} {} VALUES ({})"
)

@property
@cached_property
def placeholder(self):
conn = self.get_connection(getattr(self, self.conn_name_attr))
conn_id = getattr(self, self.conn_name_attr)
conn = self.get_connection(conn_id)
placeholder = conn.extra_dejson.get("placeholder")
if placeholder:
if placeholder in SQL_PLACEHOLDERS:
return placeholder
self.log.warning(
"Placeholder defined in Connection '%s' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' "
"and got ignored. Falling back to the default placeholder '%s'.",
self.conn_name_attr,
self._placeholder,
)
if placeholder in SQL_PLACEHOLDERS:
return placeholder
self.log.warning(
"Placeholder '%s' defined in Connection '%s' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' "
"and got ignored. Falling back to the default placeholder '%s'.",
placeholder,
conn_id,
self._placeholder,
)
return self._placeholder

def get_conn(self):
Expand Down
18 changes: 18 additions & 0 deletions tests/providers/common/sql/hooks/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,21 @@ def test_make_common_data_structure_no_deprecated_method(self):
def test_placeholder_config_from_extra(self):
dbapi_hook = mock_hook(DbApiHook, conn_params={"extra": {"placeholder": "?"}})
assert dbapi_hook.placeholder == "?"

@pytest.mark.db_test
def test_placeholder_config_from_extra_when_not_in_default_sql_placeholders(self):
dbapi_hook = mock_hook(DbApiHook, conn_params={"extra": {"placeholder": "{}"}})
assert dbapi_hook.placeholder == "%s"
dbapi_hook.log.warning.assert_called_with(
"Placeholder '%s' defined in Connection '%s' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' "
"and got ignored. Falling back to the default placeholder '%s'.",
"{}",
"default_conn_id",
DbApiHook._placeholder)

@pytest.mark.db_test
def test_placeholder_multiple_times_and_make_sure_connection_is_only_invoked_once(self):
dbapi_hook = mock_hook(DbApiHook)
for number_of_invocations in range(0, 10):
assert dbapi_hook.placeholder == "%s"
assert dbapi_hook.connection_invocations == 1
8 changes: 7 additions & 1 deletion tests/providers/common/sql/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

import logging
from typing import TYPE_CHECKING
from unittest import mock

Expand Down Expand Up @@ -52,12 +53,17 @@ def mock_hook(hook_class: type[BaseHook], hook_params=None, conn_params=None):

class MockedHook(hook_class): # type: ignore[misc, valid-type]
conn_name_attr = "test_conn_id"
log = mock.MagicMock(spec=logging.Logger)
connection_invocations = 0

@classmethod
def get_connection(cls, conn_id: str):
cls.connection_invocations += 1
return connection

def get_conn(self):
return conn

return MockedHook(**hook_params)
hook = MockedHook(**hook_params)
hook.log.setLevel(logging.DEBUG)
return hook
Loading