diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage.py index 448468718c1211..4956f41c74c290 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage.py @@ -95,17 +95,24 @@ class SnowflakeUpstreamTable: downstreamColumns: List[SnowflakeColumnWithLineage] @classmethod - def from_dict(cls, dataset, upstreams_columns_dict, downstream_columns_dict): + def from_dict(cls, dataset, upstreams_columns_json, downstream_columns_json): try: + upstreams_columns_list = [] + downstream_columns_list = [] + if upstreams_columns_json is not None: + upstreams_columns_list = json.loads(upstreams_columns_json) + if downstream_columns_json is not None: + downstream_columns_list = json.loads(downstream_columns_json) + table_with_upstreams = cls( dataset, [ SnowflakeColumnReference.parse_obj(col) - for col in upstreams_columns_dict + for col in upstreams_columns_list ], [ SnowflakeColumnWithLineage.parse_obj(col) - for col in downstream_columns_dict + for col in downstream_columns_list ], ) except ValidationError: @@ -390,8 +397,8 @@ def _populate_lineage(self, conn: SnowflakeConnection) -> None: # (, , ) SnowflakeUpstreamTable.from_dict( upstream_table_name, - json.loads(db_row["UPSTREAM_TABLE_COLUMNS"]), - json.loads(db_row["DOWNSTREAM_TABLE_COLUMNS"]), + db_row["UPSTREAM_TABLE_COLUMNS"], + db_row["DOWNSTREAM_TABLE_COLUMNS"], ), ) num_edges += 1 @@ -441,7 +448,7 @@ def _populate_view_upstream_lineage(self, conn: SnowflakeConnection) -> None: # key is the downstream view name self._lineage_map[view_name].update_lineage( # (, , ) - SnowflakeUpstreamTable.from_dict(view_upstream, [], []) + SnowflakeUpstreamTable.from_dict(view_upstream, None, None) ) num_edges += 1 logger.debug( @@ -499,8 +506,8 @@ def _populate_view_downstream_lineage(self, conn: SnowflakeConnection) -> None: # (, , ) SnowflakeUpstreamTable.from_dict( view_name, - json.loads(db_row["VIEW_COLUMNS"]), - json.loads(db_row["DOWNSTREAM_TABLE_COLUMNS"]), + db_row["VIEW_COLUMNS"], + db_row["DOWNSTREAM_TABLE_COLUMNS"], ) ) self.report.num_view_to_table_edges_scanned += 1