diff --git a/metadata-ingestion/src/datahub/emitter/serialization_helper.py b/metadata-ingestion/src/datahub/emitter/serialization_helper.py index cad4e9dd3270fc..ab9402ec891887 100644 --- a/metadata-ingestion/src/datahub/emitter/serialization_helper.py +++ b/metadata-ingestion/src/datahub/emitter/serialization_helper.py @@ -1,30 +1,92 @@ from collections import OrderedDict -from typing import Any +from typing import Any, Tuple -def _json_transform(obj: Any, from_pattern: str, to_pattern: str) -> Any: +def _pre_handle_union_with_aliases( + obj: Any, + from_pattern: str, + to_pattern: str, +) -> Tuple[bool, Any]: + # PDL supports "Unions with aliases", which are unions that have a field name + # that is different from the type name. + # See https://linkedin.github.io/rest.li/pdl_schema#union-with-aliases + # When generating the avro schema, Rest.li adds a "fieldDiscriminator" field + # to disambiguate between the different union types. + + if "fieldDiscriminator" in obj: + # On the way out, we need to remove the field discriminator. + field = obj["fieldDiscriminator"] + return True, { + field: _json_transform(obj[field], from_pattern, to_pattern, pre=True) + } + + return False, None + + +def _post_handle_unions_with_aliases( + obj: Any, + from_pattern: str, + to_pattern: str, +) -> Tuple[bool, Any]: + # Note that "CostCostClass" is the only usage of a union with aliases in our + # current PDL / metadata model, so the below hardcoding works. + # Because this method is brittle and prone to becoming stale, we validate + # the aforementioned assumption in our tests. + + if ( + set(obj.keys()) == {"cost", "costType"} + and isinstance(obj["cost"], dict) + and len(obj["cost"].keys()) == 1 + and list(obj["cost"].keys())[0] in {"costId", "costCode"} + ): + # On the way in, we need to add back the field discriminator. + return True, { + "cost": { + **obj["cost"], + "fieldDiscriminator": list(obj["cost"].keys())[0], + }, + "costType": obj["costType"], + } + + return False, None + + +def _json_transform(obj: Any, from_pattern: str, to_pattern: str, pre: bool) -> Any: if isinstance(obj, (dict, OrderedDict)): if len(obj.keys()) == 1: key: str = list(obj.keys())[0] value = obj[key] if key.startswith(from_pattern): new_key = key.replace(from_pattern, to_pattern, 1) - return {new_key: _json_transform(value, from_pattern, to_pattern)} + return { + new_key: _json_transform(value, from_pattern, to_pattern, pre=pre) + } + + if pre: + handled, new_obj = _pre_handle_union_with_aliases( + obj, from_pattern, to_pattern + ) + if handled: + return new_obj - if "fieldDiscriminator" in obj: - # Field discriminators are used for unions between primitive types. - field = obj["fieldDiscriminator"] - return {field: _json_transform(obj[field], from_pattern, to_pattern)} + if not pre: + handled, new_obj = _post_handle_unions_with_aliases( + obj, from_pattern, to_pattern + ) + if handled: + return new_obj - new_obj: Any = { - key: _json_transform(value, from_pattern, to_pattern) + new_obj = { + key: _json_transform(value, from_pattern, to_pattern, pre=pre) for key, value in obj.items() if value is not None } return new_obj elif isinstance(obj, list): - new_obj = [_json_transform(item, from_pattern, to_pattern) for item in obj] + new_obj = [ + _json_transform(item, from_pattern, to_pattern, pre=pre) for item in obj + ] return new_obj elif isinstance(obj, bytes): return obj.decode() @@ -34,12 +96,18 @@ def _json_transform(obj: Any, from_pattern: str, to_pattern: str) -> Any: def pre_json_transform(obj: Any) -> Any: """Usually called before sending avro-serialized json over to the rest.li server""" return _json_transform( - obj, from_pattern="com.linkedin.pegasus2avro.", to_pattern="com.linkedin." + obj, + from_pattern="com.linkedin.pegasus2avro.", + to_pattern="com.linkedin.", + pre=True, ) def post_json_transform(obj: Any) -> Any: """Usually called after receiving restli-serialized json before instantiating into avro-generated Python classes""" return _json_transform( - obj, from_pattern="com.linkedin.", to_pattern="com.linkedin.pegasus2avro." + obj, + from_pattern="com.linkedin.", + to_pattern="com.linkedin.pegasus2avro.", + pre=False, ) diff --git a/metadata-ingestion/tests/unit/serde/test_serde.py b/metadata-ingestion/tests/unit/serde/test_serde.py index 6e598656259949..badd4e50d7d06b 100644 --- a/metadata-ingestion/tests/unit/serde/test_serde.py +++ b/metadata-ingestion/tests/unit/serde/test_serde.py @@ -12,6 +12,7 @@ import datahub.metadata.schema_classes as models from datahub.cli.json_file import check_mce_file from datahub.emitter import mce_builder +from datahub.emitter.serialization_helper import post_json_transform, pre_json_transform from datahub.ingestion.run.pipeline import Pipeline from datahub.ingestion.source.file import FileSourceConfig, GenericFileSource from datahub.metadata.schema_classes import ( @@ -199,6 +200,9 @@ def test_field_discriminator() -> None: assert cost_object.validate() + redo = models.CostClass.from_obj(cost_object.to_obj()) + assert redo == cost_object + def test_type_error() -> None: dataflow = models.DataFlowSnapshotClass( @@ -328,3 +332,66 @@ def test_write_optional_empty_dict() -> None: out = json.dumps(model.to_obj()) assert out == '{"type": "SUCCESS", "nativeResults": {}}' + + +@pytest.mark.parametrize( + "model,ref_server_obj", + [ + ( + models.MLModelSnapshotClass( + urn="urn:li:mlModel:(urn:li:dataPlatform:science,scienceModel,PROD)", + aspects=[ + models.CostClass( + costType=models.CostTypeClass.ORG_COST_TYPE, + cost=models.CostCostClass( + fieldDiscriminator=models.CostCostDiscriminatorClass.costCode, + costCode="sampleCostCode", + ), + ) + ], + ), + { + "urn": "urn:li:mlModel:(urn:li:dataPlatform:science,scienceModel,PROD)", + "aspects": [ + { + "com.linkedin.common.Cost": { + "costType": "ORG_COST_TYPE", + "cost": {"costCode": "sampleCostCode"}, + } + } + ], + }, + ), + ], +) +def test_json_transforms(model, ref_server_obj): + server_obj = pre_json_transform(model.to_obj()) + assert server_obj == ref_server_obj + + post_obj = post_json_transform(server_obj) + + recovered = type(model).from_obj(post_obj) + assert recovered == model + + +def test_unions_with_aliases_assumptions(): + # We have special handling for unions with aliases in our json serialization helpers. + # Specifically, we assume that cost is the only instance of a union with alias. + # This test validates that assumption. + + for cls in set(models.__SCHEMA_TYPES.values()): + if cls is models.CostCostClass: + continue + + if hasattr(cls, "fieldDiscriminator"): + raise ValueError(f"{cls} has a fieldDiscriminator") + + assert set(models.CostClass.RECORD_SCHEMA.fields_dict.keys()) == { + "cost", + "costType", + } + assert set(models.CostCostClass.RECORD_SCHEMA.fields_dict.keys()) == { + "fieldDiscriminator", + "costId", + "costCode", + } diff --git a/metadata-ingestion/tests/unit/test_rest_sink.py b/metadata-ingestion/tests/unit/test_rest_sink.py index 8390e8be0ed375..82e02aced5a670 100644 --- a/metadata-ingestion/tests/unit/test_rest_sink.py +++ b/metadata-ingestion/tests/unit/test_rest_sink.py @@ -78,42 +78,6 @@ "systemMetadata": {}, }, ), - ( - # Verify the behavior of the fieldDiscriminator for primitive enums. - models.MetadataChangeEventClass( - proposedSnapshot=models.MLModelSnapshotClass( - urn="urn:li:mlModel:(urn:li:dataPlatform:science,scienceModel,PROD)", - aspects=[ - models.CostClass( - costType=models.CostTypeClass.ORG_COST_TYPE, - cost=models.CostCostClass( - fieldDiscriminator=models.CostCostDiscriminatorClass.costCode, - costCode="sampleCostCode", - ), - ) - ], - ) - ), - "/entities?action=ingest", - { - "entity": { - "value": { - "com.linkedin.metadata.snapshot.MLModelSnapshot": { - "urn": "urn:li:mlModel:(urn:li:dataPlatform:science,scienceModel,PROD)", - "aspects": [ - { - "com.linkedin.common.Cost": { - "costType": "ORG_COST_TYPE", - "cost": {"costCode": "sampleCostCode"}, - } - } - ], - } - } - }, - "systemMetadata": {}, - }, - ), ( # Verify the serialization behavior with chart type enums. models.MetadataChangeEventClass(