Skip to content

Commit

Permalink
refractor to address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mayurinehate committed Sep 28, 2022
1 parent 7f78234 commit 75e71c7
Showing 1 changed file with 102 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Set, Tuple
from typing import Dict, FrozenSet, List, Optional, Set, Tuple

from pydantic.error_wrappers import ValidationError
from snowflake.connector import SnowflakeConnection
Expand Down Expand Up @@ -34,25 +34,51 @@
class SnowflakeColumnWithLineage(SnowflakeColumnReference):
directSourceColumns: Optional[List[SnowflakeColumnReference]] = None

def __hash__(self):
return hash(self.__members())

def __members(self):
members = (
self.columnName,
tuple(
sorted((c.columnName, c.objectName) for c in self.directSourceColumns)
)
if self.directSourceColumns
else None,
)
return members
@dataclass(frozen=True)
class SnowflakeColumnId:
columnName: str
objectName: str
objectDomain: Optional[str] = None


@dataclass(frozen=True)
class SnowflakeColumnFineGrainedLineage:
"""
Fie grained upstream of column,
which represents a transformation applied on input columns"""

inputColumns: FrozenSet[SnowflakeColumnId]
# Transform function, query etc can be added here

def __eq__(self, instance):
return (
isinstance(instance, SnowflakeColumnWithLineage)
and self.__members() == instance.__members()

@dataclass
class SnowflakeColumnUpstreams:
"""All upstreams of a column"""

upstreams: Set[SnowflakeColumnFineGrainedLineage] = field(
default_factory=set, init=False
)

def update_column_lineage(
self, directSourceColumns: List[SnowflakeColumnReference]
) -> None:
input_columns = frozenset(
[
SnowflakeColumnId(
upstream_col.columnName,
upstream_col.objectName,
upstream_col.objectDomain,
)
for upstream_col in directSourceColumns
if upstream_col.objectName
]
)
if not input_columns:
return
upstream = SnowflakeColumnFineGrainedLineage(inputColumns=input_columns)
if upstream not in self.upstreams:
self.upstreams.add(upstream)


@dataclass
Expand All @@ -66,42 +92,43 @@ def from_dict(cls, dataset, upstreams_columns_dict, downstream_columns_dict):
try:
table_with_upstreams = cls(
dataset,
[SnowflakeColumnReference(**col) for col in upstreams_columns_dict],
[SnowflakeColumnWithLineage(**col) for col in downstream_columns_dict],
[
SnowflakeColumnReference.parse_obj(col)
for col in upstreams_columns_dict
],
[
SnowflakeColumnWithLineage.parse_obj(col)
for col in downstream_columns_dict
],
)
except ValidationError:
# Earlier versions of column lineage did not include columnName, only columnId
table_with_upstreams = cls(dataset, [], [])
return table_with_upstreams

def __hash__(self):
return hash(self.__members())

def __members(self):
return (self.upstreamDataset,)

def __eq__(self, instance):
return (
isinstance(instance, SnowflakeUpstreamTable)
and self.__members() == instance.__members()
)


@dataclass
class SnowflakeTableLineage:
upstreamTables: Set[SnowflakeUpstreamTable] = field(default_factory=set, init=False)
columnLineages: Set[SnowflakeColumnWithLineage] = field(
default_factory=set, init=False
# key: upstream table name
upstreamTables: Dict[str, SnowflakeUpstreamTable] = field(
default_factory=dict, init=False
)

# key: downstream column name
columnLineages: Dict[str, SnowflakeColumnUpstreams] = field(
default_factory=lambda: defaultdict(SnowflakeColumnUpstreams), init=False
)

def update_lineage(self, table: SnowflakeUpstreamTable) -> None:
if table not in self.upstreamTables:
self.upstreamTables.add(table)
if table.upstreamDataset not in self.upstreamTables.keys():
self.upstreamTables[table.upstreamDataset] = table

if table.downstreamColumns:
for col in table.downstreamColumns:
if col.directSourceColumns and col not in self.columnLineages:
self.columnLineages.add(col)
if col.directSourceColumns:
self.columnLineages[col.columnName].update_column_lineage(
col.directSourceColumns
)


class SnowflakeLineageExtractor(SnowflakeQueryMixin, SnowflakeCommonMixin):
Expand Down Expand Up @@ -151,7 +178,7 @@ def _get_upstream_lineage_info(
self.config.env,
)
for lineage_entry in sorted(
lineage.upstreamTables, key=lambda x: x.upstreamDataset
lineage.upstreamTables.values(), key=lambda x: x.upstreamDataset
):
# Update the table-lineage
upstream_table_name = lineage_entry.upstreamDataset
Expand Down Expand Up @@ -195,42 +222,43 @@ def _get_upstream_lineage_info(
)
fieldset_finegrained_lineages.append(fieldset_finegrained_lineage)

for col in lineage.columnLineages:
fieldPath = col.columnName
finegrained_lineage_entry = FineGrainedLineage(
upstreamType=FineGrainedLineageUpstreamType.FIELD_SET,
upstreams=sorted(
[
builder.make_schema_field_urn(
builder.make_dataset_urn_with_platform_instance(
self.platform,
self.get_dataset_identifier_from_qualified_name(
upstream_col.objectName
for col, col_upstreams in lineage.columnLineages.items():
for fine_upstream in col_upstreams.upstreams:
fieldPath = col
finegrained_lineage_entry = FineGrainedLineage(
upstreamType=FineGrainedLineageUpstreamType.FIELD_SET,
upstreams=sorted(
[
builder.make_schema_field_urn(
builder.make_dataset_urn_with_platform_instance(
self.platform,
self.get_dataset_identifier_from_qualified_name(
upstream_col.objectName
),
self.config.platform_instance,
self.config.env,
),
self.config.platform_instance,
self.config.env,
),
self.snowflake_identifier(upstream_col.columnName),
)
for upstream_col in col.directSourceColumns # type:ignore
if upstream_col.objectName
and upstream_col.columnName
and self._is_dataset_pattern_allowed(
upstream_col.objectName, upstream_col.objectDomain
)
]
),
downstreamType=FineGrainedLineageDownstreamType.FIELD,
downstreams=sorted(
[
builder.make_schema_field_urn(
dataset_urn, self.snowflake_identifier(fieldPath)
)
]
),
)
if finegrained_lineage_entry.upstreams:
finegrained_lineages.append(finegrained_lineage_entry)
self.snowflake_identifier(upstream_col.columnName),
)
for upstream_col in fine_upstream.inputColumns # type:ignore
if upstream_col.objectName
and upstream_col.columnName
and self._is_dataset_pattern_allowed(
upstream_col.objectName, upstream_col.objectDomain
)
]
),
downstreamType=FineGrainedLineageDownstreamType.FIELD,
downstreams=sorted(
[
builder.make_schema_field_urn(
dataset_urn, self.snowflake_identifier(fieldPath)
)
]
),
)
if finegrained_lineage_entry.upstreams:
finegrained_lineages.append(finegrained_lineage_entry)

for external_lineage_entry in sorted(external_lineage):
# For now, populate only for S3
Expand Down

0 comments on commit 75e71c7

Please sign in to comment.