diff --git a/providers/google/docs/operators/cloud/managed_kafka.rst b/providers/google/docs/operators/cloud/managed_kafka.rst index 0016076f183f7..a81f81592eeda 100644 --- a/providers/google/docs/operators/cloud/managed_kafka.rst +++ b/providers/google/docs/operators/cloud/managed_kafka.rst @@ -69,6 +69,54 @@ To update cluster you can use :start-after: [START how_to_cloud_managed_kafka_update_cluster_operator] :end-before: [END how_to_cloud_managed_kafka_update_cluster_operator] +Interacting with Apache Kafka Topics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To create an Apache Kafka topic you can use +:class:`~airflow.providers.google.cloud.operators.managed_kafka.ManagedKafkaCreateTopicOperator`. + +.. exampleinclude:: /../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_topic.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_managed_kafka_create_topic_operator] + :end-before: [END how_to_cloud_managed_kafka_create_topic_operator] + +To delete topic you can use +:class:`~airflow.providers.google.cloud.operators.managed_kafka.ManagedKafkaDeleteTopicOperator`. + +.. exampleinclude:: /../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_topic.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_managed_kafka_delete_topic_operator] + :end-before: [END how_to_cloud_managed_kafka_delete_topic_operator] + +To get topic you can use +:class:`~airflow.providers.google.cloud.operators.managed_kafka.ManagedKafkaGetTopicOperator`. + +.. exampleinclude:: /../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_topic.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_managed_kafka_get_topic_operator] + :end-before: [END how_to_cloud_managed_kafka_get_topic_operator] + +To get a list of topics you can use +:class:`~airflow.providers.google.cloud.operators.managed_kafka.ManagedKafkaListTopicsOperator`. + +.. exampleinclude:: /../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_topic.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_managed_kafka_list_topic_operator] + :end-before: [END how_to_cloud_managed_kafka_list_topic_operator] + +To update topic you can use +:class:`~airflow.providers.google.cloud.operators.managed_kafka.ManagedKafkaUpdateTopicOperator`. + +.. exampleinclude:: /../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_topic.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_managed_kafka_update_topic_operator] + :end-before: [END how_to_cloud_managed_kafka_update_topic_operator] + Reference ^^^^^^^^^ diff --git a/providers/google/provider.yaml b/providers/google/provider.yaml index 5575cfc747f18..5a67f96bb33f8 100644 --- a/providers/google/provider.yaml +++ b/providers/google/provider.yaml @@ -1229,6 +1229,7 @@ extra-links: - airflow.providers.google.cloud.links.translate.TranslationGlossariesListLink - airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaClusterLink - airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaClusterListLink + - airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaTopicLink secrets-backends: diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/managed_kafka.py b/providers/google/src/airflow/providers/google/cloud/hooks/managed_kafka.py index 48768666f8fe0..aec8d92f997a4 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/managed_kafka.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/managed_kafka.py @@ -27,12 +27,12 @@ from airflow.providers.google.common.consts import CLIENT_INFO from airflow.providers.google.common.hooks.base_google import GoogleBaseHook from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault -from google.cloud.managedkafka_v1 import Cluster, ManagedKafkaClient, types +from google.cloud.managedkafka_v1 import Cluster, ManagedKafkaClient, Topic, types if TYPE_CHECKING: from google.api_core.operation import Operation from google.api_core.retry import Retry - from google.cloud.managedkafka_v1.services.managed_kafka.pagers import ListClustersPager + from google.cloud.managedkafka_v1.services.managed_kafka.pagers import ListClustersPager, ListTopicsPager from google.protobuf.field_mask_pb2 import FieldMask @@ -286,3 +286,197 @@ def delete_cluster( metadata=metadata, ) return operation + + @GoogleBaseHook.fallback_to_default_project_id + def create_topic( + self, + project_id: str, + location: str, + cluster_id: str, + topic_id: str, + topic: types.Topic | dict, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> types.Topic: + """ + Create a new topic in a given project and location. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param cluster_id: Required. The ID of the cluster in which to create the topic. + :param topic_id: Required. The ID to use for the topic, which will become the final component of the + topic's name. + :param topic: Required. Configuration of the topic to create. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_managed_kafka_client() + parent = client.cluster_path(project_id, location, cluster_id) + + result = client.create_topic( + request={ + "parent": parent, + "topic_id": topic_id, + "topic": topic, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def list_topics( + self, + project_id: str, + location: str, + cluster_id: str, + page_size: int | None = None, + page_token: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> ListTopicsPager: + """ + List the topics in a given cluster. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param cluster_id: Required. The ID of the cluster whose topics are to be listed. + :param page_size: Optional. The maximum number of topics to return. The service may return fewer than + this value. If unset or zero, all topics for the parent is returned. + :param page_token: Optional. A page token, received from a previous ``ListTopics`` call. Provide this + to retrieve the subsequent page. When paginating, all other parameters provided to ``ListTopics`` + must match the call that provided the page token. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_managed_kafka_client() + parent = client.cluster_path(project_id, location, cluster_id) + + result = client.list_topics( + request={ + "parent": parent, + "page_size": page_size, + "page_token": page_token, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def get_topic( + self, + project_id: str, + location: str, + cluster_id: str, + topic_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> types.Topic: + """ + Return the properties of a single topic. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param cluster_id: Required. The ID of the cluster whose topic is to be returned. + :param topic_id: Required. The ID of the topic whose configuration to return. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_managed_kafka_client() + name = client.topic_path(project_id, location, cluster_id, topic_id) + + result = client.get_topic( + request={ + "name": name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def update_topic( + self, + project_id: str, + location: str, + cluster_id: str, + topic_id: str, + topic: types.Topic | dict, + update_mask: FieldMask | dict, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> types.Topic: + """ + Update the properties of a single topic. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param cluster_id: Required. The ID of the cluster whose topic is to be updated. + :param topic_id: Required. The ID of the topic whose configuration to update. + :param topic: Required. The topic to update. Its ``name`` field must be populated. + :param update_mask: Required. Field mask is used to specify the fields to be overwritten in the Topic + resource by the update. The fields specified in the update_mask are relative to the resource, not + the full request. A field will be overwritten if it is in the mask. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_managed_kafka_client() + _topic = deepcopy(topic) if isinstance(topic, dict) else Topic.to_dict(topic) + _topic["name"] = client.topic_path(project_id, location, cluster_id, topic_id) + + result = client.update_topic( + request={ + "update_mask": update_mask, + "topic": _topic, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def delete_topic( + self, + project_id: str, + location: str, + cluster_id: str, + topic_id: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> None: + """ + Delete a single topic. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param cluster_id: Required. The ID of the cluster whose topic is to be deleted. + :param topic_id: Required. The ID of the topic to delete. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_managed_kafka_client() + name = client.topic_path(project_id, location, cluster_id, topic_id) + + client.delete_topic( + request={ + "name": name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/managed_kafka.py b/providers/google/src/airflow/providers/google/cloud/links/managed_kafka.py index 00c626b3814a8..0aafe2f202daa 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/managed_kafka.py +++ b/providers/google/src/airflow/providers/google/cloud/links/managed_kafka.py @@ -28,6 +28,9 @@ MANAGED_KAFKA_BASE_LINK + "/{location}/clusters/{cluster_id}?project={project_id}" ) MANAGED_KAFKA_CLUSTER_LIST_LINK = MANAGED_KAFKA_BASE_LINK + "/clusters?project={project_id}" +MANAGED_KAFKA_TOPIC_LINK = ( + MANAGED_KAFKA_BASE_LINK + "/{location}/clusters/{cluster_id}/topics/{topic_id}?project={project_id}" +) class ApacheKafkaClusterLink(BaseGoogleLink): @@ -73,3 +76,29 @@ def persist( "project_id": task_instance.project_id, }, ) + + +class ApacheKafkaTopicLink(BaseGoogleLink): + """Helper class for constructing Apache Kafka Topic link.""" + + name = "Apache Kafka Topic" + key = "topic_conf" + format_str = MANAGED_KAFKA_TOPIC_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + cluster_id: str, + topic_id: str, + ): + task_instance.xcom_push( + context=context, + key=ApacheKafkaTopicLink.key, + value={ + "location": task_instance.location, + "cluster_id": cluster_id, + "topic_id": topic_id, + "project_id": task_instance.project_id, + }, + ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py b/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py index ebf03856216dd..2afb30fede904 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py @@ -28,6 +28,7 @@ from airflow.providers.google.cloud.links.managed_kafka import ( ApacheKafkaClusterLink, ApacheKafkaClusterListLink, + ApacheKafkaTopicLink, ) from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator from google.api_core.exceptions import AlreadyExists, NotFound @@ -449,3 +450,339 @@ def execute(self, context: Context): except NotFound as not_found_err: self.log.info("The Apache Kafka cluster ID %s does not exist.", self.cluster_id) raise AirflowException(not_found_err) + + +class ManagedKafkaCreateTopicOperator(ManagedKafkaBaseOperator): + """ + Create a new topic in a given project and location. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param cluster_id: Required. The ID of the cluster in which to create the topic. + :param topic_id: Required. The ID to use for the topic, which will become the final component of the + topic's name. + :param topic: Required. Configuration of the topic to create. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"cluster_id", "topic_id", "topic"} | set(ManagedKafkaBaseOperator.template_fields) + ) + operator_extra_links = (ApacheKafkaTopicLink(),) + + def __init__( + self, + cluster_id: str, + topic_id: str, + topic: types.Topic | dict, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_id = cluster_id + self.topic_id = topic_id + self.topic = topic + + def execute(self, context: Context): + self.log.info("Creating an Apache Kafka topic.") + ApacheKafkaTopicLink.persist( + context=context, + task_instance=self, + cluster_id=self.cluster_id, + topic_id=self.topic_id, + ) + try: + topic_obj = self.hook.create_topic( + project_id=self.project_id, + location=self.location, + cluster_id=self.cluster_id, + topic_id=self.topic_id, + topic=self.topic, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("Apache Kafka topic for %s cluster was created.", self.cluster_id) + return types.Topic.to_dict(topic_obj) + except AlreadyExists: + self.log.info("Apache Kafka topic %s already exists.", self.topic_id) + topic_obj = self.hook.get_topic( + project_id=self.project_id, + location=self.location, + cluster_id=self.cluster_id, + topic_id=self.topic_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + return types.Topic.to_dict(topic_obj) + + +class ManagedKafkaListTopicsOperator(ManagedKafkaBaseOperator): + """ + List the topics in a given cluster. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param cluster_id: Required. The ID of the cluster whose topics are to be listed. + :param page_size: Optional. The maximum number of topics to return. The service may return fewer than + this value. If unset or zero, all topics for the parent is returned. + :param page_token: Optional. A page token, received from a previous ``ListTopics`` call. Provide this + to retrieve the subsequent page. When paginating, all other parameters provided to ``ListTopics`` + must match the call that provided the page token. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple({"cluster_id"} | set(ManagedKafkaBaseOperator.template_fields)) + operator_extra_links = (ApacheKafkaClusterLink(),) + + def __init__( + self, + cluster_id: str, + page_size: int | None = None, + page_token: str | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_id = cluster_id + self.page_size = page_size + self.page_token = page_token + + def execute(self, context: Context): + ApacheKafkaClusterLink.persist(context=context, task_instance=self, cluster_id=self.cluster_id) + self.log.info("Listing Topics for cluster %s.", self.cluster_id) + try: + topic_list_pager = self.hook.list_topics( + project_id=self.project_id, + location=self.location, + cluster_id=self.cluster_id, + page_size=self.page_size, + page_token=self.page_token, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.xcom_push( + context=context, + key="topic_page", + value=types.ListTopicsResponse.to_dict(topic_list_pager._response), + ) + except Exception as error: + raise AirflowException(error) + return [types.Topic.to_dict(topic) for topic in topic_list_pager] + + +class ManagedKafkaGetTopicOperator(ManagedKafkaBaseOperator): + """ + Return the properties of a single topic. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param cluster_id: Required. The ID of the cluster whose topic is to be returned. + :param topic_id: Required. The ID of the topic whose configuration to return. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"cluster_id", "topic_id"} | set(ManagedKafkaBaseOperator.template_fields) + ) + operator_extra_links = (ApacheKafkaTopicLink(),) + + def __init__( + self, + cluster_id: str, + topic_id: str, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_id = cluster_id + self.topic_id = topic_id + + def execute(self, context: Context): + ApacheKafkaTopicLink.persist( + context=context, + task_instance=self, + cluster_id=self.cluster_id, + topic_id=self.topic_id, + ) + self.log.info("Getting Topic: %s", self.topic_id) + try: + topic = self.hook.get_topic( + project_id=self.project_id, + location=self.location, + cluster_id=self.cluster_id, + topic_id=self.topic_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("The topic %s from cluster %s was retrieved.", self.topic_id, self.cluster_id) + return types.Topic.to_dict(topic) + except NotFound as not_found_err: + self.log.info("The Topic %s does not exist.", self.topic_id) + raise AirflowException(not_found_err) + + +class ManagedKafkaUpdateTopicOperator(ManagedKafkaBaseOperator): + """ + Update the properties of a single topic. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param cluster_id: Required. The ID of the cluster whose topic is to be updated. + :param topic_id: Required. The ID of the topic whose configuration to update. + :param topic: Required. The topic to update. Its ``name`` field must be populated. + :param update_mask: Required. Field mask is used to specify the fields to be overwritten in the Topic + resource by the update. The fields specified in the update_mask are relative to the resource, not + the full request. A field will be overwritten if it is in the mask. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"cluster_id", "topic_id", "topic", "update_mask"} | set(ManagedKafkaBaseOperator.template_fields) + ) + operator_extra_links = (ApacheKafkaTopicLink(),) + + def __init__( + self, + cluster_id: str, + topic_id: str, + topic: types.Topic | dict, + update_mask: FieldMask | dict, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_id = cluster_id + self.topic_id = topic_id + self.topic = topic + self.update_mask = update_mask + + def execute(self, context: Context): + ApacheKafkaTopicLink.persist( + context=context, + task_instance=self, + cluster_id=self.cluster_id, + topic_id=self.topic_id, + ) + self.log.info("Updating an Apache Kafka topic.") + try: + topic_obj = self.hook.update_topic( + project_id=self.project_id, + location=self.location, + cluster_id=self.cluster_id, + topic_id=self.topic_id, + topic=self.topic, + update_mask=self.update_mask, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("Apache Kafka topic %s was updated.", self.topic_id) + return types.Topic.to_dict(topic_obj) + except NotFound as not_found_err: + self.log.info("The Topic %s does not exist.", self.topic_id) + raise AirflowException(not_found_err) + except Exception as error: + raise AirflowException(error) + + +class ManagedKafkaDeleteTopicOperator(ManagedKafkaBaseOperator): + """ + Delete a single topic. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param cluster_id: Required. The ID of the cluster whose topic is to be deleted. + :param topic_id: Required. The ID of the topic to delete. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"cluster_id", "topic_id"} | set(ManagedKafkaBaseOperator.template_fields) + ) + + def __init__( + self, + cluster_id: str, + topic_id: str, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_id = cluster_id + self.topic_id = topic_id + + def execute(self, context: Context): + try: + self.log.info("Deleting Apache Kafka topic: %s", self.topic_id) + self.hook.delete_topic( + project_id=self.project_id, + location=self.location, + cluster_id=self.cluster_id, + topic_id=self.topic_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.log.info("Apache Kafka topic was deleted.") + except NotFound as not_found_err: + self.log.info("The Apache Kafka topic ID %s does not exist.", self.topic_id) + raise AirflowException(not_found_err) diff --git a/providers/google/src/airflow/providers/google/get_provider_info.py b/providers/google/src/airflow/providers/google/get_provider_info.py index 64e191a3280de..45316c69a39e5 100644 --- a/providers/google/src/airflow/providers/google/get_provider_info.py +++ b/providers/google/src/airflow/providers/google/get_provider_info.py @@ -1568,6 +1568,7 @@ def get_provider_info(): "airflow.providers.google.cloud.links.translate.TranslationGlossariesListLink", "airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaClusterLink", "airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaClusterListLink", + "airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaTopicLink", ], "secrets-backends": [ "airflow.providers.google.cloud.secrets.secret_manager.CloudSecretManagerBackend" diff --git a/providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_topic.py b/providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_topic.py new file mode 100644 index 0000000000000..719891600b64c --- /dev/null +++ b/providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_topic.py @@ -0,0 +1,172 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +""" +Example Airflow DAG for Google Cloud Managed Service for Apache Kafka testing Topic operations. +""" + +from __future__ import annotations + +import os +from datetime import datetime + +from airflow.models.dag import DAG +from airflow.providers.google.cloud.operators.managed_kafka import ( + ManagedKafkaCreateClusterOperator, + ManagedKafkaCreateTopicOperator, + ManagedKafkaDeleteClusterOperator, + ManagedKafkaDeleteTopicOperator, + ManagedKafkaGetTopicOperator, + ManagedKafkaListTopicsOperator, + ManagedKafkaUpdateTopicOperator, +) +from airflow.utils.trigger_rule import TriggerRule + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") +DAG_ID = "managed_kafka_topic_operations" +LOCATION = "us-central1" + +CLUSTER_ID = f"cluster_{DAG_ID}_{ENV_ID}".replace("_", "-") +CLUSTER_CONF = { + "gcp_config": { + "access_config": { + "network_configs": [ + {"subnet": f"projects/{PROJECT_ID}/regions/{LOCATION}/subnetworks/default"}, + ], + }, + }, + "capacity_config": { + "vcpu_count": 3, + "memory_bytes": 3221225472, + }, +} +TOPIC_ID = f"topic_{DAG_ID}_{ENV_ID}".replace("_", "-") +TOPIC_CONF = { + "partition_count": 3, + "replication_factor": 3, +} +TOPIC_TO_UPDATE = { + "partition_count": 30, + "replication_factor": 3, +} +TOPIC_UPDATE_MASK: dict = {"paths": ["partition_count"]} + + +with DAG( + DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + catchup=False, + tags=["example", "managed_kafka", "topic"], +) as dag: + create_cluster = ManagedKafkaCreateClusterOperator( + task_id="create_cluster", + project_id=PROJECT_ID, + location=LOCATION, + cluster=CLUSTER_CONF, + cluster_id=CLUSTER_ID, + ) + + # [START how_to_cloud_managed_kafka_create_topic_operator] + create_topic = ManagedKafkaCreateTopicOperator( + task_id="create_topic", + project_id=PROJECT_ID, + location=LOCATION, + cluster_id=CLUSTER_ID, + topic_id=TOPIC_ID, + topic=TOPIC_CONF, + ) + # [END how_to_cloud_managed_kafka_create_topic_operator] + + # [START how_to_cloud_managed_kafka_update_topic_operator] + update_topic = ManagedKafkaUpdateTopicOperator( + task_id="update_topic", + project_id=PROJECT_ID, + location=LOCATION, + cluster_id=CLUSTER_ID, + topic_id=TOPIC_ID, + topic=TOPIC_TO_UPDATE, + update_mask=TOPIC_UPDATE_MASK, + ) + # [END how_to_cloud_managed_kafka_update_topic_operator] + + # [START how_to_cloud_managed_kafka_get_topic_operator] + get_topic = ManagedKafkaGetTopicOperator( + task_id="get_topic", + project_id=PROJECT_ID, + location=LOCATION, + cluster_id=CLUSTER_ID, + topic_id=TOPIC_ID, + ) + # [END how_to_cloud_managed_kafka_get_topic_operator] + + # [START how_to_cloud_managed_kafka_delete_topic_operator] + delete_topic = ManagedKafkaDeleteTopicOperator( + task_id="delete_topic", + project_id=PROJECT_ID, + location=LOCATION, + cluster_id=CLUSTER_ID, + topic_id=TOPIC_ID, + trigger_rule=TriggerRule.ALL_DONE, + ) + # [END how_to_cloud_managed_kafka_delete_topic_operator] + + delete_cluster = ManagedKafkaDeleteClusterOperator( + task_id="delete_cluster", + project_id=PROJECT_ID, + location=LOCATION, + cluster_id=CLUSTER_ID, + trigger_rule=TriggerRule.ALL_DONE, + ) + + # [START how_to_cloud_managed_kafka_list_topic_operator] + list_topics = ManagedKafkaListTopicsOperator( + task_id="list_topics", + project_id=PROJECT_ID, + location=LOCATION, + cluster_id=CLUSTER_ID, + ) + # [END how_to_cloud_managed_kafka_list_topic_operator] + + ( + # TEST SETUP + create_cluster + # TEST BODY + >> create_topic + >> update_topic + >> get_topic + >> list_topics + >> delete_topic + # TEST TEARDOWN + >> delete_cluster + ) + + # ### Everything below this line is not part of example ### + # ### Just for system tests purpose ### + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/providers/google/tests/unit/google/cloud/hooks/test_managed_kafka.py b/providers/google/tests/unit/google/cloud/hooks/test_managed_kafka.py index 16cb0d35cb9f1..7261f079555cb 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_managed_kafka.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_managed_kafka.py @@ -55,6 +55,17 @@ }, } +TEST_TOPIC_ID: str = "test-topic-id" +TEST_TOPIC: dict = { + "partition_count": 1634, + "replication_factor": 1912, +} +TEST_TOPIC_UPDATE_MASK: dict = {"paths": ["partition_count"]} +TEST_UPDATED_TOPIC: dict = { + "partition_count": 2000, + "replication_factor": 1912, +} + BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}" MANAGED_KAFKA_STRING = "airflow.providers.google.cloud.hooks.managed_kafka.{}" @@ -174,6 +185,122 @@ def test_list_clusters(self, mock_client) -> None: ) mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_LOCATION) + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_create_topic(self, mock_client) -> None: + self.hook.create_topic( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + topic_id=TEST_TOPIC_ID, + topic=TEST_TOPIC, + ) + mock_client.assert_called_once() + mock_client.return_value.create_topic.assert_called_once_with( + request=dict( + parent=mock_client.return_value.cluster_path.return_value, + topic_id=TEST_TOPIC_ID, + topic=TEST_TOPIC, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.cluster_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID + ) + + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_delete_topic(self, mock_client) -> None: + self.hook.delete_topic( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + topic_id=TEST_TOPIC_ID, + ) + mock_client.assert_called_once() + mock_client.return_value.delete_topic.assert_called_once_with( + request=dict(name=mock_client.return_value.topic_path.return_value), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.topic_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID, TEST_TOPIC_ID + ) + + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_get_topic(self, mock_client) -> None: + self.hook.get_topic( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + topic_id=TEST_TOPIC_ID, + ) + mock_client.assert_called_once() + mock_client.return_value.get_topic.assert_called_once_with( + request=dict( + name=mock_client.return_value.topic_path.return_value, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.topic_path.assert_called_once_with( + TEST_PROJECT_ID, + TEST_LOCATION, + TEST_CLUSTER_ID, + TEST_TOPIC_ID, + ) + + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_update_topic(self, mock_client) -> None: + self.hook.update_topic( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + topic_id=TEST_TOPIC_ID, + topic=TEST_UPDATED_TOPIC, + update_mask=TEST_TOPIC_UPDATE_MASK, + ) + mock_client.assert_called_once() + mock_client.return_value.update_topic.assert_called_once_with( + request=dict( + update_mask=TEST_TOPIC_UPDATE_MASK, + topic={ + "name": mock_client.return_value.topic_path.return_value, + **TEST_UPDATED_TOPIC, + }, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.topic_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID, TEST_TOPIC_ID + ) + + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_list_topics(self, mock_client) -> None: + self.hook.list_topics( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + ) + mock_client.assert_called_once() + mock_client.return_value.list_topics.assert_called_once_with( + request=dict( + parent=mock_client.return_value.cluster_path.return_value, + page_size=None, + page_token=None, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.cluster_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID + ) + class TestManagedKafkaWithoutDefaultProjectIdHook: def setup_method(self): @@ -289,3 +416,122 @@ def test_list_clusters(self, mock_client) -> None: timeout=None, ) mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_LOCATION) + + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_create_topic(self, mock_client) -> None: + self.hook.create_topic( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + topic_id=TEST_TOPIC_ID, + topic=TEST_TOPIC, + ) + mock_client.assert_called_once() + mock_client.return_value.create_topic.assert_called_once_with( + request=dict( + parent=mock_client.return_value.cluster_path.return_value, + topic_id=TEST_TOPIC_ID, + topic=TEST_TOPIC, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.cluster_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID + ) + + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_delete_topic(self, mock_client) -> None: + self.hook.delete_topic( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + topic_id=TEST_TOPIC_ID, + ) + mock_client.assert_called_once() + mock_client.return_value.delete_topic.assert_called_once_with( + request=dict(name=mock_client.return_value.topic_path.return_value), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.topic_path.assert_called_once_with( + TEST_PROJECT_ID, + TEST_LOCATION, + TEST_CLUSTER_ID, + TEST_TOPIC_ID, + ) + + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_get_topic(self, mock_client) -> None: + self.hook.get_topic( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + topic_id=TEST_TOPIC_ID, + ) + mock_client.assert_called_once() + mock_client.return_value.get_topic.assert_called_once_with( + request=dict( + name=mock_client.return_value.topic_path.return_value, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.topic_path.assert_called_once_with( + TEST_PROJECT_ID, + TEST_LOCATION, + TEST_CLUSTER_ID, + TEST_TOPIC_ID, + ) + + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_update_topic(self, mock_client) -> None: + self.hook.update_topic( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + topic_id=TEST_TOPIC_ID, + topic=TEST_UPDATED_TOPIC, + update_mask=TEST_TOPIC_UPDATE_MASK, + ) + mock_client.assert_called_once() + mock_client.return_value.update_topic.assert_called_once_with( + request=dict( + update_mask=TEST_TOPIC_UPDATE_MASK, + topic={ + "name": mock_client.return_value.topic_path.return_value, + **TEST_UPDATED_TOPIC, + }, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.topic_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID, TEST_TOPIC_ID + ) + + @mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client")) + def test_list_topics(self, mock_client) -> None: + self.hook.list_topics( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + ) + mock_client.assert_called_once() + mock_client.return_value.list_topics.assert_called_once_with( + request=dict( + parent=mock_client.return_value.cluster_path.return_value, + page_size=None, + page_token=None, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.cluster_path.assert_called_once_with( + TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID + ) diff --git a/providers/google/tests/unit/google/cloud/links/test_managed_kafka.py b/providers/google/tests/unit/google/cloud/links/test_managed_kafka.py index add83f74d56bf..7bf671c68e669 100644 --- a/providers/google/tests/unit/google/cloud/links/test_managed_kafka.py +++ b/providers/google/tests/unit/google/cloud/links/test_managed_kafka.py @@ -22,11 +22,13 @@ from airflow.providers.google.cloud.links.managed_kafka import ( ApacheKafkaClusterLink, ApacheKafkaClusterListLink, + ApacheKafkaTopicLink, ) TEST_LOCATION = "test-location" TEST_CLUSTER_ID = "test-cluster-id" TEST_PROJECT_ID = "test-project-id" +TEST_TOPIC_ID = "test-topic-id" EXPECTED_MANAGED_KAFKA_CLUSTER_LINK_NAME = "Apache Kafka Cluster" EXPECTED_MANAGED_KAFKA_CLUSTER_LINK_KEY = "cluster_conf" EXPECTED_MANAGED_KAFKA_CLUSTER_LINK_FORMAT_STR = ( @@ -35,6 +37,11 @@ EXPECTED_MANAGED_KAFKA_CLUSTER_LIST_LINK_NAME = "Apache Kafka Cluster List" EXPECTED_MANAGED_KAFKA_CLUSTER_LIST_LINK_KEY = "cluster_list_conf" EXPECTED_MANAGED_KAFKA_CLUSTER_LIST_LINK_FORMAT_STR = "/managedkafka/clusters?project={project_id}" +EXPECTED_MANAGED_KAFKA_TOPIC_LINK_NAME = "Apache Kafka Topic" +EXPECTED_MANAGED_KAFKA_TOPIC_LINK_KEY = "topic_conf" +EXPECTED_MANAGED_KAFKA_TOPIC_LINK_FORMAT_STR = ( + "/managedkafka/{location}/clusters/{cluster_id}/topics/{topic_id}?project={project_id}" +) class TestApacheKafkaClusterLink: @@ -87,3 +94,34 @@ def test_persist(self): "project_id": TEST_PROJECT_ID, }, ) + + +class TestApacheKafkaTopicLink: + def test_class_attributes(self): + assert ApacheKafkaTopicLink.key == EXPECTED_MANAGED_KAFKA_TOPIC_LINK_KEY + assert ApacheKafkaTopicLink.name == EXPECTED_MANAGED_KAFKA_TOPIC_LINK_NAME + assert ApacheKafkaTopicLink.format_str == EXPECTED_MANAGED_KAFKA_TOPIC_LINK_FORMAT_STR + + def test_persist(self): + mock_context, mock_task_instance = ( + mock.MagicMock(), + mock.MagicMock(location=TEST_LOCATION, project_id=TEST_PROJECT_ID), + ) + + ApacheKafkaTopicLink.persist( + context=mock_context, + task_instance=mock_task_instance, + cluster_id=TEST_CLUSTER_ID, + topic_id=TEST_TOPIC_ID, + ) + + mock_task_instance.xcom_push.assert_called_once_with( + context=mock_context, + key=EXPECTED_MANAGED_KAFKA_TOPIC_LINK_KEY, + value={ + "location": TEST_LOCATION, + "cluster_id": TEST_CLUSTER_ID, + "topic_id": TEST_TOPIC_ID, + "project_id": TEST_PROJECT_ID, + }, + ) diff --git a/providers/google/tests/unit/google/cloud/operators/test_managed_kafka.py b/providers/google/tests/unit/google/cloud/operators/test_managed_kafka.py index 4b5bc5c71257d..e9407cc0a50ca 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_managed_kafka.py +++ b/providers/google/tests/unit/google/cloud/operators/test_managed_kafka.py @@ -22,10 +22,15 @@ from airflow.providers.google.cloud.operators.managed_kafka import ( ManagedKafkaCreateClusterOperator, + ManagedKafkaCreateTopicOperator, ManagedKafkaDeleteClusterOperator, + ManagedKafkaDeleteTopicOperator, ManagedKafkaGetClusterOperator, + ManagedKafkaGetTopicOperator, ManagedKafkaListClustersOperator, + ManagedKafkaListTopicsOperator, ManagedKafkaUpdateClusterOperator, + ManagedKafkaUpdateTopicOperator, ) MANAGED_KAFKA_PATH = "airflow.providers.google.cloud.operators.managed_kafka.{}" @@ -64,6 +69,17 @@ }, } +TEST_TOPIC_ID: str = "test-topic-id" +TEST_TOPIC: dict = { + "partition_count": 1634, + "replication_factor": 1912, +} +TEST_TOPIC_UPDATE_MASK: dict = {"paths": ["partition_count"]} +TEST_UPDATED_TOPIC: dict = { + "partition_count": 2000, + "replication_factor": 1912, +} + class TestManagedKafkaCreateClusterOperator: @mock.patch(MANAGED_KAFKA_PATH.format("types.Cluster.to_dict")) @@ -221,3 +237,159 @@ def test_execute(self, mock_hook): timeout=TIMEOUT, metadata=METADATA, ) + + +class TestManagedKafkaCreateTopicOperator: + @mock.patch(MANAGED_KAFKA_PATH.format("types.Topic.to_dict")) + @mock.patch(MANAGED_KAFKA_PATH.format("ManagedKafkaHook")) + def test_execute(self, mock_hook, to_dict_mock): + op = ManagedKafkaCreateTopicOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + location=GCP_LOCATION, + project_id=GCP_PROJECT, + cluster_id=TEST_CLUSTER_ID, + topic_id=TEST_TOPIC_ID, + topic=TEST_TOPIC, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={"ti": mock.MagicMock()}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.create_topic.assert_called_once_with( + location=GCP_LOCATION, + project_id=GCP_PROJECT, + cluster_id=TEST_CLUSTER_ID, + topic_id=TEST_TOPIC_ID, + topic=TEST_TOPIC, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestManagedKafkaListTopicsOperator: + @mock.patch(MANAGED_KAFKA_PATH.format("types.ListTopicsResponse.to_dict")) + @mock.patch(MANAGED_KAFKA_PATH.format("types.Topic.to_dict")) + @mock.patch(MANAGED_KAFKA_PATH.format("ManagedKafkaHook")) + def test_execute(self, mock_hook, to_cluster_dict_mock, to_clusters_dict_mock): + page_token = "page_token" + page_size = 42 + + op = ManagedKafkaListTopicsOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + location=GCP_LOCATION, + project_id=GCP_PROJECT, + cluster_id=TEST_CLUSTER_ID, + page_size=page_size, + page_token=page_token, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={"ti": mock.MagicMock()}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.list_topics.assert_called_once_with( + location=GCP_LOCATION, + project_id=GCP_PROJECT, + cluster_id=TEST_CLUSTER_ID, + page_size=page_size, + page_token=page_token, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestManagedKafkaGetTopicOperator: + @mock.patch(MANAGED_KAFKA_PATH.format("types.Topic.to_dict")) + @mock.patch(MANAGED_KAFKA_PATH.format("ManagedKafkaHook")) + def test_execute(self, mock_hook, to_dict_mock): + op = ManagedKafkaGetTopicOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + location=GCP_LOCATION, + project_id=GCP_PROJECT, + cluster_id=TEST_CLUSTER_ID, + topic_id=TEST_TOPIC_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={"ti": mock.MagicMock()}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.get_topic.assert_called_once_with( + location=GCP_LOCATION, + project_id=GCP_PROJECT, + cluster_id=TEST_CLUSTER_ID, + topic_id=TEST_TOPIC_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestManagedKafkaUpdateTopicOperator: + @mock.patch(MANAGED_KAFKA_PATH.format("types.Topic.to_dict")) + @mock.patch(MANAGED_KAFKA_PATH.format("ManagedKafkaHook")) + def test_execute(self, mock_hook, to_dict_mock): + op = ManagedKafkaUpdateTopicOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + cluster_id=TEST_CLUSTER_ID, + topic_id=TEST_TOPIC_ID, + topic=TEST_UPDATED_TOPIC, + update_mask=TEST_TOPIC_UPDATE_MASK, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={"ti": mock.MagicMock()}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.update_topic.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + cluster_id=TEST_CLUSTER_ID, + topic_id=TEST_TOPIC_ID, + topic=TEST_UPDATED_TOPIC, + update_mask=TEST_TOPIC_UPDATE_MASK, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + +class TestManagedKafkaDeleteTopicOperator: + @mock.patch(MANAGED_KAFKA_PATH.format("ManagedKafkaHook")) + def test_execute(self, mock_hook): + op = ManagedKafkaDeleteTopicOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + location=GCP_LOCATION, + project_id=GCP_PROJECT, + cluster_id=TEST_CLUSTER_ID, + topic_id=TEST_TOPIC_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.delete_topic.assert_called_once_with( + location=GCP_LOCATION, + project_id=GCP_PROJECT, + cluster_id=TEST_CLUSTER_ID, + topic_id=TEST_TOPIC_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + )