diff --git a/metadata-ingestion/src/datahub_provider/client/airflow_generator.py b/metadata-ingestion/src/datahub_provider/client/airflow_generator.py index b5c389d298969e..8f8096b62074ec 100644 --- a/metadata-ingestion/src/datahub_provider/client/airflow_generator.py +++ b/metadata-ingestion/src/datahub_provider/client/airflow_generator.py @@ -114,8 +114,6 @@ def generate_dataflow( :param capture_owner: :return: DataFlow - Data generated dataflow """ - from airflow.serialization.serialized_objects import SerializedDAG - id = dag.dag_id orchestrator = "airflow" description = f"{dag.description}\n\n{dag.doc_md or ''}" @@ -123,13 +121,7 @@ def generate_dataflow( cluster=cluster, id=id, orchestrator=orchestrator, description=description ) - flow_property_bag: Dict[str, str] = { - key: repr(value) - for (key, value) in SerializedDAG.serialize_dag(dag).items() - } - for key in dag.get_serialized_fields(): - if key not in flow_property_bag: - flow_property_bag[key] = repr(getattr(dag, key)) + flow_property_bag: Dict[str, str] = {} allowed_flow_keys = [ "_access_control", @@ -142,9 +134,10 @@ def generate_dataflow( "tags", "timezone", ] - flow_property_bag = { - k: v for (k, v) in flow_property_bag.items() if k in allowed_flow_keys - } + + for key in allowed_flow_keys: + if hasattr(dag, key): + flow_property_bag[key] = repr(getattr(dag, key)) data_flow.properties = flow_property_bag base_url = conf.get("webserver", "base_url") @@ -191,21 +184,13 @@ def generate_datajob( :param capture_tags: bool - whether to set tags automatically from airflow task :return: DataJob - returns the generated DataJob object """ - from airflow.serialization.serialized_objects import SerializedBaseOperator - dataflow_urn = DataFlowUrn.create_from_ids( orchestrator="airflow", env=cluster, flow_id=dag.dag_id ) datajob = DataJob(id=task.task_id, flow_urn=dataflow_urn) datajob.description = AirflowGenerator._get_description(task) - job_property_bag: Dict[str, str] = { - key: repr(value) - for (key, value) in SerializedBaseOperator.serialize_operator(task).items() - } - for key in task.get_serialized_fields(): - if key not in job_property_bag: - job_property_bag[key] = repr(getattr(task, key)) + job_property_bag: Dict[str, str] = {} allowed_task_keys = [ "_downstream_task_ids", @@ -223,9 +208,10 @@ def generate_datajob( "trigger_rule", "wait_for_downstream", ] - job_property_bag = { - k: v for (k, v) in job_property_bag.items() if k in allowed_task_keys - } + + for key in allowed_task_keys: + if hasattr(task, key): + job_property_bag[key] = repr(getattr(task, key)) datajob.properties = job_property_bag base_url = conf.get("webserver", "base_url")