Skip to content

Commit

Permalink
feat(ingest/snowflake): performance improvements (#10746)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored and yoonhyejin committed Jul 16, 2024
1 parent 5069411 commit d466909
Show file tree
Hide file tree
Showing 12 changed files with 294 additions and 178 deletions.
3 changes: 2 additions & 1 deletion metadata-ingestion/src/datahub/ingestion/run/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from datahub.utilities.lossy_collections import LossyDict, LossyList

logger = logging.getLogger(__name__)
_REPORT_PRINT_INTERVAL_SECONDS = 60


class LoggingCallback(WriteCallback):
Expand Down Expand Up @@ -403,7 +404,7 @@ def create(
def _time_to_print(self) -> bool:
self.num_intermediate_workunits += 1
current_time = int(time.time())
if current_time - self.last_time_printed > 10:
if current_time - self.last_time_printed > _REPORT_PRINT_INTERVAL_SECONDS:
# we print
self.num_intermediate_workunits = 0
self.last_time_printed = current_time
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@ class SnowflakeV2Config(
description="If enabled, populates the snowflake technical schema and descriptions.",
)

include_primary_keys: bool = Field(
default=True,
description="If enabled, populates the snowflake primary keys.",
)
include_foreign_keys: bool = Field(
default=True,
description="If enabled, populates the snowflake foreign keys.",
)

include_column_lineage: bool = Field(
default=True,
description="Populates table->table and view->table column lineage. Requires appropriate grants given to the role and the Snowflake Enterprise Edition or above.",
Expand All @@ -105,6 +114,12 @@ class SnowflakeV2Config(
description="Populates view->view and table->view column lineage using DataHub's sql parser.",
)

lazy_schema_resolver: bool = Field(
default=False,
description="If enabled, uses lazy schema resolver to resolve schemas for tables and views. "
"This is useful if you have a large number of schemas and want to avoid bulk fetching the schema for each table/view.",
)

_check_role_grants_removed = pydantic_removed_field("check_role_grants")
_provision_role_removed = pydantic_removed_field("provision_role")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datahub.configuration.time_window_config import BucketDuration
from datahub.ingestion.source.snowflake.constants import SnowflakeObjectDomain
from datahub.ingestion.source.snowflake.snowflake_config import DEFAULT_TABLES_DENY_LIST
from datahub.utilities.prefix_batch_builder import PrefixGroup

SHOW_VIEWS_MAX_PAGE_SIZE = 10000

Expand Down Expand Up @@ -228,50 +229,50 @@ def show_views_for_database(
"""

@staticmethod
def columns_for_schema(schema_name: str, db_name: Optional[str]) -> str:
db_clause = f'"{db_name}".' if db_name is not None else ""
return f"""
select
table_catalog AS "TABLE_CATALOG",
table_schema AS "TABLE_SCHEMA",
table_name AS "TABLE_NAME",
column_name AS "COLUMN_NAME",
ordinal_position AS "ORDINAL_POSITION",
is_nullable AS "IS_NULLABLE",
data_type AS "DATA_TYPE",
comment AS "COMMENT",
character_maximum_length AS "CHARACTER_MAXIMUM_LENGTH",
numeric_precision AS "NUMERIC_PRECISION",
numeric_scale AS "NUMERIC_SCALE",
column_default AS "COLUMN_DEFAULT",
is_identity AS "IS_IDENTITY"
from {db_clause}information_schema.columns
WHERE table_schema='{schema_name}'
ORDER BY ordinal_position"""

@staticmethod
def columns_for_table(
table_name: str, schema_name: str, db_name: Optional[str]
def columns_for_schema(
schema_name: str,
db_name: str,
prefix_groups: Optional[List[PrefixGroup]] = None,
) -> str:
db_clause = f'"{db_name}".' if db_name is not None else ""
return f"""
select
table_catalog AS "TABLE_CATALOG",
table_schema AS "TABLE_SCHEMA",
table_name AS "TABLE_NAME",
column_name AS "COLUMN_NAME",
ordinal_position AS "ORDINAL_POSITION",
is_nullable AS "IS_NULLABLE",
data_type AS "DATA_TYPE",
comment AS "COMMENT",
character_maximum_length AS "CHARACTER_MAXIMUM_LENGTH",
numeric_precision AS "NUMERIC_PRECISION",
numeric_scale AS "NUMERIC_SCALE",
column_default AS "COLUMN_DEFAULT",
is_identity AS "IS_IDENTITY"
from {db_clause}information_schema.columns
WHERE table_schema='{schema_name}' and table_name='{table_name}'
ORDER BY ordinal_position"""
columns_template = """\
SELECT
table_catalog AS "TABLE_CATALOG",
table_schema AS "TABLE_SCHEMA",
table_name AS "TABLE_NAME",
column_name AS "COLUMN_NAME",
ordinal_position AS "ORDINAL_POSITION",
is_nullable AS "IS_NULLABLE",
data_type AS "DATA_TYPE",
comment AS "COMMENT",
character_maximum_length AS "CHARACTER_MAXIMUM_LENGTH",
numeric_precision AS "NUMERIC_PRECISION",
numeric_scale AS "NUMERIC_SCALE",
column_default AS "COLUMN_DEFAULT",
is_identity AS "IS_IDENTITY"
FROM "{db_name}".information_schema.columns
WHERE table_schema='{schema_name}' AND {extra_clause}"""

selects = []
if prefix_groups is None:
prefix_groups = [PrefixGroup(prefix="", names=[])]
for prefix_group in prefix_groups:
if prefix_group.prefix == "":
extra_clause = "TRUE"
elif prefix_group.exact_match:
extra_clause = f"table_name = '{prefix_group.prefix}'"
else:
extra_clause = f"table_name LIKE '{prefix_group.prefix}%'"

selects.append(
columns_template.format(
db_name=db_name, schema_name=schema_name, extra_clause=extra_clause
)
)

return (
"\nUNION ALL\n".join(selects)
+ """\nORDER BY table_name, ordinal_position"""
)

@staticmethod
def show_primary_keys_for_schema(schema_name: str, db_name: str) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ class SnowflakeV2Report(
# "Information schema query returned too much data. Please repeat query with more selective predicates.""
# This will result in overall increase in time complexity
num_get_tables_for_schema_queries: int = 0
num_get_columns_for_table_queries: int = 0

# these will be non-zero if the user choses to enable the extract_tags = "with_lineage" option, which requires
# individual queries per object (database, schema, table) and an extra query per table to get the tags on the columns.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from typing import Callable, Dict, List, Optional
from typing import Callable, Dict, Iterable, List, MutableMapping, Optional

from snowflake.connector import SnowflakeConnection

Expand All @@ -15,6 +15,8 @@
)
from datahub.ingestion.source.snowflake.snowflake_utils import SnowflakeQueryMixin
from datahub.ingestion.source.sql.sql_generic import BaseColumn, BaseTable, BaseView
from datahub.utilities.file_backed_collections import FileBackedDict
from datahub.utilities.prefix_batch_builder import build_prefix_batches
from datahub.utilities.serialized_lru_cache import serialized_lru_cache

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -379,58 +381,48 @@ def get_views_for_database(self, db_name: str) -> Dict[str, List[SnowflakeView]]

@serialized_lru_cache(maxsize=SCHEMA_PARALLELISM)
def get_columns_for_schema(
self, schema_name: str, db_name: str
) -> Optional[Dict[str, List[SnowflakeColumn]]]:
columns: Dict[str, List[SnowflakeColumn]] = {}
try:
cur = self.query(SnowflakeQuery.columns_for_schema(schema_name, db_name))
except Exception as e:
logger.debug(
f"Failed to get all columns for schema - {schema_name}", exc_info=e
)
# Error - Information schema query returned too much data.
# Please repeat query with more selective predicates.
return None

for column in cur:
if column["TABLE_NAME"] not in columns:
columns[column["TABLE_NAME"]] = []
columns[column["TABLE_NAME"]].append(
SnowflakeColumn(
name=column["COLUMN_NAME"],
ordinal_position=column["ORDINAL_POSITION"],
is_nullable=column["IS_NULLABLE"] == "YES",
data_type=column["DATA_TYPE"],
comment=column["COMMENT"],
character_maximum_length=column["CHARACTER_MAXIMUM_LENGTH"],
numeric_precision=column["NUMERIC_PRECISION"],
numeric_scale=column["NUMERIC_SCALE"],
self,
schema_name: str,
db_name: str,
# HACK: This key is excluded from the cache key.
cache_exclude_all_objects: Iterable[str],
) -> MutableMapping[str, List[SnowflakeColumn]]:
all_objects = list(cache_exclude_all_objects)

columns: MutableMapping[str, List[SnowflakeColumn]] = {}
if len(all_objects) > 10000:
# For massive schemas, use a FileBackedDict to avoid memory issues.
columns = FileBackedDict()

object_batches = build_prefix_batches(
all_objects, max_batch_size=10000, max_groups_in_batch=5
)
for batch_index, object_batch in enumerate(object_batches):
if batch_index > 0:
logger.info(
f"Still fetching columns for {db_name}.{schema_name} - batch {batch_index + 1} of {len(object_batches)}"
)
query = SnowflakeQuery.columns_for_schema(
schema_name, db_name, object_batch
)
return columns

def get_columns_for_table(
self, table_name: str, schema_name: str, db_name: str
) -> List[SnowflakeColumn]:
columns: List[SnowflakeColumn] = []

cur = self.query(
SnowflakeQuery.columns_for_table(table_name, schema_name, db_name),
)

for column in cur:
columns.append(
SnowflakeColumn(
name=column["COLUMN_NAME"],
ordinal_position=column["ORDINAL_POSITION"],
is_nullable=column["IS_NULLABLE"] == "YES",
data_type=column["DATA_TYPE"],
comment=column["COMMENT"],
character_maximum_length=column["CHARACTER_MAXIMUM_LENGTH"],
numeric_precision=column["NUMERIC_PRECISION"],
numeric_scale=column["NUMERIC_SCALE"],
cur = self.query(query)

for column in cur:
if column["TABLE_NAME"] not in columns:
columns[column["TABLE_NAME"]] = []
columns[column["TABLE_NAME"]].append(
SnowflakeColumn(
name=column["COLUMN_NAME"],
ordinal_position=column["ORDINAL_POSITION"],
is_nullable=column["IS_NULLABLE"] == "YES",
data_type=column["DATA_TYPE"],
comment=column["COMMENT"],
character_maximum_length=column["CHARACTER_MAXIMUM_LENGTH"],
numeric_precision=column["NUMERIC_PRECISION"],
numeric_scale=column["NUMERIC_SCALE"],
)
)
)
return columns

@serialized_lru_cache(maxsize=SCHEMA_PARALLELISM)
Expand Down
Loading

0 comments on commit d466909

Please sign in to comment.