diff --git a/airflow/providers/google/cloud/example_dags/example_datacatalog.py b/airflow/providers/google/cloud/example_dags/example_datacatalog.py index 08f8da278a90a..c8597a6e4f0ab 100644 --- a/airflow/providers/google/cloud/example_dags/example_datacatalog.py +++ b/airflow/providers/google/cloud/example_dags/example_datacatalog.py @@ -438,6 +438,7 @@ lookup_entry_linked_resource >> lookup_entry_result # Rename + update_tag >> rename_tag_template_field create_tag_template_field >> rename_tag_template_field >> delete_tag_template_field # Search diff --git a/airflow/providers/google/cloud/hooks/datacatalog.py b/airflow/providers/google/cloud/hooks/datacatalog.py index 3bdfa50c521f4..9c689c3ca9d9e 100644 --- a/airflow/providers/google/cloud/hooks/datacatalog.py +++ b/airflow/providers/google/cloud/hooks/datacatalog.py @@ -818,7 +818,7 @@ def lookup_entry( result = client.lookup_entry( sql_resource=sql_resource, retry=retry, timeout=timeout, metadata=metadata ) - self.log.info('Received entry. name=%s.', result.name) + self.log.info('Received entry. name=%s', result.name) return result @@ -1018,10 +1018,12 @@ def update_entry( "You must provide all the parameters (project_id, location, entry_group, entry_id) " "contained in the name, or do not specify any parameters and pass the name on the object " ) - name = entry.name if isinstance(entry, Entry) else entry["name"] self.log.info("Updating entry: name=%s", name) + # HACK: google-cloud-datacatalog has a problem with dictionaries for update methods. + if isinstance(entry, dict): + entry = Entry(**entry) result = client.update_entry( entry=entry, update_mask=update_mask, retry=retry, timeout=timeout, metadata=metadata ) @@ -1096,6 +1098,9 @@ def update_tag( # pylint: disable=too-many-arguments name = tag.name if isinstance(tag, Tag) else tag["name"] self.log.info("Updating tag: name=%s", name) + # HACK: google-cloud-datacatalog has a problem with dictionaries for update methods. + if isinstance(tag, dict): + tag = Tag(**tag) result = client.update_tag( tag=tag, update_mask=update_mask, retry=retry, timeout=timeout, metadata=metadata ) @@ -1170,6 +1175,9 @@ def update_tag_template( name = tag_template.name if isinstance(tag_template, TagTemplate) else tag_template["name"] self.log.info("Updating tag template: name=%s", name) + # HACK: google-cloud-datacatalog has a problem with dictionaries for update methods. + if isinstance(tag_template, dict): + tag_template = TagTemplate(**tag_template) result = client.update_tag_template( tag_template=tag_template, update_mask=update_mask, diff --git a/tests/providers/google/cloud/hooks/test_datacatalog.py b/tests/providers/google/cloud/hooks/test_datacatalog.py index 7acbbc82baeb8..cc1a60a211b05 100644 --- a/tests/providers/google/cloud/hooks/test_datacatalog.py +++ b/tests/providers/google/cloud/hooks/test_datacatalog.py @@ -21,7 +21,7 @@ from unittest import TestCase, mock from google.api_core.retry import Retry -from google.cloud.datacatalog_v1beta1.types import Tag +from google.cloud.datacatalog_v1beta1.types import Entry, Tag, TagTemplate from airflow import AirflowException from airflow.providers.google.cloud.hooks.datacatalog import CloudDataCatalogHook @@ -571,7 +571,7 @@ def test_update_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> Non metadata=TEST_METADATA, ) mock_get_conn.return_value.update_entry.assert_called_once_with( - entry=TEST_ENTRY, + entry=Entry(name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1)), update_mask=TEST_UPDATE_MASK, retry=TEST_RETRY, timeout=TEST_TIMEOUT, @@ -596,7 +596,7 @@ def test_update_tag(self, mock_get_conn, mock_get_creds_and_project_id) -> None: metadata=TEST_METADATA, ) mock_get_conn.return_value.update_tag.assert_called_once_with( - tag={"name": TEST_TAG_PATH.format(TEST_PROJECT_ID_1)}, + tag=Tag(name=TEST_TAG_PATH.format(TEST_PROJECT_ID_1)), update_mask=TEST_UPDATE_MASK, retry=TEST_RETRY, timeout=TEST_TIMEOUT, @@ -619,7 +619,7 @@ def test_update_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) metadata=TEST_METADATA, ) mock_get_conn.return_value.update_tag_template.assert_called_once_with( - tag_template=TEST_TAG_TEMPLATE, + tag_template=TagTemplate(name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1)), update_mask=TEST_UPDATE_MASK, retry=TEST_RETRY, timeout=TEST_TIMEOUT, @@ -1084,7 +1084,7 @@ def test_update_entry(self, mock_get_conn, mock_get_creds_and_project_id) -> Non metadata=TEST_METADATA, ) mock_get_conn.return_value.update_entry.assert_called_once_with( - entry=TEST_ENTRY, + entry=Entry(name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2)), update_mask=TEST_UPDATE_MASK, retry=TEST_RETRY, timeout=TEST_TIMEOUT, @@ -1110,7 +1110,7 @@ def test_update_tag(self, mock_get_conn, mock_get_creds_and_project_id) -> None: metadata=TEST_METADATA, ) mock_get_conn.return_value.update_tag.assert_called_once_with( - tag={"name": TEST_TAG_PATH.format(TEST_PROJECT_ID_2)}, + tag=Tag(name=TEST_TAG_PATH.format(TEST_PROJECT_ID_2)), update_mask=TEST_UPDATE_MASK, retry=TEST_RETRY, timeout=TEST_TIMEOUT, @@ -1134,7 +1134,7 @@ def test_update_tag_template(self, mock_get_conn, mock_get_creds_and_project_id) metadata=TEST_METADATA, ) mock_get_conn.return_value.update_tag_template.assert_called_once_with( - tag_template=TEST_TAG_TEMPLATE, + tag_template=TagTemplate(name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2)), update_mask=TEST_UPDATE_MASK, retry=TEST_RETRY, timeout=TEST_TIMEOUT, diff --git a/tests/providers/google/cloud/operators/test_datacatalog_system.py b/tests/providers/google/cloud/operators/test_datacatalog_system.py new file mode 100644 index 0000000000000..724faf8dcc896 --- /dev/null +++ b/tests/providers/google/cloud/operators/test_datacatalog_system.py @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from tests.providers.google.cloud.utils.gcp_authenticator import GCP_DATACATALOG_KEY +from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context + + +@pytest.mark.credential_file(GCP_DATACATALOG_KEY) +class CloudDataflowExampleDagsSystemTest(GoogleSystemTest): + @provide_gcp_context(GCP_DATACATALOG_KEY) + def test_run_example_gcp_dataflow_native_java(self): + self.run_dag('example_gcp_datacatalog', CLOUD_DAG_FOLDER) diff --git a/tests/providers/google/cloud/utils/gcp_authenticator.py b/tests/providers/google/cloud/utils/gcp_authenticator.py index e893d8a81321a..945dd3353bf4f 100644 --- a/tests/providers/google/cloud/utils/gcp_authenticator.py +++ b/tests/providers/google/cloud/utils/gcp_authenticator.py @@ -36,6 +36,7 @@ GCP_CLOUDSQL_KEY = 'gcp_cloudsql.json' GCP_COMPUTE_KEY = 'gcp_compute.json' GCP_COMPUTE_SSH_KEY = 'gcp_compute_ssh.json' +GCP_DATACATALOG_KEY = 'gcp_datacatalog.json' GCP_DATAFLOW_KEY = 'gcp_dataflow.json' GCP_DATAFUSION_KEY = 'gcp_datafusion.json' GCP_DATAPROC_KEY = 'gcp_dataproc.json'