Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingest): standardize sql type mappings #11982

Merged
merged 1 commit into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 7 additions & 61 deletions metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,7 @@
make_assertion_from_test,
make_assertion_result_from_test,
)
from datahub.ingestion.source.sql.sql_types import (
ATHENA_SQL_TYPES_MAP,
BIGQUERY_TYPES_MAP,
POSTGRES_TYPES_MAP,
SNOWFLAKE_TYPES_MAP,
SPARK_SQL_TYPES_MAP,
TRINO_SQL_TYPES_MAP,
VERTICA_SQL_TYPES_MAP,
resolve_athena_modified_type,
resolve_postgres_modified_type,
resolve_trino_modified_type,
resolve_vertica_modified_type,
)
from datahub.ingestion.source.sql.sql_types import resolve_sql_type
from datahub.ingestion.source.state.stale_entity_removal_handler import (
StaleEntityRemovalHandler,
StaleEntityRemovalSourceReport,
Expand All @@ -89,17 +77,11 @@
from datahub.metadata.com.linkedin.pegasus2avro.metadata.snapshot import DatasetSnapshot
from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent
from datahub.metadata.com.linkedin.pegasus2avro.schema import (
BooleanTypeClass,
DateTypeClass,
MySqlDDL,
NullTypeClass,
NumberTypeClass,
RecordType,
SchemaField,
SchemaFieldDataType,
SchemaMetadata,
StringTypeClass,
TimeTypeClass,
)
from datahub.metadata.schema_classes import (
DataPlatformInstanceClass,
Expand Down Expand Up @@ -804,28 +786,6 @@ def make_mapping_upstream_lineage(
)


# See https://github.com/fishtown-analytics/dbt/blob/master/core/dbt/adapters/sql/impl.py
_field_type_mapping = {
"boolean": BooleanTypeClass,
"date": DateTypeClass,
"time": TimeTypeClass,
"numeric": NumberTypeClass,
"text": StringTypeClass,
"timestamp with time zone": DateTypeClass,
"timestamp without time zone": DateTypeClass,
"integer": NumberTypeClass,
"float8": NumberTypeClass,
"struct": RecordType,
**POSTGRES_TYPES_MAP,
**SNOWFLAKE_TYPES_MAP,
**BIGQUERY_TYPES_MAP,
**SPARK_SQL_TYPES_MAP,
**TRINO_SQL_TYPES_MAP,
**ATHENA_SQL_TYPES_MAP,
**VERTICA_SQL_TYPES_MAP,
}


def get_column_type(
report: DBTSourceReport,
dataset_name: str,
Expand All @@ -835,24 +795,10 @@ def get_column_type(
"""
Maps known DBT types to datahub types
"""
TypeClass: Any = _field_type_mapping.get(column_type) if column_type else None

if TypeClass is None and column_type:
# resolve a modified type
if dbt_adapter == "trino":
TypeClass = resolve_trino_modified_type(column_type)
elif dbt_adapter == "athena":
TypeClass = resolve_athena_modified_type(column_type)
elif dbt_adapter == "postgres" or dbt_adapter == "redshift":
# Redshift uses a variant of Postgres, so we can use the same logic.
TypeClass = resolve_postgres_modified_type(column_type)
elif dbt_adapter == "vertica":
TypeClass = resolve_vertica_modified_type(column_type)
elif dbt_adapter == "snowflake":
# Snowflake types are uppercase, so we check that.
TypeClass = _field_type_mapping.get(column_type.upper())

# if still not found, report the warning

TypeClass = resolve_sql_type(column_type, dbt_adapter)

# if still not found, report a warning
if TypeClass is None:
if column_type:
report.info(
Expand All @@ -861,9 +807,9 @@ def get_column_type(
context=f"{dataset_name} - {column_type}",
log=False,
)
TypeClass = NullTypeClass
TypeClass = NullTypeClass()

return SchemaFieldDataType(type=TypeClass())
return SchemaFieldDataType(type=TypeClass)


@platform_name("dbt")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TimeType,
)

# TODO: Replace with standardized types in sql_types.py
FIELD_TYPE_MAPPING: Dict[
str,
Type[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ class RedshiftSource(StatefulIngestionSourceBase, TestableSource):
```
"""

# TODO: Replace with standardized types in sql_types.py
REDSHIFT_FIELD_TYPE_MAPPINGS: Dict[
str,
Type[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
logger = logging.getLogger(__name__)

# https://docs.snowflake.com/en/sql-reference/intro-summary-data-types.html
# TODO: Move to the standardized types in sql_types.py
SNOWFLAKE_FIELD_TYPE_MAPPINGS = {
"DATE": DateType,
"BIGINT": NumberType,
Expand Down
79 changes: 72 additions & 7 deletions metadata-ingestion/src/datahub/ingestion/source/sql/sql_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Any, Dict, ValuesView
from typing import Any, Dict, Optional, Type, Union, ValuesView

from datahub.metadata.com.linkedin.pegasus2avro.schema import (
ArrayType,
Expand All @@ -16,14 +16,28 @@
UnionType,
)

# these can be obtained by running `select format_type(oid, null),* from pg_type;`
# we've omitted the types without a meaningful DataHub type (e.g. postgres-specific types, index vectors, etc.)
# (run `\copy (select format_type(oid, null),* from pg_type) to 'pg_type.csv' csv header;` to get a CSV)
DATAHUB_FIELD_TYPE = Union[
ArrayType,
BooleanType,
BytesType,
DateType,
EnumType,
MapType,
NullType,
NumberType,
RecordType,
StringType,
TimeType,
UnionType,
]

# we map from format_type since this is what dbt uses
# see https://github.com/fishtown-analytics/dbt/blob/master/plugins/postgres/dbt/include/postgres/macros/catalog.sql#L22

# see https://www.npgsql.org/dev/types.html for helpful type annotations
# These can be obtained by running `select format_type(oid, null),* from pg_type;`
# We've omitted the types without a meaningful DataHub type (e.g. postgres-specific types, index vectors, etc.)
# (run `\copy (select format_type(oid, null),* from pg_type) to 'pg_type.csv' csv header;` to get a CSV)
# We map from format_type since this is what dbt uses.
# See https://github.com/fishtown-analytics/dbt/blob/master/plugins/postgres/dbt/include/postgres/macros/catalog.sql#L22
# See https://www.npgsql.org/dev/types.html for helpful type annotations
POSTGRES_TYPES_MAP: Dict[str, Any] = {
"boolean": BooleanType,
"bytea": BytesType,
Expand Down Expand Up @@ -430,3 +444,54 @@ def resolve_vertica_modified_type(type_string: str) -> Any:
"geography": None,
"uuid": StringType,
}


_merged_mapping = {
"boolean": BooleanType,
"date": DateType,
"time": TimeType,
"numeric": NumberType,
"text": StringType,
"timestamp with time zone": DateType,
"timestamp without time zone": DateType,
"integer": NumberType,
"float8": NumberType,
"struct": RecordType,
**POSTGRES_TYPES_MAP,
**SNOWFLAKE_TYPES_MAP,
**BIGQUERY_TYPES_MAP,
**SPARK_SQL_TYPES_MAP,
**TRINO_SQL_TYPES_MAP,
**ATHENA_SQL_TYPES_MAP,
**VERTICA_SQL_TYPES_MAP,
}


def resolve_sql_type(
column_type: Optional[str],
platform: Optional[str] = None,
) -> Optional[DATAHUB_FIELD_TYPE]:
# In theory, we should use the platform-specific mapping where available.
# However, the types don't ever conflict, so the merged mapping is fine.
TypeClass: Optional[Type[DATAHUB_FIELD_TYPE]] = (
_merged_mapping.get(column_type) if column_type else None
)

if TypeClass is None and column_type:
# resolve a modified type
if platform == "trino":
TypeClass = resolve_trino_modified_type(column_type)
elif platform == "athena":
TypeClass = resolve_athena_modified_type(column_type)
elif platform == "postgres" or platform == "redshift":
# Redshift uses a variant of Postgres, so we can use the same logic.
TypeClass = resolve_postgres_modified_type(column_type)
elif platform == "vertica":
TypeClass = resolve_vertica_modified_type(column_type)
elif platform == "snowflake":
# Snowflake types are uppercase, so we check that.
TypeClass = _merged_mapping.get(column_type.upper())

if TypeClass:
return TypeClass()
return None
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

logger = logging.getLogger(__name__)

# TODO: (maybe) Replace with standardized types in sql_types.py
DATA_TYPE_REGISTRY: dict = {
ColumnTypeName.BOOLEAN: BooleanTypeClass,
ColumnTypeName.BYTE: BytesTypeClass,
Expand Down
69 changes: 0 additions & 69 deletions metadata-ingestion/tests/integration/dbt/test_dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,6 @@
from datahub.ingestion.run.pipeline_config import PipelineConfig, SourceConfig
from datahub.ingestion.source.dbt.dbt_common import DBTEntitiesEnabled, EmitDirective
from datahub.ingestion.source.dbt.dbt_core import DBTCoreConfig, DBTCoreSource
from datahub.ingestion.source.sql.sql_types import (
ATHENA_SQL_TYPES_MAP,
TRINO_SQL_TYPES_MAP,
resolve_athena_modified_type,
resolve_trino_modified_type,
)
from tests.test_helpers import mce_helpers, test_connection_helpers

FROZEN_TIME = "2022-02-03 07:00:00"
Expand Down Expand Up @@ -362,69 +356,6 @@ def test_dbt_tests(test_resources_dir, pytestconfig, tmp_path, mock_time, **kwar
)


@pytest.mark.parametrize(
"data_type, expected_data_type",
[
("boolean", "boolean"),
("tinyint", "tinyint"),
("smallint", "smallint"),
("int", "int"),
("integer", "integer"),
("bigint", "bigint"),
("real", "real"),
("double", "double"),
("decimal(10,0)", "decimal"),
("varchar(20)", "varchar"),
("char", "char"),
("varbinary", "varbinary"),
("json", "json"),
("date", "date"),
("time", "time"),
("time(12)", "time"),
("timestamp", "timestamp"),
("timestamp(3)", "timestamp"),
("row(x bigint, y double)", "row"),
("array(row(x bigint, y double))", "array"),
("map(varchar, varchar)", "map"),
],
)
def test_resolve_trino_modified_type(data_type, expected_data_type):
assert (
resolve_trino_modified_type(data_type)
== TRINO_SQL_TYPES_MAP[expected_data_type]
)


@pytest.mark.parametrize(
"data_type, expected_data_type",
[
("boolean", "boolean"),
("tinyint", "tinyint"),
("smallint", "smallint"),
("int", "int"),
("integer", "integer"),
("bigint", "bigint"),
("float", "float"),
("double", "double"),
("decimal(10,0)", "decimal"),
("varchar(20)", "varchar"),
("char", "char"),
("binary", "binary"),
("date", "date"),
("timestamp", "timestamp"),
("timestamp(3)", "timestamp"),
("struct<x timestamp(3), y timestamp>", "struct"),
("array<struct<x bigint, y double>>", "array"),
("map<varchar, varchar>", "map"),
],
)
def test_resolve_athena_modified_type(data_type, expected_data_type):
assert (
resolve_athena_modified_type(data_type)
== ATHENA_SQL_TYPES_MAP[expected_data_type]
)


@pytest.mark.integration
@freeze_time(FROZEN_TIME)
def test_dbt_tests_only_assertions(
Expand Down
78 changes: 78 additions & 0 deletions metadata-ingestion/tests/unit/test_sql_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pytest

from datahub.ingestion.source.sql.sql_types import (
ATHENA_SQL_TYPES_MAP,
TRINO_SQL_TYPES_MAP,
resolve_athena_modified_type,
resolve_sql_type,
resolve_trino_modified_type,
)
from datahub.metadata.schema_classes import BooleanTypeClass, StringTypeClass


@pytest.mark.parametrize(
"data_type, expected_data_type",
[
("boolean", "boolean"),
("tinyint", "tinyint"),
("smallint", "smallint"),
("int", "int"),
("integer", "integer"),
("bigint", "bigint"),
("real", "real"),
("double", "double"),
("decimal(10,0)", "decimal"),
("varchar(20)", "varchar"),
("char", "char"),
("varbinary", "varbinary"),
("json", "json"),
("date", "date"),
("time", "time"),
("time(12)", "time"),
("timestamp", "timestamp"),
("timestamp(3)", "timestamp"),
("row(x bigint, y double)", "row"),
("array(row(x bigint, y double))", "array"),
("map(varchar, varchar)", "map"),
],
)
def test_resolve_trino_modified_type(data_type, expected_data_type):
assert (
resolve_trino_modified_type(data_type)
== TRINO_SQL_TYPES_MAP[expected_data_type]
)


@pytest.mark.parametrize(
"data_type, expected_data_type",
[
("boolean", "boolean"),
("tinyint", "tinyint"),
("smallint", "smallint"),
("int", "int"),
("integer", "integer"),
("bigint", "bigint"),
("float", "float"),
("double", "double"),
("decimal(10,0)", "decimal"),
("varchar(20)", "varchar"),
("char", "char"),
("binary", "binary"),
("date", "date"),
("timestamp", "timestamp"),
("timestamp(3)", "timestamp"),
("struct<x timestamp(3), y timestamp>", "struct"),
("array<struct<x bigint, y double>>", "array"),
("map<varchar, varchar>", "map"),
],
)
def test_resolve_athena_modified_type(data_type, expected_data_type):
assert (
resolve_athena_modified_type(data_type)
== ATHENA_SQL_TYPES_MAP[expected_data_type]
)


def test_resolve_sql_type() -> None:
assert resolve_sql_type("boolean") == BooleanTypeClass()
assert resolve_sql_type("varchar") == StringTypeClass()
Loading