Skip to content

Commit

Permalink
Fix for failing serilization when Param was specified
Browse files Browse the repository at this point in the history
  • Loading branch information
treff7es committed Jul 11, 2022
1 parent 5bb7fe3 commit 00590dc
Showing 1 changed file with 10 additions and 24 deletions.
34 changes: 10 additions & 24 deletions metadata-ingestion/src/datahub_provider/client/airflow_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,22 +114,14 @@ def generate_dataflow(
:param capture_owner:
:return: DataFlow - Data generated dataflow
"""
from airflow.serialization.serialized_objects import SerializedDAG

id = dag.dag_id
orchestrator = "airflow"
description = f"{dag.description}\n\n{dag.doc_md or ''}"
data_flow = DataFlow(
cluster=cluster, id=id, orchestrator=orchestrator, description=description
)

flow_property_bag: Dict[str, str] = {
key: repr(value)
for (key, value) in SerializedDAG.serialize_dag(dag).items()
}
for key in dag.get_serialized_fields():
if key not in flow_property_bag:
flow_property_bag[key] = repr(getattr(dag, key))
flow_property_bag: Dict[str, str] = {}

allowed_flow_keys = [
"_access_control",
Expand All @@ -142,9 +134,10 @@ def generate_dataflow(
"tags",
"timezone",
]
flow_property_bag = {
k: v for (k, v) in flow_property_bag.items() if k in allowed_flow_keys
}

for key in allowed_flow_keys:
if hasattr(dag, key):
flow_property_bag[key] = repr(getattr(dag, key))

data_flow.properties = flow_property_bag
base_url = conf.get("webserver", "base_url")
Expand Down Expand Up @@ -191,21 +184,13 @@ def generate_datajob(
:param capture_tags: bool - whether to set tags automatically from airflow task
:return: DataJob - returns the generated DataJob object
"""
from airflow.serialization.serialized_objects import SerializedBaseOperator

dataflow_urn = DataFlowUrn.create_from_ids(
orchestrator="airflow", env=cluster, flow_id=dag.dag_id
)
datajob = DataJob(id=task.task_id, flow_urn=dataflow_urn)
datajob.description = AirflowGenerator._get_description(task)

job_property_bag: Dict[str, str] = {
key: repr(value)
for (key, value) in SerializedBaseOperator.serialize_operator(task).items()
}
for key in task.get_serialized_fields():
if key not in job_property_bag:
job_property_bag[key] = repr(getattr(task, key))
job_property_bag: Dict[str, str] = {}

allowed_task_keys = [
"_downstream_task_ids",
Expand All @@ -223,9 +208,10 @@ def generate_datajob(
"trigger_rule",
"wait_for_downstream",
]
job_property_bag = {
k: v for (k, v) in job_property_bag.items() if k in allowed_task_keys
}

for key in allowed_task_keys:
if hasattr(task, key):
job_property_bag[key] = repr(getattr(task, key))

datajob.properties = job_property_bag
base_url = conf.get("webserver", "base_url")
Expand Down

0 comments on commit 00590dc

Please sign in to comment.