Skip to content

Commit

Permalink
Add deferrable mode to google cloud storage transfer sensor and opera…
Browse files Browse the repository at this point in the history
…tors (apache#45754)

* Add deferrable mode to cloud storage transfer sensor and operators

* Fix spell check

* Add system tests
  • Loading branch information
tnk-ysk authored and niklasr22 committed Feb 8, 2025
1 parent 14604fc commit 132c0c2
Show file tree
Hide file tree
Showing 11 changed files with 689 additions and 15 deletions.
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1759,6 +1759,7 @@ traceback
tracebacks
tracemalloc
TrainingPipeline
TransferOperation
TranslationServiceClient
travis
triage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from datetime import timedelta
from typing import TYPE_CHECKING, Any

from google.api_core import protobuf_helpers
from google.cloud.storage_transfer_v1 import (
ListTransferJobsRequest,
StorageTransferServiceAsyncClient,
Expand All @@ -57,6 +58,7 @@
from google.cloud.storage_transfer_v1.services.storage_transfer_service.pagers import (
ListTransferJobsAsyncPager,
)
from google.longrunning import operations_pb2 # type: ignore[attr-defined]
from proto import Message

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -596,3 +598,112 @@ async def get_latest_operation(self, job: TransferJob) -> Message | None:
operation = TransferOperation.deserialize(response_operation.metadata.value)
return operation
return None

async def list_transfer_operations(
self,
request_filter: dict | None = None,
**kwargs,
) -> list[TransferOperation]:
"""
Get a transfer operation in Google Storage Transfer Service.
:param request_filter: (Required) A request filter, as described in
https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferJobs/list#body.QUERY_PARAMETERS.filter
With one additional improvement:
:return: transfer operation
The ``project_id`` parameter is optional if you have a project ID
defined in the connection. See: :doc:`/connections/gcp`
"""
# To preserve backward compatibility
# TODO: remove one day
if request_filter is None:
if "filter" in kwargs:
request_filter = kwargs["filter"]
if not isinstance(request_filter, dict):
raise ValueError(f"The request_filter should be dict and is {type(request_filter)}")
warnings.warn(
"Use 'request_filter' instead of 'filter'",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
else:
raise TypeError(
"list_transfer_operations missing 1 required positional argument: 'request_filter'"
)

conn = await self.get_conn()

request_filter = await self._inject_project_id(request_filter, FILTER, FILTER_PROJECT_ID)

operations: list[operations_pb2.Operation] = []

response = await conn.list_operations(
request={
"name": TRANSFER_OPERATIONS,
"filter": json.dumps(request_filter),
}
)

while response is not None:
operations.extend(response.operations)
response = (
await conn.list_operations(
request={
"name": TRANSFER_OPERATIONS,
"filter": json.dumps(request_filter),
"page_token": response.next_page_token,
}
)
if response.next_page_token
else None
)

transfer_operations = [
protobuf_helpers.from_any_pb(TransferOperation, op.metadata) for op in operations
]

return transfer_operations

async def _inject_project_id(self, body: dict, param_name: str, target_key: str) -> dict:
body = deepcopy(body)
body[target_key] = body.get(target_key, self.project_id)
if not body.get(target_key):
raise AirflowException(
f"The project id must be passed either as `{target_key}` key in `{param_name}` "
f"parameter or as project_id extra in Google Cloud connection definition. Both are not set!"
)
return body

@staticmethod
async def operations_contain_expected_statuses(
operations: list[TransferOperation], expected_statuses: set[str] | str
) -> bool:
"""
Check whether an operation exists with the expected status.
:param operations: (Required) List of transfer operations to check.
:param expected_statuses: (Required) The expected status. See:
https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferOperations#Status
:return: If there is an operation with the expected state in the
operation list, returns true,
:raises AirflowException: If it encounters operations with state FAILED
or ABORTED in the list.
"""
expected_statuses_set = (
{expected_statuses} if isinstance(expected_statuses, str) else set(expected_statuses)
)
if not operations:
return False

current_statuses = {operation.status.name for operation in operations}

if len(current_statuses - expected_statuses_set) != len(current_statuses):
return True

if len(NEGATIVE_STATUSES - current_statuses) != len(NEGATIVE_STATUSES):
raise AirflowException(
f"An unexpected operation status was encountered. "
f"Expected: {', '.join(expected_statuses_set)}"
)
return False
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@

from collections.abc import Sequence
from copy import deepcopy
from datetime import date, time
from typing import TYPE_CHECKING
from datetime import date, time, timedelta
from typing import TYPE_CHECKING, Any

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import (
Expand Down Expand Up @@ -63,6 +64,9 @@
CloudStorageTransferListLink,
)
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
from airflow.providers.google.cloud.triggers.cloud_storage_transfer_service import (
CloudStorageTransferServiceCheckJobStatusTrigger,
)
from airflow.providers.google.cloud.utils.helpers import normalize_directory_path
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID

Expand Down Expand Up @@ -908,6 +912,7 @@ class CloudDataTransferServiceS3ToGCSOperator(GoogleCloudBaseOperator):
:param aws_role_arn: Optional AWS role ARN for workload identity federation. This will
override the `aws_conn_id` for authentication between GCP and AWS; see
https://cloud.google.com/storage-transfer/docs/reference/rest/v1/TransferSpec#AwsS3Data
:param deferrable: Run operator in the deferrable mode.
"""

template_fields: Sequence[str] = (
Expand Down Expand Up @@ -942,6 +947,7 @@ def __init__(
google_impersonation_chain: str | Sequence[str] | None = None,
delete_job_after_completion: bool = False,
aws_role_arn: str | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -961,6 +967,7 @@ def __init__(
self.google_impersonation_chain = google_impersonation_chain
self.delete_job_after_completion = delete_job_after_completion
self.aws_role_arn = aws_role_arn
self.deferrable = deferrable
self._validate_inputs()

def _validate_inputs(self) -> None:
Expand All @@ -979,9 +986,31 @@ def execute(self, context: Context) -> None:
job = hook.create_transfer_job(body=body)

if self.wait:
hook.wait_for_transfer_job(job, timeout=self.timeout)
if self.delete_job_after_completion:
hook.delete_transfer_job(job_name=job[NAME], project_id=self.project_id)
if not self.deferrable:
hook.wait_for_transfer_job(job, timeout=self.timeout)
if self.delete_job_after_completion:
hook.delete_transfer_job(job_name=job[NAME], project_id=self.project_id)
else:
self.defer(
timeout=timedelta(seconds=self.timeout or 60),
trigger=CloudStorageTransferServiceCheckJobStatusTrigger(
job_name=job[NAME],
project_id=job[PROJECT_ID],
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.google_impersonation_chain,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
"""
Act as a callback for when the trigger fires.
This returns immediately. It relies on trigger to throw an exception,
otherwise it assumes execution was successful.
"""
if event["status"] == "error":
raise AirflowException(event["message"])

def _create_body(self) -> dict:
body = {
Expand Down Expand Up @@ -1079,6 +1108,7 @@ class CloudDataTransferServiceGCSToGCSOperator(GoogleCloudBaseOperator):
account from the list granting this role to the originating account (templated).
:param delete_job_after_completion: If True, delete the job after complete.
If set to True, 'wait' must be set to True.
:param deferrable: Run operator in the deferrable mode.
"""

# [START gcp_transfer_gcs_to_gcs_template_fields]
Expand Down Expand Up @@ -1113,6 +1143,7 @@ def __init__(
timeout: float | None = None,
google_impersonation_chain: str | Sequence[str] | None = None,
delete_job_after_completion: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -1130,6 +1161,7 @@ def __init__(
self.timeout = timeout
self.google_impersonation_chain = google_impersonation_chain
self.delete_job_after_completion = delete_job_after_completion
self.deferrable = deferrable
self._validate_inputs()

def _validate_inputs(self) -> None:
Expand All @@ -1149,9 +1181,31 @@ def execute(self, context: Context) -> None:
job = hook.create_transfer_job(body=body)

if self.wait:
hook.wait_for_transfer_job(job, timeout=self.timeout)
if self.delete_job_after_completion:
hook.delete_transfer_job(job_name=job[NAME], project_id=self.project_id)
if not self.deferrable:
hook.wait_for_transfer_job(job, timeout=self.timeout)
if self.delete_job_after_completion:
hook.delete_transfer_job(job_name=job[NAME], project_id=self.project_id)
else:
self.defer(
timeout=timedelta(seconds=self.timeout or 60),
trigger=CloudStorageTransferServiceCheckJobStatusTrigger(
job_name=job[NAME],
project_id=job[PROJECT_ID],
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.google_impersonation_chain,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
"""
Act as a callback for when the trigger fires.
This returns immediately. It relies on trigger to throw an exception,
otherwise it assumes execution was successful.
"""
if event["status"] == "error":
raise AirflowException(event["message"])

def _create_body(self) -> dict:
body = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,20 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import (
COUNTERS,
METADATA,
NAME,
CloudDataTransferServiceHook,
)
from airflow.providers.google.cloud.links.cloud_storage_transfer import CloudStorageTransferJobLink
from airflow.providers.google.cloud.triggers.cloud_storage_transfer_service import (
CloudStorageTransferServiceCheckJobStatusTrigger,
)
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
from airflow.sensors.base import BaseSensorOperator

Expand Down Expand Up @@ -60,6 +65,7 @@ class CloudDataTransferServiceJobStatusSensor(BaseSensorOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: Run sensor in deferrable mode
"""

# [START gcp_transfer_job_sensor_template_fields]
Expand All @@ -78,6 +84,7 @@ def __init__(
project_id: str = PROVIDE_PROJECT_ID,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -88,6 +95,7 @@ def __init__(
self.project_id = project_id
self.gcp_cloud_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable

def poke(self, context: Context) -> bool:
hook = CloudDataTransferServiceHook(
Expand Down Expand Up @@ -117,3 +125,33 @@ def poke(self, context: Context) -> bool:
)

return check

def execute(self, context: Context) -> None:
"""Run on the worker and defer using the triggers if deferrable is set to True."""
if not self.deferrable:
super().execute(context)
elif not self.poke(context=context):
self.defer(
timeout=self.execution_timeout,
trigger=CloudStorageTransferServiceCheckJobStatusTrigger(
job_name=self.job_name,
expected_statuses=self.expected_statuses,
project_id=self.project_id,
poke_interval=self.poke_interval,
gcp_conn_id=self.gcp_cloud_conn_id,
impersonation_chain=self.impersonation_chain,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
"""
Act as a callback for when the trigger fires.
This returns immediately. It relies on trigger to throw an exception,
otherwise it assumes execution was successful.
"""
if event["status"] == "error":
raise AirflowException(event["message"])

self.xcom_push(key="sensed_operations", value=event["operations"], context=context)
Loading

0 comments on commit 132c0c2

Please sign in to comment.