From 78520afcf0ff504722554d696fb560ab25a7a7d9 Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 17 Jun 2024 13:30:06 +0200 Subject: [PATCH 1/6] refactor: Make sure xcoms work correctly in multi-threaded environment by taking the map_index into account --- .../microsoft/azure/operators/msgraph.py | 36 +++++++++++-------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/airflow/providers/microsoft/azure/operators/msgraph.py b/airflow/providers/microsoft/azure/operators/msgraph.py index 39ca32d2b6106..73be47fb50383 100644 --- a/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/airflow/providers/microsoft/azure/operators/msgraph.py @@ -178,8 +178,10 @@ def execute_complete( event["response"] = result try: - self.trigger_next_link(response, method_name=self.pull_execute_complete.__name__) + self.trigger_next_link(response=response, method_name=self.execute_complete.__name__) except TaskDeferred as exception: + self.results = self.pull_xcom(context=context) + self.log.debug("value: %s", result) self.append_result( result=result, append_result_as_list_if_absent=True, @@ -198,8 +200,6 @@ def append_result( result: Any, append_result_as_list_if_absent: bool = False, ): - self.log.debug("value: %s", result) - if isinstance(self.results, list): if isinstance(result, list): self.results.extend(result) @@ -214,30 +214,38 @@ def append_result( else: self.results = result - def push_xcom(self, context: Context, value) -> None: - self.log.debug("do_xcom_push: %s", self.do_xcom_push) - if self.do_xcom_push: - self.log.info("Pushing XCom with key '%s': %s", self.key, value) - self.xcom_push(context=context, key=self.key, value=value) + def xcom_key(self, context: Context) -> str: + map_index = context["ti"].map_index + return f"{self.key}_{map_index}" if map_index else self.key - def pull_execute_complete(self, context: Context, event: dict[Any, Any] | None = None) -> Any: - self.results = list( + def pull_xcom(self, context: Context) -> list: + key = self.xcom_key(context=context) + value = list( self.xcom_pull( context=context, task_ids=self.task_id, dag_id=self.dag_id, - key=self.key, + key=key, ) or [] ) + self.log.info( "Pulled XCom with task_id '%s' and dag_id '%s' and key '%s': %s", self.task_id, self.dag_id, - self.key, - self.results, + key, + value, ) - return self.execute_complete(context, event) + + return value + + def push_xcom(self, context: Context, value) -> None: + self.log.debug("do_xcom_push: %s", self.do_xcom_push) + if self.do_xcom_push: + key = self.xcom_key(context=context) + self.log.info("Pushing XCom with key '%s': %s", key, value) + self.xcom_push(context=context, key=key, value=value) @staticmethod def paginate(operator: MSGraphAsyncOperator, response: dict) -> tuple[Any, dict[str, Any] | None]: From f89ce4fb67062518f9ec3ebe85f592bdcca5f6f4 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 18 Jun 2024 16:35:50 +0200 Subject: [PATCH 2/6] fix: Test if map_index is not None instead of just doing an if otherwise we will have false positive as map_index is an integer --- airflow/providers/microsoft/azure/operators/msgraph.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/airflow/providers/microsoft/azure/operators/msgraph.py b/airflow/providers/microsoft/azure/operators/msgraph.py index 73be47fb50383..06883c68f2076 100644 --- a/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/airflow/providers/microsoft/azure/operators/msgraph.py @@ -180,8 +180,8 @@ def execute_complete( try: self.trigger_next_link(response=response, method_name=self.execute_complete.__name__) except TaskDeferred as exception: - self.results = self.pull_xcom(context=context) - self.log.debug("value: %s", result) + results = self.pull_xcom(context=context) + self.log.debug("result: %s", result) self.append_result( result=result, append_result_as_list_if_absent=True, @@ -216,7 +216,7 @@ def append_result( def xcom_key(self, context: Context) -> str: map_index = context["ti"].map_index - return f"{self.key}_{map_index}" if map_index else self.key + return f"{self.key}_{map_index}" if map_index is not None else self.key def pull_xcom(self, context: Context) -> list: key = self.xcom_key(context=context) @@ -227,7 +227,7 @@ def pull_xcom(self, context: Context) -> list: dag_id=self.dag_id, key=key, ) - or [] + or [] # noqa: W503 ) self.log.info( From 67a3f9d3558e7233a16ce1287507ecf0383ee9db Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 18 Jun 2024 17:04:42 +0200 Subject: [PATCH 3/6] refactor: Also test if operator is running in dynamic task mapping or not when pulling XCom --- airflow/providers/microsoft/azure/operators/msgraph.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/airflow/providers/microsoft/azure/operators/msgraph.py b/airflow/providers/microsoft/azure/operators/msgraph.py index 728f8f851205b..d4da8f03c9a5b 100644 --- a/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/airflow/providers/microsoft/azure/operators/msgraph.py @@ -181,7 +181,6 @@ def execute_complete( self.trigger_next_link(response=response, method_name=self.execute_complete.__name__) except TaskDeferred as exception: self.results = self.pull_xcom(context=context) - self.log.debug("result: %s", result) self.append_result( result=result, append_result_as_list_if_absent=True, @@ -190,7 +189,6 @@ def execute_complete( raise exception self.append_result(result=result) - self.log.debug("results: %s", self.results) return self.results return None @@ -200,6 +198,8 @@ def append_result( result: Any, append_result_as_list_if_absent: bool = False, ): + self.log.debug("result: %s", result) + if isinstance(self.results, list): if isinstance(result, list): self.results.extend(result) @@ -214,6 +214,8 @@ def append_result( else: self.results = result + self.log.debug("results: %s", self.results) + def xcom_key(self, context: Context) -> str: map_index = context["ti"].map_index return f"{self.key}_{map_index}" if map_index is not None else self.key @@ -227,7 +229,7 @@ def pull_xcom(self, context: Context) -> list: dag_id=self.dag_id, key=key, ) - or [] # noqa: W503 + or [] ) self.log.info( @@ -238,7 +240,7 @@ def pull_xcom(self, context: Context) -> list: value, ) - return value + return value[0] if value and context["ti"].map_index is not None else value def push_xcom(self, context: Context, value) -> None: self.log.debug("do_xcom_push: %s", self.do_xcom_push) From 716e669c5f06f669c582ec9d9fb6409ad83d2c60 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 18 Jun 2024 18:11:54 +0200 Subject: [PATCH 4/6] refactor: Try using xcom_pull directly from TaskInstance as there you can specify the map_index from which you want to pull --- .../microsoft/azure/operators/msgraph.py | 33 ++++++++----------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/airflow/providers/microsoft/azure/operators/msgraph.py b/airflow/providers/microsoft/azure/operators/msgraph.py index d4da8f03c9a5b..7628532b48731 100644 --- a/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/airflow/providers/microsoft/azure/operators/msgraph.py @@ -181,6 +181,7 @@ def execute_complete( self.trigger_next_link(response=response, method_name=self.execute_complete.__name__) except TaskDeferred as exception: self.results = self.pull_xcom(context=context) + self.log.debug("result: %s", self.results) self.append_result( result=result, append_result_as_list_if_absent=True, @@ -189,6 +190,7 @@ def execute_complete( raise exception self.append_result(result=result) + self.log.debug("results: %s", self.results) return self.results return None @@ -198,8 +200,6 @@ def append_result( result: Any, append_result_as_list_if_absent: bool = False, ): - self.log.debug("result: %s", result) - if isinstance(self.results, list): if isinstance(result, list): self.results.extend(result) @@ -214,40 +214,33 @@ def append_result( else: self.results = result - self.log.debug("results: %s", self.results) - - def xcom_key(self, context: Context) -> str: - map_index = context["ti"].map_index - return f"{self.key}_{map_index}" if map_index is not None else self.key - def pull_xcom(self, context: Context) -> list: - key = self.xcom_key(context=context) + map_index = context["ti"].map_index value = list( - self.xcom_pull( - context=context, + context["ti"].xcom_pull( + key=self.key, task_ids=self.task_id, dag_id=self.dag_id, - key=key, + map_indexes=map_index, ) - or [] + or [] # noqa: W503 ) self.log.info( - "Pulled XCom with task_id '%s' and dag_id '%s' and key '%s': %s", + "Pulled XCom with task_id '%s' and dag_id '%s' and key '%s' and map_index %s: %s", self.task_id, self.dag_id, - key, + self.key, + map_index, value, ) - - return value[0] if value and context["ti"].map_index is not None else value + return value def push_xcom(self, context: Context, value) -> None: self.log.debug("do_xcom_push: %s", self.do_xcom_push) if self.do_xcom_push: - key = self.xcom_key(context=context) - self.log.info("Pushing XCom with key '%s': %s", key, value) - self.xcom_push(context=context, key=key, value=value) + self.log.info("Pushing XCom with key '%s': %s", self.key, value) + self.xcom_push(context=context, key=self.key, value=value) @staticmethod def paginate(operator: MSGraphAsyncOperator, response: dict) -> tuple[Any, dict[str, Any] | None]: From ef380f73701ba8bc62ae8f98e7213a0bb3665ce9 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 19 Jun 2024 14:52:36 +0200 Subject: [PATCH 5/6] refactor: Refactored logging of xcom_pull --- .../microsoft/azure/operators/msgraph.py | 30 ++++++++++++------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/airflow/providers/microsoft/azure/operators/msgraph.py b/airflow/providers/microsoft/azure/operators/msgraph.py index 7628532b48731..cd387954737ab 100644 --- a/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/airflow/providers/microsoft/azure/operators/msgraph.py @@ -181,7 +181,6 @@ def execute_complete( self.trigger_next_link(response=response, method_name=self.execute_complete.__name__) except TaskDeferred as exception: self.results = self.pull_xcom(context=context) - self.log.debug("result: %s", self.results) self.append_result( result=result, append_result_as_list_if_absent=True, @@ -190,7 +189,6 @@ def execute_complete( raise exception self.append_result(result=result) - self.log.debug("results: %s", self.results) return self.results return None @@ -223,17 +221,27 @@ def pull_xcom(self, context: Context) -> list: dag_id=self.dag_id, map_indexes=map_index, ) - or [] # noqa: W503 + or [] ) - self.log.info( - "Pulled XCom with task_id '%s' and dag_id '%s' and key '%s' and map_index %s: %s", - self.task_id, - self.dag_id, - self.key, - map_index, - value, - ) + if map_index: + self.log.info( + "Pulled XCom with task_id '%s' and dag_id '%s' and key '%s' and map_index %s: %s", + self.task_id, + self.dag_id, + self.key, + map_index, + value, + ) + else: + self.log.info( + "Pulled XCom with task_id '%s' and dag_id '%s' and key '%s': %s", + self.task_id, + self.dag_id, + self.key, + value, + ) + return value def push_xcom(self, context: Context, value) -> None: From f9d351c0047095e1f0980d9f797a3e79c46fbd4e Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 19 Jun 2024 15:02:38 +0200 Subject: [PATCH 6/6] refactor: Now the mocked context also takes into account the map_index when running tests --- tests/providers/microsoft/conftest.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/providers/microsoft/conftest.py b/tests/providers/microsoft/conftest.py index b2db8c44ba13a..ecd19d8865c68 100644 --- a/tests/providers/microsoft/conftest.py +++ b/tests/providers/microsoft/conftest.py @@ -143,6 +143,8 @@ def xcom_pull( map_indexes: Iterable[int] | int | None = None, default: Any | None = None, ) -> Any: + if map_indexes: + return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}_{map_indexes}") return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}") def xcom_push( @@ -152,7 +154,7 @@ def xcom_push( execution_date: datetime | None = None, session: Session = NEW_SESSION, ) -> None: - values[f"{self.task_id}_{self.dag_id}_{key}"] = value + values[f"{self.task_id}_{self.dag_id}_{key}_{self.map_index}"] = value values["ti"] = MockedTaskInstance(task=task)