Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(dbt): fix issue of assertion error when stateful ingestion is used with dbt tests #5540

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions metadata-ingestion/src/datahub/emitter/mce_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,13 @@ def make_domain_urn(domain: str) -> str:


def make_ml_primary_key_urn(feature_table_name: str, primary_key_name: str) -> str:

return f"urn:li:mlPrimaryKey:({feature_table_name},{primary_key_name})"


def make_ml_feature_urn(
feature_table_name: str,
feature_name: str,
) -> str:

return f"urn:li:mlFeature:({feature_table_name},{feature_name})"


Expand Down
113 changes: 80 additions & 33 deletions metadata-ingestion/src/datahub/ingestion/source/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
resolve_trino_modified_type,
)
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.dbt_state import DbtCheckpointState
from datahub.ingestion.source.state.sql_common_state import (
BaseSQLAlchemyCheckpointState,
)
Expand Down Expand Up @@ -1005,15 +1006,48 @@ def __init__(self, config: DBTConfig, ctx: PipelineContext, platform: str):
self.config.owner_extraction_pattern
)

def get_last_dbt_checkpoint(
self, job_id: JobId, checkpoint_state_class: Type[DbtCheckpointState]
) -> Optional[Checkpoint]:

last_checkpoint: Optional[Checkpoint]
is_conversion_required: bool = False
try:
# Best-case that last checkpoint state is DbtCheckpointState
last_checkpoint = self.get_last_checkpoint(job_id, checkpoint_state_class)
except Exception as e:
# Backward compatibility for old dbt ingestion source which was saving dbt-nodes in
# BaseSQLAlchemyCheckpointState
last_checkpoint = self.get_last_checkpoint(
job_id, BaseSQLAlchemyCheckpointState
)
logger.debug(
f"Found BaseSQLAlchemyCheckpointState as checkpoint state (got {e})."
)
is_conversion_required = True

if last_checkpoint is not None and is_conversion_required:
# Map the BaseSQLAlchemyCheckpointState to DbtCheckpointState
dbt_checkpoint_state: DbtCheckpointState = DbtCheckpointState()
dbt_checkpoint_state.encoded_node_urns = (
cast(BaseSQLAlchemyCheckpointState, last_checkpoint.state)
).encoded_table_urns
# Old dbt source was not supporting the assertion
dbt_checkpoint_state.encoded_assertion_urns = []
last_checkpoint.state = dbt_checkpoint_state

return last_checkpoint

# TODO: Consider refactoring this logic out for use across sources as it is leading to a significant amount of
# code duplication.
def gen_removed_entity_workunits(self) -> Iterable[MetadataWorkUnit]:
last_checkpoint = self.get_last_checkpoint(
self.get_default_ingestion_job_id(), BaseSQLAlchemyCheckpointState
last_checkpoint: Optional[Checkpoint] = self.get_last_dbt_checkpoint(
self.get_default_ingestion_job_id(), DbtCheckpointState
)
cur_checkpoint = self.get_current_checkpoint(
self.get_default_ingestion_job_id()
)

if (
self.config.stateful_ingestion
and self.config.stateful_ingestion.remove_stale_metadata
Expand All @@ -1024,7 +1058,7 @@ def gen_removed_entity_workunits(self) -> Iterable[MetadataWorkUnit]:
):
logger.debug("Checking for stale entity removal.")

def soft_delete_item(urn: str, type: str) -> Iterable[MetadataWorkUnit]:
def get_soft_delete_item_workunit(urn: str, type: str) -> MetadataWorkUnit:

logger.info(f"Soft-deleting stale entity of type {type} - {urn}.")
mcp = MetadataChangeProposalWrapper(
Expand All @@ -1037,19 +1071,28 @@ def soft_delete_item(urn: str, type: str) -> Iterable[MetadataWorkUnit]:
wu = MetadataWorkUnit(id=f"soft-delete-{type}-{urn}", mcp=mcp)
self.report.report_workunit(wu)
self.report.report_stale_entity_soft_deleted(urn)
yield wu
return wu

last_checkpoint_state = cast(
BaseSQLAlchemyCheckpointState, last_checkpoint.state
)
cur_checkpoint_state = cast(
BaseSQLAlchemyCheckpointState, cur_checkpoint.state
)
last_checkpoint_state = cast(DbtCheckpointState, last_checkpoint.state)
cur_checkpoint_state = cast(DbtCheckpointState, cur_checkpoint.state)

for table_urn in last_checkpoint_state.get_table_urns_not_in(
cur_checkpoint_state
):
yield from soft_delete_item(table_urn, "dataset")
urns_to_soft_delete_by_type: Dict = {
"dataset": [
node_urn
for node_urn in last_checkpoint_state.get_node_urns_not_in(
cur_checkpoint_state
)
],
"assertion": [
assertion_urn
for assertion_urn in last_checkpoint_state.get_assertion_urns_not_in(
cur_checkpoint_state
)
],
}
for entity_type in urns_to_soft_delete_by_type:
for urn in urns_to_soft_delete_by_type[entity_type]:
yield get_soft_delete_item_workunit(urn, entity_type)

def load_file_as_json(self, uri: str) -> Any:
if re.match("^https?://", uri):
Expand Down Expand Up @@ -1155,7 +1198,7 @@ def string_map(input_map: Dict[str, Any]) -> Dict[str, str]:
}
)
)
self.save_checkpoint(node_datahub_urn)
self.save_checkpoint(node_datahub_urn, "assertion")

