diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index eeb1978738d28..e9574a359266f 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1759,6 +1759,7 @@ traceback tracebacks tracemalloc TrainingPipeline +TransferOperation TranslationServiceClient travis triage diff --git a/providers/src/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py b/providers/src/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py index f62679a63fa53..b9d19e81e174f 100644 --- a/providers/src/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +++ b/providers/src/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py @@ -36,6 +36,7 @@ from datetime import timedelta from typing import TYPE_CHECKING, Any +from google.api_core import protobuf_helpers from google.cloud.storage_transfer_v1 import ( ListTransferJobsRequest, StorageTransferServiceAsyncClient, @@ -57,6 +58,7 @@ from google.cloud.storage_transfer_v1.services.storage_transfer_service.pagers import ( ListTransferJobsAsyncPager, ) + from google.longrunning import operations_pb2 # type: ignore[attr-defined] from proto import Message log = logging.getLogger(__name__) @@ -596,3 +598,112 @@ async def get_latest_operation(self, job: TransferJob) -> Message | None: operation = TransferOperation.deserialize(response_operation.metadata.value) return operation return None + + async def list_transfer_operations( + self, + request_filter: dict | None = None, + **kwargs, + ) -> list[TransferOperation]: + """ + Get a transfer operation in Google Storage Transfer Service. + + :param request_filter: (Required) A request filter, as described in + https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferJobs/list#body.QUERY_PARAMETERS.filter + With one additional improvement: + :return: transfer operation + + The ``project_id`` parameter is optional if you have a project ID + defined in the connection. See: :doc:`/connections/gcp` + """ + # To preserve backward compatibility + # TODO: remove one day + if request_filter is None: + if "filter" in kwargs: + request_filter = kwargs["filter"] + if not isinstance(request_filter, dict): + raise ValueError(f"The request_filter should be dict and is {type(request_filter)}") + warnings.warn( + "Use 'request_filter' instead of 'filter'", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + else: + raise TypeError( + "list_transfer_operations missing 1 required positional argument: 'request_filter'" + ) + + conn = await self.get_conn() + + request_filter = await self._inject_project_id(request_filter, FILTER, FILTER_PROJECT_ID) + + operations: list[operations_pb2.Operation] = [] + + response = await conn.list_operations( + request={ + "name": TRANSFER_OPERATIONS, + "filter": json.dumps(request_filter), + } + ) + + while response is not None: + operations.extend(response.operations) + response = ( + await conn.list_operations( + request={ + "name": TRANSFER_OPERATIONS, + "filter": json.dumps(request_filter), + "page_token": response.next_page_token, + } + ) + if response.next_page_token + else None + ) + + transfer_operations = [ + protobuf_helpers.from_any_pb(TransferOperation, op.metadata) for op in operations + ] + + return transfer_operations + + async def _inject_project_id(self, body: dict, param_name: str, target_key: str) -> dict: + body = deepcopy(body) + body[target_key] = body.get(target_key, self.project_id) + if not body.get(target_key): + raise AirflowException( + f"The project id must be passed either as `{target_key}` key in `{param_name}` " + f"parameter or as project_id extra in Google Cloud connection definition. Both are not set!" + ) + return body + + @staticmethod + async def operations_contain_expected_statuses( + operations: list[TransferOperation], expected_statuses: set[str] | str + ) -> bool: + """ + Check whether an operation exists with the expected status. + + :param operations: (Required) List of transfer operations to check. + :param expected_statuses: (Required) The expected status. See: + https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferOperations#Status + :return: If there is an operation with the expected state in the + operation list, returns true, + :raises AirflowException: If it encounters operations with state FAILED + or ABORTED in the list. + """ + expected_statuses_set = ( + {expected_statuses} if isinstance(expected_statuses, str) else set(expected_statuses) + ) + if not operations: + return False + + current_statuses = {operation.status.name for operation in operations} + + if len(current_statuses - expected_statuses_set) != len(current_statuses): + return True + + if len(NEGATIVE_STATUSES - current_statuses) != len(NEGATIVE_STATUSES): + raise AirflowException( + f"An unexpected operation status was encountered. " + f"Expected: {', '.join(expected_statuses_set)}" + ) + return False diff --git a/providers/src/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py b/providers/src/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py index 808c93147d8c6..85fef62036361 100644 --- a/providers/src/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +++ b/providers/src/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py @@ -21,9 +21,10 @@ from collections.abc import Sequence from copy import deepcopy -from datetime import date, time -from typing import TYPE_CHECKING +from datetime import date, time, timedelta +from typing import TYPE_CHECKING, Any +from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import ( @@ -63,6 +64,9 @@ CloudStorageTransferListLink, ) from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator +from airflow.providers.google.cloud.triggers.cloud_storage_transfer_service import ( + CloudStorageTransferServiceCheckJobStatusTrigger, +) from airflow.providers.google.cloud.utils.helpers import normalize_directory_path from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID @@ -908,6 +912,7 @@ class CloudDataTransferServiceS3ToGCSOperator(GoogleCloudBaseOperator): :param aws_role_arn: Optional AWS role ARN for workload identity federation. This will override the `aws_conn_id` for authentication between GCP and AWS; see https://cloud.google.com/storage-transfer/docs/reference/rest/v1/TransferSpec#AwsS3Data + :param deferrable: Run operator in the deferrable mode. """ template_fields: Sequence[str] = ( @@ -942,6 +947,7 @@ def __init__( google_impersonation_chain: str | Sequence[str] | None = None, delete_job_after_completion: bool = False, aws_role_arn: str | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) @@ -961,6 +967,7 @@ def __init__( self.google_impersonation_chain = google_impersonation_chain self.delete_job_after_completion = delete_job_after_completion self.aws_role_arn = aws_role_arn + self.deferrable = deferrable self._validate_inputs() def _validate_inputs(self) -> None: @@ -979,9 +986,31 @@ def execute(self, context: Context) -> None: job = hook.create_transfer_job(body=body) if self.wait: - hook.wait_for_transfer_job(job, timeout=self.timeout) - if self.delete_job_after_completion: - hook.delete_transfer_job(job_name=job[NAME], project_id=self.project_id) + if not self.deferrable: + hook.wait_for_transfer_job(job, timeout=self.timeout) + if self.delete_job_after_completion: + hook.delete_transfer_job(job_name=job[NAME], project_id=self.project_id) + else: + self.defer( + timeout=timedelta(seconds=self.timeout or 60), + trigger=CloudStorageTransferServiceCheckJobStatusTrigger( + job_name=job[NAME], + project_id=job[PROJECT_ID], + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.google_impersonation_chain, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict[str, Any]) -> None: + """ + Act as a callback for when the trigger fires. + + This returns immediately. It relies on trigger to throw an exception, + otherwise it assumes execution was successful. + """ + if event["status"] == "error": + raise AirflowException(event["message"]) def _create_body(self) -> dict: body = { @@ -1079,6 +1108,7 @@ class CloudDataTransferServiceGCSToGCSOperator(GoogleCloudBaseOperator): account from the list granting this role to the originating account (templated). :param delete_job_after_completion: If True, delete the job after complete. If set to True, 'wait' must be set to True. + :param deferrable: Run operator in the deferrable mode. """ # [START gcp_transfer_gcs_to_gcs_template_fields] @@ -1113,6 +1143,7 @@ def __init__( timeout: float | None = None, google_impersonation_chain: str | Sequence[str] | None = None, delete_job_after_completion: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) @@ -1130,6 +1161,7 @@ def __init__( self.timeout = timeout self.google_impersonation_chain = google_impersonation_chain self.delete_job_after_completion = delete_job_after_completion + self.deferrable = deferrable self._validate_inputs() def _validate_inputs(self) -> None: @@ -1149,9 +1181,31 @@ def execute(self, context: Context) -> None: job = hook.create_transfer_job(body=body) if self.wait: - hook.wait_for_transfer_job(job, timeout=self.timeout) - if self.delete_job_after_completion: - hook.delete_transfer_job(job_name=job[NAME], project_id=self.project_id) + if not self.deferrable: + hook.wait_for_transfer_job(job, timeout=self.timeout) + if self.delete_job_after_completion: + hook.delete_transfer_job(job_name=job[NAME], project_id=self.project_id) + else: + self.defer( + timeout=timedelta(seconds=self.timeout or 60), + trigger=CloudStorageTransferServiceCheckJobStatusTrigger( + job_name=job[NAME], + project_id=job[PROJECT_ID], + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.google_impersonation_chain, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict[str, Any]) -> None: + """ + Act as a callback for when the trigger fires. + + This returns immediately. It relies on trigger to throw an exception, + otherwise it assumes execution was successful. + """ + if event["status"] == "error": + raise AirflowException(event["message"]) def _create_body(self) -> dict: body = { diff --git a/providers/src/airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py b/providers/src/airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py index 575971c2e0c87..20d645caa393f 100644 --- a/providers/src/airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +++ b/providers/src/airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py @@ -20,8 +20,10 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any +from airflow.configuration import conf +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import ( COUNTERS, METADATA, @@ -29,6 +31,9 @@ CloudDataTransferServiceHook, ) from airflow.providers.google.cloud.links.cloud_storage_transfer import CloudStorageTransferJobLink +from airflow.providers.google.cloud.triggers.cloud_storage_transfer_service import ( + CloudStorageTransferServiceCheckJobStatusTrigger, +) from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID from airflow.sensors.base import BaseSensorOperator @@ -60,6 +65,7 @@ class CloudDataTransferServiceJobStatusSensor(BaseSensorOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param deferrable: Run sensor in deferrable mode """ # [START gcp_transfer_job_sensor_template_fields] @@ -78,6 +84,7 @@ def __init__( project_id: str = PROVIDE_PROJECT_ID, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) @@ -88,6 +95,7 @@ def __init__( self.project_id = project_id self.gcp_cloud_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + self.deferrable = deferrable def poke(self, context: Context) -> bool: hook = CloudDataTransferServiceHook( @@ -117,3 +125,33 @@ def poke(self, context: Context) -> bool: ) return check + + def execute(self, context: Context) -> None: + """Run on the worker and defer using the triggers if deferrable is set to True.""" + if not self.deferrable: + super().execute(context) + elif not self.poke(context=context): + self.defer( + timeout=self.execution_timeout, + trigger=CloudStorageTransferServiceCheckJobStatusTrigger( + job_name=self.job_name, + expected_statuses=self.expected_statuses, + project_id=self.project_id, + poke_interval=self.poke_interval, + gcp_conn_id=self.gcp_cloud_conn_id, + impersonation_chain=self.impersonation_chain, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict[str, Any]) -> None: + """ + Act as a callback for when the trigger fires. + + This returns immediately. It relies on trigger to throw an exception, + otherwise it assumes execution was successful. + """ + if event["status"] == "error": + raise AirflowException(event["message"]) + + self.xcom_push(key="sensed_operations", value=event["operations"], context=context) diff --git a/providers/src/airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py b/providers/src/airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py index e95df81277902..9a28b960e6e83 100644 --- a/providers/src/airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +++ b/providers/src/airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py @@ -18,7 +18,7 @@ from __future__ import annotations import asyncio -from collections.abc import AsyncIterator, Iterable +from collections.abc import AsyncIterator, Iterable, Sequence from typing import Any from google.api_core.exceptions import GoogleAPIError @@ -27,6 +27,7 @@ from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import ( CloudDataTransferServiceAsyncHook, + GcpTransferOperationStatus, ) from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -132,3 +133,101 @@ def get_async_hook(self) -> CloudDataTransferServiceAsyncHook: project_id=self.project_id, gcp_conn_id=self.gcp_conn_id, ) + + +class CloudStorageTransferServiceCheckJobStatusTrigger(BaseTrigger): + """ + CloudStorageTransferServiceCheckJobStatusTrigger run on the trigger worker to check Cloud Storage Transfer job. + + :param job_name: The name of the transfer job + :param expected_statuses: The expected state of the operation. + See: + https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferOperations#Status + :param project_id: The ID of the project that owns the Transfer Job. + :param poke_interval: Polling period in seconds to check for the status + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + def __init__( + self, + job_name: str, + expected_statuses: set[str] | str | None = None, + project_id: str = PROVIDE_PROJECT_ID, + poke_interval: float = 10.0, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + ): + super().__init__() + self.job_name = job_name + self.expected_statuses = expected_statuses + self.project_id = project_id + self.poke_interval = poke_interval + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize CloudStorageTransferServiceCheckJobStatusTrigger arguments and classpath.""" + return ( + f"{self.__class__.__module__ }.{self.__class__.__qualname__}", + { + "job_name": self.job_name, + "expected_statuses": self.expected_statuses, + "project_id": self.project_id, + "poke_interval": self.poke_interval, + "gcp_conn_id": self.gcp_conn_id, + "impersonation_chain": self.impersonation_chain, + }, + ) + + def _get_async_hook(self) -> CloudDataTransferServiceAsyncHook: + return CloudDataTransferServiceAsyncHook( + project_id=self.project_id, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Check the status of the transfer job and yield a TriggerEvent.""" + hook = self._get_async_hook() + expected_statuses = ( + {GcpTransferOperationStatus.SUCCESS} if not self.expected_statuses else self.expected_statuses + ) + + try: + while True: + operations = await hook.list_transfer_operations( + request_filter={ + "project_id": self.project_id or hook.project_id, + "job_names": [self.job_name], + } + ) + check = await CloudDataTransferServiceAsyncHook.operations_contain_expected_statuses( + operations=operations, + expected_statuses=expected_statuses, + ) + if check: + yield TriggerEvent( + { + "status": "success", + "message": "Transfer operation completed successfully", + "operations": operations, + } + ) + return + + self.log.info( + "Sleeping for %s seconds.", + self.poke_interval, + ) + await asyncio.sleep(self.poke_interval) + except Exception as e: + self.log.exception("Exception occurred while checking for query completion") + yield TriggerEvent({"status": "error", "message": str(e)}) diff --git a/providers/tests/google/cloud/hooks/test_cloud_storage_transfer_service_async.py b/providers/tests/google/cloud/hooks/test_cloud_storage_transfer_service_async.py index 968c1a95efbdf..25bf06da8f607 100644 --- a/providers/tests/google/cloud/hooks/test_cloud_storage_transfer_service_async.py +++ b/providers/tests/google/cloud/hooks/test_cloud_storage_transfer_service_async.py @@ -23,8 +23,10 @@ import pytest +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import ( CloudDataTransferServiceAsyncHook, + GcpTransferOperationStatus, ) from providers.tests.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id @@ -119,3 +121,82 @@ async def test_get_last_operation_none(self, mock_deserialize, mock_conn, hook_a get_operation.assert_not_called() mock_deserialize.assert_not_called() assert operation == expected_operation + + @pytest.mark.asyncio + @mock.patch(f"{TRANSFER_HOOK_PATH}.CloudDataTransferServiceAsyncHook.get_conn") + @mock.patch("google.api_core.protobuf_helpers.from_any_pb") + async def test_list_transfer_operations(self, from_any_pb, mock_conn, hook_async): + expected_operations = [mock.MagicMock(), mock.MagicMock()] + from_any_pb.side_effect = expected_operations + + mock_conn.return_value.list_operations.side_effect = [ + mock.MagicMock(next_page_token="token", operations=[mock.MagicMock()]), + mock.MagicMock(next_page_token=None, operations=[mock.MagicMock()]), + ] + + actual_operations = await hook_async.list_transfer_operations( + request_filter={ + "project_id": TEST_PROJECT_ID, + }, + ) + assert actual_operations == expected_operations + assert mock_conn.return_value.list_operations.call_count == 2 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "statuses, expected_statuses", + [ + ([GcpTransferOperationStatus.ABORTED], (GcpTransferOperationStatus.IN_PROGRESS,)), + ( + [GcpTransferOperationStatus.SUCCESS, GcpTransferOperationStatus.ABORTED], + (GcpTransferOperationStatus.IN_PROGRESS,), + ), + ( + [GcpTransferOperationStatus.PAUSED, GcpTransferOperationStatus.ABORTED], + (GcpTransferOperationStatus.IN_PROGRESS,), + ), + ], + ) + async def test_operations_contain_expected_statuses_red_path(self, statuses, expected_statuses): + operations = [mock.MagicMock(**{"status.name": status}) for status in statuses] + + with pytest.raises( + AirflowException, + match=f"An unexpected operation status was encountered. Expected: {', '.join(expected_statuses)}", + ): + await CloudDataTransferServiceAsyncHook.operations_contain_expected_statuses( + operations, GcpTransferOperationStatus.IN_PROGRESS + ) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "statuses, expected_statuses", + [ + ([GcpTransferOperationStatus.ABORTED], GcpTransferOperationStatus.ABORTED), + ( + [GcpTransferOperationStatus.SUCCESS, GcpTransferOperationStatus.ABORTED], + GcpTransferOperationStatus.ABORTED, + ), + ( + [GcpTransferOperationStatus.PAUSED, GcpTransferOperationStatus.ABORTED], + GcpTransferOperationStatus.ABORTED, + ), + ([GcpTransferOperationStatus.ABORTED], (GcpTransferOperationStatus.ABORTED,)), + ( + [GcpTransferOperationStatus.SUCCESS, GcpTransferOperationStatus.ABORTED], + (GcpTransferOperationStatus.ABORTED,), + ), + ( + [GcpTransferOperationStatus.PAUSED, GcpTransferOperationStatus.ABORTED], + (GcpTransferOperationStatus.ABORTED,), + ), + ], + ) + async def test_operations_contain_expected_statuses_green_path(self, statuses, expected_statuses): + operations = [mock.MagicMock(**{"status.name": status}) for status in statuses] + + result = await CloudDataTransferServiceAsyncHook.operations_contain_expected_statuses( + operations, expected_statuses + ) + + assert result diff --git a/providers/tests/google/cloud/operators/test_cloud_storage_transfer_service.py b/providers/tests/google/cloud/operators/test_cloud_storage_transfer_service.py index 73c093a5c0d17..6ec153dca04a8 100644 --- a/providers/tests/google/cloud/operators/test_cloud_storage_transfer_service.py +++ b/providers/tests/google/cloud/operators/test_cloud_storage_transfer_service.py @@ -25,7 +25,7 @@ import time_machine from botocore.credentials import Credentials -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import ( ACCESS_KEY_ID, AWS_ACCESS_KEY, @@ -62,6 +62,9 @@ TransferJobPreprocessor, TransferJobValidator, ) +from airflow.providers.google.cloud.triggers.cloud_storage_transfer_service import ( + CloudStorageTransferServiceCheckJobStatusTrigger, +) from airflow.utils import timezone try: @@ -956,6 +959,65 @@ def test_execute_should_throw_ex_when_delete_job_without_wait(self, mock_aws_hoo delete_job_after_completion=True, ) + @mock.patch( + "airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook" + ) + @mock.patch("airflow.providers.google.cloud.operators.cloud_storage_transfer_service.AwsBaseHook") + def test_async_defer_successfully(self, mock_aws_hook, mock_transfer_hook): + mock_aws_hook.return_value.get_credentials.return_value = Credentials( + TEST_AWS_ACCESS_KEY_ID, TEST_AWS_ACCESS_SECRET, None + ) + + operator = CloudDataTransferServiceS3ToGCSOperator( + task_id=TASK_ID, + s3_bucket=AWS_BUCKET_NAME, + gcs_bucket=GCS_BUCKET_NAME, + project_id=GCP_PROJECT_ID, + description=DESCRIPTION, + schedule=SCHEDULE_DICT, + deferrable=True, + ) + with pytest.raises(TaskDeferred) as exc: + operator.execute({}) + assert isinstance(exc.value.trigger, CloudStorageTransferServiceCheckJobStatusTrigger) + + @mock.patch("airflow.providers.google.cloud.operators.cloud_storage_transfer_service.AwsBaseHook") + def test_async_execute_successfully(self, mock_aws_hook): + mock_aws_hook.return_value.get_credentials.return_value = Credentials( + TEST_AWS_ACCESS_KEY_ID, TEST_AWS_ACCESS_SECRET, None + ) + + operator = CloudDataTransferServiceS3ToGCSOperator( + task_id=TASK_ID, + s3_bucket=AWS_BUCKET_NAME, + gcs_bucket=GCS_BUCKET_NAME, + project_id=GCP_PROJECT_ID, + description=DESCRIPTION, + schedule=SCHEDULE_DICT, + deferrable=True, + ) + operator.execute_complete(context={}, event={"status": "success"}) + + @mock.patch("airflow.providers.google.cloud.operators.cloud_storage_transfer_service.AwsBaseHook") + def test_async_execute_error(self, mock_aws_hook): + mock_aws_hook.return_value.get_credentials.return_value = Credentials( + TEST_AWS_ACCESS_KEY_ID, TEST_AWS_ACCESS_SECRET, None + ) + + operator = CloudDataTransferServiceS3ToGCSOperator( + task_id=TASK_ID, + s3_bucket=AWS_BUCKET_NAME, + gcs_bucket=GCS_BUCKET_NAME, + project_id=GCP_PROJECT_ID, + description=DESCRIPTION, + schedule=SCHEDULE_DICT, + deferrable=True, + ) + with pytest.raises(AirflowException): + operator.execute_complete( + context={}, event={"status": "error", "message": "test failure message"} + ) + class TestGoogleCloudStorageToGoogleCloudStorageTransferOperator: def test_constructor(self): @@ -1073,12 +1135,56 @@ def test_execute_should_throw_ex_when_delete_job_without_wait(self, mock_transfe with pytest.raises( AirflowException, match="If 'delete_job_after_completion' is True, then 'wait' must also be True." ): - CloudDataTransferServiceS3ToGCSOperator( + CloudDataTransferServiceGCSToGCSOperator( task_id=TASK_ID, - s3_bucket=AWS_BUCKET_NAME, - gcs_bucket=GCS_BUCKET_NAME, + source_bucket=GCS_BUCKET_NAME, + destination_bucket=GCS_BUCKET_NAME, description=DESCRIPTION, schedule=SCHEDULE_DICT, wait=False, delete_job_after_completion=True, ) + + @mock.patch( + "airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook" + ) + def test_async_defer_successfully(self, mock_transfer_hook): + operator = CloudDataTransferServiceGCSToGCSOperator( + task_id=TASK_ID, + source_bucket=GCS_BUCKET_NAME, + destination_bucket=GCS_BUCKET_NAME, + project_id=GCP_PROJECT_ID, + description=DESCRIPTION, + schedule=SCHEDULE_DICT, + deferrable=True, + ) + with pytest.raises(TaskDeferred) as exc: + operator.execute({}) + assert isinstance(exc.value.trigger, CloudStorageTransferServiceCheckJobStatusTrigger) + + def test_async_execute_successfully(self): + operator = CloudDataTransferServiceGCSToGCSOperator( + task_id=TASK_ID, + source_bucket=GCS_BUCKET_NAME, + destination_bucket=GCS_BUCKET_NAME, + project_id=GCP_PROJECT_ID, + description=DESCRIPTION, + schedule=SCHEDULE_DICT, + deferrable=True, + ) + operator.execute_complete(context={}, event={"status": "success"}) + + def test_async_execute_error(self): + operator = CloudDataTransferServiceGCSToGCSOperator( + task_id=TASK_ID, + source_bucket=GCS_BUCKET_NAME, + destination_bucket=GCS_BUCKET_NAME, + project_id=GCP_PROJECT_ID, + description=DESCRIPTION, + schedule=SCHEDULE_DICT, + deferrable=True, + ) + with pytest.raises(AirflowException): + operator.execute_complete( + context={}, event={"status": "error", "message": "test failure message"} + ) diff --git a/providers/tests/google/cloud/sensors/test_cloud_storage_transfer_service.py b/providers/tests/google/cloud/sensors/test_cloud_storage_transfer_service.py index 68517e17ea206..1bc464ee23b75 100644 --- a/providers/tests/google/cloud/sensors/test_cloud_storage_transfer_service.py +++ b/providers/tests/google/cloud/sensors/test_cloud_storage_transfer_service.py @@ -21,10 +21,14 @@ import pytest +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import GcpTransferOperationStatus from airflow.providers.google.cloud.sensors.cloud_storage_transfer_service import ( CloudDataTransferServiceJobStatusSensor, ) +from airflow.providers.google.cloud.triggers.cloud_storage_transfer_service import ( + CloudStorageTransferServiceCheckJobStatusTrigger, +) TEST_NAME = "transferOperations/transferJobs-123-456" TEST_COUNTERS = { @@ -218,3 +222,71 @@ def test_wait_for_status_normalize_status(self, mock_tool, expected_status, rece mock_tool.operations_contain_expected_statuses.assert_called_once_with( operations=operations, expected_statuses=received_status ) + + @mock.patch( + "airflow.providers.google.cloud.sensors.cloud_storage_transfer_service.CloudDataTransferServiceHook" + ) + @mock.patch( + "airflow.providers.google.cloud.sensors.cloud_storage_transfer_service" + ".CloudDataTransferServiceJobStatusSensor.defer" + ) + def test_job_status_sensor_finish_before_deferred(self, mock_defer, mock_hook): + op = CloudDataTransferServiceJobStatusSensor( + task_id="task-id", + job_name=JOB_NAME, + project_id="project-id", + expected_statuses=GcpTransferOperationStatus.SUCCESS, + deferrable=True, + ) + + mock_hook.operations_contain_expected_statuses.return_value = True + context = {"ti": (mock.Mock(**{"xcom_push.return_value": None}))} + + op.execute(context) + assert not mock_defer.called + + @mock.patch( + "airflow.providers.google.cloud.sensors.cloud_storage_transfer_service.CloudDataTransferServiceHook" + ) + def test_execute_deferred(self, mock_hook): + op = CloudDataTransferServiceJobStatusSensor( + task_id="task-id", + job_name=JOB_NAME, + project_id="project-id", + expected_statuses=GcpTransferOperationStatus.SUCCESS, + deferrable=True, + ) + + mock_hook.operations_contain_expected_statuses.return_value = False + context = {"ti": (mock.Mock(**{"xcom_push.return_value": None}))} + + with pytest.raises(TaskDeferred) as exc: + op.execute(context) + assert isinstance(exc.value.trigger, CloudStorageTransferServiceCheckJobStatusTrigger) + + def test_execute_deferred_failure(self): + op = CloudDataTransferServiceJobStatusSensor( + task_id="task-id", + job_name=JOB_NAME, + project_id="project-id", + expected_statuses=GcpTransferOperationStatus.SUCCESS, + deferrable=True, + ) + + context = {"ti": (mock.Mock(**{"xcom_push.return_value": None}))} + + with pytest.raises(AirflowException): + op.execute_complete(context=context, event={"status": "error", "message": "test failure message"}) + + def test_execute_complete(self): + op = CloudDataTransferServiceJobStatusSensor( + task_id="task-id", + job_name=JOB_NAME, + project_id="project-id", + expected_statuses=GcpTransferOperationStatus.SUCCESS, + deferrable=True, + ) + + context = {"ti": (mock.Mock(**{"xcom_push.return_value": None}))} + + op.execute_complete(context=context, event={"status": "success", "operations": []}) diff --git a/providers/tests/google/cloud/triggers/test_cloud_storage_transfer_service.py b/providers/tests/google/cloud/triggers/test_cloud_storage_transfer_service.py index 072c6a5d7da9d..6bb9100b4236e 100644 --- a/providers/tests/google/cloud/triggers/test_cloud_storage_transfer_service.py +++ b/providers/tests/google/cloud/triggers/test_cloud_storage_transfer_service.py @@ -25,8 +25,10 @@ from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import ( CloudDataTransferServiceAsyncHook, + GcpTransferOperationStatus, ) from airflow.providers.google.cloud.triggers.cloud_storage_transfer_service import ( + CloudStorageTransferServiceCheckJobStatusTrigger, CloudStorageTransferServiceCreateJobsTrigger, ) from airflow.triggers.base import TriggerEvent @@ -47,6 +49,8 @@ ASYNC_HOOK_CLASS_PATH = ( "airflow.providers.google.cloud.hooks.cloud_storage_transfer_service.CloudDataTransferServiceAsyncHook" ) +EXPECTED_STATUSES = GcpTransferOperationStatus.SUCCESS +IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] @pytest.fixture(scope="session") @@ -300,3 +304,90 @@ async def test_run_get_latest_operation_google_api_call_error( actual_event = await generator.asend(None) assert actual_event == expected_event + + +class TestCloudStorageTransferServiceCheckJobStatusTrigger: + @pytest.fixture + def trigger(self): + return CloudStorageTransferServiceCheckJobStatusTrigger( + project_id=PROJECT_ID, + job_name=JOB_0, + expected_statuses=EXPECTED_STATUSES, + poke_interval=POLL_INTERVAL, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + def test_serialize(self, trigger): + class_path, serialized = trigger.serialize() + assert class_path == ( + "airflow.providers.google.cloud.triggers.cloud_storage_transfer_service" + ".CloudStorageTransferServiceCheckJobStatusTrigger" + ) + assert serialized == { + "project_id": PROJECT_ID, + "job_name": JOB_0, + "expected_statuses": EXPECTED_STATUSES, + "poke_interval": POLL_INTERVAL, + "gcp_conn_id": GCP_CONN_ID, + "impersonation_chain": IMPERSONATION_CHAIN, + } + + @pytest.mark.parametrize( + "attr, expected_value", + [ + ("gcp_conn_id", GCP_CONN_ID), + ("impersonation_chain", IMPERSONATION_CHAIN), + ], + ) + def test_get_async_hook(self, attr, expected_value, trigger): + hook = trigger._get_async_hook() + actual_value = hook._hook_kwargs.get(attr) + assert isinstance(hook, CloudDataTransferServiceAsyncHook) + assert hook._hook_kwargs is not None + assert actual_value == expected_value + + @pytest.mark.asyncio + @mock.patch(ASYNC_HOOK_CLASS_PATH + ".list_transfer_operations") + @mock.patch(ASYNC_HOOK_CLASS_PATH + ".operations_contain_expected_statuses") + async def test_run_returns_success_event( + self, + operations_contain_expected_statuses, + list_transfer_operations, + trigger, + ): + operations_contain_expected_statuses.side_effect = [ + False, + True, + ] + expected_event = TriggerEvent( + { + "status": "success", + "message": "Transfer operation completed successfully", + "operations": list_transfer_operations.return_value, + } + ) + + actual_event = await trigger.run().asend(None) + + assert actual_event == expected_event + assert operations_contain_expected_statuses.call_count == 2 + + @pytest.mark.asyncio + @mock.patch(ASYNC_HOOK_CLASS_PATH + ".list_transfer_operations") + async def test_run_returns_exception_event( + self, + list_transfer_operations, + trigger, + ): + list_transfer_operations.side_effect = Exception("Transfer operation failed") + expected_event = TriggerEvent( + { + "status": "error", + "message": "Transfer operation failed", + } + ) + + actual_event = await trigger.run().asend(None) + + assert actual_event == expected_event diff --git a/providers/tests/system/google/cloud/storage_transfer/example_cloud_storage_transfer_service_gcp.py b/providers/tests/system/google/cloud/storage_transfer/example_cloud_storage_transfer_service_gcp.py index 960a9c0679a9b..9f751fb219896 100644 --- a/providers/tests/system/google/cloud/storage_transfer/example_cloud_storage_transfer_service_gcp.py +++ b/providers/tests/system/google/cloud/storage_transfer/example_cloud_storage_transfer_service_gcp.py @@ -148,6 +148,14 @@ expected_statuses={GcpTransferOperationStatus.SUCCESS}, ) + wait_for_transfer_defered = CloudDataTransferServiceJobStatusSensor( + task_id="wait_for_transfer_defered", + job_name="{{task_instance.xcom_pull('create_transfer')['name']}}", + project_id=PROJECT_ID_TRANSFER, + expected_statuses={GcpTransferOperationStatus.SUCCESS}, + deferrable=True, + ) + # [START howto_operator_gcp_transfer_run_job] run_transfer = CloudDataTransferServiceRunJobOperator( task_id="run_transfer", @@ -187,7 +195,7 @@ [create_bucket_src, create_bucket_dst] >> upload_file >> create_transfer - >> wait_for_transfer + >> [wait_for_transfer, wait_for_transfer_defered] >> update_transfer >> run_transfer >> list_operations diff --git a/providers/tests/system/google/cloud/storage_transfer/example_cloud_storage_transfer_service_gcs_to_gcs.py b/providers/tests/system/google/cloud/storage_transfer/example_cloud_storage_transfer_service_gcs_to_gcs.py index 46289ea66ca23..1c1a0e3daa9e2 100644 --- a/providers/tests/system/google/cloud/storage_transfer/example_cloud_storage_transfer_service_gcs_to_gcs.py +++ b/providers/tests/system/google/cloud/storage_transfer/example_cloud_storage_transfer_service_gcs_to_gcs.py @@ -87,6 +87,18 @@ ) # [END howto_operator_transfer_gcs_to_gcs] + # [START howto_operator_transfer_gcs_to_gcs_defered] + transfer_gcs_to_gcs_defered = CloudDataTransferServiceGCSToGCSOperator( + task_id="transfer_gcs_to_gcs_defered", + source_bucket=BUCKET_NAME_SRC, + source_path=FILE_URI, + destination_bucket=BUCKET_NAME_DST, + destination_path=FILE_URI, + wait=True, + deferrable=True, + ) + # [END howto_operator_transfer_gcs_to_gcs_defered] + delete_bucket_dst = GCSDeleteBucketOperator( task_id="delete_bucket", bucket_name=BUCKET_NAME_DST, trigger_rule=TriggerRule.ALL_DONE ) @@ -102,6 +114,7 @@ [create_bucket_dst, create_bucket_src >> upload_file] # TEST BODY >> transfer_gcs_to_gcs + >> transfer_gcs_to_gcs_defered # TEST TEARDOWN >> [delete_bucket_src, delete_bucket_dst] )