Skip to content

Commit

Permalink
fix OpenLineage extraction for GCP deferrable operators (#40521)
Browse files Browse the repository at this point in the history
Signed-off-by: Kacper Muda <[email protected]>
  • Loading branch information
kacpermuda authored Jul 1, 2024
1 parent acdac24 commit 47e7e25
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 15 deletions.
30 changes: 19 additions & 11 deletions airflow/providers/google/cloud/transfers/bigquery_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,6 @@ def __init__(
self.hook: BigQueryHook | None = None
self.deferrable = deferrable

self._job_id: str = ""

@staticmethod
def _handle_job_error(job: BigQueryJob | UnknownJob) -> None:
if job.error_result:
Expand Down Expand Up @@ -212,7 +210,7 @@ def execute(self, context: Context):
self.hook = hook

configuration = self._prepare_configuration()
job_id = hook.generate_job_id(
self.job_id = hook.generate_job_id(
job_id=self.job_id,
dag_id=self.dag_id,
task_id=self.task_id,
Expand All @@ -224,14 +222,14 @@ def execute(self, context: Context):
try:
self.log.info("Executing: %s", configuration)
job: BigQueryJob | UnknownJob = self._submit_job(
hook=hook, job_id=job_id, configuration=configuration
hook=hook, job_id=self.job_id, configuration=configuration
)
except Conflict:
# If the job already exists retrieve it
job = hook.get_job(
project_id=self.project_id,
location=self.location,
job_id=job_id,
job_id=self.job_id,
)
if job.state in self.reattach_states:
# We are reattaching to a job
Expand All @@ -240,12 +238,12 @@ def execute(self, context: Context):
else:
# Same job configuration so we need force_rerun
raise AirflowException(
f"Job with id: {job_id} already exists and is in {job.state} state. If you "
f"Job with id: {self.job_id} already exists and is in {job.state} state. If you "
f"want to force rerun it consider setting `force_rerun=True`."
f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`"
)

self._job_id = job.job_id
self.job_id = job.job_id
conf = job.to_api_repr()["configuration"]["extract"]["sourceTable"]
dataset_id, project_id, table_id = conf["datasetId"], conf["projectId"], conf["tableId"]
BigQueryTableLink.persist(
Expand All @@ -261,7 +259,7 @@ def execute(self, context: Context):
timeout=self.execution_timeout,
trigger=BigQueryInsertJobTrigger(
conn_id=self.gcp_conn_id,
job_id=self._job_id,
job_id=self.job_id,
project_id=self.project_id or self.hook.project_id,
location=self.location or self.hook.location,
impersonation_chain=self.impersonation_chain,
Expand All @@ -284,6 +282,8 @@ def execute_complete(self, context: Context, event: dict[str, Any]):
self.task_id,
event["message"],
)
# Save job_id as an attribute to be later used by listeners
self.job_id = event.get("job_id")

def get_openlineage_facets_on_complete(self, task_instance):
"""Implement on_complete as we will include final BQ job id."""
Expand All @@ -303,7 +303,15 @@ def get_openlineage_facets_on_complete(self, task_instance):
)
from airflow.providers.openlineage.extractors import OperatorLineage

table_object = self.hook.get_client(self.hook.project_id).get_table(self.source_project_dataset_table)
if not self.hook:
self.hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
location=self.location,
impersonation_chain=self.impersonation_chain,
)

project_id = self.project_id or self.hook.project_id
table_object = self.hook.get_client(project_id).get_table(self.source_project_dataset_table)

input_dataset = Dataset(
namespace="bigquery",
Expand Down Expand Up @@ -347,9 +355,9 @@ def get_openlineage_facets_on_complete(self, task_instance):
output_datasets.append(dataset)

run_facets = {}
if self._job_id:
if self.job_id:
run_facets = {
"externalQuery": ExternalQueryRunFacet(externalQueryId=self._job_id, source="bigquery"),
"externalQuery": ExternalQueryRunFacet(externalQueryId=self.job_id, source="bigquery"),
}

return OperatorLineage(inputs=[input_dataset], outputs=output_datasets, run_facets=run_facets)
19 changes: 15 additions & 4 deletions airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,8 @@ def execute_complete(self, context: Context, event: dict[str, Any]):
self.task_id,
event["message"],
)
# Save job_id as an attribute to be later used by listeners
self.job_id = event.get("job_id")
return self._find_max_value_in_column()

def _find_max_value_in_column(self):
Expand Down Expand Up @@ -757,17 +759,26 @@ def get_openlineage_facets_on_complete(self, task_instance):
)
from airflow.providers.openlineage.extractors import OperatorLineage

table_object = self.hook.get_client(self.hook.project_id).get_table(
self.destination_project_dataset_table
)
if not self.hook:
self.hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
location=self.location,
impersonation_chain=self.impersonation_chain,
)

project_id = self.project_id or self.hook.project_id
table_object = self.hook.get_client(project_id).get_table(self.destination_project_dataset_table)

output_dataset_facets = get_facets_from_bq_table(table_object)

source_objects = (
self.source_objects if isinstance(self.source_objects, list) else [self.source_objects]
)
input_dataset_facets = {
"schema": output_dataset_facets["schema"],
}
input_datasets = []
for blob in sorted(self.source_objects):
for blob in sorted(source_objects):
additional_facets = {}

if "*" in blob:
Expand Down
20 changes: 20 additions & 0 deletions tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,26 @@ def test_execute_deferrable_mode(self, mock_hook):
nowait=True,
)

def test_execute_complete_reassigns_job_id(self):
"""Assert that we use job_id from event after deferral."""

operator = BigQueryToGCSOperator(
project_id=JOB_PROJECT_ID,
task_id=TASK_ID,
source_project_dataset_table=f"{PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}",
destination_cloud_storage_uris=[f"gs://{TEST_BUCKET}/{TEST_FOLDER}/"],
deferrable=True,
job_id=None,
)
job_id = "123456"

assert operator.job_id is None
operator.execute_complete(
context=MagicMock(),
event={"status": "success", "message": "Job completed", "job_id": job_id},
)
assert operator.job_id == job_id

@pytest.mark.parametrize(
("gcs_uri", "expected_dataset_name"),
(
Expand Down
22 changes: 22 additions & 0 deletions tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -1913,6 +1913,28 @@ def test_schema_fields_int_without_external_table_async_should_execute_successfu

bq_hook.return_value.insert_job.assert_has_calls(calls)

@mock.patch(GCS_TO_BQ_PATH.format("BigQueryHook"))
def test_execute_complete_reassigns_job_id(self, bq_hook):
"""Assert that we use job_id from event after deferral."""

operator = GCSToBigQueryOperator(
task_id=TASK_ID,
bucket=TEST_BUCKET,
source_objects=TEST_SOURCE_OBJECTS,
destination_project_dataset_table=TEST_EXPLICIT_DEST,
deferrable=True,
job_id=None,
)
generated_job_id = "123456"

assert operator.job_id is None

operator.execute_complete(
context=MagicMock(),
event={"status": "success", "message": "Job completed", "job_id": generated_job_id},
)
assert operator.job_id == generated_job_id

def create_context(self, task):
dag = DAG(dag_id="dag")
logical_date = datetime(2022, 1, 1, 0, 0, 0)
Expand Down

0 comments on commit 47e7e25

Please sign in to comment.