diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index 06d91f8f4103e..4dba2d843a380 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -20,6 +20,7 @@ import warnings from contextlib import closing, contextmanager from datetime import datetime +from functools import cached_property from typing import ( TYPE_CHECKING, Any, @@ -54,6 +55,7 @@ from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.sqlparser import DatabaseInfo + T = TypeVar("T") SQL_PLACEHOLDERS = frozenset({"%s", "?"}) @@ -181,24 +183,28 @@ def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwa "replace_statement_format", "REPLACE INTO {} {} VALUES ({})" ) - @property + def get_conn_id(self) -> str: + return getattr(self, self.conn_name_attr) + + @cached_property def placeholder(self): - conn = self.get_connection(getattr(self, self.conn_name_attr)) + conn = self.get_connection(self.get_conn_id()) placeholder = conn.extra_dejson.get("placeholder") if placeholder: if placeholder in SQL_PLACEHOLDERS: return placeholder self.log.warning( - "Placeholder defined in Connection '%s' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' " + "Placeholder '%s' defined in Connection '%s' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' " "and got ignored. Falling back to the default placeholder '%s'.", - self.conn_name_attr, + placeholder, + self.get_conn_id(), self._placeholder, ) return self._placeholder def get_conn(self): """Return a connection object.""" - db = self.get_connection(getattr(self, cast(str, self.conn_name_attr))) + db = self.get_connection(self.get_conn_id()) return self.connector.connect(host=db.host, port=db.port, username=db.login, schema=db.schema) def get_uri(self) -> str: @@ -207,7 +213,7 @@ def get_uri(self) -> str: :return: the extracted uri. """ - conn = self.get_connection(getattr(self, self.conn_name_attr)) + conn = self.get_connection(self.get_conn_id()) conn.schema = self.__schema or conn.schema return conn.get_uri() @@ -502,7 +508,7 @@ def set_autocommit(self, conn, autocommit): if not self.supports_autocommit and autocommit: self.log.warning( "%s connection doesn't support autocommit but autocommit activated.", - getattr(self, self.conn_name_attr), + self.get_conn_id(), ) conn.autocommit = autocommit diff --git a/airflow/providers/common/sql/hooks/sql.pyi b/airflow/providers/common/sql/hooks/sql.pyi index 16c3d6592a341..27142aeaf2a0f 100644 --- a/airflow/providers/common/sql/hooks/sql.pyi +++ b/airflow/providers/common/sql/hooks/sql.pyi @@ -40,6 +40,7 @@ from airflow.exceptions import ( from airflow.hooks.base import BaseHook as BaseHook from airflow.providers.openlineage.extractors import OperatorLineage as OperatorLineage from airflow.providers.openlineage.sqlparser import DatabaseInfo as DatabaseInfo +from functools import cached_property as cached_property from pandas import DataFrame as DataFrame from sqlalchemy.engine import URL as URL from typing import Any, Callable, Generator, Iterable, Mapping, Protocol, Sequence, TypeVar, overload @@ -63,7 +64,9 @@ class DbApiHook(BaseHook): log_sql: Incomplete descriptions: Incomplete def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwargs) -> None: ... - @property + + def get_conn_id(self) -> str: ... + @cached_property def placeholder(self): ... def get_conn(self): ... def get_uri(self) -> str: ... diff --git a/tests/providers/common/sql/hooks/test_dbapi.py b/tests/providers/common/sql/hooks/test_dbapi.py index 090ec80e682b1..10e4756307e69 100644 --- a/tests/providers/common/sql/hooks/test_dbapi.py +++ b/tests/providers/common/sql/hooks/test_dbapi.py @@ -476,7 +476,7 @@ def test_placeholder_with_invalid_placeholder_in_extra(self, caplog): ) assert self.db_hook.placeholder == "%s" - assert any( + assert ( "Placeholder defined in Connection 'test_conn_id' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' " "and got ignored. Falling back to the default placeholder '%s'." in message for message in caplog.messages diff --git a/tests/providers/common/sql/hooks/test_sql.py b/tests/providers/common/sql/hooks/test_sql.py index 5d3f4acfb16f0..866a6f55287d6 100644 --- a/tests/providers/common/sql/hooks/test_sql.py +++ b/tests/providers/common/sql/hooks/test_sql.py @@ -18,6 +18,7 @@ # from __future__ import annotations +import logging import warnings from unittest.mock import MagicMock @@ -256,3 +257,21 @@ def test_make_common_data_structure_no_deprecated_method(self): def test_placeholder_config_from_extra(self): dbapi_hook = mock_hook(DbApiHook, conn_params={"extra": {"placeholder": "?"}}) assert dbapi_hook.placeholder == "?" + + @pytest.mark.db_test + def test_placeholder_config_from_extra_when_not_in_default_sql_placeholders(self, caplog): + with caplog.at_level(logging.WARNING, logger="airflow.providers.common.sql.hooks.test_sql"): + dbapi_hook = mock_hook(DbApiHook, conn_params={"extra": {"placeholder": "!"}}) + assert dbapi_hook.placeholder == "%s" + assert ( + "Placeholder '!' defined in Connection 'default_conn_id' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' " + f"and got ignored. Falling back to the default placeholder '{DbApiHook._placeholder}'." + in caplog.text + ) + + @pytest.mark.db_test + def test_placeholder_multiple_times_and_make_sure_connection_is_only_invoked_once(self): + dbapi_hook = mock_hook(DbApiHook) + for _ in range(10): + assert dbapi_hook.placeholder == "%s" + assert dbapi_hook.connection_invocations == 1 diff --git a/tests/providers/common/sql/test_utils.py b/tests/providers/common/sql/test_utils.py index 3f5255f8bdb16..7c76f3a7fa507 100644 --- a/tests/providers/common/sql/test_utils.py +++ b/tests/providers/common/sql/test_utils.py @@ -52,9 +52,11 @@ def mock_hook(hook_class: type[BaseHook], hook_params=None, conn_params=None): class MockedHook(hook_class): # type: ignore[misc, valid-type] conn_name_attr = "test_conn_id" + connection_invocations = 0 @classmethod def get_connection(cls, conn_id: str): + cls.connection_invocations += 1 return connection def get_conn(self):