dpi_mcp = MetadataChangeProposalWrapper(
entityType="assertion",
Expand Down Expand Up @@ -1412,10 +1455,12 @@ def remove_duplicate_urns_from_checkpoint_state(self) -> None:
)

if cur_checkpoint is not None:
# Utilizing BaseSQLAlchemyCheckpointState class to save state
checkpoint_state = cast(BaseSQLAlchemyCheckpointState, cur_checkpoint.state)
checkpoint_state.encoded_table_urns = list(
set(checkpoint_state.encoded_table_urns)
checkpoint_state = cast(DbtCheckpointState, cur_checkpoint.state)
checkpoint_state.encoded_node_urns = list(
set(checkpoint_state.encoded_node_urns)
)
checkpoint_state.encoded_assertion_urns = list(
set(checkpoint_state.encoded_assertion_urns)
)

def create_platform_mces(
Expand Down Expand Up @@ -1458,7 +1503,7 @@ def create_platform_mces(
self.config.env,
mce_platform_instance,
)
self.save_checkpoint(node_datahub_urn)
self.save_checkpoint(node_datahub_urn, "dataset")

meta_aspects: Dict[str, Any] = {}
if self.config.enable_meta_mapping and node.meta:
Expand Down Expand Up @@ -1534,18 +1579,21 @@ def create_platform_mces(
self.report.report_workunit(wu)
yield wu

def save_checkpoint(self, node_datahub_urn: str) -> None:
if self.is_stateful_ingestion_configured():
cur_checkpoint = self.get_current_checkpoint(
self.get_default_ingestion_job_id()
)
def save_checkpoint(self, urn: str, entity_type: str) -> None:
# if stateful ingestion is not configured then return
if not self.is_stateful_ingestion_configured():
return

if cur_checkpoint is not None:
# Utilizing BaseSQLAlchemyCheckpointState class to save state
checkpoint_state = cast(
BaseSQLAlchemyCheckpointState, cur_checkpoint.state
)
checkpoint_state.add_table_urn(node_datahub_urn)
cur_checkpoint = self.get_current_checkpoint(
self.get_default_ingestion_job_id()
)
# if no checkpoint found then return
if cur_checkpoint is None:
return

# Cast and set the state
checkpoint_state = cast(DbtCheckpointState, cur_checkpoint.state)
checkpoint_state.set_checkpoint_urn(urn, entity_type)

def extract_query_tag_aspects(
self,
Expand Down Expand Up @@ -1900,8 +1948,7 @@ def create_checkpoint(self, job_id: JobId) -> Optional[Checkpoint]:
platform_instance_id=self.get_platform_instance_id(),
run_id=self.ctx.run_id,
config=self.config,
# Reusing BaseSQLAlchemyCheckpointState as it has needed functionality to support statefulness of DBT
state=BaseSQLAlchemyCheckpointState(),
state=DbtCheckpointState(),
)
return None

Expand Down
70 changes: 70 additions & 0 deletions metadata-ingestion/src/datahub/ingestion/source/state/dbt_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import logging
from typing import Callable, Dict, Iterable, List

import pydantic

from datahub.emitter.mce_builder import make_assertion_urn
from datahub.ingestion.source.state.checkpoint import CheckpointStateBase
from datahub.utilities.checkpoint_state_util import CheckpointStateUtil
from datahub.utilities.urns.urn import Urn

logger = logging.getLogger(__name__)


class DbtCheckpointState(CheckpointStateBase):
"""
Class for representing the checkpoint state for DBT sources.
Stores all nodes and assertions being ingested and is used to remove any stale entities.
"""

encoded_node_urns: List[str] = pydantic.Field(default_factory=list)
encoded_assertion_urns: List[str] = pydantic.Field(default_factory=list)

@staticmethod
def _get_assertion_lightweight_repr(assertion_urn: str) -> str:
"""Reduces the amount of text in the URNs for smaller state footprint."""
urn = Urn.create_from_string(assertion_urn)
key = urn.get_entity_id_as_string()
assert key is not None
return key

def add_assertion_urn(self, assertion_urn: str) -> None:
self.encoded_assertion_urns.append(
self._get_assertion_lightweight_repr(assertion_urn)
)

def get_assertion_urns_not_in(
self, checkpoint: "DbtCheckpointState"
) -> Iterable[str]:
"""
Dbt assertion are mapped to DataHub assertion concept
"""
difference = CheckpointStateUtil.get_encoded_urns_not_in(
self.encoded_assertion_urns, checkpoint.encoded_assertion_urns
)
for key in difference:
yield make_assertion_urn(key)

def get_node_urns_not_in(self, checkpoint: "DbtCheckpointState") -> Iterable[str]:
"""
Dbt node are mapped to DataHub dataset concept
"""
yield from CheckpointStateUtil.get_dataset_urns_not_in(
self.encoded_node_urns, checkpoint.encoded_node_urns
)

def add_node_urn(self, node_urn: str) -> None:
self.encoded_node_urns.append(
CheckpointStateUtil.get_dataset_lightweight_repr(node_urn)
)

def set_checkpoint_urn(self, urn: str, entity_type: str) -> None:
supported_entities_add_handlers: Dict[str, Callable[[str], None]] = {
"dataset": self.add_node_urn,
"assertion": self.add_assertion_urn,
}

if entity_type not in supported_entities_add_handlers:
logger.error(f"Can not save Unknown entity {entity_type} to checkpoint.")

supported_entities_add_handlers[entity_type](urn)
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,9 @@

import pydantic

from datahub.emitter.mce_builder import (
container_urn_to_key,
dataset_urn_to_key,
make_container_urn,
make_dataset_urn,
)
from datahub.emitter.mce_builder import container_urn_to_key, make_container_urn
from datahub.ingestion.source.state.checkpoint import CheckpointStateBase
from datahub.utilities.checkpoint_state_util import CheckpointStateUtil


class BaseSQLAlchemyCheckpointState(CheckpointStateBase):
Expand All @@ -21,19 +17,12 @@ class BaseSQLAlchemyCheckpointState(CheckpointStateBase):
encoded_table_urns: List[str] = pydantic.Field(default_factory=list)
encoded_view_urns: List[str] = pydantic.Field(default_factory=list)
encoded_container_urns: List[str] = pydantic.Field(default_factory=list)

@staticmethod
def _get_separator() -> str:
# Unique small string not allowed in URNs.
return "||"
encoded_assertion_urns: List[str] = pydantic.Field(default_factory=list)

@staticmethod
def _get_lightweight_repr(dataset_urn: str) -> str:
"""Reduces the amount of text in the URNs for smaller state footprint."""
SEP = BaseSQLAlchemyCheckpointState._get_separator()
key = dataset_urn_to_key(dataset_urn)
assert key is not None
return f"{key.platform}{SEP}{key.name}{SEP}{key.origin}"
return CheckpointStateUtil.get_dataset_lightweight_repr(dataset_urn)

@staticmethod
def _get_container_lightweight_repr(container_urn: str) -> str:
Expand All @@ -42,36 +31,29 @@ def _get_container_lightweight_repr(container_urn: str) -> str:
assert key is not None
return f"{key.guid}"

@staticmethod
def _get_dataset_urns_not_in(
encoded_urns_1: List[str], encoded_urns_2: List[str]
) -> Iterable[str]:
difference = set(encoded_urns_1) - set(encoded_urns_2)
for encoded_urn in difference:
platform, name, env = encoded_urn.split(
BaseSQLAlchemyCheckpointState._get_separator()
)
yield make_dataset_urn(platform, name, env)

@staticmethod
def _get_container_urns_not_in(
encoded_urns_1: List[str], encoded_urns_2: List[str]
) -> Iterable[str]:
difference = set(encoded_urns_1) - set(encoded_urns_2)
difference = CheckpointStateUtil.get_encoded_urns_not_in(
encoded_urns_1, encoded_urns_2
)
for guid in difference:
yield make_container_urn(guid)

def get_table_urns_not_in(
self, checkpoint: "BaseSQLAlchemyCheckpointState"
) -> Iterable[str]:
yield from self._get_dataset_urns_not_in(
"""Tables are mapped to DataHub dataset concept."""
yield from CheckpointStateUtil.get_dataset_urns_not_in(
self.encoded_table_urns, checkpoint.encoded_table_urns
)

def get_view_urns_not_in(
self, checkpoint: "BaseSQLAlchemyCheckpointState"
) -> Iterable[str]:
yield from self._get_dataset_urns_not_in(
"""Views are mapped to DataHub dataset concept."""
yield from CheckpointStateUtil.get_dataset_urns_not_in(
self.encoded_view_urns, checkpoint.encoded_view_urns
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def get_last_checkpoint(
):
return None

if JobId not in self.last_checkpoints:
if job_id not in self.last_checkpoints:
self.last_checkpoints[job_id] = self._get_last_checkpoint(
job_id, checkpoint_state_class
)
Expand Down
Loading