Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ingest): bigquery-beta - handling complex types properly #6062

Merged
merged 4 commits into from
Sep 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@
GlobalTagsClass,
TagAssociationClass,
)
from datahub.utilities.hive_schema_to_avro import (
HiveColumnToAvroConverter,
get_schema_fields_for_hive_column,
)
from datahub.utilities.mapping import Constants
from datahub.utilities.perf_timer import PerfTimer
from datahub.utilities.registries.domain_registry import DomainRegistry
Expand Down Expand Up @@ -785,20 +789,35 @@ def gen_dataset_urn(self, dataset_name: str, project_id: str, table: str) -> str
)
return dataset_urn

def gen_schema_metadata(
self,
dataset_urn: str,
table: Union[BigqueryTable, BigqueryView],
dataset_name: str,
) -> MetadataWorkUnit:
schema_metadata = SchemaMetadata(
schemaName=dataset_name,
platform=make_data_platform_urn(self.platform),
version=0,
hash="",
platformSchema=MySqlDDL(tableSchema=""),
fields=[
SchemaField(
def gen_schema_fields(self, columns: List[BigqueryColumn]) -> List[SchemaField]:
schema_fields: List[SchemaField] = []

HiveColumnToAvroConverter._STRUCT_TYPE_SEPARATOR = " "
_COMPLEX_TYPE = re.compile("^(struct|array)")
last_id = -1
for col in columns:

if _COMPLEX_TYPE.match(col.data_type.lower()):
# If the we have seen the ordinal position that most probably means we already processed this complex type
if last_id != col.ordinal_position:
schema_fields.extend(
get_schema_fields_for_hive_column(
col.name, col.data_type.lower(), description=col.comment
)
)

# We have to add complex type comments to the correct level
if col.comment:
for idx, field in enumerate(schema_fields):
# Remove all the [version=2.0].[type=struct]. tags to get the field path
if (
re.sub(r"\[.*?\]\.", "", field.fieldPath, 0, re.MULTILINE)
== col.field_path
):
field.description = col.comment
schema_fields[idx] = field
else:
field = SchemaField(
fieldPath=col.name,
type=SchemaFieldDataType(
self.BIGQUERY_FIELD_TYPE_MAPPINGS.get(col.data_type, NullType)()
Expand All @@ -817,8 +836,24 @@ def gen_schema_metadata(
if col.is_partition_column
else GlobalTagsClass(tags=[]),
)
for col in table.columns
],
schema_fields.append(field)
last_id = col.ordinal_position
return schema_fields

def gen_schema_metadata(
self,
dataset_urn: str,
table: Union[BigqueryTable, BigqueryView],
dataset_name: str,
) -> MetadataWorkUnit:

schema_metadata = SchemaMetadata(
schemaName=dataset_name,
platform=make_data_platform_urn(self.platform),
version=0,
hash="",
platformSchema=MySqlDDL(tableSchema=""),
fields=self.gen_schema_fields(table.columns),
)
wu = wrap_aspect_as_workunit(
"dataset", dataset_urn, "schemaMetadata", schema_metadata
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class BigqueryTableIdentifier:
table: str

invalid_chars: ClassVar[Set[str]] = {"$", "@"}
_BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX: ClassVar[str] = "((.+)[_$])?(\\d{4,10})$"
_BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX: ClassVar[str] = "((.+)[_$])?(\\d{8})$"

@staticmethod
def get_table_and_shard(table_name: str) -> Tuple[str, Optional[str]]:
Expand All @@ -101,17 +101,10 @@ def from_string_name(cls, table: str) -> "BigqueryTableIdentifier":
def raw_table_name(self):
return f"{self.project_id}.{self.dataset}.{self.table}"

@staticmethod
def _remove_suffix(input_string: str, suffixes: List[str]) -> str:
for suffix in suffixes:
if input_string.endswith(suffix):
return input_string[: -len(suffix)]
return input_string

def get_table_display_name(self) -> str:
shortened_table_name = self.table
# if table name ends in _* or * then we strip it as that represents a query on a sharded table
shortened_table_name = self._remove_suffix(shortened_table_name, ["_*", "*"])
shortened_table_name = re.sub("(_(.+)?\\*)|\\*$", "", shortened_table_name)

table_name, _ = self.get_table_and_shard(shortened_table_name)
if not table_name:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
logger: logging.Logger = logging.getLogger(__name__)


@dataclass
@dataclass(frozen=True, eq=True)
class BigqueryColumn:
name: str
ordinal_position: int
field_path: str
is_nullable: bool
is_partition_column: bool
data_type: str
Expand Down Expand Up @@ -175,6 +176,7 @@ class BigqueryQuery:
c.table_name as table_name,
c.column_name as column_name,
c.ordinal_position as ordinal_position,
cfp.field_path as field_path,
c.is_nullable as is_nullable,
c.data_type as data_type,
description as comment,
Expand All @@ -194,6 +196,7 @@ class BigqueryQuery:
c.table_name as table_name,
c.column_name as column_name,
c.ordinal_position as ordinal_position,
cfp.field_path as field_path,
c.is_nullable as is_nullable,
c.data_type as data_type,
c.is_hidden as is_hidden,
Expand Down Expand Up @@ -355,6 +358,7 @@ def get_columns_for_dataset(
BigqueryColumn(
name=column.column_name,
ordinal_position=column.ordinal_position,
field_path=column.field_path,
is_nullable=column.is_nullable == "YES",
data_type=column.data_type,
comment=column.comment,
Expand All @@ -379,6 +383,7 @@ def get_columns_for_table(
name=column.column_name,
ordinal_position=column.ordinal_position,
is_nullable=column.is_nullable == "YES",
field_path=column.field_path,
data_type=column.data_type,
comment=column.comment,
is_partition_column=column.is_partitioning_column == "YES",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,19 +268,12 @@ def is_temporary_table(self, prefix: str) -> bool:
# Temporary tables will have a dataset that begins with an underscore.
return self.dataset.startswith(prefix)

@staticmethod
def remove_suffix(input_string, suffix):
if suffix and input_string.endswith(suffix):
return input_string[: -len(suffix)]
return input_string

def remove_extras(self, sharded_table_regex: str) -> "BigQueryTableRef":
# Handle partitioned and sharded tables.
table_name: Optional[str] = None
shortened_table_name = self.table
# if table name ends in _* or * then we strip it as that represents a query on a sharded table
shortened_table_name = self.remove_suffix(shortened_table_name, "_*")
shortened_table_name = self.remove_suffix(shortened_table_name, "*")
shortened_table_name = re.sub("(_(.+)?\\*)|\\*$", "", shortened_table_name)

matches = re.match(sharded_table_regex, shortened_table_name)
if matches:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from datahub.configuration.common import ConfigModel, ConfigurationError

_BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX: str = "((.+)[_$])?(\\d{4,10})$"
_BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX: str = "((.+)[_$])?(\\d{8})$"


class BigQueryBaseConfig(ConfigModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class HiveColumnToAvroConverter:
"float": "float",
"tinyint": "int",
"smallint": "int",
"int": "int",
"bigint": "long",
"varchar": "string",
"char": "string",
Expand All @@ -34,6 +33,8 @@ class HiveColumnToAvroConverter:

_FIXED_STRING = re.compile(r"(var)?char\(\s*(\d+)\s*\)")

_STRUCT_TYPE_SEPARATOR = ":"

@staticmethod
def _parse_datatype_string(
s: str, **kwargs: Any
Expand Down Expand Up @@ -103,7 +104,9 @@ def _parse_struct_fields_string(s: str, **kwargs: Any) -> Dict[str, object]:
parts = HiveColumnToAvroConverter._ignore_brackets_split(s, ",")
fields = []
for part in parts:
name_and_type = HiveColumnToAvroConverter._ignore_brackets_split(part, ":")
name_and_type = HiveColumnToAvroConverter._ignore_brackets_split(
part.strip(), HiveColumnToAvroConverter._STRUCT_TYPE_SEPARATOR
)
if len(name_and_type) != 2:
raise ValueError(
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def test_bigquery_ref_extra_removal():

table_ref = BigQueryTableRef("project-1234", "dataset-4567", "foo_2022")
new_table_ref = table_ref.remove_extras(_BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX)
assert new_table_ref.table == "foo"
assert new_table_ref.table == "foo_2022"
assert new_table_ref.project == table_ref.project
assert new_table_ref.dataset == table_ref.dataset

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def test_bigquery_table_sanitasitation():
new_table_ref = BigqueryTableIdentifier.from_string_name(
table_ref.table_identifier.get_table_name()
)
assert new_table_ref.table == "foo"
assert new_table_ref.table == "foo_2022"
assert new_table_ref.project_id == "project-1234"
assert new_table_ref.dataset == "dataset-4567"

Expand Down