Skip to content

Commit

Permalink
Add output property to MappedOperator (#25604)
Browse files Browse the repository at this point in the history
Co-authored-by: Tzu-ping Chung <[email protected]>
  • Loading branch information
josh-fell and uranusjr authored Aug 28, 2022
1 parent 05cbba3 commit 1ed0146
Show file tree
Hide file tree
Showing 45 changed files with 227 additions and 108 deletions.
9 changes: 3 additions & 6 deletions airflow/example_dags/example_xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"""Example DAG demonstrating the usage of XComs."""
import pendulum

from airflow import DAG
from airflow import DAG, XComArg
from airflow.decorators import task
from airflow.operators.bash import BashOperator

Expand Down Expand Up @@ -79,8 +79,8 @@ def pull_value_from_bash_push(ti=None):
bash_pull = BashOperator(
task_id='bash_pull',
bash_command='echo "bash pull demo" && '
f'echo "The xcom pushed manually is {bash_push.output["manually_pushed_value"]}" && '
f'echo "The returned_value xcom is {bash_push.output}" && '
f'echo "The xcom pushed manually is {XComArg(bash_push, key="manually_pushed_value")}" && '
f'echo "The returned_value xcom is {XComArg(bash_push)}" && '
'echo "finished"',
do_xcom_push=False,
)
Expand All @@ -90,6 +90,3 @@ def pull_value_from_bash_push(ti=None):
[bash_pull, python_pull_from_bash] << bash_push

puller(push_by_returning()) << push()

# Task dependencies created via `XComArgs`:
# pull << push2
3 changes: 2 additions & 1 deletion airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@

from airflow.models.dag import DAG
from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.xcom_arg import XComArg
from airflow.utils.task_group import TaskGroup

ScheduleInterval = Union[str, timedelta, relativedelta]
Expand Down Expand Up @@ -1365,7 +1366,7 @@ def leaves(self) -> List["BaseOperator"]:
return [self]

@property
def output(self):
def output(self) -> "XComArg":
"""Returns reference to XCom pushed by current operator"""
from airflow.models.xcom_arg import XComArg

Expand Down
7 changes: 7 additions & 0 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,13 @@ def get_dag(self) -> Optional["DAG"]:
"""Implementing Operator."""
return self.dag

@property
def output(self) -> "XComArg":
"""Returns reference to XCom pushed by current operator"""
from airflow.models.xcom_arg import XComArg

return XComArg(operator=self)

def serialize_for_task_group(self) -> Tuple[DagAttributeTypes, Any]:
"""Implementing DAGNode."""
return DagAttributeTypes.OP, self.task_id
Expand Down
13 changes: 8 additions & 5 deletions airflow/providers/amazon/aws/example_dags/example_dms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import json
import os
from datetime import datetime
from typing import cast

import boto3
from sqlalchemy import Column, MetaData, String, Table, create_engine
Expand Down Expand Up @@ -256,10 +257,12 @@ def delete_dms_assets():
)
# [END howto_operator_dms_create_task]

task_arn = cast(str, create_task.output)

# [START howto_operator_dms_start_task]
start_task = DmsStartTaskOperator(
task_id='start_task',
replication_task_arn=create_task.output,
replication_task_arn=task_arn,
)
# [END howto_operator_dms_start_task]

Expand All @@ -280,30 +283,30 @@ def delete_dms_assets():

await_task_start = DmsTaskBaseSensor(
task_id='await_task_start',
replication_task_arn=create_task.output,
replication_task_arn=task_arn,
target_statuses=['running'],
termination_statuses=['stopped', 'deleting', 'failed'],
)

# [START howto_operator_dms_stop_task]
stop_task = DmsStopTaskOperator(
task_id='stop_task',
replication_task_arn=create_task.output,
replication_task_arn=task_arn,
)
# [END howto_operator_dms_stop_task]

# TaskCompletedSensor actually waits until task reaches the "Stopped" state, so it will work here.
# [START howto_sensor_dms_task_completed]
await_task_stop = DmsTaskCompletedSensor(
task_id='await_task_stop',
replication_task_arn=create_task.output,
replication_task_arn=task_arn,
)
# [END howto_sensor_dms_task_completed]

# [START howto_operator_dms_delete_task]
delete_task = DmsDeleteTaskOperator(
task_id='delete_task',
replication_task_arn=create_task.output,
replication_task_arn=task_arn,
trigger_rule='all_done',
)
# [END howto_operator_dms_delete_task]
Expand Down
9 changes: 6 additions & 3 deletions airflow/providers/amazon/aws/example_dags/example_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

