From 261025500acafab75441d8431740d328e09ab172 Mon Sep 17 00:00:00 2001 From: Tamas Nemeth Date: Tue, 12 Jul 2022 19:20:27 +0200 Subject: [PATCH] fix(airflow): fix for failing serialisation when Param was specified + support for external task sensor (#5368) fixes #4546 --- .../client/airflow_generator.py | 58 +++++++++++-------- 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/metadata-ingestion/src/datahub_provider/client/airflow_generator.py b/metadata-ingestion/src/datahub_provider/client/airflow_generator.py index b5c389d298969e..b7864ddb71ea60 100644 --- a/metadata-ingestion/src/datahub_provider/client/airflow_generator.py +++ b/metadata-ingestion/src/datahub_provider/client/airflow_generator.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast from airflow.configuration import conf @@ -87,6 +87,27 @@ def _get_dependencies( if subdag_task_id in upstream_task._downstream_task_ids: upstream_subdag_triggers.append(upstream_task_urn) + # If the operator is an ExternalTaskSensor then we set the remote task as upstream. + # It is possible to tie an external sensor to DAG if external_task_id is omitted but currently we can't tie + # jobflow to anothet jobflow. + external_task_upstreams = [] + if task.task_type == "ExternalTaskSensor": + from airflow.sensors.external_task_sensor import ExternalTaskSensor + + task = cast(ExternalTaskSensor, task) + if hasattr(task, "external_task_id") and task.external_task_id is not None: + external_task_upstreams = [ + DataJobUrn.create_from_ids( + job_id=task.external_task_id, + data_flow_urn=str( + DataFlowUrn.create_from_ids( + orchestrator=flow_urn.get_orchestrator_name(), + flow_id=task.external_dag_id, + env=flow_urn.get_env(), + ) + ), + ) + ] # exclude subdag operator tasks since these are not emitted, resulting in empty metadata upstream_tasks = ( [ @@ -96,6 +117,7 @@ def _get_dependencies( ] + upstream_subdag_task_urns + upstream_subdag_triggers + + external_task_upstreams ) return upstream_tasks @@ -114,8 +136,6 @@ def generate_dataflow( :param capture_owner: :return: DataFlow - Data generated dataflow """ - from airflow.serialization.serialized_objects import SerializedDAG - id = dag.dag_id orchestrator = "airflow" description = f"{dag.description}\n\n{dag.doc_md or ''}" @@ -123,13 +143,7 @@ def generate_dataflow( cluster=cluster, id=id, orchestrator=orchestrator, description=description ) - flow_property_bag: Dict[str, str] = { - key: repr(value) - for (key, value) in SerializedDAG.serialize_dag(dag).items() - } - for key in dag.get_serialized_fields(): - if key not in flow_property_bag: - flow_property_bag[key] = repr(getattr(dag, key)) + flow_property_bag: Dict[str, str] = {} allowed_flow_keys = [ "_access_control", @@ -142,9 +156,10 @@ def generate_dataflow( "tags", "timezone", ] - flow_property_bag = { - k: v for (k, v) in flow_property_bag.items() if k in allowed_flow_keys - } + + for key in allowed_flow_keys: + if hasattr(dag, key): + flow_property_bag[key] = repr(getattr(dag, key)) data_flow.properties = flow_property_bag base_url = conf.get("webserver", "base_url") @@ -191,21 +206,13 @@ def generate_datajob( :param capture_tags: bool - whether to set tags automatically from airflow task :return: DataJob - returns the generated DataJob object """ - from airflow.serialization.serialized_objects import SerializedBaseOperator - dataflow_urn = DataFlowUrn.create_from_ids( orchestrator="airflow", env=cluster, flow_id=dag.dag_id ) datajob = DataJob(id=task.task_id, flow_urn=dataflow_urn) datajob.description = AirflowGenerator._get_description(task) - job_property_bag: Dict[str, str] = { - key: repr(value) - for (key, value) in SerializedBaseOperator.serialize_operator(task).items() - } - for key in task.get_serialized_fields(): - if key not in job_property_bag: - job_property_bag[key] = repr(getattr(task, key)) + job_property_bag: Dict[str, str] = {} allowed_task_keys = [ "_downstream_task_ids", @@ -223,9 +230,10 @@ def generate_datajob( "trigger_rule", "wait_for_downstream", ] - job_property_bag = { - k: v for (k, v) in job_property_bag.items() if k in allowed_task_keys - } + + for key in allowed_task_keys: + if hasattr(task, key): + job_property_bag[key] = repr(getattr(task, key)) datajob.properties = job_property_bag base_url = conf.get("webserver", "base_url")