Skip to content

Commit

Permalink
feat: Add OpenLineage support for MsSqlHook and MSSQLToGCSOperator (#…
Browse files Browse the repository at this point in the history
…45637)

Signed-off-by: Kacper Muda <[email protected]>
  • Loading branch information
kacpermuda authored Jan 16, 2025
1 parent e7d8f5b commit 61ecbed
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 6 deletions.
6 changes: 3 additions & 3 deletions dev/breeze/tests/test_selective_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1752,7 +1752,7 @@ def test_expected_output_push(
"airflow/datasets/",
),
{
"selected-providers-list-as-string": "amazon common.compat common.io common.sql dbt.cloud ftp google mysql openlineage postgres sftp snowflake trino",
"selected-providers-list-as-string": "amazon common.compat common.io common.sql dbt.cloud ftp google microsoft.mssql mysql openlineage postgres sftp snowflake trino",
"all-python-versions": "['3.9']",
"all-python-versions-list-as-string": "3.9",
"ci-image-build": "true",
Expand All @@ -1762,13 +1762,13 @@ def test_expected_output_push(
"skip-providers-tests": "false",
"test-groups": "['core', 'providers']",
"docs-build": "true",
"docs-list-as-string": "apache-airflow amazon common.compat common.io common.sql dbt.cloud ftp google mysql openlineage postgres sftp snowflake trino",
"docs-list-as-string": "apache-airflow amazon common.compat common.io common.sql dbt.cloud ftp google microsoft.mssql mysql openlineage postgres sftp snowflake trino",
"skip-pre-commits": "check-provider-yaml-valid,flynt,identity,lint-helm-chart,mypy-airflow,mypy-dev,mypy-docs,mypy-providers,mypy-task-sdk,"
"ts-compile-format-lint-ui,ts-compile-format-lint-www",
"run-kubernetes-tests": "false",
"upgrade-to-newer-dependencies": "false",
"core-test-types-list-as-string": "API Always CLI Core Operators Other Serialization WWW",
"providers-test-types-list-as-string": "Providers[amazon] Providers[common.compat,common.io,common.sql,dbt.cloud,ftp,mysql,openlineage,postgres,sftp,snowflake,trino] Providers[google]",
"providers-test-types-list-as-string": "Providers[amazon] Providers[common.compat,common.io,common.sql,dbt.cloud,ftp,microsoft.mssql,mysql,openlineage,postgres,sftp,snowflake,trino] Providers[google]",
"needs-mypy": "false",
"mypy-checks": "[]",
},
Expand Down
3 changes: 2 additions & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,8 @@
"devel-deps": [],
"plugins": [],
"cross-providers-deps": [
"common.sql"
"common.sql",
"openlineage"
],
"excluded-python-versions": [],
"state": "ready"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,15 @@
import datetime
import decimal
from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING

from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator
from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook

if TYPE_CHECKING:
from airflow.providers.openlineage.extractors import OperatorLineage


class MSSQLToGCSOperator(BaseSQLToGCSOperator):
"""
Expand Down Expand Up @@ -75,14 +80,17 @@ def __init__(
self.mssql_conn_id = mssql_conn_id
self.bit_fields = bit_fields or []

@cached_property
def db_hook(self) -> MsSqlHook:
return MsSqlHook(mssql_conn_id=self.mssql_conn_id)

def query(self):
"""
Query MSSQL and returns a cursor of results.
:return: mssql cursor
"""
mssql = MsSqlHook(mssql_conn_id=self.mssql_conn_id)
conn = mssql.get_conn()
conn = self.db_hook.get_conn()
cursor = conn.cursor()
cursor.execute(self.sql)
return cursor
Expand All @@ -109,3 +117,20 @@ def convert_type(cls, value, schema_type, **kwargs):
if isinstance(value, (datetime.date, datetime.time)):
return value.isoformat()
return value

def get_openlineage_facets_on_start(self) -> OperatorLineage | None:
from airflow.providers.common.compat.openlineage.facet import SQLJobFacet
from airflow.providers.common.compat.openlineage.utils.sql import get_openlineage_facets_with_sql
from airflow.providers.openlineage.extractors import OperatorLineage

sql_parsing_result = get_openlineage_facets_with_sql(
hook=self.db_hook,
sql=self.sql,
conn_id=self.mssql_conn_id,
database=None,
)
gcs_output_datasets = self._get_openlineage_output_datasets()
if sql_parsing_result:
sql_parsing_result.outputs = gcs_output_datasets
return sql_parsing_result
return OperatorLineage(outputs=gcs_output_datasets, job_facets={"sql": SQLJobFacet(self.sql)})
28 changes: 28 additions & 0 deletions providers/src/airflow/providers/microsoft/mssql/hooks/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

if TYPE_CHECKING:
from airflow.providers.common.sql.dialects.dialect import Dialect
from airflow.providers.openlineage.sqlparser import DatabaseInfo


class MsSqlHook(DbApiHook):
Expand Down Expand Up @@ -117,3 +118,30 @@ def set_autocommit(

def get_autocommit(self, conn: PymssqlConnection):
return conn.autocommit_state

def get_openlineage_database_info(self, connection) -> DatabaseInfo:
"""Return MSSQL specific information for OpenLineage."""
from airflow.providers.openlineage.sqlparser import DatabaseInfo

return DatabaseInfo(
scheme=self.get_openlineage_database_dialect(connection),
authority=DbApiHook.get_openlineage_authority_part(connection, default_port=1433),
information_schema_columns=[
"table_schema",
"table_name",
"column_name",
"ordinal_position",
"data_type",
"table_catalog",
],
database=self.schema or self.connection.schema,
is_information_schema_cross_db=True,
)

def get_openlineage_database_dialect(self, connection) -> str:
"""Return database dialect."""
return "mssql"

def get_openlineage_default_schema(self) -> str | None:
"""Return current schema."""
return self.get_first("SELECT SCHEMA_NAME();")[0]
65 changes: 65 additions & 0 deletions providers/tests/google/cloud/transfers/test_mssql_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@

import pytest

from airflow.models import Connection
from airflow.providers.common.compat.openlineage.facet import (
OutputDataset,
SchemaDatasetFacetFields,
)
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.google.cloud.transfers.mssql_to_gcs import MSSQLToGCSOperator

TASK_ID = "test-mssql-to-gcs"
Expand Down Expand Up @@ -188,3 +194,62 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None):

# once for the file and once for the schema
assert gcs_hook_mock.upload.call_count == 2

@pytest.mark.parametrize(
"connection_port, default_port, expected_port",
[(None, 4321, 4321), (1234, None, 1234), (1234, 4321, 1234)],
)
def test_execute_openlineage_events(self, connection_port, default_port, expected_port):
class DBApiHookForTests(DbApiHook):
conn_name_attr = "sql_default"
get_conn = mock.MagicMock(name="conn")
get_connection = mock.MagicMock()

def get_openlineage_database_info(self, connection):
from airflow.providers.openlineage.sqlparser import DatabaseInfo

return DatabaseInfo(
scheme="sqlscheme",
authority=DbApiHook.get_openlineage_authority_part(connection, default_port=default_port),
)

dbapi_hook = DBApiHookForTests()

class MSSQLToGCSOperatorForTest(MSSQLToGCSOperator):
@property
def db_hook(self):
return dbapi_hook

sql = """SELECT a,b,c from my_db.my_table"""
op = MSSQLToGCSOperatorForTest(task_id=TASK_ID, sql=sql, bucket="bucket", filename="dir/file{}.csv")
DB_SCHEMA_NAME = "PUBLIC"
rows = [
(DB_SCHEMA_NAME, "popular_orders_day_of_week", "order_day_of_week", 1, "varchar"),
(DB_SCHEMA_NAME, "popular_orders_day_of_week", "order_placed_on", 2, "timestamp"),
(DB_SCHEMA_NAME, "popular_orders_day_of_week", "orders_placed", 3, "int4"),
]
dbapi_hook.get_connection.return_value = Connection(
conn_id="sql_default", conn_type="mssql", host="host", port=connection_port
)
dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = [rows, []]

lineage = op.get_openlineage_facets_on_start()
assert len(lineage.inputs) == 1
assert lineage.inputs[0].namespace == f"sqlscheme://host:{expected_port}"
assert lineage.inputs[0].name == "PUBLIC.popular_orders_day_of_week"
assert len(lineage.inputs[0].facets) == 1
assert lineage.inputs[0].facets["schema"].fields == [
SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"),
SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"),
SchemaDatasetFacetFields(name="orders_placed", type="int4"),
]
assert lineage.outputs == [
OutputDataset(
namespace="gs://bucket",
name="dir",
)
]

assert len(lineage.job_facets) == 1
assert lineage.job_facets["sql"].query == sql
assert lineage.run_facets == {}

0 comments on commit 61ecbed

Please sign in to comment.