Skip to content

Commit

Permalink
feat(ingest): support ingesting from multiple snowflake dbs (#2793)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored Jun 30, 2021
1 parent c7ce817 commit e51f86a
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 20 deletions.
8 changes: 7 additions & 1 deletion metadata-ingestion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,13 @@ source:
username: user
password: pass
host_port: account_name
database: db_name
database_pattern:
allow:
- ^regex$
- ^another_regex$
deny:
- ^SNOWFLAKE$
- ^SNOWFLAKE_SAMPLE_DATA$
warehouse: "COMPUTE_WH" # optional
role: "sysadmin" # optional
include_views: True # whether to include views, defaults to True
Expand Down
1 change: 0 additions & 1 deletion metadata-ingestion/src/datahub/entrypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def datahub(debug: bool) -> None:
logging.getLogger("datahub").setLevel(logging.INFO)
# loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
# print(loggers)
# breakpoint()


@datahub.command()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,6 @@ def from_entry(cls, entry: AuditLogEntry) -> "QueryEvent":
referencedTables = [
BigQueryTableRef.from_spec_obj(spec) for spec in rawRefTables
]
# if job['jobConfiguration']['query']['statementType'] != "SCRIPT" and not referencedTables:
# breakpoint()

queryEvent = QueryEvent(
timestamp=entry.timestamp,
Expand Down
50 changes: 47 additions & 3 deletions metadata-ingestion/src/datahub/ingestion/source/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import logging
from typing import Optional
from typing import Iterable, Optional

import pydantic

# This import verifies that the dependencies are available.
import snowflake.sqlalchemy # noqa: F401
from snowflake.sqlalchemy import custom_types
from sqlalchemy import create_engine, inspect
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql import text
from sqlalchemy.sql.elements import quoted_name

from datahub.configuration.common import ConfigModel
from datahub.configuration.common import AllowDenyPattern, ConfigModel

from .sql_common import (
SQLAlchemyConfig,
Expand All @@ -18,6 +24,7 @@
register_custom_type(custom_types.TIMESTAMP_TZ, TimeTypeClass)
register_custom_type(custom_types.TIMESTAMP_LTZ, TimeTypeClass)
register_custom_type(custom_types.TIMESTAMP_NTZ, TimeTypeClass)
register_custom_type(custom_types.VARIANT)

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -53,7 +60,23 @@ def get_sql_alchemy_url(self, database=None):


class SnowflakeConfig(BaseSnowflakeConfig, SQLAlchemyConfig):
database: str
database_pattern: AllowDenyPattern = AllowDenyPattern(
deny=[
r"^UTIL_DB$",
r"^SNOWFLAKE$",
r"^SNOWFLAKE_SAMPLE_DATA$",
]
)

database: str = ".*" # deprecated

@pydantic.validator("database")
def note_database_opt_deprecation(cls, v, values, **kwargs):
logger.warn(
"snowflake's `database` option has been deprecated; use database_pattern instead"
)
values["database_pattern"].allow = f"^{v}$"
return None

def get_sql_alchemy_url(self):
return super().get_sql_alchemy_url(self.database)
Expand All @@ -64,10 +87,31 @@ def get_identifier(self, schema: str, table: str) -> str:


class SnowflakeSource(SQLAlchemySource):
config: SnowflakeConfig

def __init__(self, config, ctx):
super().__init__(config, ctx, "snowflake")

@classmethod
def create(cls, config_dict, ctx):
config = SnowflakeConfig.parse_obj(config_dict)
return cls(config, ctx)

def get_inspectors(self) -> Iterable[Inspector]:
url = self.config.get_sql_alchemy_url()
logger.debug(f"sql_alchemy_url={url}")
engine = create_engine(url, **self.config.options)

for db_row in engine.execute(text("SHOW DATABASES")):
with engine.connect() as conn:
db = db_row.name
if self.config.database_pattern.allowed(db):
# TRICKY: As we iterate through this loop, we modify the value of
# self.config.database so that the get_identifier method can function
# as intended.
self.config.database = db
conn.execute((f'USE DATABASE "{quoted_name(db, True)}"'))
inspector = inspect(conn)
yield inspector
else:
self.report.report_dropped(db)
33 changes: 21 additions & 12 deletions metadata-ingestion/src/datahub/ingestion/source/sql_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,26 +238,35 @@ def __init__(self, config: SQLAlchemyConfig, ctx: PipelineContext, platform: str
self.platform = platform
self.report = SQLSourceReport()

def get_inspectors(self) -> Iterable[Inspector]:
# This method can be overridden in the case that you want to dynamically
# run on multiple databases.

url = self.config.get_sql_alchemy_url()
logger.debug(f"sql_alchemy_url={url}")
engine = create_engine(url, **self.config.options)
inspector = inspect(engine)
yield inspector

def get_workunits(self) -> Iterable[SqlWorkUnit]:
sql_config = self.config
if logger.isEnabledFor(logging.DEBUG):
# If debug logging is enabled, we also want to echo each SQL query issued.
sql_config.options["echo"] = True

url = sql_config.get_sql_alchemy_url()
logger.debug(f"sql_alchemy_url={url}")
engine = create_engine(url, **sql_config.options)
inspector = inspect(engine)
for schema in inspector.get_schema_names():
if not sql_config.schema_pattern.allowed(schema):
self.report.report_dropped(schema)
continue
for inspector in self.get_inspectors():
for schema in inspector.get_schema_names():
if not sql_config.schema_pattern.allowed(schema):
self.report.report_dropped(
".".join(sql_config.standardize_schema_table_names(schema, "*"))
)
continue

if sql_config.include_tables:
yield from self.loop_tables(inspector, schema, sql_config)
if sql_config.include_tables:
yield from self.loop_tables(inspector, schema, sql_config)

if sql_config.include_views:
yield from self.loop_views(inspector, schema, sql_config)
if sql_config.include_views:
yield from self.loop_views(inspector, schema, sql_config)

def loop_tables(
self,
Expand Down
2 changes: 1 addition & 1 deletion metadata-ingestion/tests/unit/test_snowflake_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ def test_snowflake_uri():

assert (
config.get_sql_alchemy_url()
== "snowflake://user:password@acctname/demo?warehouse=COMPUTE_WH&role=sysadmin"
== "snowflake://user:password@acctname/?warehouse=COMPUTE_WH&role=sysadmin"
)

0 comments on commit e51f86a

Please sign in to comment.