Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingestion): pull metabase database, schema names from raw query and api #7039

Merged
merged 2 commits into from
Jan 18, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 31 additions & 16 deletions metadata-ingestion/src/datahub/ingestion/source/metabase.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime, timezone
from functools import lru_cache
from typing import Dict, Iterable, Optional
from typing import Dict, Iterable, List, Optional

import dateutil.parser as dp
import pydantic
Expand Down Expand Up @@ -436,12 +436,15 @@ def construct_card_custom_properties(self, card_details: dict) -> Dict:

return custom_properties

def get_datasource_urn(self, card_details):
platform, database_name, platform_instance = self.get_datasource_from_id(
card_details.get("database_id", "")
)
def get_datasource_urn(self, card_details: dict) -> Optional[List]:
(
platform,
database_name,
database_schema,
platform_instance,
) = self.get_datasource_from_id(card_details.get("database_id", ""))
query_type = card_details.get("dataset_query", {}).get("type", {})
source_paths = set()
source_tables = set()

if query_type == "query":
source_table_id = (
Expand All @@ -452,8 +455,8 @@ def get_datasource_urn(self, card_details):
if source_table_id is not None:
schema_name, table_name = self.get_source_table_from_id(source_table_id)
if table_name:
source_paths.add(
f"{f'{schema_name}.' if schema_name else ''}{table_name}"
source_tables.add(
f"{database_name + '.' if database_name else ''}{schema_name + '.' if schema_name else ''}{table_name}"
)
else:
try:
Expand All @@ -466,11 +469,19 @@ def get_datasource_urn(self, card_details):

for table in parser.source_tables:
sources = str(table).split(".")

source_db = sources[-3] if len(sources) > 2 else database_name
source_schema, source_table = sources[-2], sources[-1]
if source_schema == "<default>":
source_schema = str(self.config.default_schema)

source_paths.add(f"{source_schema}.{source_table}")
source_schema = (
database_schema
if database_schema is not None
else str(self.config.default_schema)
)

source_tables.add(
f"{source_db + '.' if source_db else ''}{source_schema}.{source_table}"
)
except Exception as e:
self.report.report_failure(
key="metabase-query",
Expand All @@ -480,10 +491,10 @@ def get_datasource_urn(self, card_details):
)
return None

if platform == "snowflake":
source_tables = set(i.lower() for i in source_tables)

# Create dataset URNs
dataset_urn = []
dbname = f"{f'{database_name}.' if database_name else ''}"
source_tables = list(map(lambda tbl: f"{dbname}{tbl}", source_paths))
dataset_urn = [
builder.make_dataset_urn_with_platform_instance(
platform=platform,
Expand Down Expand Up @@ -535,7 +546,6 @@ def get_datasource_from_id(self, datasource_id):
# Map engine names to what datahub expects in
# https://github.com/datahub-project/datahub/blob/master/metadata-service/war/src/main/resources/boot/data_platforms.json
engine = dataset_json.get("engine", "")
platform = engine

engine_mapping = {
"sparksql": "spark",
Expand All @@ -551,10 +561,13 @@ def get_datasource_from_id(self, datasource_id):
if engine in engine_mapping:
platform = engine_mapping[engine]
else:
platform = engine

self.report.report_warning(
key=f"metabase-platform-{datasource_id}",
reason=f"Platform was not found in DataHub. Using {platform} name as is",
)

# Set platform_instance if configuration provides a mapping from platform name to instance
platform_instance = (
self.config.platform_instance_map.get(platform)
Expand All @@ -580,6 +593,8 @@ def get_datasource_from_id(self, datasource_id):
else None
)

schema = dataset_json.get("details", {}).get("schema")

if (
self.config.database_alias_map is not None
and platform in self.config.database_alias_map
Expand All @@ -591,7 +606,7 @@ def get_datasource_from_id(self, datasource_id):
reason=f"Cannot determine database name for platform: {platform}",
)

return platform, dbname, platform_instance
return platform, dbname, schema, platform_instance

@classmethod
def create(cls, config_dict: dict, ctx: PipelineContext) -> Source:
Expand Down