Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

openlineage, bigquery: add openlineage method support for BigQueryInsertJobOperator #31293

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2245,7 +2245,7 @@ def run_query(
self.running_job_id = job.job_id
return job.job_id

def generate_job_id(self, job_id, dag_id, task_id, logical_date, configuration, force_rerun=False):
def generate_job_id(self, job_id, dag_id, task_id, logical_date, configuration, force_rerun=False) -> str:
if force_rerun:
hash_base = str(uuid.uuid4())
else:
Expand Down
92 changes: 82 additions & 10 deletions airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,68 @@ def get_db_hook(self: BigQueryCheckOperator) -> BigQueryHook: # type:ignore[mis
)


class _BigQueryOpenLineageMixin:
def get_openlineage_facets_on_complete(self, task_instance):
"""
Retrieve OpenLineage data for a COMPLETE BigQuery job.
This method retrieves statistics for the specified job_ids using the BigQueryDatasetsProvider.
It calls BigQuery API, retrieving input and output dataset info from it, as well as run-level
usage statistics.
Run facets should contain:
- ExternalQueryRunFacet
- BigQueryJobRunFacet
Job facets should contain:
- SqlJobFacet if operator has self.sql
Input datasets should contain facets:
- DataSourceDatasetFacet
- SchemaDatasetFacet
Output datasets should contain facets:
- DataSourceDatasetFacet
- SchemaDatasetFacet
- OutputStatisticsOutputDatasetFacet
"""
from openlineage.client.facet import SqlJobFacet
from openlineage.common.provider.bigquery import BigQueryDatasetsProvider

from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.providers.openlineage.utils.utils import normalize_sql

if not self.job_id:
return OperatorLineage()

client = self.hook.get_client(project_id=self.hook.project_id)
job_ids = self.job_id
if isinstance(self.job_id, str):
job_ids = [self.job_id]
inputs, outputs, run_facets = {}, {}, {}
for job_id in job_ids:
stats = BigQueryDatasetsProvider(client=client).get_facets(job_id=job_id)
for input in stats.inputs:
input = input.to_openlineage_dataset()
inputs[input.name] = input
if stats.output:
output = stats.output.to_openlineage_dataset()
outputs[output.name] = output
for key, value in stats.run_facets.items():
run_facets[key] = value

job_facets = {}
if hasattr(self, "sql"):
job_facets["sql"] = SqlJobFacet(query=normalize_sql(self.sql))

return OperatorLineage(
inputs=list(inputs.values()),
outputs=list(outputs.values()),
run_facets=run_facets,
job_facets=job_facets,
)


class BigQueryCheckOperator(_BigQueryDbHookMixin, SQLCheckOperator):
"""Performs checks against BigQuery.
Expand Down Expand Up @@ -1153,6 +1215,7 @@ def __init__(
self.encryption_configuration = encryption_configuration
self.hook: BigQueryHook | None = None
self.impersonation_chain = impersonation_chain
self.job_id: str | list[str] | None = None

def execute(self, context: Context):
if self.hook is None:
Expand All @@ -1164,7 +1227,7 @@ def execute(self, context: Context):
impersonation_chain=self.impersonation_chain,
)
if isinstance(self.sql, str):
job_id: str | list[str] = self.hook.run_query(
self.job_id = self.hook.run_query(
sql=self.sql,
destination_dataset_table=self.destination_dataset_table,
write_disposition=self.write_disposition,
Expand All @@ -1184,7 +1247,7 @@ def execute(self, context: Context):
encryption_configuration=self.encryption_configuration,
)
elif isinstance(self.sql, Iterable):
job_id = [
self.job_id = [
self.hook.run_query(
sql=s,
destination_dataset_table=self.destination_dataset_table,
Expand All @@ -1210,9 +1273,9 @@ def execute(self, context: Context):
raise AirflowException(f"argument 'sql' of type {type(str)} is neither a string nor an iterable")
project_id = self.hook.project_id
if project_id:
job_id_path = convert_job_id(job_id=job_id, project_id=project_id, location=self.location)
job_id_path = convert_job_id(job_id=self.job_id, project_id=project_id, location=self.location)
context["task_instance"].xcom_push(key="job_id_path", value=job_id_path)
return job_id
return self.job_id

def on_kill(self) -> None:
super().on_kill()
Expand Down Expand Up @@ -2562,7 +2625,7 @@ def execute(self, context: Context):
return table


class BigQueryInsertJobOperator(GoogleCloudBaseOperator):
class BigQueryInsertJobOperator(GoogleCloudBaseOperator, _BigQueryOpenLineageMixin):
"""Execute a BigQuery job.
Waits for the job to complete and returns job id.
Expand Down Expand Up @@ -2663,6 +2726,13 @@ def __init__(
self.deferrable = deferrable
self.poll_interval = poll_interval

@property
def sql(self) -> str | None:
try:
return self.configuration["query"]["query"]
except KeyError:
return None

def prepare_template(self) -> None:
# If .json is passed then we have to read the file
if isinstance(self.configuration, str) and self.configuration.endswith(".json"):
Expand Down Expand Up @@ -2697,7 +2767,7 @@ def execute(self, context: Any):
)
self.hook = hook