from datetime import datetime
from typing import cast

from airflow import DAG
from airflow.models.baseoperator import chain
Expand Down Expand Up @@ -99,18 +100,20 @@
)
# [END howto_operator_ecs_register_task_definition]

registered_task_definition = cast(str, register_task.output)

# [START howto_sensor_ecs_task_definition_state]
await_task_definition = EcsTaskDefinitionStateSensor(
task_id='await_task_definition',
task_definition=register_task.output,
task_definition=registered_task_definition,
)
# [END howto_sensor_ecs_task_definition_state]

# [START howto_operator_ecs_run_task]
run_task = EcsRunTaskOperator(
task_id="run_task",
cluster=EXISTING_CLUSTER_NAME,
task_definition=register_task.output,
task_definition=registered_task_definition,
launch_type="EC2",
overrides={
"containerOverrides": [
Expand Down Expand Up @@ -156,7 +159,7 @@
deregister_task = EcsDeregisterTaskDefinitionOperator(
task_id='deregister_task',
trigger_rule=TriggerRule.ALL_DONE,
task_definition=register_task.output,
task_definition=registered_task_definition,
)
# [END howto_operator_ecs_deregister_task_definition]

Expand Down
16 changes: 8 additions & 8 deletions airflow/providers/amazon/aws/example_dags/example_emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
import os
from datetime import datetime
from typing import cast

from airflow import DAG
from airflow.models.baseoperator import chain
Expand Down Expand Up @@ -79,39 +80,38 @@
)
# [END howto_operator_emr_create_job_flow]

job_flow_id = cast(str, job_flow_creator.output)

# [START howto_sensor_emr_job_flow]
job_sensor = EmrJobFlowSensor(
task_id='check_job_flow',
job_flow_id=job_flow_creator.output,
)
job_sensor = EmrJobFlowSensor(task_id='check_job_flow', job_flow_id=job_flow_id)
# [END howto_sensor_emr_job_flow]

# [START howto_operator_emr_modify_cluster]
cluster_modifier = EmrModifyClusterOperator(
task_id='modify_cluster', cluster_id=job_flow_creator.output, step_concurrency_level=1
task_id='modify_cluster', cluster_id=job_flow_id, step_concurrency_level=1
)
# [END howto_operator_emr_modify_cluster]

# [START howto_operator_emr_add_steps]
step_adder = EmrAddStepsOperator(
task_id='add_steps',
job_flow_id=job_flow_creator.output,
job_flow_id=job_flow_id,
steps=SPARK_STEPS,
)
# [END howto_operator_emr_add_steps]

# [START howto_sensor_emr_step]
step_checker = EmrStepSensor(
task_id='watch_step',
job_flow_id=job_flow_creator.output,
job_flow_id=job_flow_id,
step_id="{{ task_instance.xcom_pull(task_ids='add_steps', key='return_value')[0] }}",
)
# [END howto_sensor_emr_step]

# [START howto_operator_emr_terminate_job_flow]
cluster_remover = EmrTerminateJobFlowOperator(
task_id='remove_cluster',
job_flow_id=job_flow_creator.output,
job_flow_id=job_flow_id,
)
# [END howto_operator_emr_terminate_job_flow]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
"""
import os
from datetime import datetime
from typing import cast

from airflow import models
from airflow.models.xcom_arg import XComArg
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
from airflow.providers.google.cloud.operators.automl import (
AutoMLCreateDatasetOperator,
Expand Down Expand Up @@ -67,7 +69,7 @@
task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION
)

dataset_id = create_dataset_task.output['dataset_id']
dataset_id = cast(str, XComArg(create_dataset_task, key='dataset_id'))

import_dataset_task = AutoMLImportDataOperator(
task_id="import_dataset_task",
Expand All @@ -80,7 +82,7 @@

create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION)

model_id = create_model.output['model_id']
model_id = cast(str, XComArg(create_model, key='model_id'))

delete_model_task = AutoMLDeleteModelOperator(
task_id="delete_model_task",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
"""
import os
from datetime import datetime
from typing import cast

from airflow import models
from airflow.models.xcom_arg import XComArg
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
from airflow.providers.google.cloud.operators.automl import (
AutoMLCreateDatasetOperator,
Expand Down Expand Up @@ -68,7 +70,7 @@
task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION
)

dataset_id = create_dataset_task.output['dataset_id']
dataset_id = cast(str, XComArg(create_dataset_task, key='dataset_id'))

import_dataset_task = AutoMLImportDataOperator(
task_id="import_dataset_task",
Expand All @@ -81,7 +83,7 @@

create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION)

model_id = create_model.output['model_id']
model_id = cast(str, XComArg(create_model, key='model_id'))

delete_model_task = AutoMLDeleteModelOperator(
task_id="delete_model_task",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
import os
from copy import deepcopy
from datetime import datetime
from typing import Dict, List
from typing import Dict, List, cast

from airflow import models
from airflow.models.xcom_arg import XComArg
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
from airflow.providers.google.cloud.operators.automl import (
AutoMLBatchPredictOperator,
Expand Down Expand Up @@ -103,7 +104,7 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str:
project_id=GCP_PROJECT_ID,
)

dataset_id = create_dataset_task.output['dataset_id']
dataset_id = cast(str, XComArg(create_dataset_task, key='dataset_id'))
# [END howto_operator_automl_create_dataset]

MODEL["dataset_id"] = dataset_id
Expand Down Expand Up @@ -158,7 +159,7 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str:
project_id=GCP_PROJECT_ID,
)

model_id = create_model_task.output['model_id']
model_id = cast(str, XComArg(create_model_task, key='model_id'))
# [END howto_operator_automl_create_model]

# [START howto_operator_automl_delete_model]
Expand Down Expand Up @@ -209,7 +210,7 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str:
project_id=GCP_PROJECT_ID,
)

dataset_id = create_dataset_task2.output['dataset_id']
dataset_id = cast(str, XComArg(create_dataset_task2, key='dataset_id'))

import_dataset_task = AutoMLImportDataOperator(
task_id="import_dataset_task",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
"""
import os
from datetime import datetime
from typing import cast

from airflow import models
from airflow.models.xcom_arg import XComArg
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
from airflow.providers.google.cloud.operators.automl import (
AutoMLCreateDatasetOperator,
Expand Down Expand Up @@ -74,7 +76,7 @@
task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION
)

dataset_id = create_dataset_task.output["dataset_id"]
dataset_id = cast(str, XComArg(create_dataset_task, key="dataset_id"))

import_dataset_task = AutoMLImportDataOperator(
task_id="import_dataset_task",
Expand All @@ -87,7 +89,7 @@

create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION)

model_id = create_model.output["model_id"]
model_id = cast(str, XComArg(create_model, key="model_id"))

delete_model_task = AutoMLDeleteModelOperator(
task_id="delete_model_task",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
"""
import os
from datetime import datetime
from typing import cast

from airflow import models
from airflow.models.xcom_arg import XComArg
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
from airflow.providers.google.cloud.operators.automl import (
AutoMLCreateDatasetOperator,
Expand Down Expand Up @@ -71,7 +73,7 @@
task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION
)

dataset_id = create_dataset_task.output["dataset_id"]
dataset_id = cast(str, XComArg(create_dataset_task, key="dataset_id"))

import_dataset_task = AutoMLImportDataOperator(
task_id="import_dataset_task",
Expand All @@ -84,7 +86,7 @@

create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION)

model_id = create_model.output["model_id"]
model_id = cast(str, XComArg(create_model, key="model_id"))

delete_model_task = AutoMLDeleteModelOperator(
task_id="delete_model_task",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
"""
import os
from datetime import datetime
from typing import cast

from airflow import models
from airflow.models.xcom_arg import XComArg
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
from airflow.providers.google.cloud.operators.automl import (
AutoMLCreateDatasetOperator,
Expand Down Expand Up @@ -72,7 +74,7 @@
task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION
)

dataset_id = create_dataset_task.output["dataset_id"]
dataset_id = cast(str, XComArg(create_dataset_task, key="dataset_id"))

import_dataset_task = AutoMLImportDataOperator(
task_id="import_dataset_task",
Expand All @@ -85,7 +87,7 @@

create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION)

model_id = create_model.output["model_id"]
model_id = cast(str, XComArg(create_model, key="model_id"))

delete_model_task = AutoMLDeleteModelOperator(
task_id="delete_model_task",
Expand Down
Loading

0 comments on commit 1ed0146

Please sign in to comment.