diff --git a/airflow/providers/google/cloud/hooks/gcs.py b/airflow/providers/google/cloud/hooks/gcs.py index fb48fcd1906096..c67addd17dac45 100644 --- a/airflow/providers/google/cloud/hooks/gcs.py +++ b/airflow/providers/google/cloud/hooks/gcs.py @@ -43,6 +43,7 @@ from requests import Session from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning +from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector from airflow.providers.google.cloud.utils.helpers import normalize_directory_path from airflow.providers.google.common.consts import CLIENT_INFO from airflow.providers.google.common.hooks.base_google import ( @@ -214,6 +215,16 @@ def copy( destination_object = source_bucket.copy_blob( # type: ignore[attr-defined] blob=source_object, destination_bucket=destination_bucket, new_name=destination_object ) + get_hook_lineage_collector().add_input_dataset( + context=self, + scheme="gs", + dataset_kwargs={"bucket": source_bucket.name, "key": source_object.name}, # type: ignore[attr-defined] + ) + get_hook_lineage_collector().add_output_dataset( + context=self, + scheme="gs", + dataset_kwargs={"bucket": destination_bucket.name, "key": destination_object.name}, # type: ignore[union-attr] + ) self.log.info( "Object %s in bucket %s copied to object %s in bucket %s", @@ -267,6 +278,16 @@ def rewrite( ).rewrite(source=source_object, token=token) self.log.info("Total Bytes: %s | Bytes Written: %s", total_bytes, bytes_rewritten) + get_hook_lineage_collector().add_input_dataset( + context=self, + scheme="gs", + dataset_kwargs={"bucket": source_bucket.name, "key": source_object.name}, # type: ignore[attr-defined] + ) + get_hook_lineage_collector().add_output_dataset( + context=self, + scheme="gs", + dataset_kwargs={"bucket": destination_bucket.name, "key": destination_object}, # type: ignore[attr-defined] + ) self.log.info( "Object %s in bucket %s rewritten to object %s in bucket %s", source_object.name, # type: ignore[attr-defined] @@ -345,9 +366,18 @@ def download( if filename: blob.download_to_filename(filename, timeout=timeout) + get_hook_lineage_collector().add_input_dataset( + context=self, scheme="gs", dataset_kwargs={"bucket": bucket.name, "key": blob.name} + ) + get_hook_lineage_collector().add_output_dataset( + context=self, scheme="file", dataset_kwargs={"path": filename} + ) self.log.info("File downloaded to %s", filename) return filename else: + get_hook_lineage_collector().add_input_dataset( + context=self, scheme="gs", dataset_kwargs={"bucket": bucket.name, "key": blob.name} + ) return blob.download_as_bytes() except GoogleCloudError: @@ -555,6 +585,9 @@ def _call_with_retry(f: Callable[[], None]) -> None: _call_with_retry( partial(blob.upload_from_filename, filename=filename, content_type=mime_type, timeout=timeout) ) + get_hook_lineage_collector().add_input_dataset( + context=self, scheme="file", dataset_kwargs={"path": filename} + ) if gzip: os.remove(filename) @@ -576,6 +609,10 @@ def _call_with_retry(f: Callable[[], None]) -> None: else: raise ValueError("'filename' and 'data' parameter missing. One is required to upload to gcs.") + get_hook_lineage_collector().add_output_dataset( + context=self, scheme="gs", dataset_kwargs={"bucket": bucket.name, "key": blob.name} + ) + def exists(self, bucket_name: str, object_name: str, retry: Retry = DEFAULT_RETRY) -> bool: """ Check for the existence of a file in Google Cloud Storage. @@ -691,6 +728,9 @@ def delete(self, bucket_name: str, object_name: str) -> None: bucket = client.bucket(bucket_name) blob = bucket.blob(blob_name=object_name) blob.delete() + get_hook_lineage_collector().add_input_dataset( + context=self, scheme="gs", dataset_kwargs={"bucket": bucket.name, "key": blob.name} + ) self.log.info("Blob %s deleted.", object_name) @@ -1198,9 +1238,17 @@ def compose(self, bucket_name: str, source_objects: List[str], destination_objec client = self.get_conn() bucket = client.bucket(bucket_name) destination_blob = bucket.blob(destination_object) - destination_blob.compose( - sources=[bucket.blob(blob_name=source_object) for source_object in source_objects] + source_blobs = [bucket.blob(blob_name=source_object) for source_object in source_objects] + destination_blob.compose(sources=source_blobs) + get_hook_lineage_collector().add_output_dataset( + context=self, scheme="gs", dataset_kwargs={"bucket": bucket.name, "key": destination_blob.name} ) + for single_source_blob in source_blobs: + get_hook_lineage_collector().add_input_dataset( + context=self, + scheme="gs", + dataset_kwargs={"bucket": bucket.name, "key": single_source_blob.name}, + ) self.log.info("Completed successfully.") diff --git a/airflow/providers/google/datasets/gcs.py b/airflow/providers/google/datasets/gcs.py new file mode 100644 index 00000000000000..fc23532279f84c --- /dev/null +++ b/airflow/providers/google/datasets/gcs.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.datasets import Dataset +from airflow.providers.google.cloud.hooks.gcs import _parse_gcs_url + +if TYPE_CHECKING: + from urllib.parse import SplitResult + + from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset + + +def create_dataset(*, bucket: str, key: str, extra: dict | None = None) -> Dataset: + return Dataset(uri=f"gs://{bucket}/{key}", extra=extra) + + +def sanitize_uri(uri: SplitResult) -> SplitResult: + if not uri.netloc: + raise ValueError("URI format gs:// must contain a bucket name") + return uri + + +def convert_dataset_to_openlineage(dataset: Dataset, lineage_context) -> OpenLineageDataset: + """Translate Dataset with valid AIP-60 uri to OpenLineage with assistance from the hook.""" + from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset + + bucket, key = _parse_gcs_url(dataset.uri) + return OpenLineageDataset(namespace=f"gs://{bucket}", name=key if key else "/") diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 612fd8e29bac39..99b66efe58dd63 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -773,6 +773,10 @@ dataset-uris: handler: null - schemes: [bigquery] handler: airflow.providers.google.datasets.bigquery.sanitize_uri + - schemes: [gs] + handler: airflow.providers.google.datasets.gcs.sanitize_uri + factory: airflow.providers.google.datasets.gcs.create_dataset + to_openlineage_converter: airflow.providers.google.datasets.gcs.convert_dataset_to_openlineage hooks: - integration-name: Google Ads diff --git a/tests/providers/google/cloud/hooks/test_gcs.py b/tests/providers/google/cloud/hooks/test_gcs.py index 5d2735834a9586..199cff9a1e863e 100644 --- a/tests/providers/google/cloud/hooks/test_gcs.py +++ b/tests/providers/google/cloud/hooks/test_gcs.py @@ -35,6 +35,7 @@ from google.cloud import exceptions, storage # type: ignore[attr-defined] from google.cloud.storage.retry import DEFAULT_RETRY +from airflow.datasets import Dataset from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks import gcs from airflow.providers.google.cloud.hooks.gcs import _fallback_object_url_to_object_name_and_bucket_name @@ -42,6 +43,7 @@ from airflow.utils import timezone from airflow.version import version from tests.providers.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id +from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}" GCS_STRING = "airflow.providers.google.cloud.hooks.gcs.{}" @@ -412,6 +414,41 @@ def test_copy_empty_source_object(self): assert str(ctx.value) == "source_bucket and source_object cannot be empty." + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0") + @mock.patch("google.cloud.storage.Bucket.copy_blob") + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) + def test_copy_exposes_lineage(self, mock_service, mock_copy, hook_lineage_collector): + source_bucket_name = "test-source-bucket" + source_object_name = "test-source-object" + destination_bucket_name = "test-dest-bucket" + destination_object_name = "test-dest-object" + + source_bucket = storage.Bucket(mock_service, source_bucket_name) + mock_copy.return_value = storage.Blob( + name=destination_object_name, bucket=storage.Bucket(mock_service, destination_bucket_name) + ) + mock_service.return_value.bucket.side_effect = ( + lambda name: source_bucket + if name == source_bucket_name + else storage.Bucket(mock_service, destination_bucket_name) + ) + + self.gcs_hook.copy( + source_bucket=source_bucket_name, + source_object=source_object_name, + destination_bucket=destination_bucket_name, + destination_object=destination_object_name, + ) + + assert len(hook_lineage_collector.collected_datasets.inputs) == 1 + assert len(hook_lineage_collector.collected_datasets.outputs) == 1 + assert hook_lineage_collector.collected_datasets.inputs[0].dataset == Dataset( + uri=f"gs://{source_bucket_name}/{source_object_name}" + ) + assert hook_lineage_collector.collected_datasets.outputs[0].dataset == Dataset( + uri=f"gs://{destination_bucket_name}/{destination_object_name}" + ) + @mock.patch("google.cloud.storage.Bucket") @mock.patch(GCS_STRING.format("GCSHook.get_conn")) def test_rewrite(self, mock_service, mock_bucket): @@ -473,6 +510,40 @@ def test_rewrite_empty_source_object(self): assert str(ctx.value) == "source_bucket and source_object cannot be empty." + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0") + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) + def test_rewrite_exposes_lineage(self, mock_service, hook_lineage_collector): + source_bucket_name = "test-source-bucket" + source_object_name = "test-source-object" + destination_bucket_name = "test-dest-bucket" + destination_object_name = "test-dest-object" + + dest_bucket = storage.Bucket(mock_service, destination_bucket_name) + blob = MagicMock(spec=storage.Blob) + blob.rewrite = MagicMock(return_value=(None, None, None)) + dest_bucket.blob = MagicMock(return_value=blob) + mock_service.return_value.bucket.side_effect = ( + lambda name: storage.Bucket(mock_service, source_bucket_name) + if name == source_bucket_name + else dest_bucket + ) + + self.gcs_hook.rewrite( + source_bucket=source_bucket_name, + source_object=source_object_name, + destination_bucket=destination_bucket_name, + destination_object=destination_object_name, + ) + + assert len(hook_lineage_collector.collected_datasets.inputs) == 1 + assert len(hook_lineage_collector.collected_datasets.outputs) == 1 + assert hook_lineage_collector.collected_datasets.inputs[0].dataset == Dataset( + uri=f"gs://{source_bucket_name}/{source_object_name}" + ) + assert hook_lineage_collector.collected_datasets.outputs[0].dataset == Dataset( + uri=f"gs://{destination_bucket_name}/{destination_object_name}" + ) + @mock.patch("google.cloud.storage.Bucket") @mock.patch(GCS_STRING.format("GCSHook.get_conn")) def test_delete(self, mock_service, mock_bucket): @@ -501,6 +572,22 @@ def test_delete_nonexisting_object(self, mock_service): with pytest.raises(exceptions.NotFound): self.gcs_hook.delete(bucket_name=test_bucket, object_name=test_object) + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0") + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) + def test_delete_exposes_lineage(self, mock_service, hook_lineage_collector): + test_bucket = "test_bucket" + test_object = "test_object" + + mock_service.return_value.bucket.return_value = storage.Bucket(mock_service, test_bucket) + + self.gcs_hook.delete(bucket_name=test_bucket, object_name=test_object) + + assert len(hook_lineage_collector.collected_datasets.inputs) == 1 + assert len(hook_lineage_collector.collected_datasets.outputs) == 0 + assert hook_lineage_collector.collected_datasets.inputs[0].dataset == Dataset( + uri=f"gs://{test_bucket}/{test_object}" + ) + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) def test_delete_bucket(self, mock_service): test_bucket = "test bucket" @@ -728,6 +815,33 @@ def test_compose_without_destination_object(self, mock_service): assert str(ctx.value) == "bucket_name and destination_object cannot be empty." + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0") + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) + def test_compose_exposes_lineage(self, mock_service, hook_lineage_collector): + test_bucket = "test_bucket" + source_object_names = ["test-source-object1", "test-source-object2"] + destination_object_name = "test-dest-object" + + mock_service.return_value.bucket.return_value = storage.Bucket(mock_service, test_bucket) + + self.gcs_hook.compose( + bucket_name=test_bucket, + source_objects=source_object_names, + destination_object=destination_object_name, + ) + + assert len(hook_lineage_collector.collected_datasets.inputs) == 2 + assert len(hook_lineage_collector.collected_datasets.outputs) == 1 + assert hook_lineage_collector.collected_datasets.inputs[0].dataset == Dataset( + uri=f"gs://{test_bucket}/{source_object_names[0]}" + ) + assert hook_lineage_collector.collected_datasets.inputs[1].dataset == Dataset( + uri=f"gs://{test_bucket}/{source_object_names[1]}" + ) + assert hook_lineage_collector.collected_datasets.outputs[0].dataset == Dataset( + uri=f"gs://{test_bucket}/{destination_object_name}" + ) + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) def test_download_as_bytes(self, mock_service): test_bucket = "test_bucket" @@ -742,6 +856,23 @@ def test_download_as_bytes(self, mock_service): assert response == test_object_bytes download_method.assert_called_once_with() + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0") + @mock.patch("google.cloud.storage.Blob.download_as_bytes") + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) + def test_download_as_bytes_exposes_lineage(self, mock_service, mock_download, hook_lineage_collector): + source_bucket_name = "test-source-bucket" + source_object_name = "test-source-object" + + mock_service.return_value.bucket.return_value = storage.Bucket(mock_service, source_bucket_name) + + self.gcs_hook.download(bucket_name=source_bucket_name, object_name=source_object_name, filename=None) + + assert len(hook_lineage_collector.collected_datasets.inputs) == 1 + assert len(hook_lineage_collector.collected_datasets.outputs) == 0 + assert hook_lineage_collector.collected_datasets.inputs[0].dataset == Dataset( + uri=f"gs://{source_bucket_name}/{source_object_name}" + ) + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) def test_download_to_file(self, mock_service): test_bucket = "test_bucket" @@ -765,6 +896,29 @@ def test_download_to_file(self, mock_service): assert response == test_file download_filename_method.assert_called_once_with(test_file, timeout=60) + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0") + @mock.patch("google.cloud.storage.Blob.download_to_filename") + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) + def test_download_to_file_exposes_lineage(self, mock_service, mock_download, hook_lineage_collector): + source_bucket_name = "test-source-bucket" + source_object_name = "test-source-object" + file_name = "test.txt" + + mock_service.return_value.bucket.return_value = storage.Bucket(mock_service, source_bucket_name) + + self.gcs_hook.download( + bucket_name=source_bucket_name, object_name=source_object_name, filename=file_name + ) + + assert len(hook_lineage_collector.collected_datasets.inputs) == 1 + assert len(hook_lineage_collector.collected_datasets.outputs) == 1 + assert hook_lineage_collector.collected_datasets.inputs[0].dataset == Dataset( + uri=f"gs://{source_bucket_name}/{source_object_name}" + ) + assert hook_lineage_collector.collected_datasets.outputs[0].dataset == Dataset( + uri=f"file://{file_name}" + ) + @mock.patch(GCS_STRING.format("NamedTemporaryFile")) @mock.patch(GCS_STRING.format("GCSHook.get_conn")) def test_provide_file(self, mock_service, mock_temp_file): @@ -998,6 +1152,29 @@ def test_upload_file(self, mock_service, testdata_file): assert metadata == blob_object.return_value.metadata + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0") + @mock.patch("google.cloud.storage.Blob.upload_from_filename") + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) + def test_upload_file_exposes_lineage(self, mock_service, mock_upload, hook_lineage_collector): + source_bucket_name = "test-source-bucket" + source_object_name = "test-source-object" + file_name = "test.txt" + + mock_service.return_value.bucket.return_value = storage.Bucket(mock_service, source_bucket_name) + + self.gcs_hook.upload( + bucket_name=source_bucket_name, object_name=source_object_name, filename=file_name + ) + + assert len(hook_lineage_collector.collected_datasets.inputs) == 1 + assert len(hook_lineage_collector.collected_datasets.outputs) == 1 + assert hook_lineage_collector.collected_datasets.outputs[0].dataset == Dataset( + uri=f"gs://{source_bucket_name}/{source_object_name}" + ) + assert hook_lineage_collector.collected_datasets.inputs[0].dataset == Dataset( + uri=f"file://{file_name}" + ) + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) def test_upload_cache_control(self, mock_service, testdata_file): test_bucket = "test_bucket" @@ -1077,6 +1254,23 @@ def test_upload_data_bytes_gzip(self, mock_service, mock_gzip, mock_bytes_io, te gzip_ctx.write.assert_called_once_with(testdata_bytes) upload_method.assert_called_once_with(data, content_type="text/plain", timeout=60) + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0") + @mock.patch("google.cloud.storage.Blob.upload_from_string") + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) + def test_upload_data_exposes_lineage(self, mock_service, mock_upload, hook_lineage_collector): + source_bucket_name = "test-source-bucket" + source_object_name = "test-source-object" + + mock_service.return_value.bucket.return_value = storage.Bucket(mock_service, source_bucket_name) + + self.gcs_hook.upload(bucket_name=source_bucket_name, object_name=source_object_name, data="test") + + assert len(hook_lineage_collector.collected_datasets.inputs) == 0 + assert len(hook_lineage_collector.collected_datasets.outputs) == 1 + assert hook_lineage_collector.collected_datasets.outputs[0].dataset == Dataset( + uri=f"gs://{source_bucket_name}/{source_object_name}" + ) + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) def test_upload_exceptions(self, mock_service, testdata_file, testdata_string): test_bucket = "test_bucket" diff --git a/tests/providers/google/datasets/test_gcs.py b/tests/providers/google/datasets/test_gcs.py new file mode 100644 index 00000000000000..d9893bd9cfffd6 --- /dev/null +++ b/tests/providers/google/datasets/test_gcs.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import urllib.parse + +import pytest + +from airflow.datasets import Dataset +from airflow.providers.google.datasets.gcs import convert_dataset_to_openlineage, create_dataset, sanitize_uri + + +def test_sanitize_uri(): + uri = sanitize_uri(urllib.parse.urlsplit("gs://bucket/dir/file.txt")) + result = sanitize_uri(uri) + assert result.scheme == "gs" + assert result.netloc == "bucket" + assert result.path == "/dir/file.txt" + + +def test_sanitize_uri_no_netloc(): + with pytest.raises(ValueError): + sanitize_uri(urllib.parse.urlsplit("gs://")) + + +def test_sanitize_uri_no_path(): + uri = sanitize_uri(urllib.parse.urlsplit("gs://bucket")) + result = sanitize_uri(uri) + assert result.scheme == "gs" + assert result.netloc == "bucket" + assert result.path == "" + + +def test_create_dataset(): + assert create_dataset(bucket="test-bucket", key="test-path") == Dataset(uri="gs://test-bucket/test-path") + assert create_dataset(bucket="test-bucket", key="test-dir/test-path") == Dataset( + uri="gs://test-bucket/test-dir/test-path" + ) + + +def test_sanitize_uri_trailing_slash(): + uri = sanitize_uri(urllib.parse.urlsplit("gs://bucket/")) + result = sanitize_uri(uri) + assert result.scheme == "gs" + assert result.netloc == "bucket" + assert result.path == "/" + + +def test_convert_dataset_to_openlineage_valid(): + uri = "gs://bucket/dir/file.txt" + ol_dataset = convert_dataset_to_openlineage(dataset=Dataset(uri=uri), lineage_context=None) + assert ol_dataset.namespace == "gs://bucket" + assert ol_dataset.name == "dir/file.txt" + + +@pytest.mark.parametrize("uri", ("gs://bucket", "gs://bucket/")) +def test_convert_dataset_to_openlineage_no_path(uri): + ol_dataset = convert_dataset_to_openlineage(dataset=Dataset(uri=uri), lineage_context=None) + assert ol_dataset.namespace == "gs://bucket" + assert ol_dataset.name == "/"