job_id = hook.generate_job_id(
self.job_id = hook.generate_job_id(
job_id=self.job_id,
dag_id=self.dag_id,
task_id=self.task_id,
Expand All @@ -2708,13 +2778,13 @@ def execute(self, context: Any):

try:
self.log.info("Executing: %s'", self.configuration)
job: BigQueryJob | UnknownJob = self._submit_job(hook, job_id)
job: BigQueryJob | UnknownJob = self._submit_job(hook, self.job_id)
except Conflict:
# If the job already exists retrieve it
job = hook.get_job(
project_id=self.project_id,
location=self.location,
job_id=job_id,
job_id=self.job_id,
)
if job.state in self.reattach_states:
# We are reattaching to a job
Expand All @@ -2723,7 +2793,7 @@ def execute(self, context: Any):
else:
# Same job configuration so we need force_rerun
raise AirflowException(
f"Job with id: {job_id} already exists and is in {job.state} state. If you "
f"Job with id: {self.job_id} already exists and is in {job.state} state. If you "
f"want to force rerun it consider setting `force_rerun=True`."
f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`"
)
Expand Down Expand Up @@ -2757,7 +2827,9 @@ def execute(self, context: Any):
self.job_id = job.job_id
project_id = self.project_id or self.hook.project_id
if project_id:
job_id_path = convert_job_id(job_id=job_id, project_id=project_id, location=self.location)
job_id_path = convert_job_id(
job_id=self.job_id, project_id=project_id, location=self.location # type: ignore[arg-type]
)
context["ti"].xcom_push(key="job_id_path", value=job_id_path)
# Wait for the job to complete
if not self.deferrable:
Expand Down
6 changes: 6 additions & 0 deletions airflow/providers/openlineage/extractors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ def extract(self) -> OperatorLineage | None:
# OpenLineage methods are optional - if there's no method, return None
try:
return self._get_openlineage_facets(self.operator.get_openlineage_facets_on_start) # type: ignore
except ImportError:
self.log.error(
"OpenLineage provider method failed to import OpenLineage integration. "
"This should not happen. Please report this bug to developers."
)
return None
except AttributeError:
return None

Expand Down
9 changes: 8 additions & 1 deletion airflow/providers/openlineage/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import os
from contextlib import suppress
from functools import wraps
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Iterable
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse

import attrs
Expand Down Expand Up @@ -414,3 +414,10 @@ def is_source_enabled() -> bool:
def get_filtered_unknown_operator_keys(operator: BaseOperator) -> dict:
not_required_keys = {"dag", "task_group"}
return {attr: value for attr, value in operator.__dict__.items() if attr not in not_required_keys}


def normalize_sql(sql: str | Iterable[str]):
if isinstance(sql, str):
sql = [stmt for stmt in sql.split(";") if stmt != ""]
sql = [obj for stmt in sql for obj in stmt.split(";") if obj != ""]
return ";\n".join(sql)
Loading