Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add deferrable mode to google cloud storage transfer sensor and operators #45754

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1755,6 +1755,7 @@ traceback
tracebacks
tracemalloc
TrainingPipeline
TransferOperation
TranslationServiceClient
travis
triage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from datetime import timedelta
from typing import TYPE_CHECKING, Any

from google.api_core import protobuf_helpers
from google.cloud.storage_transfer_v1 import (
ListTransferJobsRequest,
StorageTransferServiceAsyncClient,
Expand All @@ -57,6 +58,7 @@
from google.cloud.storage_transfer_v1.services.storage_transfer_service.pagers import (
ListTransferJobsAsyncPager,
)
from google.longrunning import operations_pb2 # type: ignore[attr-defined]
from proto import Message

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -596,3 +598,112 @@ async def get_latest_operation(self, job: TransferJob) -> Message | None:
operation = TransferOperation.deserialize(response_operation.metadata.value)
return operation
return None

async def list_transfer_operations(
self,
request_filter: dict | None = None,
**kwargs,
) -> list[TransferOperation]:
"""
Get a transfer operation in Google Storage Transfer Service.

:param request_filter: (Required) A request filter, as described in
https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferJobs/list#body.QUERY_PARAMETERS.filter
With one additional improvement:
:return: transfer operation

The ``project_id`` parameter is optional if you have a project ID
defined in the connection. See: :doc:`/connections/gcp`
"""
# To preserve backward compatibility
# TODO: remove one day
if request_filter is None:
if "filter" in kwargs:
request_filter = kwargs["filter"]
if not isinstance(request_filter, dict):
raise ValueError(f"The request_filter should be dict and is {type(request_filter)}")
warnings.warn(
"Use 'request_filter' instead of 'filter'",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
else:
raise TypeError(
"list_transfer_operations missing 1 required positional argument: 'request_filter'"
)

conn = await self.get_conn()

request_filter = await self._inject_project_id(request_filter, FILTER, FILTER_PROJECT_ID)

operations: list[operations_pb2.Operation] = []

response = await conn.list_operations(
request={
"name": TRANSFER_OPERATIONS,
"filter": json.dumps(request_filter),
}
)

while response is not None:
operations.extend(response.operations)
response = (
await conn.list_operations(
request={
"name": TRANSFER_OPERATIONS,
"filter": json.dumps(request_filter),
"page_token": response.next_page_token,
}
)
if response.next_page_token
else None
)

transfer_operations = [
protobuf_helpers.from_any_pb(TransferOperation, op.metadata) for op in operations
]

return transfer_operations

async def _inject_project_id(self, body: dict, param_name: str, target_key: str) -> dict:
body = deepcopy(body)
body[target_key] = body.get(target_key, self.project_id)
if not body.get(target_key):
raise AirflowException(
f"The project id must be passed either as `{target_key}` key in `{param_name}` "
f"parameter or as project_id extra in Google Cloud connection definition. Both are not set!"
)
return body

@staticmethod
async def operations_contain_expected_statuses(
operations: list[TransferOperation], expected_statuses: set[str] | str
) -> bool:
"""
Check whether an operation exists with the expected status.

:param operations: (Required) List of transfer operations to check.
:param expected_statuses: (Required) The expected status. See:
https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferOperations#Status
:return: If there is an operation with the expected state in the
operation list, returns true,
:raises AirflowException: If it encounters operations with state FAILED
or ABORTED in the list.
"""
expected_statuses_set = (
{expected_statuses} if isinstance(expected_statuses, str) else set(expected_statuses)
)
if not operations:
return False

current_statuses = {operation.status.name for operation in operations}

if len(current_statuses - expected_statuses_set) != len(current_statuses):
return True

if len(NEGATIVE_STATUSES - current_statuses) != len(NEGATIVE_STATUSES):
raise AirflowException(
f"An unexpected operation status was encountered. "
f"Expected: {', '.join(expected_statuses_set)}"
)
return False
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@

from collections.abc import Sequence
from copy import deepcopy
from datetime import date, time
from typing import TYPE_CHECKING
from datetime import date, time, timedelta
from typing import TYPE_CHECKING, Any

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import (
Expand Down Expand Up @@ -63,6 +64,9 @@
CloudStorageTransferListLink,
)
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
from airflow.providers.google.cloud.triggers.cloud_storage_transfer_service import (
CloudStorageTransferServiceCheckJobStatusTrigger,
)
from airflow.providers.google.cloud.utils.helpers import normalize_directory_path
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID

Expand Down Expand Up @@ -908,6 +912,7 @@ class CloudDataTransferServiceS3ToGCSOperator(GoogleCloudBaseOperator):
:param aws_role_arn: Optional AWS role ARN for workload identity federation. This will
override the `aws_conn_id` for authentication between GCP and AWS; see
https://cloud.google.com/storage-transfer/docs/reference/rest/v1/TransferSpec#AwsS3Data
:param deferrable: Run operator in the deferrable mode.
"""

template_fields: Sequence[str] = (
Expand Down Expand Up @@ -942,6 +947,7 @@ def __init__(
google_impersonation_chain: str | Sequence[str] | None = None,
delete_job_after_completion: bool = False,
aws_role_arn: str | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -961,6 +967,7 @@ def __init__(
self.google_impersonation_chain = google_impersonation_chain
self.delete_job_after_completion = delete_job_after_completion
self.aws_role_arn = aws_role_arn
self.deferrable = deferrable
self._validate_inputs()

def _validate_inputs(self) -> None:
Expand All @@ -979,9 +986,31 @@ def execute(self, context: Context) -> None:
job = hook.create_transfer_job(body=body)

if self.wait:
hook.wait_for_transfer_job(job, timeout=self.timeout)
if self.delete_job_after_completion:
hook.delete_transfer_job(job_name=job[NAME], project_id=self.project_id)
if not self.deferrable:
hook.wait_for_transfer_job(job, timeout=self.timeout)
if self.delete_job_after_completion:
hook.delete_transfer_job(job_name=job[NAME], project_id=self.project_id)
else:
self.defer(
timeout=timedelta(seconds=self.timeout or 60),
trigger=CloudStorageTransferServiceCheckJobStatusTrigger(
job_name=job[NAME],
project_id=job[PROJECT_ID],
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.google_impersonation_chain,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
"""
Act as a callback for when the trigger fires.

This returns immediately. It relies on trigger to throw an exception,
otherwise it assumes execution was successful.
"""
if event["status"] == "error":
raise AirflowException(event["message"])

def _create_body(self) -> dict:
body = {
Expand Down Expand Up @@ -1079,6 +1108,7 @@ class CloudDataTransferServiceGCSToGCSOperator(GoogleCloudBaseOperator):
account from the list granting this role to the originating account (templated).
:param delete_job_after_completion: If True, delete the job after complete.
If set to True, 'wait' must be set to True.
:param deferrable: Run operator in the deferrable mode.
"""

# [START gcp_transfer_gcs_to_gcs_template_fields]
Expand Down Expand Up @@ -1113,6 +1143,7 @@ def __init__(
timeout: float | None = None,
google_impersonation_chain: str | Sequence[str] | None = None,
delete_job_after_completion: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -1130,6 +1161,7 @@ def __init__(
self.timeout = timeout
self.google_impersonation_chain = google_impersonation_chain
self.delete_job_after_completion = delete_job_after_completion
self.deferrable = deferrable
self._validate_inputs()

def _validate_inputs(self) -> None:
Expand All @@ -1149,9 +1181,31 @@ def execute(self, context: Context) -> None:
job = hook.create_transfer_job(body=body)

if self.wait:
hook.wait_for_transfer_job(job, timeout=self.timeout)
if self.delete_job_after_completion:
hook.delete_transfer_job(job_name=job[NAME], project_id=self.project_id)
if not self.deferrable:
hook.wait_for_transfer_job(job, timeout=self.timeout)
if self.delete_job_after_completion:
hook.delete_transfer_job(job_name=job[NAME], project_id=self.project_id)
else:
self.defer(
timeout=timedelta(seconds=self.timeout or 60),
trigger=CloudStorageTransferServiceCheckJobStatusTrigger(
job_name=job[NAME],
project_id=job[PROJECT_ID],
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.google_impersonation_chain,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
"""
Act as a callback for when the trigger fires.

This returns immediately. It relies on trigger to throw an exception,
otherwise it assumes execution was successful.
"""
if event["status"] == "error":
raise AirflowException(event["message"])

def _create_body(self) -> dict:
body = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,20 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import (
COUNTERS,
METADATA,
NAME,
CloudDataTransferServiceHook,
)
from airflow.providers.google.cloud.links.cloud_storage_transfer import CloudStorageTransferJobLink
from airflow.providers.google.cloud.triggers.cloud_storage_transfer_service import (
CloudStorageTransferServiceCheckJobStatusTrigger,
)
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
from airflow.sensors.base import BaseSensorOperator

Expand Down Expand Up @@ -60,6 +65,7 @@ class CloudDataTransferServiceJobStatusSensor(BaseSensorOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: Run sensor in deferrable mode
"""

# [START gcp_transfer_job_sensor_template_fields]
Expand All @@ -78,6 +84,7 @@ def __init__(
project_id: str = PROVIDE_PROJECT_ID,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -88,6 +95,7 @@ def __init__(
self.project_id = project_id
self.gcp_cloud_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable

def poke(self, context: Context) -> bool:
hook = CloudDataTransferServiceHook(
Expand Down Expand Up @@ -117,3 +125,33 @@ def poke(self, context: Context) -> bool:
)

return check

def execute(self, context: Context) -> None:
"""Run on the worker and defer using the triggers if deferrable is set to True."""
if not self.deferrable:
super().execute(context)
elif not self.poke(context=context):
self.defer(
timeout=self.execution_timeout,
trigger=CloudStorageTransferServiceCheckJobStatusTrigger(
job_name=self.job_name,
expected_statuses=self.expected_statuses,
project_id=self.project_id,
poke_interval=self.poke_interval,
gcp_conn_id=self.gcp_cloud_conn_id,
impersonation_chain=self.impersonation_chain,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
"""
Act as a callback for when the trigger fires.

This returns immediately. It relies on trigger to throw an exception,
otherwise it assumes execution was successful.
"""
if event["status"] == "error":
raise AirflowException(event["message"])

self.xcom_push(key="sensed_operations", value=event["operations"], context=context)
Loading