diff --git a/metadata-ingestion/docs/sources/snowflake/README.md b/metadata-ingestion/docs/sources/snowflake/README.md index 1ca05f5ec4887..eb65620a9c2b6 100644 --- a/metadata-ingestion/docs/sources/snowflake/README.md +++ b/metadata-ingestion/docs/sources/snowflake/README.md @@ -1,4 +1,4 @@ To get all metadata from Snowflake you need to use two plugins `snowflake` and `snowflake-usage`. Both of them are described in this page. These will require 2 separate recipes. -We encourage you to try out new `snowflake-beta` plugin as alternative to running both `snowflake` and `snowflake-usage` plugins and share feedback. `snowflake-beta` is much faster than `snowflake` for extracting metadata . Please note that, `snowflake-beta` plugin currently does not support column level profiling, unlike `snowflake` plugin. \ No newline at end of file +We encourage you to try out new `snowflake-beta` plugin as alternative to running both `snowflake` and `snowflake-usage` plugins and share feedback. `snowflake-beta` is much faster than `snowflake` for extracting metadata . \ No newline at end of file diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_profiler.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_profiler.py index 2d7737bf8385c..551b5443c3d5d 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_profiler.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_profiler.py @@ -1,6 +1,7 @@ +import dataclasses import datetime import logging -from typing import Callable, Dict, Iterable, List, Optional +from typing import Callable, Dict, Iterable, List, Optional, Tuple, cast from sqlalchemy import create_engine, inspect @@ -19,10 +20,17 @@ ) from datahub.ingestion.source.snowflake.snowflake_utils import SnowflakeCommonMixin from datahub.metadata.com.linkedin.pegasus2avro.dataset import DatasetProfile +from datahub.metadata.schema_classes import DatasetProfileClass logger = logging.getLogger(__name__) +@dataclasses.dataclass +class SnowflakeProfilerRequest(GEProfilerRequest): + table: SnowflakeTable + profile_table_level_only: bool = False + + class SnowflakeProfiler(SnowflakeCommonMixin): def __init__(self, config: SnowflakeV2Config, report: SnowflakeV2Report) -> None: self.config = config @@ -31,12 +39,6 @@ def __init__(self, config: SnowflakeV2Config, report: SnowflakeV2Report) -> None def get_workunits(self, databases: List[SnowflakeDatabase]) -> Iterable[WorkUnit]: - # If only table level profiling is enabled, report table profile and exit - if self.config.profiling.profile_table_level_only: - - yield from self.get_table_level_profile_workunits(databases) - return - # Extra default SQLAlchemy option for better connection pooling and threading. # https://docs.sqlalchemy.org/en/14/core/pooling.html#sqlalchemy.pool.QueuePool.params.max_overflow if self.config.profiling.enabled: @@ -55,7 +57,7 @@ def get_workunits(self, databases: List[SnowflakeDatabase]) -> Iterable[WorkUnit for table in schema.tables: # Emit the profile work unit - profile_request = self.get_ge_profile_request( + profile_request = self.get_snowflake_profile_request( table, schema.name, db.name ) if profile_request is not None: @@ -63,13 +65,14 @@ def get_workunits(self, databases: List[SnowflakeDatabase]) -> Iterable[WorkUnit if len(profile_requests) == 0: continue - ge_profiler = self.get_profiler_instance(db.name) - for request, profile in ge_profiler.generate_profiles( + for request, profile in self.generate_profiles( + db.name, profile_requests, self.config.profiling.max_workers, platform=self.platform, profiler_args=self.get_profile_args(), ): + profile.sizeInBytes = request.table.size_in_bytes # type:ignore if profile is None: continue dataset_name = request.pretty_name @@ -86,68 +89,26 @@ def get_workunits(self, databases: List[SnowflakeDatabase]) -> Iterable[WorkUnit profile, ) - def get_table_level_profile_workunits( - self, databases: List[SnowflakeDatabase] - ) -> Iterable[WorkUnit]: - for db in databases: - if not self.config.database_pattern.allowed(db.name): - continue - for schema in db.schemas: - if not self.config.schema_pattern.allowed(schema.name): - continue - for table in schema.tables: - dataset_name = self.get_dataset_identifier( - table.name, schema.name, db.name - ) - skip_profiling = False - # no need to filter by size_in_bytes and row_count limits, - # if table level profilin, since its not expensive - if not self.is_dataset_eligible_for_profiling( - dataset_name, - table.last_altered, - 0, - 0, - ): - skip_profiling = True - - if skip_profiling: - if self.config.profiling.report_dropped_profiles: - self.report.report_dropped(f"profile of {dataset_name}") - return None - - self.report.report_entity_profiled(dataset_name) - - dataset_urn = make_dataset_urn_with_platform_instance( - self.platform, - dataset_name, - self.config.platform_instance, - self.config.env, - ) - yield self.wrap_aspect_as_workunit( - "dataset", - dataset_urn, - "datasetProfile", - DatasetProfile( - timestampMillis=round( - datetime.datetime.now().timestamp() * 1000 - ), - columnCount=len(table.columns), - rowCount=table.rows_count, - ), - ) - - def get_ge_profile_request( + def get_snowflake_profile_request( self, table: SnowflakeTable, schema_name: str, db_name: str, - ) -> Optional[GEProfilerRequest]: + ) -> Optional[SnowflakeProfilerRequest]: skip_profiling = False + profile_table_level_only = self.config.profiling.profile_table_level_only dataset_name = self.get_dataset_identifier(table.name, schema_name, db_name) if not self.is_dataset_eligible_for_profiling( dataset_name, table.last_altered, table.size_in_bytes, table.rows_count ): - skip_profiling = True + # Profile only table level if dataset is filtered from profiling + # due to size limits alone + if self.is_dataset_eligible_for_profiling( + dataset_name, table.last_altered, 0, 0 + ): + profile_table_level_only = True + else: + skip_profiling = True if len(table.columns) == 0: skip_profiling = True @@ -159,9 +120,11 @@ def get_ge_profile_request( self.report.report_entity_profiled(dataset_name) logger.debug(f"Preparing profiling request for {dataset_name}") - profile_request = GEProfilerRequest( + profile_request = SnowflakeProfilerRequest( pretty_name=dataset_name, batch_kwargs=dict(schema=schema_name, table=table.name), + table=table, + profile_table_level_only=profile_table_level_only, ) return profile_request @@ -237,3 +200,37 @@ def get_db_connection(): return conn return get_db_connection + + def generate_profiles( + self, + db_name: str, + requests: List[SnowflakeProfilerRequest], + max_workers: int, + platform: Optional[str] = None, + profiler_args: Optional[Dict] = None, + ) -> Iterable[Tuple[GEProfilerRequest, Optional[DatasetProfileClass]]]: + + ge_profile_requests: List[GEProfilerRequest] = [ + cast(GEProfilerRequest, request) + for request in requests + if not request.profile_table_level_only + ] + table_level_profile_requests: List[SnowflakeProfilerRequest] = [ + request for request in requests if request.profile_table_level_only + ] + for request in table_level_profile_requests: + profile = DatasetProfile( + timestampMillis=round(datetime.datetime.now().timestamp() * 1000), + columnCount=len(request.table.columns), + rowCount=request.table.rows_count, + sizeInBytes=request.table.size_in_bytes, + ) + yield (request, profile) + + if len(ge_profile_requests) == 0: + return + + ge_profiler = self.get_profiler_instance(db_name) + yield from ge_profiler.generate_profiles( + ge_profile_requests, max_workers, platform, profiler_args + )