Skip to content

Commit

Permalink
update quicksight lineage (#1067)
Browse files Browse the repository at this point in the history
  • Loading branch information
alyiwang authored Jan 17, 2025
1 parent 62bb023 commit 6f6c3a7
Show file tree
Hide file tree
Showing 8 changed files with 246 additions and 205 deletions.
12 changes: 9 additions & 3 deletions metaphor/common/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
PRIVATE_LINK_SUFFIX = ".privatelink"
SNOWFLAKE_HOST_SUFFIX = ".snowflakecomputing.com"
SNOWFLAKE_DEFAULT_REGION = ".us-west-2"


def normalize_snowflake_account(host: str) -> str:
def normalize_snowflake_account(host: str, remove_default_region: bool = False) -> str:
"""
Normalize different variations of Snowflake account.
See https://docs.snowflake.com/en/user-guide/admin-account-identifier
Expand All @@ -14,8 +15,13 @@ def normalize_snowflake_account(host: str) -> str:
if host.endswith(SNOWFLAKE_HOST_SUFFIX):
host = host[: -len(SNOWFLAKE_HOST_SUFFIX)]

# Strip PrivateLink suffix
# Strip PrivateLink suffix, e.g. account.privatelink.snowflakecomputing.com
if host.endswith(PRIVATE_LINK_SUFFIX):
return host[: -len(PRIVATE_LINK_SUFFIX)]
host = host[: -len(PRIVATE_LINK_SUFFIX)]

# Remove default region (us-west-2) if applicable, e.g. account.us-west-2.snowflakecomputing.com
# This is to keep the account name consistent with the results from the snowflake crawler
if remove_default_region and host.endswith(SNOWFLAKE_DEFAULT_REGION):
host = host[: -len(SNOWFLAKE_DEFAULT_REGION)]

return host
10 changes: 9 additions & 1 deletion metaphor/quick_sight/data_source_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,19 @@ def get_account(data_source: DataSource) -> Optional[str]:

if parameters.SnowflakeParameters:
return (
normalize_snowflake_account(parameters.SnowflakeParameters.Host)
normalize_snowflake_account(parameters.SnowflakeParameters.Host, True)
if parameters.SnowflakeParameters.Host
else None
)

if parameters.SqlServerParameters:
return parameters.SqlServerParameters.Host

return None


def get_id_from_arn(arn: str) -> str:
"""
Extract id from arn (e.g. arn:aws:quicksight:us-west-2:1231048943:[dataset/dashboard/datasource]/xxx)
"""
return arn.split("/")[-1]
55 changes: 30 additions & 25 deletions metaphor/quick_sight/extractor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Collection, Dict, List
from typing import Collection, Dict, List, Optional

from func_timeout import FunctionTimedOut, func_set_timeout

Expand Down Expand Up @@ -26,6 +26,7 @@
)
from metaphor.quick_sight.client import Client
from metaphor.quick_sight.config import QuickSightRunConfig
from metaphor.quick_sight.data_source_utils import get_id_from_arn
from metaphor.quick_sight.folder import (
DASHBOARD_DIRECTORIES,
DATA_SET_DIRECTORIES,
Expand Down Expand Up @@ -55,10 +56,10 @@ def __init__(self, config: QuickSightRunConfig) -> None:
# Arn -> Resource
self._resources: Dict[str, ResourceType] = {}

# Arn -> VirtualView
# DataSetId -> VirtualView
self._virtual_views: Dict[str, VirtualView] = {}

# Arn -> Dashboard
# DashboardId -> Dashboard
self._dashboards: Dict[str, MetaphorDashboard] = {}

async def extract(self) -> Collection[ENTITY_TYPES]:
Expand All @@ -74,8 +75,7 @@ async def extract(self) -> Collection[ENTITY_TYPES]:

@func_set_timeout(20)
def _extract_virtual_view(self, data_set: DataSet) -> None:
assert data_set.Arn
view = self._init_virtual_view(data_set.Arn, data_set)
view = self._init_virtual_view(data_set)
output_logical_table_id = LineageProcessor(
self._resources, self._virtual_views, data_set
).run()
Expand All @@ -90,7 +90,7 @@ def _extract_virtual_view(self, data_set: DataSet) -> None:
def _extract_virtual_views(self):
count = 0
for data_set in self._resources.values():
if not isinstance(data_set, DataSet) or data_set.Arn is None:
if not isinstance(data_set, DataSet):
continue

try:
Expand All @@ -110,14 +110,10 @@ def _extract_virtual_views(self):
def _extract_dashboards(self) -> None:
count = 0
for dashboard in self._resources.values():
if (
not isinstance(dashboard, Dashboard)
or dashboard.Arn is None
or dashboard.Version is None
):
if not isinstance(dashboard, Dashboard) or dashboard.Version is None:
continue

metaphor_dashboard = self._init_dashboard(dashboard.Arn, dashboard)
metaphor_dashboard = self._init_dashboard(dashboard)
metaphor_dashboard.entity_upstream = self._get_dashboard_upstream(
dataset_arns=dashboard.Version.DataSetArns or []
)
Expand All @@ -135,10 +131,13 @@ def _make_entities_list(self) -> Collection[ENTITY_TYPES]:
entities.extend(create_top_level_folders())
return entities

def _init_virtual_view(self, arn: str, data_set: DataSet) -> VirtualView:
def _init_virtual_view(self, data_set: DataSet) -> VirtualView:
data_set_id = data_set.DataSetId
assert data_set_id

view = VirtualView(
logical_id=VirtualViewLogicalID(
name=arn,
name=data_set_id,
type=VirtualViewType.QUICK_SIGHT,
),
structure=AssetStructure(
Expand All @@ -150,16 +149,17 @@ def _init_virtual_view(self, arn: str, data_set: DataSet) -> VirtualView:
),
)

self._virtual_views[arn] = view

self._virtual_views[data_set_id] = view
return view

def _init_dashboard(self, arn: str, dashboard: Dashboard) -> MetaphorDashboard:
def _init_dashboard(self, dashboard: Dashboard) -> MetaphorDashboard:
dashboard_id = dashboard.DashboardId
assert dashboard_id
assert dashboard.Version

metaphor_dashboard = MetaphorDashboard(
logical_id=MetaphorDashboardLogicalId(
dashboard_id=arn,
dashboard_id=dashboard_id,
platform=MetaphorDashboardPlatform.QUICK_SIGHT,
),
source_info=SourceInfo(
Expand Down Expand Up @@ -187,21 +187,26 @@ def _init_dashboard(self, arn: str, dashboard: Dashboard) -> MetaphorDashboard:
],
)

self._dashboards[arn] = metaphor_dashboard

self._dashboards[dashboard_id] = metaphor_dashboard
return metaphor_dashboard

def _get_dashboard_upstream(self, dataset_arns: List[str]) -> EntityUpstream:
def _get_dashboard_upstream(
self, dataset_arns: List[str]
) -> Optional[EntityUpstream]:
source_entities: List[str] = []

for arn in dataset_arns:
virtual_view = self._virtual_views.get(arn)
dataset_id = get_id_from_arn(arn)
virtual_view = self._virtual_views.get(dataset_id)
if not virtual_view:
logger.warning(f"Virtual view not found for dataset {dataset_id}")
continue

source_entities.append(
str(to_entity_id_from_virtual_view_logical_id(virtual_view.logical_id))
)

return EntityUpstream(
source_entities=(unique_list(source_entities) if source_entities else None)
)
if not source_entities:
return None

return EntityUpstream(source_entities=(unique_list(source_entities)))
Loading

0 comments on commit 6f6c3a7

Please sign in to comment.