diff --git a/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py b/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py index f1a39dd6a24cc..96cf63546fdbc 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py @@ -28,9 +28,12 @@ from collections.abc import Sequence from typing import TYPE_CHECKING, Any, Callable +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.pubsub import PubSubHook from airflow.providers.google.cloud.links.pubsub import PubSubSubscriptionLink, PubSubTopicLink from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator +from airflow.providers.google.cloud.triggers.pubsub import PubsubPullTrigger +from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.cloud.pubsub_v1.types import ( @@ -744,6 +747,9 @@ class PubSubPullOperator(GoogleCloudBaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param deferrable: If True, run the task in the deferrable mode. + :param poll_interval: Time (seconds) to wait between two consecutive calls to check the job. + The default is 300 seconds. """ template_fields: Sequence[str] = ( @@ -762,6 +768,8 @@ def __init__( messages_callback: Callable[[list[ReceivedMessage], Context], Any] | None = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, + deferrable: bool = False, + poll_interval: int = 300, **kwargs, ) -> None: super().__init__(**kwargs) @@ -772,8 +780,23 @@ def __init__( self.ack_messages = ack_messages self.messages_callback = messages_callback self.impersonation_chain = impersonation_chain + self.deferrable = deferrable + self.poll_interval = poll_interval def execute(self, context: Context) -> list: + if self.deferrable: + self.defer( + trigger=PubsubPullTrigger( + subscription=self.subscription, + project_id=self.project_id, + max_messages=self.max_messages, + ack_messages=self.ack_messages, + gcp_conn_id=self.gcp_conn_id, + poke_interval=self.poll_interval, + impersonation_chain=self.impersonation_chain, + ), + method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME, + ) hook = PubSubHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -799,6 +822,17 @@ def execute(self, context: Context) -> list: return ret + def execute_complete(self, context: Context, event: dict[str, Any]): + """If messages_callback is provided, execute it; otherwise, return immediately with trigger event message.""" + if event["status"] == "success": + self.log.info("Sensor pulls messages: %s", event["message"]) + messages_callback = self.messages_callback or self._default_message_callback + _return_value = messages_callback(event["message"], context) + return _return_value + + self.log.info("Sensor failed: %s", event["message"]) + raise AirflowException(event["message"]) + def _default_message_callback( self, pulled_messages: list[ReceivedMessage], diff --git a/providers/google/tests/provider_tests/google/cloud/operators/test_pubsub.py b/providers/google/tests/provider_tests/google/cloud/operators/test_pubsub.py index 9b6f0dce393c6..7a5ceab93b568 100644 --- a/providers/google/tests/provider_tests/google/cloud/operators/test_pubsub.py +++ b/providers/google/tests/provider_tests/google/cloud/operators/test_pubsub.py @@ -20,9 +20,11 @@ from typing import Any from unittest import mock +import pytest from google.api_core.gapic_v1.method import DEFAULT from google.cloud.pubsub_v1.types import ReceivedMessage +from airflow.exceptions import TaskDeferred from airflow.providers.google.cloud.operators.pubsub import ( PubSubCreateSubscriptionOperator, PubSubCreateTopicOperator, @@ -337,3 +339,20 @@ def messages_callback( messages_callback.assert_called_once() assert response == messages_callback_return_value + + @mock.patch("airflow.providers.google.cloud.operators.pubsub.PubSubHook") + def test_execute_deferred(self, mock_hook, create_task_instance_of_operator): + """ + Asserts that a task is deferred and a PubSubPullOperator will be fired + when the PubSubPullOperator is executed with deferrable=True. + """ + ti = create_task_instance_of_operator( + PubSubPullOperator, + dag_id="dag_id", + task_id=TASK_ID, + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + deferrable=True, + ) + with pytest.raises(TaskDeferred) as _: + ti.task.execute(mock.MagicMock())