Skip to content

Commit

Permalink
Add deferrable mode to the PubSubPullOperator
Browse files Browse the repository at this point in the history
  • Loading branch information
molcay authored and Oleg Kachur committed Feb 7, 2025
1 parent 764bf20 commit 838495c
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Callable

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.pubsub import PubSubHook
from airflow.providers.google.cloud.links.pubsub import PubSubSubscriptionLink, PubSubTopicLink
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
from airflow.providers.google.cloud.triggers.pubsub import PubsubPullTrigger
from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.pubsub_v1.types import (
Expand Down Expand Up @@ -744,6 +747,9 @@ class PubSubPullOperator(GoogleCloudBaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: If True, run the task in the deferrable mode.
:param poll_interval: Time (seconds) to wait between two consecutive calls to check the job.
The default is 300 seconds.
"""

template_fields: Sequence[str] = (
Expand All @@ -762,6 +768,8 @@ def __init__(
messages_callback: Callable[[list[ReceivedMessage], Context], Any] | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = False,
poll_interval: int = 300,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -772,8 +780,23 @@ def __init__(
self.ack_messages = ack_messages
self.messages_callback = messages_callback
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable
self.poll_interval = poll_interval

def execute(self, context: Context) -> list:
if self.deferrable:
self.defer(
trigger=PubsubPullTrigger(
subscription=self.subscription,
project_id=self.project_id,
max_messages=self.max_messages,
ack_messages=self.ack_messages,
gcp_conn_id=self.gcp_conn_id,
poke_interval=self.poll_interval,
impersonation_chain=self.impersonation_chain,
),
method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME,
)
hook = PubSubHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
Expand All @@ -799,6 +822,17 @@ def execute(self, context: Context) -> list:

return ret

def execute_complete(self, context: Context, event: dict[str, Any]):
"""If messages_callback is provided, execute it; otherwise, return immediately with trigger event message."""
if event["status"] == "success":
self.log.info("Sensor pulls messages: %s", event["message"])
messages_callback = self.messages_callback or self._default_message_callback
_return_value = messages_callback(event["message"], context)
return _return_value

self.log.info("Sensor failed: %s", event["message"])
raise AirflowException(event["message"])

def _default_message_callback(
self,
pulled_messages: list[ReceivedMessage],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
from typing import Any
from unittest import mock

import pytest
from google.api_core.gapic_v1.method import DEFAULT
from google.cloud.pubsub_v1.types import ReceivedMessage

from airflow.exceptions import TaskDeferred
from airflow.providers.google.cloud.operators.pubsub import (
PubSubCreateSubscriptionOperator,
PubSubCreateTopicOperator,
Expand Down Expand Up @@ -337,3 +339,20 @@ def messages_callback(
messages_callback.assert_called_once()

assert response == messages_callback_return_value

@mock.patch("airflow.providers.google.cloud.operators.pubsub.PubSubHook")
def test_execute_deferred(self, mock_hook, create_task_instance_of_operator):
"""
Asserts that a task is deferred and a PubSubPullOperator will be fired
when the PubSubPullOperator is executed with deferrable=True.
"""
ti = create_task_instance_of_operator(
PubSubPullOperator,
dag_id="dag_id",
task_id=TASK_ID,
project_id=TEST_PROJECT,
subscription=TEST_SUBSCRIPTION,
deferrable=True,
)
with pytest.raises(TaskDeferred) as _:
ti.task.execute(mock.MagicMock())

0 comments on commit 838495c

Please sign in to comment.