diff --git a/UPDATING.md b/UPDATING.md index f2613cd3b935b..a2e98ae93ca37 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -40,6 +40,27 @@ assists users migrating to a new version. ## Airflow Master +### Changes to Google PubSub Operators, Hook and Sensor +In the `PubSubPublishOperator` and `PubSubHook.publsh` method the data field in a message should be bytestring (utf-8 encoded) rather than base64 encoded string. + +Due to the normalization of the parameters within GCP operators and hooks a parameters like `project` or `topic_project` +are deprecated and will be substituted by parameter `project_id`. +In `PubSubHook.create_subscription` hook method in the parameter `subscription_project` is replaced by `subscription_project_id`. +Template fields are updated accordingly and old ones may not work. + +It is required now to pass key-word only arguments to `PubSub` hook. + +These changes are not backward compatible. + +Affected components: + * airflow.gcp.hooks.pubsub.PubSubHook + * airflow.gcp.operators.pubsub.PubSubTopicCreateOperator + * airflow.gcp.operators.pubsub.PubSubSubscriptionCreateOperator + * airflow.gcp.operators.pubsub.PubSubTopicDeleteOperator + * airflow.gcp.operators.pubsub.PubSubSubscriptionDeleteOperator + * airflow.gcp.operators.pubsub.PubSubPublishOperator + * airflow.gcp.sensors.pubsub.PubSubPullSensor + ### Changes to `aws_default` Connection's default region The region of Airflow's default connection to AWS (`aws_default`) was previously diff --git a/airflow/gcp/example_dags/example_pubsub.py b/airflow/gcp/example_dags/example_pubsub.py index 40df801d5665a..cd909ec891ebb 100644 --- a/airflow/gcp/example_dags/example_pubsub.py +++ b/airflow/gcp/example_dags/example_pubsub.py @@ -25,7 +25,6 @@ import airflow from airflow import models - from airflow.gcp.operators.pubsub import ( PubSubTopicCreateOperator, PubSubSubscriptionDeleteOperator, @@ -38,7 +37,7 @@ GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "your-project-id") TOPIC = "PubSubTestTopic" -MESSAGE = {"attributes": {"name": "wrench", "mass": "1.3kg", "count": "3"}} +MESSAGE = {"data": b"Tool", "attributes": {"name": "wrench", "mass": "1.3kg", "count": "3"}} default_args = {"start_date": airflow.utils.dates.days_ago(1)} @@ -54,18 +53,18 @@ schedule_interval=None, # Override to match your needs ) as example_dag: create_topic = PubSubTopicCreateOperator( - task_id="create_topic", topic=TOPIC, project=GCP_PROJECT_ID + task_id="create_topic", topic=TOPIC, project_id=GCP_PROJECT_ID ) subscribe_task = PubSubSubscriptionCreateOperator( - task_id="subscribe_task", topic_project=GCP_PROJECT_ID, topic=TOPIC + task_id="subscribe_task", project_id=GCP_PROJECT_ID, topic=TOPIC ) subscription = "{{ task_instance.xcom_pull('subscribe_task') }}" pull_messages = PubSubPullSensor( task_id="pull_messages", ack_messages=True, - project=GCP_PROJECT_ID, + project_id=GCP_PROJECT_ID, subscription=subscription, ) @@ -75,19 +74,18 @@ publish_task = PubSubPublishOperator( task_id="publish_task", - project=GCP_PROJECT_ID, + project_id=GCP_PROJECT_ID, topic=TOPIC, messages=[MESSAGE, MESSAGE, MESSAGE], ) - unsubscribe_task = PubSubSubscriptionDeleteOperator( task_id="unsubscribe_task", - project=GCP_PROJECT_ID, + project_id=GCP_PROJECT_ID, subscription="{{ task_instance.xcom_pull('subscribe_task') }}", ) delete_topic = PubSubTopicDeleteOperator( - task_id="delete_topic", topic=TOPIC, project=GCP_PROJECT_ID + task_id="delete_topic", topic=TOPIC, project_id=GCP_PROJECT_ID ) create_topic >> subscribe_task >> publish_task diff --git a/airflow/gcp/hooks/pubsub.py b/airflow/gcp/hooks/pubsub.py index 9b865ecd5dc61..d8f612665b64e 100644 --- a/airflow/gcp/hooks/pubsub.py +++ b/airflow/gcp/hooks/pubsub.py @@ -19,29 +19,30 @@ """ This module contains a Google Pub/Sub Hook. """ -from typing import Any, List, Dict, Optional +import warnings +from base64 import b64decode +from typing import List, Dict, Optional, Sequence, Tuple, Union from uuid import uuid4 -from googleapiclient.discovery import build +from cached_property import cached_property +from google.api_core.retry import Retry +from google.api_core.exceptions import AlreadyExists, GoogleAPICallError +from google.cloud.exceptions import NotFound +from google.cloud.pubsub_v1 import PublisherClient, SubscriberClient +from google.cloud.pubsub_v1.types import Duration, PushConfig, MessageStoragePolicy from googleapiclient.errors import HttpError +from airflow.version import version from airflow.gcp.hooks.base import GoogleCloudBaseHook -def _format_subscription(project, subscription): - return 'projects/{}/subscriptions/{}'.format(project, subscription) - - -def _format_topic(project, topic): - return 'projects/{}/topics/{}'.format(project, topic) - - class PubSubException(Exception): """ Alias for Exception. """ +# noinspection PyAbstractClass class PubSubHook(GoogleCloudBaseHook): """ Hook for accessing Google Pub/Sub. @@ -52,254 +53,510 @@ class PubSubHook(GoogleCloudBaseHook): def __init__(self, gcp_conn_id: str = 'google_cloud_default', delegate_to: Optional[str] = None) -> None: super().__init__(gcp_conn_id, delegate_to=delegate_to) + self._client = None + + def get_conn(self) -> PublisherClient: + """ + Retrieves connection to Google Cloud Pub/Sub. - def get_conn(self) -> Any: + :return: Google Cloud Pub/Sub client object. + :rtype: google.cloud.pubsub_v1.PublisherClient """ - Returns a Pub/Sub service object. + if not self._client: + self._client = PublisherClient( + credentials=self._get_credentials(), + client_info=self.client_info + ) + return self._client - :rtype: googleapiclient.discovery.Resource + @cached_property + def subscriber_client(self) -> SubscriberClient: """ - http_authorized = self._authorize() - return build( - 'pubsub', 'v1', http=http_authorized, cache_discovery=False) + Creates SubscriberClient. - def publish(self, project: str, topic: str, messages: List[Dict]) -> None: + :return: Google Cloud Pub/Sub client object. + :rtype: google.cloud.pubsub_v1.SubscriberClient + """ + return SubscriberClient( + credentials=self._get_credentials(), + client_info=self.client_info + ) + + @GoogleCloudBaseHook.fallback_to_default_project_id + def publish( + self, + topic: str, + messages: List[Dict], + project_id: Optional[str] = None, + ) -> None: """ Publishes messages to a Pub/Sub topic. - :param project: the GCP project ID in which to publish - :type project: str :param topic: the Pub/Sub topic to which to publish; do not include the ``projects/{project}/topics/`` prefix. :type topic: str :param messages: messages to publish; if the data field in a - message is set, it should already be base64 encoded. + message is set, it should be a bytestring (utf-8 encoded) :type messages: list of PubSub messages; see http://cloud.google.com/pubsub/docs/reference/rest/v1/PubsubMessage + :param project_id: Optional, the GCP project ID in which to publish. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str """ - body = {'messages': messages} - full_topic = _format_topic(project, topic) - request = self.get_conn().projects().topics().publish( # pylint: disable=no-member - topic=full_topic, body=body) + assert project_id is not None + self._validate_messages(messages) + + publisher = self.get_conn() + topic_path = PublisherClient.topic_path(project_id, topic) # pylint: disable=no-member + + self.log.info("Publish %d messages to topic (path) %s", len(messages), topic_path) try: - request.execute(num_retries=self.num_retries) - except HttpError as e: - raise PubSubException( - 'Error publishing to topic {}'.format(full_topic), e) + for message in messages: + publisher.publish( + topic=topic_path, + data=message.get("data", b''), + **message.get('attributes', {}) + ) + except GoogleAPICallError as e: + raise PubSubException('Error publishing to topic {}'.format(topic_path), e) + + self.log.info("Published %d messages to topic (path) %s", len(messages), topic_path) + + @staticmethod + def _validate_messages(messages) -> None: + for message in messages: + # To warn about broken backward compatibility + # TODO: remove one day + if "data" in message and isinstance(message["data"], str): + try: + b64decode(message["data"]) + warnings.warn( + "The base 64 encoded string as 'data' field has been deprecated. " + "You should pass bytestring (utf-8 encoded).", DeprecationWarning, stacklevel=4 + ) + except ValueError: + pass - def create_topic(self, project: str, topic: str, fail_if_exists: bool = False) -> None: + if not isinstance(message, dict): + raise PubSubException("Wrong message type. Must be a dictionary.") + if "data" not in message and "attributes" not in message: + raise PubSubException("Wrong message. Dictionary must contain 'data' or 'attributes'.") + if "data" in message and not isinstance(message["data"], bytes): + raise PubSubException("Wrong message. 'data' must be send as a bytestring") + if ("data" not in message and "attributes" in message and not message["attributes"]) \ + or ("attributes" in message and not isinstance(message["attributes"], dict)): + raise PubSubException( + "Wrong message. If 'data' is not provided 'attributes' must be a non empty dictionary.") + + # pylint: disable=too-many-arguments + @GoogleCloudBaseHook.fallback_to_default_project_id + def create_topic( + self, + topic: str, + project_id: Optional[str] = None, + fail_if_exists: bool = False, + labels: Optional[Dict[str, str]] = None, + message_storage_policy: Union[Dict, MessageStoragePolicy] = None, + kms_key_name: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: """ Creates a Pub/Sub topic, if it does not already exist. - :param project: the GCP project ID in which to create - the topic - :type project: str :param topic: the Pub/Sub topic name to create; do not include the ``projects/{project}/topics/`` prefix. :type topic: str + :param project_id: Optional, the GCP project ID in which to create the topic + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str :param fail_if_exists: if set, raise an exception if the topic already exists :type fail_if_exists: bool + :param labels: Client-assigned labels; see + https://cloud.google.com/pubsub/docs/labels + :type labels: Dict[str, str] + :param message_storage_policy: Policy constraining the set + of Google Cloud Platform regions where messages published to + the topic may be stored. If not present, then no constraints + are in effect. + :type message_storage_policy: + Union[Dict, google.cloud.pubsub_v1.types.MessageStoragePolicy] + :param kms_key_name: The resource name of the Cloud KMS CryptoKey + to be used to protect access to messages published on this topic. + The expected format is + ``projects/*/locations/*/keyRings/*/cryptoKeys/*``. + :type kms_key_name: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]]] """ - service = self.get_conn() - full_topic = _format_topic(project, topic) + assert project_id is not None + publisher = self.get_conn() + topic_path = PublisherClient.topic_path(project_id, topic) # pylint: disable=no-member + + # Add airflow-version label to the topic + labels = labels or {} + labels['airflow-version'] = 'v' + version.replace('.', '-').replace('+', '-') + + self.log.info("Creating topic (path) %s", topic_path) try: - service.projects().topics().create( # pylint: disable=no-member - name=full_topic, body={}).execute(num_retries=self.num_retries) - except HttpError as e: - # Status code 409 indicates that the topic already exists. - if str(e.resp['status']) == '409': - message = 'Topic already exists: {}'.format(full_topic) - self.log.warning(message) - if fail_if_exists: - raise PubSubException(message) - else: - raise PubSubException( - 'Error creating topic {}'.format(full_topic), e) + # pylint: disable=no-member + publisher.create_topic( + name=topic_path, + labels=labels, + message_storage_policy=message_storage_policy, + kms_key_name=kms_key_name, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + except AlreadyExists: + self.log.warning('Topic already exists: %s', topic) + if fail_if_exists: + raise PubSubException('Topic already exists: {}'.format(topic)) + except GoogleAPICallError as e: + raise PubSubException('Error creating topic {}'.format(topic), e) + + self.log.info("Created topic (path) %s", topic_path) - def delete_topic(self, project: str, topic: str, fail_if_not_exists: bool = False) -> None: + @GoogleCloudBaseHook.fallback_to_default_project_id + def delete_topic( + self, + topic: str, + project_id: Optional[str] = None, + fail_if_not_exists: bool = False, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: """ Deletes a Pub/Sub topic if it exists. - :param project: the GCP project ID in which to delete the topic - :type project: str :param topic: the Pub/Sub topic name to delete; do not include the ``projects/{project}/topics/`` prefix. :type topic: str + :param project_id: Optional, the GCP project ID in which to delete the topic. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str :param fail_if_not_exists: if set, raise an exception if the topic does not exist :type fail_if_not_exists: bool + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]]] """ - service = self.get_conn() - full_topic = _format_topic(project, topic) + assert project_id is not None + publisher = self.get_conn() + topic_path = PublisherClient.topic_path(project_id, topic) # pylint: disable=no-member + + self.log.info("Deleting topic (path) %s", topic_path) try: - service.projects().topics().delete( # pylint: disable=no-member - topic=full_topic).execute(num_retries=self.num_retries) - except HttpError as e: - # Status code 409 indicates that the topic was not found - if str(e.resp['status']) == '404': - message = 'Topic does not exist: {}'.format(full_topic) - self.log.warning(message) - if fail_if_not_exists: - raise PubSubException(message) - else: - raise PubSubException( - 'Error deleting topic {}'.format(full_topic), e) + # pylint: disable=no-member + publisher.delete_topic( + topic=topic_path, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + except NotFound: + self.log.warning('Topic does not exist: %s', topic_path) + if fail_if_not_exists: + raise PubSubException('Topic does not exist: {}'.format(topic_path)) + except GoogleAPICallError as e: + raise PubSubException('Error deleting topic {}'.format(topic), e) + self.log.info("Deleted topic (path) %s", topic_path) + # pylint: disable=too-many-arguments + @GoogleCloudBaseHook.fallback_to_default_project_id def create_subscription( self, - topic_project: str, topic: str, + project_id: Optional[str] = None, subscription: Optional[str] = None, - subscription_project: Optional[str] = None, + subscription_project_id: Optional[str] = None, ack_deadline_secs: int = 10, fail_if_exists: bool = False, + push_config: Optional[Union[Dict, PushConfig]] = None, + retain_acked_messages: Optional[bool] = None, + message_retention_duration: Optional[Union[Dict, Duration]] = None, + labels: Optional[Dict[str, str]] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> str: """ Creates a Pub/Sub subscription, if it does not already exist. - :param topic_project: the GCP project ID of the topic that the - subscription will be bound to. - :type topic_project: str :param topic: the Pub/Sub topic name that the subscription will be bound - to create; do not include the ``projects/{project}/subscriptions/`` - prefix. + to create; do not include the ``projects/{project}/subscriptions/`` prefix. :type topic: str + :param project_id: Optional, the GCP project ID of the topic that the subscription will be bound to. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str :param subscription: the Pub/Sub subscription name. If empty, a random name will be generated using the uuid module :type subscription: str - :param subscription_project: the GCP project ID where the subscription - will be created. If unspecified, ``topic_project`` will be used. - :type subscription_project: str + :param subscription_project_id: the GCP project ID where the subscription + will be created. If unspecified, ``project_id`` will be used. + :type subscription_project_id: str :param ack_deadline_secs: Number of seconds that a subscriber has to acknowledge each message pulled from the subscription :type ack_deadline_secs: int :param fail_if_exists: if set, raise an exception if the topic already exists :type fail_if_exists: bool + :param push_config: If push delivery is used with this subscription, + this field is used to configure it. An empty ``pushConfig`` signifies + that the subscriber will pull and ack messages using API methods. + :type push_config: Union[Dict, google.cloud.pubsub_v1.types.PushConfig] + :param retain_acked_messages: Indicates whether to retain acknowledged + messages. If true, then messages are not expunged from the subscription's + backlog, even if they are acknowledged, until they fall out of the + ``message_retention_duration`` window. This must be true if you would + like to Seek to a timestamp. + :type retain_acked_messages: bool + :param message_retention_duration: How long to retain unacknowledged messages + in the subscription's backlog, from the moment a message is published. If + ``retain_acked_messages`` is true, then this also configures the + retention of acknowledged messages, and thus configures how far back in + time a ``Seek`` can be done. Defaults to 7 days. Cannot be more than 7 + days or less than 10 minutes. + :type message_retention_duration: Union[Dict, google.cloud.pubsub_v1.types.Duration] + :param labels: Client-assigned labels; see + https://cloud.google.com/pubsub/docs/labels + :type labels: Dict[str, str] + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]]] :return: subscription name which will be the system-generated value if the ``subscription`` parameter is not supplied :rtype: str """ - service = self.get_conn() - full_topic = _format_topic(topic_project, topic) + assert project_id is not None + subscriber = self.subscriber_client + if not subscription: subscription = 'sub-{}'.format(uuid4()) - if not subscription_project: - subscription_project = topic_project - full_subscription = _format_subscription(subscription_project, - subscription) - body = { - 'topic': full_topic, - 'ackDeadlineSeconds': ack_deadline_secs - } + if not subscription_project_id: + subscription_project_id = project_id + + # Add airflow-version label to the subscription + labels = labels or {} + labels['airflow-version'] = 'v' + version.replace('.', '-').replace('+', '-') + + # pylint: disable=no-member + subscription_path = SubscriberClient.subscription_path(subscription_project_id, subscription) + topic_path = SubscriberClient.topic_path(project_id, topic) + + self.log.info("Creating subscription (path) %s for topic (path) %a", subscription_path, topic_path) try: - service.projects().subscriptions().create( # pylint: disable=no-member - name=full_subscription, body=body).execute(num_retries=self.num_retries) - except HttpError as e: - # Status code 409 indicates that the subscription already exists. - if str(e.resp['status']) == '409': - message = 'Subscription already exists: {}'.format( - full_subscription) - self.log.warning(message) - if fail_if_exists: - raise PubSubException(message) - else: - raise PubSubException( - 'Error creating subscription {}'.format(full_subscription), - e) + subscriber.create_subscription( + name=subscription_path, + topic=topic_path, + push_config=push_config, + ack_deadline_seconds=ack_deadline_secs, + retain_acked_messages=retain_acked_messages, + message_retention_duration=message_retention_duration, + labels=labels, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + except AlreadyExists: + self.log.warning('Subscription already exists: %s', subscription_path) + if fail_if_exists: + raise PubSubException('Subscription already exists: {}'.format(subscription_path)) + except GoogleAPICallError as e: + raise PubSubException('Error creating subscription {}'.format(subscription_path), e) + + self.log.info("Created subscription (path) %s for topic (path) %s", subscription_path, topic_path) return subscription - def delete_subscription(self, project: str, subscription: str, fail_if_not_exists: bool = False) -> None: + @GoogleCloudBaseHook.fallback_to_default_project_id + def delete_subscription( + self, + subscription: str, + project_id: Optional[str] = None, + fail_if_not_exists: bool = False, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: """ Deletes a Pub/Sub subscription, if it exists. - :param project: the GCP project ID where the subscription exists - :type project: str :param subscription: the Pub/Sub subscription name to delete; do not include the ``projects/{project}/subscriptions/`` prefix. + :param project_id: Optional, the GCP project ID where the subscription exists + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str :type subscription: str - :param fail_if_not_exists: if set, raise an exception if the topic - does not exist + :param fail_if_not_exists: if set, raise an exception if the topic does not exist :type fail_if_not_exists: bool + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]]] """ - service = self.get_conn() - full_subscription = _format_subscription(project, subscription) + assert project_id is not None + subscriber = self.subscriber_client + subscription_path = SubscriberClient.subscription_path(project_id, subscription) # noqa E501 # pylint: disable=no-member,line-too-long + + self.log.info("Deleting subscription (path) %s", subscription_path) try: - service.projects().subscriptions().delete( # pylint: disable=no-member - subscription=full_subscription).execute(num_retries=self.num_retries) - except HttpError as e: - # Status code 404 indicates that the subscription was not found - if str(e.resp['status']) == '404': - message = 'Subscription does not exist: {}'.format( - full_subscription) - self.log.warning(message) - if fail_if_not_exists: - raise PubSubException(message) - else: - raise PubSubException( - 'Error deleting subscription {}'.format(full_subscription), - e) + # pylint: disable=no-member + subscriber.delete_subscription( + subscription=subscription_path, + retry=retry, + timeout=timeout, + metadata=metadata + ) + except NotFound: + self.log.warning('Subscription does not exist: %s', subscription_path) + if fail_if_not_exists: + raise PubSubException('Subscription does not exist: {}'.format(subscription_path)) + except GoogleAPICallError as e: + raise PubSubException('Error deleting subscription {}'.format(subscription_path), e) + + self.log.info("Deleted subscription (path) %s", subscription_path) + + @GoogleCloudBaseHook.fallback_to_default_project_id def pull( - self, project: str, subscription: str, max_messages: int, return_immediately: bool = False + self, + subscription: str, + max_messages: int, + project_id: Optional[str] = None, + return_immediately: bool = False, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, ) -> List[Dict]: """ Pulls up to ``max_messages`` messages from Pub/Sub subscription. - :param project: the GCP project ID where the subscription exists - :type project: str :param subscription: the Pub/Sub subscription name to pull from; do not include the 'projects/{project}/topics/' prefix. :type subscription: str :param max_messages: The maximum number of messages to return from the Pub/Sub API. :type max_messages: int + :param project_id: Optional, the GCP project ID where the subscription exists. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str :param return_immediately: If set, the Pub/Sub API will immediately return if no messages are available. Otherwise, the request will block for an undisclosed, but bounded period of time :type return_immediately: bool + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]]] :return: A list of Pub/Sub ReceivedMessage objects each containing an ``ackId`` property and a ``message`` property, which includes the base64-encoded message content. See https://cloud.google.com/pubsub/docs/reference/rest/v1/projects.subscriptions/pull#ReceivedMessage """ - service = self.get_conn() - full_subscription = _format_subscription(project, subscription) - body = { - 'maxMessages': max_messages, - 'returnImmediately': return_immediately - } + assert project_id is not None + subscriber = self.subscriber_client + subscription_path = SubscriberClient.subscription_path(project_id, subscription) # noqa E501 # pylint: disable=no-member,line-too-long + + self.log.info("Pulling mex %d messages from subscription (path) %s", max_messages, subscription_path) try: - response = service.projects().subscriptions().pull( # pylint: disable=no-member - subscription=full_subscription, body=body).execute(num_retries=self.num_retries) - return response.get('receivedMessages', []) - except HttpError as e: - raise PubSubException( - 'Error pulling messages from subscription {}'.format( - full_subscription), e) + # pylint: disable=no-member + response = subscriber.pull( + subscription=subscription_path, + max_messages=max_messages, + return_immediately=return_immediately, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + result = getattr(response, 'received_messages', []) + self.log.info("Pulled %d messages from subscription (path) %s", len(result), subscription_path) + return result + except (HttpError, GoogleAPICallError) as e: + raise PubSubException('Error pulling messages from subscription {}'.format(subscription_path), e) - def acknowledge(self, project: str, subscription: str, ack_ids: List) -> None: + @GoogleCloudBaseHook.fallback_to_default_project_id + def acknowledge( + self, + subscription: str, + ack_ids: List[str], + project_id: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + ) -> None: """ - Pulls up to ``max_messages`` messages from Pub/Sub subscription. + Acknowledges the messages associated with the ``ack_ids`` from Pub/Sub subscription. - :param project: the GCP project name or ID in which to create - the topic - :type project: str :param subscription: the Pub/Sub subscription name to delete; do not include the 'projects/{project}/topics/' prefix. :type subscription: str :param ack_ids: List of ReceivedMessage ackIds from a previous pull response :type ack_ids: list + :param project_id: Optional, the GCP project name or ID in which to create the topic + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]]] """ - service = self.get_conn() - full_subscription = _format_subscription(project, subscription) + assert project_id is not None + subscriber = self.subscriber_client + subscription_path = SubscriberClient.subscription_path(project_id, subscription) # noqa E501 # pylint: disable=no-member,line-too-long + + self.log.info("Acknowledging %d ack_ids from subscription (path) %s", len(ack_ids), subscription_path) try: - service.projects().subscriptions().acknowledge( # pylint: disable=no-member - subscription=full_subscription, body={'ackIds': ack_ids} - ).execute(num_retries=self.num_retries) - except HttpError as e: + # pylint: disable=no-member + subscriber.acknowledge( + subscription=subscription_path, + ack_ids=ack_ids, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + except (HttpError, GoogleAPICallError) as e: raise PubSubException( 'Error acknowledging {} messages pulled from subscription {}' - .format(len(ack_ids), full_subscription), e) + .format(len(ack_ids), subscription_path), e) + + self.log.info("Acknowledged ack_ids from subscription (path) %s", subscription_path) diff --git a/airflow/gcp/operators/pubsub.py b/airflow/gcp/operators/pubsub.py index e9a1df3cddde1..67c9ee72f1d38 100644 --- a/airflow/gcp/operators/pubsub.py +++ b/airflow/gcp/operators/pubsub.py @@ -19,7 +19,11 @@ """ This module contains Google PubSub operators. """ -from typing import List, Optional +import warnings +from typing import List, Optional, Sequence, Tuple, Dict, Union + +from google.api_core.retry import Retry +from google.cloud.pubsub_v1.types import Duration, PushConfig, MessageStoragePolicy from airflow.gcp.hooks.pubsub import PubSubHook from airflow.models import BaseOperator @@ -56,8 +60,9 @@ class PubSubTopicCreateOperator(BaseOperator): Both ``project`` and ``topic`` are templated so you can use variables in them. - :param project: the GCP project ID where the topic will be created - :type project: str + :param project_id: Optional, the GCP project ID where the topic will be created. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str :param topic: the topic to create. Do not include the full topic path. In other words, instead of ``projects/{project}/topics/{topic}``, provide only @@ -70,34 +75,92 @@ class PubSubTopicCreateOperator(BaseOperator): For this to work, the service account making the request must have domain-wide delegation enabled. :type delegate_to: str + :param labels: Client-assigned labels; see + https://cloud.google.com/pubsub/docs/labels + :type labels: Dict[str, str] + :param message_storage_policy: Policy constraining the set + of Google Cloud Platform regions where messages published to + the topic may be stored. If not present, then no constraints + are in effect. + :type message_storage_policy: + Union[Dict, google.cloud.pubsub_v1.types.MessageStoragePolicy] + :param kms_key_name: The resource name of the Cloud KMS CryptoKey + to be used to protect access to messages published on this topic. + The expected format is + ``projects/*/locations/*/keyRings/*/cryptoKeys/*``. + :type kms_key_name: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]]] + :param project: (Deprecated) the GCP project ID where the topic will be created + :type project: str """ - template_fields = ['project', 'topic'] + template_fields = ['project_id', 'topic'] ui_color = '#0273d4' + # pylint: disable=too-many-arguments @apply_defaults def __init__( self, - project: str, topic: str, + project_id: Optional[str] = None, fail_if_exists: bool = False, gcp_conn_id: str = 'google_cloud_default', delegate_to: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + message_storage_policy: Union[Dict, MessageStoragePolicy] = None, + kms_key_name: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + project: Optional[str] = None, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.project = project + # To preserve backward compatibility + # TODO: remove one day + if project: + warnings.warn( + "The project parameter has been deprecated. You should pass " + "the project_id parameter.", DeprecationWarning, stacklevel=2) + project_id = project + + super().__init__(*args, **kwargs) + self.project_id = project_id self.topic = topic self.fail_if_exists = fail_if_exists self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to + self.labels = labels + self.message_storage_policy = message_storage_policy + self.kms_key_name = kms_key_name + self.retry = retry + self.timeout = timeout + self.metadata = metadata def execute(self, context): hook = PubSubHook(gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to) - hook.create_topic(self.project, self.topic, - fail_if_exists=self.fail_if_exists) + self.log.info("Creating topic %s", self.topic) + hook.create_topic( + project_id=self.project_id, + topic=self.topic, + fail_if_exists=self.fail_if_exists, + labels=self.labels, + message_storage_policy=self.message_storage_policy, + kms_key_name=self.kms_key_name, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata + ) + self.log.info("Created topic %s", self.topic) class PubSubSubscriptionCreateOperator(BaseOperator): @@ -147,8 +210,9 @@ class PubSubSubscriptionCreateOperator(BaseOperator): ``topic_project``, ``topic``, ``subscription``, and ``subscription`` are templated so you can use variables in them. - :param topic_project: the GCP project ID where the topic exists - :type topic_project: str + :param project_id: Optional, the GCP project ID where the topic exists. + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str :param topic: the topic to create. Do not include the full topic path. In other words, instead of ``projects/{project}/topics/{topic}``, provide only @@ -157,9 +221,9 @@ class PubSubSubscriptionCreateOperator(BaseOperator): :param subscription: the Pub/Sub subscription name. If empty, a random name will be generated using the uuid module :type subscription: str - :param subscription_project: the GCP project ID where the subscription + :param subscription_project_id: the GCP project ID where the subscription will be created. If empty, ``topic_project`` will be used. - :type subscription_project: str + :type subscription_project_id: str :param ack_deadline_secs: Number of seconds that a subscriber has to acknowledge each message pulled from the subscription :type ack_deadline_secs: int @@ -170,42 +234,121 @@ class PubSubSubscriptionCreateOperator(BaseOperator): For this to work, the service account making the request must have domain-wide delegation enabled. :type delegate_to: str + :param push_config: If push delivery is used with this subscription, + this field is used to configure it. An empty ``pushConfig`` signifies + that the subscriber will pull and ack messages using API methods. + :type push_config: Union[Dict, google.cloud.pubsub_v1.types.PushConfig] + :param retain_acked_messages: Indicates whether to retain acknowledged + messages. If true, then messages are not expunged from the subscription's + backlog, even if they are acknowledged, until they fall out of the + ``message_retention_duration`` window. This must be true if you would + like to Seek to a timestamp. + :type retain_acked_messages: bool + :param message_retention_duration: How long to retain unacknowledged messages + in the subscription's backlog, from the moment a message is published. If + ``retain_acked_messages`` is true, then this also configures the + retention of acknowledged messages, and thus configures how far back in + time a ``Seek`` can be done. Defaults to 7 days. Cannot be more than 7 + days or less than 10 minutes. + :type message_retention_duration: Union[Dict, google.cloud.pubsub_v1.types.Duration] + :param labels: Client-assigned labels; see + https://cloud.google.com/pubsub/docs/labels + :type labels: Dict[str, str] + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]]] + :param topic_project: (Deprecated) the GCP project ID where the topic exists + :type topic_project: str + :param subscription_project: (Deprecated) the GCP project ID where the subscription + will be created. If empty, ``topic_project`` will be used. + :type subscription_project: str """ - template_fields = ['topic_project', 'topic', 'subscription', - 'subscription_project'] + template_fields = ['project_id', 'topic', 'subscription', 'subscription_project_id'] ui_color = '#0273d4' + # pylint: disable=too-many-arguments @apply_defaults def __init__( self, - topic_project, topic: str, - subscription=None, - subscription_project=None, + project_id: Optional[str] = None, + subscription: Optional[str] = None, + subscription_project_id: Optional[str] = None, ack_deadline_secs: int = 10, fail_if_exists: bool = False, gcp_conn_id: str = 'google_cloud_default', delegate_to: Optional[str] = None, + push_config: Optional[Union[Dict, PushConfig]] = None, + retain_acked_messages: Optional[bool] = None, + message_retention_duration: Optional[Union[Dict, Duration]] = None, + labels: Optional[Dict[str, str]] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + topic_project: Optional[str] = None, + subscription_project: Optional[str] = None, *args, **kwargs) -> None: + + # To preserve backward compatibility + # TODO: remove one day + if topic_project: + warnings.warn( + "The topic_project parameter has been deprecated. You should pass " + "the project_id parameter.", DeprecationWarning, stacklevel=2) + project_id = topic_project + if subscription_project: + warnings.warn( + "The project_id parameter has been deprecated. You should pass " + "the subscription_project parameter.", DeprecationWarning, stacklevel=2) + subscription_project_id = subscription_project + super().__init__(*args, **kwargs) - self.topic_project = topic_project + self.project_id = project_id self.topic = topic self.subscription = subscription - self.subscription_project = subscription_project + self.subscription_project_id = subscription_project_id self.ack_deadline_secs = ack_deadline_secs self.fail_if_exists = fail_if_exists self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to + self.push_config = push_config + self.retain_acked_messages = retain_acked_messages + self.message_retention_duration = message_retention_duration + self.labels = labels + self.retry = retry + self.timeout = timeout + self.metadata = metadata def execute(self, context): hook = PubSubHook(gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to) - return hook.create_subscription( - self.topic_project, self.topic, self.subscription, - self.subscription_project, self.ack_deadline_secs, - self.fail_if_exists) + self.log.info("Creating subscription for topic %s", self.topic) + result = hook.create_subscription( + project_id=self.project_id, + topic=self.topic, + subscription=self.subscription, + subscription_project_id=self.subscription_project_id, + ack_deadline_secs=self.ack_deadline_secs, + fail_if_exists=self.fail_if_exists, + push_config=self.push_config, + retain_acked_messages=self.retain_acked_messages, + message_retention_duration=self.message_retention_duration, + labels=self.labels, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata + ) + + self.log.info("Created subscription for topic %s", self.topic) + return result class PubSubTopicDeleteOperator(BaseOperator): @@ -234,8 +377,9 @@ class PubSubTopicDeleteOperator(BaseOperator): Both ``project`` and ``topic`` are templated so you can use variables in them. - :param project: the GCP project ID in which to work (templated) - :type project: str + :param project_id: Optional, the GCP project ID in which to work (templated). + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str :param topic: the topic to delete. Do not include the full topic path. In other words, instead of ``projects/{project}/topics/{topic}``, provide only @@ -251,34 +395,68 @@ class PubSubTopicDeleteOperator(BaseOperator): For this to work, the service account making the request must have domain-wide delegation enabled. :type delegate_to: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]]] + :param project: (Deprecated) the GCP project ID where the topic will be created + :type project: str """ - template_fields = ['project', 'topic'] + template_fields = ['project_id', 'topic'] ui_color = '#cb4335' @apply_defaults def __init__( self, - project: str, topic: str, + project_id: Optional[str] = None, fail_if_not_exists=False, gcp_conn_id: str = 'google_cloud_default', delegate_to: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + project: Optional[str] = None, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.project = project + # To preserve backward compatibility + # TODO: remove one day + if project: + warnings.warn( + "The project parameter has been deprecated. You should pass " + "the project_id parameter.", DeprecationWarning, stacklevel=2) + project_id = project + + super().__init__(*args, **kwargs) + self.project_id = project_id self.topic = topic self.fail_if_not_exists = fail_if_not_exists self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to + self.retry = retry + self.timeout = timeout + self.metadata = metadata def execute(self, context): hook = PubSubHook(gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to) - hook.delete_topic(self.project, self.topic, - fail_if_not_exists=self.fail_if_not_exists) + self.log.info("Deleting topic %s", self.topic) + hook.delete_topic( + project_id=self.project_id, + topic=self.topic, + fail_if_not_exists=self.fail_if_not_exists, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata + ) + self.log.info("Deleted topic %s", self.topic) class PubSubSubscriptionDeleteOperator(BaseOperator): @@ -309,8 +487,9 @@ class PubSubSubscriptionDeleteOperator(BaseOperator): ``project``, and ``subscription`` are templated so you can use variables in them. - :param project: the GCP project ID in which to work (templated) - :type project: str + :param project_id: Optional, the GCP project ID in which to work (templated). + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str :param subscription: the subscription to delete. Do not include the full subscription path. In other words, instead of ``projects/{project}/subscription/{subscription}``, provide only @@ -326,34 +505,68 @@ class PubSubSubscriptionDeleteOperator(BaseOperator): For this to work, the service account making the request must have domain-wide delegation enabled. :type delegate_to: str + :param retry: (Optional) A retry object used to retry requests. + If None is specified, requests will not be retried. + :type retry: google.api_core.retry.Retry + :param timeout: (Optional) The amount of time, in seconds, to wait for the request + to complete. Note that if retry is specified, the timeout applies to each + individual attempt. + :type timeout: float + :param metadata: (Optional) Additional metadata that is provided to the method. + :type metadata: Sequence[Tuple[str, str]]] + :param project: (Deprecated) the GCP project ID where the topic will be created + :type project: str """ - template_fields = ['project', 'subscription'] + template_fields = ['project_id', 'subscription'] ui_color = '#cb4335' @apply_defaults def __init__( self, - project: str, subscription: str, + project_id: Optional[str] = None, fail_if_not_exists=False, gcp_conn_id: str = 'google_cloud_default', delegate_to: Optional[str] = None, + retry: Optional[Retry] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + project: Optional[str] = None, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.project = project + # To preserve backward compatibility + # TODO: remove one day + if project: + warnings.warn( + "The project parameter has been deprecated. You should pass " + "the project_id parameter.", DeprecationWarning, stacklevel=2) + project_id = project + + super().__init__(*args, **kwargs) + self.project_id = project_id self.subscription = subscription self.fail_if_not_exists = fail_if_not_exists self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to + self.retry = retry + self.timeout = timeout + self.metadata = metadata def execute(self, context): hook = PubSubHook(gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to) - hook.delete_subscription(self.project, self.subscription, - fail_if_not_exists=self.fail_if_not_exists) + self.log.info("Deleting subscription %s", self.subscription) + hook.delete_subscription( + project_id=self.project_id, + subscription=self.subscription, + fail_if_not_exists=self.fail_if_not_exists, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata + ) + self.log.info("Deleted subscription %s", self.subscription) class PubSubPublishOperator(BaseOperator): @@ -363,12 +576,10 @@ class PubSubPublishOperator(BaseOperator): in a single GCP project. If the topic does not exist, this task will fail. :: - from base64 import b64encode as b64e - - m1 = {'data': b64e('Hello, World!'), + m1 = {'data': b'Hello, World!', 'attributes': {'type': 'greeting'} } - m2 = {'data': b64e('Knock, knock')} + m2 = {'data': b'Knock, knock'} m3 = {'attributes': {'foo': ''}} t1 = PubSubPublishOperator( @@ -380,8 +591,9 @@ class PubSubPublishOperator(BaseOperator): ``project`` , ``topic``, and ``messages`` are templated so you can use variables in them. - :param project: the GCP project ID in which to work (templated) - :type project: str + :param project_id: Optional, the GCP project ID in which to work (templated). + If set to None or missing, the default project_id from the GCP connection is used. + :type project_id: str :param topic: the topic to which to publish. Do not include the full topic path. In other words, instead of ``projects/{project}/topics/{topic}``, provide only @@ -390,7 +602,7 @@ class PubSubPublishOperator(BaseOperator): :param messages: a list of messages to be published to the topic. Each message is a dict with one or more of the following keys-value mappings: - * 'data': a base64-encoded string + * 'data': a bytestring (utf-8 encoded) * 'attributes': {'key1': 'value1', ...} Each message must contain at least a non-empty 'data' value or an attribute dict with at least one key (templated). See @@ -403,29 +615,43 @@ class PubSubPublishOperator(BaseOperator): For this to work, the service account making the request must have domain-wide delegation enabled. :type delegate_to: str + :param project: (Deprecated) the GCP project ID where the topic will be created + :type project: str """ - template_fields = ['project', 'topic', 'messages'] + template_fields = ['project_id', 'topic', 'messages'] ui_color = '#0273d4' @apply_defaults def __init__( self, - project: str, topic: str, messages: List, + project_id: Optional[str] = None, gcp_conn_id: str = 'google_cloud_default', delegate_to: Optional[str] = None, + project: Optional[str] = None, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.gcp_conn_id = gcp_conn_id - self.delegate_to = delegate_to - self.project = project + # To preserve backward compatibility + # TODO: remove one day + if project: + warnings.warn( + "The project parameter has been deprecated. You should pass " + "the project_id parameter.", DeprecationWarning, stacklevel=2) + project_id = project + + super().__init__(*args, **kwargs) + self.project_id = project_id self.topic = topic self.messages = messages + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to def execute(self, context): hook = PubSubHook(gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to) - hook.publish(self.project, self.topic, self.messages) + + self.log.info("Publishing to topic %s", self.topic) + hook.publish(project_id=self.project_id, topic=self.topic, messages=self.messages) + self.log.info("Published to topic %s", self.topic) diff --git a/airflow/gcp/sensors/pubsub.py b/airflow/gcp/sensors/pubsub.py index 45af7c901c485..9323eea95e905 100644 --- a/airflow/gcp/sensors/pubsub.py +++ b/airflow/gcp/sensors/pubsub.py @@ -19,8 +19,11 @@ """ This module contains a Google PubSub sensor. """ +import warnings from typing import Optional +from google.protobuf.json_format import MessageToDict + from airflow.gcp.hooks.pubsub import PubSubHook from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults @@ -63,26 +66,35 @@ class PubSubPullSensor(BaseSensorOperator): must have domain-wide delegation enabled. :type delegate_to: str """ - template_fields = ['project', 'subscription'] + template_fields = ['project_id', 'subscription'] ui_color = '#ff7f50' @apply_defaults def __init__( self, - project: str, + project_id: str, subscription: str, max_messages: int = 5, return_immediately: bool = False, ack_messages: bool = False, gcp_conn_id: str = 'google_cloud_default', delegate_to: Optional[str] = None, + project: Optional[str] = None, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + # To preserve backward compatibility + # TODO: remove one day + if project: + warnings.warn( + "The project parameter has been deprecated. You should pass " + "the project_id parameter.", DeprecationWarning, stacklevel=2) + project_id = project + + super().__init__(*args, **kwargs) self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to - self.project = project + self.project_id = project_id self.subscription = subscription self.max_messages = max_messages self.return_immediately = return_immediately @@ -98,11 +110,16 @@ def execute(self, context): def poke(self, context): hook = PubSubHook(gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to) - self._messages = hook.pull( - self.project, self.subscription, self.max_messages, - self.return_immediately) + pulled_messages = hook.pull( + project_id=self.project_id, + subscription=self.subscription, + max_messages=self.max_messages, + return_immediately=self.return_immediately + ) + + self._messages = [MessageToDict(m) for m in pulled_messages] + if self._messages and self.ack_messages: - if self.ack_messages: - ack_ids = [m['ackId'] for m in self._messages if m.get('ackId')] - hook.acknowledge(self.project, self.subscription, ack_ids) + ack_ids = [m['ackId'] for m in self._messages if m.get('ackId')] + hook.acknowledge(project_id=self.project_id, subscription=self.subscription, ack_ids=ack_ids) return self._messages diff --git a/setup.py b/setup.py index 49e2b11739b31..e0fd073448c9a 100644 --- a/setup.py +++ b/setup.py @@ -200,6 +200,7 @@ def write_version(filename: str = os.path.join(*["airflow", "git_version"])): 'google-cloud-dlp>=0.11.0', 'google-cloud-kms>=1.2.1', 'google-cloud-language>=1.1.1', + 'google-cloud-pubsub==1.0.0', 'google-cloud-redis>=0.3.0', 'google-cloud-spanner>=1.10.0', 'google-cloud-speech>=0.36.3', diff --git a/tests/gcp/hooks/test_pubsub.py b/tests/gcp/hooks/test_pubsub.py index 664595e3a51f5..e4b985423319e 100644 --- a/tests/gcp/hooks/test_pubsub.py +++ b/tests/gcp/hooks/test_pubsub.py @@ -17,11 +17,14 @@ # specific language governing permissions and limitations # under the License. -from base64 import b64encode as b64e import unittest +from google.api_core.exceptions import AlreadyExists, GoogleAPICallError +from google.cloud.exceptions import NotFound from googleapiclient.errors import HttpError +from parameterized import parameterized +from airflow.version import version from airflow.gcp.hooks.pubsub import PubSubException, PubSubHook from tests.compat import mock @@ -35,15 +38,15 @@ TEST_UUID = 'abc123-xzy789' TEST_MESSAGES = [ { - 'data': b64e(b'Hello, World!'), + 'data': b'Hello, World!', 'attributes': {'type': 'greeting'} }, - {'data': b64e(b'Knock, knock')}, + {'data': b'Knock, knock'}, {'attributes': {'foo': ''}}] EXPANDED_TOPIC = 'projects/{}/topics/{}'.format(TEST_PROJECT, TEST_TOPIC) -EXPANDED_SUBSCRIPTION = 'projects/{}/subscriptions/{}'.format( - TEST_PROJECT, TEST_SUBSCRIPTION) +EXPANDED_SUBSCRIPTION = 'projects/{}/subscriptions/{}'.format(TEST_PROJECT, TEST_SUBSCRIPTION) +LABELS = {'airflow-version': 'v' + version.replace('.', '-').replace('+', '-')} def mock_init(self, gcp_conn_id, delegate_to=None): # pylint: disable=unused-argument @@ -56,257 +59,402 @@ def setUp(self): new=mock_init): self.pubsub_hook = PubSubHook(gcp_conn_id='test') - @mock.patch("airflow.gcp.hooks.pubsub.PubSubHook._authorize") - @mock.patch("airflow.gcp.hooks.pubsub.build") - def test_pubsub_client_creation(self, mock_build, mock_authorize): + @mock.patch("airflow.gcp.hooks.pubsub.PubSubHook.client_info", new_callable=mock.PropertyMock) + @mock.patch("airflow.gcp.hooks.pubsub.PubSubHook._get_credentials") + @mock.patch("airflow.gcp.hooks.pubsub.PublisherClient") + def test_publisher_client_creation(self, mock_client, mock_get_creds, mock_client_info): + self.assertIsNone(self.pubsub_hook._client) result = self.pubsub_hook.get_conn() - mock_build.assert_called_once_with( - 'pubsub', 'v1', http=mock_authorize.return_value, cache_discovery=False + mock_client.assert_called_once_with( + credentials=mock_get_creds.return_value, + client_info=mock_client_info.return_value ) - self.assertEqual(mock_build.return_value, result) + self.assertEqual(mock_client.return_value, result) + self.assertEqual(self.pubsub_hook._client, result) + + @mock.patch("airflow.gcp.hooks.pubsub.PubSubHook.client_info", new_callable=mock.PropertyMock) + @mock.patch("airflow.gcp.hooks.pubsub.PubSubHook._get_credentials") + @mock.patch("airflow.gcp.hooks.pubsub.SubscriberClient") + def test_subscriber_client_creation(self, mock_client, mock_get_creds, mock_client_info): + self.assertIsNone(self.pubsub_hook._client) + result = self.pubsub_hook.subscriber_client + mock_client.assert_called_once_with( + credentials=mock_get_creds.return_value, + client_info=mock_client_info.return_value + ) + self.assertEqual(mock_client.return_value, result) @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn')) def test_create_nonexistent_topic(self, mock_service): - self.pubsub_hook.create_topic(TEST_PROJECT, TEST_TOPIC) - - create_method = (mock_service.return_value.projects.return_value.topics - .return_value.create) - create_method.assert_called_once_with(body={}, name=EXPANDED_TOPIC) - create_method.return_value.execute.assert_called_once_with(num_retries=mock.ANY) + create_method = mock_service.return_value.create_topic + self.pubsub_hook.create_topic(project_id=TEST_PROJECT, topic=TEST_TOPIC) + create_method.assert_called_once_with( + name=EXPANDED_TOPIC, + labels=LABELS, + message_storage_policy=None, + kms_key_name=None, + retry=None, + timeout=None, + metadata=None + ) @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn')) def test_delete_topic(self, mock_service): - self.pubsub_hook.delete_topic(TEST_PROJECT, TEST_TOPIC) - - delete_method = (mock_service.return_value.projects.return_value.topics - .return_value.delete) - delete_method.assert_called_once_with(topic=EXPANDED_TOPIC) - delete_method.return_value.execute.assert_called_once_with(num_retries=mock.ANY) + delete_method = mock_service.return_value.delete_topic + self.pubsub_hook.delete_topic(project_id=TEST_PROJECT, topic=TEST_TOPIC) + delete_method.assert_called_once_with( + topic=EXPANDED_TOPIC, + retry=None, + timeout=None, + metadata=None + ) @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn')) def test_delete_nonexisting_topic_failifnotexists(self, mock_service): - (mock_service.return_value.projects.return_value.topics - .return_value.delete.return_value.execute.side_effect) = HttpError( - resp={'status': '404'}, content=EMPTY_CONTENT) - + mock_service.return_value.delete_topic.side_effect = NotFound( + 'Topic does not exists: %s' % EXPANDED_TOPIC + ) with self.assertRaises(PubSubException) as e: - self.pubsub_hook.delete_topic(TEST_PROJECT, TEST_TOPIC, True) + self.pubsub_hook.delete_topic(project_id=TEST_PROJECT, topic=TEST_TOPIC, fail_if_not_exists=True) - self.assertEqual(str(e.exception), - 'Topic does not exist: %s' % EXPANDED_TOPIC) + self.assertEqual(str(e.exception), 'Topic does not exist: %s' % EXPANDED_TOPIC) @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn')) - def test_create_preexisting_topic_failifexists(self, mock_service): - (mock_service.return_value.projects.return_value.topics.return_value - .create.return_value.execute.side_effect) = HttpError( - resp={'status': '409'}, content=EMPTY_CONTENT) + def test_delete_topic_api_call_error(self, mock_service): + mock_service.return_value.delete_topic.side_effect = GoogleAPICallError( + 'Error deleting topic: %s' % EXPANDED_TOPIC + ) + with self.assertRaises(PubSubException): + self.pubsub_hook.delete_topic(project_id=TEST_PROJECT, topic=TEST_TOPIC, fail_if_not_exists=True) + @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn')) + def test_create_preexisting_topic_failifexists(self, mock_service): + mock_service.return_value.create_topic.side_effect = AlreadyExists( + 'Topic already exists: %s' % TEST_TOPIC + ) with self.assertRaises(PubSubException) as e: - self.pubsub_hook.create_topic(TEST_PROJECT, TEST_TOPIC, True) - self.assertEqual(str(e.exception), - 'Topic already exists: %s' % EXPANDED_TOPIC) + self.pubsub_hook.create_topic(project_id=TEST_PROJECT, topic=TEST_TOPIC, fail_if_exists=True) + self.assertEqual(str(e.exception), 'Topic already exists: %s' % TEST_TOPIC) @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn')) def test_create_preexisting_topic_nofailifexists(self, mock_service): - (mock_service.return_value.projects.return_value.topics.return_value - .get.return_value.execute.side_effect) = HttpError( - resp={'status': '409'}, content=EMPTY_CONTENT) - - self.pubsub_hook.create_topic(TEST_PROJECT, TEST_TOPIC) + mock_service.return_value.create_topic.side_effect = AlreadyExists( + 'Topic already exists: %s' % EXPANDED_TOPIC + ) + self.pubsub_hook.create_topic(project_id=TEST_PROJECT, topic=TEST_TOPIC) @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn')) + def test_create_topic_api_call_error(self, mock_service): + mock_service.return_value.create_topic.side_effect = GoogleAPICallError( + 'Error creating topic: %s' % TEST_TOPIC + ) + with self.assertRaises(PubSubException): + self.pubsub_hook.create_topic(project_id=TEST_PROJECT, topic=TEST_TOPIC, fail_if_exists=True) + + @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client')) def test_create_nonexistent_subscription(self, mock_service): + create_method = mock_service.create_subscription + response = self.pubsub_hook.create_subscription( - TEST_PROJECT, TEST_TOPIC, TEST_SUBSCRIPTION) - - create_method = ( - mock_service.return_value.projects.return_value.subscriptions. - return_value.create) - expected_body = { - 'topic': EXPANDED_TOPIC, - 'ackDeadlineSeconds': 10 - } - create_method.assert_called_once_with(name=EXPANDED_SUBSCRIPTION, body=expected_body) - create_method.return_value.execute.assert_called_once_with(num_retries=mock.ANY) + project_id=TEST_PROJECT, topic=TEST_TOPIC, subscription=TEST_SUBSCRIPTION + ) + create_method.assert_called_once_with( + name=EXPANDED_SUBSCRIPTION, + topic=EXPANDED_TOPIC, + push_config=None, + ack_deadline_seconds=10, + retain_acked_messages=None, + message_retention_duration=None, + labels=LABELS, + retry=None, + timeout=None, + metadata=None, + ) self.assertEqual(TEST_SUBSCRIPTION, response) - @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn')) + @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client')) def test_create_subscription_different_project_topic(self, mock_service): + create_method = mock_service.create_subscription response = self.pubsub_hook.create_subscription( - TEST_PROJECT, TEST_TOPIC, TEST_SUBSCRIPTION, 'a-different-project') - - create_method = ( - mock_service.return_value.projects.return_value.subscriptions. - return_value.create) - - expected_subscription = 'projects/%s/subscriptions/%s' % ( - 'a-different-project', TEST_SUBSCRIPTION) - expected_body = { - 'topic': EXPANDED_TOPIC, - 'ackDeadlineSeconds': 10 - } - create_method.assert_called_once_with(name=expected_subscription, body=expected_body) - create_method.return_value.execute.assert_called_once_with(num_retries=mock.ANY) + project_id=TEST_PROJECT, + topic=TEST_TOPIC, + subscription=TEST_SUBSCRIPTION, + subscription_project_id='a-different-project' + ) + expected_subscription = 'projects/{}/subscriptions/{}'.format( + 'a-different-project', TEST_SUBSCRIPTION + ) + create_method.assert_called_once_with( + name=expected_subscription, + topic=EXPANDED_TOPIC, + push_config=None, + ack_deadline_seconds=10, + retain_acked_messages=None, + message_retention_duration=None, + labels=LABELS, + retry=None, + timeout=None, + metadata=None, + ) + self.assertEqual(TEST_SUBSCRIPTION, response) - @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn')) + @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client')) def test_delete_subscription(self, mock_service): - self.pubsub_hook.delete_subscription(TEST_PROJECT, TEST_SUBSCRIPTION) - - delete_method = (mock_service.return_value.projects - .return_value.subscriptions.return_value.delete) - delete_method.assert_called_once_with(subscription=EXPANDED_SUBSCRIPTION) - delete_method.return_value.execute.assert_called_once_with(num_retries=mock.ANY) - - @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn')) - def test_delete_nonexisting_subscription_failifnotexists(self, - mock_service): - (mock_service.return_value.projects.return_value.subscriptions. - return_value.delete.return_value.execute.side_effect) = HttpError( - resp={'status': '404'}, content=EMPTY_CONTENT) + self.pubsub_hook.delete_subscription(project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION) + delete_method = mock_service.delete_subscription + delete_method.assert_called_once_with( + subscription=EXPANDED_SUBSCRIPTION, + retry=None, + timeout=None, + metadata=None + ) + @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client')) + def test_delete_nonexisting_subscription_failifnotexists(self, mock_service): + mock_service.delete_subscription.side_effect = NotFound( + 'Subscription does not exists: %s' % EXPANDED_SUBSCRIPTION + ) with self.assertRaises(PubSubException) as e: self.pubsub_hook.delete_subscription( - TEST_PROJECT, TEST_SUBSCRIPTION, fail_if_not_exists=True) - - self.assertEqual(str(e.exception), - 'Subscription does not exist: %s' % - EXPANDED_SUBSCRIPTION) - - @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn')) - @mock.patch(PUBSUB_STRING.format('uuid4'), - new_callable=mock.Mock(return_value=lambda: TEST_UUID)) - def test_create_subscription_without_name(self, mock_uuid, mock_service): # noqa # pylint: disable=unused-argument,line-too-long - response = self.pubsub_hook.create_subscription(TEST_PROJECT, - TEST_TOPIC) - create_method = ( - mock_service.return_value.projects.return_value.subscriptions. - return_value.create) - expected_body = { - 'topic': EXPANDED_TOPIC, - 'ackDeadlineSeconds': 10 - } - expected_name = EXPANDED_SUBSCRIPTION.replace( - TEST_SUBSCRIPTION, 'sub-%s' % TEST_UUID) - create_method.assert_called_once_with(name=expected_name, body=expected_body) - create_method.return_value.execute.assert_called_once_with(num_retries=mock.ANY) + project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, fail_if_not_exists=True + ) + self.assertEqual(str(e.exception), 'Subscription does not exist: %s' % EXPANDED_SUBSCRIPTION) + + @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client')) + def test_delete_subscription_api_call_error(self, mock_service): + mock_service.delete_subscription.side_effect = GoogleAPICallError( + 'Error deleting subscription %s' % EXPANDED_SUBSCRIPTION + ) + with self.assertRaises(PubSubException): + self.pubsub_hook.delete_subscription( + project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, fail_if_not_exists=True + ) + + @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client')) + @mock.patch(PUBSUB_STRING.format('uuid4'), new_callable=mock.Mock(return_value=lambda: TEST_UUID)) + def test_create_subscription_without_subscription_name(self, mock_uuid, mock_service): # noqa # pylint: disable=unused-argument,line-too-long + create_method = mock_service.create_subscription + expected_name = EXPANDED_SUBSCRIPTION.replace(TEST_SUBSCRIPTION, 'sub-%s' % TEST_UUID) + + response = self.pubsub_hook.create_subscription(project_id=TEST_PROJECT, topic=TEST_TOPIC) + create_method.assert_called_once_with( + name=expected_name, + topic=EXPANDED_TOPIC, + push_config=None, + ack_deadline_seconds=10, + retain_acked_messages=None, + message_retention_duration=None, + labels=LABELS, + retry=None, + timeout=None, + metadata=None, + ) self.assertEqual('sub-%s' % TEST_UUID, response) - @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn')) + @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client')) def test_create_subscription_with_ack_deadline(self, mock_service): + create_method = mock_service.create_subscription + response = self.pubsub_hook.create_subscription( - TEST_PROJECT, TEST_TOPIC, TEST_SUBSCRIPTION, ack_deadline_secs=30) - - create_method = ( - mock_service.return_value.projects.return_value.subscriptions. - return_value.create) - expected_body = { - 'topic': EXPANDED_TOPIC, - 'ackDeadlineSeconds': 30 - } - create_method.assert_called_once_with(name=EXPANDED_SUBSCRIPTION, body=expected_body) - create_method.return_value.execute.assert_called_once_with(num_retries=mock.ANY) + project_id=TEST_PROJECT, topic=TEST_TOPIC, subscription=TEST_SUBSCRIPTION, ack_deadline_secs=30 + ) + create_method.assert_called_once_with( + name=EXPANDED_SUBSCRIPTION, + topic=EXPANDED_TOPIC, + push_config=None, + ack_deadline_seconds=30, + retain_acked_messages=None, + message_retention_duration=None, + labels=LABELS, + retry=None, + timeout=None, + metadata=None, + ) self.assertEqual(TEST_SUBSCRIPTION, response) - @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn')) + @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client')) def test_create_subscription_failifexists(self, mock_service): - (mock_service.return_value.projects.return_value. - subscriptions.return_value.create.return_value - .execute.side_effect) = HttpError(resp={'status': '409'}, - content=EMPTY_CONTENT) - + mock_service.create_subscription.side_effect = AlreadyExists( + 'Subscription already exists: %s' % EXPANDED_SUBSCRIPTION + ) with self.assertRaises(PubSubException) as e: self.pubsub_hook.create_subscription( - TEST_PROJECT, TEST_TOPIC, TEST_SUBSCRIPTION, - fail_if_exists=True) - - self.assertEqual(str(e.exception), - 'Subscription already exists: %s' % - EXPANDED_SUBSCRIPTION) + project_id=TEST_PROJECT, topic=TEST_TOPIC, subscription=TEST_SUBSCRIPTION, fail_if_exists=True + ) + self.assertEqual(str(e.exception), 'Subscription already exists: %s' % EXPANDED_SUBSCRIPTION) + + @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client')) + def test_create_subscription_api_call_error(self, mock_service): + mock_service.create_subscription.side_effect = GoogleAPICallError( + 'Error creating subscription %s' % EXPANDED_SUBSCRIPTION + ) + with self.assertRaises(PubSubException): + self.pubsub_hook.create_subscription( + project_id=TEST_PROJECT, topic=TEST_TOPIC, subscription=TEST_SUBSCRIPTION, fail_if_exists=True + ) - @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn')) + @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client')) def test_create_subscription_nofailifexists(self, mock_service): - (mock_service.return_value.projects.return_value.topics.return_value - .get.return_value.execute.side_effect) = HttpError( - resp={'status': '409'}, content=EMPTY_CONTENT) - + mock_service.create_subscription.side_effect = AlreadyExists( + 'Subscription already exists: %s' % EXPANDED_SUBSCRIPTION + ) response = self.pubsub_hook.create_subscription( - TEST_PROJECT, TEST_TOPIC, TEST_SUBSCRIPTION + project_id=TEST_PROJECT, topic=TEST_TOPIC, subscription=TEST_SUBSCRIPTION ) self.assertEqual(TEST_SUBSCRIPTION, response) @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn')) def test_publish(self, mock_service): - self.pubsub_hook.publish(TEST_PROJECT, TEST_TOPIC, TEST_MESSAGES) + publish_method = mock_service.return_value.publish - publish_method = (mock_service.return_value.projects.return_value - .topics.return_value.publish) - publish_method.assert_called_once_with( - topic=EXPANDED_TOPIC, body={'messages': TEST_MESSAGES}) + self.pubsub_hook.publish(project_id=TEST_PROJECT, topic=TEST_TOPIC, messages=TEST_MESSAGES) + calls = [ + mock.call(topic=EXPANDED_TOPIC, data=message.get("data", b''), **message.get('attributes', {})) + for message in TEST_MESSAGES + ] + publish_method.has_calls(calls) @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn')) + def test_publish_api_call_error(self, mock_service): + publish_method = mock_service.return_value.publish + publish_method.side_effect = GoogleAPICallError( + 'Error publishing to topic {}'.format(EXPANDED_SUBSCRIPTION) + ) + + with self.assertRaises(PubSubException): + self.pubsub_hook.publish(project_id=TEST_PROJECT, topic=TEST_TOPIC, messages=TEST_MESSAGES) + + @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client')) def test_pull(self, mock_service): - pull_method = (mock_service.return_value.projects.return_value - .subscriptions.return_value.pull) + pull_method = mock_service.pull pulled_messages = [] for i, msg in enumerate(TEST_MESSAGES): pulled_messages.append({'ackId': i, 'message': msg}) - pull_method.return_value.execute.return_value = { - 'receivedMessages': pulled_messages} + pull_method.return_value.received_messages = pulled_messages - response = self.pubsub_hook.pull(TEST_PROJECT, TEST_SUBSCRIPTION, 10) + response = self.pubsub_hook.pull( + project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, max_messages=10 + ) pull_method.assert_called_once_with( subscription=EXPANDED_SUBSCRIPTION, - body={'maxMessages': 10, 'returnImmediately': False}) + max_messages=10, + return_immediately=False, + retry=None, + timeout=None, + metadata=None, + ) self.assertEqual(pulled_messages, response) - @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn')) + @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client')) def test_pull_no_messages(self, mock_service): - pull_method = (mock_service.return_value.projects.return_value - .subscriptions.return_value.pull) - pull_method.return_value.execute.return_value = { - 'receivedMessages': []} + pull_method = mock_service.pull + pull_method.return_value.received_messages = [] - response = self.pubsub_hook.pull(TEST_PROJECT, TEST_SUBSCRIPTION, 10) + response = self.pubsub_hook.pull( + project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, max_messages=10 + ) pull_method.assert_called_once_with( subscription=EXPANDED_SUBSCRIPTION, - body={'maxMessages': 10, 'returnImmediately': False}) + max_messages=10, + return_immediately=False, + retry=None, + timeout=None, + metadata=None, + ) self.assertListEqual([], response) - @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn')) - def test_pull_fails_on_exception(self, mock_service): - pull_method = (mock_service.return_value.projects.return_value - .subscriptions.return_value.pull) - pull_method.return_value.execute.side_effect = HttpError( - resp={'status': '404'}, content=EMPTY_CONTENT) - - with self.assertRaises(Exception): - self.pubsub_hook.pull(TEST_PROJECT, TEST_SUBSCRIPTION, 10) + @parameterized.expand([ + (exception, ) for exception in [ + HttpError(resp={'status': '404'}, content=EMPTY_CONTENT), + GoogleAPICallError("API Call Error") + ] + ]) + @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client')) + def test_pull_fails_on_exception(self, exception, mock_service): + pull_method = mock_service.pull + pull_method.side_effect = exception + + with self.assertRaises(PubSubException): + self.pubsub_hook.pull(project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, max_messages=10) pull_method.assert_called_once_with( subscription=EXPANDED_SUBSCRIPTION, - body={'maxMessages': 10, 'returnImmediately': False}) - - @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn')) + max_messages=10, + return_immediately=False, + retry=None, + timeout=None, + metadata=None, + ) + + @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client')) def test_acknowledge(self, mock_service): - ack_method = (mock_service.return_value.projects.return_value - .subscriptions.return_value.acknowledge) + ack_method = mock_service.acknowledge + self.pubsub_hook.acknowledge( - TEST_PROJECT, TEST_SUBSCRIPTION, ['1', '2', '3']) + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + ack_ids=['1', '2', '3'] + ) ack_method.assert_called_once_with( subscription=EXPANDED_SUBSCRIPTION, - body={'ackIds': ['1', '2', '3']}) - - @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn')) - def test_acknowledge_fails_on_exception(self, mock_service): - ack_method = (mock_service.return_value.projects.return_value - .subscriptions.return_value.acknowledge) - ack_method.return_value.execute.side_effect = HttpError( - resp={'status': '404'}, content=EMPTY_CONTENT) + ack_ids=['1', '2', '3'], + retry=None, + timeout=None, + metadata=None + ) - with self.assertRaises(Exception) as e: + @parameterized.expand([ + (exception, ) for exception in [ + HttpError(resp={'status': '404'}, content=EMPTY_CONTENT), + GoogleAPICallError("API Call Error") + ] + ]) + @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client')) + def test_acknowledge_fails_on_exception(self, exception, mock_service): + ack_method = mock_service.acknowledge + ack_method.side_effect = exception + + with self.assertRaises(PubSubException): self.pubsub_hook.acknowledge( - TEST_PROJECT, TEST_SUBSCRIPTION, ['1', '2', '3']) + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + ack_ids=['1', '2', '3'] + ) ack_method.assert_called_once_with( subscription=EXPANDED_SUBSCRIPTION, - body={'ackIds': ['1', '2', '3']}) - print(e) + ack_ids=['1', '2', '3'], + retry=None, + timeout=None, + metadata=None + ) + + @parameterized.expand([ + (messages, ) for messages in [ + [{"data": b'test'}], + [{"data": b''}], + [{"data": b'test', "attributes": {"weight": "100kg"}}], + [{"data": b'', "attributes": {"weight": "100kg"}}], + [{"attributes": {"weight": "100kg"}}], + ] + ]) + def test_messages_validation_positive(self, messages): + PubSubHook._validate_messages(messages) + + @parameterized.expand([ + ([("wrong type", )], "Wrong message type. Must be a dictionary."), + ([{"wrong_key": b'test'}], "Wrong message. Dictionary must contain 'data' or 'attributes'."), + ([{"data": 'wrong string'}], "Wrong message. 'data' must be send as a bytestring"), + ([{"data": None}], "Wrong message. 'data' must be send as a bytestring"), + ( + [{"attributes": None}], + "Wrong message. If 'data' is not provided 'attributes' must be a non empty dictionary." + ), + ( + [{"attributes": "wrong string"}], + "Wrong message. If 'data' is not provided 'attributes' must be a non empty dictionary." + ) + ]) + def test_messages_validation_negative(self, messages, error_message): + with self.assertRaises(PubSubException) as e: + PubSubHook._validate_messages(messages) + self.assertEqual(str(e.exception), error_message) diff --git a/tests/gcp/operators/test_pubsub.py b/tests/gcp/operators/test_pubsub.py index 065e7aa8b1756..4d0ad33950cb6 100644 --- a/tests/gcp/operators/test_pubsub.py +++ b/tests/gcp/operators/test_pubsub.py @@ -17,7 +17,6 @@ # specific language governing permissions and limitations # under the License. -from base64 import b64encode as b64e import unittest from airflow.gcp.operators.pubsub import ( @@ -32,10 +31,10 @@ TEST_SUBSCRIPTION = 'test-subscription' TEST_MESSAGES = [ { - 'data': b64e(b'Hello, World!'), + 'data': b'Hello, World!', 'attributes': {'type': 'greeting'} }, - {'data': b64e(b'Knock, knock')}, + {'data': b'Knock, knock'}, {'attributes': {'foo': ''}}] TEST_POKE_INTERVAl = 0 @@ -44,38 +43,68 @@ class TestPubSubTopicCreateOperator(unittest.TestCase): @mock.patch('airflow.gcp.operators.pubsub.PubSubHook') def test_failifexists(self, mock_hook): - operator = PubSubTopicCreateOperator(task_id=TASK_ID, - project=TEST_PROJECT, - topic=TEST_TOPIC, - fail_if_exists=True) + operator = PubSubTopicCreateOperator( + task_id=TASK_ID, + project_id=TEST_PROJECT, + topic=TEST_TOPIC, + fail_if_exists=True + ) operator.execute(None) mock_hook.return_value.create_topic.assert_called_once_with( - TEST_PROJECT, TEST_TOPIC, fail_if_exists=True) + project_id=TEST_PROJECT, + topic=TEST_TOPIC, + fail_if_exists=True, + labels=None, + message_storage_policy=None, + kms_key_name=None, + retry=None, + timeout=None, + metadata=None, + ) @mock.patch('airflow.gcp.operators.pubsub.PubSubHook') def test_succeedifexists(self, mock_hook): - operator = PubSubTopicCreateOperator(task_id=TASK_ID, - project=TEST_PROJECT, - topic=TEST_TOPIC, - fail_if_exists=False) + operator = PubSubTopicCreateOperator( + task_id=TASK_ID, + project_id=TEST_PROJECT, + topic=TEST_TOPIC, + fail_if_exists=False + ) operator.execute(None) mock_hook.return_value.create_topic.assert_called_once_with( - TEST_PROJECT, TEST_TOPIC, fail_if_exists=False) + project_id=TEST_PROJECT, + topic=TEST_TOPIC, + fail_if_exists=False, + labels=None, + message_storage_policy=None, + kms_key_name=None, + retry=None, + timeout=None, + metadata=None + ) class TestPubSubTopicDeleteOperator(unittest.TestCase): @mock.patch('airflow.gcp.operators.pubsub.PubSubHook') def test_execute(self, mock_hook): - operator = PubSubTopicDeleteOperator(task_id=TASK_ID, - project=TEST_PROJECT, - topic=TEST_TOPIC) + operator = PubSubTopicDeleteOperator( + task_id=TASK_ID, + project_id=TEST_PROJECT, + topic=TEST_TOPIC + ) operator.execute(None) mock_hook.return_value.delete_topic.assert_called_once_with( - TEST_PROJECT, TEST_TOPIC, fail_if_not_exists=False) + project_id=TEST_PROJECT, + topic=TEST_TOPIC, + fail_if_not_exists=False, + retry=None, + timeout=None, + metadata=None + ) class TestPubSubSubscriptionCreateOperator(unittest.TestCase): @@ -83,40 +112,83 @@ class TestPubSubSubscriptionCreateOperator(unittest.TestCase): @mock.patch('airflow.gcp.operators.pubsub.PubSubHook') def test_execute(self, mock_hook): operator = PubSubSubscriptionCreateOperator( - task_id=TASK_ID, topic_project=TEST_PROJECT, topic=TEST_TOPIC, - subscription=TEST_SUBSCRIPTION) - mock_hook.return_value.create_subscription.return_value = ( - TEST_SUBSCRIPTION) + task_id=TASK_ID, + project_id=TEST_PROJECT, + topic=TEST_TOPIC, + subscription=TEST_SUBSCRIPTION + ) + mock_hook.return_value.create_subscription.return_value = TEST_SUBSCRIPTION response = operator.execute(None) mock_hook.return_value.create_subscription.assert_called_once_with( - TEST_PROJECT, TEST_TOPIC, TEST_SUBSCRIPTION, None, - 10, False) + project_id=TEST_PROJECT, + topic=TEST_TOPIC, + subscription=TEST_SUBSCRIPTION, + subscription_project_id=None, + ack_deadline_secs=10, + fail_if_exists=False, + push_config=None, + retain_acked_messages=None, + message_retention_duration=None, + labels=None, + retry=None, + timeout=None, + metadata=None, + ) self.assertEqual(response, TEST_SUBSCRIPTION) @mock.patch('airflow.gcp.operators.pubsub.PubSubHook') def test_execute_different_project_ids(self, mock_hook): another_project = 'another-project' operator = PubSubSubscriptionCreateOperator( - task_id=TASK_ID, topic_project=TEST_PROJECT, topic=TEST_TOPIC, + project_id=TEST_PROJECT, + topic=TEST_TOPIC, subscription=TEST_SUBSCRIPTION, - subscription_project=another_project) - mock_hook.return_value.create_subscription.return_value = ( - TEST_SUBSCRIPTION) + subscription_project_id=another_project, + task_id=TASK_ID + ) + mock_hook.return_value.create_subscription.return_value = TEST_SUBSCRIPTION response = operator.execute(None) mock_hook.return_value.create_subscription.assert_called_once_with( - TEST_PROJECT, TEST_TOPIC, TEST_SUBSCRIPTION, another_project, - 10, False) + project_id=TEST_PROJECT, + topic=TEST_TOPIC, + subscription=TEST_SUBSCRIPTION, + subscription_project_id=another_project, + ack_deadline_secs=10, + fail_if_exists=False, + push_config=None, + retain_acked_messages=None, + message_retention_duration=None, + labels=None, + retry=None, + timeout=None, + metadata=None + ) self.assertEqual(response, TEST_SUBSCRIPTION) @mock.patch('airflow.gcp.operators.pubsub.PubSubHook') def test_execute_no_subscription(self, mock_hook): operator = PubSubSubscriptionCreateOperator( - task_id=TASK_ID, topic_project=TEST_PROJECT, topic=TEST_TOPIC) - mock_hook.return_value.create_subscription.return_value = ( - TEST_SUBSCRIPTION) + task_id=TASK_ID, + project_id=TEST_PROJECT, + topic=TEST_TOPIC + ) + mock_hook.return_value.create_subscription.return_value = TEST_SUBSCRIPTION response = operator.execute(None) mock_hook.return_value.create_subscription.assert_called_once_with( - TEST_PROJECT, TEST_TOPIC, None, None, 10, False) + project_id=TEST_PROJECT, + topic=TEST_TOPIC, + subscription=None, + subscription_project_id=None, + ack_deadline_secs=10, + fail_if_exists=False, + push_config=None, + retain_acked_messages=None, + message_retention_duration=None, + labels=None, + retry=None, + timeout=None, + metadata=None, + ) self.assertEqual(response, TEST_SUBSCRIPTION) @@ -125,12 +197,20 @@ class TestPubSubSubscriptionDeleteOperator(unittest.TestCase): @mock.patch('airflow.gcp.operators.pubsub.PubSubHook') def test_execute(self, mock_hook): operator = PubSubSubscriptionDeleteOperator( - task_id=TASK_ID, project=TEST_PROJECT, - subscription=TEST_SUBSCRIPTION) + task_id=TASK_ID, + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION + ) operator.execute(None) mock_hook.return_value.delete_subscription.assert_called_once_with( - TEST_PROJECT, TEST_SUBSCRIPTION, fail_if_not_exists=False) + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + fail_if_not_exists=False, + retry=None, + timeout=None, + metadata=None + ) class TestPubSubPublishOperator(unittest.TestCase): @@ -138,10 +218,11 @@ class TestPubSubPublishOperator(unittest.TestCase): @mock.patch('airflow.gcp.operators.pubsub.PubSubHook') def test_publish(self, mock_hook): operator = PubSubPublishOperator(task_id=TASK_ID, - project=TEST_PROJECT, + project_id=TEST_PROJECT, topic=TEST_TOPIC, messages=TEST_MESSAGES) operator.execute(None) mock_hook.return_value.publish.assert_called_once_with( - TEST_PROJECT, TEST_TOPIC, TEST_MESSAGES) + project_id=TEST_PROJECT, topic=TEST_TOPIC, messages=TEST_MESSAGES + ) diff --git a/tests/gcp/sensors/test_pubsub.py b/tests/gcp/sensors/test_pubsub.py index 701e32bb14843..12261e380afe2 100644 --- a/tests/gcp/sensors/test_pubsub.py +++ b/tests/gcp/sensors/test_pubsub.py @@ -19,7 +19,8 @@ import unittest -from base64 import b64encode as b64e +from google.cloud.pubsub_v1.types import ReceivedMessage +from google.protobuf.json_format import ParseDict, MessageToDict from airflow.gcp.sensors.pubsub import PubSubPullSensor from airflow.exceptions import AirflowSensorTimeout @@ -27,69 +28,81 @@ TASK_ID = 'test-task-id' TEST_PROJECT = 'test-project' -TEST_TOPIC = 'test-topic' TEST_SUBSCRIPTION = 'test-subscription' -TEST_MESSAGES = [ - { - 'data': b64e(b'Hello, World!'), - 'attributes': {'type': 'greeting'} - }, - {'data': b64e(b'Knock, knock')}, - {'attributes': {'foo': ''}}] class TestPubSubPullSensor(unittest.TestCase): - def _generate_messages(self, count): - messages = [] - for i in range(1, count + 1): - messages.append({ - 'ackId': '%s' % i, - 'message': { - 'data': b64e('Message {}'.format(i).encode('utf8')), - 'attributes': {'type': 'generated message'} - } - }) - return messages + return [ + ParseDict( + { + "ack_id": "%s" % i, + "message": { + "data": 'Message {}'.format(i).encode('utf8'), + "attributes": {"type": "generated message"}, + }, + }, + ReceivedMessage(), + ) + for i in range(1, count + 1) + ] + + def _generate_dicts(self, count): + return [MessageToDict(m) for m in self._generate_messages(count)] @mock.patch('airflow.gcp.sensors.pubsub.PubSubHook') def test_poke_no_messages(self, mock_hook): - operator = PubSubPullSensor(task_id=TASK_ID, project=TEST_PROJECT, + operator = PubSubPullSensor(task_id=TASK_ID, project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION) mock_hook.return_value.pull.return_value = [] self.assertEqual([], operator.poke(None)) @mock.patch('airflow.gcp.sensors.pubsub.PubSubHook') def test_poke_with_ack_messages(self, mock_hook): - operator = PubSubPullSensor(task_id=TASK_ID, project=TEST_PROJECT, + operator = PubSubPullSensor(task_id=TASK_ID, project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, ack_messages=True) generated_messages = self._generate_messages(5) + generated_dicts = self._generate_dicts(5) mock_hook.return_value.pull.return_value = generated_messages - self.assertEqual(generated_messages, operator.poke(None)) + self.assertEqual(generated_dicts, operator.poke(None)) mock_hook.return_value.acknowledge.assert_called_once_with( - TEST_PROJECT, TEST_SUBSCRIPTION, ['1', '2', '3', '4', '5'] + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + ack_ids=['1', '2', '3', '4', '5'] ) @mock.patch('airflow.gcp.sensors.pubsub.PubSubHook') def test_execute(self, mock_hook): - operator = PubSubPullSensor(task_id=TASK_ID, project=TEST_PROJECT, - subscription=TEST_SUBSCRIPTION, - poke_interval=0) + operator = PubSubPullSensor( + task_id=TASK_ID, + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + poke_interval=0 + ) generated_messages = self._generate_messages(5) + generated_dicts = self._generate_dicts(5) mock_hook.return_value.pull.return_value = generated_messages response = operator.execute(None) mock_hook.return_value.pull.assert_called_once_with( - TEST_PROJECT, TEST_SUBSCRIPTION, 5, False) - self.assertEqual(response, generated_messages) + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + max_messages=5, + return_immediately=False + ) + self.assertEqual(generated_dicts, response) @mock.patch('airflow.gcp.sensors.pubsub.PubSubHook') def test_execute_timeout(self, mock_hook): - operator = PubSubPullSensor(task_id=TASK_ID, project=TEST_PROJECT, + operator = PubSubPullSensor(task_id=TASK_ID, project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, poke_interval=0, timeout=1) mock_hook.return_value.pull.return_value = [] with self.assertRaises(AirflowSensorTimeout): operator.execute(None) mock_hook.return_value.pull.assert_called_once_with( - TEST_PROJECT, TEST_SUBSCRIPTION, 5, False) + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + max_messages=5, + return_immediately=False + )