From 7797a9953ab400836ff461b60dabaea0204513b9 Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Thu, 13 Oct 2022 14:39:42 -0700 Subject: [PATCH 01/16] feat(ingest): loosen sqlalchemy dep & support airflow 2.3 update pytest-docker feat(ingest): loosen sqlalchemy dep ci matrix update airflow stuff to match more airflow fixes always install the right sqlalchemy stubs sqlalchemy compat type fixes fix ci fix types in oracle fix more stuff with lint fix attrs typing more airflow type annotations refactor iolet preprocessing more type fixes even more type fixes markupsafe compat fix airflow tests more airflow compat tweak workflow fix imports skip dag bag load test fix sqlalchemy mypy final type fixes ignore more mypy issues add assert to handle mypy issue fix athena test modify airflow plugin --- .github/workflows/metadata-ingestion.yml | 12 +++- .../airflow-plugin/build.gradle | 10 +-- .../airflow-plugin/setup.py | 10 +-- .../datahub_airflow_plugin/datahub_plugin.py | 5 +- metadata-ingestion/build.gradle | 22 ++++-- .../docs/sources/oracle/oracle.md | 2 +- metadata-ingestion/pyproject.toml | 1 + .../scripts/install-sqlalchemy-stubs.sh | 28 ++++++++ metadata-ingestion/setup.cfg | 3 +- metadata-ingestion/setup.py | 17 ++--- .../ingestion/source/ge_data_profiler.py | 3 + .../datahub/ingestion/source/kafka_connect.py | 3 +- .../datahub/ingestion/source/sql/athena.py | 2 +- .../datahub/ingestion/source/sql/bigquery.py | 5 +- .../src/datahub/ingestion/source/sql/mssql.py | 9 ++- .../datahub/ingestion/source/sql/oracle.py | 14 +++- .../ingestion/source/sql/sql_common.py | 15 ++-- .../src/datahub/ingestion/source/sql/trino.py | 5 +- .../ingestion/source/usage/redshift_usage.py | 19 ++++-- .../integrations/great_expectations/action.py | 3 + .../datahub/utilities/_markupsafe_compat.py | 11 +++ .../utilities/sqlalchemy_query_combiner.py | 11 ++- .../src/datahub_provider/_airflow_compat.py | 27 ++++++++ .../src/datahub_provider/_lineage_core.py | 68 +++++++++++++++++-- .../client/airflow_generator.py | 46 +++++++++---- .../src/datahub_provider/hooks/datahub.py | 10 ++- .../src/datahub_provider/lineage/datahub.py | 2 + .../test_helpers/sqlalchemy_mypy_plugin.py | 33 +++++++++ metadata-ingestion/tests/unit/test_airflow.py | 32 +++++---- .../tests/unit/test_athena_source.py | 5 +- 30 files changed, 340 insertions(+), 93 deletions(-) create mode 100755 metadata-ingestion/scripts/install-sqlalchemy-stubs.sh create mode 100644 metadata-ingestion/src/datahub/utilities/_markupsafe_compat.py create mode 100644 metadata-ingestion/src/datahub_provider/_airflow_compat.py create mode 100644 metadata-ingestion/tests/test_helpers/sqlalchemy_mypy_plugin.py diff --git a/.github/workflows/metadata-ingestion.yml b/.github/workflows/metadata-ingestion.yml index 329006b1fa87d4..1008cc1240d9bc 100644 --- a/.github/workflows/metadata-ingestion.yml +++ b/.github/workflows/metadata-ingestion.yml @@ -38,6 +38,14 @@ jobs: "testIntegrationBatch1", "testSlowIntegration", ] + include: + - python-version: "3.7" + extraPythonRequirement: "sqlalchemy==1.3.24" + - python-version: "3.7" + command: "testAirflow1" + extraPythonRequirement: "sqlalchemy==1.3.24" + - python-version: "3.10" + extraPythonRequirement: "sqlalchemy~=1.4.0" fail-fast: false steps: - uses: actions/checkout@v3 @@ -50,8 +58,8 @@ jobs: hadoop-version: "3.2" - name: Install dependencies run: ./metadata-ingestion/scripts/install_deps.sh - - name: Run metadata-ingestion tests - run: ./gradlew :metadata-ingestion:build :metadata-ingestion:${{ matrix.command }} + - name: Run metadata-ingestion tests (extras ${{ matrix.extraPythonRequirement }}) + run: ./gradlew -Pextra_pip_requirements='${{ matrix.extraPythonRequirement }}' :metadata-ingestion:${{ matrix.command }} - name: pip freeze show list installed if: always() run: source metadata-ingestion/venv/bin/activate && pip freeze diff --git a/metadata-ingestion-modules/airflow-plugin/build.gradle b/metadata-ingestion-modules/airflow-plugin/build.gradle index e8b1b0839187d2..2a0d6f76f9d8e6 100644 --- a/metadata-ingestion-modules/airflow-plugin/build.gradle +++ b/metadata-ingestion-modules/airflow-plugin/build.gradle @@ -7,8 +7,10 @@ ext { venv_name = 'venv' } +def pip_install_command = "USE_DEV_VERSION=1 ${venv_name}/bin/pip install -e ../../metadata-ingestion" + task checkPythonVersion(type: Exec) { - commandLine python_executable, '-c', 'import sys; assert sys.version_info >= (3, 6)' + commandLine python_executable, '-c', 'import sys; assert sys.version_info >= (3, 7)' } task environmentSetup(type: Exec, dependsOn: checkPythonVersion) { @@ -20,7 +22,7 @@ task environmentSetup(type: Exec, dependsOn: checkPythonVersion) { task installPackage(type: Exec, dependsOn: environmentSetup) { inputs.file file('setup.py') outputs.dir("${venv_name}") - commandLine "${venv_name}/bin/pip", 'install', '-e', '.' + commandLine 'bash', '-x', '-c', "${pip_install_command} -e ." } task install(dependsOn: [installPackage]) @@ -30,7 +32,7 @@ task installDev(type: Exec, dependsOn: [install]) { outputs.dir("${venv_name}") outputs.file("${venv_name}/.build_install_dev_sentinel") commandLine 'bash', '-x', '-c', - "${venv_name}/bin/pip install -e .[dev] && touch ${venv_name}/.build_install_dev_sentinel" + "${pip_install_command} -e .[dev] && touch ${venv_name}/.build_install_dev_sentinel" } task lint(type: Exec, dependsOn: installDev) { @@ -65,7 +67,7 @@ task installDevTest(type: Exec, dependsOn: [installDev]) { outputs.dir("${venv_name}") outputs.file("${venv_name}/.build_install_dev_test_sentinel") commandLine 'bash', '-x', '-c', - "${venv_name}/bin/pip install -e .[dev,integration-tests] && touch ${venv_name}/.build_install_dev_test_sentinel" + "${pip_install_command} -e .[dev,integration-tests] && touch ${venv_name}/.build_install_dev_test_sentinel" } def testFile = hasProperty('testFile') ? testFile : 'unknown' diff --git a/metadata-ingestion-modules/airflow-plugin/setup.py b/metadata-ingestion-modules/airflow-plugin/setup.py index 56c250fd9f317b..8b8952a5f0dcfa 100644 --- a/metadata-ingestion-modules/airflow-plugin/setup.py +++ b/metadata-ingestion-modules/airflow-plugin/setup.py @@ -1,9 +1,11 @@ import os import pathlib -from typing import Dict, Set import setuptools +USE_DEV_VERSION = os.environ.get("USE_DEV_VERSION", "0") == "1" + + package_metadata: dict = {} with open("./src/datahub_airflow_plugin/__init__.py") as fp: exec(fp.read(), package_metadata) @@ -23,9 +25,9 @@ def get_long_description(): "typing-inspect", "pydantic>=1.5.1", "apache-airflow >= 2.0.2", - "acryl-datahub[airflow] >= 0.8.36", - # Pinned dependencies to make dependency resolution faster. - "sqlalchemy==1.3.24", + "acryl-datahub[airflow] >= 0.8.36" + if not USE_DEV_VERSION + else "acryl-datahub[airflow] == 0.0.0.dev0", } diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_plugin.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_plugin.py index 92b5cffa588e78..81541326b172aa 100644 --- a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_plugin.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_plugin.py @@ -1,6 +1,6 @@ import contextlib import traceback -from typing import Any, Iterable +from typing import Any, Dict, Iterable import attr from airflow.configuration import conf @@ -10,6 +10,7 @@ from airflow.utils.module_loading import import_string from cattr import structure from datahub.api.entities.dataprocess.dataprocess_instance import InstanceRunResult +from datahub_provider._lineage_core import preprocess_task_iolets from datahub_provider.client.airflow_generator import AirflowGenerator from datahub_provider.hooks.datahub import DatahubGenericHook from datahub_provider.lineage.datahub import DatahubLineageConfig @@ -39,6 +40,8 @@ def get_lineage_config() -> DatahubLineageConfig: def get_inlets_from_task(task: BaseOperator, context: Any) -> Iterable[Any]: + # TODO fix for https://github.com/apache/airflow/commit/1b1f3fabc5909a447a6277cafef3a0d4ef1f01ae + inlets = [] if isinstance(task._inlets, (str, BaseOperator)) or attr.has(task._inlets): # type: ignore inlets = [ diff --git a/metadata-ingestion/build.gradle b/metadata-ingestion/build.gradle index 7c8b302616060f..b0ee814cafa9f5 100644 --- a/metadata-ingestion/build.gradle +++ b/metadata-ingestion/build.gradle @@ -7,8 +7,12 @@ ext { venv_name = 'venv' } +if (!project.hasProperty("extra_pip_requirements")) { + ext.extra_pip_requirements = "" +} + task checkPythonVersion(type: Exec) { - commandLine python_executable, '-c', 'import sys; assert sys.version_info >= (3, 6)' + commandLine python_executable, '-c', 'import sys; assert sys.version_info >= (3, 7)' } task environmentSetup(type: Exec, dependsOn: checkPythonVersion) { @@ -24,7 +28,7 @@ task runPreFlightScript(type: Exec, dependsOn: environmentSetup) { task installPackage(type: Exec, dependsOn: runPreFlightScript) { inputs.file file('setup.py') outputs.dir("${venv_name}") - commandLine "${venv_name}/bin/pip", 'install', '-e', '.' + commandLine 'bash', '-x', '-c', "${venv_name}/bin/pip install -e . ${extra_pip_requirements}" } task codegen(type: Exec, dependsOn: [environmentSetup, installPackage, ':metadata-events:mxe-schemas:build']) { @@ -40,7 +44,7 @@ task installDev(type: Exec, dependsOn: [install]) { outputs.dir("${venv_name}") outputs.file("${venv_name}/.build_install_dev_sentinel") commandLine 'bash', '-x', '-c', - "${venv_name}/bin/pip install -e .[dev] && touch ${venv_name}/.build_install_dev_sentinel" + "${venv_name}/bin/pip install -e .[dev] ${extra_pip_requirements} && touch ${venv_name}/.build_install_dev_sentinel" } @@ -67,15 +71,21 @@ task lint(type: Exec, dependsOn: installDev) { */ commandLine 'bash', '-c', "find ${venv_name}/lib -path *airflow/_vendor/connexion/spec.py -exec sed -i.bak -e '169,169s/ # type: List\\[str\\]//g' {} \\; && " + - "source ${venv_name}/bin/activate && set -x && black --check --diff src/ tests/ examples/ && isort --check --diff src/ tests/ examples/ && flake8 --count --statistics src/ tests/ examples/ && mypy src/ tests/ examples/" + "source ${venv_name}/bin/activate && set -x && " + + "./scripts/install-sqlalchemy-stubs.sh && " + + "black --check --diff src/ tests/ examples/ && " + + "isort --check --diff src/ tests/ examples/ && " + + "flake8 --count --statistics src/ tests/ examples/ && " + + "mypy --show-traceback --show-error-codes src/ tests/ examples/" } task lintFix(type: Exec, dependsOn: installDev) { commandLine 'bash', '-c', "source ${venv_name}/bin/activate && set -x && " + + "./scripts/install-sqlalchemy-stubs.sh && " + "black src/ tests/ examples/ && " + "isort src/ tests/ examples/ && " + "flake8 src/ tests/ examples/ && " + - "mypy src/ tests/ examples/" + "mypy --show-traceback --show-error-codes src/ tests/ examples/" } task testQuick(type: Exec, dependsOn: installDev) { @@ -92,7 +102,7 @@ task installDevTest(type: Exec, dependsOn: [install]) { outputs.dir("${venv_name}") outputs.file("${venv_name}/.build_install_dev_test_sentinel") commandLine 'bash', '-c', - "${venv_name}/bin/pip install -e .[dev,integration-tests] && touch ${venv_name}/.build_install_dev_test_sentinel" + "${venv_name}/bin/pip install -e .[dev,integration-tests] ${extra_pip_requirements} && touch ${venv_name}/.build_install_dev_test_sentinel" } def testFile = hasProperty('testFile') ? testFile : 'unknown' diff --git a/metadata-ingestion/docs/sources/oracle/oracle.md b/metadata-ingestion/docs/sources/oracle/oracle.md index 6043e1c9156632..1ddd772e07af94 100644 --- a/metadata-ingestion/docs/sources/oracle/oracle.md +++ b/metadata-ingestion/docs/sources/oracle/oracle.md @@ -1 +1 @@ -As a SQL-based service, the Athena integration is also supported by our SQL profiler. See here for more details on configuration. \ No newline at end of file +As a SQL-based service, the Oracle integration is also supported by our SQL profiler. See here for more details on configuration. diff --git a/metadata-ingestion/pyproject.toml b/metadata-ingestion/pyproject.toml index 60b67ca4429223..cf390684bb76db 100644 --- a/metadata-ingestion/pyproject.toml +++ b/metadata-ingestion/pyproject.toml @@ -14,6 +14,7 @@ target-version = ['py36', 'py37', 'py38'] [tool.isort] combine_as_imports = true indent = ' ' +known_future_library = ['__future__', 'datahub.utilities._markupsafe_compat', 'datahub_provider._airflow_compat'] profile = 'black' sections = 'FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER' skip_glob = 'src/datahub/metadata' diff --git a/metadata-ingestion/scripts/install-sqlalchemy-stubs.sh b/metadata-ingestion/scripts/install-sqlalchemy-stubs.sh new file mode 100755 index 00000000000000..7c14a06464f99e --- /dev/null +++ b/metadata-ingestion/scripts/install-sqlalchemy-stubs.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +set -euo pipefail + +# ASSUMPTION: This assumes that we're running from inside the venv. + +SQLALCHEMY_VERSION=$(python -c 'import sqlalchemy; print(sqlalchemy.__version__)') + +if [[ $SQLALCHEMY_VERSION == 1.3.* ]]; then + ENSURE_NOT_INSTALLED=sqlalchemy2-stubs + ENSURE_INSTALLED=sqlalchemy-stubs +elif [[ $SQLALCHEMY_VERSION == 1.4.* ]]; then + ENSURE_NOT_INSTALLED=sqlalchemy-stubs + ENSURE_INSTALLED=sqlalchemy2-stubs +else + echo "Unsupported SQLAlchemy version: $SQLALCHEMY_VERSION" + exit 1 +fi + +FORCE_REINSTALL="" +if pip show $ENSURE_NOT_INSTALLED >/dev/null 2>&1 ; then + pip uninstall --yes $ENSURE_NOT_INSTALLED + FORCE_REINSTALL="--force-reinstall" +fi + +if [ -n "$FORCE_REINSTALL" ] || ! pip show $ENSURE_INSTALLED >/dev/null 2>&1 ; then + pip install $FORCE_REINSTALL $ENSURE_INSTALLED +fi diff --git a/metadata-ingestion/setup.cfg b/metadata-ingestion/setup.cfg index 2467c61983e5d8..f9a8ba2a54e413 100644 --- a/metadata-ingestion/setup.cfg +++ b/metadata-ingestion/setup.cfg @@ -22,7 +22,7 @@ ban-relative-imports = true [mypy] plugins = - sqlmypy, + ./tests/test_helpers/sqlalchemy_mypy_plugin.py, pydantic.mypy exclude = ^(venv|build|dist)/ ignore_missing_imports = yes @@ -54,6 +54,7 @@ disallow_untyped_defs = yes asyncio_mode = auto addopts = --cov=src --cov-report term-missing --cov-config setup.cfg --strict-markers markers = + airflow: marks tests related to airflow (deselect with '-m not airflow') slow_unit: marks tests to only run slow unit tests (deselect with '-m not slow_unit') integration: marks tests to only run in integration (deselect with '-m "not integration"') integration_batch_1: mark tests to only run in batch 1 of integration tests. This is done mainly for parallelisation (deselect with '-m not integration_batch_1') diff --git a/metadata-ingestion/setup.py b/metadata-ingestion/setup.py index 0adcd27efcf993..4c3ec8590bdd23 100644 --- a/metadata-ingestion/setup.py +++ b/metadata-ingestion/setup.py @@ -101,7 +101,7 @@ def get_long_description(): sql_common = { # Required for all SQL sources. - "sqlalchemy==1.3.24", + "sqlalchemy>=1.3.24, <2", # Required for SQL profiling. "great-expectations>=0.15.12", # GE added handling for higher version of jinja2 @@ -145,6 +145,12 @@ def get_long_description(): "more-itertools>=8.12.0", } +clickhouse_common = { + # Clickhouse 0.1.8 requires SQLAlchemy 1.3.x, while the newer versions + # allow SQLAlchemy 1.4.x. + "clickhouse-sqlalchemy>=0.1.8", +} + redshift_common = { "sqlalchemy-redshift", "psycopg2-binary", @@ -237,12 +243,8 @@ def get_long_description(): "sqllineage==1.3.6", "sql_metadata", }, # deprecated, but keeping the extra for backwards compatibility - "clickhouse": sql_common | {"clickhouse-sqlalchemy==0.1.8"}, - "clickhouse-usage": sql_common - | usage_common - | { - "clickhouse-sqlalchemy==0.1.8", - }, + "clickhouse": sql_common | clickhouse_common, + "clickhouse-usage": sql_common | usage_common | clickhouse_common, "datahub-lineage-file": set(), "datahub-business-glossary": set(), "delta-lake": {*data_lake_profiling, *delta_lake}, @@ -337,7 +339,6 @@ def get_long_description(): mypy_stubs = { "types-dataclasses", - "sqlalchemy-stubs", "types-pkg_resources", "types-six", "types-python-dateutil", diff --git a/metadata-ingestion/src/datahub/ingestion/source/ge_data_profiler.py b/metadata-ingestion/src/datahub/ingestion/source/ge_data_profiler.py index d7a9816d455a6f..d969944d286d52 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/ge_data_profiler.py +++ b/metadata-ingestion/src/datahub/ingestion/source/ge_data_profiler.py @@ -1,3 +1,5 @@ +from datahub.utilities._markupsafe_compat import MARKUPSAFE_PATCHED + import collections import concurrent.futures import contextlib @@ -51,6 +53,7 @@ get_query_columns, ) +assert MARKUPSAFE_PATCHED logger: logging.Logger = logging.getLogger(__name__) P = ParamSpec("P") diff --git a/metadata-ingestion/src/datahub/ingestion/source/kafka_connect.py b/metadata-ingestion/src/datahub/ingestion/source/kafka_connect.py index 0dc00f955f5223..0e3487eb927a16 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/kafka_connect.py +++ b/metadata-ingestion/src/datahub/ingestion/source/kafka_connect.py @@ -276,7 +276,8 @@ def get_parser( url_instance = make_url(url) source_platform = get_platform_from_sqlalchemy_uri(str(url_instance)) database_name = url_instance.database - db_connection_url = f"{url_instance.drivername}://{url_instance.host}:{url_instance.port}/{url_instance.database}" + assert database_name + db_connection_url = f"{url_instance.drivername}://{url_instance.host}:{url_instance.port}/{database_name}" topic_prefix = self.connector_manifest.config.get("topic.prefix", None) diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py b/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py index 594c7857afb35b..f9ad4fc3ea722a 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py @@ -109,7 +109,7 @@ def get_table_properties( self, inspector: Inspector, schema: str, table: str ) -> Tuple[Optional[str], Dict[str, str], Optional[str]]: if not self.cursor: - self.cursor = inspector.dialect._raw_connection(inspector.engine).cursor() + self.cursor = inspector.engine.raw_connection().cursor() assert self.cursor # Unfortunately properties can be only get through private methods as those are not exposed diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/bigquery.py b/metadata-ingestion/src/datahub/ingestion/source/sql/bigquery.py index 715f265b6bb662..0613fda4595b24 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/bigquery.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/bigquery.py @@ -792,8 +792,11 @@ def get_latest_partition( # Bigquery only supports one partition column # https://stackoverflow.com/questions/62886213/adding-multiple-partitioned-columns-to-bigquery-table-from-sql-query row = result.fetchone() + if row and hasattr(row, "_asdict"): + # Compat with sqlalchemy 1.4 Row type. + row = row._asdict() if row: - return BigQueryPartitionColumn(**row) + return BigQueryPartitionColumn(**row.items()) return None def get_shard_from_table(self, table: str) -> Tuple[str, Optional[str]]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/mssql.py b/metadata-ingestion/src/datahub/ingestion/source/sql/mssql.py index 12e972018e2406..54880a895989e8 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/mssql.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/mssql.py @@ -11,7 +11,6 @@ from sqlalchemy import create_engine, inspect from sqlalchemy.engine.base import Connection from sqlalchemy.engine.reflection import Inspector -from sqlalchemy.engine.result import ResultProxy, RowProxy from datahub.configuration.common import AllowDenyPattern from datahub.ingestion.api.common import PipelineContext @@ -135,7 +134,7 @@ def __init__(self, config: SQLServerConfig, ctx: PipelineContext): def _populate_table_descriptions(self, conn: Connection, db_name: str) -> None: # see https://stackoverflow.com/questions/5953330/how-do-i-map-the-id-in-sys-extended-properties-to-an-object-name # also see https://www.mssqltips.com/sqlservertip/5384/working-with-sql-server-extended-properties/ - table_metadata: ResultProxy = conn.execute( + table_metadata = conn.execute( """ SELECT SCHEMA_NAME(T.SCHEMA_ID) AS schema_name, @@ -149,13 +148,13 @@ def _populate_table_descriptions(self, conn: Connection, db_name: str) -> None: AND EP.CLASS = 1 """ ) - for row in table_metadata: # type: RowProxy + for row in table_metadata: self.table_descriptions[ f"{db_name}.{row['schema_name']}.{row['table_name']}" ] = row["table_description"] def _populate_column_descriptions(self, conn: Connection, db_name: str) -> None: - column_metadata: RowProxy = conn.execute( + column_metadata = conn.execute( """ SELECT SCHEMA_NAME(T.SCHEMA_ID) AS schema_name, @@ -172,7 +171,7 @@ def _populate_column_descriptions(self, conn: Connection, db_name: str) -> None: AND EP.CLASS = 1 """ ) - for row in column_metadata: # type: RowProxy + for row in column_metadata: self.column_descriptions[ f"{db_name}.{row['schema_name']}.{row['table_name']}.{row['column_name']}" ] = row["column_description"] diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/oracle.py b/metadata-ingestion/src/datahub/ingestion/source/sql/oracle.py index efe62c5e7efb29..480ab2c46d588e 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/oracle.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/oracle.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Iterable, List, Optional, Tuple, cast +from typing import Any, Iterable, List, NoReturn, Optional, Tuple, cast from unittest.mock import patch # This import verifies that the dependencies are available. @@ -35,6 +35,10 @@ assert OracleDialect.ischema_names +def _raise_err(exc: Exception) -> NoReturn: + raise exc + + def output_type_handler(cursor, name, defaultType, size, precision, scale): """Add CLOB and BLOB support to Oracle connection.""" @@ -94,7 +98,9 @@ def get_schema_names(self) -> List[str]: s = "SELECT username FROM dba_users ORDER BY username" cursor = self._inspector_instance.bind.execute(s) return [ - self._inspector_instance.dialect.normalize_name(row[0]) for row in cursor + self._inspector_instance.dialect.normalize_name(row[0]) + or _raise_err(ValueError(f"Invalid schema name: {row[0]}")) + for row in cursor ] def get_table_names(self, schema: str = None, order_by: str = None) -> List[str]: @@ -121,7 +127,9 @@ def get_table_names(self, schema: str = None, order_by: str = None) -> List[str] cursor = self._inspector_instance.bind.execute(sql.text(sql_str), owner=schema) return [ - self._inspector_instance.dialect.normalize_name(row[0]) for row in cursor + self._inspector_instance.dialect.normalize_name(row[0]) + or _raise_err(ValueError(f"Invalid table name: {row[0]}")) + for row in cursor ] def __getattr__(self, item: str) -> Any: diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py index 257c84ff5f2e9f..aa035e7494f995 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py @@ -27,6 +27,7 @@ from sqlalchemy.engine.reflection import Inspector from sqlalchemy.exc import ProgrammingError from sqlalchemy.sql import sqltypes as types +from sqlalchemy.types import TypeDecorator, TypeEngine from datahub.configuration.common import AllowDenyPattern from datahub.emitter.mce_builder import ( @@ -328,7 +329,7 @@ class SqlWorkUnit(MetadataWorkUnit): pass -_field_type_mapping: Dict[Type[types.TypeEngine], Type] = { +_field_type_mapping: Dict[Type[TypeEngine], Type] = { types.Integer: NumberTypeClass, types.Numeric: NumberTypeClass, types.Boolean: BooleanTypeClass, @@ -366,30 +367,28 @@ class SqlWorkUnit(MetadataWorkUnit): # assigns the NullType by default. We want to carry this warning through. types.NullType: NullTypeClass, } -_known_unknown_field_types: Set[Type[types.TypeEngine]] = { +_known_unknown_field_types: Set[Type[TypeEngine]] = { types.Interval, types.CLOB, } -def register_custom_type( - tp: Type[types.TypeEngine], output: Optional[Type] = None -) -> None: +def register_custom_type(tp: Type[TypeEngine], output: Optional[Type] = None) -> None: if output: _field_type_mapping[tp] = output else: _known_unknown_field_types.add(tp) -class _CustomSQLAlchemyDummyType(types.TypeDecorator): +class _CustomSQLAlchemyDummyType(TypeDecorator): impl = types.LargeBinary -def make_sqlalchemy_type(name: str) -> Type[types.TypeEngine]: +def make_sqlalchemy_type(name: str) -> Type[TypeEngine]: # This usage of type() dynamically constructs a class. # See https://stackoverflow.com/a/15247202/5004662 and # https://docs.python.org/3/library/functions.html#type. - sqlalchemy_type: Type[types.TypeEngine] = type( + sqlalchemy_type: Type[TypeEngine] = type( name, (_CustomSQLAlchemyDummyType,), { diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/trino.py b/metadata-ingestion/src/datahub/ingestion/source/sql/trino.py index 5a64612022644e..45df2bd91a1475 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/trino.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/trino.py @@ -4,15 +4,12 @@ from typing import Any, Dict, List, Optional import sqlalchemy - -# This import verifies that the dependencies are available. -import trino.sqlalchemy # noqa: F401 from pydantic.fields import Field from sqlalchemy import exc, sql from sqlalchemy.engine import reflection from sqlalchemy.engine.reflection import Inspector from sqlalchemy.sql import sqltypes -from sqlalchemy.sql.type_api import TypeEngine +from sqlalchemy.types import TypeEngine from trino.exceptions import TrinoQueryError from trino.sqlalchemy import datatype, error from trino.sqlalchemy.dialect import TrinoDialect diff --git a/metadata-ingestion/src/datahub/ingestion/source/usage/redshift_usage.py b/metadata-ingestion/src/datahub/ingestion/source/usage/redshift_usage.py index d34989182cec32..2b7001761f210f 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/usage/redshift_usage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/usage/redshift_usage.py @@ -3,13 +3,12 @@ import logging import time from datetime import datetime -from typing import Dict, Iterable, List, Optional, Set +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set from pydantic.fields import Field from pydantic.main import BaseModel from sqlalchemy import create_engine from sqlalchemy.engine import Engine -from sqlalchemy.engine.result import ResultProxy, RowProxy import datahub.emitter.mce_builder as builder from datahub.configuration.source_common import EnvBasedSourceConfigBase @@ -39,6 +38,13 @@ logger = logging.getLogger(__name__) +if TYPE_CHECKING: + try: + from sqlalchemy.engine import Row # type: ignore + except ImportError: + # See https://github.com/python/mypy/issues/1153. + from sqlalchemy.engine.result import RowProxy as Row # type: ignore + REDSHIFT_DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S" @@ -267,7 +273,7 @@ def _make_sql_engine(self) -> Engine: logger.debug(f"sql_alchemy_url = {url}") return create_engine(url, **self.config.options) - def _should_process_row(self, row: RowProxy) -> bool: + def _should_process_row(self, row: "Row") -> bool: # Check for mandatory proerties being present first. missing_props: List[str] = [ prop @@ -295,10 +301,13 @@ def _should_process_row(self, row: RowProxy) -> bool: def _gen_access_events_from_history_query( self, query: str, engine: Engine ) -> Iterable[RedshiftAccessEvent]: - results: ResultProxy = engine.execute(query) - for row in results: # type: RowProxy + results = engine.execute(query) + for row in results: if not self._should_process_row(row): continue + if hasattr(row, "_asdict"): + # Compatibility with sqlalchemy 1.4.x. + row = row._asdict() access_event = RedshiftAccessEvent(**dict(row.items())) # Replace database name with the alias name if one is provided in the config. if self.config.database_alias: diff --git a/metadata-ingestion/src/datahub/integrations/great_expectations/action.py b/metadata-ingestion/src/datahub/integrations/great_expectations/action.py index b0449105168d32..eb3c610dc9fabc 100644 --- a/metadata-ingestion/src/datahub/integrations/great_expectations/action.py +++ b/metadata-ingestion/src/datahub/integrations/great_expectations/action.py @@ -1,3 +1,5 @@ +from datahub.utilities._markupsafe_compat import MARKUPSAFE_PATCHED + import json import logging import os @@ -56,6 +58,7 @@ from datahub.metadata.schema_classes import PartitionSpecClass, PartitionTypeClass from datahub.utilities.sql_parser import DefaultSQLParser +assert MARKUPSAFE_PATCHED logger = logging.getLogger(__name__) if os.getenv("DATAHUB_DEBUG", False): handler = logging.StreamHandler(stream=sys.stdout) diff --git a/metadata-ingestion/src/datahub/utilities/_markupsafe_compat.py b/metadata-ingestion/src/datahub/utilities/_markupsafe_compat.py new file mode 100644 index 00000000000000..801cf2fbfab0ae --- /dev/null +++ b/metadata-ingestion/src/datahub/utilities/_markupsafe_compat.py @@ -0,0 +1,11 @@ +try: + import markupsafe + + # This monkeypatch hack is required for markupsafe>=2.1.0 and older versions of Jinja2. + # Changelog: https://markupsafe.palletsprojects.com/en/2.1.x/changes/#version-2-1-0 + # Example discussion: https://github.com/aws/aws-sam-cli/issues/3661. + markupsafe.soft_unicode = markupsafe.soft_str # type: ignore[attr-defined] + + MARKUPSAFE_PATCHED = True +except ImportError: + MARKUPSAFE_PATCHED = False diff --git a/metadata-ingestion/src/datahub/utilities/sqlalchemy_query_combiner.py b/metadata-ingestion/src/datahub/utilities/sqlalchemy_query_combiner.py index 947f5e30d62c89..29a51007b15d4a 100644 --- a/metadata-ingestion/src/datahub/utilities/sqlalchemy_query_combiner.py +++ b/metadata-ingestion/src/datahub/utilities/sqlalchemy_query_combiner.py @@ -7,7 +7,7 @@ import string import threading import unittest.mock -from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, cast import greenlet import sqlalchemy @@ -39,7 +39,8 @@ def __getitem__(self, k): # type: ignore class _ResultProxyFake: - # This imitates the interface provided by sqlalchemy.engine.result.ResultProxy. + # This imitates the interface provided by sqlalchemy.engine.result.ResultProxy (sqlalchemy 1.3.x) + # or sqlalchemy.engine.Result (1.4.x). # Adapted from https://github.com/rajivsarvepalli/mock-alchemy/blob/2eba95588e7693aab973a6d60441d2bc3c4ea35d/src/mock_alchemy/mocking.py#L213 def __init__(self, result: List[_RowProxyFake]) -> None: @@ -363,7 +364,11 @@ def _execute_queue_fallback(self, main_greenlet: greenlet.greenlet) -> None: *query_future.multiparams, **query_future.params, ) - query_future.res = res + + # The actual execute method returns a CursorResult on SQLAlchemy 1.4.x + # and a ResultProxy on SQLAlchemy 1.3.x. Both interfaces are shimmed + # by _ResultProxyFake. + query_future.res = cast(_ResultProxyFake, res) except Exception as e: query_future.exc = e finally: diff --git a/metadata-ingestion/src/datahub_provider/_airflow_compat.py b/metadata-ingestion/src/datahub_provider/_airflow_compat.py new file mode 100644 index 00000000000000..3493bf721c1a45 --- /dev/null +++ b/metadata-ingestion/src/datahub_provider/_airflow_compat.py @@ -0,0 +1,27 @@ +from datahub.utilities._markupsafe_compat import MARKUPSAFE_PATCHED + +from airflow.hooks.base import BaseHook +from airflow.models.baseoperator import BaseOperator + +try: + from airflow.models.mappedoperator import MappedOperator + from airflow.models.operator import Operator +except ModuleNotFoundError: + Operator = BaseOperator # type: ignore + MappedOperator = None # type: ignore + +try: + from airflow.sensors.external_task import ExternalTaskSensor +except ImportError: + from airflow.sensors.external_task_sensor import ExternalTaskSensor # type: ignore + +assert MARKUPSAFE_PATCHED + +__all__ = [ + "MARKUPSAFE_PATCHED", + "BaseHook", + "Operator", + "BaseOperator", + "MappedOperator", + "ExternalTaskSensor", +] diff --git a/metadata-ingestion/src/datahub_provider/_lineage_core.py b/metadata-ingestion/src/datahub_provider/_lineage_core.py index aa7b61e8c8b71c..78b4a78260d1de 100644 --- a/metadata-ingestion/src/datahub_provider/_lineage_core.py +++ b/metadata-ingestion/src/datahub_provider/_lineage_core.py @@ -1,5 +1,7 @@ +from datahub_provider._airflow_compat import Operator + from datetime import datetime -from typing import TYPE_CHECKING, Dict, List +from typing import TYPE_CHECKING, Dict, List, Optional import datahub.emitter.mce_builder as builder from datahub.api.entities.dataprocess.dataprocess_instance import InstanceRunResult @@ -10,7 +12,6 @@ if TYPE_CHECKING: from airflow import DAG - from airflow.models.baseoperator import BaseOperator from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance @@ -45,9 +46,21 @@ def make_emitter_hook(self) -> "DatahubGenericHook": return DatahubGenericHook(self.datahub_conn_id) +def _task_underscore_inlets(operator: "Operator") -> Optional[List]: + if hasattr(operator, "_inlets"): + return operator._inlets # type: ignore[attr-defined,union-attr] + return None + + +def _task_underscore_outlets(operator: "Operator") -> Optional[List]: + if hasattr(operator, "_outlets"): + return operator._outlets # type: ignore[attr-defined,union-attr] + return None + + def send_lineage_to_datahub( config: DatahubBasicLineageConfig, - operator: "BaseOperator", + operator: "Operator", inlets: List[_Entity], outlets: List[_Entity], context: Dict, @@ -56,7 +69,7 @@ def send_lineage_to_datahub( return dag: "DAG" = context["dag"] - task: "BaseOperator" = context["task"] + task: "Operator" = context["task"] ti: "TaskInstance" = context["task_instance"] hook = config.make_emitter_hook() @@ -110,3 +123,50 @@ def send_lineage_to_datahub( end_timestamp_millis=int(datetime.utcnow().timestamp() * 1000), ) operator.log.info(f"Emitted from Lineage: {dpi}") + + +def preprocess_task_iolets(task: "Operator", context: Dict) -> None: + # This is necessary to avoid issues with circular imports. + from airflow.lineage import prepare_lineage + + from datahub_provider.hooks.datahub import AIRFLOW_1 + + # Detect Airflow 1.10.x inlet/outlet configurations in Airflow 2.x, and + # convert to the newer version. This code path will only be triggered + # when 2.x receives a 1.10.x inlet/outlet config. + needs_repeat_preparation = False + + # Translate inlets. + previous_inlets = _task_underscore_inlets(task) + if ( + not AIRFLOW_1 + and previous_inlets is not None + and isinstance(previous_inlets, list) + and len(previous_inlets) == 1 + and isinstance(previous_inlets[0], dict) + ): + from airflow.lineage import AUTO + + task._inlets = [ # type: ignore[attr-defined,union-attr] + # See https://airflow.apache.org/docs/apache-airflow/1.10.15/lineage.html. + *previous_inlets[0].get("datasets", []), # assumes these are attr-annotated + *previous_inlets[0].get("task_ids", []), + *([AUTO] if previous_inlets[0].get("auto", False) else []), + ] + needs_repeat_preparation = True + + # Translate outlets. + previous_outlets = _task_underscore_outlets(task) + if ( + not AIRFLOW_1 + and previous_inlets is not None + and isinstance(previous_outlets, list) + and len(previous_outlets) == 1 + and isinstance(previous_outlets[0], dict) + ): + task._outlets = [*previous_outlets[0].get("datasets", [])] # type: ignore[attr-defined,union-attr] + needs_repeat_preparation = True + + # Rerun the lineage preparation routine, now that the old format has been translated to the new one. + if needs_repeat_preparation: + prepare_lineage(lambda self, ctx: None)(task, context) diff --git a/metadata-ingestion/src/datahub_provider/client/airflow_generator.py b/metadata-ingestion/src/datahub_provider/client/airflow_generator.py index 69943df50d3afb..1faedd516c3011 100644 --- a/metadata-ingestion/src/datahub_provider/client/airflow_generator.py +++ b/metadata-ingestion/src/datahub_provider/client/airflow_generator.py @@ -1,4 +1,6 @@ -from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast +from datahub_provider._airflow_compat import BaseOperator, ExternalTaskSensor, Operator + +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast from airflow.configuration import conf @@ -13,16 +15,22 @@ if TYPE_CHECKING: from airflow import DAG - from airflow.models import BaseOperator, DagRun, TaskInstance + from airflow.models import DagRun, TaskInstance from datahub.emitter.kafka_emitter import DatahubKafkaEmitter from datahub.emitter.rest_emitter import DatahubRestEmitter +def _task_downstream_task_ids(operator: "Operator") -> Set[str]: + if hasattr(operator, "downstream_task_ids"): + return operator.downstream_task_ids + return operator._downstream_task_id # type: ignore[attr-defined,union-attr] + + class AirflowGenerator: @staticmethod def _get_dependencies( - task: "BaseOperator", dag: "DAG", flow_urn: DataFlowUrn + task: "Operator", dag: "DAG", flow_urn: DataFlowUrn ) -> List[DataJobUrn]: # resolve URNs for upstream nodes in subdags upstream of the current task. @@ -47,7 +55,7 @@ def _get_dependencies( ) # if subdag task is a leaf task, then link it as an upstream task - if len(upstream_subdag_task._downstream_task_ids) == 0: + if len(_task_downstream_task_ids(upstream_subdag_task)) == 0: upstream_subdag_task_urns.append(upstream_subdag_task_urn) # resolve URNs for upstream nodes that trigger the subdag containing the current task. @@ -59,7 +67,7 @@ def _get_dependencies( if ( dag.is_subdag and dag.parent_dag is not None - and len(task._upstream_task_ids) == 0 + and len(task.upstream_task_ids) == 0 ): # filter through the parent dag's tasks and find the subdag trigger(s) @@ -83,7 +91,7 @@ def _get_dependencies( ) # if the task triggers the subdag, link it to this node in the subdag - if subdag_task_id in upstream_task._downstream_task_ids: + if subdag_task_id in _task_downstream_task_ids(upstream_task): upstream_subdag_triggers.append(upstream_task_urn) # If the operator is an ExternalTaskSensor then we set the remote task as upstream. @@ -91,8 +99,6 @@ def _get_dependencies( # jobflow to anothet jobflow. external_task_upstreams = [] if task.task_type == "ExternalTaskSensor": - from airflow.sensors.external_task_sensor import ExternalTaskSensor - task = cast(ExternalTaskSensor, task) if hasattr(task, "external_task_id") and task.external_task_id is not None: external_task_upstreams = [ @@ -173,7 +179,11 @@ def generate_dataflow( return data_flow @staticmethod - def _get_description(task: "BaseOperator") -> Optional[str]: + def _get_description(task: "Operator") -> Optional[str]: + if not isinstance(task, BaseOperator): + # TODO: Get docs for mapped operators. + return None + if hasattr(task, "doc") and task.doc: return task.doc elif hasattr(task, "doc_md") and task.doc_md: @@ -189,9 +199,9 @@ def _get_description(task: "BaseOperator") -> Optional[str]: @staticmethod def generate_datajob( cluster: str, - task: "BaseOperator", + task: "Operator", dag: "DAG", - set_dependendecies: bool = True, + set_dependencies: bool = True, capture_owner: bool = True, capture_tags: bool = True, ) -> DataJob: @@ -200,7 +210,7 @@ def generate_datajob( :param cluster: str :param task: TaskIntance :param dag: DAG - :param set_dependendecies: bool - whether to extract dependencies from airflow task + :param set_dependencies: bool - whether to extract dependencies from airflow task :param capture_owner: bool - whether to extract owner from airflow task :param capture_tags: bool - whether to set tags automatically from airflow task :return: DataJob - returns the generated DataJob object @@ -209,6 +219,8 @@ def generate_datajob( orchestrator="airflow", env=cluster, flow_id=dag.dag_id ) datajob = DataJob(id=task.task_id, flow_urn=dataflow_urn) + + # TODO add support for MappedOperator datajob.description = AirflowGenerator._get_description(task) job_property_bag: Dict[str, str] = {} @@ -228,6 +240,11 @@ def generate_datajob( "task_id", "trigger_rule", "wait_for_downstream", + # In Airflow 2.3, _downstream_task_ids was renamed to downstream_task_ids + "downstream_task_ids", + # In Airflow 2.4, _inlets and _outlets were removed in favor of non-private versions. + "inlets", + "outlets", ] for key in allowed_task_keys: @@ -244,7 +261,7 @@ def generate_datajob( if capture_tags and dag.tags: datajob.tags.update(dag.tags) - if set_dependendecies: + if set_dependencies: datajob.upstream_urns.extend( AirflowGenerator._get_dependencies( task=task, dag=dag, flow_urn=datajob.flow_urn @@ -256,7 +273,7 @@ def generate_datajob( @staticmethod def create_datajob_instance( cluster: str, - task: "BaseOperator", + task: "Operator", dag: "DAG", data_job: Optional[DataJob] = None, ) -> DataProcessInstance: @@ -282,6 +299,7 @@ def run_dataflow( dataflow = AirflowGenerator.generate_dataflow(cluster, dag_run.dag) if start_timestamp_millis is None: + assert dag_run.execution_date start_timestamp_millis = int(dag_run.execution_date.timestamp() * 1000) dpi = DataProcessInstance.from_dataflow(dataflow=dataflow, id=dag_run.run_id) diff --git a/metadata-ingestion/src/datahub_provider/hooks/datahub.py b/metadata-ingestion/src/datahub_provider/hooks/datahub.py index c35076b17f2322..95dbb1b0bc9085 100644 --- a/metadata-ingestion/src/datahub_provider/hooks/datahub.py +++ b/metadata-ingestion/src/datahub_provider/hooks/datahub.py @@ -1,3 +1,5 @@ +from datahub_provider._airflow_compat import AIRFLOW_1, BaseHook + from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from airflow.exceptions import AirflowException @@ -9,6 +11,8 @@ ) if TYPE_CHECKING: + from airflow.models.connection import Connection + from datahub.emitter.kafka_emitter import DatahubKafkaEmitter from datahub.emitter.rest_emitter import DatahubRestEmitter from datahub.ingestion.sink.datahub_kafka import KafkaSinkConfig @@ -51,12 +55,14 @@ def get_ui_field_behaviour() -> Dict: } def _get_config(self) -> Tuple[str, Optional[str], Optional[int]]: - conn = self.get_connection(self.datahub_rest_conn_id) + conn: "Connection" = self.get_connection(self.datahub_rest_conn_id) + host = conn.host if host is None: raise AirflowException("host parameter is required") + password = conn.password timeout_sec = conn.extra_dejson.get("timeout_sec") - return (host, conn.password, timeout_sec) + return (host, password, timeout_sec) def make_emitter(self) -> "DatahubRestEmitter": import datahub.emitter.rest_emitter diff --git a/metadata-ingestion/src/datahub_provider/lineage/datahub.py b/metadata-ingestion/src/datahub_provider/lineage/datahub.py index fb3728aa53f471..8883f09adbabb2 100644 --- a/metadata-ingestion/src/datahub_provider/lineage/datahub.py +++ b/metadata-ingestion/src/datahub_provider/lineage/datahub.py @@ -6,6 +6,7 @@ from datahub_provider._lineage_core import ( DatahubBasicLineageConfig, + preprocess_task_iolets, send_lineage_to_datahub, ) @@ -78,6 +79,7 @@ def send_lineage( try: context = context or {} # ensure not None to satisfy mypy + preprocess_task_iolets(operator, context) send_lineage_to_datahub( config, operator, operator.inlets, operator.outlets, context ) diff --git a/metadata-ingestion/tests/test_helpers/sqlalchemy_mypy_plugin.py b/metadata-ingestion/tests/test_helpers/sqlalchemy_mypy_plugin.py new file mode 100644 index 00000000000000..f51377caf3534d --- /dev/null +++ b/metadata-ingestion/tests/test_helpers/sqlalchemy_mypy_plugin.py @@ -0,0 +1,33 @@ +# On SQLAlchemy 1.4.x, the mypy plugin is built-in. +# However, with SQLAlchemy 1.3.x, it requires the sqlalchemy-stubs package and hence has a separate import. +# This file serves as a thin shim layer that directs mypy to the appropriate plugin implementation. +try: + from mypy.semanal import SemanticAnalyzer + from sqlalchemy.ext.mypy.plugin import plugin + + # On SQLAlchemy >=1.4, <=1.4.29, the mypy plugin is incompatible with newer versions of mypy. + # See https://github.com/sqlalchemy/sqlalchemy/commit/aded8b11d9eccbd1f2b645a94338e34a3d234bc9 + # and https://github.com/sqlalchemy/sqlalchemy/issues/7496. + # To fix this, we need to patch the mypy plugin interface. + # + # We cannot set a min version of SQLAlchemy because of the bigquery SQLAlchemy package. + # See https://github.com/googleapis/python-bigquery-sqlalchemy/issues/385. + _named_type_original = SemanticAnalyzer.named_type + _named_type_translations = { + "__builtins__.object": "builtins.object", + "__builtins__.str": "builtins.str", + "__builtins__.list": "builtins.list", + "__sa_Mapped": "sqlalchemy.orm.attributes.Mapped", + } + + def _named_type_shim(self, fullname, *args, **kwargs): + if fullname in _named_type_translations: + fullname = _named_type_translations[fullname] + + return _named_type_original(self, fullname, *args, **kwargs) + + SemanticAnalyzer.named_type = _named_type_shim # type: ignore +except ModuleNotFoundError: + from sqlmypy import plugin # type: ignore[no-redef] + +__all__ = ["plugin"] diff --git a/metadata-ingestion/tests/unit/test_airflow.py b/metadata-ingestion/tests/unit/test_airflow.py index f0e951d9985e5f..82e76f43b00c72 100644 --- a/metadata-ingestion/tests/unit/test_airflow.py +++ b/metadata-ingestion/tests/unit/test_airflow.py @@ -1,3 +1,5 @@ +from datahub_provider._airflow_compat import MARKUPSAFE_PATCHED + import datetime import json import os @@ -13,19 +15,19 @@ import pytest from airflow.lineage import apply_lineage, prepare_lineage from airflow.models import DAG, Connection, DagBag, DagRun, TaskInstance +from airflow.operators.dummy import DummyOperator from airflow.utils.dates import days_ago -try: - from airflow.operators.dummy import DummyOperator -except ModuleNotFoundError: - from airflow.operators.dummy_operator import DummyOperator - import datahub.emitter.mce_builder as builder from datahub_provider import get_provider_info from datahub_provider.entities import Dataset from datahub_provider.hooks.datahub import DatahubKafkaHook, DatahubRestHook from datahub_provider.operators.datahub import DatahubEmitterOperator +assert MARKUPSAFE_PATCHED + +pytestmark = pytest.mark.airflow + # Approach suggested by https://stackoverflow.com/a/11887885/5004662. AIRFLOW_VERSION = packaging.version.parse(airflow.version.version) @@ -73,6 +75,10 @@ def test_airflow_provider_info(): assert get_provider_info() +@pytest.mark.skipif( + AIRFLOW_VERSION < packaging.version.parse("2.0.0"), + reason="the examples use list-style lineage, which is only supported on Airflow 2.x", +) def test_dags_load_with_no_errors(pytestconfig): airflow_examples_folder = ( pytestconfig.rootpath / "src/datahub_provider/example_dags" @@ -99,6 +105,7 @@ def patch_airflow_connection(conn: Connection) -> Iterator[Connection]: @mock.patch("datahub.emitter.rest_emitter.DatahubRestEmitter", autospec=True) def test_datahub_rest_hook(mock_emitter): with patch_airflow_connection(datahub_rest_connection_config) as config: + assert config.conn_id hook = DatahubRestHook(config.conn_id) hook.emit_mces([lineage_mce]) @@ -112,6 +119,7 @@ def test_datahub_rest_hook_with_timeout(mock_emitter): with patch_airflow_connection( datahub_rest_connection_config_with_timeout ) as config: + assert config.conn_id hook = DatahubRestHook(config.conn_id) hook.emit_mces([lineage_mce]) @@ -123,6 +131,7 @@ def test_datahub_rest_hook_with_timeout(mock_emitter): @mock.patch("datahub.emitter.kafka_emitter.DatahubKafkaEmitter", autospec=True) def test_datahub_kafka_hook(mock_emitter): with patch_airflow_connection(datahub_kafka_connection_config) as config: + assert config.conn_id hook = DatahubKafkaHook(config.conn_id) hook.emit_mces([lineage_mce]) @@ -135,6 +144,7 @@ def test_datahub_kafka_hook(mock_emitter): @mock.patch("datahub_provider.hooks.datahub.DatahubRestHook.emit_mces") def test_datahub_lineage_operator(mock_emit): with patch_airflow_connection(datahub_rest_connection_config) as config: + assert config.conn_id task = DatahubEmitterOperator( task_id="emit_lineage", datahub_conn_id=config.conn_id, @@ -331,6 +341,7 @@ def test_lineage_backend(mock_emit, inlets, outlets): ) @mock.patch("datahub_provider.hooks.datahub.DatahubRestHook.make_emitter") def test_lineage_backend_capture_executions(mock_emit, inlets, outlets): + # TODO: Merge this code into the test above to reduce duplication. DEFAULT_DATE = datetime.datetime(2020, 5, 17) mock_emitter = Mock() mock_emit.return_value = mock_emitter @@ -375,10 +386,6 @@ def test_lineage_backend_capture_executions(mock_emit, inlets, outlets): ti = TaskInstance(task=op2, execution_date=DEFAULT_DATE) # Ignoring type here because DagRun state is just a sring at Airflow 1 dag_run = DagRun(state="success", run_id=f"scheduled_{DEFAULT_DATE}") # type: ignore - ti.dag_run = dag_run - ti.start_date = datetime.datetime.utcnow() - ti.execution_date = DEFAULT_DATE - else: from airflow.utils.state import DagRunState @@ -386,9 +393,10 @@ def test_lineage_backend_capture_executions(mock_emit, inlets, outlets): dag_run = DagRun( state=DagRunState.SUCCESS, run_id=f"scheduled_{DEFAULT_DATE}" ) - ti.dag_run = dag_run - ti.start_date = datetime.datetime.utcnow() - ti.execution_date = DEFAULT_DATE + + ti.dag_run = dag_run # type: ignore + ti.start_date = datetime.datetime.utcnow() + ti.execution_date = DEFAULT_DATE ctx1 = { "dag": dag, diff --git a/metadata-ingestion/tests/unit/test_athena_source.py b/metadata-ingestion/tests/unit/test_athena_source.py index f92318f69f0948..60083a961d3b6b 100644 --- a/metadata-ingestion/tests/unit/test_athena_source.py +++ b/metadata-ingestion/tests/unit/test_athena_source.py @@ -65,9 +65,8 @@ def test_athena_get_table_properties(): mock_cursor = mock.MagicMock() mock_inspector = mock.MagicMock() - mock_inspector.engine.return_value = mock.MagicMock() - mock_inspector.dialect._raw_connection.return_value = mock_cursor - mock_inspector.dialect._raw_connection().cursor()._get_table_metadata.return_value = AthenaTableMetadata( + mock_inspector.engine.raw_connection().cursor.return_value = mock_cursor + mock_cursor._get_table_metadata.return_value = AthenaTableMetadata( response=table_metadata ) From e81284f50bdfd136f2721101c970726cd787f5e9 Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Tue, 1 Nov 2022 23:59:40 -0700 Subject: [PATCH 02/16] fix compat --- .github/workflows/metadata-ingestion.yml | 5 +---- metadata-ingestion/src/datahub_provider/_airflow_compat.py | 2 -- metadata-ingestion/src/datahub_provider/hooks/datahub.py | 2 -- metadata-ingestion/src/datahub_provider/lineage/datahub.py | 2 -- 4 files changed, 1 insertion(+), 10 deletions(-) diff --git a/.github/workflows/metadata-ingestion.yml b/.github/workflows/metadata-ingestion.yml index 1008cc1240d9bc..93da99b4151a9b 100644 --- a/.github/workflows/metadata-ingestion.yml +++ b/.github/workflows/metadata-ingestion.yml @@ -40,10 +40,7 @@ jobs: ] include: - python-version: "3.7" - extraPythonRequirement: "sqlalchemy==1.3.24" - - python-version: "3.7" - command: "testAirflow1" - extraPythonRequirement: "sqlalchemy==1.3.24" + extraPythonRequirement: "sqlalchemy==1.3.24 apache-airflow==2.0.2" - python-version: "3.10" extraPythonRequirement: "sqlalchemy~=1.4.0" fail-fast: false diff --git a/metadata-ingestion/src/datahub_provider/_airflow_compat.py b/metadata-ingestion/src/datahub_provider/_airflow_compat.py index 3493bf721c1a45..d342e49f5342f8 100644 --- a/metadata-ingestion/src/datahub_provider/_airflow_compat.py +++ b/metadata-ingestion/src/datahub_provider/_airflow_compat.py @@ -1,6 +1,5 @@ from datahub.utilities._markupsafe_compat import MARKUPSAFE_PATCHED -from airflow.hooks.base import BaseHook from airflow.models.baseoperator import BaseOperator try: @@ -19,7 +18,6 @@ __all__ = [ "MARKUPSAFE_PATCHED", - "BaseHook", "Operator", "BaseOperator", "MappedOperator", diff --git a/metadata-ingestion/src/datahub_provider/hooks/datahub.py b/metadata-ingestion/src/datahub_provider/hooks/datahub.py index 95dbb1b0bc9085..1343e34954975a 100644 --- a/metadata-ingestion/src/datahub_provider/hooks/datahub.py +++ b/metadata-ingestion/src/datahub_provider/hooks/datahub.py @@ -1,5 +1,3 @@ -from datahub_provider._airflow_compat import AIRFLOW_1, BaseHook - from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from airflow.exceptions import AirflowException diff --git a/metadata-ingestion/src/datahub_provider/lineage/datahub.py b/metadata-ingestion/src/datahub_provider/lineage/datahub.py index 8883f09adbabb2..fb3728aa53f471 100644 --- a/metadata-ingestion/src/datahub_provider/lineage/datahub.py +++ b/metadata-ingestion/src/datahub_provider/lineage/datahub.py @@ -6,7 +6,6 @@ from datahub_provider._lineage_core import ( DatahubBasicLineageConfig, - preprocess_task_iolets, send_lineage_to_datahub, ) @@ -79,7 +78,6 @@ def send_lineage( try: context = context or {} # ensure not None to satisfy mypy - preprocess_task_iolets(operator, context) send_lineage_to_datahub( config, operator, operator.inlets, operator.outlets, context ) From 451e1539fcd20600af1534b7c4c78eee8d2267b0 Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Wed, 2 Nov 2022 11:10:01 -0700 Subject: [PATCH 03/16] save file --- .../src/datahub_provider/_lineage_core.py | 47 ------------------- 1 file changed, 47 deletions(-) diff --git a/metadata-ingestion/src/datahub_provider/_lineage_core.py b/metadata-ingestion/src/datahub_provider/_lineage_core.py index 78b4a78260d1de..ee15f8ca54760f 100644 --- a/metadata-ingestion/src/datahub_provider/_lineage_core.py +++ b/metadata-ingestion/src/datahub_provider/_lineage_core.py @@ -123,50 +123,3 @@ def send_lineage_to_datahub( end_timestamp_millis=int(datetime.utcnow().timestamp() * 1000), ) operator.log.info(f"Emitted from Lineage: {dpi}") - - -def preprocess_task_iolets(task: "Operator", context: Dict) -> None: - # This is necessary to avoid issues with circular imports. - from airflow.lineage import prepare_lineage - - from datahub_provider.hooks.datahub import AIRFLOW_1 - - # Detect Airflow 1.10.x inlet/outlet configurations in Airflow 2.x, and - # convert to the newer version. This code path will only be triggered - # when 2.x receives a 1.10.x inlet/outlet config. - needs_repeat_preparation = False - - # Translate inlets. - previous_inlets = _task_underscore_inlets(task) - if ( - not AIRFLOW_1 - and previous_inlets is not None - and isinstance(previous_inlets, list) - and len(previous_inlets) == 1 - and isinstance(previous_inlets[0], dict) - ): - from airflow.lineage import AUTO - - task._inlets = [ # type: ignore[attr-defined,union-attr] - # See https://airflow.apache.org/docs/apache-airflow/1.10.15/lineage.html. - *previous_inlets[0].get("datasets", []), # assumes these are attr-annotated - *previous_inlets[0].get("task_ids", []), - *([AUTO] if previous_inlets[0].get("auto", False) else []), - ] - needs_repeat_preparation = True - - # Translate outlets. - previous_outlets = _task_underscore_outlets(task) - if ( - not AIRFLOW_1 - and previous_inlets is not None - and isinstance(previous_outlets, list) - and len(previous_outlets) == 1 - and isinstance(previous_outlets[0], dict) - ): - task._outlets = [*previous_outlets[0].get("datasets", [])] # type: ignore[attr-defined,union-attr] - needs_repeat_preparation = True - - # Rerun the lineage preparation routine, now that the old format has been translated to the new one. - if needs_repeat_preparation: - prepare_lineage(lambda self, ctx: None)(task, context) From 080fbb02fb40ba139ec56e5e0dc2f06c0785be48 Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Wed, 2 Nov 2022 17:00:21 -0700 Subject: [PATCH 04/16] fix tests --- .github/workflows/metadata-ingestion.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/metadata-ingestion.yml b/.github/workflows/metadata-ingestion.yml index 93da99b4151a9b..b409fe1f0504f9 100644 --- a/.github/workflows/metadata-ingestion.yml +++ b/.github/workflows/metadata-ingestion.yml @@ -40,9 +40,9 @@ jobs: ] include: - python-version: "3.7" - extraPythonRequirement: "sqlalchemy==1.3.24 apache-airflow==2.0.2" + extraPythonRequirement: "sqlalchemy==1.3.24 apache-airflow~=2.2.0" - python-version: "3.10" - extraPythonRequirement: "sqlalchemy~=1.4.0" + extraPythonRequirement: "sqlalchemy~=1.4.0 apache-airflow~=2.3.0" fail-fast: false steps: - uses: actions/checkout@v3 From e57a474964b8db1c396408c24045a7d48bd3d909 Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Wed, 2 Nov 2022 21:44:42 -0700 Subject: [PATCH 05/16] move datahub plugin core into main codebase --- .../datahub_airflow_plugin/datahub_plugin.py | 367 +----------------- .../client/airflow_generator.py | 1 + .../src/datahub_provider/hooks/_plugin.py | 365 +++++++++++++++++ 3 files changed, 368 insertions(+), 365 deletions(-) create mode 100644 metadata-ingestion/src/datahub_provider/hooks/_plugin.py diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_plugin.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_plugin.py index 81541326b172aa..e54363618a7e38 100644 --- a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_plugin.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_plugin.py @@ -1,365 +1,2 @@ -import contextlib -import traceback -from typing import Any, Dict, Iterable - -import attr -from airflow.configuration import conf -from airflow.lineage import PIPELINE_OUTLETS -from airflow.models.baseoperator import BaseOperator -from airflow.plugins_manager import AirflowPlugin -from airflow.utils.module_loading import import_string -from cattr import structure -from datahub.api.entities.dataprocess.dataprocess_instance import InstanceRunResult -from datahub_provider._lineage_core import preprocess_task_iolets -from datahub_provider.client.airflow_generator import AirflowGenerator -from datahub_provider.hooks.datahub import DatahubGenericHook -from datahub_provider.lineage.datahub import DatahubLineageConfig - - -def get_lineage_config() -> DatahubLineageConfig: - """Load the lineage config from airflow.cfg.""" - - enabled = conf.get("datahub", "enabled", fallback=True) - datahub_conn_id = conf.get("datahub", "conn_id", fallback="datahub_rest_default") - cluster = conf.get("datahub", "cluster", fallback="prod") - graceful_exceptions = conf.get("datahub", "graceful_exceptions", fallback=True) - capture_tags_info = conf.get("datahub", "capture_tags_info", fallback=True) - capture_ownership_info = conf.get( - "datahub", "capture_ownership_info", fallback=True - ) - capture_executions = conf.get("datahub", "capture_executions", fallback=True) - return DatahubLineageConfig( - enabled=enabled, - datahub_conn_id=datahub_conn_id, - cluster=cluster, - graceful_exceptions=graceful_exceptions, - capture_ownership_info=capture_ownership_info, - capture_tags_info=capture_tags_info, - capture_executions=capture_executions, - ) - - -def get_inlets_from_task(task: BaseOperator, context: Any) -> Iterable[Any]: - # TODO fix for https://github.com/apache/airflow/commit/1b1f3fabc5909a447a6277cafef3a0d4ef1f01ae - - inlets = [] - if isinstance(task._inlets, (str, BaseOperator)) or attr.has(task._inlets): # type: ignore - inlets = [ - task._inlets, - ] - - if task._inlets and isinstance(task._inlets, list): - inlets = [] - task_ids = ( - {o for o in task._inlets if isinstance(o, str)} - .union(op.task_id for op in task._inlets if isinstance(op, BaseOperator)) - .intersection(task.get_flat_relative_ids(upstream=True)) - ) - - from airflow.lineage import AUTO - - # pick up unique direct upstream task_ids if AUTO is specified - if AUTO.upper() in task._inlets or AUTO.lower() in task._inlets: - print("Picking up unique direct upstream task_ids as AUTO is specified") - task_ids = task_ids.union( - task_ids.symmetric_difference(task.upstream_task_ids) - ) - - inlets = task.xcom_pull( - context, task_ids=list(task_ids), dag_id=task.dag_id, key=PIPELINE_OUTLETS - ) - - # re-instantiate the obtained inlets - inlets = [ - structure(item["data"], import_string(item["type_name"])) - # _get_instance(structure(item, Metadata)) - for sublist in inlets - if sublist - for item in sublist - ] - - for inlet in task._inlets: - if type(inlet) != str: - inlets.append(inlet) - - return inlets - - -def datahub_on_failure_callback(context, *args, **kwargs): - ti = context["ti"] - task: "BaseOperator" = ti.task - dag = context["dag"] - - # This code is from the original airflow lineage code -> - # https://github.com/apache/airflow/blob/main/airflow/lineage/__init__.py - inlets = get_inlets_from_task(task, context) - - emitter = ( - DatahubGenericHook(context["_datahub_config"].datahub_conn_id) - .get_underlying_hook() - .make_emitter() - ) - - dataflow = AirflowGenerator.generate_dataflow( - cluster=context["_datahub_config"].cluster, - dag=dag, - capture_tags=context["_datahub_config"].capture_tags_info, - capture_owner=context["_datahub_config"].capture_ownership_info, - ) - dataflow.emit(emitter) - - task.log.info(f"Emitted Datahub DataFlow: {dataflow}") - - datajob = AirflowGenerator.generate_datajob( - cluster=context["_datahub_config"].cluster, - task=context["ti"].task, - dag=dag, - capture_tags=context["_datahub_config"].capture_tags_info, - capture_owner=context["_datahub_config"].capture_ownership_info, - ) - - for inlet in inlets: - datajob.inlets.append(inlet.urn) - - for outlet in task._outlets: - datajob.outlets.append(outlet.urn) - - task.log.info(f"Emitted Datahub DataJob: {datajob}") - datajob.emit(emitter) - - if context["_datahub_config"].capture_executions: - dpi = AirflowGenerator.run_datajob( - emitter=emitter, - cluster=context["_datahub_config"].cluster, - ti=context["ti"], - dag=dag, - dag_run=context["dag_run"], - datajob=datajob, - start_timestamp_millis=int(ti.start_date.timestamp() * 1000), - ) - - task.log.info(f"Emitted Start Datahub Dataprocess Instance: {dpi}") - - dpi = AirflowGenerator.complete_datajob( - emitter=emitter, - cluster=context["_datahub_config"].cluster, - ti=context["ti"], - dag_run=context["dag_run"], - result=InstanceRunResult.FAILURE, - dag=dag, - datajob=datajob, - end_timestamp_millis=int(ti.end_date.timestamp() * 1000), - ) - task.log.info(f"Emitted Completed Datahub Dataprocess Instance: {dpi}") - - -def datahub_on_success_callback(context, *args, **kwargs): - ti = context["ti"] - task: "BaseOperator" = ti.task - dag = context["dag"] - - # This code is from the original airflow lineage code -> - # https://github.com/apache/airflow/blob/main/airflow/lineage/__init__.py - inlets = get_inlets_from_task(task, context) - - emitter = ( - DatahubGenericHook(context["_datahub_config"].datahub_conn_id) - .get_underlying_hook() - .make_emitter() - ) - - dataflow = AirflowGenerator.generate_dataflow( - cluster=context["_datahub_config"].cluster, - dag=dag, - capture_tags=context["_datahub_config"].capture_tags_info, - capture_owner=context["_datahub_config"].capture_ownership_info, - ) - dataflow.emit(emitter) - - task.log.info(f"Emitted Datahub DataFlow: {dataflow}") - - datajob = AirflowGenerator.generate_datajob( - cluster=context["_datahub_config"].cluster, - task=task, - dag=dag, - capture_tags=context["_datahub_config"].capture_tags_info, - capture_owner=context["_datahub_config"].capture_ownership_info, - ) - - for inlet in inlets: - datajob.inlets.append(inlet.urn) - - # We have to use _outlets because outlets is empty - for outlet in task._outlets: - datajob.outlets.append(outlet.urn) - - task.log.info(f"Emitted Datahub dataJob: {datajob}") - datajob.emit(emitter) - - if context["_datahub_config"].capture_executions: - dpi = AirflowGenerator.run_datajob( - emitter=emitter, - cluster=context["_datahub_config"].cluster, - ti=context["ti"], - dag=dag, - dag_run=context["dag_run"], - datajob=datajob, - start_timestamp_millis=int(ti.start_date.timestamp() * 1000), - ) - - task.log.info(f"Emitted Start Datahub Dataprocess Instance: {dpi}") - - dpi = AirflowGenerator.complete_datajob( - emitter=emitter, - cluster=context["_datahub_config"].cluster, - ti=context["ti"], - dag_run=context["dag_run"], - result=InstanceRunResult.SUCCESS, - dag=dag, - datajob=datajob, - end_timestamp_millis=int(ti.end_date.timestamp() * 1000), - ) - task.log.info(f"Emitted Completed Data Process Instance: {dpi}") - - -def datahub_pre_execution(context): - ti = context["ti"] - task: "BaseOperator" = ti.task - dag = context["dag"] - - task.log.info("Running Datahub pre_execute method") - - emitter = ( - DatahubGenericHook(context["_datahub_config"].datahub_conn_id) - .get_underlying_hook() - .make_emitter() - ) - - # This code is from the original airflow lineage code -> - # https://github.com/apache/airflow/blob/main/airflow/lineage/__init__.py - inlets = get_inlets_from_task(task, context) - - datajob = AirflowGenerator.generate_datajob( - cluster=context["_datahub_config"].cluster, - task=context["ti"].task, - dag=dag, - capture_tags=context["_datahub_config"].capture_tags_info, - capture_owner=context["_datahub_config"].capture_ownership_info, - ) - - for inlet in inlets: - datajob.inlets.append(inlet.urn) - - for outlet in task._outlets: - datajob.outlets.append(outlet.urn) - - datajob.emit(emitter) - task.log.info(f"Emitting Datahub DataJob: {datajob}") - - if context["_datahub_config"].capture_executions: - dpi = AirflowGenerator.run_datajob( - emitter=emitter, - cluster=context["_datahub_config"].cluster, - ti=context["ti"], - dag=dag, - dag_run=context["dag_run"], - datajob=datajob, - start_timestamp_millis=int(ti.start_date.timestamp() * 1000), - ) - - task.log.info(f"Emitting Datahub Dataprocess Instance: {dpi}") - - -def _wrap_pre_execution(pre_execution): - def custom_pre_execution(context): - config = get_lineage_config() - context["_datahub_config"] = config - datahub_pre_execution(context) - - # Call original policy - if pre_execution: - pre_execution(context) - - return custom_pre_execution - - -def _wrap_on_failure_callback(on_failure_callback): - def custom_on_failure_callback(context): - config = get_lineage_config() - context["_datahub_config"] = config - try: - datahub_on_failure_callback(context) - except Exception as e: - if not config.graceful_exceptions: - raise e - else: - print(f"Exception: {traceback.format_exc()}") - - # Call original policy - if on_failure_callback: - on_failure_callback(context) - - return custom_on_failure_callback - - -def _wrap_on_success_callback(on_success_callback): - def custom_on_success_callback(context): - config = get_lineage_config() - context["_datahub_config"] = config - try: - datahub_on_success_callback(context) - except Exception as e: - if not config.graceful_exceptions: - raise e - else: - print(f"Exception: {traceback.format_exc()}") - - if on_success_callback: - on_success_callback(context) - - return custom_on_success_callback - - -def task_policy(task: BaseOperator) -> None: - print(f"Setting task policy for Dag: {task.dag_id} Task: {task.task_id}") - # task.add_inlets(["auto"]) - # task.pre_execute = _wrap_pre_execution(task.pre_execute) - task.on_failure_callback = _wrap_on_failure_callback(task.on_failure_callback) - task.on_success_callback = _wrap_on_success_callback(task.on_success_callback) - # task.pre_execute = _wrap_pre_execution(task.pre_execute) - - -def _wrap_task_policy(policy): - if policy and hasattr(policy, "_task_policy_patched_by"): - return policy - - def custom_task_policy(task): - policy(task) - task_policy(task) - - setattr(custom_task_policy, "_task_policy_patched_by", "datahub_plugin") - return custom_task_policy - - -def _patch_policy(settings): - print("Patching datahub policy") - if hasattr(settings, "task_policy"): - datahub_task_policy = _wrap_task_policy(settings.task_policy) - settings.task_policy = datahub_task_policy - - -def _patch_datahub_policy(): - with contextlib.suppress(ImportError): - import airflow_local_settings - - _patch_policy(airflow_local_settings) - from airflow.models.dagbag import settings - - _patch_policy(settings) - - -_patch_datahub_policy() - - -class DatahubPlugin(AirflowPlugin): - name = "datahub_plugin" +# This package serves as a shim, but the actual implementation lives in datahub_provider from the acryl-datahub package. +from datahub_provider._plugin import DatahubPlugin diff --git a/metadata-ingestion/src/datahub_provider/client/airflow_generator.py b/metadata-ingestion/src/datahub_provider/client/airflow_generator.py index 1faedd516c3011..063630e595c116 100644 --- a/metadata-ingestion/src/datahub_provider/client/airflow_generator.py +++ b/metadata-ingestion/src/datahub_provider/client/airflow_generator.py @@ -302,6 +302,7 @@ def run_dataflow( assert dag_run.execution_date start_timestamp_millis = int(dag_run.execution_date.timestamp() * 1000) + assert dag_run.run_id dpi = DataProcessInstance.from_dataflow(dataflow=dataflow, id=dag_run.run_id) # This property only exists in Airflow2 diff --git a/metadata-ingestion/src/datahub_provider/hooks/_plugin.py b/metadata-ingestion/src/datahub_provider/hooks/_plugin.py new file mode 100644 index 00000000000000..f341f714ae7a6a --- /dev/null +++ b/metadata-ingestion/src/datahub_provider/hooks/_plugin.py @@ -0,0 +1,365 @@ +import contextlib +import traceback +from typing import Any, Iterable + +import attr +from airflow.configuration import conf +from airflow.lineage import PIPELINE_OUTLETS +from airflow.models.baseoperator import BaseOperator +from airflow.plugins_manager import AirflowPlugin +from airflow.utils.module_loading import import_string +from cattr import structure + +from datahub.api.entities.dataprocess.dataprocess_instance import InstanceRunResult +from datahub_provider.client.airflow_generator import AirflowGenerator +from datahub_provider.hooks.datahub import DatahubGenericHook +from datahub_provider.lineage.datahub import DatahubLineageConfig + + +def get_lineage_config() -> DatahubLineageConfig: + """Load the lineage config from airflow.cfg.""" + + enabled = conf.get("datahub", "enabled", fallback=True) + datahub_conn_id = conf.get("datahub", "conn_id", fallback="datahub_rest_default") + cluster = conf.get("datahub", "cluster", fallback="prod") + graceful_exceptions = conf.get("datahub", "graceful_exceptions", fallback=True) + capture_tags_info = conf.get("datahub", "capture_tags_info", fallback=True) + capture_ownership_info = conf.get( + "datahub", "capture_ownership_info", fallback=True + ) + capture_executions = conf.get("datahub", "capture_executions", fallback=True) + return DatahubLineageConfig( + enabled=enabled, + datahub_conn_id=datahub_conn_id, + cluster=cluster, + graceful_exceptions=graceful_exceptions, + capture_ownership_info=capture_ownership_info, + capture_tags_info=capture_tags_info, + capture_executions=capture_executions, + ) + + +def get_inlets_from_task(task: BaseOperator, context: Any) -> Iterable[Any]: + # TODO fix for https://github.com/apache/airflow/commit/1b1f3fabc5909a447a6277cafef3a0d4ef1f01ae + + inlets = [] + if isinstance(task._inlets, (str, BaseOperator)) or attr.has(task._inlets): # type: ignore + inlets = [ + task._inlets, + ] + + if task._inlets and isinstance(task._inlets, list): + inlets = [] + task_ids = ( + {o for o in task._inlets if isinstance(o, str)} + .union(op.task_id for op in task._inlets if isinstance(op, BaseOperator)) + .intersection(task.get_flat_relative_ids(upstream=True)) + ) + + from airflow.lineage import AUTO + + # pick up unique direct upstream task_ids if AUTO is specified + if AUTO.upper() in task._inlets or AUTO.lower() in task._inlets: + print("Picking up unique direct upstream task_ids as AUTO is specified") + task_ids = task_ids.union( + task_ids.symmetric_difference(task.upstream_task_ids) + ) + + inlets = task.xcom_pull( + context, task_ids=list(task_ids), dag_id=task.dag_id, key=PIPELINE_OUTLETS + ) + + # re-instantiate the obtained inlets + inlets = [ + structure(item["data"], import_string(item["type_name"])) + # _get_instance(structure(item, Metadata)) + for sublist in inlets + if sublist + for item in sublist + ] + + for inlet in task._inlets: + if type(inlet) != str: + inlets.append(inlet) + + return inlets + + +def datahub_on_failure_callback(context, *args, **kwargs): + ti = context["ti"] + task: "BaseOperator" = ti.task + dag = context["dag"] + + # This code is from the original airflow lineage code -> + # https://github.com/apache/airflow/blob/main/airflow/lineage/__init__.py + inlets = get_inlets_from_task(task, context) + + emitter = ( + DatahubGenericHook(context["_datahub_config"].datahub_conn_id) + .get_underlying_hook() + .make_emitter() + ) + + dataflow = AirflowGenerator.generate_dataflow( + cluster=context["_datahub_config"].cluster, + dag=dag, + capture_tags=context["_datahub_config"].capture_tags_info, + capture_owner=context["_datahub_config"].capture_ownership_info, + ) + dataflow.emit(emitter) + + task.log.info(f"Emitted Datahub DataFlow: {dataflow}") + + datajob = AirflowGenerator.generate_datajob( + cluster=context["_datahub_config"].cluster, + task=context["ti"].task, + dag=dag, + capture_tags=context["_datahub_config"].capture_tags_info, + capture_owner=context["_datahub_config"].capture_ownership_info, + ) + + for inlet in inlets: + datajob.inlets.append(inlet.urn) + + for outlet in task._outlets: + datajob.outlets.append(outlet.urn) + + task.log.info(f"Emitted Datahub DataJob: {datajob}") + datajob.emit(emitter) + + if context["_datahub_config"].capture_executions: + dpi = AirflowGenerator.run_datajob( + emitter=emitter, + cluster=context["_datahub_config"].cluster, + ti=context["ti"], + dag=dag, + dag_run=context["dag_run"], + datajob=datajob, + start_timestamp_millis=int(ti.start_date.timestamp() * 1000), + ) + + task.log.info(f"Emitted Start Datahub Dataprocess Instance: {dpi}") + + dpi = AirflowGenerator.complete_datajob( + emitter=emitter, + cluster=context["_datahub_config"].cluster, + ti=context["ti"], + dag_run=context["dag_run"], + result=InstanceRunResult.FAILURE, + dag=dag, + datajob=datajob, + end_timestamp_millis=int(ti.end_date.timestamp() * 1000), + ) + task.log.info(f"Emitted Completed Datahub Dataprocess Instance: {dpi}") + + +def datahub_on_success_callback(context, *args, **kwargs): + ti = context["ti"] + task: "BaseOperator" = ti.task + dag = context["dag"] + + # This code is from the original airflow lineage code -> + # https://github.com/apache/airflow/blob/main/airflow/lineage/__init__.py + inlets = get_inlets_from_task(task, context) + + emitter = ( + DatahubGenericHook(context["_datahub_config"].datahub_conn_id) + .get_underlying_hook() + .make_emitter() + ) + + dataflow = AirflowGenerator.generate_dataflow( + cluster=context["_datahub_config"].cluster, + dag=dag, + capture_tags=context["_datahub_config"].capture_tags_info, + capture_owner=context["_datahub_config"].capture_ownership_info, + ) + dataflow.emit(emitter) + + task.log.info(f"Emitted Datahub DataFlow: {dataflow}") + + datajob = AirflowGenerator.generate_datajob( + cluster=context["_datahub_config"].cluster, + task=task, + dag=dag, + capture_tags=context["_datahub_config"].capture_tags_info, + capture_owner=context["_datahub_config"].capture_ownership_info, + ) + + for inlet in inlets: + datajob.inlets.append(inlet.urn) + + # We have to use _outlets because outlets is empty + for outlet in task._outlets: + datajob.outlets.append(outlet.urn) + + task.log.info(f"Emitted Datahub dataJob: {datajob}") + datajob.emit(emitter) + + if context["_datahub_config"].capture_executions: + dpi = AirflowGenerator.run_datajob( + emitter=emitter, + cluster=context["_datahub_config"].cluster, + ti=context["ti"], + dag=dag, + dag_run=context["dag_run"], + datajob=datajob, + start_timestamp_millis=int(ti.start_date.timestamp() * 1000), + ) + + task.log.info(f"Emitted Start Datahub Dataprocess Instance: {dpi}") + + dpi = AirflowGenerator.complete_datajob( + emitter=emitter, + cluster=context["_datahub_config"].cluster, + ti=context["ti"], + dag_run=context["dag_run"], + result=InstanceRunResult.SUCCESS, + dag=dag, + datajob=datajob, + end_timestamp_millis=int(ti.end_date.timestamp() * 1000), + ) + task.log.info(f"Emitted Completed Data Process Instance: {dpi}") + + +def datahub_pre_execution(context): + ti = context["ti"] + task: "BaseOperator" = ti.task + dag = context["dag"] + + task.log.info("Running Datahub pre_execute method") + + emitter = ( + DatahubGenericHook(context["_datahub_config"].datahub_conn_id) + .get_underlying_hook() + .make_emitter() + ) + + # This code is from the original airflow lineage code -> + # https://github.com/apache/airflow/blob/main/airflow/lineage/__init__.py + inlets = get_inlets_from_task(task, context) + + datajob = AirflowGenerator.generate_datajob( + cluster=context["_datahub_config"].cluster, + task=context["ti"].task, + dag=dag, + capture_tags=context["_datahub_config"].capture_tags_info, + capture_owner=context["_datahub_config"].capture_ownership_info, + ) + + for inlet in inlets: + datajob.inlets.append(inlet.urn) + + for outlet in task._outlets: + datajob.outlets.append(outlet.urn) + + datajob.emit(emitter) + task.log.info(f"Emitting Datahub DataJob: {datajob}") + + if context["_datahub_config"].capture_executions: + dpi = AirflowGenerator.run_datajob( + emitter=emitter, + cluster=context["_datahub_config"].cluster, + ti=context["ti"], + dag=dag, + dag_run=context["dag_run"], + datajob=datajob, + start_timestamp_millis=int(ti.start_date.timestamp() * 1000), + ) + + task.log.info(f"Emitting Datahub Dataprocess Instance: {dpi}") + + +def _wrap_pre_execution(pre_execution): + def custom_pre_execution(context): + config = get_lineage_config() + context["_datahub_config"] = config + datahub_pre_execution(context) + + # Call original policy + if pre_execution: + pre_execution(context) + + return custom_pre_execution + + +def _wrap_on_failure_callback(on_failure_callback): + def custom_on_failure_callback(context): + config = get_lineage_config() + context["_datahub_config"] = config + try: + datahub_on_failure_callback(context) + except Exception as e: + if not config.graceful_exceptions: + raise e + else: + print(f"Exception: {traceback.format_exc()}") + + # Call original policy + if on_failure_callback: + on_failure_callback(context) + + return custom_on_failure_callback + + +def _wrap_on_success_callback(on_success_callback): + def custom_on_success_callback(context): + config = get_lineage_config() + context["_datahub_config"] = config + try: + datahub_on_success_callback(context) + except Exception as e: + if not config.graceful_exceptions: + raise e + else: + print(f"Exception: {traceback.format_exc()}") + + if on_success_callback: + on_success_callback(context) + + return custom_on_success_callback + + +def task_policy(task: BaseOperator) -> None: + print(f"Setting task policy for Dag: {task.dag_id} Task: {task.task_id}") + # task.add_inlets(["auto"]) + # task.pre_execute = _wrap_pre_execution(task.pre_execute) + task.on_failure_callback = _wrap_on_failure_callback(task.on_failure_callback) + task.on_success_callback = _wrap_on_success_callback(task.on_success_callback) + # task.pre_execute = _wrap_pre_execution(task.pre_execute) + + +def _wrap_task_policy(policy): + if policy and hasattr(policy, "_task_policy_patched_by"): + return policy + + def custom_task_policy(task): + policy(task) + task_policy(task) + + setattr(custom_task_policy, "_task_policy_patched_by", "datahub_plugin") + return custom_task_policy + + +def _patch_policy(settings): + print("Patching datahub policy") + if hasattr(settings, "task_policy"): + datahub_task_policy = _wrap_task_policy(settings.task_policy) + settings.task_policy = datahub_task_policy + + +def _patch_datahub_policy(): + with contextlib.suppress(ImportError): + import airflow_local_settings + + _patch_policy(airflow_local_settings) + from airflow.models.dagbag import settings + + _patch_policy(settings) + + +_patch_datahub_policy() + + +class DatahubPlugin(AirflowPlugin): + name = "datahub_plugin" From 62ed57ef29149ac086c9e53e77c01a36d9d72d5c Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Wed, 2 Nov 2022 22:57:18 -0700 Subject: [PATCH 06/16] fix lint --- .../src/datahub_provider/client/airflow_generator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/metadata-ingestion/src/datahub_provider/client/airflow_generator.py b/metadata-ingestion/src/datahub_provider/client/airflow_generator.py index 063630e595c116..31eed2d922fc3a 100644 --- a/metadata-ingestion/src/datahub_provider/client/airflow_generator.py +++ b/metadata-ingestion/src/datahub_provider/client/airflow_generator.py @@ -354,6 +354,7 @@ def complete_dataflow( assert dag_run.dag dataflow = AirflowGenerator.generate_dataflow(cluster, dag_run.dag) + assert dag_run.run_id dpi = DataProcessInstance.from_dataflow(dataflow=dataflow, id=dag_run.run_id) if end_timestamp_millis is None: if dag_run.end_date is None: @@ -394,6 +395,7 @@ def run_datajob( if datajob is None: datajob = AirflowGenerator.generate_datajob(cluster, ti.task, dag) + assert dag_run.run_id dpi = DataProcessInstance.from_datajob( datajob=datajob, id=f"{dag.dag_id}_{ti.task_id}_{dag_run.run_id}", From 184e846ef60a3d305ccd23d0d71c922a31b54d52 Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Wed, 2 Nov 2022 23:20:56 -0700 Subject: [PATCH 07/16] fix plugin --- .../airflow-plugin/src/datahub_airflow_plugin/datahub_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_plugin.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_plugin.py index e54363618a7e38..525c2318d5fca9 100644 --- a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_plugin.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_plugin.py @@ -1,2 +1,2 @@ # This package serves as a shim, but the actual implementation lives in datahub_provider from the acryl-datahub package. -from datahub_provider._plugin import DatahubPlugin +from datahub_provider._plugin import DatahubPlugin # noqa: F401 From a483f2ff3ad3b543d195ccb3de6340d55de98c4f Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Fri, 4 Nov 2022 09:23:37 -0700 Subject: [PATCH 08/16] fix plugin location --- metadata-ingestion/src/datahub_provider/{hooks => }/_plugin.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename metadata-ingestion/src/datahub_provider/{hooks => }/_plugin.py (100%) diff --git a/metadata-ingestion/src/datahub_provider/hooks/_plugin.py b/metadata-ingestion/src/datahub_provider/_plugin.py similarity index 100% rename from metadata-ingestion/src/datahub_provider/hooks/_plugin.py rename to metadata-ingestion/src/datahub_provider/_plugin.py From 7a40495f178f87b93d34041e165426d21194eeac Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Fri, 4 Nov 2022 09:24:05 -0700 Subject: [PATCH 09/16] cleanup dep setup --- metadata-ingestion-modules/airflow-plugin/build.gradle | 2 +- metadata-ingestion-modules/airflow-plugin/setup.py | 6 +----- .../src/datahub_airflow_plugin/datahub_plugin.py | 4 +++- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/metadata-ingestion-modules/airflow-plugin/build.gradle b/metadata-ingestion-modules/airflow-plugin/build.gradle index 2a0d6f76f9d8e6..b627723d3b95d1 100644 --- a/metadata-ingestion-modules/airflow-plugin/build.gradle +++ b/metadata-ingestion-modules/airflow-plugin/build.gradle @@ -7,7 +7,7 @@ ext { venv_name = 'venv' } -def pip_install_command = "USE_DEV_VERSION=1 ${venv_name}/bin/pip install -e ../../metadata-ingestion" +def pip_install_command = "${venv_name}/bin/pip install -e ../../metadata-ingestion" task checkPythonVersion(type: Exec) { commandLine python_executable, '-c', 'import sys; assert sys.version_info >= (3, 7)' diff --git a/metadata-ingestion-modules/airflow-plugin/setup.py b/metadata-ingestion-modules/airflow-plugin/setup.py index 8b8952a5f0dcfa..8c6338c114d88b 100644 --- a/metadata-ingestion-modules/airflow-plugin/setup.py +++ b/metadata-ingestion-modules/airflow-plugin/setup.py @@ -3,8 +3,6 @@ import setuptools -USE_DEV_VERSION = os.environ.get("USE_DEV_VERSION", "0") == "1" - package_metadata: dict = {} with open("./src/datahub_airflow_plugin/__init__.py") as fp: @@ -25,9 +23,7 @@ def get_long_description(): "typing-inspect", "pydantic>=1.5.1", "apache-airflow >= 2.0.2", - "acryl-datahub[airflow] >= 0.8.36" - if not USE_DEV_VERSION - else "acryl-datahub[airflow] == 0.0.0.dev0", + f"acryl-datahub[airflow] == {package_metadata['__version__']}", } diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_plugin.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_plugin.py index 525c2318d5fca9..226a7382f75954 100644 --- a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_plugin.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_plugin.py @@ -1,2 +1,4 @@ -# This package serves as a shim, but the actual implementation lives in datahub_provider from the acryl-datahub package. +# This package serves as a shim, but the actual implementation lives in datahub_provider +# from the acryl-datahub package. We leave this shim here to avoid breaking existing +# Airflow installs. from datahub_provider._plugin import DatahubPlugin # noqa: F401 From f4c76bb0bfede3adf440b8d836e8490119eab953 Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Fri, 4 Nov 2022 11:16:44 -0700 Subject: [PATCH 10/16] test with airflow 2.4 --- .github/workflows/metadata-ingestion.yml | 2 +- metadata-ingestion/src/datahub_provider/_plugin.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/workflows/metadata-ingestion.yml b/.github/workflows/metadata-ingestion.yml index b409fe1f0504f9..9e6a9da5bdd760 100644 --- a/.github/workflows/metadata-ingestion.yml +++ b/.github/workflows/metadata-ingestion.yml @@ -42,7 +42,7 @@ jobs: - python-version: "3.7" extraPythonRequirement: "sqlalchemy==1.3.24 apache-airflow~=2.2.0" - python-version: "3.10" - extraPythonRequirement: "sqlalchemy~=1.4.0 apache-airflow~=2.3.0" + extraPythonRequirement: "sqlalchemy~=1.4.0 apache-airflow~=2.4.0" fail-fast: false steps: - uses: actions/checkout@v3 diff --git a/metadata-ingestion/src/datahub_provider/_plugin.py b/metadata-ingestion/src/datahub_provider/_plugin.py index f341f714ae7a6a..a838d9377df53f 100644 --- a/metadata-ingestion/src/datahub_provider/_plugin.py +++ b/metadata-ingestion/src/datahub_provider/_plugin.py @@ -40,7 +40,8 @@ def get_lineage_config() -> DatahubLineageConfig: def get_inlets_from_task(task: BaseOperator, context: Any) -> Iterable[Any]: - # TODO fix for https://github.com/apache/airflow/commit/1b1f3fabc5909a447a6277cafef3a0d4ef1f01ae + # TODO: Fix for https://github.com/apache/airflow/commit/1b1f3fabc5909a447a6277cafef3a0d4ef1f01ae + # in Airflow 2.4. inlets = [] if isinstance(task._inlets, (str, BaseOperator)) or attr.has(task._inlets): # type: ignore @@ -342,17 +343,19 @@ def custom_task_policy(task): def _patch_policy(settings): - print("Patching datahub policy") if hasattr(settings, "task_policy"): datahub_task_policy = _wrap_task_policy(settings.task_policy) settings.task_policy = datahub_task_policy def _patch_datahub_policy(): + print("Patching datahub policy") + with contextlib.suppress(ImportError): import airflow_local_settings _patch_policy(airflow_local_settings) + from airflow.models.dagbag import settings _patch_policy(settings) From 44462c6f54e2678621c7b9eb31ca986bfceadddf Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Sat, 5 Nov 2022 23:33:03 -0700 Subject: [PATCH 11/16] remove unused method --- .../src/datahub_provider/_lineage_core.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/metadata-ingestion/src/datahub_provider/_lineage_core.py b/metadata-ingestion/src/datahub_provider/_lineage_core.py index ee15f8ca54760f..4941d062fef945 100644 --- a/metadata-ingestion/src/datahub_provider/_lineage_core.py +++ b/metadata-ingestion/src/datahub_provider/_lineage_core.py @@ -1,12 +1,11 @@ -from datahub_provider._airflow_compat import Operator - from datetime import datetime -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List import datahub.emitter.mce_builder as builder from datahub.api.entities.dataprocess.dataprocess_instance import InstanceRunResult from datahub.configuration.common import ConfigModel from datahub.utilities.urns.dataset_urn import DatasetUrn +from datahub_provider._airflow_compat import Operator from datahub_provider.client.airflow_generator import AirflowGenerator from datahub_provider.entities import _Entity @@ -46,18 +45,6 @@ def make_emitter_hook(self) -> "DatahubGenericHook": return DatahubGenericHook(self.datahub_conn_id) -def _task_underscore_inlets(operator: "Operator") -> Optional[List]: - if hasattr(operator, "_inlets"): - return operator._inlets # type: ignore[attr-defined,union-attr] - return None - - -def _task_underscore_outlets(operator: "Operator") -> Optional[List]: - if hasattr(operator, "_outlets"): - return operator._outlets # type: ignore[attr-defined,union-attr] - return None - - def send_lineage_to_datahub( config: DatahubBasicLineageConfig, operator: "Operator", From 2ee6589af38248ca61660fd033f3a1bef4390b12 Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Sat, 5 Nov 2022 23:33:09 -0700 Subject: [PATCH 12/16] possibly useless commit --- .../src/datahub_provider/_plugin.py | 92 ++++--------------- 1 file changed, 19 insertions(+), 73 deletions(-) diff --git a/metadata-ingestion/src/datahub_provider/_plugin.py b/metadata-ingestion/src/datahub_provider/_plugin.py index a838d9377df53f..428f2507eb548e 100644 --- a/metadata-ingestion/src/datahub_provider/_plugin.py +++ b/metadata-ingestion/src/datahub_provider/_plugin.py @@ -1,6 +1,6 @@ import contextlib import traceback -from typing import Any, Iterable +from typing import Any, Iterable, List, Optional import attr from airflow.configuration import conf @@ -11,6 +11,7 @@ from cattr import structure from datahub.api.entities.dataprocess.dataprocess_instance import InstanceRunResult +from datahub_provider._airflow_compat import Operator from datahub_provider.client.airflow_generator import AirflowGenerator from datahub_provider.hooks.datahub import DatahubGenericHook from datahub_provider.lineage.datahub import DatahubLineageConfig @@ -39,9 +40,22 @@ def get_lineage_config() -> DatahubLineageConfig: ) +def _task_inlets(operator: "Operator") -> Optional[List]: + if hasattr(operator, "_inlets"): + return operator._inlets # type: ignore[attr-defined,union-attr] + return operator.inlets + + +def _task_outlets(operator: "Operator") -> Optional[List]: + if hasattr(operator, "_outlets"): + return operator._outlets # type: ignore[attr-defined,union-attr] + return operator.outlets + + def get_inlets_from_task(task: BaseOperator, context: Any) -> Iterable[Any]: # TODO: Fix for https://github.com/apache/airflow/commit/1b1f3fabc5909a447a6277cafef3a0d4ef1f01ae # in Airflow 2.4. + # TODO: ignore/handle airflow's dataset type in our lineage inlets = [] if isinstance(task._inlets, (str, BaseOperator)) or attr.has(task._inlets): # type: ignore @@ -86,75 +100,7 @@ def get_inlets_from_task(task: BaseOperator, context: Any) -> Iterable[Any]: return inlets -def datahub_on_failure_callback(context, *args, **kwargs): - ti = context["ti"] - task: "BaseOperator" = ti.task - dag = context["dag"] - - # This code is from the original airflow lineage code -> - # https://github.com/apache/airflow/blob/main/airflow/lineage/__init__.py - inlets = get_inlets_from_task(task, context) - - emitter = ( - DatahubGenericHook(context["_datahub_config"].datahub_conn_id) - .get_underlying_hook() - .make_emitter() - ) - - dataflow = AirflowGenerator.generate_dataflow( - cluster=context["_datahub_config"].cluster, - dag=dag, - capture_tags=context["_datahub_config"].capture_tags_info, - capture_owner=context["_datahub_config"].capture_ownership_info, - ) - dataflow.emit(emitter) - - task.log.info(f"Emitted Datahub DataFlow: {dataflow}") - - datajob = AirflowGenerator.generate_datajob( - cluster=context["_datahub_config"].cluster, - task=context["ti"].task, - dag=dag, - capture_tags=context["_datahub_config"].capture_tags_info, - capture_owner=context["_datahub_config"].capture_ownership_info, - ) - - for inlet in inlets: - datajob.inlets.append(inlet.urn) - - for outlet in task._outlets: - datajob.outlets.append(outlet.urn) - - task.log.info(f"Emitted Datahub DataJob: {datajob}") - datajob.emit(emitter) - - if context["_datahub_config"].capture_executions: - dpi = AirflowGenerator.run_datajob( - emitter=emitter, - cluster=context["_datahub_config"].cluster, - ti=context["ti"], - dag=dag, - dag_run=context["dag_run"], - datajob=datajob, - start_timestamp_millis=int(ti.start_date.timestamp() * 1000), - ) - - task.log.info(f"Emitted Start Datahub Dataprocess Instance: {dpi}") - - dpi = AirflowGenerator.complete_datajob( - emitter=emitter, - cluster=context["_datahub_config"].cluster, - ti=context["ti"], - dag_run=context["dag_run"], - result=InstanceRunResult.FAILURE, - dag=dag, - datajob=datajob, - end_timestamp_millis=int(ti.end_date.timestamp() * 1000), - ) - task.log.info(f"Emitted Completed Datahub Dataprocess Instance: {dpi}") - - -def datahub_on_success_callback(context, *args, **kwargs): +def datahub_task_status_callback(context, status): ti = context["ti"] task: "BaseOperator" = ti.task dag = context["dag"] @@ -215,7 +161,7 @@ def datahub_on_success_callback(context, *args, **kwargs): cluster=context["_datahub_config"].cluster, ti=context["ti"], dag_run=context["dag_run"], - result=InstanceRunResult.SUCCESS, + result=status, dag=dag, datajob=datajob, end_timestamp_millis=int(ti.end_date.timestamp() * 1000), @@ -289,7 +235,7 @@ def custom_on_failure_callback(context): config = get_lineage_config() context["_datahub_config"] = config try: - datahub_on_failure_callback(context) + datahub_task_status_callback(context, status=InstanceRunResult.FAILURE) except Exception as e: if not config.graceful_exceptions: raise e @@ -308,7 +254,7 @@ def custom_on_success_callback(context): config = get_lineage_config() context["_datahub_config"] = config try: - datahub_on_success_callback(context) + datahub_task_status_callback(context, status=InstanceRunResult.SUCCESS) except Exception as e: if not config.graceful_exceptions: raise e From 96e26227fa891a14dad73acb084918d9a2dd3839 Mon Sep 17 00:00:00 2001 From: treff7es Date: Thu, 10 Nov 2022 11:03:08 +0100 Subject: [PATCH 13/16] Making mypy happy --- .../src/datahub_provider/_lineage_core.py | 3 +- .../src/datahub_provider/_plugin.py | 43 +++++++++++-------- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/metadata-ingestion/src/datahub_provider/_lineage_core.py b/metadata-ingestion/src/datahub_provider/_lineage_core.py index 4941d062fef945..89956c9354b553 100644 --- a/metadata-ingestion/src/datahub_provider/_lineage_core.py +++ b/metadata-ingestion/src/datahub_provider/_lineage_core.py @@ -1,3 +1,5 @@ +from datahub_provider._airflow_compat import Operator + from datetime import datetime from typing import TYPE_CHECKING, Dict, List @@ -5,7 +7,6 @@ from datahub.api.entities.dataprocess.dataprocess_instance import InstanceRunResult from datahub.configuration.common import ConfigModel from datahub.utilities.urns.dataset_urn import DatasetUrn -from datahub_provider._airflow_compat import Operator from datahub_provider.client.airflow_generator import AirflowGenerator from datahub_provider.entities import _Entity diff --git a/metadata-ingestion/src/datahub_provider/_plugin.py b/metadata-ingestion/src/datahub_provider/_plugin.py index 428f2507eb548e..ea24af77fed739 100644 --- a/metadata-ingestion/src/datahub_provider/_plugin.py +++ b/metadata-ingestion/src/datahub_provider/_plugin.py @@ -1,8 +1,9 @@ +from datahub_provider._airflow_compat import Operator + import contextlib import traceback -from typing import Any, Iterable, List, Optional +from typing import Any, Iterable, List -import attr from airflow.configuration import conf from airflow.lineage import PIPELINE_OUTLETS from airflow.models.baseoperator import BaseOperator @@ -11,7 +12,6 @@ from cattr import structure from datahub.api.entities.dataprocess.dataprocess_instance import InstanceRunResult -from datahub_provider._airflow_compat import Operator from datahub_provider.client.airflow_generator import AirflowGenerator from datahub_provider.hooks.datahub import DatahubGenericHook from datahub_provider.lineage.datahub import DatahubLineageConfig @@ -40,15 +40,18 @@ def get_lineage_config() -> DatahubLineageConfig: ) -def _task_inlets(operator: "Operator") -> Optional[List]: +def _task_inlets(operator: "Operator") -> List: + # From Airflow 2.4 _inlets is dropped and inlets used consistently. Earlier it was not the case, so we have to stick there to _inlets if hasattr(operator, "_inlets"): - return operator._inlets # type: ignore[attr-defined,union-attr] + return operator._inlets # type: ignore[attr-defined, union-attr] return operator.inlets -def _task_outlets(operator: "Operator") -> Optional[List]: +def _task_outlets(operator: "Operator") -> List: + # From Airflow 2.4 _outlets is dropped and inlets used consistently. Earlier it was not the case, so we have to stick there to _outlets + # We have to use _outlets because outlets is empty in Airflow < 2.4.0 if hasattr(operator, "_outlets"): - return operator._outlets # type: ignore[attr-defined,union-attr] + return operator._outlets # type: ignore[attr-defined, union-attr] return operator.outlets @@ -57,24 +60,26 @@ def get_inlets_from_task(task: BaseOperator, context: Any) -> Iterable[Any]: # in Airflow 2.4. # TODO: ignore/handle airflow's dataset type in our lineage - inlets = [] - if isinstance(task._inlets, (str, BaseOperator)) or attr.has(task._inlets): # type: ignore + inlets: List[Any] = [] + task_inlets = _task_inlets(task) + # From Airflow 2.3 this should be AbstractOperator but due to compatibility reason lets use BaseOperator + if isinstance(task_inlets, (str, BaseOperator)): inlets = [ - task._inlets, + task_inlets, ] - if task._inlets and isinstance(task._inlets, list): + if task_inlets and isinstance(task_inlets, list): inlets = [] task_ids = ( - {o for o in task._inlets if isinstance(o, str)} - .union(op.task_id for op in task._inlets if isinstance(op, BaseOperator)) + {o for o in task_inlets if isinstance(o, str)} + .union(op.task_id for op in task_inlets if isinstance(op, BaseOperator)) .intersection(task.get_flat_relative_ids(upstream=True)) ) from airflow.lineage import AUTO # pick up unique direct upstream task_ids if AUTO is specified - if AUTO.upper() in task._inlets or AUTO.lower() in task._inlets: + if AUTO.upper() in task_inlets or AUTO.lower() in task_inlets: print("Picking up unique direct upstream task_ids as AUTO is specified") task_ids = task_ids.union( task_ids.symmetric_difference(task.upstream_task_ids) @@ -93,7 +98,7 @@ def get_inlets_from_task(task: BaseOperator, context: Any) -> Iterable[Any]: for item in sublist ] - for inlet in task._inlets: + for inlet in task_inlets: if type(inlet) != str: inlets.append(inlet) @@ -136,8 +141,8 @@ def datahub_task_status_callback(context, status): for inlet in inlets: datajob.inlets.append(inlet.urn) - # We have to use _outlets because outlets is empty - for outlet in task._outlets: + task_outlets = _task_outlets(task) + for outlet in task_outlets: datajob.outlets.append(outlet.urn) task.log.info(f"Emitted Datahub dataJob: {datajob}") @@ -197,7 +202,9 @@ def datahub_pre_execution(context): for inlet in inlets: datajob.inlets.append(inlet.urn) - for outlet in task._outlets: + task_outlets = _task_outlets(task) + + for outlet in task_outlets: datajob.outlets.append(outlet.urn) datajob.emit(emitter) From 4196e712c217a44f52f1843eeaebce12bbba83b7 Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Fri, 11 Nov 2022 12:29:25 -0500 Subject: [PATCH 14/16] include sinks in airflow plugin --- metadata-ingestion-modules/airflow-plugin/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metadata-ingestion-modules/airflow-plugin/setup.py b/metadata-ingestion-modules/airflow-plugin/setup.py index 8c6338c114d88b..cd0162aa0e73c7 100644 --- a/metadata-ingestion-modules/airflow-plugin/setup.py +++ b/metadata-ingestion-modules/airflow-plugin/setup.py @@ -23,7 +23,7 @@ def get_long_description(): "typing-inspect", "pydantic>=1.5.1", "apache-airflow >= 2.0.2", - f"acryl-datahub[airflow] == {package_metadata['__version__']}", + f"acryl-datahub[airflow,datahub-rest,datahub-kafka] == {package_metadata['__version__']}", } From aa73660dc02deb9361b46f2a03c01eedcccd3de9 Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Fri, 11 Nov 2022 12:30:28 -0500 Subject: [PATCH 15/16] Revert "include sinks in airflow plugin" This reverts commit 4196e712c217a44f52f1843eeaebce12bbba83b7. --- metadata-ingestion-modules/airflow-plugin/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metadata-ingestion-modules/airflow-plugin/setup.py b/metadata-ingestion-modules/airflow-plugin/setup.py index cd0162aa0e73c7..8c6338c114d88b 100644 --- a/metadata-ingestion-modules/airflow-plugin/setup.py +++ b/metadata-ingestion-modules/airflow-plugin/setup.py @@ -23,7 +23,7 @@ def get_long_description(): "typing-inspect", "pydantic>=1.5.1", "apache-airflow >= 2.0.2", - f"acryl-datahub[airflow,datahub-rest,datahub-kafka] == {package_metadata['__version__']}", + f"acryl-datahub[airflow] == {package_metadata['__version__']}", } From 6e91c8683bcd1ce535dba023ce4eb1f2721b499d Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Fri, 11 Nov 2022 12:31:36 -0500 Subject: [PATCH 16/16] include hooks in airflow setup.py --- .github/workflows/metadata-ingestion.yml | 2 +- metadata-ingestion/setup.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/metadata-ingestion.yml b/.github/workflows/metadata-ingestion.yml index 9e6a9da5bdd760..043a32a4ce06fd 100644 --- a/.github/workflows/metadata-ingestion.yml +++ b/.github/workflows/metadata-ingestion.yml @@ -42,7 +42,7 @@ jobs: - python-version: "3.7" extraPythonRequirement: "sqlalchemy==1.3.24 apache-airflow~=2.2.0" - python-version: "3.10" - extraPythonRequirement: "sqlalchemy~=1.4.0 apache-airflow~=2.4.0" + extraPythonRequirement: "sqlalchemy~=1.4.0 apache-airflow>=2.4.0" fail-fast: false steps: - uses: actions/checkout@v3 diff --git a/metadata-ingestion/setup.py b/metadata-ingestion/setup.py index 76e5572d311f12..04af3e0dc1d180 100644 --- a/metadata-ingestion/setup.py +++ b/metadata-ingestion/setup.py @@ -55,6 +55,10 @@ def get_long_description(): "click-spinner", } +rest_common = { + "requests", +} + kafka_common = { # The confluent_kafka package provides a number of pre-built wheels for # various platforms and architectures. However, it does not provide wheels @@ -220,10 +224,12 @@ def get_long_description(): plugins: Dict[str, Set[str]] = { # Sink plugins. "datahub-kafka": kafka_common, - "datahub-rest": {"requests"}, + "datahub-rest": rest_common, # Integrations. "airflow": { "apache-airflow >= 2.0.2", + *rest_common, + *kafka_common, }, "circuit-breaker": { "gql>=3.3.0",