Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingest): add column-level lineage support for snowflake #6034

Merged
merged 3 commits into from
Sep 28, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,29 +1,112 @@
import json
import logging
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Set, Tuple

from pydantic.error_wrappers import ValidationError
from snowflake.connector import SnowflakeConnection

import datahub.emitter.mce_builder as builder
from datahub.ingestion.source.aws.s3_util import make_s3_urn
from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config
from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery
from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report
from datahub.ingestion.source.snowflake.snowflake_usage_v2 import (
SnowflakeColumnReference,
)
from datahub.ingestion.source.snowflake.snowflake_utils import (
SnowflakeCommonMixin,
SnowflakeQueryMixin,
)
from datahub.metadata.com.linkedin.pegasus2avro.dataset import UpstreamLineage
from datahub.metadata.com.linkedin.pegasus2avro.dataset import (
FineGrainedLineage,
FineGrainedLineageDownstreamType,
FineGrainedLineageUpstreamType,
UpstreamLineage,
)
from datahub.metadata.schema_classes import DatasetLineageTypeClass, UpstreamClass
from datahub.utilities.perf_timer import PerfTimer

logger: logging.Logger = logging.getLogger(__name__)


class SnowflakeColumnWithLineage(SnowflakeColumnReference):
directSourceColumns: Optional[List[SnowflakeColumnReference]] = None

def __hash__(self):
return hash(self.__members())

def __members(self):
members = (
self.columnName,
tuple(
sorted((c.columnName, c.objectName) for c in self.directSourceColumns)
)
if self.directSourceColumns
else None,
)
return members

def __eq__(self, instance):
return (
isinstance(instance, SnowflakeColumnWithLineage)
and self.__members() == instance.__members()
)


@dataclass
class SnowflakeUpstreamTable:
upstreamDataset: str
upstreamColumns: List[SnowflakeColumnReference]
downstreamColumns: List[SnowflakeColumnWithLineage]

@classmethod
def from_dict(cls, dataset, upstreams_columns_dict, downstream_columns_dict):
try:
table_with_upstreams = cls(
dataset,
[SnowflakeColumnReference(**col) for col in upstreams_columns_dict],
[SnowflakeColumnWithLineage(**col) for col in downstream_columns_dict],
)
except ValidationError:
hsheth2 marked this conversation as resolved.
Show resolved Hide resolved
# Earlier versions of column lineage did not include columnName, only columnId
table_with_upstreams = cls(dataset, [], [])
return table_with_upstreams

def __hash__(self):
return hash(self.__members())

def __members(self):
return (self.upstreamDataset,)
hsheth2 marked this conversation as resolved.
Show resolved Hide resolved

def __eq__(self, instance):
return (
isinstance(instance, SnowflakeUpstreamTable)
and self.__members() == instance.__members()
)


@dataclass
class SnowflakeTableLineage:
upstreamTables: Set[SnowflakeUpstreamTable] = field(default_factory=set, init=False)
columnLineages: Set[SnowflakeColumnWithLineage] = field(
default_factory=set, init=False
)

def update_lineage(self, table: SnowflakeUpstreamTable) -> None:
if table not in self.upstreamTables:
self.upstreamTables.add(table)

if table.downstreamColumns:
for col in table.downstreamColumns:
if col.directSourceColumns and col not in self.columnLineages:
self.columnLineages.add(col)


