diff --git a/airflow/providers/databricks/provider.yaml b/airflow/providers/databricks/provider.yaml index 0b14851140c5b..4e1a9873ceb8e 100644 --- a/airflow/providers/databricks/provider.yaml +++ b/airflow/providers/databricks/provider.yaml @@ -104,6 +104,7 @@ sensors: - integration-name: Databricks python-modules: - airflow.providers.databricks.sensors.databricks_sql + - airflow.providers.databricks.sensors.databricks_partition connection-types: - hook-class-name: airflow.providers.databricks.hooks.databricks.DatabricksHook diff --git a/airflow/providers/databricks/sensors/databricks_partition.py b/airflow/providers/databricks/sensors/databricks_partition.py new file mode 100644 index 0000000000000..94708df9950d0 --- /dev/null +++ b/airflow/providers/databricks/sensors/databricks_partition.py @@ -0,0 +1,228 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""This module contains Databricks sensors.""" + +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING, Any, Callable, Sequence + +from databricks.sql.utils import ParamEscaper + +from airflow.compat.functools import cached_property +from airflow.exceptions import AirflowException +from airflow.providers.common.sql.hooks.sql import fetch_all_handler +from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook +from airflow.sensors.base import BaseSensorOperator + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class DatabricksPartitionSensor(BaseSensorOperator): + """ + Sensor to detect the presence of table partitions in Databricks. + + :param databricks_conn_id: Reference to :ref:`Databricks + connection id` (templated), defaults to + DatabricksSqlHook.default_conn_name. + :param sql_warehouse_name: Optional name of Databricks SQL warehouse. If not specified, ``http_path`` + must be provided as described below, defaults to None + :param http_path: Optional string specifying HTTP path of Databricks SQL warehouse or All Purpose cluster. + If not specified, it should be either specified in the Databricks connection's + extra parameters, or ``sql_warehouse_name`` must be specified. + :param session_configuration: An optional dictionary of Spark session parameters. If not specified, + it could be specified in the Databricks connection's extra parameters, defaults to None + :param http_headers: An optional list of (k, v) pairs + that will be set as HTTP headers on every request. (templated). + :param catalog: An optional initial catalog to use. + Requires Databricks Runtime version 9.0+ (templated), defaults to "" + :param schema: An optional initial schema to use. + Requires Databricks Runtime version 9.0+ (templated), defaults to "default" + :param table_name: Name of the table to check partitions. + :param partitions: Name of the partitions to check. + Example: {"date": "2023-01-03", "name": ["abc", "def"]} + :param partition_operator: Optional comparison operator for partitions, such as >=. + :param handler: Handler for DbApiHook.run() to return results, defaults to fetch_all_handler + :param client_parameters: Additional parameters internal to Databricks SQL connector parameters. + """ + + template_fields: Sequence[str] = ( + "databricks_conn_id", + "catalog", + "schema", + "table_name", + "partitions", + "http_headers", + ) + + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"sql": "sql"} + + def __init__( + self, + *, + databricks_conn_id: str = DatabricksSqlHook.default_conn_name, + http_path: str | None = None, + sql_warehouse_name: str | None = None, + session_configuration=None, + http_headers: list[tuple[str, str]] | None = None, + catalog: str = "", + schema: str = "default", + table_name: str, + partitions: dict, + partition_operator: str = "=", + handler: Callable[[Any], Any] = fetch_all_handler, + client_parameters: dict[str, Any] | None = None, + **kwargs, + ) -> None: + self.databricks_conn_id = databricks_conn_id + self._http_path = http_path + self._sql_warehouse_name = sql_warehouse_name + self.session_config = session_configuration + self.http_headers = http_headers + self.catalog = catalog + self.schema = schema + self.caller = "DatabricksPartitionSensor" + self.partitions = partitions + self.partition_operator = partition_operator + self.table_name = table_name + self.client_parameters = client_parameters or {} + self.hook_params = kwargs.pop("hook_params", {}) + self.handler = handler + self.escaper = ParamEscaper() + super().__init__(**kwargs) + + def _sql_sensor(self, sql): + """Executes the supplied SQL statement using the hook object.""" + hook = self._get_hook + sql_result = hook.run( + sql, + handler=self.handler if self.do_xcom_push else None, + ) + self.log.debug("SQL result: %s", sql_result) + return sql_result + + @cached_property + def _get_hook(self) -> DatabricksSqlHook: + """Creates and returns a DatabricksSqlHook object.""" + return DatabricksSqlHook( + self.databricks_conn_id, + self._http_path, + self._sql_warehouse_name, + self.session_config, + self.http_headers, + self.catalog, + self.schema, + self.caller, + **self.client_parameters, + **self.hook_params, + ) + + def _check_table_partitions(self) -> list: + """ + The method performs the following: + * Generates the fully qualified table name. + * Calls the generate partition query. + * Based on the result returned by the partition generation method, + the _sql_sensor method is called. + """ + if self.table_name.split(".")[0] == "delta": + _fully_qualified_table_name = self.table_name + else: + _fully_qualified_table_name = str(self.catalog + "." + self.schema + "." + self.table_name) + self.log.debug("Table name generated from arguments: %s", _fully_qualified_table_name) + _joiner_val = " AND " + _prefix = f"SELECT 1 FROM {_fully_qualified_table_name} WHERE" + _suffix = " LIMIT 1" + + partition_sql = self._generate_partition_query( + prefix=_prefix, + suffix=_suffix, + joiner_val=_joiner_val, + opts=self.partitions, + table_name=_fully_qualified_table_name, + escape_key=False, + ) + return self._sql_sensor(partition_sql) + + def _generate_partition_query( + self, + prefix: str, + suffix: str, + joiner_val: str, + table_name: str, + opts: dict[str, str] | None = None, + escape_key: bool = False, + ) -> str: + """ + Queries the table for available partitions. + Generates the SQL query based on the partition data types. + * For a list, it prepares the SQL in the format: + column_name in (value1, value2,...) + * For a numeric type, it prepares the format: + column_name =(or other provided operator such as >=) value + * For a date type, it prepares the format: + column_name =(or other provided operator such as >=) value + Once the filter predicates have been generated like above, the query + is prepared to be executed using the prefix and suffix supplied, which are: + "SELECT 1 FROM {_fully_qualified_table_name} WHERE" and "LIMIT 1". + """ + partition_columns = self._sql_sensor(f"DESCRIBE DETAIL {table_name}")[0][7] + self.log.debug("Partition columns: %s", partition_columns) + if len(partition_columns) < 1: + raise AirflowException(f"Table {table_name} does not have partitions") + formatted_opts = "" + if opts is not None and len(opts) > 0: + output_list = [] + for partition_col, partition_value in opts.items(): + if escape_key: + partition_col = self.escaper.escape_item(partition_col) + if partition_col in partition_columns: + if isinstance(partition_value, list): + output_list.append(f"""{partition_col} in {tuple(partition_value)}""") + self.log.debug("List formatting for partitions: %s", output_list) + if isinstance(partition_value, (int, float, complex)): + output_list.append( + f"""{partition_col}{self.partition_operator}{self.escaper.escape_item(partition_value)}""" + ) + if isinstance(partition_value, (str, datetime)): + output_list.append( + f"""{partition_col}{self.partition_operator}{self.escaper.escape_item(partition_value)}""" + ) + else: + raise AirflowException( + f"Column {partition_col} not part of table partitions: {partition_columns}" + ) + else: + # Raises exception if the table does not have any partitions. + raise AirflowException("No partitions specified to check with the sensor.") + formatted_opts = f"{prefix} {joiner_val.join(output_list)} {suffix}" + self.log.debug("Formatted options: %s", formatted_opts) + + return formatted_opts.strip() + + def poke(self, context: Context) -> bool: + """Checks the table partitions and returns the results.""" + partition_result = self._check_table_partitions() + self.log.debug("Partition sensor result: %s", partition_result) + if len(partition_result) >= 1: + return True + else: + raise AirflowException(f"Specified partition(s): {self.partitions} were not found.") diff --git a/docs/apache-airflow-providers-databricks/operators/sql.rst b/docs/apache-airflow-providers-databricks/operators/sql.rst index fd0535283df17..55bbf64758562 100644 --- a/docs/apache-airflow-providers-databricks/operators/sql.rst +++ b/docs/apache-airflow-providers-databricks/operators/sql.rst @@ -113,10 +113,57 @@ Configuring Databricks connection to be used with the Sensor. :start-after: [START howto_sensor_databricks_connection_setup] :end-before: [END howto_sensor_databricks_connection_setup] -Poking the specific table for existence of data/partition: +Poking the specific table with the SQL statement: .. exampleinclude:: /../../tests/system/providers/databricks/example_databricks_sensors.py :language: python :dedent: 4 :start-after: [START howto_sensor_databricks_sql] :end-before: [END howto_sensor_databricks_sql] + + +DatabricksPartitionSensor +========================= + +Sensors are a special type of Operator that are designed to do exactly one thing - wait for something to occur. It can be time-based, or waiting for a file, or an external event, but all they do is wait until something happens, and then succeed so their downstream tasks can run. + +For the Databricks Partition Sensor, we check if a partition and its related value exists and if not, it waits until the partition value arrives. The waiting time and interval to check can be configured in the timeout and poke_interval parameters respectively. + +Use the :class:`~airflow.providers.databricks.sensors.partition.DatabricksPartitionSensor` to run the sensor +for a table accessible via a Databricks SQL warehouse or interactive cluster. + +Using the Sensor +---------------- + +The sensor accepts the table name and partition name(s), value(s) from the user and generates the SQL query to check if +the specified partition name, value(s) exist in the specified table. + +The required parameters are: + +* ``table_name`` (name of the table for partition check). + +* ``partitions`` (name of the partitions to check). + +* ``partition_operator`` (comparison operator for partitions, to be used for range or limit of values, such as partition_name >= partition_value). `Databricks comparison operators `_ are supported. + +* One of ``sql_warehouse_name`` (name of Databricks SQL warehouse to use) or ``http_path`` (HTTP path for Databricks SQL warehouse or Databricks cluster). + +Other parameters are optional and can be found in the class documentation. + +Examples +-------- +Configuring Databricks connection to be used with the Sensor. + +.. exampleinclude:: /../../tests/system/providers/databricks/example_databricks_sensors.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_databricks_connection_setup] + :end-before: [END howto_sensor_databricks_connection_setup] + +Poking the specific table for existence of data/partition: + +.. exampleinclude:: /../../tests/system/providers/databricks/example_databricks_sensors.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_databricks_partition] + :end-before: [END howto_sensor_databricks_partition] diff --git a/tests/providers/databricks/sensors/test_databricks_partition.py b/tests/providers/databricks/sensors/test_databricks_partition.py new file mode 100644 index 0000000000000..56e77e5e59a38 --- /dev/null +++ b/tests/providers/databricks/sensors/test_databricks_partition.py @@ -0,0 +1,101 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from __future__ import annotations + +from datetime import datetime, timedelta +from unittest.mock import patch + +import pytest + +from airflow.exceptions import AirflowException +from airflow.models import DAG +from airflow.providers.common.sql.hooks.sql import fetch_all_handler +from airflow.providers.databricks.sensors.databricks_partition import DatabricksPartitionSensor +from airflow.utils import timezone + +TASK_ID = "db-partition-sensor" +DEFAULT_CONN_ID = "databricks_default" +HOST = "xx.cloud.databricks.com" +HOST_WITH_SCHEME = "https://xx.cloud.databricks.com" +PERSONAL_ACCESS_TOKEN = "token" + +DEFAULT_SCHEMA = "schema1" +DEFAULT_CATALOG = "catalog1" +DEFAULT_TABLE = "table1" +DEFAULT_HTTP_PATH = "/sql/1.0/warehouses/xxxxx" +DEFAULT_SQL_WAREHOUSE = "sql_warehouse_default" +DEFAULT_CALLER = "TestDatabricksPartitionSensor" +DEFAULT_PARTITION = {"date": "2023-01-01"} +DEFAULT_DATE = timezone.datetime(2017, 1, 1) + +TIMESTAMP_TEST = datetime.now() - timedelta(days=30) + +sql_sensor = DatabricksPartitionSensor( + databricks_conn_id=DEFAULT_CONN_ID, + sql_warehouse_name=DEFAULT_SQL_WAREHOUSE, + task_id=TASK_ID, + table_name=DEFAULT_TABLE, + schema=DEFAULT_SCHEMA, + catalog=DEFAULT_CATALOG, + partitions=DEFAULT_PARTITION, + handler=fetch_all_handler, +) + + +class TestDatabricksPartitionSensor: + def setup_method(self): + args = {"owner": "airflow", "start_date": DEFAULT_DATE} + self.dag = DAG("test_dag_id", default_args=args) + + self.partition_sensor = DatabricksPartitionSensor( + task_id=TASK_ID, + databricks_conn_id=DEFAULT_CONN_ID, + sql_warehouse_name=DEFAULT_SQL_WAREHOUSE, + dag=self.dag, + schema=DEFAULT_SCHEMA, + catalog=DEFAULT_CATALOG, + table_name=DEFAULT_TABLE, + partitions={"date": "2023-01-01"}, + partition_operator="=", + timeout=30, + poke_interval=15, + ) + + def test_init(self): + assert self.partition_sensor.databricks_conn_id == "databricks_default" + assert self.partition_sensor.task_id == "db-partition-sensor" + assert self.partition_sensor._sql_warehouse_name == "sql_warehouse_default" + assert self.partition_sensor.poke_interval == 15 + + @pytest.mark.parametrize( + argnames=("sensor_poke_result", "expected_poke_result"), argvalues=[(True, True), (False, False)] + ) + @patch.object(DatabricksPartitionSensor, "poke") + def test_poke(self, mock_poke, sensor_poke_result, expected_poke_result): + mock_poke.return_value = sensor_poke_result + assert self.partition_sensor.poke({}) == expected_poke_result + + def test_unsupported_conn_type(self): + with pytest.raises(AirflowException): + self.partition_sensor.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + @patch.object(DatabricksPartitionSensor, "poke") + def test_partition_sensor(self, patched_poke): + patched_poke.return_value = True + assert self.partition_sensor.poke({}) diff --git a/tests/system/providers/databricks/example_databricks_sensors.py b/tests/system/providers/databricks/example_databricks_sensors.py index d2507118e2f27..fd572a6bd9055 100644 --- a/tests/system/providers/databricks/example_databricks_sensors.py +++ b/tests/system/providers/databricks/example_databricks_sensors.py @@ -22,6 +22,7 @@ from datetime import datetime from airflow import DAG +from airflow.providers.databricks.sensors.databricks_partition import DatabricksPartitionSensor from airflow.providers.databricks.sensors.databricks_sql import DatabricksSqlSensor # [Env variable to be used from the OS] @@ -66,11 +67,26 @@ ) # [END howto_sensor_databricks_sql] - # This DAG contains only one task, so the below pattern (task1) is not necessary and does not - # affect the execution of the single DAG task which would run regardless of its presence. - # It is present here as a pattern to be expanded for users. - # For example, (task1 >> task 2 >> task3) - (sql_sensor) + # [START howto_sensor_databricks_partition] + # Example of using the Databricks Partition Sensor to check the presence + # of the specified partition(s) in a table. + partition_sensor = DatabricksPartitionSensor( + databricks_conn_id=connection_id, + sql_warehouse_name=sql_warehouse_name, + catalog="hive_metastore", + task_id="partition_sensor_task", + table_name="sample_table_2", + schema="temp", + partitions={"date": "2023-01-03", "name": ["abc", "def"]}, + partition_operator="=", + timeout=60 * 2, + ) + # [END howto_sensor_databricks_partition] + + # Task dependency between the SQL sensor and the partition sensor. + # If the first task(sql_sensor) succeeds, the second task(partition_sensor) + # runs, else all the subsequent DAG tasks and the DAG are marked as failed. + (sql_sensor >> partition_sensor) from tests.system.utils.watcher import watcher