Skip to content

Commit

Permalink
make it b/c
Browse files Browse the repository at this point in the history
  • Loading branch information
hussein-awala committed Aug 6, 2023
1 parent 8beb232 commit 391f345
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 4 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/databricks/operators/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
client_parameters: dict[str, Any] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
super().__init__(conn_id=databricks_conn_id, **kwargs)
self.databricks_conn_id = databricks_conn_id
self._output_path = output_path
self._output_format = output_format
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/exasol/operators/exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ def __init__(
if schema is not None:
hook_params = kwargs.pop("hook_params", {})
kwargs["hook_params"] = {"schema": schema, **hook_params}
super().__init__(handler=handler, **kwargs)
super().__init__(conn_id=exasol_conn_id, handler=handler, **kwargs)
7 changes: 5 additions & 2 deletions airflow/providers/snowflake/operators/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def __init__(
"session_parameters": session_parameters,
**hook_params,
}
super().__init__(sql=sql, parameters=parameters, **kwargs)
super().__init__(sql=sql, parameters=parameters, conn_id=snowflake_conn_id, **kwargs)
self.query_ids: list[str] = []


Expand Down Expand Up @@ -293,7 +293,9 @@ def __init__(
"session_parameters": session_parameters,
**hook_params,
}
super().__init__(sql=sql, pass_value=pass_value, tolerance=tolerance, **kwargs)
super().__init__(
sql=sql, pass_value=pass_value, tolerance=tolerance, conn_id=snowflake_conn_id, **kwargs
)
self.query_ids: list[str] = []


Expand Down Expand Up @@ -376,6 +378,7 @@ def __init__(
metrics_thresholds=metrics_thresholds,
date_filter_column=date_filter_column,
days_back=days_back,
conn_id=snowflake_conn_id,
**kwargs,
)
self.query_ids: list[str] = []
Expand Down
32 changes: 32 additions & 0 deletions tests/providers/common/sql/operators/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1278,3 +1278,35 @@ def test_with_skip_in_branch_downstream_dependencies2(self, mock_get_db_hook):
assert ti.state == State.NONE
else:
raise ValueError(f"Invalid task id {ti.task_id} found!")


class TestBaseSQLOperatorSubClass:

from airflow.providers.common.sql.operators.sql import BaseSQLOperator

class NewStyleBaseSQLOperatorSubClass(BaseSQLOperator):
"""New style subclass of BaseSQLOperator"""

conn_id_field = "custom_conn_id_field"

def __init__(self, custom_conn_id_field="test_conn", **kwargs):
super().__init__(**kwargs)
self.custom_conn_id_field = custom_conn_id_field

class OldStyleBaseSQLOperatorSubClass(BaseSQLOperator):
"""Old style subclass of BaseSQLOperator"""

def __init__(self, custom_conn_id_field="test_conn", **kwargs):
super().__init__(conn_id=custom_conn_id_field, **kwargs)

@pytest.mark.parametrize(
"operator_class", [NewStyleBaseSQLOperatorSubClass, OldStyleBaseSQLOperatorSubClass]
)
@mock.patch("airflow.hooks.base.BaseHook.get_connection")
def test_new_style_subclass(self, mock_get_connection, operator_class):
from airflow.providers.common.sql.hooks.sql import DbApiHook

op = operator_class(task_id="test_task")
mock_get_connection.return_value.get_hook.return_value = MagicMock(spec=DbApiHook)
op.get_db_hook()
mock_get_connection.assert_called_once_with("test_conn")
1 change: 1 addition & 0 deletions tests/providers/exasol/operators/test_exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def test_pass_parameters(self, mock_get_db_hook):
def test_overwrite_schema(self, mock_base_op):
ExasolOperator(task_id="TEST", sql="SELECT 1", schema="dummy")
mock_base_op.assert_called_once_with(
conn_id="exasol_default",
database=None,
hook_params={"schema": "dummy"},
default_args={},
Expand Down

0 comments on commit 391f345

Please sign in to comment.