From 56556738421b57d5d9ba9929020370fef3582a54 Mon Sep 17 00:00:00 2001 From: Kacper Muda Date: Fri, 3 Jan 2025 17:09:45 +0100 Subject: [PATCH] feat: Add OpenLineage support for some SQL to GCS operators Signed-off-by: Kacper Muda --- generated/provider_dependencies.json | 4 +- .../providers/common/compat/__init__.py | 2 +- .../common/compat/openlineage/utils/sql.py | 84 +++++++++++++++++++ .../providers/common/compat/provider.yaml | 1 + .../google/cloud/transfers/mysql_to_gcs.py | 29 ++++++- .../google/cloud/transfers/postgres_to_gcs.py | 29 ++++++- .../google/cloud/transfers/sql_to_gcs.py | 15 ++++ .../google/cloud/transfers/trino_to_gcs.py | 27 +++++- .../airflow/providers/google/provider.yaml | 2 +- .../providers/openlineage/provider.yaml | 2 +- .../providers/openlineage/sqlparser.py | 39 +++++++++ .../cloud/transfers/test_mysql_to_gcs.py | 66 +++++++++++++++ .../cloud/transfers/test_postgres_to_gcs.py | 70 +++++++++++++++- .../google/cloud/transfers/test_sql_to_gcs.py | 19 +++++ .../cloud/transfers/test_trino_to_gcs.py | 69 ++++++++++++++- 15 files changed, 445 insertions(+), 13 deletions(-) create mode 100644 providers/src/airflow/providers/common/compat/openlineage/utils/sql.py diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 6e0616ac4cc6d..6ca4511088ccf 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -628,7 +628,7 @@ "google": { "deps": [ "PyOpenSSL>=23.0.0", - "apache-airflow-providers-common-compat>=1.3.0", + "apache-airflow-providers-common-compat>=1.4.0", "apache-airflow-providers-common-sql>=1.20.0", "apache-airflow>=2.9.0", "asgiref>=3.5.2", @@ -970,7 +970,7 @@ }, "openlineage": { "deps": [ - "apache-airflow-providers-common-compat>=1.3.0", + "apache-airflow-providers-common-compat>=1.4.0", "apache-airflow-providers-common-sql>=1.20.0", "apache-airflow>=2.9.0", "attrs>=22.2", diff --git a/providers/src/airflow/providers/common/compat/__init__.py b/providers/src/airflow/providers/common/compat/__init__.py index 21133a52bb083..bee2112ac7343 100644 --- a/providers/src/airflow/providers/common/compat/__init__.py +++ b/providers/src/airflow/providers/common/compat/__init__.py @@ -29,7 +29,7 @@ __all__ = ["__version__"] -__version__ = "1.3.0" +__version__ = "1.4.0" if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse( "2.9.0" diff --git a/providers/src/airflow/providers/common/compat/openlineage/utils/sql.py b/providers/src/airflow/providers/common/compat/openlineage/utils/sql.py new file mode 100644 index 0000000000000..9a9618d7a2eb6 --- /dev/null +++ b/providers/src/airflow/providers/common/compat/openlineage/utils/sql.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +log = logging.getLogger(__name__) + +if TYPE_CHECKING: + from airflow.providers.openlineage.sqlparser import get_openlineage_facets_with_sql + +else: + try: + from airflow.providers.openlineage.sqlparser import get_openlineage_facets_with_sql + except ImportError: + + def get_openlineage_facets_with_sql( + hook, + sql: str | list[str], + conn_id: str, + database: str | None, + ): + try: + from airflow.providers.openlineage.sqlparser import SQLParser + except ImportError: + log.debug("SQLParser could not be imported from OpenLineage provider.") + return None + + try: + from airflow.providers.openlineage.utils.utils import should_use_external_connection + + use_external_connection = should_use_external_connection(hook) + except ImportError: + # OpenLineage provider release < 1.8.0 - we always use connection + use_external_connection = True + + connection = hook.get_connection(conn_id) + try: + database_info = hook.get_openlineage_database_info(connection) + except AttributeError: + log.debug("%s has no database info provided", hook) + database_info = None + + if database_info is None: + return None + + try: + sql_parser = SQLParser( + dialect=hook.get_openlineage_database_dialect(connection), + default_schema=hook.get_openlineage_default_schema(), + ) + except AttributeError: + log.debug("%s failed to get database dialect", hook) + return None + + operator_lineage = sql_parser.generate_openlineage_metadata_from_sql( + sql=sql, + hook=hook, + database_info=database_info, + database=database, + sqlalchemy_engine=hook.get_sqlalchemy_engine(), + use_connection=use_external_connection, + ) + + return operator_lineage + + +__all__ = ["get_openlineage_facets_with_sql"] diff --git a/providers/src/airflow/providers/common/compat/provider.yaml b/providers/src/airflow/providers/common/compat/provider.yaml index 34be19b27b665..2a2f96af1fef0 100644 --- a/providers/src/airflow/providers/common/compat/provider.yaml +++ b/providers/src/airflow/providers/common/compat/provider.yaml @@ -25,6 +25,7 @@ state: ready source-date-epoch: 1731569875 # note that those versions are maintained by release manager - do not update them manually versions: + - 1.4.0 - 1.3.0 - 1.2.2 - 1.2.1 diff --git a/providers/src/airflow/providers/google/cloud/transfers/mysql_to_gcs.py b/providers/src/airflow/providers/google/cloud/transfers/mysql_to_gcs.py index b0eae584f7b4e..5371a7d5a6d9f 100644 --- a/providers/src/airflow/providers/google/cloud/transfers/mysql_to_gcs.py +++ b/providers/src/airflow/providers/google/cloud/transfers/mysql_to_gcs.py @@ -22,6 +22,8 @@ import base64 from datetime import date, datetime, time, timedelta from decimal import Decimal +from functools import cached_property +from typing import TYPE_CHECKING try: from MySQLdb.constants import FIELD_TYPE @@ -37,6 +39,9 @@ from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator from airflow.providers.mysql.hooks.mysql import MySqlHook +if TYPE_CHECKING: + from airflow.providers.openlineage.extractors import OperatorLineage + class MySQLToGCSOperator(BaseSQLToGCSOperator): """ @@ -77,10 +82,13 @@ def __init__(self, *, mysql_conn_id="mysql_default", ensure_utc=False, **kwargs) self.mysql_conn_id = mysql_conn_id self.ensure_utc = ensure_utc + @cached_property + def db_hook(self) -> MySqlHook: + return MySqlHook(mysql_conn_id=self.mysql_conn_id) + def query(self): """Query mysql and returns a cursor to the results.""" - mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) - conn = mysql.get_conn() + conn = self.db_hook.get_conn() cursor = conn.cursor() if self.ensure_utc: # Ensure TIMESTAMP results are in UTC @@ -140,3 +148,20 @@ def convert_type(self, value, schema_type: str, **kwargs): else: value = base64.standard_b64encode(value).decode("ascii") return value + + def get_openlineage_facets_on_start(self) -> OperatorLineage | None: + from airflow.providers.common.compat.openlineage.facet import SQLJobFacet + from airflow.providers.common.compat.openlineage.utils.sql import get_openlineage_facets_with_sql + from airflow.providers.openlineage.extractors import OperatorLineage + + sql_parsing_result = get_openlineage_facets_with_sql( + hook=self.db_hook, + sql=self.sql, + conn_id=self.mysql_conn_id, + database=None, + ) + gcs_output_datasets = self._get_openlineage_output_datasets() + if sql_parsing_result: + sql_parsing_result.outputs = gcs_output_datasets + return sql_parsing_result + return OperatorLineage(outputs=gcs_output_datasets, job_facets={"sql": SQLJobFacet(self.sql)}) diff --git a/providers/src/airflow/providers/google/cloud/transfers/postgres_to_gcs.py b/providers/src/airflow/providers/google/cloud/transfers/postgres_to_gcs.py index 1c52f8497cd00..c1a731c9b44ac 100644 --- a/providers/src/airflow/providers/google/cloud/transfers/postgres_to_gcs.py +++ b/providers/src/airflow/providers/google/cloud/transfers/postgres_to_gcs.py @@ -24,6 +24,8 @@ import time import uuid from decimal import Decimal +from functools import cached_property +from typing import TYPE_CHECKING import pendulum from slugify import slugify @@ -31,6 +33,9 @@ from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator from airflow.providers.postgres.hooks.postgres import PostgresHook +if TYPE_CHECKING: + from airflow.providers.openlineage.extractors import OperatorLineage + class _PostgresServerSideCursorDecorator: """ @@ -132,10 +137,13 @@ def _unique_name(self): ) return None + @cached_property + def db_hook(self) -> PostgresHook: + return PostgresHook(postgres_conn_id=self.postgres_conn_id) + def query(self): """Query Postgres and returns a cursor to the results.""" - hook = PostgresHook(postgres_conn_id=self.postgres_conn_id) - conn = hook.get_conn() + conn = self.db_hook.get_conn() cursor = conn.cursor(name=self._unique_name()) cursor.execute(self.sql, self.parameters) if self.use_server_side_cursor: @@ -180,3 +188,20 @@ def convert_type(self, value, schema_type, stringify_dict=True): if isinstance(value, Decimal): return float(value) return value + + def get_openlineage_facets_on_start(self) -> OperatorLineage | None: + from airflow.providers.common.compat.openlineage.facet import SQLJobFacet + from airflow.providers.common.compat.openlineage.utils.sql import get_openlineage_facets_with_sql + from airflow.providers.openlineage.extractors import OperatorLineage + + sql_parsing_result = get_openlineage_facets_with_sql( + hook=self.db_hook, + sql=self.sql, + conn_id=self.postgres_conn_id, + database=self.db_hook.database, + ) + gcs_output_datasets = self._get_openlineage_output_datasets() + if sql_parsing_result: + sql_parsing_result.outputs = gcs_output_datasets + return sql_parsing_result + return OperatorLineage(outputs=gcs_output_datasets, job_facets={"sql": SQLJobFacet(self.sql)}) diff --git a/providers/src/airflow/providers/google/cloud/transfers/sql_to_gcs.py b/providers/src/airflow/providers/google/cloud/transfers/sql_to_gcs.py index 2ce2260720aae..6953c70876733 100644 --- a/providers/src/airflow/providers/google/cloud/transfers/sql_to_gcs.py +++ b/providers/src/airflow/providers/google/cloud/transfers/sql_to_gcs.py @@ -34,6 +34,7 @@ from airflow.providers.google.cloud.hooks.gcs import GCSHook if TYPE_CHECKING: + from airflow.providers.common.compat.openlineage.facet import OutputDataset from airflow.utils.context import Context @@ -151,6 +152,7 @@ def __init__( self.partition_columns = partition_columns self.write_on_empty = write_on_empty self.parquet_row_group_size = parquet_row_group_size + self._uploaded_file_names: list[str] = [] def execute(self, context: Context): if self.partition_columns: @@ -501,3 +503,16 @@ def _upload_to_gcs(self, file_to_upload): gzip=self.gzip if is_data_file else False, metadata=metadata, ) + self._uploaded_file_names.append(object_name) + + def _get_openlineage_output_datasets(self) -> list[OutputDataset]: + """Retrieve OpenLineage output datasets.""" + from airflow.providers.common.compat.openlineage.facet import OutputDataset + from airflow.providers.google.cloud.openlineage.utils import extract_ds_name_from_gcs_path + + return [ + OutputDataset( + namespace=f"gs://{self.bucket}", + name=extract_ds_name_from_gcs_path(self.filename.split("{}", maxsplit=1)[0]), + ) + ] diff --git a/providers/src/airflow/providers/google/cloud/transfers/trino_to_gcs.py b/providers/src/airflow/providers/google/cloud/transfers/trino_to_gcs.py index eeacc41d54e2f..206505f2f1eab 100644 --- a/providers/src/airflow/providers/google/cloud/transfers/trino_to_gcs.py +++ b/providers/src/airflow/providers/google/cloud/transfers/trino_to_gcs.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +from functools import cached_property from typing import TYPE_CHECKING, Any from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator @@ -26,6 +27,8 @@ from trino.client import TrinoResult from trino.dbapi import Cursor as TrinoCursor + from airflow.providers.openlineage.extractors import OperatorLineage + class _TrinoToGCSTrinoCursorAdapter: """ @@ -181,10 +184,13 @@ def __init__(self, *, trino_conn_id: str = "trino_default", **kwargs): super().__init__(**kwargs) self.trino_conn_id = trino_conn_id + @cached_property + def db_hook(self) -> TrinoHook: + return TrinoHook(trino_conn_id=self.trino_conn_id) + def query(self): """Query trino and returns a cursor to the results.""" - trino = TrinoHook(trino_conn_id=self.trino_conn_id) - conn = trino.get_conn() + conn = self.db_hook.get_conn() cursor = conn.cursor() self.log.info("Executing: %s", self.sql) cursor.execute(self.sql) @@ -207,3 +213,20 @@ def convert_type(self, value, schema_type, **kwargs): :param schema_type: BigQuery data type """ return value + + def get_openlineage_facets_on_start(self) -> OperatorLineage | None: + from airflow.providers.common.compat.openlineage.facet import SQLJobFacet + from airflow.providers.common.compat.openlineage.utils.sql import get_openlineage_facets_with_sql + from airflow.providers.openlineage.extractors import OperatorLineage + + sql_parsing_result = get_openlineage_facets_with_sql( + hook=self.db_hook, + sql=self.sql, + conn_id=self.trino_conn_id, + database=None, + ) + gcs_output_datasets = self._get_openlineage_output_datasets() + if sql_parsing_result: + sql_parsing_result.outputs = gcs_output_datasets + return sql_parsing_result + return OperatorLineage(outputs=gcs_output_datasets, job_facets={"sql": SQLJobFacet(self.sql)}) diff --git a/providers/src/airflow/providers/google/provider.yaml b/providers/src/airflow/providers/google/provider.yaml index 8b5ed8539cfec..ff129b1ac5e60 100644 --- a/providers/src/airflow/providers/google/provider.yaml +++ b/providers/src/airflow/providers/google/provider.yaml @@ -101,7 +101,7 @@ versions: dependencies: - apache-airflow>=2.9.0 - - apache-airflow-providers-common-compat>=1.3.0 + - apache-airflow-providers-common-compat>=1.4.0 - apache-airflow-providers-common-sql>=1.20.0 - asgiref>=3.5.2 - dill>=0.2.3 diff --git a/providers/src/airflow/providers/openlineage/provider.yaml b/providers/src/airflow/providers/openlineage/provider.yaml index 71115b099d47f..b20f4756ef907 100644 --- a/providers/src/airflow/providers/openlineage/provider.yaml +++ b/providers/src/airflow/providers/openlineage/provider.yaml @@ -54,7 +54,7 @@ versions: dependencies: - apache-airflow>=2.9.0 - apache-airflow-providers-common-sql>=1.20.0 - - apache-airflow-providers-common-compat>=1.3.0 + - apache-airflow-providers-common-compat>=1.4.0 - attrs>=22.2 - openlineage-integration-common>=1.24.2 - openlineage-python>=1.24.2 diff --git a/providers/src/airflow/providers/openlineage/sqlparser.py b/providers/src/airflow/providers/openlineage/sqlparser.py index 9751af3f7941e..b4225909d7b20 100644 --- a/providers/src/airflow/providers/openlineage/sqlparser.py +++ b/providers/src/airflow/providers/openlineage/sqlparser.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import logging from typing import TYPE_CHECKING, Callable import sqlparse @@ -30,6 +31,7 @@ create_information_schema_query, get_table_schemas, ) +from airflow.providers.openlineage.utils.utils import should_use_external_connection from airflow.typing_compat import TypedDict from airflow.utils.log.logging_mixin import LoggingMixin @@ -38,6 +40,9 @@ from sqlalchemy.engine import Engine from airflow.hooks.base import BaseHook + from airflow.providers.common.sql.hooks.sql import DbApiHook + +log = logging.getLogger(__name__) DEFAULT_NAMESPACE = "default" DEFAULT_INFORMATION_SCHEMA_COLUMNS = [ @@ -397,3 +402,37 @@ def _get_tables_hierarchy( tables = schemas.setdefault(normalize_name(table.schema) if table.schema else None, []) tables.append(table.name) return hierarchy + + +def get_openlineage_facets_with_sql( + hook: DbApiHook, sql: str | list[str], conn_id: str, database: str | None +) -> OperatorLineage | None: + connection = hook.get_connection(conn_id) + try: + database_info = hook.get_openlineage_database_info(connection) + except AttributeError: + database_info = None + + if database_info is None: + log.debug("%s has no database info provided", hook) + return None + + try: + sql_parser = SQLParser( + dialect=hook.get_openlineage_database_dialect(connection), + default_schema=hook.get_openlineage_default_schema(), + ) + except AttributeError: + log.debug("%s failed to get database dialect", hook) + return None + + operator_lineage = sql_parser.generate_openlineage_metadata_from_sql( + sql=sql, + hook=hook, + database_info=database_info, + database=database, + sqlalchemy_engine=hook.get_sqlalchemy_engine(), + use_connection=should_use_external_connection(hook), + ) + + return operator_lineage diff --git a/providers/tests/google/cloud/transfers/test_mysql_to_gcs.py b/providers/tests/google/cloud/transfers/test_mysql_to_gcs.py index 01c3498c954ba..331ca0c0243ae 100644 --- a/providers/tests/google/cloud/transfers/test_mysql_to_gcs.py +++ b/providers/tests/google/cloud/transfers/test_mysql_to_gcs.py @@ -23,6 +23,13 @@ import pytest +from airflow.models import Connection +from airflow.providers.common.compat.openlineage.facet import ( + OutputDataset, + SchemaDatasetFacetFields, +) +from airflow.providers.common.sql.hooks.sql import DbApiHook + TASK_ID = "test-mysql-to-gcs" MYSQL_CONN_ID = "mysql_conn_test" TZ_QUERY = "SET time_zone = '+00:00'" @@ -360,3 +367,62 @@ def test_execute_with_query_error(self, mock_gcs_hook, mock_mysql_hook): ) with pytest.raises(ProgrammingError): op.execute(None) + + @pytest.mark.parametrize( + "connection_port, default_port, expected_port", + [(None, 4321, 4321), (1234, None, 1234), (1234, 4321, 1234)], + ) + def test_execute_openlineage_events(self, connection_port, default_port, expected_port): + class DBApiHookForTests(DbApiHook): + conn_name_attr = "sql_default" + get_conn = mock.MagicMock(name="conn") + get_connection = mock.MagicMock() + + def get_openlineage_database_info(self, connection): + from airflow.providers.openlineage.sqlparser import DatabaseInfo + + return DatabaseInfo( + scheme="sqlscheme", + authority=DbApiHook.get_openlineage_authority_part(connection, default_port=default_port), + ) + + dbapi_hook = DBApiHookForTests() + + class MySQLToGCSOperatorForTest(MySQLToGCSOperator): + @property + def db_hook(self): + return dbapi_hook + + sql = """SELECT a,b,c from my_db.my_table""" + op = MySQLToGCSOperatorForTest(task_id=TASK_ID, sql=sql, bucket="bucket", filename="dir/file{}.csv") + DB_SCHEMA_NAME = "PUBLIC" + rows = [ + (DB_SCHEMA_NAME, "popular_orders_day_of_week", "order_day_of_week", 1, "varchar"), + (DB_SCHEMA_NAME, "popular_orders_day_of_week", "order_placed_on", 2, "timestamp"), + (DB_SCHEMA_NAME, "popular_orders_day_of_week", "orders_placed", 3, "int4"), + ] + dbapi_hook.get_connection.return_value = Connection( + conn_id="sql_default", conn_type="mysql", host="host", port=connection_port + ) + dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = [rows, []] + + lineage = op.get_openlineage_facets_on_start() + assert len(lineage.inputs) == 1 + assert lineage.inputs[0].namespace == f"sqlscheme://host:{expected_port}" + assert lineage.inputs[0].name == "PUBLIC.popular_orders_day_of_week" + assert len(lineage.inputs[0].facets) == 1 + assert lineage.inputs[0].facets["schema"].fields == [ + SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), + SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"), + SchemaDatasetFacetFields(name="orders_placed", type="int4"), + ] + assert lineage.outputs == [ + OutputDataset( + namespace="gs://bucket", + name="dir", + ) + ] + + assert len(lineage.job_facets) == 1 + assert lineage.job_facets["sql"].query == sql + assert lineage.run_facets == {} diff --git a/providers/tests/google/cloud/transfers/test_postgres_to_gcs.py b/providers/tests/google/cloud/transfers/test_postgres_to_gcs.py index dbc68935b687a..96ea68c9aab98 100644 --- a/providers/tests/google/cloud/transfers/test_postgres_to_gcs.py +++ b/providers/tests/google/cloud/transfers/test_postgres_to_gcs.py @@ -18,10 +18,16 @@ from __future__ import annotations import datetime -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest +from airflow.models import Connection +from airflow.providers.common.compat.openlineage.facet import ( + OutputDataset, + SchemaDatasetFacetFields, +) +from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.google.cloud.transfers.postgres_to_gcs import PostgresToGCSOperator from airflow.providers.postgres.hooks.postgres import PostgresHook @@ -217,3 +223,65 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None): # once for the file and once for the schema assert gcs_hook_mock.upload.call_count == 2 + + @pytest.mark.parametrize( + "connection_port, default_port, expected_port", + [(None, 4321, 4321), (1234, None, 1234), (1234, 4321, 1234)], + ) + def test_execute_openlineage_events(self, connection_port, default_port, expected_port): + class DBApiHookForTests(DbApiHook): + conn_name_attr = "sql_default" + get_conn = MagicMock(name="conn") + get_connection = MagicMock() + database = None + + def get_openlineage_database_info(self, connection): + from airflow.providers.openlineage.sqlparser import DatabaseInfo + + return DatabaseInfo( + scheme="sqlscheme", + authority=DbApiHook.get_openlineage_authority_part(connection, default_port=default_port), + ) + + dbapi_hook = DBApiHookForTests() + + class PostgresToGCSOperatorForTest(PostgresToGCSOperator): + @property + def db_hook(self): + return dbapi_hook + + sql = """SELECT a,b,c from my_db.my_table""" + op = PostgresToGCSOperatorForTest( + task_id=TASK_ID, sql=sql, bucket="bucket", filename="dir/file{}.csv" + ) + DB_SCHEMA_NAME = "PUBLIC" + rows = [ + (DB_SCHEMA_NAME, "popular_orders_day_of_week", "order_day_of_week", 1, "varchar"), + (DB_SCHEMA_NAME, "popular_orders_day_of_week", "order_placed_on", 2, "timestamp"), + (DB_SCHEMA_NAME, "popular_orders_day_of_week", "orders_placed", 3, "int4"), + ] + dbapi_hook.get_connection.return_value = Connection( + conn_id="sql_default", conn_type="postgresql", host="host", port=connection_port + ) + dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = [rows, []] + + lineage = op.get_openlineage_facets_on_start() + assert len(lineage.inputs) == 1 + assert lineage.inputs[0].namespace == f"sqlscheme://host:{expected_port}" + assert lineage.inputs[0].name == "PUBLIC.popular_orders_day_of_week" + assert len(lineage.inputs[0].facets) == 1 + assert lineage.inputs[0].facets["schema"].fields == [ + SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), + SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"), + SchemaDatasetFacetFields(name="orders_placed", type="int4"), + ] + assert lineage.outputs == [ + OutputDataset( + namespace="gs://bucket", + name="dir", + ) + ] + + assert len(lineage.job_facets) == 1 + assert lineage.job_facets["sql"].query == sql + assert lineage.run_facets == {} diff --git a/providers/tests/google/cloud/transfers/test_sql_to_gcs.py b/providers/tests/google/cloud/transfers/test_sql_to_gcs.py index a65c7fd52919f..b252ae198d6a7 100644 --- a/providers/tests/google/cloud/transfers/test_sql_to_gcs.py +++ b/providers/tests/google/cloud/transfers/test_sql_to_gcs.py @@ -581,3 +581,22 @@ def test__write_local_data_files_csv_writes_empty_file_with_write_on_empty(self) df = pd.read_csv(file.name) assert len(df.index) == 0 + + @pytest.mark.parametrize( + ("filename", "expected_name"), + ( + ("file_{}.csv", "/"), + ("dir/file_{}.csv", "dir"), + ("{}.csv", "/"), + ("file.csv", "file.csv"), + ("dir/file.csv", "dir/file.csv"), + ), + ) + def test__get_openlineage_output_datasets(self, filename, expected_name): + op = DummySQLToGCSOperator( + task_id=TASK_ID, sql="SELECT * FROM a.b", bucket="my-bucket", filename=filename + ) + result = op._get_openlineage_output_datasets() + assert len(result) == 1 + assert result[0].namespace == "gs://my-bucket" + assert result[0].name == expected_name diff --git a/providers/tests/google/cloud/transfers/test_trino_to_gcs.py b/providers/tests/google/cloud/transfers/test_trino_to_gcs.py index 14aaf2cda9716..9511bd4ac91ae 100644 --- a/providers/tests/google/cloud/transfers/test_trino_to_gcs.py +++ b/providers/tests/google/cloud/transfers/test_trino_to_gcs.py @@ -17,8 +17,16 @@ # under the License. from __future__ import annotations -from unittest.mock import patch +from unittest.mock import MagicMock, patch +import pytest + +from airflow.models import Connection +from airflow.providers.common.compat.openlineage.facet import ( + OutputDataset, + SchemaDatasetFacetFields, +) +from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.google.cloud.transfers.trino_to_gcs import TrinoToGCSOperator TASK_ID = "test-trino-to-gcs" @@ -325,3 +333,62 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None): # once for the file and once for the schema assert mock_gcs_hook.return_value.upload.call_count == 2 + + @pytest.mark.parametrize( + "connection_port, default_port, expected_port", + [(None, 4321, 4321), (1234, None, 1234), (1234, 4321, 1234)], + ) + def test_execute_openlineage_events(self, connection_port, default_port, expected_port): + class DBApiHookForTests(DbApiHook): + conn_name_attr = "sql_default" + get_conn = MagicMock(name="conn") + get_connection = MagicMock() + + def get_openlineage_database_info(self, connection): + from airflow.providers.openlineage.sqlparser import DatabaseInfo + + return DatabaseInfo( + scheme="sqlscheme", + authority=DbApiHook.get_openlineage_authority_part(connection, default_port=default_port), + ) + + dbapi_hook = DBApiHookForTests() + + class TrinoToGCSOperatorForTest(TrinoToGCSOperator): + @property + def db_hook(self): + return dbapi_hook + + sql = """SELECT a,b,c from my_db.my_table""" + op = TrinoToGCSOperatorForTest(task_id=TASK_ID, sql=sql, bucket="bucket", filename="dir/file{}.csv") + DB_SCHEMA_NAME = "PUBLIC" + rows = [ + (DB_SCHEMA_NAME, "popular_orders_day_of_week", "order_day_of_week", 1, "varchar"), + (DB_SCHEMA_NAME, "popular_orders_day_of_week", "order_placed_on", 2, "timestamp"), + (DB_SCHEMA_NAME, "popular_orders_day_of_week", "orders_placed", 3, "int4"), + ] + dbapi_hook.get_connection.return_value = Connection( + conn_id="sql_default", conn_type="trino", host="host", port=connection_port + ) + dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = [rows, []] + + lineage = op.get_openlineage_facets_on_start() + assert len(lineage.inputs) == 1 + assert lineage.inputs[0].namespace == f"sqlscheme://host:{expected_port}" + assert lineage.inputs[0].name == "PUBLIC.popular_orders_day_of_week" + assert len(lineage.inputs[0].facets) == 1 + assert lineage.inputs[0].facets["schema"].fields == [ + SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), + SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"), + SchemaDatasetFacetFields(name="orders_placed", type="int4"), + ] + assert lineage.outputs == [ + OutputDataset( + namespace="gs://bucket", + name="dir", + ) + ] + + assert len(lineage.job_facets) == 1 + assert lineage.job_facets["sql"].query == sql + assert lineage.run_facets == {}