Skip to content

Commit

Permalink
feat(ingest): add column-level lineage support for snowflake (datahub…
Browse files Browse the repository at this point in the history
  • Loading branch information
mayurinehate authored and shirshanka committed Sep 29, 2022
1 parent f95ef89 commit e417ccb
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 57 deletions.
Original file line number Diff line number Diff line change
@@ -1,29 +1,139 @@
import json
import logging
from collections import defaultdict
from typing import Dict, List, Optional, Set, Tuple
from dataclasses import dataclass, field
from typing import Dict, FrozenSet, List, Optional, Set, Tuple

from pydantic.error_wrappers import ValidationError
from snowflake.connector import SnowflakeConnection

import datahub.emitter.mce_builder as builder
from datahub.ingestion.source.aws.s3_util import make_s3_urn
from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config
from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery
from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report
from datahub.ingestion.source.snowflake.snowflake_usage_v2 import (
SnowflakeColumnReference,
)
from datahub.ingestion.source.snowflake.snowflake_utils import (
SnowflakeCommonMixin,
SnowflakeQueryMixin,
)
from datahub.metadata.com.linkedin.pegasus2avro.dataset import UpstreamLineage
from datahub.metadata.com.linkedin.pegasus2avro.dataset import (
FineGrainedLineage,
FineGrainedLineageDownstreamType,
FineGrainedLineageUpstreamType,
UpstreamLineage,
)
from datahub.metadata.schema_classes import DatasetLineageTypeClass, UpstreamClass
from datahub.utilities.perf_timer import PerfTimer

logger: logging.Logger = logging.getLogger(__name__)


class SnowflakeColumnWithLineage(SnowflakeColumnReference):
directSourceColumns: Optional[List[SnowflakeColumnReference]] = None


@dataclass(frozen=True)
class SnowflakeColumnId:
columnName: str
objectName: str
objectDomain: Optional[str] = None


@dataclass(frozen=True)
class SnowflakeColumnFineGrainedLineage:
"""
Fie grained upstream of column,
which represents a transformation applied on input columns"""

inputColumns: FrozenSet[SnowflakeColumnId]
# Transform function, query etc can be added here


@dataclass
class SnowflakeColumnUpstreams:
"""All upstreams of a column"""

upstreams: Set[SnowflakeColumnFineGrainedLineage] = field(
default_factory=set, init=False
)

def update_column_lineage(
self, directSourceColumns: List[SnowflakeColumnReference]
) -> None:
input_columns = frozenset(
[
SnowflakeColumnId(
upstream_col.columnName,
upstream_col.objectName,
upstream_col.objectDomain,
)
for upstream_col in directSourceColumns
if upstream_col.objectName
]
)
if not input_columns:
return
upstream = SnowflakeColumnFineGrainedLineage(inputColumns=input_columns)
if upstream not in self.upstreams:
self.upstreams.add(upstream)


@dataclass
class SnowflakeUpstreamTable:
upstreamDataset: str
upstreamColumns: List[SnowflakeColumnReference]
downstreamColumns: List[SnowflakeColumnWithLineage]

@classmethod
def from_dict(cls, dataset, upstreams_columns_dict, downstream_columns_dict):
try:
table_with_upstreams = cls(
dataset,
[
SnowflakeColumnReference.parse_obj(col)
for col in upstreams_columns_dict
],
[
SnowflakeColumnWithLineage.parse_obj(col)
for col in downstream_columns_dict
],
)
except ValidationError:
# Earlier versions of column lineage did not include columnName, only columnId
table_with_upstreams = cls(dataset, [], [])
return table_with_upstreams


@dataclass
class SnowflakeTableLineage:
# key: upstream table name
upstreamTables: Dict[str, SnowflakeUpstreamTable] = field(
default_factory=dict, init=False
)

# key: downstream column name
columnLineages: Dict[str, SnowflakeColumnUpstreams] = field(
default_factory=lambda: defaultdict(SnowflakeColumnUpstreams), init=False
)

def update_lineage(self, table: SnowflakeUpstreamTable) -> None:
if table.upstreamDataset not in self.upstreamTables.keys():
self.upstreamTables[table.upstreamDataset] = table

if table.downstreamColumns:
for col in table.downstreamColumns:
if col.directSourceColumns:
self.columnLineages[col.columnName].update_column_lineage(
col.directSourceColumns
)


class SnowflakeLineageExtractor(SnowflakeQueryMixin, SnowflakeCommonMixin):
def __init__(self, config: SnowflakeV2Config, report: SnowflakeV2Report) -> None:
self._lineage_map: Optional[Dict[str, List[Tuple[str, str, str]]]] = None
self._lineage_map: Optional[Dict[str, SnowflakeTableLineage]] = None
self._external_lineage_map: Optional[Dict[str, Set[str]]] = None
self.config = config
self.platform = "snowflake"
Expand Down Expand Up @@ -54,49 +164,103 @@ def _get_upstream_lineage_info(

lineage = self._lineage_map[dataset_name]
external_lineage = self._external_lineage_map[dataset_name]
if not (lineage or external_lineage):
if not (lineage.upstreamTables or lineage.columnLineages or external_lineage):
logger.debug(f"No lineage found for {dataset_name}")
return None
upstream_tables: List[UpstreamClass] = []
finegrained_lineages: List[FineGrainedLineage] = []
fieldset_finegrained_lineages: List[FineGrainedLineage] = []
column_lineage: Dict[str, str] = {}
for lineage_entry in lineage:
dataset_urn = builder.make_dataset_urn_with_platform_instance(
self.platform,
dataset_name,
self.config.platform_instance,
self.config.env,
)
for lineage_entry in sorted(
lineage.upstreamTables.values(), key=lambda x: x.upstreamDataset
):
# Update the table-lineage
upstream_table_name = lineage_entry[0]
if not self._is_dataset_pattern_allowed(upstream_table_name, "table"):
continue
upstream_table_name = lineage_entry.upstreamDataset
upstream_table_urn = builder.make_dataset_urn_with_platform_instance(
self.platform,
upstream_table_name,
self.config.platform_instance,
self.config.env,
)
upstream_table = UpstreamClass(
dataset=builder.make_dataset_urn_with_platform_instance(
self.platform,
upstream_table_name,
self.config.platform_instance,
self.config.env,
),
dataset=upstream_table_urn,
type=DatasetLineageTypeClass.TRANSFORMED,
)
upstream_tables.append(upstream_table)
# Update column-lineage for each down-stream column.
upstream_columns = [
self.snowflake_identifier(d["columnName"])
for d in json.loads(lineage_entry[1])
]
downstream_columns = [
self.snowflake_identifier(d["columnName"])
for d in json.loads(lineage_entry[2])
]
upstream_column_str = (
f"{upstream_table_name}({', '.join(sorted(upstream_columns))})"
)
downstream_column_str = (
f"{dataset_name}({', '.join(sorted(downstream_columns))})"
)
column_lineage_key = f"column_lineage[{upstream_table_name}]"
column_lineage_value = (
f"{{{upstream_column_str} -> {downstream_column_str}}}"
)
column_lineage[column_lineage_key] = column_lineage_value
logger.debug(f"{column_lineage_key}:{column_lineage_value}")

for external_lineage_entry in external_lineage:
if lineage_entry.upstreamColumns and lineage_entry.downstreamColumns:
# This is not used currently. This indicates same column lineage as was set
# in customProperties earlier - not accurate.
fieldset_finegrained_lineage = FineGrainedLineage(
upstreamType=FineGrainedLineageUpstreamType.FIELD_SET,
downstreamType=FineGrainedLineageDownstreamType.FIELD_SET
if len(lineage_entry.downstreamColumns) > 1
else FineGrainedLineageDownstreamType.FIELD,
upstreams=sorted(
[
builder.make_schema_field_urn(
upstream_table_urn,
self.snowflake_identifier(d.columnName),
)
for d in lineage_entry.upstreamColumns
]
),
downstreams=sorted(
[
builder.make_schema_field_urn(
dataset_urn, self.snowflake_identifier(d.columnName)
)
for d in lineage_entry.downstreamColumns
]
),
)
fieldset_finegrained_lineages.append(fieldset_finegrained_lineage)

for col, col_upstreams in lineage.columnLineages.items():
for fine_upstream in col_upstreams.upstreams:
fieldPath = col
finegrained_lineage_entry = FineGrainedLineage(
upstreamType=FineGrainedLineageUpstreamType.FIELD_SET,
upstreams=sorted(
[
builder.make_schema_field_urn(
builder.make_dataset_urn_with_platform_instance(
self.platform,
self.get_dataset_identifier_from_qualified_name(
upstream_col.objectName
),
self.config.platform_instance,
self.config.env,
),
self.snowflake_identifier(upstream_col.columnName),
)
for upstream_col in fine_upstream.inputColumns # type:ignore
if upstream_col.objectName
and upstream_col.columnName
and self._is_dataset_pattern_allowed(
upstream_col.objectName, upstream_col.objectDomain
)
]
),
downstreamType=FineGrainedLineageDownstreamType.FIELD,
downstreams=sorted(
[
builder.make_schema_field_urn(
dataset_urn, self.snowflake_identifier(fieldPath)
)
]
),
)
if finegrained_lineage_entry.upstreams:
finegrained_lineages.append(finegrained_lineage_entry)

for external_lineage_entry in sorted(external_lineage):
# For now, populate only for S3
if external_lineage_entry.startswith("s3://"):
external_upstream_table = UpstreamClass(
Expand All @@ -113,7 +277,16 @@ def _get_upstream_lineage_info(
self.report.upstream_lineage[dataset_name] = [
u.dataset for u in upstream_tables
]
return UpstreamLineage(upstreams=upstream_tables), column_lineage
return (
UpstreamLineage(
upstreams=upstream_tables,
fineGrainedLineages=sorted(
finegrained_lineages, key=lambda x: (x.downstreams, x.upstreams)
)
or None,
),
column_lineage,
)
return None

def _populate_view_lineage(self, conn: SnowflakeConnection) -> None:
Expand Down Expand Up @@ -189,7 +362,7 @@ def _populate_lineage(self, conn: SnowflakeConnection) -> None:
end_time_millis=int(self.config.end_time.timestamp() * 1000),
)
num_edges: int = 0
self._lineage_map = defaultdict(list)
self._lineage_map = defaultdict(SnowflakeTableLineage)
try:
for db_row in self.query(conn, query):
# key is the down-stream table name
Expand All @@ -204,19 +377,21 @@ def _populate_lineage(self, conn: SnowflakeConnection) -> None:
or self._is_dataset_pattern_allowed(upstream_table_name, "table")
):
continue
self._lineage_map[key].append(

self._lineage_map[key].update_lineage(
# (<upstream_table_name>, <json_list_of_upstream_columns>, <json_list_of_downstream_columns>)
(
SnowflakeUpstreamTable.from_dict(
upstream_table_name,
db_row["UPSTREAM_TABLE_COLUMNS"],
db_row["DOWNSTREAM_TABLE_COLUMNS"],
)
json.loads(db_row["UPSTREAM_TABLE_COLUMNS"]),
json.loads(db_row["DOWNSTREAM_TABLE_COLUMNS"]),
),
)
num_edges += 1
logger.debug(
f"Lineage[Table(Down)={key}]:Table(Up)={self._lineage_map[key]}"
)
except Exception as e:
logger.error(e, exc_info=e)
self.warn(
"lineage",
f"Extracting lineage from Snowflake failed."
Expand Down Expand Up @@ -246,15 +421,19 @@ def _populate_view_upstream_lineage(self, conn: SnowflakeConnection) -> None:
view_name: str = self.get_dataset_identifier_from_qualified_name(
db_row["DOWNSTREAM_VIEW"]
)

if not self._is_dataset_pattern_allowed(
dataset_name=view_name,
dataset_type=db_row["REFERENCING_OBJECT_DOMAIN"],
) or not self._is_dataset_pattern_allowed(
view_upstream, db_row["REFERENCED_OBJECT_DOMAIN"]
):
continue

# key is the downstream view name
self._lineage_map[view_name].append(
self._lineage_map[view_name].update_lineage(
# (<upstream_table_name>, <empty_json_list_of_upstream_table_columns>, <empty_json_list_of_downstream_view_columns>)
(view_upstream, "[]", "[]")
SnowflakeUpstreamTable.from_dict(view_upstream, [], [])
)
num_edges += 1
logger.debug(
Expand Down Expand Up @@ -297,20 +476,23 @@ def _populate_view_downstream_lineage(self, conn: SnowflakeConnection) -> None:
view_name: str = self.get_dataset_identifier_from_qualified_name(
db_row["VIEW_NAME"]
)
downstream_table: str = self.get_dataset_identifier_from_qualified_name(
db_row["DOWNSTREAM_TABLE_NAME"]
)
if not self._is_dataset_pattern_allowed(
view_name, db_row["VIEW_DOMAIN"]
) or not self._is_dataset_pattern_allowed(
downstream_table, db_row["DOWNSTREAM_TABLE_DOMAIN"]
):
continue
downstream_table: str = self.get_dataset_identifier_from_qualified_name(
db_row["DOWNSTREAM_TABLE_NAME"]
)

# Capture view->downstream table lineage.
self._lineage_map[downstream_table].append(
self._lineage_map[downstream_table].update_lineage(
# (<upstream_view_name>, <json_list_of_upstream_view_columns>, <json_list_of_downstream_columns>)
(
SnowflakeUpstreamTable.from_dict(
view_name,
db_row["VIEW_COLUMNS"],
db_row["DOWNSTREAM_TABLE_COLUMNS"],
json.loads(db_row["VIEW_COLUMNS"]),
json.loads(db_row["DOWNSTREAM_TABLE_COLUMNS"]),
)
)
self.report.num_view_to_table_edges_scanned += 1
Expand Down
Loading

0 comments on commit e417ccb

Please sign in to comment.