From 1ed014647e7293d342d9d1c2706343a68f003655 Mon Sep 17 00:00:00 2001 From: Josh Fell <48934154+josh-fell@users.noreply.github.com> Date: Sun, 28 Aug 2022 13:07:38 -0400 Subject: [PATCH] Add `output` property to MappedOperator (#25604) Co-authored-by: Tzu-ping Chung --- airflow/example_dags/example_xcom.py | 9 ++-- airflow/models/baseoperator.py | 3 +- airflow/models/mappedoperator.py | 7 +++ .../amazon/aws/example_dags/example_dms.py | 13 +++--- .../amazon/aws/example_dags/example_ecs.py | 9 ++-- .../amazon/aws/example_dags/example_emr.py | 16 +++---- .../example_automl_nl_text_classification.py | 6 ++- .../example_automl_nl_text_sentiment.py | 6 ++- .../example_dags/example_automl_tables.py | 9 ++-- .../example_automl_translation.py | 6 ++- ...utoml_video_intelligence_classification.py | 6 ++- ...mple_automl_video_intelligence_tracking.py | 6 ++- .../example_automl_vision_object_detection.py | 6 ++- .../example_dags/example_bigquery_dts.py | 6 ++- .../cloud/example_dags/example_datafusion.py | 3 +- .../cloud/example_dags/example_looker.py | 3 +- .../cloud/example_dags/example_vision.py | 5 ++- .../example_dags/example_display_video.py | 5 ++- tests/models/test_mappedoperator.py | 45 +++++++++++++++++++ .../airbyte/example_airbyte_trigger_job.py | 3 +- .../providers/amazon/aws/example_athena.py | 3 +- .../providers/amazon/aws/example_batch.py | 3 +- .../amazon/aws/example_emr_serverless.py | 11 +++-- .../providers/amazon/aws/example_glue.py | 8 ++-- .../amazon/aws/example_step_functions.py | 9 ++-- .../providers/dbt/cloud/example_dbt_cloud.py | 5 ++- .../example_automl_nl_text_extraction.py | 6 ++- .../example_automl_vision_classification.py | 6 ++- .../cloud/cloud_build/example_cloud_build.py | 13 +++--- .../example_cloud_build_trigger.py | 13 +++--- .../cloud/cloud_sql/example_cloud_sql.py | 7 ++- .../dataproc/example_dataproc_spark_async.py | 3 +- .../datastore/example_datastore_rollback.py | 4 +- .../google/cloud/gcs/example_sheets.py | 3 +- .../google/cloud/pubsub/example_pubsub.py | 3 +- .../cloud/workflows/example_workflows.py | 6 ++- .../example_datacatalog_entries.py | 5 ++- .../example_datacatalog_search_catalog.py | 14 +++--- .../example_datacatalog_tag_templates.py | 5 ++- .../datacatalog/example_datacatalog_tags.py | 16 ++++--- .../example_campaign_manager.py | 6 ++- .../marketing_platform/example_search_ads.py | 4 +- .../azure/example_adf_run_pipeline.py | 4 +- .../providers/tableau/example_tableau.py | 3 +- .../example_tableau_refresh_workbook.py | 3 +- 45 files changed, 227 insertions(+), 108 deletions(-) diff --git a/airflow/example_dags/example_xcom.py b/airflow/example_dags/example_xcom.py index fa9cb685caa12..971c8ff58af76 100644 --- a/airflow/example_dags/example_xcom.py +++ b/airflow/example_dags/example_xcom.py @@ -19,7 +19,7 @@ """Example DAG demonstrating the usage of XComs.""" import pendulum -from airflow import DAG +from airflow import DAG, XComArg from airflow.decorators import task from airflow.operators.bash import BashOperator @@ -79,8 +79,8 @@ def pull_value_from_bash_push(ti=None): bash_pull = BashOperator( task_id='bash_pull', bash_command='echo "bash pull demo" && ' - f'echo "The xcom pushed manually is {bash_push.output["manually_pushed_value"]}" && ' - f'echo "The returned_value xcom is {bash_push.output}" && ' + f'echo "The xcom pushed manually is {XComArg(bash_push, key="manually_pushed_value")}" && ' + f'echo "The returned_value xcom is {XComArg(bash_push)}" && ' 'echo "finished"', do_xcom_push=False, ) @@ -90,6 +90,3 @@ def pull_value_from_bash_push(ti=None): [bash_pull, python_pull_from_bash] << bash_push puller(push_by_returning()) << push() - - # Task dependencies created via `XComArgs`: - # pull << push2 diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 77a92b4924322..7578d75efc3e4 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -98,6 +98,7 @@ from airflow.models.dag import DAG from airflow.models.taskinstance import TaskInstanceKey + from airflow.models.xcom_arg import XComArg from airflow.utils.task_group import TaskGroup ScheduleInterval = Union[str, timedelta, relativedelta] @@ -1365,7 +1366,7 @@ def leaves(self) -> List["BaseOperator"]: return [self] @property - def output(self): + def output(self) -> "XComArg": """Returns reference to XCom pushed by current operator""" from airflow.models.xcom_arg import XComArg diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 384799cc6cb21..82d7eea870161 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -530,6 +530,13 @@ def get_dag(self) -> Optional["DAG"]: """Implementing Operator.""" return self.dag + @property + def output(self) -> "XComArg": + """Returns reference to XCom pushed by current operator""" + from airflow.models.xcom_arg import XComArg + + return XComArg(operator=self) + def serialize_for_task_group(self) -> Tuple[DagAttributeTypes, Any]: """Implementing DAGNode.""" return DagAttributeTypes.OP, self.task_id diff --git a/airflow/providers/amazon/aws/example_dags/example_dms.py b/airflow/providers/amazon/aws/example_dags/example_dms.py index 77bcbe3ef22aa..0ce8a2e5e2053 100644 --- a/airflow/providers/amazon/aws/example_dags/example_dms.py +++ b/airflow/providers/amazon/aws/example_dags/example_dms.py @@ -23,6 +23,7 @@ import json import os from datetime import datetime +from typing import cast import boto3 from sqlalchemy import Column, MetaData, String, Table, create_engine @@ -256,10 +257,12 @@ def delete_dms_assets(): ) # [END howto_operator_dms_create_task] + task_arn = cast(str, create_task.output) + # [START howto_operator_dms_start_task] start_task = DmsStartTaskOperator( task_id='start_task', - replication_task_arn=create_task.output, + replication_task_arn=task_arn, ) # [END howto_operator_dms_start_task] @@ -280,7 +283,7 @@ def delete_dms_assets(): await_task_start = DmsTaskBaseSensor( task_id='await_task_start', - replication_task_arn=create_task.output, + replication_task_arn=task_arn, target_statuses=['running'], termination_statuses=['stopped', 'deleting', 'failed'], ) @@ -288,7 +291,7 @@ def delete_dms_assets(): # [START howto_operator_dms_stop_task] stop_task = DmsStopTaskOperator( task_id='stop_task', - replication_task_arn=create_task.output, + replication_task_arn=task_arn, ) # [END howto_operator_dms_stop_task] @@ -296,14 +299,14 @@ def delete_dms_assets(): # [START howto_sensor_dms_task_completed] await_task_stop = DmsTaskCompletedSensor( task_id='await_task_stop', - replication_task_arn=create_task.output, + replication_task_arn=task_arn, ) # [END howto_sensor_dms_task_completed] # [START howto_operator_dms_delete_task] delete_task = DmsDeleteTaskOperator( task_id='delete_task', - replication_task_arn=create_task.output, + replication_task_arn=task_arn, trigger_rule='all_done', ) # [END howto_operator_dms_delete_task] diff --git a/airflow/providers/amazon/aws/example_dags/example_ecs.py b/airflow/providers/amazon/aws/example_dags/example_ecs.py index 3f14c11c76970..a2d7ca2f6c7e0 100644 --- a/airflow/providers/amazon/aws/example_dags/example_ecs.py +++ b/airflow/providers/amazon/aws/example_dags/example_ecs.py @@ -16,6 +16,7 @@ # under the License. from datetime import datetime +from typing import cast from airflow import DAG from airflow.models.baseoperator import chain @@ -99,10 +100,12 @@ ) # [END howto_operator_ecs_register_task_definition] + registered_task_definition = cast(str, register_task.output) + # [START howto_sensor_ecs_task_definition_state] await_task_definition = EcsTaskDefinitionStateSensor( task_id='await_task_definition', - task_definition=register_task.output, + task_definition=registered_task_definition, ) # [END howto_sensor_ecs_task_definition_state] @@ -110,7 +113,7 @@ run_task = EcsRunTaskOperator( task_id="run_task", cluster=EXISTING_CLUSTER_NAME, - task_definition=register_task.output, + task_definition=registered_task_definition, launch_type="EC2", overrides={ "containerOverrides": [ @@ -156,7 +159,7 @@ deregister_task = EcsDeregisterTaskDefinitionOperator( task_id='deregister_task', trigger_rule=TriggerRule.ALL_DONE, - task_definition=register_task.output, + task_definition=registered_task_definition, ) # [END howto_operator_ecs_deregister_task_definition] diff --git a/airflow/providers/amazon/aws/example_dags/example_emr.py b/airflow/providers/amazon/aws/example_dags/example_emr.py index 9ddeb5498d2fe..6390608e5067d 100644 --- a/airflow/providers/amazon/aws/example_dags/example_emr.py +++ b/airflow/providers/amazon/aws/example_dags/example_emr.py @@ -17,6 +17,7 @@ # under the License. import os from datetime import datetime +from typing import cast from airflow import DAG from airflow.models.baseoperator import chain @@ -79,23 +80,22 @@ ) # [END howto_operator_emr_create_job_flow] + job_flow_id = cast(str, job_flow_creator.output) + # [START howto_sensor_emr_job_flow] - job_sensor = EmrJobFlowSensor( - task_id='check_job_flow', - job_flow_id=job_flow_creator.output, - ) + job_sensor = EmrJobFlowSensor(task_id='check_job_flow', job_flow_id=job_flow_id) # [END howto_sensor_emr_job_flow] # [START howto_operator_emr_modify_cluster] cluster_modifier = EmrModifyClusterOperator( - task_id='modify_cluster', cluster_id=job_flow_creator.output, step_concurrency_level=1 + task_id='modify_cluster', cluster_id=job_flow_id, step_concurrency_level=1 ) # [END howto_operator_emr_modify_cluster] # [START howto_operator_emr_add_steps] step_adder = EmrAddStepsOperator( task_id='add_steps', - job_flow_id=job_flow_creator.output, + job_flow_id=job_flow_id, steps=SPARK_STEPS, ) # [END howto_operator_emr_add_steps] @@ -103,7 +103,7 @@ # [START howto_sensor_emr_step] step_checker = EmrStepSensor( task_id='watch_step', - job_flow_id=job_flow_creator.output, + job_flow_id=job_flow_id, step_id="{{ task_instance.xcom_pull(task_ids='add_steps', key='return_value')[0] }}", ) # [END howto_sensor_emr_step] @@ -111,7 +111,7 @@ # [START howto_operator_emr_terminate_job_flow] cluster_remover = EmrTerminateJobFlowOperator( task_id='remove_cluster', - job_flow_id=job_flow_creator.output, + job_flow_id=job_flow_id, ) # [END howto_operator_emr_terminate_job_flow] diff --git a/airflow/providers/google/cloud/example_dags/example_automl_nl_text_classification.py b/airflow/providers/google/cloud/example_dags/example_automl_nl_text_classification.py index cfa1eb17c88b8..82c96ec578c52 100644 --- a/airflow/providers/google/cloud/example_dags/example_automl_nl_text_classification.py +++ b/airflow/providers/google/cloud/example_dags/example_automl_nl_text_classification.py @@ -21,8 +21,10 @@ """ import os from datetime import datetime +from typing import cast from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.automl import ( AutoMLCreateDatasetOperator, @@ -67,7 +69,7 @@ task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION ) - dataset_id = create_dataset_task.output['dataset_id'] + dataset_id = cast(str, XComArg(create_dataset_task, key='dataset_id')) import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", @@ -80,7 +82,7 @@ create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) - model_id = create_model.output['model_id'] + model_id = cast(str, XComArg(create_model, key='model_id')) delete_model_task = AutoMLDeleteModelOperator( task_id="delete_model_task", diff --git a/airflow/providers/google/cloud/example_dags/example_automl_nl_text_sentiment.py b/airflow/providers/google/cloud/example_dags/example_automl_nl_text_sentiment.py index a85ce801417e4..07236ee75ee6a 100644 --- a/airflow/providers/google/cloud/example_dags/example_automl_nl_text_sentiment.py +++ b/airflow/providers/google/cloud/example_dags/example_automl_nl_text_sentiment.py @@ -21,8 +21,10 @@ """ import os from datetime import datetime +from typing import cast from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.automl import ( AutoMLCreateDatasetOperator, @@ -68,7 +70,7 @@ task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION ) - dataset_id = create_dataset_task.output['dataset_id'] + dataset_id = cast(str, XComArg(create_dataset_task, key='dataset_id')) import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", @@ -81,7 +83,7 @@ create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) - model_id = create_model.output['model_id'] + model_id = cast(str, XComArg(create_model, key='model_id')) delete_model_task = AutoMLDeleteModelOperator( task_id="delete_model_task", diff --git a/airflow/providers/google/cloud/example_dags/example_automl_tables.py b/airflow/providers/google/cloud/example_dags/example_automl_tables.py index 53de60ca36a3b..30cfdb0baaab0 100644 --- a/airflow/providers/google/cloud/example_dags/example_automl_tables.py +++ b/airflow/providers/google/cloud/example_dags/example_automl_tables.py @@ -22,9 +22,10 @@ import os from copy import deepcopy from datetime import datetime -from typing import Dict, List +from typing import Dict, List, cast from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.automl import ( AutoMLBatchPredictOperator, @@ -103,7 +104,7 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str: project_id=GCP_PROJECT_ID, ) - dataset_id = create_dataset_task.output['dataset_id'] + dataset_id = cast(str, XComArg(create_dataset_task, key='dataset_id')) # [END howto_operator_automl_create_dataset] MODEL["dataset_id"] = dataset_id @@ -158,7 +159,7 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str: project_id=GCP_PROJECT_ID, ) - model_id = create_model_task.output['model_id'] + model_id = cast(str, XComArg(create_model_task, key='model_id')) # [END howto_operator_automl_create_model] # [START howto_operator_automl_delete_model] @@ -209,7 +210,7 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str: project_id=GCP_PROJECT_ID, ) - dataset_id = create_dataset_task2.output['dataset_id'] + dataset_id = cast(str, XComArg(create_dataset_task2, key='dataset_id')) import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", diff --git a/airflow/providers/google/cloud/example_dags/example_automl_translation.py b/airflow/providers/google/cloud/example_dags/example_automl_translation.py index 96370c7b9564d..a4274c0ff3df3 100644 --- a/airflow/providers/google/cloud/example_dags/example_automl_translation.py +++ b/airflow/providers/google/cloud/example_dags/example_automl_translation.py @@ -21,8 +21,10 @@ """ import os from datetime import datetime +from typing import cast from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.automl import ( AutoMLCreateDatasetOperator, @@ -74,7 +76,7 @@ task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION ) - dataset_id = create_dataset_task.output["dataset_id"] + dataset_id = cast(str, XComArg(create_dataset_task, key="dataset_id")) import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", @@ -87,7 +89,7 @@ create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) - model_id = create_model.output["model_id"] + model_id = cast(str, XComArg(create_model, key="model_id")) delete_model_task = AutoMLDeleteModelOperator( task_id="delete_model_task", diff --git a/airflow/providers/google/cloud/example_dags/example_automl_video_intelligence_classification.py b/airflow/providers/google/cloud/example_dags/example_automl_video_intelligence_classification.py index f98be55bb0e39..2fd10578c7b10 100644 --- a/airflow/providers/google/cloud/example_dags/example_automl_video_intelligence_classification.py +++ b/airflow/providers/google/cloud/example_dags/example_automl_video_intelligence_classification.py @@ -21,8 +21,10 @@ """ import os from datetime import datetime +from typing import cast from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.automl import ( AutoMLCreateDatasetOperator, @@ -71,7 +73,7 @@ task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION ) - dataset_id = create_dataset_task.output["dataset_id"] + dataset_id = cast(str, XComArg(create_dataset_task, key="dataset_id")) import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", @@ -84,7 +86,7 @@ create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) - model_id = create_model.output["model_id"] + model_id = cast(str, XComArg(create_model, key="model_id")) delete_model_task = AutoMLDeleteModelOperator( task_id="delete_model_task", diff --git a/airflow/providers/google/cloud/example_dags/example_automl_video_intelligence_tracking.py b/airflow/providers/google/cloud/example_dags/example_automl_video_intelligence_tracking.py index 6c3aacdf3b62a..21c099ef082d5 100644 --- a/airflow/providers/google/cloud/example_dags/example_automl_video_intelligence_tracking.py +++ b/airflow/providers/google/cloud/example_dags/example_automl_video_intelligence_tracking.py @@ -21,8 +21,10 @@ """ import os from datetime import datetime +from typing import cast from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.automl import ( AutoMLCreateDatasetOperator, @@ -72,7 +74,7 @@ task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION ) - dataset_id = create_dataset_task.output["dataset_id"] + dataset_id = cast(str, XComArg(create_dataset_task, key="dataset_id")) import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", @@ -85,7 +87,7 @@ create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) - model_id = create_model.output["model_id"] + model_id = cast(str, XComArg(create_model, key="model_id")) delete_model_task = AutoMLDeleteModelOperator( task_id="delete_model_task", diff --git a/airflow/providers/google/cloud/example_dags/example_automl_vision_object_detection.py b/airflow/providers/google/cloud/example_dags/example_automl_vision_object_detection.py index 5cf0299e823c6..8ab472b322d1f 100644 --- a/airflow/providers/google/cloud/example_dags/example_automl_vision_object_detection.py +++ b/airflow/providers/google/cloud/example_dags/example_automl_vision_object_detection.py @@ -21,8 +21,10 @@ """ import os from datetime import datetime +from typing import cast from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.automl import ( AutoMLCreateDatasetOperator, @@ -71,7 +73,7 @@ task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION ) - dataset_id = create_dataset_task.output["dataset_id"] + dataset_id = cast(str, XComArg(create_dataset_task, key="dataset_id")) import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", @@ -84,7 +86,7 @@ create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) - model_id = create_model.output["model_id"] + model_id = cast(str, XComArg(create_model, key="model_id")) delete_model_task = AutoMLDeleteModelOperator( task_id="delete_model_task", diff --git a/airflow/providers/google/cloud/example_dags/example_bigquery_dts.py b/airflow/providers/google/cloud/example_dags/example_bigquery_dts.py index 06380827f5450..0a53682f36128 100644 --- a/airflow/providers/google/cloud/example_dags/example_bigquery_dts.py +++ b/airflow/providers/google/cloud/example_dags/example_bigquery_dts.py @@ -22,8 +22,10 @@ import os import time from datetime import datetime +from typing import cast from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.providers.google.cloud.operators.bigquery_dts import ( BigQueryCreateDataTransferOperator, BigQueryDataTransferServiceStartTransferRunsOperator, @@ -75,7 +77,7 @@ task_id="gcp_bigquery_create_transfer", ) - transfer_config_id = gcp_bigquery_create_transfer.output["transfer_config_id"] + transfer_config_id = cast(str, XComArg(gcp_bigquery_create_transfer, key="transfer_config_id")) # [END howto_bigquery_create_data_transfer] # [START howto_bigquery_start_transfer] @@ -90,7 +92,7 @@ gcp_run_sensor = BigQueryDataTransferServiceTransferRunSensor( task_id="gcp_run_sensor", transfer_config_id=transfer_config_id, - run_id=gcp_bigquery_start_transfer.output["run_id"], + run_id=cast(str, XComArg(gcp_bigquery_start_transfer, key="run_id")), expected_statuses={"SUCCEEDED"}, ) # [END howto_bigquery_dts_sensor] diff --git a/airflow/providers/google/cloud/example_dags/example_datafusion.py b/airflow/providers/google/cloud/example_dags/example_datafusion.py index afadb1926279e..5160b28299c23 100644 --- a/airflow/providers/google/cloud/example_dags/example_datafusion.py +++ b/airflow/providers/google/cloud/example_dags/example_datafusion.py @@ -20,6 +20,7 @@ """ import os from datetime import datetime +from typing import cast from airflow import models from airflow.operators.bash import BashOperator @@ -221,7 +222,7 @@ start_pipeline_sensor = CloudDataFusionPipelineStateSensor( task_id="pipeline_state_sensor", pipeline_name=PIPELINE_NAME, - pipeline_id=start_pipeline_async.output, + pipeline_id=cast(str, start_pipeline_async.output), expected_statuses=["COMPLETED"], failure_statuses=["FAILED"], instance_name=INSTANCE_NAME, diff --git a/airflow/providers/google/cloud/example_dags/example_looker.py b/airflow/providers/google/cloud/example_dags/example_looker.py index ac583c6a1cd3f..86098f0afdf7f 100644 --- a/airflow/providers/google/cloud/example_dags/example_looker.py +++ b/airflow/providers/google/cloud/example_dags/example_looker.py @@ -21,6 +21,7 @@ """ from datetime import datetime +from typing import cast from airflow import models from airflow.providers.google.cloud.operators.looker import LookerStartPdtBuildOperator @@ -43,7 +44,7 @@ check_pdt_task_async_sensor = LookerCheckPdtBuildSensor( task_id='check_pdt_task_async_sensor', looker_conn_id='your_airflow_connection_for_looker', - materialization_id=start_pdt_task_async.output, + materialization_id=cast(str, start_pdt_task_async.output), poke_interval=10, ) # [END cloud_looker_async_start_pdt_sensor] diff --git a/airflow/providers/google/cloud/example_dags/example_vision.py b/airflow/providers/google/cloud/example_dags/example_vision.py index 24bc094e02845..7369c073ec19a 100644 --- a/airflow/providers/google/cloud/example_dags/example_vision.py +++ b/airflow/providers/google/cloud/example_dags/example_vision.py @@ -33,6 +33,7 @@ import os from datetime import datetime +from typing import cast from airflow import models from airflow.operators.bash import BashOperator @@ -135,7 +136,7 @@ ) # [END howto_operator_vision_product_set_create] - product_set_create_output = product_set_create.output + product_set_create_output = cast(str, product_set_create.output) # [START howto_operator_vision_product_set_get] product_set_get = CloudVisionGetProductSetOperator( @@ -172,7 +173,7 @@ ) # [END howto_operator_vision_product_create] - product_create_output = product_create.output + product_create_output = cast(str, product_create.output) # [START howto_operator_vision_product_get] product_get = CloudVisionGetProductOperator( diff --git a/airflow/providers/google/marketing_platform/example_dags/example_display_video.py b/airflow/providers/google/marketing_platform/example_dags/example_display_video.py index 4ddedf8f35826..1a48a2361783e 100644 --- a/airflow/providers/google/marketing_platform/example_dags/example_display_video.py +++ b/airflow/providers/google/marketing_platform/example_dags/example_display_video.py @@ -20,9 +20,10 @@ """ import os from datetime import datetime -from typing import Dict +from typing import Dict, cast from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.providers.google.cloud.transfers.gcs_to_bigquery import GCSToBigQueryOperator from airflow.providers.google.marketing_platform.hooks.display_video import GoogleDisplayVideo360Hook from airflow.providers.google.marketing_platform.operators.display_video import ( @@ -91,7 +92,7 @@ ) as dag1: # [START howto_google_display_video_createquery_report_operator] create_report = GoogleDisplayVideo360CreateReportOperator(body=REPORT, task_id="create_report") - report_id = create_report.output["report_id"] + report_id = cast(str, XComArg(create_report, key="report_id")) # [END howto_google_display_video_createquery_report_operator] # [START howto_google_display_video_runquery_report_operator] diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index cf80e6ede4f85..c797cc06b273b 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -434,3 +434,48 @@ def test_expand_kwargs_render_template_fields_validating_operator(dag_maker, ses assert isinstance(op, MockOperator) assert op.arg1 == expected assert op.arg2 == "a" + + +def test_xcomarg_property_of_mapped_operator(dag_maker): + with dag_maker("test_xcomarg_property_of_mapped_operator"): + op_a = MockOperator.partial(task_id="a").expand(arg1=["x", "y", "z"]) + dag_maker.create_dagrun() + + assert op_a.output == XComArg(op_a) + + +def test_set_xcomarg_dependencies_with_mapped_operator(dag_maker): + with dag_maker("test_set_xcomargs_dependencies_with_mapped_operator"): + op1 = MockOperator.partial(task_id="op1").expand(arg1=[1, 2, 3]) + op2 = MockOperator.partial(task_id="op2").expand(arg2=["a", "b", "c"]) + op3 = MockOperator(task_id="op3", arg1=op1.output) + op4 = MockOperator(task_id="op4", arg1=[op1.output, op2.output]) + op5 = MockOperator(task_id="op5", arg1={"op1": op1.output, "op2": op2.output}) + + assert op1 in op3.upstream_list + assert op1 in op4.upstream_list + assert op2 in op4.upstream_list + assert op1 in op5.upstream_list + assert op2 in op5.upstream_list + + +def test_all_xcomargs_from_mapped_tasks_are_consumable(dag_maker, session): + class PushXcomOperator(MockOperator): + def __init__(self, arg1, **kwargs): + super().__init__(arg1=arg1, **kwargs) + + def execute(self, context): + return self.arg1 + + class ConsumeXcomOperator(PushXcomOperator): + def execute(self, context): + assert {i for i in self.arg1} == {1, 2, 3} + + with dag_maker("test_all_xcomargs_from_mapped_tasks_are_consumable"): + op1 = PushXcomOperator.partial(task_id="op1").expand(arg1=[1, 2, 3]) + ConsumeXcomOperator(task_id="op2", arg1=op1.output) + + dr = dag_maker.create_dagrun() + tis = dr.get_task_instances(session=session) + for ti in tis: + ti.run() diff --git a/tests/system/providers/airbyte/example_airbyte_trigger_job.py b/tests/system/providers/airbyte/example_airbyte_trigger_job.py index e57c573ebe5b5..fecf9353ed1c9 100644 --- a/tests/system/providers/airbyte/example_airbyte_trigger_job.py +++ b/tests/system/providers/airbyte/example_airbyte_trigger_job.py @@ -20,6 +20,7 @@ import os from datetime import datetime, timedelta +from typing import cast from airflow import DAG from airflow.providers.airbyte.operators.airbyte import AirbyteTriggerSyncOperator @@ -54,7 +55,7 @@ airbyte_sensor = AirbyteJobSensor( task_id='airbyte_sensor_source_dest_example', - airbyte_job_id=async_source_destination.output, + airbyte_job_id=cast(int, async_source_destination.output), ) # [END howto_operator_airbyte_asynchronous] diff --git a/tests/system/providers/amazon/aws/example_athena.py b/tests/system/providers/amazon/aws/example_athena.py index ea9e6cc114051..34c8af6970fd4 100644 --- a/tests/system/providers/amazon/aws/example_athena.py +++ b/tests/system/providers/amazon/aws/example_athena.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime +from typing import cast import boto3 @@ -121,7 +122,7 @@ def read_results_from_s3(bucket_name, query_execution_id): # [START howto_sensor_athena] await_query = AthenaSensor( task_id='await_query', - query_execution_id=read_table.output, + query_execution_id=cast(str, read_table.output), ) # [END howto_sensor_athena] diff --git a/tests/system/providers/amazon/aws/example_batch.py b/tests/system/providers/amazon/aws/example_batch.py index 0f381ff1c1dc9..cd52567ce93fd 100644 --- a/tests/system/providers/amazon/aws/example_batch.py +++ b/tests/system/providers/amazon/aws/example_batch.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime +from typing import cast import boto3 @@ -193,7 +194,7 @@ def delete_job_queue(job_queue_name): # [START howto_sensor_batch] wait_for_batch_job = BatchSensor( task_id='wait_for_batch_job', - job_id=submit_batch_job.output, + job_id=cast(str, submit_batch_job.output), ) # [END howto_sensor_batch] diff --git a/tests/system/providers/amazon/aws/example_emr_serverless.py b/tests/system/providers/amazon/aws/example_emr_serverless.py index 2a5450b796ea2..227eb85cf18f2 100644 --- a/tests/system/providers/amazon/aws/example_emr_serverless.py +++ b/tests/system/providers/amazon/aws/example_emr_serverless.py @@ -17,6 +17,7 @@ from datetime import datetime +from typing import cast from airflow.models.baseoperator import chain from airflow.models.dag import DAG @@ -75,17 +76,19 @@ ) # [END howto_operator_emr_serverless_create_application] + emr_serverless_app_id = cast(str, emr_serverless_app.output) + # [START howto_sensor_emr_serverless_application] wait_for_app_creation = EmrServerlessApplicationSensor( task_id='wait_for_app_creation', - application_id=emr_serverless_app.output, + application_id=emr_serverless_app_id, ) # [END howto_sensor_emr_serverless_application] # [START howto_operator_emr_serverless_start_job] start_job = EmrServerlessStartJobOperator( task_id='start_emr_serverless_job', - application_id=emr_serverless_app.output, + application_id=emr_serverless_app_id, execution_role_arn=role_arn, job_driver=SPARK_JOB_DRIVER, configuration_overrides=SPARK_CONFIGURATION_OVERRIDES, @@ -94,14 +97,14 @@ # [START howto_sensor_emr_serverless_job] wait_for_job = EmrServerlessJobSensor( - task_id='wait_for_job', application_id=emr_serverless_app.output, job_run_id=start_job.output + task_id='wait_for_job', application_id=emr_serverless_app_id, job_run_id=cast(str, start_job.output) ) # [END howto_sensor_emr_serverless_job] # [START howto_operator_emr_serverless_delete_application] delete_app = EmrServerlessDeleteApplicationOperator( task_id='delete_application', - application_id=emr_serverless_app.output, + application_id=emr_serverless_app_id, trigger_rule=TriggerRule.ALL_DONE, ) # [END howto_operator_emr_serverless_delete_application] diff --git a/tests/system/providers/amazon/aws/example_glue.py b/tests/system/providers/amazon/aws/example_glue.py index 4072838be0271..c12b066367efc 100644 --- a/tests/system/providers/amazon/aws/example_glue.py +++ b/tests/system/providers/amazon/aws/example_glue.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, cast import boto3 from botocore.client import BaseClient @@ -162,7 +162,7 @@ def glue_cleanup(crawler_name: str, job_name: str, db_name: str) -> None: job_name=glue_job_name, script_location=f's3://{bucket_name}/etl_script.py', s3_bucket=bucket_name, - iam_role_name=role_name, + iam_role_name=cast(str, role_name), create_job_kwargs={'GlueVersion': '3.0', 'NumberOfWorkers': 2, 'WorkerType': 'G.1X'}, # Waits by default, set False to test the Sensor below wait_for_completion=False, @@ -174,7 +174,7 @@ def glue_cleanup(crawler_name: str, job_name: str, db_name: str) -> None: task_id='wait_for_job', job_name=glue_job_name, # Job ID extracted from previous Glue Job Operator task - run_id=submit_glue_job.output, + run_id=cast(str, submit_glue_job.output), ) # [END howto_sensor_glue] @@ -199,7 +199,7 @@ def glue_cleanup(crawler_name: str, job_name: str, db_name: str) -> None: # TEST TEARDOWN glue_cleanup(glue_crawler_name, glue_job_name, glue_db_name), delete_bucket, - delete_logs(submit_glue_job.output, glue_crawler_name), + delete_logs(cast(str, submit_glue_job.output), glue_crawler_name), ) from tests.system.utils.watcher import watcher diff --git a/tests/system/providers/amazon/aws/example_step_functions.py b/tests/system/providers/amazon/aws/example_step_functions.py index fc0270b5afab1..760094ef3b3c1 100644 --- a/tests/system/providers/amazon/aws/example_step_functions.py +++ b/tests/system/providers/amazon/aws/example_step_functions.py @@ -16,6 +16,7 @@ # under the License. import json from datetime import datetime +from typing import cast from airflow import DAG from airflow.decorators import task @@ -79,19 +80,21 @@ def delete_state_machine(state_machine_arn): # [START howto_operator_step_function_start_execution] start_execution = StepFunctionStartExecutionOperator( - task_id='start_execution', state_machine_arn=state_machine_arn + task_id='start_execution', state_machine_arn=cast(str, state_machine_arn) ) # [END howto_operator_step_function_start_execution] + execution_arn = cast(str, start_execution.output) + # [START howto_sensor_step_function_execution] wait_for_execution = StepFunctionExecutionSensor( - task_id='wait_for_execution', execution_arn=start_execution.output + task_id='wait_for_execution', execution_arn=execution_arn ) # [END howto_sensor_step_function_execution] # [START howto_operator_step_function_get_execution_output] get_execution_output = StepFunctionGetExecutionOutputOperator( - task_id='get_execution_output', execution_arn=start_execution.output + task_id='get_execution_output', execution_arn=execution_arn ) # [END howto_operator_step_function_get_execution_output] diff --git a/tests/system/providers/dbt/cloud/example_dbt_cloud.py b/tests/system/providers/dbt/cloud/example_dbt_cloud.py index c6445e0f78313..5a97b41be5594 100644 --- a/tests/system/providers/dbt/cloud/example_dbt_cloud.py +++ b/tests/system/providers/dbt/cloud/example_dbt_cloud.py @@ -16,6 +16,7 @@ # under the License. from datetime import datetime +from typing import cast from airflow.models import DAG @@ -56,7 +57,7 @@ # [START howto_operator_dbt_cloud_get_artifact] get_run_results_artifact = DbtCloudGetJobRunArtifactOperator( - task_id="get_run_results_artifact", run_id=trigger_job_run1.output, path="run_results.json" + task_id="get_run_results_artifact", run_id=cast(int, trigger_job_run1.output), path="run_results.json" ) # [END howto_operator_dbt_cloud_get_artifact] @@ -71,7 +72,7 @@ # [START howto_operator_dbt_cloud_run_job_sensor] job_run_sensor = DbtCloudJobRunSensor( - task_id="job_run_sensor", run_id=trigger_job_run2.output, timeout=20 + task_id="job_run_sensor", run_id=cast(int, trigger_job_run2.output), timeout=20 ) # [END howto_operator_dbt_cloud_run_job_sensor] diff --git a/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py b/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py index 6deb89be6dc4c..246afe1c1b0f1 100644 --- a/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py +++ b/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py @@ -21,8 +21,10 @@ """ import os from datetime import datetime +from typing import cast from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.automl import ( AutoMLCreateDatasetOperator, @@ -72,7 +74,7 @@ task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION ) - dataset_id = create_dataset_task.output['dataset_id'] + dataset_id = cast(str, XComArg(create_dataset_task, key='dataset_id')) import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", @@ -85,7 +87,7 @@ create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) - model_id = create_model.output['model_id'] + model_id = cast(str, XComArg(create_model, key='model_id')) delete_model_task = AutoMLDeleteModelOperator( task_id="delete_model_task", diff --git a/tests/system/providers/google/cloud/automl/example_automl_vision_classification.py b/tests/system/providers/google/cloud/automl/example_automl_vision_classification.py index 647e033216c94..e8cad1dfc0bf9 100644 --- a/tests/system/providers/google/cloud/automl/example_automl_vision_classification.py +++ b/tests/system/providers/google/cloud/automl/example_automl_vision_classification.py @@ -21,8 +21,10 @@ """ import os from datetime import datetime +from typing import cast from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.automl import ( AutoMLCreateDatasetOperator, @@ -73,7 +75,7 @@ task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION ) - dataset_id = create_dataset_task.output["dataset_id"] + dataset_id = cast(str, XComArg(create_dataset_task, key="dataset_id")) import_dataset_task = AutoMLImportDataOperator( task_id="import_dataset_task", @@ -86,7 +88,7 @@ create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) - model_id = create_model.output["model_id"] + model_id = cast(str, XComArg(create_model, key="model_id")) delete_model_task = AutoMLDeleteModelOperator( task_id="delete_model_task", diff --git a/tests/system/providers/google/cloud/cloud_build/example_cloud_build.py b/tests/system/providers/google/cloud/cloud_build/example_cloud_build.py index 311ba953ca9fa..a81e00eda48e9 100644 --- a/tests/system/providers/google/cloud/cloud_build/example_cloud_build.py +++ b/tests/system/providers/google/cloud/cloud_build/example_cloud_build.py @@ -31,13 +31,14 @@ import os from datetime import datetime from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, cast import yaml from future.backports.urllib.parse import urlparse from airflow import models from airflow.models.baseoperator import chain +from airflow.models.xcom_arg import XComArg from airflow.operators.bash import BashOperator from airflow.providers.google.cloud.operators.cloud_build import ( CloudBuildCancelBuildOperator, @@ -106,7 +107,7 @@ # [START howto_operator_create_build_from_storage_result] create_build_from_storage_result = BashOperator( - bash_command=f"echo { create_build_from_storage.output['results'] }", + bash_command=f"echo {cast(str, XComArg(create_build_from_storage, key='results'))}", task_id="create_build_from_storage_result", ) # [END howto_operator_create_build_from_storage_result] @@ -119,7 +120,7 @@ # [START howto_operator_create_build_from_repo_result] create_build_from_repo_result = BashOperator( - bash_command=f"echo { create_build_from_repo.output['results'] }", + bash_command=f"echo {cast(str, XComArg(create_build_from_repo, key='results'))}", task_id="create_build_from_repo_result", ) # [END howto_operator_create_build_from_repo_result] @@ -142,7 +143,7 @@ # [START howto_operator_cancel_build] cancel_build = CloudBuildCancelBuildOperator( task_id="cancel_build", - id_=create_build_without_wait.output['id'], + id_=cast(str, XComArg(create_build_without_wait, key='id')), project_id=PROJECT_ID, ) # [END howto_operator_cancel_build] @@ -150,7 +151,7 @@ # [START howto_operator_retry_build] retry_build = CloudBuildRetryBuildOperator( task_id="retry_build", - id_=cancel_build.output['id'], + id_=cast(str, XComArg(cancel_build, key='id')), project_id=PROJECT_ID, ) # [END howto_operator_retry_build] @@ -158,7 +159,7 @@ # [START howto_operator_get_build] get_build = CloudBuildGetBuildOperator( task_id="get_build", - id_=retry_build.output['id'], + id_=cast(str, XComArg(retry_build, key='id')), project_id=PROJECT_ID, ) # [END howto_operator_get_build] diff --git a/tests/system/providers/google/cloud/cloud_build/example_cloud_build_trigger.py b/tests/system/providers/google/cloud/cloud_build/example_cloud_build_trigger.py index 07d7a911e41d1..bea49a3d728e3 100644 --- a/tests/system/providers/google/cloud/cloud_build/example_cloud_build_trigger.py +++ b/tests/system/providers/google/cloud/cloud_build/example_cloud_build_trigger.py @@ -24,10 +24,11 @@ import os from datetime import datetime -from typing import Any, Dict +from typing import Any, Dict, cast from airflow import models from airflow.models.baseoperator import chain +from airflow.models.xcom_arg import XComArg from airflow.providers.google.cloud.operators.cloud_build import ( CloudBuildCreateBuildTriggerOperator, CloudBuildDeleteBuildTriggerOperator, @@ -90,11 +91,13 @@ ) # [END howto_operator_create_build_trigger] + build_trigger_id = cast(str, XComArg(create_build_trigger, key="id")) + # [START howto_operator_run_build_trigger] run_build_trigger = CloudBuildRunBuildTriggerOperator( task_id="run_build_trigger", project_id=PROJECT_ID, - trigger_id=create_build_trigger.output['id'], + trigger_id=build_trigger_id, source=create_build_from_repo_body['source']['repo_source'], ) # [END howto_operator_run_build_trigger] @@ -103,7 +106,7 @@ update_build_trigger = CloudBuildUpdateBuildTriggerOperator( task_id="update_build_trigger", project_id=PROJECT_ID, - trigger_id=create_build_trigger.output['id'], + trigger_id=build_trigger_id, trigger=update_build_trigger_body, ) # [END howto_operator_update_build_trigger] @@ -112,7 +115,7 @@ get_build_trigger = CloudBuildGetBuildTriggerOperator( task_id="get_build_trigger", project_id=PROJECT_ID, - trigger_id=create_build_trigger.output['id'], + trigger_id=build_trigger_id, ) # [END howto_operator_get_build_trigger] @@ -120,7 +123,7 @@ delete_build_trigger = CloudBuildDeleteBuildTriggerOperator( task_id="delete_build_trigger", project_id=PROJECT_ID, - trigger_id=create_build_trigger.output['id'], + trigger_id=build_trigger_id, ) # [END howto_operator_delete_build_trigger] diff --git a/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql.py b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql.py index 125be7ecb397b..f9a56cd3ecd07 100644 --- a/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql.py +++ b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql.py @@ -32,6 +32,7 @@ from urllib.parse import urlsplit from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.providers.google.cloud.operators.cloud_sql import ( CloudSQLCreateInstanceDatabaseOperator, CloudSQLCreateInstanceOperator, @@ -196,9 +197,11 @@ # For export & import to work we need to add the Cloud SQL instance's Service Account # write access to the destination GCS bucket. + service_account_email = XComArg(sql_instance_create_task, key='service_account_email') + # [START howto_operator_cloudsql_export_gcs_permissions] sql_gcp_add_bucket_permission_task = GCSBucketCreateAclEntryOperator( - entity=f"user-{sql_instance_create_task.output['service_account_email']}", + entity=f"user-{service_account_email}", role="WRITER", bucket=file_url_split[1], # netloc (bucket) task_id='sql_gcp_add_bucket_permission_task', @@ -215,7 +218,7 @@ # read access to the target GCS object. # [START howto_operator_cloudsql_import_gcs_permissions] sql_gcp_add_object_permission_task = GCSObjectCreateAclEntryOperator( - entity=f"user-{sql_instance_create_task.output['service_account_email']}", + entity=f"user-{service_account_email}", role="READER", bucket=file_url_split[1], # netloc (bucket) object_name=file_url_split[2][1:], # path (strip first '/') diff --git a/tests/system/providers/google/cloud/dataproc/example_dataproc_spark_async.py b/tests/system/providers/google/cloud/dataproc/example_dataproc_spark_async.py index 9890e553b179a..9cd9b6ad2bda7 100644 --- a/tests/system/providers/google/cloud/dataproc/example_dataproc_spark_async.py +++ b/tests/system/providers/google/cloud/dataproc/example_dataproc_spark_async.py @@ -21,6 +21,7 @@ import os from datetime import datetime +from typing import cast from airflow import models from airflow.providers.google.cloud.operators.dataproc import ( @@ -90,7 +91,7 @@ task_id='spark_task_async_sensor_task', region=REGION, project_id=PROJECT_ID, - dataproc_job_id=spark_task_async.output, + dataproc_job_id=cast(str, spark_task_async.output), poke_interval=10, ) # [END cloud_dataproc_async_submit_sensor] diff --git a/tests/system/providers/google/cloud/datastore/example_datastore_rollback.py b/tests/system/providers/google/cloud/datastore/example_datastore_rollback.py index 32a1a017fcbcf..7316efc536d12 100644 --- a/tests/system/providers/google/cloud/datastore/example_datastore_rollback.py +++ b/tests/system/providers/google/cloud/datastore/example_datastore_rollback.py @@ -22,7 +22,7 @@ import os from datetime import datetime -from typing import Any, Dict +from typing import Any, Dict, cast from airflow import models from airflow.providers.google.cloud.operators.datastore import ( @@ -54,7 +54,7 @@ # [START how_to_rollback_transaction] rollback_transaction = CloudDatastoreRollbackOperator( task_id="rollback_transaction", - transaction=begin_transaction_to_rollback.output, + transaction=cast(str, begin_transaction_to_rollback.output), ) # [END how_to_rollback_transaction] diff --git a/tests/system/providers/google/cloud/gcs/example_sheets.py b/tests/system/providers/google/cloud/gcs/example_sheets.py index 415c50725dd5a..75d42400a8f64 100644 --- a/tests/system/providers/google/cloud/gcs/example_sheets.py +++ b/tests/system/providers/google/cloud/gcs/example_sheets.py @@ -20,6 +20,7 @@ from datetime import datetime from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.operators.bash import BashOperator from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator from airflow.providers.google.cloud.transfers.sheets_to_gcs import GoogleSheetsToGCSOperator @@ -68,7 +69,7 @@ # [START print_spreadsheet_url] print_spreadsheet_url = BashOperator( task_id="print_spreadsheet_url", - bash_command=f"echo {create_spreadsheet.output['spreadsheet_url']}", + bash_command=f"echo {XComArg(create_spreadsheet, key='spreadsheet_url')}", ) # [END print_spreadsheet_url] diff --git a/tests/system/providers/google/cloud/pubsub/example_pubsub.py b/tests/system/providers/google/cloud/pubsub/example_pubsub.py index 6e4c6d765857a..de00991324353 100644 --- a/tests/system/providers/google/cloud/pubsub/example_pubsub.py +++ b/tests/system/providers/google/cloud/pubsub/example_pubsub.py @@ -21,6 +21,7 @@ """ import os from datetime import datetime +from typing import cast from airflow import models from airflow.operators.bash import BashOperator @@ -71,7 +72,7 @@ # [END howto_operator_gcp_pubsub_create_subscription] # [START howto_operator_gcp_pubsub_pull_message_with_sensor] - subscription = subscribe_task.output + subscription = cast(str, subscribe_task.output) pull_messages = PubSubPullSensor( task_id="pull_messages", diff --git a/tests/system/providers/google/cloud/workflows/example_workflows.py b/tests/system/providers/google/cloud/workflows/example_workflows.py index 567b04de8d250..acfba431046ca 100644 --- a/tests/system/providers/google/cloud/workflows/example_workflows.py +++ b/tests/system/providers/google/cloud/workflows/example_workflows.py @@ -17,10 +17,12 @@ import os from datetime import datetime +from typing import cast from google.protobuf.field_mask_pb2 import FieldMask from airflow import DAG +from airflow.models.xcom_arg import XComArg from airflow.providers.google.cloud.operators.workflows import ( WorkflowsCancelExecutionOperator, WorkflowsCreateExecutionOperator, @@ -143,7 +145,7 @@ ) # [END how_to_create_execution] - create_execution_id = create_execution.output["execution_id"] + create_execution_id = cast(str, XComArg(create_execution, key="execution_id")) # [START how_to_wait_for_execution] wait_for_execution = WorkflowExecutionSensor( @@ -187,7 +189,7 @@ workflow_id=SLEEP_WORKFLOW_ID, ) - cancel_execution_id = create_execution_for_cancel.output["execution_id"] + cancel_execution_id = cast(str, XComArg(create_execution_for_cancel, key="execution_id")) # [START how_to_cancel_execution] cancel_execution = WorkflowsCancelExecutionOperator( diff --git a/tests/system/providers/google/datacatalog/example_datacatalog_entries.py b/tests/system/providers/google/datacatalog/example_datacatalog_entries.py index 18ce63eb6e610..99f27af115db4 100644 --- a/tests/system/providers/google/datacatalog/example_datacatalog_entries.py +++ b/tests/system/providers/google/datacatalog/example_datacatalog_entries.py @@ -22,6 +22,7 @@ from google.protobuf.field_mask_pb2 import FieldMask from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.operators.bash import BashOperator from airflow.providers.google.cloud.operators.datacatalog import ( CloudDataCatalogCreateEntryGroupOperator, @@ -69,7 +70,7 @@ # [START howto_operator_gcp_datacatalog_create_entry_group_result] create_entry_group_result = BashOperator( task_id="create_entry_group_result", - bash_command=f"echo {create_entry_group.output['entry_group_id']}", + bash_command=f"echo {XComArg(create_entry_group, key='entry_group_id')}", ) # [END howto_operator_gcp_datacatalog_create_entry_group_result] @@ -90,7 +91,7 @@ # [START howto_operator_gcp_datacatalog_create_entry_gcs_result] create_entry_gcs_result = BashOperator( task_id="create_entry_gcs_result", - bash_command=f"echo {create_entry_gcs.output['entry_id']}", + bash_command=f"echo {XComArg(create_entry_gcs, key='entry_id')}", ) # [END howto_operator_gcp_datacatalog_create_entry_gcs_result] diff --git a/tests/system/providers/google/datacatalog/example_datacatalog_search_catalog.py b/tests/system/providers/google/datacatalog/example_datacatalog_search_catalog.py index 7e93bb6bf4130..c3e1cc5112c9b 100644 --- a/tests/system/providers/google/datacatalog/example_datacatalog_search_catalog.py +++ b/tests/system/providers/google/datacatalog/example_datacatalog_search_catalog.py @@ -18,10 +18,12 @@ import os from datetime import datetime +from typing import cast from google.cloud.datacatalog import TagField, TagTemplateField from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.operators.bash import BashOperator from airflow.providers.google.cloud.operators.datacatalog import ( CloudDataCatalogCreateEntryGroupOperator, @@ -73,7 +75,7 @@ # [START howto_operator_gcp_datacatalog_create_entry_group_result] create_entry_group_result = BashOperator( task_id="create_entry_group_result", - bash_command=f"echo {create_entry_group.output['entry_group_id']}", + bash_command=f"echo {XComArg(create_entry_group, key='entry_group_id')}", ) # [END howto_operator_gcp_datacatalog_create_entry_group_result] @@ -94,7 +96,7 @@ # [START howto_operator_gcp_datacatalog_create_entry_gcs_result] create_entry_gcs_result = BashOperator( task_id="create_entry_gcs_result", - bash_command=f"echo {create_entry_gcs.output['entry_id']}", + bash_command=f"echo {XComArg(create_entry_gcs, key='entry_id')}", ) # [END howto_operator_gcp_datacatalog_create_entry_gcs_result] @@ -109,10 +111,12 @@ ) # [END howto_operator_gcp_datacatalog_create_tag] + tag_id = cast(str, XComArg(create_tag, key='tag_id')) + # [START howto_operator_gcp_datacatalog_create_tag_result] create_tag_result = BashOperator( task_id="create_tag_result", - bash_command=f"echo {create_tag.output['tag_id']}", + bash_command=f"echo {tag_id}", ) # [END howto_operator_gcp_datacatalog_create_tag_result] @@ -135,7 +139,7 @@ # [START howto_operator_gcp_datacatalog_create_tag_template_result] create_tag_template_result = BashOperator( task_id="create_tag_template_result", - bash_command=f"echo {create_tag_template.output['tag_template_id']}", + bash_command=f"echo {XComArg(create_tag_template, key='tag_template_id')}", ) # [END howto_operator_gcp_datacatalog_create_tag_template_result] @@ -174,7 +178,7 @@ location=LOCATION, entry_group=ENTRY_GROUP_ID, entry=ENTRY_ID, - tag=create_tag.output["tag_id"], + tag=tag_id, ) # [END howto_operator_gcp_datacatalog_delete_tag] delete_tag.trigger_rule = TriggerRule.ALL_DONE diff --git a/tests/system/providers/google/datacatalog/example_datacatalog_tag_templates.py b/tests/system/providers/google/datacatalog/example_datacatalog_tag_templates.py index ab5357bb51003..22a6e8ce7e707 100644 --- a/tests/system/providers/google/datacatalog/example_datacatalog_tag_templates.py +++ b/tests/system/providers/google/datacatalog/example_datacatalog_tag_templates.py @@ -22,6 +22,7 @@ from google.cloud.datacatalog import FieldType, TagTemplateField from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.operators.bash import BashOperator from airflow.providers.google.cloud.operators.datacatalog import ( CloudDataCatalogCreateTagTemplateFieldOperator, @@ -73,7 +74,7 @@ # [START howto_operator_gcp_datacatalog_create_tag_template_result] create_tag_template_result = BashOperator( task_id="create_tag_template_result", - bash_command=f"echo {create_tag_template.output['tag_template_id']}", + bash_command=f"echo {XComArg(create_tag_template, key='tag_template_id')}", ) # [END howto_operator_gcp_datacatalog_create_tag_template_result] @@ -92,7 +93,7 @@ # [START howto_operator_gcp_datacatalog_create_tag_template_field_result] create_tag_template_field_result = BashOperator( task_id="create_tag_template_field_result", - bash_command=f"echo {create_tag_template_field.output['tag_template_field_id']}", + bash_command=f"echo {XComArg(create_tag_template_field, key='tag_template_field_id')}", ) # [END howto_operator_gcp_datacatalog_create_tag_template_field_result] diff --git a/tests/system/providers/google/datacatalog/example_datacatalog_tags.py b/tests/system/providers/google/datacatalog/example_datacatalog_tags.py index 277087ae07916..b841feb95887b 100644 --- a/tests/system/providers/google/datacatalog/example_datacatalog_tags.py +++ b/tests/system/providers/google/datacatalog/example_datacatalog_tags.py @@ -18,10 +18,12 @@ import os from datetime import datetime +from typing import cast from google.cloud.datacatalog import TagField, TagTemplateField from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.operators.bash import BashOperator from airflow.providers.google.cloud.operators.datacatalog import ( CloudDataCatalogCreateEntryGroupOperator, @@ -74,7 +76,7 @@ # [START howto_operator_gcp_datacatalog_create_entry_group_result] create_entry_group_result = BashOperator( task_id="create_entry_group_result", - bash_command=f"echo {create_entry_group.output['entry_group_id']}", + bash_command=f"echo {XComArg(create_entry_group, key='entry_group_id')}", ) # [END howto_operator_gcp_datacatalog_create_entry_group_result] @@ -95,7 +97,7 @@ # [START howto_operator_gcp_datacatalog_create_entry_gcs_result] create_entry_gcs_result = BashOperator( task_id="create_entry_gcs_result", - bash_command=f"echo {create_entry_gcs.output['entry_id']}", + bash_command=f"echo {XComArg(create_entry_gcs, key='entry_id')}", ) # [END howto_operator_gcp_datacatalog_create_entry_gcs_result] @@ -110,10 +112,12 @@ ) # [END howto_operator_gcp_datacatalog_create_tag] + tag_id = cast(str, XComArg(create_tag, key='tag_id')) + # [START howto_operator_gcp_datacatalog_create_tag_result] create_tag_result = BashOperator( task_id="create_tag_result", - bash_command=f"echo {create_tag.output['tag_id']}", + bash_command=f"echo {tag_id}", ) # [END howto_operator_gcp_datacatalog_create_tag_result] @@ -136,7 +140,7 @@ # [START howto_operator_gcp_datacatalog_create_tag_template_result] create_tag_template_result = BashOperator( task_id="create_tag_template_result", - bash_command=f"echo {create_tag_template.output['tag_template_id']}", + bash_command=f"echo {XComArg(create_tag_template, key='tag_template_id')}", ) # [END howto_operator_gcp_datacatalog_create_tag_template_result] @@ -160,7 +164,7 @@ location=LOCATION, entry_group=ENTRY_GROUP_ID, entry=ENTRY_ID, - tag_id=f"{create_tag.output['tag_id']}", + tag_id=tag_id, ) # [END howto_operator_gcp_datacatalog_update_tag] @@ -185,7 +189,7 @@ location=LOCATION, entry_group=ENTRY_GROUP_ID, entry=ENTRY_ID, - tag=create_tag.output["tag_id"], + tag=tag_id, ) # [END howto_operator_gcp_datacatalog_delete_tag] delete_tag.trigger_rule = TriggerRule.ALL_DONE diff --git a/tests/system/providers/google/marketing_platform/example_campaign_manager.py b/tests/system/providers/google/marketing_platform/example_campaign_manager.py index 23ac650908602..9c1ade2dfde50 100644 --- a/tests/system/providers/google/marketing_platform/example_campaign_manager.py +++ b/tests/system/providers/google/marketing_platform/example_campaign_manager.py @@ -21,8 +21,10 @@ import os import time from datetime import datetime +from typing import cast from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator from airflow.providers.google.marketing_platform.operators.campaign_manager import ( GoogleCampaignManagerBatchInsertConversionsOperator, @@ -105,14 +107,14 @@ create_report = GoogleCampaignManagerInsertReportOperator( profile_id=PROFILE_ID, report=REPORT, task_id="create_report" ) - report_id = create_report.output["report_id"] + report_id = cast(str, XComArg(create_report, key="report_id")) # [END howto_campaign_manager_insert_report_operator] # [START howto_campaign_manager_run_report_operator] run_report = GoogleCampaignManagerRunReportOperator( profile_id=PROFILE_ID, report_id=report_id, task_id="run_report" ) - file_id = run_report.output["file_id"] + file_id = cast(str, XComArg(run_report, key="file_id")) # [END howto_campaign_manager_run_report_operator] # [START howto_campaign_manager_wait_for_operation] diff --git a/tests/system/providers/google/marketing_platform/example_search_ads.py b/tests/system/providers/google/marketing_platform/example_search_ads.py index ac5d98b98e4ad..196845faaadc2 100644 --- a/tests/system/providers/google/marketing_platform/example_search_ads.py +++ b/tests/system/providers/google/marketing_platform/example_search_ads.py @@ -20,8 +20,10 @@ """ import os from datetime import datetime +from typing import cast from airflow import models +from airflow.models.xcom_arg import XComArg from airflow.providers.google.marketing_platform.operators.search_ads import ( GoogleSearchAdsDownloadReportOperator, GoogleSearchAdsInsertReportOperator, @@ -59,7 +61,7 @@ # [END howto_search_ads_generate_report_operator] # [START howto_search_ads_get_report_id] - report_id = generate_report.output["report_id"] + report_id = cast(str, XComArg(generate_report, key="report_id")) # [END howto_search_ads_get_report_id] # [START howto_search_ads_get_report_operator] diff --git a/tests/system/providers/microsoft/azure/example_adf_run_pipeline.py b/tests/system/providers/microsoft/azure/example_adf_run_pipeline.py index 6c8eb3757fed2..5ebbe36a267bf 100644 --- a/tests/system/providers/microsoft/azure/example_adf_run_pipeline.py +++ b/tests/system/providers/microsoft/azure/example_adf_run_pipeline.py @@ -16,8 +16,10 @@ # under the License. import os from datetime import datetime, timedelta +from typing import cast from airflow.models import DAG +from airflow.models.xcom_arg import XComArg try: from airflow.operators.empty import EmptyOperator @@ -65,7 +67,7 @@ pipeline_run_sensor = AzureDataFactoryPipelineRunStatusSensor( task_id="pipeline_run_sensor", - run_id=run_pipeline2.output["run_id"], + run_id=cast(str, XComArg(run_pipeline2, key="run_id")), ) # [END howto_operator_adf_run_pipeline_async] diff --git a/tests/system/providers/tableau/example_tableau.py b/tests/system/providers/tableau/example_tableau.py index 30e6583c5199a..734795514d9e4 100644 --- a/tests/system/providers/tableau/example_tableau.py +++ b/tests/system/providers/tableau/example_tableau.py @@ -22,6 +22,7 @@ """ import os from datetime import datetime, timedelta +from typing import cast from airflow import DAG from airflow.providers.tableau.operators.tableau import TableauOperator @@ -60,7 +61,7 @@ ) # The following task queries the status of the workbook refresh job until it succeeds. task_check_job_status = TableauJobStatusSensor( - job_id=task_refresh_workbook_non_blocking.output, + job_id=cast(str, task_refresh_workbook_non_blocking.output), task_id='check_tableau_job_status', ) diff --git a/tests/system/providers/tableau/example_tableau_refresh_workbook.py b/tests/system/providers/tableau/example_tableau_refresh_workbook.py index 999abcb73c369..d9c5ab3b9f96a 100644 --- a/tests/system/providers/tableau/example_tableau_refresh_workbook.py +++ b/tests/system/providers/tableau/example_tableau_refresh_workbook.py @@ -22,6 +22,7 @@ """ import os from datetime import datetime, timedelta +from typing import cast from airflow import DAG from airflow.providers.tableau.operators.tableau import TableauOperator @@ -58,7 +59,7 @@ ) # The following task queries the status of the workbook refresh job until it succeeds. task_check_job_status = TableauJobStatusSensor( - job_id=task_refresh_workbook_non_blocking.output, + job_id=cast(str, task_refresh_workbook_non_blocking.output), task_id='check_tableau_job_status', )