Skip to content

Commit

Permalink
update code, add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mayurinehate committed Dec 19, 2022
1 parent 63180ab commit 7da9ad8
Show file tree
Hide file tree
Showing 9 changed files with 372 additions and 141 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ class SnowflakeCloudProvider(str, Enum):
AZURE = "azure"


SNOWFLAKE_DEFAULT_CLOUD_REGION_ID = "us-west-2"
SNOWFLAKE_DEFAULT_CLOUD = SnowflakeCloudProvider.AWS


Expand All @@ -16,3 +15,12 @@ class SnowflakeEdition(str, Enum):

# We use this to represent Enterprise Edition or higher
ENTERPRISE = "Enterprise or above"


# See https://docs.snowflake.com/en/user-guide/admin-account-identifier.html#region-ids
# Includes only exceptions to format <provider>_<cloud region with hyphen replaced by _>
SNOWFLAKE_REGION_CLOUD_REGION_MAPPING = {
"aws_us_east_1_gov": (SnowflakeCloudProvider.AWS, "us-east-1"),
"azure_uksouth": (SnowflakeCloudProvider.AZURE, "uk-south"),
"azure_centralindia": (SnowflakeCloudProvider.AZURE, "central-india.azure"),
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class SnowflakeV2Config(SnowflakeConfig, SnowflakeUsageConfig):

include_column_lineage: bool = Field(
default=True,
description="If enabled, populates the column lineage. Supported only for snowflake table-to-table and view-to-table lineage edge (not supported in table-to-view or view-to-view lineage edge yet). Requires appropriate grants given to the role, include_table_lineage to be True and Snowflake Enterprise Edition or above.",
description="If enabled, populates the column lineage. Supported only for snowflake table-to-table and view-to-table lineage edge (not supported in table-to-view or view-to-view lineage edge yet). Requires appropriate grants given to the role.",
)

check_role_grants: bool = Field(
Expand All @@ -54,7 +54,7 @@ class SnowflakeV2Config(SnowflakeConfig, SnowflakeUsageConfig):
description="Whether to populate Snowsight url for Snowflake Objects",
)

match_fully_qualified_names: bool = Field(
match_fully_qualified_names = bool = Field(
default=False,
description="Whether `schema_pattern` is matched against fully qualified schema name `<catalog>.<schema>`.",
)
Expand Down Expand Up @@ -118,11 +118,12 @@ def validate_unsupported_configs(cls, values: Dict) -> Dict:
and values["stateful_ingestion"].enabled
and values["stateful_ingestion"].remove_stale_metadata
)
include_table_lineage = values.get("include_table_lineage")

# TODO: Allow profiling irrespective of basic schema extraction,
# TODO: Allow lineage extraction and profiling irrespective of basic schema extraction,
# as it seems possible with some refractor
if not include_technical_schema and any(
[include_profiles, delete_detection_enabled]
[include_profiles, delete_detection_enabled, include_table_lineage]
):
raise ValueError(
"Can not perform Deletion Detection, Lineage Extraction, Profiling without extracting snowflake technical schema. Set `include_technical_schema` to True or disable Deletion Detection, Lineage Extraction, Profiling."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,49 @@ def get_workunits(
self, discovered_tables: List[str], discovered_views: List[str]
) -> Iterable[MetadataWorkUnit]:

try:
conn = self.get_connection()
except Exception as e:
if isinstance(e, SnowflakePermissionError):
self.report_error("permission-error", str(e))
else:
logger.debug(e, exc_info=e)
self.report_error(
"snowflake-connection",
f"Failed to connect to snowflake instance due to error {e}.",
)
return

if self._lineage_map is None:
self._lineage_map = defaultdict(SnowflakeTableLineage)
if self.report.edition == SnowflakeEdition.STANDARD:
logger.info(
"Snowflake Account is Standard Edition. Table to Table Lineage Feature is not supported."
)
else:
with PerfTimer() as timer:
self._populate_lineage(conn)
self.report.table_lineage_query_secs = timer.elapsed_seconds()
if self.config.include_view_lineage:
if len(discovered_views) > 0:
self._populate_view_lineage(conn)
else:
logger.info("No views found. Skipping View Lineage Extraction.")

if self._external_lineage_map is None:
with PerfTimer() as timer:
self._populate_external_lineage(conn)
self.report.external_lineage_queries_secs = timer.elapsed_seconds()

assert self._lineage_map is not None
assert self._external_lineage_map is not None
if (
len(self._lineage_map.keys()) == 0
and len(self._external_lineage_map.keys()) == 0
):
logger.debug("No lineage found.")
return

if self.config.include_table_lineage:
for dataset_name in discovered_tables:
if self._is_dataset_pattern_allowed(dataset_name, "table"):
Expand Down Expand Up @@ -203,26 +246,6 @@ def get_workunits(
def _get_upstream_lineage_info(
self, dataset_name: str
) -> Optional[UpstreamLineage]:

if self._lineage_map is None or self._external_lineage_map is None:
conn = self.config.get_connection()
if self._lineage_map is None:
if self.report.edition == SnowflakeEdition.STANDARD:
logger.info(
"Snowflake Account is Standard Edition. Table to Table Lineage Feature is not supported."
)
else:
with PerfTimer() as timer:
self._populate_lineage(conn)
self.report.table_lineage_query_secs = timer.elapsed_seconds()
if self.config.include_view_lineage:
self._populate_view_lineage(conn)

if self._external_lineage_map is None:
with PerfTimer() as timer:
self._populate_external_lineage(conn)
self.report.external_lineage_queries_secs = timer.elapsed_seconds()

assert self._lineage_map is not None
assert self._external_lineage_map is not None

Expand Down Expand Up @@ -364,45 +387,46 @@ def _populate_view_lineage(self, conn: SnowflakeConnection) -> None:
self.report.view_downstream_lineage_query_secs = timer.elapsed_seconds()

def _populate_external_lineage(self, conn: SnowflakeConnection) -> None:
# Handles the case where a table is populated from an external location via copy.
# Eg: copy into category_english from 's3://acryl-snow-demo-olist/olist_raw_data/category_english'credentials=(aws_key_id='...' aws_secret_key='...') pattern='.*.csv';
query: str = SnowflakeQuery.external_table_lineage_history(
start_time_millis=int(self.config.start_time.timestamp() * 1000)
if not self.config.ignore_start_time_lineage
else 0,
end_time_millis=int(self.config.end_time.timestamp() * 1000),
)

num_edges: int = 0
self._external_lineage_map = defaultdict(set)
try:
for db_row in self.query(conn, query):
# key is the down-stream table name
key: str = self.get_dataset_identifier_from_qualified_name(
db_row["DOWNSTREAM_TABLE_NAME"]
)
if not self._is_dataset_pattern_allowed(key, "table"):
continue
self._external_lineage_map[key] |= {
*json.loads(db_row["UPSTREAM_LOCATIONS"])
}
logger.debug(
f"ExternalLineage[Table(Down)={key}]:External(Up)={self._external_lineage_map[key]} via access_history"
)
except SnowflakePermissionError:
if self.report.edition == SnowflakeEdition.STANDARD:
logger.info(
"Snowflake Account is Standard Edition. External Lineage Feature is not supported."
)
else:
error_msg = "Failed to get lineage. Please grant permissions for SNOWFLAKE database. "
self.warn_if_stateful_else_error("lineage-permission-error", error_msg)
except Exception as e:
logger.debug(e, exc_info=e)
self.report_warning(
"external_lineage",
f"Populating table external lineage from Snowflake failed due to error {e}.",
if self.report.edition == SnowflakeEdition.STANDARD:
logger.info(
"Snowflake Account is Standard Edition. External Lineage Feature via Access History is not supported."
)
else:
# Handles the case where a table is populated from an external location via copy.
# Eg: copy into category_english from 's3://acryl-snow-demo-olist/olist_raw_data/category_english'credentials=(aws_key_id='...' aws_secret_key='...') pattern='.*.csv';
query: str = SnowflakeQuery.external_table_lineage_history(
start_time_millis=int(self.config.start_time.timestamp() * 1000)
if not self.config.ignore_start_time_lineage
else 0,
end_time_millis=int(self.config.end_time.timestamp() * 1000),
)

try:
for db_row in self.query(conn, query):
# key is the down-stream table name
key: str = self.get_dataset_identifier_from_qualified_name(
db_row["DOWNSTREAM_TABLE_NAME"]
)
if not self._is_dataset_pattern_allowed(key, "table"):
continue
self._external_lineage_map[key] |= {
*json.loads(db_row["UPSTREAM_LOCATIONS"])
}
logger.debug(
f"ExternalLineage[Table(Down)={key}]:External(Up)={self._external_lineage_map[key]} via access_history"
)
except SnowflakePermissionError:
error_msg = "Failed to get external lineage. Please grant permissions for SNOWFLAKE database. "
self.warn_if_stateful_else_error("lineage-permission-error", error_msg)
except Exception as e:
logger.debug(e, exc_info=e)
self.report_warning(
"external_lineage",
f"Populating table external lineage from Snowflake failed due to error {e}.",
)

# Handles the case for explicitly created external tables.
# NOTE: Snowflake does not log this information to the access_history table.
Expand Down Expand Up @@ -430,6 +454,7 @@ def _populate_external_lineage(self, conn: SnowflakeConnection) -> None:
self.report.num_external_table_edges_scanned = num_edges

def _populate_lineage(self, conn: SnowflakeConnection) -> None:
assert self._lineage_map is not None
query: str = SnowflakeQuery.table_to_table_lineage_history(
start_time_millis=int(self.config.start_time.timestamp() * 1000)
if not self.config.ignore_start_time_lineage
Expand All @@ -438,7 +463,6 @@ def _populate_lineage(self, conn: SnowflakeConnection) -> None:
include_column_lineage=self.config.include_column_lineage,
)
num_edges: int = 0
self._lineage_map = defaultdict(SnowflakeTableLineage)
try:
for db_row in self.query(conn, query):
# key is the down-stream table name
Expand Down Expand Up @@ -467,7 +491,7 @@ def _populate_lineage(self, conn: SnowflakeConnection) -> None:
f"Lineage[Table(Down)={key}]:Table(Up)={self._lineage_map[key]}"
)
except SnowflakePermissionError:
error_msg = "Failed to get lineage. Please grant permissions for SNOWFLAKE database. "
error_msg = "Failed to get table to table lineage. Please grant permissions for SNOWFLAKE database. "
self.warn_if_stateful_else_error("lineage-permission-error", error_msg)
except Exception as e:
logger.debug(e, exc_info=e)
Expand Down Expand Up @@ -519,16 +543,16 @@ def _populate_view_upstream_lineage(self, conn: SnowflakeConnection) -> None:
f"Upstream->View: Lineage[View(Down)={view_name}]:Upstream={view_upstream}"
)
except SnowflakePermissionError:
error_msg = "Failed to get lineage. Please grant permissions for SNOWFLAKE database. "
error_msg = "Failed to get table to view lineage. Please grant permissions for SNOWFLAKE database."
self.warn_if_stateful_else_error("lineage-permission-error", error_msg)
except Exception as e:
logger.debug(e, exc_info=e)
self.report_warning(
"view-upstream-lineage",
f"Extracting the upstream view lineage from Snowflake failed due to error {e}.",
)

logger.info(f"A total of {num_edges} View upstream edges found.")
else:
logger.info(f"A total of {num_edges} View upstream edges found.")
self.report.num_table_to_view_edges_scanned = num_edges

def _populate_view_downstream_lineage(self, conn: SnowflakeConnection) -> None:
Expand All @@ -550,7 +574,7 @@ def _populate_view_downstream_lineage(self, conn: SnowflakeConnection) -> None:
try:
db_rows = self.query(conn, view_lineage_query)
except SnowflakePermissionError:
error_msg = "Failed to get lineage. Please grant permissions for SNOWFLAKE database. "
error_msg = "Failed to get view to table lineage. Please grant permissions for SNOWFLAKE database. "
self.warn_if_stateful_else_error("lineage-permission-error", error_msg)
except Exception as e:
logger.debug(e, exc_info=e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def get_profiler_instance(

def callable_for_db_connection(self, db_name: str) -> Callable:
def get_db_connection():
conn = self.config.get_connection()
conn = self.get_connection()
conn.cursor().execute(SnowflakeQuery.use_database(db_name))
return conn

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,30 @@ def __init__(self, config: SnowflakeV2Config, report: SnowflakeV2Report) -> None
def get_workunits(
self, discovered_datasets: List[str]
) -> Iterable[MetadataWorkUnit]:
conn = self.config.get_connection()
try:
conn = self.get_connection()
except Exception as e:
if isinstance(e, SnowflakePermissionError):
self.report_error("permission-error", str(e))
else:
logger.debug(e, exc_info=e)
self.report_error(
"snowflake-connection",
f"Failed to connect to snowflake instance due to error {e}.",
)
return

if self.report.edition == SnowflakeEdition.STANDARD.value:
logger.info(
"Snowflake Account is Standard Edition. Usage Feature is not supported."
)
return

logger.info("Checking usage date ranges")

self._check_usage_date_ranges(conn)

# If permission error, execution returns from here
if (
self.report.min_access_history_time is None
or self.report.max_access_history_time is None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
from enum import Enum
from typing import Any, Optional

from snowflake.connector import SnowflakeConnection
Expand All @@ -12,8 +11,7 @@
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.snowflake.constants import (
SNOWFLAKE_DEFAULT_CLOUD,
SNOWFLAKE_DEFAULT_CLOUD_REGION_ID,
SnowflakeCloudProvider,
SNOWFLAKE_REGION_CLOUD_REGION_MAPPING,
)
from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config
from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report
Expand All @@ -27,30 +25,6 @@ class SnowflakePermissionError(MetaError):
"""A permission error has happened"""


class SnowflakeEdition(str, Enum):
STANDARD = "Standard"

# We use this to represent Enterprise Edition or higher
ENTERPRISE = "Enterprise or above"


class SnowflakeCloudProvider(str, Enum):
AWS = "aws"
GCP = "gcp"
AZURE = "azure"


# See https://docs.snowflake.com/en/user-guide/admin-account-identifier.html#region-ids
# Includes only exceptions to format <provider>_<cloud region with hyphen replaced by _>
SNOWFLAKE_REGION_CLOUD_REGION_MAPPING = {
"aws_us_east_1_gov": (SnowflakeCloudProvider.AWS, "us-east-1"),
"azure_uksouth": (SnowflakeCloudProvider.AZURE, "uk-south"),
"azure_centralindia": (SnowflakeCloudProvider.AZURE, "central-india.azure"),
}

SNOWFLAKE_DEFAULT_CLOUD = SnowflakeCloudProvider.AWS


# Required only for mypy, since we are using mixin classes, and not inheritance.
# Reference - https://mypy.readthedocs.io/en/latest/more_types.html#mixin-classes
class SnowflakeLoggingProtocol(Protocol):
Expand Down Expand Up @@ -125,6 +99,19 @@ def get_cloud_region_from_snowflake_region_id(region):
raise Exception(f"Unknown snowflake region {region}")
return cloud, cloud_region_id

def get_connection(self: SnowflakeCommonProtocol) -> SnowflakeConnection:
try:
conn = self.config.get_connection()
except Exception as e:
# 250001 (08001): Failed to connect to DB: xxxx.snowflakecomputing.com:443. Role 'XXXXX' specified in the connect string is not granted to this user. Contact your local system administrator, or attempt to login with another role, e.g. PUBLIC.
if "not granted to this user" in str(e):
raise SnowflakePermissionError(
f"Failed to connect with snowflake due to error {e}"
) from e
raise
else:
return conn

def _is_dataset_pattern_allowed(
self: SnowflakeCommonProtocol,
dataset_name: Optional[str],
Expand Down Expand Up @@ -252,5 +239,6 @@ def report_error(self: SnowflakeCommonProtocol, key: str, reason: str) -> None:

def is_permission_error(e: Exception) -> bool:
msg = str(e)
# Database 'XXXX' does not exist or not authorized.
# 002003 (02000): SQL compilation error: Database/SCHEMA 'XXXX' does not exist or not authorized.
# Insufficient privileges to operate on database 'XXXX'
return "Insufficient privileges" in msg or "not authorized" in msg
Loading

0 comments on commit 7da9ad8

Please sign in to comment.