diff --git a/metadata-ingestion/src/datahub/ingestion/run/pipeline.py b/metadata-ingestion/src/datahub/ingestion/run/pipeline.py index 75dabc4a7e02ae..dc7521f8f8c555 100644 --- a/metadata-ingestion/src/datahub/ingestion/run/pipeline.py +++ b/metadata-ingestion/src/datahub/ingestion/run/pipeline.py @@ -49,6 +49,7 @@ from datahub.utilities.lossy_collections import LossyDict, LossyList logger = logging.getLogger(__name__) +_REPORT_PRINT_INTERVAL_SECONDS = 60 class LoggingCallback(WriteCallback): @@ -403,7 +404,7 @@ def create( def _time_to_print(self) -> bool: self.num_intermediate_workunits += 1 current_time = int(time.time()) - if current_time - self.last_time_printed > 10: + if current_time - self.last_time_printed > _REPORT_PRINT_INTERVAL_SECONDS: # we print self.num_intermediate_workunits = 0 self.last_time_printed = current_time diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py index 4beb2684485694..588187e8e11c28 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py @@ -95,6 +95,15 @@ class SnowflakeV2Config( description="If enabled, populates the snowflake technical schema and descriptions.", ) + include_primary_keys: bool = Field( + default=True, + description="If enabled, populates the snowflake primary keys.", + ) + include_foreign_keys: bool = Field( + default=True, + description="If enabled, populates the snowflake foreign keys.", + ) + include_column_lineage: bool = Field( default=True, description="Populates table->table and view->table column lineage. Requires appropriate grants given to the role and the Snowflake Enterprise Edition or above.", @@ -105,6 +114,12 @@ class SnowflakeV2Config( description="Populates view->view and table->view column lineage using DataHub's sql parser.", ) + lazy_schema_resolver: bool = Field( + default=False, + description="If enabled, uses lazy schema resolver to resolve schemas for tables and views. " + "This is useful if you have a large number of schemas and want to avoid bulk fetching the schema for each table/view.", + ) + _check_role_grants_removed = pydantic_removed_field("check_role_grants") _provision_role_removed = pydantic_removed_field("provision_role") diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py index b3eb23b25e0a37..c4b6f597bbb7e5 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py @@ -4,6 +4,7 @@ from datahub.configuration.time_window_config import BucketDuration from datahub.ingestion.source.snowflake.constants import SnowflakeObjectDomain from datahub.ingestion.source.snowflake.snowflake_config import DEFAULT_TABLES_DENY_LIST +from datahub.utilities.prefix_batch_builder import PrefixGroup SHOW_VIEWS_MAX_PAGE_SIZE = 10000 @@ -228,50 +229,50 @@ def show_views_for_database( """ @staticmethod - def columns_for_schema(schema_name: str, db_name: Optional[str]) -> str: - db_clause = f'"{db_name}".' if db_name is not None else "" - return f""" - select - table_catalog AS "TABLE_CATALOG", - table_schema AS "TABLE_SCHEMA", - table_name AS "TABLE_NAME", - column_name AS "COLUMN_NAME", - ordinal_position AS "ORDINAL_POSITION", - is_nullable AS "IS_NULLABLE", - data_type AS "DATA_TYPE", - comment AS "COMMENT", - character_maximum_length AS "CHARACTER_MAXIMUM_LENGTH", - numeric_precision AS "NUMERIC_PRECISION", - numeric_scale AS "NUMERIC_SCALE", - column_default AS "COLUMN_DEFAULT", - is_identity AS "IS_IDENTITY" - from {db_clause}information_schema.columns - WHERE table_schema='{schema_name}' - ORDER BY ordinal_position""" - - @staticmethod - def columns_for_table( - table_name: str, schema_name: str, db_name: Optional[str] + def columns_for_schema( + schema_name: str, + db_name: str, + prefix_groups: Optional[List[PrefixGroup]] = None, ) -> str: - db_clause = f'"{db_name}".' if db_name is not None else "" - return f""" - select - table_catalog AS "TABLE_CATALOG", - table_schema AS "TABLE_SCHEMA", - table_name AS "TABLE_NAME", - column_name AS "COLUMN_NAME", - ordinal_position AS "ORDINAL_POSITION", - is_nullable AS "IS_NULLABLE", - data_type AS "DATA_TYPE", - comment AS "COMMENT", - character_maximum_length AS "CHARACTER_MAXIMUM_LENGTH", - numeric_precision AS "NUMERIC_PRECISION", - numeric_scale AS "NUMERIC_SCALE", - column_default AS "COLUMN_DEFAULT", - is_identity AS "IS_IDENTITY" - from {db_clause}information_schema.columns - WHERE table_schema='{schema_name}' and table_name='{table_name}' - ORDER BY ordinal_position""" + columns_template = """\ +SELECT + table_catalog AS "TABLE_CATALOG", + table_schema AS "TABLE_SCHEMA", + table_name AS "TABLE_NAME", + column_name AS "COLUMN_NAME", + ordinal_position AS "ORDINAL_POSITION", + is_nullable AS "IS_NULLABLE", + data_type AS "DATA_TYPE", + comment AS "COMMENT", + character_maximum_length AS "CHARACTER_MAXIMUM_LENGTH", + numeric_precision AS "NUMERIC_PRECISION", + numeric_scale AS "NUMERIC_SCALE", + column_default AS "COLUMN_DEFAULT", + is_identity AS "IS_IDENTITY" +FROM "{db_name}".information_schema.columns +WHERE table_schema='{schema_name}' AND {extra_clause}""" + + selects = [] + if prefix_groups is None: + prefix_groups = [PrefixGroup(prefix="", names=[])] + for prefix_group in prefix_groups: + if prefix_group.prefix == "": + extra_clause = "TRUE" + elif prefix_group.exact_match: + extra_clause = f"table_name = '{prefix_group.prefix}'" + else: + extra_clause = f"table_name LIKE '{prefix_group.prefix}%'" + + selects.append( + columns_template.format( + db_name=db_name, schema_name=schema_name, extra_clause=extra_clause + ) + ) + + return ( + "\nUNION ALL\n".join(selects) + + """\nORDER BY table_name, ordinal_position""" + ) @staticmethod def show_primary_keys_for_schema(schema_name: str, db_name: str) -> str: diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_report.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_report.py index d84580a94ab4e4..4924546383aa43 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_report.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_report.py @@ -117,7 +117,6 @@ class SnowflakeV2Report( # "Information schema query returned too much data. Please repeat query with more selective predicates."" # This will result in overall increase in time complexity num_get_tables_for_schema_queries: int = 0 - num_get_columns_for_table_queries: int = 0 # these will be non-zero if the user choses to enable the extract_tags = "with_lineage" option, which requires # individual queries per object (database, schema, table) and an extra query per table to get the tags on the columns. diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py index 3254224e437a6e..4bc684a22514c4 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py @@ -3,7 +3,7 @@ from collections import defaultdict from dataclasses import dataclass, field from datetime import datetime -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, Iterable, List, MutableMapping, Optional from snowflake.connector import SnowflakeConnection @@ -15,6 +15,8 @@ ) from datahub.ingestion.source.snowflake.snowflake_utils import SnowflakeQueryMixin from datahub.ingestion.source.sql.sql_generic import BaseColumn, BaseTable, BaseView +from datahub.utilities.file_backed_collections import FileBackedDict +from datahub.utilities.prefix_batch_builder import build_prefix_batches from datahub.utilities.serialized_lru_cache import serialized_lru_cache logger: logging.Logger = logging.getLogger(__name__) @@ -379,58 +381,48 @@ def get_views_for_database(self, db_name: str) -> Dict[str, List[SnowflakeView]] @serialized_lru_cache(maxsize=SCHEMA_PARALLELISM) def get_columns_for_schema( - self, schema_name: str, db_name: str - ) -> Optional[Dict[str, List[SnowflakeColumn]]]: - columns: Dict[str, List[SnowflakeColumn]] = {} - try: - cur = self.query(SnowflakeQuery.columns_for_schema(schema_name, db_name)) - except Exception as e: - logger.debug( - f"Failed to get all columns for schema - {schema_name}", exc_info=e - ) - # Error - Information schema query returned too much data. - # Please repeat query with more selective predicates. - return None - - for column in cur: - if column["TABLE_NAME"] not in columns: - columns[column["TABLE_NAME"]] = [] - columns[column["TABLE_NAME"]].append( - SnowflakeColumn( - name=column["COLUMN_NAME"], - ordinal_position=column["ORDINAL_POSITION"], - is_nullable=column["IS_NULLABLE"] == "YES", - data_type=column["DATA_TYPE"], - comment=column["COMMENT"], - character_maximum_length=column["CHARACTER_MAXIMUM_LENGTH"], - numeric_precision=column["NUMERIC_PRECISION"], - numeric_scale=column["NUMERIC_SCALE"], + self, + schema_name: str, + db_name: str, + # HACK: This key is excluded from the cache key. + cache_exclude_all_objects: Iterable[str], + ) -> MutableMapping[str, List[SnowflakeColumn]]: + all_objects = list(cache_exclude_all_objects) + + columns: MutableMapping[str, List[SnowflakeColumn]] = {} + if len(all_objects) > 10000: + # For massive schemas, use a FileBackedDict to avoid memory issues. + columns = FileBackedDict() + + object_batches = build_prefix_batches( + all_objects, max_batch_size=10000, max_groups_in_batch=5 + ) + for batch_index, object_batch in enumerate(object_batches): + if batch_index > 0: + logger.info( + f"Still fetching columns for {db_name}.{schema_name} - batch {batch_index + 1} of {len(object_batches)}" ) + query = SnowflakeQuery.columns_for_schema( + schema_name, db_name, object_batch ) - return columns - - def get_columns_for_table( - self, table_name: str, schema_name: str, db_name: str - ) -> List[SnowflakeColumn]: - columns: List[SnowflakeColumn] = [] - cur = self.query( - SnowflakeQuery.columns_for_table(table_name, schema_name, db_name), - ) - - for column in cur: - columns.append( - SnowflakeColumn( - name=column["COLUMN_NAME"], - ordinal_position=column["ORDINAL_POSITION"], - is_nullable=column["IS_NULLABLE"] == "YES", - data_type=column["DATA_TYPE"], - comment=column["COMMENT"], - character_maximum_length=column["CHARACTER_MAXIMUM_LENGTH"], - numeric_precision=column["NUMERIC_PRECISION"], - numeric_scale=column["NUMERIC_SCALE"], + cur = self.query(query) + + for column in cur: + if column["TABLE_NAME"] not in columns: + columns[column["TABLE_NAME"]] = [] + columns[column["TABLE_NAME"]].append( + SnowflakeColumn( + name=column["COLUMN_NAME"], + ordinal_position=column["ORDINAL_POSITION"], + is_nullable=column["IS_NULLABLE"] == "YES", + data_type=column["DATA_TYPE"], + comment=column["COMMENT"], + character_maximum_length=column["CHARACTER_MAXIMUM_LENGTH"], + numeric_precision=column["NUMERIC_PRECISION"], + numeric_scale=column["NUMERIC_SCALE"], + ) ) - ) return columns @serialized_lru_cache(maxsize=SCHEMA_PARALLELISM) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema_gen.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema_gen.py index 920cf741770c39..b6f16cd671b8d3 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema_gen.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema_gen.py @@ -1,4 +1,5 @@ import concurrent.futures +import itertools import logging import queue from typing import Dict, Iterable, List, Optional, Union @@ -321,7 +322,7 @@ def _process_schema_worker(snowflake_schema: SnowflakeSchema) -> None: # Read from the queue and yield the work units until all futures are done. while True: - if q.empty(): + if not q.empty(): while not q.empty(): yield q.get_nowait() else: @@ -394,18 +395,24 @@ def _process_schema( if self.config.include_technical_schema: yield from self.gen_schema_containers(snowflake_schema, db_name) + # We need to do this first so that we can use it when fetching columns. if self.config.include_tables: tables = self.fetch_tables_for_schema( snowflake_schema, db_name, schema_name ) + if self.config.include_views: + views = self.fetch_views_for_schema(snowflake_schema, db_name, schema_name) + + if self.config.include_tables: db_tables[schema_name] = tables if self.config.include_technical_schema: data_reader = self.make_data_reader() for table in tables: table_wu_generator = self._process_table( - table, schema_name, db_name + table, snowflake_schema, db_name ) + yield from classification_workunit_processor( table_wu_generator, self.classification_handler, @@ -414,7 +421,6 @@ def _process_schema( ) if self.config.include_views: - views = self.fetch_views_for_schema(snowflake_schema, db_name, schema_name) if ( self.aggregator and self.config.include_view_lineage @@ -434,7 +440,7 @@ def _process_schema( if self.config.include_technical_schema: for view in views: - yield from self._process_view(view, schema_name, db_name) + yield from self._process_view(view, snowflake_schema, db_name) if self.config.include_technical_schema and snowflake_schema.tags: for tag in snowflake_schema.tags: @@ -522,16 +528,27 @@ def make_data_reader(self) -> Optional[SnowflakeDataReader]: def _process_table( self, table: SnowflakeTable, - schema_name: str, + snowflake_schema: SnowflakeSchema, db_name: str, ) -> Iterable[MetadataWorkUnit]: + schema_name = snowflake_schema.name table_identifier = self.get_dataset_identifier(table.name, schema_name, db_name) - self.fetch_columns_for_table(table, schema_name, db_name, table_identifier) - - self.fetch_pk_for_table(table, schema_name, db_name, table_identifier) - - self.fetch_foreign_keys_for_table(table, schema_name, db_name, table_identifier) + try: + table.columns = self.get_columns_for_table( + table.name, snowflake_schema, db_name + ) + table.column_count = len(table.columns) + if self.config.extract_tags != TagOption.skip: + table.column_tags = self.tag_extractor.get_column_tags_for_table( + table.name, schema_name, db_name + ) + except Exception as e: + logger.debug( + f"Failed to get columns for table {table_identifier} due to error {e}", + exc_info=e, + ) + self.report_warning("Failed to get columns for table", table_identifier) if self.config.extract_tags != TagOption.skip: table.tags = self.tag_extractor.get_tags_on_object( @@ -542,12 +559,13 @@ def _process_table( ) if self.config.include_technical_schema: - if table.tags: - for tag in table.tags: - yield from self._process_tag(tag) - for column_name in table.column_tags: - for tag in table.column_tags[column_name]: - yield from self._process_tag(tag) + if self.config.include_primary_keys: + self.fetch_pk_for_table(table, schema_name, db_name, table_identifier) + + if self.config.include_foreign_keys: + self.fetch_foreign_keys_for_table( + table, schema_name, db_name, table_identifier + ) yield from self.gen_dataset_workunits(table, schema_name, db_name) @@ -587,37 +605,19 @@ def fetch_pk_for_table( ) self.report_warning("Failed to get primary key for table", table_identifier) - def fetch_columns_for_table( - self, - table: SnowflakeTable, - schema_name: str, - db_name: str, - table_identifier: str, - ) -> None: - try: - table.columns = self.get_columns_for_table(table.name, schema_name, db_name) - table.column_count = len(table.columns) - if self.config.extract_tags != TagOption.skip: - table.column_tags = self.tag_extractor.get_column_tags_for_table( - table.name, schema_name, db_name - ) - except Exception as e: - logger.debug( - f"Failed to get columns for table {table_identifier} due to error {e}", - exc_info=e, - ) - self.report_warning("Failed to get columns for table", table_identifier) - def _process_view( self, view: SnowflakeView, - schema_name: str, + snowflake_schema: SnowflakeSchema, db_name: str, ) -> Iterable[MetadataWorkUnit]: + schema_name = snowflake_schema.name view_name = self.get_dataset_identifier(view.name, schema_name, db_name) try: - view.columns = self.get_columns_for_table(view.name, schema_name, db_name) + view.columns = self.get_columns_for_table( + view.name, snowflake_schema, db_name + ) if self.config.extract_tags != TagOption.skip: view.column_tags = self.tag_extractor.get_column_tags_for_table( view.name, schema_name, db_name @@ -638,13 +638,6 @@ def _process_view( ) if self.config.include_technical_schema: - if view.tags: - for tag in view.tags: - yield from self._process_tag(tag) - for column_name in view.column_tags: - for tag in view.column_tags[column_name]: - yield from self._process_tag(tag) - yield from self.gen_dataset_workunits(view, schema_name, db_name) def _process_tag(self, tag: SnowflakeTag) -> Iterable[MetadataWorkUnit]: @@ -663,6 +656,13 @@ def gen_dataset_workunits( schema_name: str, db_name: str, ) -> Iterable[MetadataWorkUnit]: + if table.tags: + for tag in table.tags: + yield from self._process_tag(tag) + for column_name in table.column_tags: + for tag in table.column_tags[column_name]: + yield from self._process_tag(tag) + dataset_name = self.get_dataset_identifier(table.name, schema_name, db_name) dataset_urn = self.gen_dataset_urn(dataset_name) @@ -1015,17 +1015,16 @@ def get_views_for_schema( return views.get(schema_name, []) def get_columns_for_table( - self, table_name: str, schema_name: str, db_name: str + self, table_name: str, snowflake_schema: SnowflakeSchema, db_name: str ) -> List[SnowflakeColumn]: - columns = self.data_dictionary.get_columns_for_schema(schema_name, db_name) - - # get all columns for schema failed, - # falling back to get columns for table - if columns is None: - self.report.num_get_columns_for_table_queries += 1 - return self.data_dictionary.get_columns_for_table( - table_name, schema_name, db_name - ) + schema_name = snowflake_schema.name + columns = self.data_dictionary.get_columns_for_schema( + schema_name, + db_name, + cache_exclude_all_objects=itertools.chain( + snowflake_schema.tables, snowflake_schema.views + ), + ) # Access to table but none of its columns - is this possible ? return columns.get(table_name, []) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py index 06d7042e02456c..f39620b79cfd43 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py @@ -158,6 +158,7 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): and self.config.include_tables and self.config.include_views ) + and not self.config.lazy_schema_resolver else None ), generate_usage_statistics=False, diff --git a/metadata-ingestion/src/datahub/utilities/prefix_batch_builder.py b/metadata-ingestion/src/datahub/utilities/prefix_batch_builder.py new file mode 100644 index 00000000000000..b6da7a1fbd1521 --- /dev/null +++ b/metadata-ingestion/src/datahub/utilities/prefix_batch_builder.py @@ -0,0 +1,81 @@ +import dataclasses +from collections import defaultdict +from typing import List + + +@dataclasses.dataclass +class PrefixGroup: + prefix: str + names: List[str] # every name in the list has the same prefix + exact_match: bool = False + + +def build_prefix_batches( + names: List[str], max_batch_size: int, max_groups_in_batch: int +) -> List[List[PrefixGroup]]: + """Split the names into a list of batches, where each batch is a list of groups and each group is a list of names with a common prefix.""" + + groups = _build_prefix_groups(names, max_batch_size=max_batch_size) + batches = _batch_prefix_groups( + groups, max_batch_size=max_batch_size, max_groups_in_batch=max_groups_in_batch + ) + return batches + + +def _build_prefix_groups(names: List[str], max_batch_size: int) -> List[PrefixGroup]: + """Given a list of names, group them by shared prefixes such that no group is larger than `max_batch_size`.""" + + def split_group(group: PrefixGroup) -> List[PrefixGroup]: + if len(group.names) <= max_batch_size: + return [group] + + result = [] + + # Split into subgroups by the next character. + prefix_length = len(group.prefix) + 1 + subgroups = defaultdict(list) + for name in group.names: + if len(name) <= prefix_length: + # Handle cases where a single name is also the prefix for a large number of names. + # For example, if NAME and NAME_{1..10000} are both in the list. + result.append(PrefixGroup(prefix=name, names=[name], exact_match=True)) + continue + + prefix = name[:prefix_length] + subgroups[prefix].append(name) + + for prefix, names in subgroups.items(): + result.extend(split_group(PrefixGroup(prefix=prefix, names=names))) + + return result + + return split_group(PrefixGroup(prefix="", names=sorted(names))) + + +def _batch_prefix_groups( + groups: List[PrefixGroup], max_batch_size: int, max_groups_in_batch: int +) -> List[List[PrefixGroup]]: + """Batch the groups together, so that no batch's total is larger than `max_batch_size` + and no group in a batch is larger than `max_group_size`.""" + + # A batch is a set of groups. + + # This is a variant of the 1D bin packing problem, which is actually NP-hard. + # However, we'll just use a greedy algorithm for simplicity. + + batches = [] + current_batch_size = 0 + batch: List[PrefixGroup] = [] + for group in groups: + if ( + current_batch_size + len(group.names) > max_batch_size + or len(batch) > max_groups_in_batch + ): + batches.append(batch) + batch = [] + current_batch_size = 0 + batch.append(group) + current_batch_size += len(group.names) + if batch: + batches.append(batch) + return batches diff --git a/metadata-ingestion/src/datahub/utilities/serialized_lru_cache.py b/metadata-ingestion/src/datahub/utilities/serialized_lru_cache.py index 23523501ee0b49..b5f490720340ce 100644 --- a/metadata-ingestion/src/datahub/utilities/serialized_lru_cache.py +++ b/metadata-ingestion/src/datahub/utilities/serialized_lru_cache.py @@ -41,7 +41,7 @@ def decorator(func: Callable[_F, _T]) -> Callable[_F, _T]: def wrapper(*args: _F.args, **kwargs: _F.kwargs) -> _T: # We need a type ignore here because there's no way for us to require that # the args and kwargs are hashable while using ParamSpec. - key: _Key = cachetools.keys.hashkey(*args, **kwargs) # type: ignore + key: _Key = cachetools.keys.hashkey(*args, **{k: v for k, v in kwargs.items() if "cache_exclude" not in k}) # type: ignore with cache_lock: if key in cache: diff --git a/metadata-ingestion/tests/integration/snowflake/common.py b/metadata-ingestion/tests/integration/snowflake/common.py index ea08a942674808..881fac96f82e8d 100644 --- a/metadata-ingestion/tests/integration/snowflake/common.py +++ b/metadata-ingestion/tests/integration/snowflake/common.py @@ -252,26 +252,11 @@ def default_query_results( # noqa: C901 for view_idx in range(1, num_views + 1) ] elif query == SnowflakeQuery.columns_for_schema("TEST_SCHEMA", "TEST_DB"): - raise Exception("Information schema query returned too much data") - elif query in [ - *[ - SnowflakeQuery.columns_for_table( - f"TABLE_{tbl_idx}", "TEST_SCHEMA", "TEST_DB" - ) - for tbl_idx in range(1, num_tables + 1) - ], - *[ - SnowflakeQuery.columns_for_table( - f"VIEW_{view_idx}", "TEST_SCHEMA", "TEST_DB" - ) - for view_idx in range(1, num_views + 1) - ], - ]: return [ { - # "TABLE_CATALOG": "TEST_DB", - # "TABLE_SCHEMA": "TEST_SCHEMA", - # "TABLE_NAME": "TABLE_{}".format(tbl_idx), + "TABLE_CATALOG": "TEST_DB", + "TABLE_SCHEMA": "TEST_SCHEMA", + "TABLE_NAME": table_name, "COLUMN_NAME": f"COL_{col_idx}", "ORDINAL_POSITION": col_idx, "IS_NULLABLE": "NO", @@ -281,6 +266,10 @@ def default_query_results( # noqa: C901 "NUMERIC_PRECISION": None if col_idx > 1 else 38, "NUMERIC_SCALE": None if col_idx > 1 else 0, } + for table_name in ( + [f"TABLE_{tbl_idx}" for tbl_idx in range(1, num_tables + 1)] + + [f"VIEW_{view_idx}" for view_idx in range(1, num_views + 1)] + ) for col_idx in range(1, num_cols + 1) ] elif query in ( diff --git a/metadata-ingestion/tests/integration/snowflake/test_snowflake_failures.py b/metadata-ingestion/tests/integration/snowflake/test_snowflake_failures.py index 3a37382de65b46..23f5c10b10f8e8 100644 --- a/metadata-ingestion/tests/integration/snowflake/test_snowflake_failures.py +++ b/metadata-ingestion/tests/integration/snowflake/test_snowflake_failures.py @@ -10,11 +10,7 @@ from datahub.ingestion.source.snowflake import snowflake_query from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery -from tests.integration.snowflake.common import ( - FROZEN_TIME, - NUM_TABLES, - default_query_results, -) +from tests.integration.snowflake.common import FROZEN_TIME, default_query_results def query_permission_error_override(fn, override_for_query, error_msg): @@ -168,10 +164,7 @@ def test_snowflake_list_columns_error_causes_pipeline_warning( sf_cursor.execute.side_effect = query_permission_error_override( default_query_results, [ - SnowflakeQuery.columns_for_table( - f"TABLE_{tbl_idx}", "TEST_SCHEMA", "TEST_DB" - ) - for tbl_idx in range(1, NUM_TABLES + 1) + SnowflakeQuery.columns_for_schema("TEST_SCHEMA", "TEST_DB"), ], "Database 'TEST_DB' does not exist or not authorized.", ) diff --git a/metadata-ingestion/tests/unit/utilities/test_prefix_patch_builder.py b/metadata-ingestion/tests/unit/utilities/test_prefix_patch_builder.py new file mode 100644 index 00000000000000..19af7e9f66c1ab --- /dev/null +++ b/metadata-ingestion/tests/unit/utilities/test_prefix_patch_builder.py @@ -0,0 +1,45 @@ +from datahub.utilities.prefix_batch_builder import PrefixGroup, build_prefix_batches + + +def test_build_prefix_batches_empty_input(): + assert build_prefix_batches([], 10, 5) == [[PrefixGroup(prefix="", names=[])]] + + +def test_build_prefix_batches_single_group(): + names = ["apple", "applet", "application"] + expected = [[PrefixGroup(prefix="", names=names)]] + assert build_prefix_batches(names, 10, 5) == expected + + +def test_build_prefix_batches_multiple_groups(): + names = ["apple", "applet", "banana", "band", "bandana"] + expected = [ + [PrefixGroup(prefix="a", names=["apple", "applet"])], + [PrefixGroup(prefix="b", names=["banana", "band", "bandana"])], + ] + assert build_prefix_batches(names, 4, 5) == expected + + +def test_build_prefix_batches_exceeds_max_batch_size(): + names = [ + "app", + "apple", + "applet", + "application", + "banana", + "band", + "bandana", + "candy", + "candle", + "dog", + ] + expected = [ + [PrefixGroup(prefix="app", names=["app"], exact_match=True)], + [PrefixGroup(prefix="app", names=["apple", "applet", "application"])], + [PrefixGroup(prefix="b", names=["banana", "band", "bandana"])], + [ + PrefixGroup(prefix="c", names=["candle", "candy"]), + PrefixGroup(prefix="d", names=["dog"]), + ], + ] + assert build_prefix_batches(names, 3, 2) == expected