Skip to content

Commit

Permalink
fix(airflow): fix for failing serialisation when Param was specified …
Browse files Browse the repository at this point in the history
…+ support for external task sensor (datahub-project#5368)

fixes datahub-project#4546
  • Loading branch information
treff7es authored and maggiehays committed Aug 1, 2022
1 parent a151f19 commit 2610255
Showing 1 changed file with 33 additions and 25 deletions.
58 changes: 33 additions & 25 deletions metadata-ingestion/src/datahub_provider/client/airflow_generator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast

from airflow.configuration import conf

Expand Down Expand Up @@ -87,6 +87,27 @@ def _get_dependencies(
if subdag_task_id in upstream_task._downstream_task_ids:
upstream_subdag_triggers.append(upstream_task_urn)

# If the operator is an ExternalTaskSensor then we set the remote task as upstream.
# It is possible to tie an external sensor to DAG if external_task_id is omitted but currently we can't tie
# jobflow to anothet jobflow.
external_task_upstreams = []
if task.task_type == "ExternalTaskSensor":
from airflow.sensors.external_task_sensor import ExternalTaskSensor

task = cast(ExternalTaskSensor, task)
if hasattr(task, "external_task_id") and task.external_task_id is not None:
external_task_upstreams = [
DataJobUrn.create_from_ids(
job_id=task.external_task_id,
data_flow_urn=str(
DataFlowUrn.create_from_ids(
orchestrator=flow_urn.get_orchestrator_name(),
flow_id=task.external_dag_id,
env=flow_urn.get_env(),
)
),
)
]
# exclude subdag operator tasks since these are not emitted, resulting in empty metadata
upstream_tasks = (
[
Expand All @@ -96,6 +117,7 @@ def _get_dependencies(
]
+ upstream_subdag_task_urns
+ upstream_subdag_triggers
+ external_task_upstreams
)
return upstream_tasks

Expand All @@ -114,22 +136,14 @@ def generate_dataflow(
:param capture_owner:
:return: DataFlow - Data generated dataflow
"""
from airflow.serialization.serialized_objects import SerializedDAG

id = dag.dag_id
orchestrator = "airflow"
description = f"{dag.description}\n\n{dag.doc_md or ''}"
data_flow = DataFlow(
cluster=cluster, id=id, orchestrator=orchestrator, description=description
)

flow_property_bag: Dict[str, str] = {
key: repr(value)
for (key, value) in SerializedDAG.serialize_dag(dag).items()
}
for key in dag.get_serialized_fields():
if key not in flow_property_bag:
flow_property_bag[key] = repr(getattr(dag, key))
flow_property_bag: Dict[str, str] = {}

allowed_flow_keys = [
"_access_control",
Expand All @@ -142,9 +156,10 @@ def generate_dataflow(
"tags",
"timezone",
]
flow_property_bag = {
k: v for (k, v) in flow_property_bag.items() if k in allowed_flow_keys
}

for key in allowed_flow_keys:
if hasattr(dag, key):
flow_property_bag[key] = repr(getattr(dag, key))

data_flow.properties = flow_property_bag
base_url = conf.get("webserver", "base_url")
Expand Down Expand Up @@ -191,21 +206,13 @@ def generate_datajob(
:param capture_tags: bool - whether to set tags automatically from airflow task
:return: DataJob - returns the generated DataJob object
"""
from airflow.serialization.serialized_objects import SerializedBaseOperator

dataflow_urn = DataFlowUrn.create_from_ids(
orchestrator="airflow", env=cluster, flow_id=dag.dag_id
)
datajob = DataJob(id=task.task_id, flow_urn=dataflow_urn)
datajob.description = AirflowGenerator._get_description(task)

job_property_bag: Dict[str, str] = {
key: repr(value)
for (key, value) in SerializedBaseOperator.serialize_operator(task).items()
}
for key in task.get_serialized_fields():
if key not in job_property_bag:
job_property_bag[key] = repr(getattr(task, key))
job_property_bag: Dict[str, str] = {}

allowed_task_keys = [
"_downstream_task_ids",
Expand All @@ -223,9 +230,10 @@ def generate_datajob(
"trigger_rule",
"wait_for_downstream",
]
job_property_bag = {
k: v for (k, v) in job_property_bag.items() if k in allowed_task_keys
}

for key in allowed_task_keys:
if hasattr(task, key):
job_property_bag[key] = repr(getattr(task, key))

datajob.properties = job_property_bag
base_url = conf.get("webserver", "base_url")
Expand Down

0 comments on commit 2610255

Please sign in to comment.