Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add OpenLineage support for RedshiftToS3Operator #41632

Merged
merged 1 commit into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 106 additions & 7 deletions providers/src/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ def default_select_query(self) -> str | None:
table = self.table
return f"SELECT * FROM {table}"

@property
def use_redshift_data(self):
return bool(self.redshift_data_api_kwargs)

def execute(self, context: Context) -> None:
if self.table and self.table_as_file_name:
self.s3_key = f"{self.s3_key}/{self.table}_"
Expand All @@ -164,14 +168,13 @@ def execute(self, context: Context) -> None:
if self.include_header and "HEADER" not in [uo.upper().strip() for uo in self.unload_options]:
self.unload_options = [*self.unload_options, "HEADER"]

redshift_hook: RedshiftDataHook | RedshiftSQLHook
if self.redshift_data_api_kwargs:
redshift_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id)
if self.use_redshift_data:
redshift_data_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id)
for arg in ["sql", "parameters"]:
if arg in self.redshift_data_api_kwargs:
raise AirflowException(f"Cannot include param '{arg}' in Redshift Data API kwargs")
else:
redshift_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
redshift_sql_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
conn = S3Hook.get_connection(conn_id=self.aws_conn_id) if self.aws_conn_id else None
if conn and conn.extra_dejson.get("role_arn", False):
credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}"
Expand All @@ -187,10 +190,106 @@ def execute(self, context: Context) -> None:
)

self.log.info("Executing UNLOAD command...")
if isinstance(redshift_hook, RedshiftDataHook):
redshift_hook.execute_query(
if self.use_redshift_data:
redshift_data_hook.execute_query(
sql=unload_query, parameters=self.parameters, **self.redshift_data_api_kwargs
)
else:
redshift_hook.run(unload_query, self.autocommit, parameters=self.parameters)
redshift_sql_hook.run(unload_query, self.autocommit, parameters=self.parameters)
self.log.info("UNLOAD command complete...")

def get_openlineage_facets_on_complete(self, task_instance):
"""Implement on_complete as we may query for table details."""
from airflow.providers.amazon.aws.utils.openlineage import (
get_facets_from_redshift_table,
get_identity_column_lineage_facet,
)
from airflow.providers.common.compat.openlineage.facet import (
Dataset,
Error,
ExtractionErrorRunFacet,
)
from airflow.providers.openlineage.extractors import OperatorLineage

output_dataset = Dataset(
namespace=f"s3://{self.s3_bucket}",
name=self.s3_key,
)

if self.use_redshift_data:
redshift_data_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id)
database = self.redshift_data_api_kwargs.get("database")
identifier = self.redshift_data_api_kwargs.get(
"cluster_identifier", self.redshift_data_api_kwargs.get("workgroup_name")
)
port = self.redshift_data_api_kwargs.get("port", "5439")
authority = f"{identifier}.{redshift_data_hook.region_name}:{port}"
else:
redshift_sql_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
database = redshift_sql_hook.conn.schema
authority = redshift_sql_hook.get_openlineage_database_info(redshift_sql_hook.conn).authority

if self.select_query == self.default_select_query:
if self.use_redshift_data:
input_dataset_facets = get_facets_from_redshift_table(
redshift_data_hook, self.table, self.redshift_data_api_kwargs, self.schema
)
else:
input_dataset_facets = get_facets_from_redshift_table(
redshift_sql_hook, self.table, {}, self.schema
)

input_dataset = Dataset(
namespace=f"redshift://{authority}",
name=f"{database}.{self.schema}.{self.table}" if database else f"{self.schema}.{self.table}",
facets=input_dataset_facets,
)

# If default select query is used (SELECT *) output file matches the input table.
output_dataset.facets = {
"schema": input_dataset_facets["schema"],
"columnLineage": get_identity_column_lineage_facet(
field_names=[field.name for field in input_dataset_facets["schema"].fields],
input_datasets=[input_dataset],
),
}

return OperatorLineage(inputs=[input_dataset], outputs=[output_dataset])

try:
from airflow.providers.openlineage.sqlparser import SQLParser, from_table_meta
except ImportError:
return OperatorLineage(outputs=[output_dataset])

run_facets = {}
parse_result = SQLParser(dialect="redshift", default_schema=self.schema).parse(self.select_query)
if parse_result.errors:
run_facets["extractionError"] = ExtractionErrorRunFacet(
totalTasks=1,
failedTasks=1,
errors=[
Error(
errorMessage=error.message,
stackTrace=None,
task=error.origin_statement,
taskNumber=error.index,
)
for error in parse_result.errors
],
)

input_datasets = []
for in_tb in parse_result.in_tables:
ds = from_table_meta(in_tb, database, f"redshift://{authority}", False)
schema, table = ds.name.split(".")[-2:]
if self.use_redshift_data:
input_dataset_facets = get_facets_from_redshift_table(
redshift_data_hook, table, self.redshift_data_api_kwargs, schema
)
else:
input_dataset_facets = get_facets_from_redshift_table(redshift_sql_hook, table, {}, schema)

ds.facets = input_dataset_facets
input_datasets.append(ds)

return OperatorLineage(inputs=input_datasets, outputs=[output_dataset], run_facets=run_facets)
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def get_openlineage_facets_on_complete(self, task_instance):

output_dataset = Dataset(
namespace=f"redshift://{authority}",
name=f"{database}.{self.schema}.{self.table}",
name=f"{database}.{self.schema}.{self.table}" if database else f"{self.schema}.{self.table}",
facets=output_dataset_facets,
)

Expand Down
Loading