Skip to content

Commit

Permalink
FIX: DbApiHook.insert_rows unnecessarily restarting connections (#40615)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: David Blain <[email protected]>
  • Loading branch information
dabla and davidblain-infrabel authored Jul 5, 2024
1 parent 1dc582d commit 3f0979c
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 9 deletions.
20 changes: 13 additions & 7 deletions airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import warnings
from contextlib import closing, contextmanager
from datetime import datetime
from functools import cached_property
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -54,6 +55,7 @@
from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.providers.openlineage.sqlparser import DatabaseInfo


T = TypeVar("T")
SQL_PLACEHOLDERS = frozenset({"%s", "?"})

Expand Down Expand Up @@ -181,24 +183,28 @@ def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwa
"replace_statement_format", "REPLACE INTO {} {} VALUES ({})"
)

@property
def get_conn_id(self) -> str:
return getattr(self, self.conn_name_attr)

@cached_property
def placeholder(self):
conn = self.get_connection(getattr(self, self.conn_name_attr))
conn = self.get_connection(self.get_conn_id())
placeholder = conn.extra_dejson.get("placeholder")
if placeholder:
if placeholder in SQL_PLACEHOLDERS:
return placeholder
self.log.warning(
"Placeholder defined in Connection '%s' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' "
"Placeholder '%s' defined in Connection '%s' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' "
"and got ignored. Falling back to the default placeholder '%s'.",
self.conn_name_attr,
placeholder,
self.get_conn_id(),
self._placeholder,
)
return self._placeholder

def get_conn(self):
"""Return a connection object."""
db = self.get_connection(getattr(self, cast(str, self.conn_name_attr)))
db = self.get_connection(self.get_conn_id())
return self.connector.connect(host=db.host, port=db.port, username=db.login, schema=db.schema)

def get_uri(self) -> str:
Expand All @@ -207,7 +213,7 @@ def get_uri(self) -> str:
:return: the extracted uri.
"""
conn = self.get_connection(getattr(self, self.conn_name_attr))
conn = self.get_connection(self.get_conn_id())
conn.schema = self.__schema or conn.schema
return conn.get_uri()

Expand Down Expand Up @@ -502,7 +508,7 @@ def set_autocommit(self, conn, autocommit):
if not self.supports_autocommit and autocommit:
self.log.warning(
"%s connection doesn't support autocommit but autocommit activated.",
getattr(self, self.conn_name_attr),
self.get_conn_id(),
)
conn.autocommit = autocommit

Expand Down
5 changes: 4 additions & 1 deletion airflow/providers/common/sql/hooks/sql.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ from airflow.exceptions import (
from airflow.hooks.base import BaseHook as BaseHook
from airflow.providers.openlineage.extractors import OperatorLineage as OperatorLineage
from airflow.providers.openlineage.sqlparser import DatabaseInfo as DatabaseInfo
from functools import cached_property as cached_property
from pandas import DataFrame as DataFrame
from sqlalchemy.engine import URL as URL
from typing import Any, Callable, Generator, Iterable, Mapping, Protocol, Sequence, TypeVar, overload
Expand All @@ -63,7 +64,9 @@ class DbApiHook(BaseHook):
log_sql: Incomplete
descriptions: Incomplete
def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwargs) -> None: ...
@property

def get_conn_id(self) -> str: ...
@cached_property
def placeholder(self): ...
def get_conn(self): ...
def get_uri(self) -> str: ...
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/common/sql/hooks/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def test_placeholder_with_invalid_placeholder_in_extra(self, caplog):
)

assert self.db_hook.placeholder == "%s"
assert any(
assert (
"Placeholder defined in Connection 'test_conn_id' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' "
"and got ignored. Falling back to the default placeholder '%s'." in message
for message in caplog.messages
Expand Down
19 changes: 19 additions & 0 deletions tests/providers/common/sql/hooks/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#
from __future__ import annotations

import logging
import warnings
from unittest.mock import MagicMock

Expand Down Expand Up @@ -256,3 +257,21 @@ def test_make_common_data_structure_no_deprecated_method(self):
def test_placeholder_config_from_extra(self):
dbapi_hook = mock_hook(DbApiHook, conn_params={"extra": {"placeholder": "?"}})
assert dbapi_hook.placeholder == "?"

@pytest.mark.db_test
def test_placeholder_config_from_extra_when_not_in_default_sql_placeholders(self, caplog):
with caplog.at_level(logging.WARNING, logger="airflow.providers.common.sql.hooks.test_sql"):
dbapi_hook = mock_hook(DbApiHook, conn_params={"extra": {"placeholder": "!"}})
assert dbapi_hook.placeholder == "%s"
assert (
"Placeholder '!' defined in Connection 'default_conn_id' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' "
f"and got ignored. Falling back to the default placeholder '{DbApiHook._placeholder}'."
in caplog.text
)

@pytest.mark.db_test
def test_placeholder_multiple_times_and_make_sure_connection_is_only_invoked_once(self):
dbapi_hook = mock_hook(DbApiHook)
for _ in range(10):
assert dbapi_hook.placeholder == "%s"
assert dbapi_hook.connection_invocations == 1
2 changes: 2 additions & 0 deletions tests/providers/common/sql/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ def mock_hook(hook_class: type[BaseHook], hook_params=None, conn_params=None):

class MockedHook(hook_class): # type: ignore[misc, valid-type]
conn_name_attr = "test_conn_id"
connection_invocations = 0

@classmethod
def get_connection(cls, conn_id: str):
cls.connection_invocations += 1
return connection

def get_conn(self):
Expand Down

0 comments on commit 3f0979c

Please sign in to comment.