class SnowflakeLineageExtractor(SnowflakeQueryMixin, SnowflakeCommonMixin):
def __init__(self, config: SnowflakeV2Config, report: SnowflakeV2Report) -> None:
self._lineage_map: Optional[Dict[str, List[Tuple[str, str, str]]]] = None
self._lineage_map: Optional[Dict[str, SnowflakeTableLineage]] = None
self._external_lineage_map: Optional[Dict[str, Set[str]]] = None
self.config = config
self.platform = "snowflake"
Expand Down Expand Up @@ -54,49 +137,102 @@ def _get_upstream_lineage_info(

lineage = self._lineage_map[dataset_name]
external_lineage = self._external_lineage_map[dataset_name]
if not (lineage or external_lineage):
if not (lineage.upstreamTables or lineage.columnLineages or external_lineage):
logger.debug(f"No lineage found for {dataset_name}")
return None
upstream_tables: List[UpstreamClass] = []
finegrained_lineages: List[FineGrainedLineage] = []
fieldset_finegrained_lineages: List[FineGrainedLineage] = []
column_lineage: Dict[str, str] = {}
for lineage_entry in lineage:
dataset_urn = builder.make_dataset_urn_with_platform_instance(
self.platform,
dataset_name,
self.config.platform_instance,
self.config.env,
)
for lineage_entry in sorted(
lineage.upstreamTables, key=lambda x: x.upstreamDataset
):
# Update the table-lineage
upstream_table_name = lineage_entry[0]
if not self._is_dataset_pattern_allowed(upstream_table_name, "table"):
continue
upstream_table_name = lineage_entry.upstreamDataset
upstream_table_urn = builder.make_dataset_urn_with_platform_instance(
self.platform,
upstream_table_name,
self.config.platform_instance,
self.config.env,
)
upstream_table = UpstreamClass(
dataset=builder.make_dataset_urn_with_platform_instance(
self.platform,
upstream_table_name,
self.config.platform_instance,
self.config.env,
),
dataset=upstream_table_urn,
type=DatasetLineageTypeClass.TRANSFORMED,
)
upstream_tables.append(upstream_table)
# Update column-lineage for each down-stream column.
upstream_columns = [
self.snowflake_identifier(d["columnName"])
for d in json.loads(lineage_entry[1])
]
downstream_columns = [
self.snowflake_identifier(d["columnName"])
for d in json.loads(lineage_entry[2])
]
upstream_column_str = (
f"{upstream_table_name}({', '.join(sorted(upstream_columns))})"
)
downstream_column_str = (
f"{dataset_name}({', '.join(sorted(downstream_columns))})"
)
column_lineage_key = f"column_lineage[{upstream_table_name}]"
column_lineage_value = (
f"{{{upstream_column_str} -> {downstream_column_str}}}"

if lineage_entry.upstreamColumns and lineage_entry.downstreamColumns:
# This is not used currently. This indicates same column lineage as was set
# in customProperties earlier - not accurate.
fieldset_finegrained_lineage = FineGrainedLineage(
upstreamType=FineGrainedLineageUpstreamType.FIELD_SET,
downstreamType=FineGrainedLineageDownstreamType.FIELD_SET
if len(lineage_entry.downstreamColumns) > 1
else FineGrainedLineageDownstreamType.FIELD,
upstreams=sorted(
[
builder.make_schema_field_urn(
upstream_table_urn,
self.snowflake_identifier(d.columnName),
)
for d in lineage_entry.upstreamColumns
]
),
downstreams=sorted(
[
builder.make_schema_field_urn(
dataset_urn, self.snowflake_identifier(d.columnName)
)
for d in lineage_entry.downstreamColumns
]
),
)
fieldset_finegrained_lineages.append(fieldset_finegrained_lineage)

for col in lineage.columnLineages:
fieldPath = col.columnName
finegrained_lineage_entry = FineGrainedLineage(
upstreamType=FineGrainedLineageUpstreamType.FIELD_SET,
upstreams=sorted(
[
builder.make_schema_field_urn(
builder.make_dataset_urn_with_platform_instance(
self.platform,
self.get_dataset_identifier_from_qualified_name(
upstream_col.objectName
),
self.config.platform_instance,
self.config.env,
),
self.snowflake_identifier(upstream_col.columnName),
)
for upstream_col in col.directSourceColumns # type:ignore
if upstream_col.objectName
and upstream_col.columnName
and self._is_dataset_pattern_allowed(
upstream_col.objectName, upstream_col.objectDomain
)
]
),
downstreamType=FineGrainedLineageDownstreamType.FIELD,
downstreams=sorted(
[
builder.make_schema_field_urn(
dataset_urn, self.snowflake_identifier(fieldPath)
)
]
),
)
column_lineage[column_lineage_key] = column_lineage_value
logger.debug(f"{column_lineage_key}:{column_lineage_value}")
if finegrained_lineage_entry.upstreams:
finegrained_lineages.append(finegrained_lineage_entry)

for external_lineage_entry in external_lineage:
for external_lineage_entry in sorted(external_lineage):
# For now, populate only for S3
if external_lineage_entry.startswith("s3://"):
external_upstream_table = UpstreamClass(
Expand All @@ -113,7 +249,16 @@ def _get_upstream_lineage_info(
self.report.upstream_lineage[dataset_name] = [
u.dataset for u in upstream_tables
]
return UpstreamLineage(upstreams=upstream_tables), column_lineage
return (
UpstreamLineage(
upstreams=upstream_tables,
fineGrainedLineages=sorted(
finegrained_lineages, key=lambda x: (x.downstreams, x.upstreams)
)
or None,
),
column_lineage,
)
return None

def _populate_view_lineage(self, conn: SnowflakeConnection) -> None:
Expand Down Expand Up @@ -189,7 +334,7 @@ def _populate_lineage(self, conn: SnowflakeConnection) -> None:
end_time_millis=int(self.config.end_time.timestamp() * 1000),
)
num_edges: int = 0
self._lineage_map = defaultdict(list)
self._lineage_map = defaultdict(SnowflakeTableLineage)
try:
for db_row in self.query(conn, query):
# key is the down-stream table name
Expand All @@ -204,19 +349,21 @@ def _populate_lineage(self, conn: SnowflakeConnection) -> None:
or self._is_dataset_pattern_allowed(upstream_table_name, "table")
):
continue
self._lineage_map[key].append(

self._lineage_map[key].update_lineage(
# (<upstream_table_name>, <json_list_of_upstream_columns>, <json_list_of_downstream_columns>)
(
SnowflakeUpstreamTable.from_dict(
upstream_table_name,
db_row["UPSTREAM_TABLE_COLUMNS"],
db_row["DOWNSTREAM_TABLE_COLUMNS"],
)
json.loads(db_row["UPSTREAM_TABLE_COLUMNS"]),
json.loads(db_row["DOWNSTREAM_TABLE_COLUMNS"]),
),
)
num_edges += 1
logger.debug(
f"Lineage[Table(Down)={key}]:Table(Up)={self._lineage_map[key]}"
)
except Exception as e:
logger.error(e, exc_info=e)
self.warn(
"lineage",
f"Extracting lineage from Snowflake failed."
Expand Down Expand Up @@ -246,15 +393,19 @@ def _populate_view_upstream_lineage(self, conn: SnowflakeConnection) -> None:
view_name: str = self.get_dataset_identifier_from_qualified_name(
db_row["DOWNSTREAM_VIEW"]
)

if not self._is_dataset_pattern_allowed(
dataset_name=view_name,
dataset_type=db_row["REFERENCING_OBJECT_DOMAIN"],
) or not self._is_dataset_pattern_allowed(
view_upstream, db_row["REFERENCED_OBJECT_DOMAIN"]
):
continue

# key is the downstream view name
self._lineage_map[view_name].append(
self._lineage_map[view_name].update_lineage(
# (<upstream_table_name>, <empty_json_list_of_upstream_table_columns>, <empty_json_list_of_downstream_view_columns>)
(view_upstream, "[]", "[]")
SnowflakeUpstreamTable.from_dict(view_upstream, [], [])
)
num_edges += 1
logger.debug(
Expand Down Expand Up @@ -297,20 +448,23 @@ def _populate_view_downstream_lineage(self, conn: SnowflakeConnection) -> None:
view_name: str = self.get_dataset_identifier_from_qualified_name(
db_row["VIEW_NAME"]
)
downstream_table: str = self.get_dataset_identifier_from_qualified_name(
db_row["DOWNSTREAM_TABLE_NAME"]
)
if not self._is_dataset_pattern_allowed(
view_name, db_row["VIEW_DOMAIN"]
) or not self._is_dataset_pattern_allowed(
downstream_table, db_row["DOWNSTREAM_TABLE_DOMAIN"]
):
continue
downstream_table: str = self.get_dataset_identifier_from_qualified_name(
db_row["DOWNSTREAM_TABLE_NAME"]
)

# Capture view->downstream table lineage.
self._lineage_map[downstream_table].append(
self._lineage_map[downstream_table].update_lineage(
hsheth2 marked this conversation as resolved.
Show resolved Hide resolved
# (<upstream_view_name>, <json_list_of_upstream_view_columns>, <json_list_of_downstream_columns>)
(
SnowflakeUpstreamTable.from_dict(
view_name,
db_row["VIEW_COLUMNS"],
db_row["DOWNSTREAM_TABLE_COLUMNS"],
json.loads(db_row["VIEW_COLUMNS"]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if this bombs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I don't expect bombing since the VIEW_COLUMNS column is equivalent to columns json array as per snowflake access history view documentation and the dict to class conversion is protected already.

json.loads(db_row["DOWNSTREAM_TABLE_COLUMNS"]),
)
)
self.report.num_view_to_table_edges_scanned += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,12 @@ def table_to_table_lineage_history(
downstream_table_columns AS "DOWNSTREAM_TABLE_COLUMNS"
FROM table_lineage_history
WHERE upstream_table_domain in ('Table', 'External table') and downstream_table_domain = 'Table'
QUALIFY ROW_NUMBER() OVER (PARTITION BY downstream_table_name, upstream_table_name ORDER BY query_start_time DESC) = 1"""
QUALIFY ROW_NUMBER() OVER (
PARTITION BY downstream_table_name,
upstream_table_name,
downstream_table_columns
ORDER BY query_start_time DESC
) = 1"""

@staticmethod
def view_dependencies() -> str:
Expand All @@ -260,6 +265,7 @@ def view_dependencies() -> str:
referenced_database, '.', referenced_schema,
'.', referenced_object_name
) AS "VIEW_UPSTREAM",
referenced_object_domain as "REFERENCED_OBJECT_DOMAIN",
concat(
referencing_database, '.', referencing_schema,
'.', referencing_object_name
Expand Down Expand Up @@ -305,14 +311,16 @@ def view_lineage_history(start_time_millis: int, end_time_millis: int) -> str:
view_domain AS "VIEW_DOMAIN",
view_columns AS "VIEW_COLUMNS",
downstream_table_name AS "DOWNSTREAM_TABLE_NAME",
downstream_table_domain AS "DOWNSTREAM_TABLE_DOMAIN",
downstream_table_columns AS "DOWNSTREAM_TABLE_COLUMNS"
FROM
view_lineage_history
WHERE
view_domain in ('View', 'Materialized view')
QUALIFY ROW_NUMBER() OVER (
PARTITION BY view_name,
downstream_table_name
downstream_table_name,
downstream_table_columns
ORDER BY
query_start_time DESC
) = 1
Expand Down
Loading