Skip to content

Commit

Permalink
openlineage, common.sql: provide OL SQL parser as internal OpenLineag…
Browse files Browse the repository at this point in the history
…e provider API (#31398)

* Add SQLParser class serving as the API for openlineage_sql library.
Implement base methods for SQLExecuteQueryOperator & DbApiHook.

Signed-off-by: Jakub Dardzinski <[email protected]>

Rename methods to expose their purpose for OpenLineage.

Signed-off-by: Jakub Dardzinski <[email protected]>

* Rewrite information schema query construction to SQLALchemy ORM.

Signed-off-by: Jakub Dardzinski <[email protected]>

* Clean up in-class reference

Instead of referencing the SQLParser directly, modify various static
methods to class methods instead, so they can use the cls argument
to avoid spelling out the class name repeatedly.

Also added a few changes to better ultilize type reference and eliminate
some verbose type annotations.

* Clean up typing and iterator usage

* Add static typing to hint returned type.

Signed-off-by: Jakub Dardzinski <[email protected]>

* Fix mypy issues.

Signed-off-by: Jakub Dardzinski <[email protected]>

---------

Signed-off-by: Jakub Dardzinski <[email protected]>
Co-authored-by: Tzu-ping Chung <[email protected]>
  • Loading branch information
JDarDagran and uranusjr authored Jun 29, 2023
1 parent e2e707c commit f2e2125
Show file tree
Hide file tree
Showing 11 changed files with 1,271 additions and 13 deletions.
59 changes: 56 additions & 3 deletions airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

from contextlib import closing
from datetime import datetime
from typing import Any, Callable, Iterable, Mapping, Protocol, Sequence, cast
from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Protocol, Sequence, cast
from urllib.parse import urlparse

import sqlparse
from packaging.version import Version
Expand All @@ -28,6 +29,10 @@
from airflow.hooks.base import BaseHook
from airflow.version import version

if TYPE_CHECKING:
from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.providers.openlineage.sqlparser import DatabaseInfo


def return_single_query_results(sql: str | Iterable[str], return_last: bool, split_statements: bool):
"""
Expand Down Expand Up @@ -255,8 +260,7 @@ def split_sql_string(sql: str) -> list[str]:
:return: list of individual expressions
"""
splits = sqlparse.split(sqlparse.format(sql, strip_comments=True))
statements: list[str] = list(filter(None, splits))
return statements
return [s for s in splits if s]

@property
def last_description(self) -> Sequence[Sequence] | None:
Expand Down Expand Up @@ -515,3 +519,52 @@ def test_connection(self):
message = str(e)

return status, message

def get_openlineage_database_info(self, connection) -> DatabaseInfo | None:
"""
Returns database specific information needed to generate and parse lineage metadata.
This includes information helpful for constructing information schema query
and creating correct namespace.
:param connection: Airflow connection to reduce calls of `get_connection` method
"""

def get_openlineage_database_dialect(self, connection) -> str:
"""
Returns database dialect used for SQL parsing.
For a list of supported dialects check: https://openlineage.io/docs/development/sql#sql-dialects
"""
return "generic"

def get_openlineage_default_schema(self) -> str:
"""
Returns default schema specific to database.
.. seealso::
- :class:`airflow.providers.openlineage.sqlparser.SQLParser`
"""
return self.__schema or "public"

def get_openlineage_database_specific_lineage(self, task_instance) -> OperatorLineage | None:
"""
Returns additional database specific lineage, e.g. query execution information.
This method is called only on completion of the task.
:param task_instance: this may be used to retrieve additional information
that is collected during runtime of the task
"""

@staticmethod
def get_openlineage_authority_part(connection) -> str:
"""
This method serves as common method for several hooks to get authority part from Airflow Connection.
The authority represents the hostname and port of the connection
and conforms OpenLineage naming convention for a number of databases (e.g. MySQL, Postgres, Trino).
"""
parsed = urlparse(connection.get_uri())
authority = f"{parsed.hostname}:{parsed.port}"
return authority
59 changes: 59 additions & 0 deletions airflow/providers/common/sql/operators/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
from airflow.hooks.base import BaseHook
from airflow.models import BaseOperator, SkipMixin
from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler, return_single_query_results
from airflow.utils.helpers import merge_dicts

if TYPE_CHECKING:
from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.utils.context import Context


Expand Down Expand Up @@ -290,6 +292,63 @@ def prepare_template(self) -> None:
if isinstance(self.parameters, str):
self.parameters = ast.literal_eval(self.parameters)

def get_openlineage_facets_on_start(self) -> OperatorLineage | None:
try:
from airflow.providers.openlineage.sqlparser import SQLParser
except ImportError:
return None

hook = self.get_db_hook()

connection = hook.get_connection(getattr(hook, hook.conn_name_attr))
try:
database_info = hook.get_openlineage_database_info(connection)
except AttributeError:
self.log.debug("%s has no database info provided", hook)
database_info = None

if database_info is None:
return None

try:
sql_parser = SQLParser(
dialect=hook.get_openlineage_database_dialect(connection),
default_schema=hook.get_openlineage_default_schema(),
)
except AttributeError:
self.log.debug("%s failed to get database dialect", hook)
return None

operator_lineage = sql_parser.generate_openlineage_metadata_from_sql(
sql=self.sql, hook=hook, database_info=database_info, database=self.database
)

return operator_lineage

def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage | None:
operator_lineage = self.get_openlineage_facets_on_start() or OperatorLineage()

try:
from airflow.providers.openlineage.extractors import OperatorLineage
except ImportError:
return operator_lineage

hook = self.get_db_hook()
try:
database_specific_lineage = hook.get_openlineage_database_specific_lineage(task_instance)
except AttributeError:
database_specific_lineage = None

if database_specific_lineage is None:
return operator_lineage

return OperatorLineage(
inputs=operator_lineage.inputs + database_specific_lineage.inputs,
outputs=operator_lineage.outputs + database_specific_lineage.outputs,
run_facets=merge_dicts(operator_lineage.run_facets, database_specific_lineage.run_facets),
job_facets=merge_dicts(operator_lineage.job_facets, database_specific_lineage.job_facets),
)


class SQLColumnCheckOperator(BaseSQLOperator):
"""
Expand Down
29 changes: 22 additions & 7 deletions airflow/providers/openlineage/extractors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from attrs import Factory, define

from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import TaskInstanceState
from openlineage.client.facet import BaseFacet
from openlineage.client.run import Dataset

Expand Down Expand Up @@ -88,16 +89,30 @@ def extract(self) -> OperatorLineage | None:
return None

def extract_on_complete(self, task_instance) -> OperatorLineage | None:
if task_instance.state == TaskInstanceState.FAILED:
on_failed = getattr(self.operator, "get_openlineage_facets_on_failure", None)
if on_failed and callable(on_failed):
return self._get_openlineage_facets(on_failed, task_instance)
on_complete = getattr(self.operator, "get_openlineage_facets_on_complete", None)
if on_complete and callable(on_complete):
return self._get_openlineage_facets(on_complete, task_instance)
return self.extract()

def _get_openlineage_facets(self, get_facets_method, *args) -> OperatorLineage | None:
facets: OperatorLineage = get_facets_method(*args)
return OperatorLineage(
inputs=facets.inputs,
outputs=facets.outputs,
run_facets=facets.run_facets,
job_facets=facets.job_facets,
)
try:
facets = get_facets_method(*args)
except ImportError:
self.log.exception(
"OpenLineage provider method failed to import OpenLineage integration. "
"This should not happen."
)
except Exception:
self.log.exception("OpenLineage provider method failed to extract data from provider. ")
else:
return OperatorLineage(
inputs=facets.inputs,
outputs=facets.outputs,
run_facets=facets.run_facets,
job_facets=facets.job_facets,
)
return None
Loading

0 comments on commit f2e2125

Please sign in to comment.