Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make placeholder of DbApiHook configurable in UI #38528

Merged
merged 24 commits into from
Apr 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
347dbe9
refactor: Moved placeholder property from OdbcHook class to parent Db…
davidblain-infrabel Mar 27, 2024
c0691eb
Merge branch 'main' into feature/configure-jdbc-hook-placeholder
dabla Mar 27, 2024
22b8efe
refactor: Import BaseHook under type checking block
davidblain-infrabel Mar 27, 2024
faf832b
refactor: Marked test_placeholder_config_from_extra as a db test
davidblain-infrabel Mar 27, 2024
1e6daa9
refactor: Moved mock_conn from conftest to test_utils module under co…
davidblain-infrabel Mar 27, 2024
7b3e535
refactor: Removed unnecessary else statement in placeholder property
davidblain-infrabel Mar 27, 2024
e9ad88a
refactor: Default placeholder can be a class/static variable as it's …
davidblain-infrabel Mar 27, 2024
a79de0b
refactor: Updated sql test with changes from main
davidblain-infrabel Mar 29, 2024
089cdb3
Merge branch 'main' into feature/configure-jdbc-hook-placeholder
davidblain-infrabel Mar 29, 2024
9957b5f
Merge branch 'main' into feature/configure-jdbc-hook-placeholder
dabla Mar 29, 2024
20cae13
refactor: Reformatted test
davidblain-infrabel Mar 30, 2024
0fbc20f
Merge branch 'main' into feature/configure-jdbc-hook-placeholder
dabla Mar 30, 2024
ebd6a7c
Merge branch 'main' into feature/configure-jdbc-hook-placeholder
dabla Apr 3, 2024
28afc30
Merge branch 'main' into feature/configure-jdbc-hook-placeholder
dabla Apr 3, 2024
5ab376c
Merge branch 'main' into feature/configure-jdbc-hook-placeholder
dabla Apr 3, 2024
e2ae2b0
Merge branch 'main' into feature/configure-jdbc-hook-placeholder
dabla Apr 4, 2024
aee9483
Update airflow/providers/common/sql/hooks/sql.py
dabla Apr 4, 2024
27039ea
Merge branch 'main' into feature/configure-jdbc-hook-placeholder
dabla Apr 4, 2024
e4c50a6
Merge branch 'main' into feature/configure-jdbc-hook-placeholder
dabla Apr 4, 2024
f3dda46
Merge branch 'main' into feature/configure-jdbc-hook-placeholder
dabla Apr 4, 2024
ad3d898
fix: Fixed name of constant SQL_PLACEHOLDERS being checked in placeho…
davidblain-infrabel Apr 4, 2024
c329894
Merge branch 'main' into feature/configure-jdbc-hook-placeholder
dabla Apr 4, 2024
c58dfe6
Merge branch 'main' into feature/configure-jdbc-hook-placeholder
dabla Apr 4, 2024
05c3dec
Merge branch 'main' into feature/configure-jdbc-hook-placeholder
dabla Apr 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from airflow.providers.openlineage.sqlparser import DatabaseInfo

T = TypeVar("T")
SQL_PLACEHOLDERS = frozenset({"%s", "?"})


def return_single_query_results(sql: str | Iterable[str], return_last: bool, split_statements: bool):
Expand Down Expand Up @@ -146,6 +147,8 @@ class DbApiHook(BaseHook):
connector: ConnectorProtocol | None = None
# Override with db-specific query to check connection
_test_connection_sql = "select 1"
# Default SQL placeholder
_placeholder: str = "%s"

def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwargs):
super().__init__()
Expand All @@ -164,7 +167,6 @@ def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwa
self.__schema = schema
self.log_sql = log_sql
self.descriptions: list[Sequence[Sequence] | None] = []
self._placeholder: str = "%s"
self._insert_statement_format: str = kwargs.get(
"insert_statement_format", "INSERT INTO {} {} VALUES ({})"
)
Expand All @@ -173,7 +175,17 @@ def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwa
)

@property
def placeholder(self) -> str:
def placeholder(self):
conn = self.get_connection(getattr(self, self.conn_name_attr))
placeholder = conn.extra_dejson.get("placeholder")
if placeholder in SQL_PLACEHOLDERS:
return placeholder
self.log.warning(
"Placeholder defined in Connection '%s' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' "
"and got ignored. Falling back to the default placeholder '%s'.",
placeholder,
self._placeholder,
)
return self._placeholder

def get_conn(self):
Expand Down
16 changes: 0 additions & 16 deletions airflow/providers/odbc/hooks/odbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.utils.helpers import merge_dicts

DEFAULT_ODBC_PLACEHOLDERS = frozenset({"%s", "?"})


