diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index f2b5f5cceea66..c05b1dd62ecaa 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -569,6 +569,7 @@ def _xcom_pull( session: Session = NEW_SESSION, map_indexes: int | Iterable[int] | None = None, default: Any = None, + run_id: str | None = None, ) -> Any: """ Pull XComs that optionally meet certain criteria. @@ -588,6 +589,8 @@ def _xcom_pull( :param include_prior_dates: If False, only XComs from the current execution_date are returned. If *True*, XComs from previous dates are returned as well. + :param run_id: If provided, only pulls XComs from a DagRun w/a matching run_id. + If *None* (default), the run_id of the calling task is used. When pulling one single task (``task_id`` is *None* or a str) without specifying ``map_indexes``, the return value is inferred from whether @@ -603,10 +606,12 @@ def _xcom_pull( """ if dag_id is None: dag_id = ti.dag_id + if run_id is None: + run_id = ti.run_id query = XCom.get_many( key=key, - run_id=ti.run_id, + run_id=run_id, dag_ids=dag_id, task_ids=task_ids, map_indexes=map_indexes, @@ -3472,6 +3477,7 @@ def xcom_pull( *, map_indexes: int | Iterable[int] | None = None, default: Any = None, + run_id: str | None = None, ) -> Any: """ Pull XComs that optionally meet certain criteria. @@ -3491,6 +3497,8 @@ def xcom_pull( :param include_prior_dates: If False, only XComs from the current execution_date are returned. If *True*, XComs from previous dates are returned as well. + :param run_id: If provided, only pulls XComs from a DagRun w/a matching run_id. + If *None* (default), the run_id of the calling task is used. When pulling one single task (``task_id`` is *None* or a str) without specifying ``map_indexes``, the return value is inferred from whether @@ -3513,6 +3521,7 @@ def xcom_pull( session=session, map_indexes=map_indexes, default=default, + run_id=run_id, ) @provide_session diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index d749edce1f2ba..8f15005f43d13 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -1762,6 +1762,40 @@ def test_xcom_pull_different_execution_date(self, create_task_instance): # We *should* get a value using 'include_prior_dates' assert ti.xcom_pull(task_ids="test_xcom", key=key, include_prior_dates=True) == value + def test_xcom_pull_different_run_ids(self, create_task_instance): + """ + tests xcom fetch behavior w/different run ids + """ + key = "xcom_key" + task_id = "test_xcom" + diff_run_id = "diff_run_id" + same_run_id_value = "xcom_value_same_run_id" + diff_run_id_value = "xcom_value_different_run_id" + + ti_same_run_id = create_task_instance( + dag_id="test_xcom", + task_id=task_id, + ) + ti_same_run_id.run(mark_success=True) + ti_same_run_id.xcom_push(key=key, value=same_run_id_value) + + ti_diff_run_id = create_task_instance( + dag_id="test_xcom", + task_id=task_id, + run_id=diff_run_id, + ) + ti_diff_run_id.run(mark_success=True) + ti_diff_run_id.xcom_push(key=key, value=diff_run_id_value) + + assert ( + ti_same_run_id.xcom_pull(run_id=ti_same_run_id.dag_run.run_id, task_ids=task_id, key=key) + == same_run_id_value + ) + assert ( + ti_same_run_id.xcom_pull(run_id=ti_diff_run_id.dag_run.run_id, task_ids=task_id, key=key) + == diff_run_id_value + ) + def test_xcom_push_flag(self, dag_maker): """ Tests the option for Operators to push XComs