Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added OL support for AzureBlobStorageToGCSOperator in google provider package #40290

Merged
merged 8 commits into from
Jun 20, 2024
13 changes: 13 additions & 0 deletions airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,16 @@ def execute(self, context: Context) -> str:
self.bucket_name,
)
return f"gs://{self.bucket_name}/{self.object_name}"

def get_openlineage_facets_on_start(self):
from openlineage.client.run import Dataset

from airflow.providers.openlineage.extractors import OperatorLineage

wasb_hook = WasbHook(wasb_conn_id=self.wasb_conn_id)
account_name = wasb_hook.get_conn().account_name

return OperatorLineage(
inputs=[Dataset(namespace=f"wasbs://{self.container_name}@{account_name}", name=self.blob_name)],
outputs=[Dataset(namespace=f"gs://{self.bucket_name}", name=self.object_name)],
)
29 changes: 29 additions & 0 deletions tests/providers/google/cloud/transfers/test_azure_blob_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,32 @@ def test_execute(self, mock_temp, mock_hook_gcs, mock_hook_wasb):
gzip=GZIP,
filename=mock_temp.NamedTemporaryFile.return_value.__enter__.return_value.name,
)

@mock.patch("airflow.providers.google.cloud.transfers.azure_blob_to_gcs.WasbHook")
def test_execute_single_file_transfer_openlineage(self, mock_hook_wasb):
from openlineage.client.run import Dataset

MOCK_AZURE_ACCOUNT_NAME = "mock_account_name"
mock_hook_wasb.return_value.get_conn.return_value.account_name = MOCK_AZURE_ACCOUNT_NAME

operator = AzureBlobStorageToGCSOperator(
wasb_conn_id=WASB_CONN_ID,
gcp_conn_id=GCP_CONN_ID,
blob_name=BLOB_NAME,
container_name=CONTAINER_NAME,
bucket_name=BUCKET_NAME,
object_name=OBJECT_NAME,
filename=FILENAME,
gzip=GZIP,
impersonation_chain=IMPERSONATION_CHAIN,
task_id=TASK_ID,
)

lineage = operator.get_openlineage_facets_on_start()

assert len(lineage.inputs) == 1
assert len(lineage.outputs) == 1
assert lineage.inputs[0] == Dataset(
namespace=f"wasbs://{CONTAINER_NAME}@{MOCK_AZURE_ACCOUNT_NAME}", name=BLOB_NAME
)
assert lineage.outputs[0] == Dataset(namespace=f"gs://{BUCKET_NAME}", name=OBJECT_NAME)
30 changes: 30 additions & 0 deletions tests/providers/microsoft/azure/hooks/test_wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,3 +615,33 @@ def test_connection_failure(self, mocked_blob_service_client):
status, msg = hook.test_connection()
assert status is False
assert msg == "Authentication failed."

@pytest.mark.parametrize(
"conn_id_str",
[
"wasb_test_key",
"pub_read_id",
"pub_read_id_without_host",
"azure_test_connection_string",
"azure_shared_key_test",
"ad_conn_id",
"managed_identity_conn_id",
"sas_conn_id",
"extra__wasb__sas_conn_id",
"http_sas_conn_id",
"extra__wasb__http_sas_conn_id",
],
)
def test_extract_account_name_from_connection(self, conn_id_str, mocked_blob_service_client):
expected_account_name = "testname"
if conn_id_str == "azure_test_connection_string":
mocked_blob_service_client.from_connection_string().account_name = expected_account_name
else:
mocked_blob_service_client.return_value.account_name = expected_account_name

wasb_hook = WasbHook(wasb_conn_id=conn_id_str)
account_name = wasb_hook.get_conn().account_name

assert (
account_name == expected_account_name
), f"Expected account name {expected_account_name} but got {account_name}"