Skip to content

Commit

Permalink
feat(model, ingest): sizeInBytes in datasetProfile, populate size in …
Browse files Browse the repository at this point in the history
…snowflake
  • Loading branch information
mayurinehate committed Aug 30, 2022
1 parent ee43262 commit 6b3d679
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 63 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dataclasses
import datetime
import logging
from typing import Callable, Dict, Iterable, List, Optional
from typing import Callable, Dict, Iterable, List, Optional, Tuple, cast

from sqlalchemy import create_engine, inspect

Expand All @@ -19,10 +20,17 @@
)
from datahub.ingestion.source.snowflake.snowflake_utils import SnowflakeCommonMixin
from datahub.metadata.com.linkedin.pegasus2avro.dataset import DatasetProfile
from datahub.metadata.schema_classes import DatasetProfileClass

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class SnowflakeProfilerRequest(GEProfilerRequest):
table: SnowflakeTable
profile_table_level_only: bool = False


class SnowflakeProfiler(SnowflakeCommonMixin):
def __init__(self, config: SnowflakeV2Config, report: SnowflakeV2Report) -> None:
self.config = config
Expand All @@ -31,12 +39,6 @@ def __init__(self, config: SnowflakeV2Config, report: SnowflakeV2Report) -> None

def get_workunits(self, databases: List[SnowflakeDatabase]) -> Iterable[WorkUnit]:

# If only table level profiling is enabled, report table profile and exit
if self.config.profiling.profile_table_level_only:

yield from self.get_table_level_profile_workunits(databases)
return

# Extra default SQLAlchemy option for better connection pooling and threading.
# https://docs.sqlalchemy.org/en/14/core/pooling.html#sqlalchemy.pool.QueuePool.params.max_overflow
if self.config.profiling.enabled:
Expand All @@ -55,21 +57,22 @@ def get_workunits(self, databases: List[SnowflakeDatabase]) -> Iterable[WorkUnit
for table in schema.tables:

# Emit the profile work unit
profile_request = self.get_ge_profile_request(
profile_request = self.get_snowflake_profile_request(
table, schema.name, db.name
)
if profile_request is not None:
profile_requests.append(profile_request)

if len(profile_requests) == 0:
continue
ge_profiler = self.get_profiler_instance(db.name)
for request, profile in ge_profiler.generate_profiles(
for request, profile in self.generate_profiles(
db.name,
profile_requests,
self.config.profiling.max_workers,
platform=self.platform,
profiler_args=self.get_profile_args(),
):
profile.sizeInBytes = request.table.size_in_bytes # type:ignore
if profile is None:
continue
dataset_name = request.pretty_name
Expand All @@ -86,67 +89,26 @@ def get_workunits(self, databases: List[SnowflakeDatabase]) -> Iterable[WorkUnit
profile,
)

def get_table_level_profile_workunits(
self, databases: List[SnowflakeDatabase]
) -> Iterable[WorkUnit]:
for db in databases:
if not self.config.database_pattern.allowed(db.name):
continue
for schema in db.schemas:
if not self.config.schema_pattern.allowed(schema.name):
continue
for table in schema.tables:
dataset_name = self.get_dataset_identifier(
table.name, schema.name, db.name
)
# no need to filter by size_in_bytes and row_count limits,
# if table level profilin, since its not expensive
if not self.is_dataset_eligible_for_profiling(
dataset_name,
table.last_altered,
0,
0,
):
skip_profiling = True

if skip_profiling:
if self.config.profiling.report_dropped_profiles:
self.report.report_dropped(f"profile of {dataset_name}")
return None

self.report.report_entity_profiled(dataset_name)

dataset_urn = make_dataset_urn_with_platform_instance(
self.platform,
dataset_name,
self.config.platform_instance,
self.config.env,
)
yield self.wrap_aspect_as_workunit(
"dataset",
dataset_urn,
"datasetProfile",
DatasetProfile(
timestampMillis=round(
datetime.datetime.now().timestamp() * 1000
),
columnCount=len(table.columns),
rowCount=table.rows_count,
),
)

def get_ge_profile_request(
def get_snowflake_profile_request(
self,
table: SnowflakeTable,
schema_name: str,
db_name: str,
) -> Optional[GEProfilerRequest]:
) -> Optional[SnowflakeProfilerRequest]:
skip_profiling = False
profile_table_level_only = self.config.profiling.profile_table_level_only
dataset_name = self.get_dataset_identifier(table.name, schema_name, db_name)
if not self.is_dataset_eligible_for_profiling(
dataset_name, table.last_altered, table.size_in_bytes, table.rows_count
):
skip_profiling = True
# Profile only table level if dataset is filtered from profiling
# due to size limits alone
if self.is_dataset_eligible_for_profiling(
dataset_name, table.last_altered, 0, 0
):
profile_table_level_only = True
else:
skip_profiling = True

if len(table.columns) == 0:
skip_profiling = True
Expand All @@ -158,9 +120,11 @@ def get_ge_profile_request(

self.report.report_entity_profiled(dataset_name)
logger.debug(f"Preparing profiling request for {dataset_name}")
profile_request = GEProfilerRequest(
profile_request = SnowflakeProfilerRequest(
pretty_name=dataset_name,
batch_kwargs=dict(schema=schema_name, table=table.name),
table=table,
profile_table_level_only=profile_table_level_only,
)
return profile_request

Expand Down Expand Up @@ -236,3 +200,37 @@ def get_db_connection():
return conn

return get_db_connection

def generate_profiles(
self,
db_name: str,
requests: List[SnowflakeProfilerRequest],
max_workers: int,
platform: Optional[str] = None,
profiler_args: Optional[Dict] = None,
) -> Iterable[Tuple[GEProfilerRequest, Optional[DatasetProfileClass]]]:

ge_profile_requests: List[GEProfilerRequest] = [
cast(GEProfilerRequest, request)
for request in requests
if not request.profile_table_level_only
]
table_level_profile_requests: List[SnowflakeProfilerRequest] = [
request for request in requests if request.profile_table_level_only
]
for request in table_level_profile_requests:
profile = DatasetProfile(
timestampMillis=round(datetime.datetime.now().timestamp() * 1000),
columnCount=len(request.table.columns),
rowCount=request.table.rows_count,
sizeInBytes=request.table.size_in_bytes,
)
yield (request, profile)

if len(ge_profile_requests) == 0:
return

ge_profiler = self.get_profiler_instance(db_name)
yield from ge_profiler.generate_profiles(
ge_profile_requests, max_workers, platform, profiler_args
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,10 @@ record DatasetProfile includes TimeseriesAspectBase {

columnCount: optional long

/**
* Storage size in bytes
*/
sizeInBytes: optional long

fieldProfiles: optional array[DatasetFieldProfile]
}

0 comments on commit 6b3d679

Please sign in to comment.