class OdbcHook(DbApiHook):
"""
Expand Down Expand Up @@ -202,20 +200,6 @@ def get_conn(self) -> Connection:
conn = connect(self.odbc_connection_string, **self.connect_kwargs)
return conn

@property
def placeholder(self):
placeholder = self.connection.extra_dejson.get("placeholder")
if placeholder in DEFAULT_ODBC_PLACEHOLDERS:
return placeholder
else:
self.log.warning(
"Placeholder defined in Connection '%s' is not listed in 'DEFAULT_ODBC_PLACEHOLDERS' "
"and got ignored. Falling back to the default placeholder '%s'.",
placeholder,
self._placeholder,
)
return self._placeholder

def get_uri(self) -> str:
"""URI invoked in :meth:`~airflow.providers.common.sql.hooks.sql.DbApiHook.get_sqlalchemy_engine`."""
quoted_conn_str = quote_plus(self.odbc_connection_string)
Expand Down
75 changes: 38 additions & 37 deletions tests/providers/common/sql/hooks/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from __future__ import annotations

import warnings
from typing import Any
from unittest.mock import MagicMock

import pytest
Expand All @@ -28,6 +27,7 @@
from airflow.models import Connection
from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler
from airflow.utils.session import provide_session
from tests.providers.common.sql.test_utils import mock_hook

TASK_ID = "sql-operator"
HOST = "host"
Expand Down Expand Up @@ -214,39 +214,40 @@ def mock_execute(*args, **kwargs):
dbapi_hook.get_conn.return_value.cursor.return_value.close.assert_called()


@pytest.mark.db_test
@pytest.mark.parametrize(
"empty_statement",
[
pytest.param([], id="Empty list"),
pytest.param("", id="Empty string"),
pytest.param("\n", id="Only EOL"),
],
)
def test_no_query(empty_statement):
dbapi_hook = DBApiHookForTests()
dbapi_hook.get_conn.return_value.cursor.rowcount = 0
with pytest.raises(ValueError) as err:
dbapi_hook.run(sql=empty_statement)
assert err.value.args[0] == "List of SQL statements is empty"


@pytest.mark.db_test
def test_make_common_data_structure_hook_has_deprecated_method():
"""If hook implements ``_make_serializable`` warning should be raised on call."""

class DBApiHookForMakeSerializableTests(DBApiHookForTests):
def _make_serializable(self, result: Any):
return result

hook = DBApiHookForMakeSerializableTests()
with pytest.warns(AirflowProviderDeprecationWarning, match="`_make_serializable` method is deprecated"):
hook._make_common_data_structure(["foo", "bar", "baz"])


@pytest.mark.db_test
def test_make_common_data_structure_no_deprecated_method():
"""If hook not implements ``_make_serializable`` there is no warning should be raised on call."""
with warnings.catch_warnings():
warnings.simplefilter("error", AirflowProviderDeprecationWarning)
DBApiHookForTests()._make_common_data_structure(["foo", "bar", "baz"])
class TestDbApiHook:
@pytest.mark.db_test
@pytest.mark.parametrize(
"empty_statement",
[
pytest.param([], id="Empty list"),
pytest.param("", id="Empty string"),
pytest.param("\n", id="Only EOL"),
],
)
def test_no_query(self, empty_statement):
dbapi_hook = mock_hook(DbApiHook)
with pytest.raises(ValueError) as err:
dbapi_hook.run(sql=empty_statement)
assert err.value.args[0] == "List of SQL statements is empty"

@pytest.mark.db_test
def test_make_common_data_structure_hook_has_deprecated_method(self):
"""If hook implements ``_make_serializable`` warning should be raised on call."""
hook = mock_hook(DbApiHook)
hook._make_serializable = lambda result: result
with pytest.warns(
AirflowProviderDeprecationWarning, match="`_make_serializable` method is deprecated"
):
hook._make_common_data_structure(["foo", "bar", "baz"])

@pytest.mark.db_test
def test_make_common_data_structure_no_deprecated_method(self):
"""If hook not implements ``_make_serializable`` there is no warning should be raised on call."""
with warnings.catch_warnings():
warnings.simplefilter("error", AirflowProviderDeprecationWarning)
mock_hook(DbApiHook)._make_common_data_structure(["foo", "bar", "baz"])

@pytest.mark.db_test
def test_placeholder_config_from_extra(self):
dbapi_hook = mock_hook(DbApiHook, conn_params={"extra": {"placeholder": "?"}})
assert dbapi_hook.placeholder == "?"
55 changes: 55 additions & 0 deletions tests/providers/common/sql/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import TYPE_CHECKING
from unittest import mock

from airflow.models import Connection

if TYPE_CHECKING:
from airflow.hooks.base import BaseHook


def mock_hook(hook_class: type[BaseHook], hook_params=None, conn_params=None):
hook_params = hook_params or {}
conn_params = conn_params or {}
connection = Connection(
**{
**dict(login="login", password="password", host="host", schema="schema", port=1234),
**conn_params,
}
)

cursor = mock.MagicMock(
rowcount=0, spec=["description", "rowcount", "execute", "fetchall", "fetchone", "close"]
)
conn = mock.MagicMock()
conn.cursor.return_value = cursor

class MockedHook(hook_class):
conn_name_attr = "test_conn_id"

@classmethod
def get_connection(cls, conn_id: str):
return connection

def get_conn(self):
return conn

return MockedHook(**hook_params)
Loading