diff --git a/airflow/providers/amazon/aws/transfers/gcs_to_s3.py b/airflow/providers/amazon/aws/transfers/gcs_to_s3.py index d57de7e11efb2..2cdd0761a1344 100644 --- a/airflow/providers/amazon/aws/transfers/gcs_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/gcs_to_s3.py @@ -80,6 +80,8 @@ class GCSToS3Operator(BaseOperator): on the bucket is recreated within path passed in dest_s3_key. :param match_glob: (Optional) filters objects based on the glob pattern given by the string (e.g, ``'**/*/.json'``) + :param gcp_user_project: (Optional) The identifier of the Google Cloud project to bill for this request. + Required for Requester Pays buckets. """ template_fields: Sequence[str] = ( @@ -88,6 +90,7 @@ class GCSToS3Operator(BaseOperator): "delimiter", "dest_s3_key", "google_impersonation_chain", + "gcp_user_project", ) ui_color = "#f0eee4" @@ -107,6 +110,7 @@ def __init__( s3_acl_policy: str | None = None, keep_directory_structure: bool = True, match_glob: str | None = None, + gcp_user_project: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -130,10 +134,11 @@ def __init__( self.s3_acl_policy = s3_acl_policy self.keep_directory_structure = keep_directory_structure self.match_glob = match_glob + self.gcp_user_project = gcp_user_project def execute(self, context: Context) -> list[str]: # list all files in an Google Cloud Storage bucket - hook = GCSHook( + gcs_hook = GCSHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.google_impersonation_chain, ) @@ -145,8 +150,12 @@ def execute(self, context: Context) -> list[str]: self.prefix, ) - files = hook.list( - bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter, match_glob=self.match_glob + gcs_files = gcs_hook.list( + bucket_name=self.bucket, + prefix=self.prefix, + delimiter=self.delimiter, + match_glob=self.match_glob, + user_project=self.gcp_user_project, ) s3_hook = S3Hook( @@ -173,24 +182,23 @@ def execute(self, context: Context) -> list[str]: existing_files = existing_files if existing_files is not None else [] # remove the prefix for the existing files to allow the match existing_files = [file.replace(prefix, "", 1) for file in existing_files] - files = list(set(files) - set(existing_files)) + gcs_files = list(set(gcs_files) - set(existing_files)) - if files: - - for file in files: - with hook.provide_file(object_name=file, bucket_name=self.bucket) as local_tmp_file: + if gcs_files: + for file in gcs_files: + with gcs_hook.provide_file( + object_name=file, bucket_name=self.bucket, user_project=self.gcp_user_project + ) as local_tmp_file: dest_key = os.path.join(self.dest_s3_key, file) self.log.info("Saving file to %s", dest_key) - s3_hook.load_file( filename=local_tmp_file.name, key=dest_key, replace=self.replace, acl_policy=self.s3_acl_policy, ) - - self.log.info("All done, uploaded %d files to S3", len(files)) + self.log.info("All done, uploaded %d files to S3", len(gcs_files)) else: self.log.info("In sync, no files needed to be uploaded to S3") - return files + return gcs_files diff --git a/airflow/providers/google/cloud/hooks/gcs.py b/airflow/providers/google/cloud/hooks/gcs.py index a42f81c35506d..a01bf72259404 100644 --- a/airflow/providers/google/cloud/hooks/gcs.py +++ b/airflow/providers/google/cloud/hooks/gcs.py @@ -197,7 +197,6 @@ def copy( destination_object = destination_object or source_object if source_bucket == destination_bucket and source_object == destination_object: - raise ValueError( f"Either source/destination bucket or source/destination object must be different, " f"not both the same: bucket={source_bucket}, object={source_object}" @@ -282,6 +281,7 @@ def download( chunk_size: int | None = None, timeout: int | None = DEFAULT_TIMEOUT, num_max_attempts: int | None = 1, + user_project: str | None = None, ) -> bytes: ... @@ -294,6 +294,7 @@ def download( chunk_size: int | None = None, timeout: int | None = DEFAULT_TIMEOUT, num_max_attempts: int | None = 1, + user_project: str | None = None, ) -> str: ... @@ -305,6 +306,7 @@ def download( chunk_size: int | None = None, timeout: int | None = DEFAULT_TIMEOUT, num_max_attempts: int | None = 1, + user_project: str | None = None, ) -> str | bytes: """ Downloads a file from Google Cloud Storage. @@ -320,6 +322,8 @@ def download( :param chunk_size: Blob chunk size. :param timeout: Request timeout in seconds. :param num_max_attempts: Number of attempts to download the file. + :param user_project: The identifier of the Google Cloud project to bill for the request. + Required for Requester Pays buckets. """ # TODO: future improvement check file size before downloading, # to check for local space availability @@ -330,7 +334,7 @@ def download( try: num_file_attempts += 1 client = self.get_conn() - bucket = client.bucket(bucket_name) + bucket = client.bucket(bucket_name, user_project=user_project) blob = bucket.blob(blob_name=object_name, chunk_size=chunk_size) if filename: @@ -395,6 +399,7 @@ def provide_file( object_name: str | None = None, object_url: str | None = None, dir: str | None = None, + user_project: str | None = None, ) -> Generator[IO[bytes], None, None]: """ Downloads the file to a temporary directory and returns a file handle. @@ -406,13 +411,20 @@ def provide_file( :param object_name: The object to fetch. :param object_url: File reference url. Must start with "gs: //" :param dir: The tmp sub directory to download the file to. (passed to NamedTemporaryFile) + :param user_project: The identifier of the Google Cloud project to bill for the request. + Required for Requester Pays buckets. :return: File handler """ if object_name is None: raise ValueError("Object name can not be empty") _, _, file_name = object_name.rpartition("/") with NamedTemporaryFile(suffix=file_name, dir=dir) as tmp_file: - self.download(bucket_name=bucket_name, object_name=object_name, filename=tmp_file.name) + self.download( + bucket_name=bucket_name, + object_name=object_name, + filename=tmp_file.name, + user_project=user_project, + ) tmp_file.flush() yield tmp_file @@ -423,6 +435,7 @@ def provide_file_and_upload( bucket_name: str = PROVIDE_BUCKET, object_name: str | None = None, object_url: str | None = None, + user_project: str | None = None, ) -> Generator[IO[bytes], None, None]: """ Creates temporary file, returns a file handle and uploads the files content on close. @@ -433,6 +446,8 @@ def provide_file_and_upload( :param bucket_name: The bucket to fetch from. :param object_name: The object to fetch. :param object_url: File reference url. Must start with "gs: //" + :param user_project: The identifier of the Google Cloud project to bill for the request. + Required for Requester Pays buckets. :return: File handler """ if object_name is None: @@ -442,7 +457,12 @@ def provide_file_and_upload( with NamedTemporaryFile(suffix=file_name) as tmp_file: yield tmp_file tmp_file.flush() - self.upload(bucket_name=bucket_name, object_name=object_name, filename=tmp_file.name) + self.upload( + bucket_name=bucket_name, + object_name=object_name, + filename=tmp_file.name, + user_project=user_project, + ) def upload( self, @@ -458,6 +478,7 @@ def upload( num_max_attempts: int = 1, metadata: dict | None = None, cache_control: str | None = None, + user_project: str | None = None, ) -> None: """ Uploads a local file or file data as string or bytes to Google Cloud Storage. @@ -474,6 +495,8 @@ def upload( :param num_max_attempts: Number of attempts to try to upload the file. :param metadata: The metadata to be uploaded with the file. :param cache_control: Cache-Control metadata field. + :param user_project: The identifier of the Google Cloud project to bill for the request. + Required for Requester Pays buckets. """ def _call_with_retry(f: Callable[[], None]) -> None: @@ -506,7 +529,7 @@ def _call_with_retry(f: Callable[[], None]) -> None: continue client = self.get_conn() - bucket = client.bucket(bucket_name) + bucket = client.bucket(bucket_name, user_project=user_project) blob = bucket.blob(blob_name=object_name, chunk_size=chunk_size) if metadata: @@ -596,7 +619,6 @@ def is_updated_after(self, bucket_name: str, object_name: str, ts: datetime) -> """ blob_update_time = self.get_blob_update_time(bucket_name, object_name) if blob_update_time is not None: - if not ts.tzinfo: ts = ts.replace(tzinfo=timezone.utc) self.log.info("Verify object date: %s > %s", blob_update_time, ts) @@ -618,7 +640,6 @@ def is_updated_between( """ blob_update_time = self.get_blob_update_time(bucket_name, object_name) if blob_update_time is not None: - if not min_ts.tzinfo: min_ts = min_ts.replace(tzinfo=timezone.utc) if not max_ts.tzinfo: @@ -639,7 +660,6 @@ def is_updated_before(self, bucket_name: str, object_name: str, ts: datetime) -> """ blob_update_time = self.get_blob_update_time(bucket_name, object_name) if blob_update_time is not None: - if not ts.tzinfo: ts = ts.replace(tzinfo=timezone.utc) self.log.info("Verify object date: %s < %s", blob_update_time, ts) @@ -681,16 +701,18 @@ def delete(self, bucket_name: str, object_name: str) -> None: self.log.info("Blob %s deleted.", object_name) - def delete_bucket(self, bucket_name: str, force: bool = False) -> None: + def delete_bucket(self, bucket_name: str, force: bool = False, user_project: str | None = None) -> None: """ Delete a bucket object from the Google Cloud Storage. :param bucket_name: name of the bucket which will be deleted :param force: false not allow to delete non empty bucket, set force=True allows to delete non empty bucket + :param user_project: The identifier of the Google Cloud project to bill for the request. + Required for Requester Pays buckets. """ client = self.get_conn() - bucket = client.bucket(bucket_name) + bucket = client.bucket(bucket_name, user_project=user_project) self.log.info("Deleting %s bucket", bucket_name) try: @@ -707,6 +729,7 @@ def list( prefix: str | List[str] | None = None, delimiter: str | None = None, match_glob: str | None = None, + user_project: str | None = None, ): """ List all objects from the bucket with the given a single prefix or multiple prefixes. @@ -718,6 +741,8 @@ def list( :param delimiter: (Deprecated) filters objects based on the delimiter (for e.g '.csv') :param match_glob: (Optional) filters objects based on the glob pattern given by the string (e.g, ``'**/*/.json'``). + :param user_project: The identifier of the Google Cloud project to bill for the request. + Required for Requester Pays buckets. :return: a stream of object names matching the filtering criteria """ if delimiter and delimiter != "/": @@ -739,6 +764,7 @@ def list( prefix=prefix_item, delimiter=delimiter, match_glob=match_glob, + user_project=user_project, ) ) else: @@ -750,6 +776,7 @@ def list( prefix=prefix, delimiter=delimiter, match_glob=match_glob, + user_project=user_project, ) ) return objects @@ -762,6 +789,7 @@ def _list( prefix: str | None = None, delimiter: str | None = None, match_glob: str | None = None, + user_project: str | None = None, ) -> List: """ List all objects from the bucket with the give string prefix in name. @@ -773,10 +801,12 @@ def _list( :param delimiter: (Deprecated) filters objects based on the delimiter (for e.g '.csv') :param match_glob: (Optional) filters objects based on the glob pattern given by the string (e.g, ``'**/*/.json'``). + :param user_project: The identifier of the Google Cloud project to bill for the request. + Required for Requester Pays buckets. :return: a stream of object names matching the filtering criteria """ client = self.get_conn() - bucket = client.bucket(bucket_name) + bucket = client.bucket(bucket_name, user_project=user_project) ids = [] page_token = None diff --git a/airflow/providers/google/cloud/operators/gcs.py b/airflow/providers/google/cloud/operators/gcs.py index fd73af42fb666..9b95032b4264c 100644 --- a/airflow/providers/google/cloud/operators/gcs.py +++ b/airflow/providers/google/cloud/operators/gcs.py @@ -301,7 +301,6 @@ def __init__( impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: - self.bucket_name = bucket_name self.objects = objects self.prefix = prefix @@ -875,12 +874,15 @@ class GCSDeleteBucketOperator(GoogleCloudBaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param user_project: (Optional) The identifier of the project to bill for this request. + Required for Requester Pays buckets. """ template_fields: Sequence[str] = ( "bucket_name", "gcp_conn_id", "impersonation_chain", + "user_project", ) def __init__( @@ -890,6 +892,7 @@ def __init__( force: bool = True, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, + user_project: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -898,10 +901,11 @@ def __init__( self.force: bool = force self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + self.user_project = user_project def execute(self, context: Context) -> None: hook = GCSHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) - hook.delete_bucket(bucket_name=self.bucket_name, force=self.force) + hook.delete_bucket(bucket_name=self.bucket_name, force=self.force, user_project=self.user_project) class GCSSynchronizeBucketsOperator(GoogleCloudBaseOperator): diff --git a/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py b/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py index 5e64f167ba453..9d5c497de1413 100644 --- a/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py +++ b/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py @@ -69,7 +69,11 @@ def test_execute__match_glob(self, mock_hook): operator.execute(None) mock_hook.return_value.list.assert_called_once_with( - bucket_name=GCS_BUCKET, delimiter=None, match_glob=f"**/*{DELIMITER}", prefix=PREFIX + bucket_name=GCS_BUCKET, + delimiter=None, + match_glob=f"**/*{DELIMITER}", + prefix=PREFIX, + user_project=None, ) @mock.patch("airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook") diff --git a/tests/providers/google/cloud/hooks/test_gcs.py b/tests/providers/google/cloud/hooks/test_gcs.py index 61ce7a4162714..4f1839fc42c02 100644 --- a/tests/providers/google/cloud/hooks/test_gcs.py +++ b/tests/providers/google/cloud/hooks/test_gcs.py @@ -503,7 +503,7 @@ def test_delete_bucket(self, mock_service): self.gcs_hook.delete_bucket(bucket_name=test_bucket) - mock_service.return_value.bucket.assert_called_once_with(test_bucket) + mock_service.return_value.bucket.assert_called_once_with(test_bucket, user_project=None) mock_service.return_value.bucket.return_value.delete.assert_called_once() @mock.patch(GCS_STRING.format("GCSHook.get_conn")) @@ -514,7 +514,7 @@ def test_delete_nonexisting_bucket(self, mock_service, caplog): test_bucket = "test bucket" with caplog.at_level(logging.INFO): self.gcs_hook.delete_bucket(bucket_name=test_bucket) - mock_service.return_value.bucket.assert_called_once_with(test_bucket) + mock_service.return_value.bucket.assert_called_once_with(test_bucket, user_project=None) mock_service.return_value.bucket.return_value.delete.assert_called_once() assert "Bucket test bucket not exist" in caplog.text @@ -784,7 +784,7 @@ def test_provide_file_upload(self, mock_upload, mock_temp_file): fhandle.write() mock_upload.assert_called_once_with( - bucket_name=test_bucket, object_name=test_object, filename=test_file + bucket_name=test_bucket, object_name=test_object, filename=test_file, user_project=None ) mock_temp_file.assert_has_calls( [ diff --git a/tests/providers/google/cloud/operators/test_gcs.py b/tests/providers/google/cloud/operators/test_gcs.py index 4ceaa5292bb4d..815cad300df89 100644 --- a/tests/providers/google/cloud/operators/test_gcs.py +++ b/tests/providers/google/cloud/operators/test_gcs.py @@ -196,7 +196,6 @@ class TestGCSFileTransformOperator: @mock.patch("airflow.providers.google.cloud.operators.gcs.subprocess") @mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook") def test_execute(self, mock_hook, mock_subprocess, mock_tempfile): - source_bucket = TEST_BUCKET source_object = "test.txt" destination_bucket = TEST_BUCKET + "-dest" @@ -416,7 +415,9 @@ def test_delete_bucket(self, mock_hook): operator = GCSDeleteBucketOperator(task_id=TASK_ID, bucket_name=TEST_BUCKET) operator.execute(None) - mock_hook.return_value.delete_bucket.assert_called_once_with(bucket_name=TEST_BUCKET, force=True) + mock_hook.return_value.delete_bucket.assert_called_once_with( + bucket_name=TEST_BUCKET, force=True, user_project=None + ) class TestGoogleCloudStorageSync: diff --git a/tests/system/providers/amazon/aws/example_gcs_to_s3.py b/tests/system/providers/amazon/aws/example_gcs_to_s3.py index c0182f2d099ff..68db86f82a55f 100644 --- a/tests/system/providers/amazon/aws/example_gcs_to_s3.py +++ b/tests/system/providers/amazon/aws/example_gcs_to_s3.py @@ -19,13 +19,25 @@ from datetime import datetime from airflow import DAG +from airflow.decorators import task from airflow.models.baseoperator import chain -from airflow.providers.amazon.aws.operators.s3 import S3CreateBucketOperator, S3DeleteBucketOperator +from airflow.providers.amazon.aws.operators.s3 import ( + S3CreateBucketOperator, + S3DeleteBucketOperator, +) from airflow.providers.amazon.aws.transfers.gcs_to_s3 import GCSToS3Operator +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.cloud.operators.gcs import ( + GCSCreateBucketOperator, + GCSDeleteBucketOperator, +) from airflow.utils.trigger_rule import TriggerRule from tests.system.providers.amazon.aws.utils import SystemTestContextBuilder -sys_test_context_task = SystemTestContextBuilder().build() +# Externally fetched variables: +GCP_PROJECT_ID = "GCP_PROJECT_ID" + +sys_test_context_task = SystemTestContextBuilder().add_variable(GCP_PROJECT_ID).build() DAG_ID = "example_gcs_to_s3" @@ -38,18 +50,40 @@ ) as dag: test_context = sys_test_context_task() env_id = test_context["ENV_ID"] + gcp_user_project = test_context[GCP_PROJECT_ID] s3_bucket = f"{env_id}-gcs-to-s3-bucket" s3_key = f"{env_id}-gcs-to-s3-key" create_s3_bucket = S3CreateBucketOperator(task_id="create_s3_bucket", bucket_name=s3_bucket) + gcs_bucket = f"{env_id}-gcs-to-s3-bucket" + gcs_key = f"{env_id}-gcs-to-s3-key" + + create_gcs_bucket = GCSCreateBucketOperator( + task_id="create_gcs_bucket", + bucket_name=gcs_bucket, + resource={"billing": {"requesterPays": True}}, + project_id=gcp_user_project, + ) + + @task + def upload_gcs_file(bucket_name: str, object_name: str, user_project: str): + hook = GCSHook() + with hook.provide_file_and_upload( + bucket_name=bucket_name, + object_name=object_name, + user_project=user_project, + ) as temp_file: + temp_file.write(b"test") + # [START howto_transfer_gcs_to_s3] gcs_to_s3 = GCSToS3Operator( task_id="gcs_to_s3", - bucket=s3_bucket, - dest_s3_key=s3_key, + bucket=gcs_bucket, + dest_s3_key=f"s3://{s3_bucket}/{s3_key}", replace=True, + gcp_user_project=gcp_user_project, ) # [END howto_transfer_gcs_to_s3] @@ -60,14 +94,24 @@ trigger_rule=TriggerRule.ALL_DONE, ) + delete_gcs_bucket = GCSDeleteBucketOperator( + task_id="delete_gcs_bucket", + bucket_name=gcs_bucket, + trigger_rule=TriggerRule.ALL_DONE, + user_project=gcp_user_project, + ) + chain( # TEST SETUP test_context, + create_gcs_bucket, + upload_gcs_file(gcs_bucket, gcs_key, gcp_user_project), create_s3_bucket, # TEST BODY gcs_to_s3, # TEST TEARDOWN delete_s3_bucket, + delete_gcs_bucket, ) from tests.system.utils.watcher import watcher