Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingest): fix handling of unions with aliases in post restli conversion #7058

Merged
merged 4 commits into from
Jan 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 40 additions & 9 deletions metadata-ingestion/src/datahub/emitter/serialization_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,54 @@
from typing import Any


def _json_transform(obj: Any, from_pattern: str, to_pattern: str) -> Any:
def _json_transform(obj: Any, from_pattern: str, to_pattern: str, pre: bool) -> Any:
if isinstance(obj, (dict, OrderedDict)):
if len(obj.keys()) == 1:
key: str = list(obj.keys())[0]
value = obj[key]
if key.startswith(from_pattern):
new_key = key.replace(from_pattern, to_pattern, 1)
return {new_key: _json_transform(value, from_pattern, to_pattern)}
return {
new_key: _json_transform(value, from_pattern, to_pattern, pre=pre)
}

if "fieldDiscriminator" in obj:
# Field discriminators are used for unions between primitive types.
# Avro uses "fieldDiscriminator" for unions between primitive types, while
# rest.li simply uses the field names. We have to add special handling for this.
if pre and "fieldDiscriminator" in obj:
# On the way out, we need to remove the field discriminator.
field = obj["fieldDiscriminator"]
return {field: _json_transform(obj[field], from_pattern, to_pattern)}
return {
field: _json_transform(obj[field], from_pattern, to_pattern, pre=pre)
}
if (
hsheth2 marked this conversation as resolved.
Show resolved Hide resolved
hsheth2 marked this conversation as resolved.
Show resolved Hide resolved
not pre
and set(obj.keys()) == {"cost", "costType"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this check continue to work if we evolve the schema of the CostModel in the future?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no they won't work

I've added a bit more testing that will alert us any time the assumptions made here are violated by changes in the model

and isinstance(obj["cost"], dict)
and len(obj["cost"].keys()) == 1
and list(obj["cost"].keys())[0] in {"costId", "costCode"}
):
# On the way in, we need to add back the field discriminator.
# Note that "CostCostClass" is the only usage of fieldDiscriminator in our models,
# so this works ok.
return {
"cost": {
**obj["cost"],
"fieldDiscriminator": list(obj["cost"].keys())[0],
},
"costType": obj["costType"],
}

new_obj: Any = {
key: _json_transform(value, from_pattern, to_pattern)
key: _json_transform(value, from_pattern, to_pattern, pre=pre)
for key, value in obj.items()
if value is not None
}

return new_obj
elif isinstance(obj, list):
new_obj = [_json_transform(item, from_pattern, to_pattern) for item in obj]
new_obj = [
_json_transform(item, from_pattern, to_pattern, pre=pre) for item in obj
]
return new_obj
elif isinstance(obj, bytes):
return obj.decode()
Expand All @@ -34,12 +59,18 @@ def _json_transform(obj: Any, from_pattern: str, to_pattern: str) -> Any:
def pre_json_transform(obj: Any) -> Any:
"""Usually called before sending avro-serialized json over to the rest.li server"""
return _json_transform(
obj, from_pattern="com.linkedin.pegasus2avro.", to_pattern="com.linkedin."
obj,
from_pattern="com.linkedin.pegasus2avro.",
to_pattern="com.linkedin.",
pre=True,
)


def post_json_transform(obj: Any) -> Any:
"""Usually called after receiving restli-serialized json before instantiating into avro-generated Python classes"""
return _json_transform(
obj, from_pattern="com.linkedin.", to_pattern="com.linkedin.pegasus2avro."
obj,
from_pattern="com.linkedin.",
to_pattern="com.linkedin.pegasus2avro.",
pre=False,
)
57 changes: 57 additions & 0 deletions metadata-ingestion/tests/unit/serde/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import datahub.metadata.schema_classes as models
from datahub.cli.json_file import check_mce_file
from datahub.emitter import mce_builder
from datahub.emitter.serialization_helper import post_json_transform, pre_json_transform
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.file import FileSourceConfig, GenericFileSource
from datahub.metadata.schema_classes import (
Expand Down Expand Up @@ -199,6 +200,9 @@ def test_field_discriminator() -> None:

assert cost_object.validate()

redo = models.CostClass.from_obj(cost_object.to_obj())
assert redo == cost_object


def test_type_error() -> None:
dataflow = models.DataFlowSnapshotClass(
Expand Down Expand Up @@ -328,3 +332,56 @@ def test_write_optional_empty_dict() -> None:

out = json.dumps(model.to_obj())
assert out == '{"type": "SUCCESS", "nativeResults": {}}'


@pytest.mark.parametrize(
"model,ref_server_obj",
[
(
models.MLModelSnapshotClass(
urn="urn:li:mlModel:(urn:li:dataPlatform:science,scienceModel,PROD)",
aspects=[
models.CostClass(
costType=models.CostTypeClass.ORG_COST_TYPE,
cost=models.CostCostClass(
fieldDiscriminator=models.CostCostDiscriminatorClass.costCode,
costCode="sampleCostCode",
),
)
],
),
{
"urn": "urn:li:mlModel:(urn:li:dataPlatform:science,scienceModel,PROD)",
"aspects": [
{
"com.linkedin.common.Cost": {
"costType": "ORG_COST_TYPE",
"cost": {"costCode": "sampleCostCode"},
}
}
],
},
),
],
)
def test_json_transforms(model, ref_server_obj):
server_obj = pre_json_transform(model.to_obj())
assert server_obj == ref_server_obj

post_obj = post_json_transform(server_obj)

recovered = type(model).from_obj(post_obj)
assert recovered == model


def test_only_cost_has_field_discriminator():
# We have special handling for field discriminators in our json serialization helpers.
# Those assume that this is the only class with a field discriminator, so we want to
# validate that assumption here.

for cls in set(models.__SCHEMA_TYPES.values()):
if cls is models.CostCostClass:
continue

if hasattr(cls, "fieldDiscriminator"):
raise ValueError(f"{cls} has a fieldDiscriminator")
36 changes: 0 additions & 36 deletions metadata-ingestion/tests/unit/test_rest_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,42 +78,6 @@
"systemMetadata": {},
},
),
(
# Verify the behavior of the fieldDiscriminator for primitive enums.
models.MetadataChangeEventClass(
proposedSnapshot=models.MLModelSnapshotClass(
urn="urn:li:mlModel:(urn:li:dataPlatform:science,scienceModel,PROD)",
aspects=[
models.CostClass(
costType=models.CostTypeClass.ORG_COST_TYPE,
cost=models.CostCostClass(
fieldDiscriminator=models.CostCostDiscriminatorClass.costCode,
costCode="sampleCostCode",
),
)
],
)
),
"/entities?action=ingest",
{
"entity": {
"value": {
"com.linkedin.metadata.snapshot.MLModelSnapshot": {
"urn": "urn:li:mlModel:(urn:li:dataPlatform:science,scienceModel,PROD)",
"aspects": [
{
"com.linkedin.common.Cost": {
"costType": "ORG_COST_TYPE",
"cost": {"costCode": "sampleCostCode"},
}
}
],
}
}
},
"systemMetadata": {},
},
),
(
# Verify the serialization behavior with chart type enums.
models.MetadataChangeEventClass(
Expand Down