diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py index bfaef219adf45..2ac6c2645e20c 100644 --- a/airflow/providers/google/cloud/hooks/bigquery.py +++ b/airflow/providers/google/cloud/hooks/bigquery.py @@ -2245,7 +2245,7 @@ def run_query( self.running_job_id = job.job_id return job.job_id - def generate_job_id(self, job_id, dag_id, task_id, logical_date, configuration, force_rerun=False): + def generate_job_id(self, job_id, dag_id, task_id, logical_date, configuration, force_rerun=False) -> str: if force_rerun: hash_base = str(uuid.uuid4()) else: diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index f5e5a9634f72d..ca6f2900043d5 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -133,6 +133,68 @@ def get_db_hook(self: BigQueryCheckOperator) -> BigQueryHook: # type:ignore[mis ) +class _BigQueryOpenLineageMixin: + def get_openlineage_facets_on_complete(self, task_instance): + """ + Retrieve OpenLineage data for a COMPLETE BigQuery job. + + This method retrieves statistics for the specified job_ids using the BigQueryDatasetsProvider. + It calls BigQuery API, retrieving input and output dataset info from it, as well as run-level + usage statistics. + + Run facets should contain: + - ExternalQueryRunFacet + - BigQueryJobRunFacet + + Job facets should contain: + - SqlJobFacet if operator has self.sql + + Input datasets should contain facets: + - DataSourceDatasetFacet + - SchemaDatasetFacet + + Output datasets should contain facets: + - DataSourceDatasetFacet + - SchemaDatasetFacet + - OutputStatisticsOutputDatasetFacet + """ + from openlineage.client.facet import SqlJobFacet + from openlineage.common.provider.bigquery import BigQueryDatasetsProvider + + from airflow.providers.openlineage.extractors import OperatorLineage + from airflow.providers.openlineage.utils.utils import normalize_sql + + if not self.job_id: + return OperatorLineage() + + client = self.hook.get_client(project_id=self.hook.project_id) + job_ids = self.job_id + if isinstance(self.job_id, str): + job_ids = [self.job_id] + inputs, outputs, run_facets = {}, {}, {} + for job_id in job_ids: + stats = BigQueryDatasetsProvider(client=client).get_facets(job_id=job_id) + for input in stats.inputs: + input = input.to_openlineage_dataset() + inputs[input.name] = input + if stats.output: + output = stats.output.to_openlineage_dataset() + outputs[output.name] = output + for key, value in stats.run_facets.items(): + run_facets[key] = value + + job_facets = {} + if hasattr(self, "sql"): + job_facets["sql"] = SqlJobFacet(query=normalize_sql(self.sql)) + + return OperatorLineage( + inputs=list(inputs.values()), + outputs=list(outputs.values()), + run_facets=run_facets, + job_facets=job_facets, + ) + + class BigQueryCheckOperator(_BigQueryDbHookMixin, SQLCheckOperator): """Performs checks against BigQuery. @@ -1153,6 +1215,7 @@ def __init__( self.encryption_configuration = encryption_configuration self.hook: BigQueryHook | None = None self.impersonation_chain = impersonation_chain + self.job_id: str | list[str] | None = None def execute(self, context: Context): if self.hook is None: @@ -1164,7 +1227,7 @@ def execute(self, context: Context): impersonation_chain=self.impersonation_chain, ) if isinstance(self.sql, str): - job_id: str | list[str] = self.hook.run_query( + self.job_id = self.hook.run_query( sql=self.sql, destination_dataset_table=self.destination_dataset_table, write_disposition=self.write_disposition, @@ -1184,7 +1247,7 @@ def execute(self, context: Context): encryption_configuration=self.encryption_configuration, ) elif isinstance(self.sql, Iterable): - job_id = [ + self.job_id = [ self.hook.run_query( sql=s, destination_dataset_table=self.destination_dataset_table, @@ -1210,9 +1273,9 @@ def execute(self, context: Context): raise AirflowException(f"argument 'sql' of type {type(str)} is neither a string nor an iterable") project_id = self.hook.project_id if project_id: - job_id_path = convert_job_id(job_id=job_id, project_id=project_id, location=self.location) + job_id_path = convert_job_id(job_id=self.job_id, project_id=project_id, location=self.location) context["task_instance"].xcom_push(key="job_id_path", value=job_id_path) - return job_id + return self.job_id def on_kill(self) -> None: super().on_kill() @@ -2562,7 +2625,7 @@ def execute(self, context: Context): return table -class BigQueryInsertJobOperator(GoogleCloudBaseOperator): +class BigQueryInsertJobOperator(GoogleCloudBaseOperator, _BigQueryOpenLineageMixin): """Execute a BigQuery job. Waits for the job to complete and returns job id. @@ -2663,6 +2726,13 @@ def __init__( self.deferrable = deferrable self.poll_interval = poll_interval + @property + def sql(self) -> str | None: + try: + return self.configuration["query"]["query"] + except KeyError: + return None + def prepare_template(self) -> None: # If .json is passed then we have to read the file if isinstance(self.configuration, str) and self.configuration.endswith(".json"): @@ -2697,7 +2767,7 @@ def execute(self, context: Any): ) self.hook = hook - job_id = hook.generate_job_id( + self.job_id = hook.generate_job_id( job_id=self.job_id, dag_id=self.dag_id, task_id=self.task_id, @@ -2708,13 +2778,13 @@ def execute(self, context: Any): try: self.log.info("Executing: %s'", self.configuration) - job: BigQueryJob | UnknownJob = self._submit_job(hook, job_id) + job: BigQueryJob | UnknownJob = self._submit_job(hook, self.job_id) except Conflict: # If the job already exists retrieve it job = hook.get_job( project_id=self.project_id, location=self.location, - job_id=job_id, + job_id=self.job_id, ) if job.state in self.reattach_states: # We are reattaching to a job @@ -2723,7 +2793,7 @@ def execute(self, context: Any): else: # Same job configuration so we need force_rerun raise AirflowException( - f"Job with id: {job_id} already exists and is in {job.state} state. If you " + f"Job with id: {self.job_id} already exists and is in {job.state} state. If you " f"want to force rerun it consider setting `force_rerun=True`." f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`" ) @@ -2757,7 +2827,9 @@ def execute(self, context: Any): self.job_id = job.job_id project_id = self.project_id or self.hook.project_id if project_id: - job_id_path = convert_job_id(job_id=job_id, project_id=project_id, location=self.location) + job_id_path = convert_job_id( + job_id=self.job_id, project_id=project_id, location=self.location # type: ignore[arg-type] + ) context["ti"].xcom_push(key="job_id_path", value=job_id_path) # Wait for the job to complete if not self.deferrable: diff --git a/airflow/providers/openlineage/extractors/base.py b/airflow/providers/openlineage/extractors/base.py index 95d8fa6f2821e..0926489c0d92f 100644 --- a/airflow/providers/openlineage/extractors/base.py +++ b/airflow/providers/openlineage/extractors/base.py @@ -86,6 +86,12 @@ def extract(self) -> OperatorLineage | None: # OpenLineage methods are optional - if there's no method, return None try: return self._get_openlineage_facets(self.operator.get_openlineage_facets_on_start) # type: ignore + except ImportError: + self.log.error( + "OpenLineage provider method failed to import OpenLineage integration. " + "This should not happen. Please report this bug to developers." + ) + return None except AttributeError: return None diff --git a/airflow/providers/openlineage/utils/utils.py b/airflow/providers/openlineage/utils/utils.py index ca8b559e3a499..20b9afef49294 100644 --- a/airflow/providers/openlineage/utils/utils.py +++ b/airflow/providers/openlineage/utils/utils.py @@ -23,7 +23,7 @@ import os from contextlib import suppress from functools import wraps -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Iterable from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse import attrs @@ -414,3 +414,10 @@ def is_source_enabled() -> bool: def get_filtered_unknown_operator_keys(operator: BaseOperator) -> dict: not_required_keys = {"dag", "task_group"} return {attr: value for attr, value in operator.__dict__.items() if attr not in not_required_keys} + + +def normalize_sql(sql: str | Iterable[str]): + if isinstance(sql, str): + sql = [stmt for stmt in sql.split(";") if stmt != ""] + sql = [obj for stmt in sql for obj in stmt.split(";") if obj != ""] + return ";\n".join(sql) diff --git a/tests/providers/google/cloud/operators/job_details.json b/tests/providers/google/cloud/operators/job_details.json new file mode 100644 index 0000000000000..f12ec1321d57f --- /dev/null +++ b/tests/providers/google/cloud/operators/job_details.json @@ -0,0 +1,240 @@ +{ + "kind": "bigquery#job", + "etag": "vd2aBaVVX6a4bUJW13+Tqg==", + "id": "airflow:US.job_IDnbVW6NACdFDkermznYm9o4mcVH", + "selfLink": "https://bigquery.googleapis.com/bigquery/v2/projects/airflow-openlineage/jobs/job_IDnbVW6NACdFDkermznYm9o4mcVH?location=US", + "user_email": "svc-account@airflow-openlineage.iam.gserviceaccount.com", + "configuration": { + "query": { + "query": "Select * from test_table", + "destinationTable": { + "projectId": "airflow-openlineage", + "datasetId": "new_dataset", + "tableId": "output_table" + }, + "createDisposition": "CREATE_IF_NEEDED", + "writeDisposition": "WRITE_TRUNCATE", + "priority": "INTERACTIVE", + "allowLargeResults": false, + "useLegacySql": false + }, + "jobType": "QUERY" + }, + "jobReference": { + "projectId": "airflow-openlineage", + "jobId": "job_IDnbVW6NACdFDkermznYm9o4mcVH", + "location": "US" + }, + "statistics": { + "creationTime": 1.60390893E12, + "startTime": 1.60390893E12, + "endTime": 1.60390893E12, + "totalBytesProcessed": "110355534", + "query": { + "queryPlan": [{ + "name": "S00: Input", + "id": "0", + "startMs": "1603908925668", + "endMs": "1603908925880", + "waitRatioAvg": 0.0070422534, + "waitMsAvg": "2", + "waitRatioMax": 0.0070422534, + "waitMsMax": "2", + "readRatioAvg": 0.14084508, + "readMsAvg": "40", + "readRatioMax": 0.14084508, + "readMsMax": "40", + "computeRatioAvg": 1, + "computeMsAvg": "284", + "computeRatioMax": 1, + "computeMsMax": "284", + "writeRatioAvg": 0.017605634, + "writeMsAvg": "5", + "writeRatioMax": 0.017605634, + "writeMsMax": "5", + "shuffleOutputBytes": "439409", + "shuffleOutputBytesSpilled": "0", + "recordsRead": "5552452", + "recordsWritten": "16142", + "parallelInputs": "1", + "completedParallelInputs": "1", + "status": "COMPLETE", + "steps": [{ + "kind": "READ", + "substeps": [ + "$1:state, $2:name, $3:number", + "FROM bigquery-public-data.usa_names.usa_1910_2013", + "WHERE equal($1, 'TX')" + ] + }, + { + "kind": "AGGREGATE", + "substeps": [ + "GROUP BY $30 := $2, $31 := $1", + "$20 := SUM($3)" + ] + }, + { + "kind": "WRITE", + "substeps": [ + "$31, $30, $20", + "TO __stage00_output", + "BY HASH($30, $31)" + ] + } + ], + "slotMs": "448" + }, + { + "name": "S01: Sort+", + "id": "1", + "startMs": "1603908925891", + "endMs": "1603908925911", + "inputStages": [ + "0" + ], + "waitRatioAvg": 0.0070422534, + "waitMsAvg": "2", + "waitRatioMax": 0.0070422534, + "waitMsMax": "2", + "readRatioAvg": 0, + "readMsAvg": "0", + "readRatioMax": 0, + "readMsMax": "0", + "computeRatioAvg": 0.049295776, + "computeMsAvg": "14", + "computeRatioMax": 0.049295776, + "computeMsMax": "14", + "writeRatioAvg": 0.0070422534, + "writeMsAvg": "2", + "writeRatioMax": 0.0070422534, + "writeMsMax": "2", + "shuffleOutputBytes": "401", + "shuffleOutputBytesSpilled": "0", + "recordsRead": "16142", + "recordsWritten": "20", + "parallelInputs": "1", + "completedParallelInputs": "1", + "status": "COMPLETE", + "steps": [{ + "kind": "READ", + "substeps": [ + "$31, $30, $20", + "FROM __stage00_output" + ] + }, + { + "kind": "SORT", + "substeps": [ + "$10 DESC", + "LIMIT 20" + ] + }, + { + "kind": "AGGREGATE", + "substeps": [ + "GROUP BY $40 := $30, $41 := $31", + "$10 := SUM($20)" + ] + }, + { + "kind": "WRITE", + "substeps": [ + "$50, $51", + "TO __stage01_output" + ] + } + ], + "slotMs": "33" + }, + { + "name": "S02: Output", + "id": "2", + "startMs": "1603908926017", + "endMs": "1603908926191", + "inputStages": [ + "1" + ], + "waitRatioAvg": 0.4471831, + "waitMsAvg": "127", + "waitRatioMax": 0.4471831, + "waitMsMax": "127", + "readRatioAvg": 0, + "readMsAvg": "0", + "readRatioMax": 0, + "readMsMax": "0", + "computeRatioAvg": 0.03169014, + "computeMsAvg": "9", + "computeRatioMax": 0.03169014, + "computeMsMax": "9", + "writeRatioAvg": 0.5633803, + "writeMsAvg": "160", + "writeRatioMax": 0.5633803, + "writeMsMax": "160", + "shuffleOutputBytes": "321", + "shuffleOutputBytesSpilled": "0", + "recordsRead": "20", + "recordsWritten": "20", + "parallelInputs": "1", + "completedParallelInputs": "1", + "status": "COMPLETE", + "steps": [{ + "kind": "READ", + "substeps": [ + "$50, $51", + "FROM __stage01_output" + ] + }, + { + "kind": "SORT", + "substeps": [ + "$51 DESC", + "LIMIT 20" + ] + }, + { + "kind": "WRITE", + "substeps": [ + "$60, $61", + "TO __stage02_output" + ] + } + ], + "slotMs": "342" + } + ], + "estimatedBytesProcessed": "110355534", + "timeline": [{ + "elapsedMs": "736", + "totalSlotMs": "482", + "pendingUnits": "1", + "completedUnits": "2", + "activeUnits": "1" + }, + { + "elapsedMs": "1045", + "totalSlotMs": "825", + "pendingUnits": "0", + "completedUnits": "3", + "activeUnits": "1" + } + ], + "totalPartitionsProcessed": "0", + "totalBytesProcessed": "110355534", + "totalBytesBilled": "111149056", + "billingTier": 1, + "totalSlotMs": "825", + "cacheHit": false, + "referencedTables": [{ + "projectId": "airflow-openlineage", + "datasetId": "new_dataset", + "tableId": "test_table" + }], + "statementType": "SELECT" + }, + "totalSlotMs": "825" + }, + "status": { + "state": "DONE" + } +} diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py index 4026b4ba45074..d17f5498e27c6 100644 --- a/tests/providers/google/cloud/operators/test_bigquery.py +++ b/tests/providers/google/cloud/operators/test_bigquery.py @@ -17,6 +17,8 @@ # under the License. from __future__ import annotations +import json +from contextlib import suppress from unittest import mock from unittest.mock import ANY, MagicMock @@ -24,6 +26,13 @@ import pytest from google.cloud.bigquery import DEFAULT_RETRY from google.cloud.exceptions import Conflict +from openlineage.client.facet import ( + DataSourceDatasetFacet, + ExternalQueryRunFacet, + SqlJobFacet, +) +from openlineage.client.run import Dataset +from openlineage.common.provider.bigquery import BigQueryErrorRunFacet from airflow.exceptions import AirflowException, AirflowSkipException, AirflowTaskTimeout, TaskDeferred from airflow.providers.google.cloud.operators.bigquery import ( @@ -1520,6 +1529,88 @@ def test_bigquery_insert_job_operator_with_job_id_generate( force_rerun=True, ) + @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") + def test_execute_openlineage_events(self, mock_hook): + job_id = "123456" + hash_ = "hash" + real_job_id = f"{job_id}_{hash_}" + + configuration = { + "query": { + "query": "SELECT * FROM test_table", + "useLegacySql": False, + } + } + mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False) + mock_hook.return_value.generate_job_id.return_value = real_job_id + + op = BigQueryInsertJobOperator( + task_id="insert_query_job", + configuration=configuration, + location=TEST_DATASET_LOCATION, + job_id=job_id, + project_id=TEST_GCP_PROJECT_ID, + ) + result = op.execute(context=MagicMock()) + + mock_hook.return_value.insert_job.assert_called_once_with( + configuration=configuration, + location=TEST_DATASET_LOCATION, + job_id=real_job_id, + nowait=True, + project_id=TEST_GCP_PROJECT_ID, + retry=DEFAULT_RETRY, + timeout=None, + ) + + assert result == real_job_id + + with open(file="tests/providers/google/cloud/operators/job_details.json") as f: + job_details = json.loads(f.read()) + mock_hook.return_value.get_client.return_value.get_job.return_value._properties = job_details + + lineage = op.get_openlineage_facets_on_complete(None) + assert lineage.inputs == [ + Dataset( + namespace="bigquery", + name="airflow-openlineage.new_dataset.test_table", + facets={"dataSource": DataSourceDatasetFacet(name="bigquery", uri="bigquery")}, + ) + ] + + assert lineage.run_facets == { + "bigQuery_job": mock.ANY, + "externalQuery": ExternalQueryRunFacet(externalQueryId=mock.ANY, source="bigquery"), + } + assert lineage.job_facets == {"sql": SqlJobFacet(query="SELECT * FROM test_table")} + + @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") + def test_execute_fails_openlineage_events(self, mock_hook): + job_id = "1234" + + configuration = { + "query": { + "query": "SELECT * FROM test_table", + "useLegacySql": False, + } + } + operator = BigQueryInsertJobOperator( + task_id="insert_query_job_failed", + configuration=configuration, + location=TEST_DATASET_LOCATION, + job_id=job_id, + project_id=TEST_GCP_PROJECT_ID, + ) + mock_hook.return_value.generate_job_id.return_value = "1234" + mock_hook.return_value.get_client.return_value.get_job.side_effect = RuntimeError() + mock_hook.return_value.insert_job.side_effect = RuntimeError() + + with suppress(RuntimeError): + operator.execute(MagicMock()) + lineage = operator.get_openlineage_facets_on_complete(None) + + assert lineage.run_facets == {"bigQuery_error": BigQueryErrorRunFacet(clientError=mock.ANY)} + @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") def test_execute_force_rerun_async(self, mock_hook, create_task_instance_of_operator): job_id = "123456"