From 00f89bf267fe98f3ac288c4bf65fbed085440666 Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Fri, 14 Oct 2022 21:48:44 +0000 Subject: [PATCH] feat(ingest): replace base85's pickle with json (#6178) --- .../ingestion/source/state/checkpoint.py | 51 +++++- .../state/test_checkpoint.py | 165 ++++++++++++------ 2 files changed, 156 insertions(+), 60 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/state/checkpoint.py b/metadata-ingestion/src/datahub/ingestion/source/state/checkpoint.py index 36238842f9ddcc..7c233a3a39da73 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state/checkpoint.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state/checkpoint.py @@ -18,6 +18,8 @@ logger: logging.Logger = logging.getLogger(__name__) +DEFAULT_MAX_STATE_SIZE = 2**22 # 4MB + class CheckpointStateBase(ConfigModel): """ @@ -28,17 +30,14 @@ class CheckpointStateBase(ConfigModel): """ version: str = pydantic.Field(default="1.0") - serde: str = pydantic.Field(default="base85") + serde: str = pydantic.Field(default="base85-bz2-json") def to_bytes( self, compressor: Callable[[bytes], bytes] = functools.partial( bz2.compress, compresslevel=9 ), - # fmt: off - # 4 MB - max_allowed_state_size: int = 2**22, - # fmt: on + max_allowed_state_size: int = DEFAULT_MAX_STATE_SIZE, ) -> bytes: """ NOTE: Binary compression cannot be turned on yet as the current MCPs encode the GeneralizedAspect @@ -50,7 +49,13 @@ def to_bytes( if self.serde == "utf-8": encoded_bytes = CheckpointStateBase._to_bytes_utf8(self) elif self.serde == "base85": - encoded_bytes = CheckpointStateBase._to_bytes_base85(self, compressor) + # The original base85 implementation used pickle, which would cause + # issues with deserialization if we ever changed the state class definition. + raise ValueError( + "Cannot write base85 encoded bytes. Use base85-bz2-json instead." + ) + elif self.serde == "base85-bz2-json": + encoded_bytes = CheckpointStateBase._to_bytes_base85_json(self, compressor) else: raise ValueError(f"Unknown serde: {self.serde}") @@ -66,10 +71,10 @@ def _to_bytes_utf8(model: ConfigModel) -> bytes: return model.json(exclude={"version", "serde"}).encode("utf-8") @staticmethod - def _to_bytes_base85( + def _to_bytes_base85_json( model: ConfigModel, compressor: Callable[[bytes], bytes] ) -> bytes: - return base64.b85encode(compressor(pickle.dumps(model))) + return base64.b85encode(compressor(CheckpointStateBase._to_bytes_utf8(model))) def prepare_for_commit(self) -> None: """ @@ -125,6 +130,12 @@ def create_from_checkpoint_aspect( state_obj = Checkpoint._from_base85_bytes( checkpoint_aspect, functools.partial(bz2.decompress) ) + elif checkpoint_aspect.state.serde == "base85-bz2-json": + state_obj = Checkpoint._from_base85_json_bytes( + checkpoint_aspect, + functools.partial(bz2.decompress), + state_class, + ) else: raise ValueError(f"Unknown serde: {checkpoint_aspect.state.serde}") except Exception as e: @@ -167,10 +178,32 @@ def _from_base85_bytes( checkpoint_aspect: DatahubIngestionCheckpointClass, decompressor: Callable[[bytes], bytes], ) -> StateType: - return pickle.loads( + state: StateType = pickle.loads( decompressor(base64.b85decode(checkpoint_aspect.state.payload)) # type: ignore ) + # Because the base85 method is deprecated in favor of base85-bz2-json, + # we will automatically switch the serde. + state.serde = "base85-bz2-json" + + return state + + @staticmethod + def _from_base85_json_bytes( + checkpoint_aspect: DatahubIngestionCheckpointClass, + decompressor: Callable[[bytes], bytes], + state_class: Type[StateType], + ) -> StateType: + state_uncompressed = decompressor( + base64.b85decode(checkpoint_aspect.state.payload) + if checkpoint_aspect.state.payload is not None + else b"{}" + ) + state_as_dict = json.loads(state_uncompressed.decode("utf-8")) + state_as_dict["version"] = checkpoint_aspect.state.formatVersion + state_as_dict["serde"] = checkpoint_aspect.state.serde + return state_class.parse_obj(state_as_dict) + def to_checkpoint_aspect( self, max_allowed_state_size: int ) -> Optional[DatahubIngestionCheckpointClass]: diff --git a/metadata-ingestion/tests/unit/stateful_ingestion/state/test_checkpoint.py b/metadata-ingestion/tests/unit/stateful_ingestion/state/test_checkpoint.py index 1f44a5cfddc10a..b07eb9dca3d603 100644 --- a/metadata-ingestion/tests/unit/stateful_ingestion/state/test_checkpoint.py +++ b/metadata-ingestion/tests/unit/stateful_ingestion/state/test_checkpoint.py @@ -1,6 +1,7 @@ from datetime import datetime -from typing import Dict +from typing import Dict, List +import pydantic import pytest from datahub.emitter.mce_builder import make_dataset_urn @@ -23,54 +24,18 @@ test_run_id: str = "test_run_1" test_source_config: BasicSQLAlchemyConfig = PostgresConfig(host_port="test_host:1234") -# 2. Create the params for parametrized tests. - -# 2.1 Create and add an instance of BaseSQLAlchemyCheckpointState. -test_checkpoint_serde_params: Dict[str, CheckpointStateBase] = {} -base_sql_alchemy_checkpoint_state_obj = BaseSQLAlchemyCheckpointState() -base_sql_alchemy_checkpoint_state_obj.add_checkpoint_urn( - type="table", urn=make_dataset_urn("mysql", "db1.t1", "prod") -) -base_sql_alchemy_checkpoint_state_obj.add_checkpoint_urn( - type="view", urn=make_dataset_urn("mysql", "db1.v1", "prod") -) -test_checkpoint_serde_params[ - "BaseSQLAlchemyCheckpointState" -] = base_sql_alchemy_checkpoint_state_obj - -# 2.2 Create and add an instance of BaseUsageCheckpointState. -base_usage_checkpoint_state_obj = BaseUsageCheckpointState( - version="2.0", begin_timestamp_millis=1, end_timestamp_millis=100 -) -test_checkpoint_serde_params[ - "BaseUsageCheckpointState" -] = base_usage_checkpoint_state_obj - - -# 3. Define the test with the params - -@pytest.mark.parametrize( - "state_obj", - test_checkpoint_serde_params.values(), - ids=test_checkpoint_serde_params.keys(), -) -def test_create_from_checkpoint_aspect(state_obj): - """ - Tests the Checkpoint class API 'create_from_checkpoint_aspect' with the state_obj parameter as the state. - """ - # 1. Construct the raw aspect object with the state - checkpoint_state = IngestionCheckpointStateClass( - formatVersion=state_obj.version, - serde=state_obj.serde, - payload=state_obj.to_bytes(), - ) +def _assert_checkpoint_deserialization( + serialized_checkpoint_state: IngestionCheckpointStateClass, + expected_checkpoint_state: CheckpointStateBase, +) -> Checkpoint: + # Serialize a checkpoint aspect with the previous state. checkpoint_aspect = DatahubIngestionCheckpointClass( timestampMillis=int(datetime.utcnow().timestamp() * 1000), pipelineName=test_pipeline_name, platformInstanceId=test_platform_instance_id, config=test_source_config.json(), - state=checkpoint_state, + state=serialized_checkpoint_state, runId=test_run_id, ) @@ -78,7 +43,7 @@ def test_create_from_checkpoint_aspect(state_obj): checkpoint_obj = Checkpoint.create_from_checkpoint_aspect( job_name=test_job_name, checkpoint_aspect=checkpoint_aspect, - state_class=type(state_obj), + state_class=type(expected_checkpoint_state), config_class=PostgresConfig, ) @@ -88,15 +53,69 @@ def test_create_from_checkpoint_aspect(state_obj): platform_instance_id=test_platform_instance_id, run_id=test_run_id, config=test_source_config, - state=state_obj, + state=expected_checkpoint_state, ) assert checkpoint_obj == expected_checkpoint_obj + return checkpoint_obj + + +# 2. Create the params for parametrized tests. + + +def _make_sql_alchemy_checkpoint_state() -> BaseSQLAlchemyCheckpointState: + base_sql_alchemy_checkpoint_state_obj = BaseSQLAlchemyCheckpointState() + base_sql_alchemy_checkpoint_state_obj.add_checkpoint_urn( + type="table", urn=make_dataset_urn("mysql", "db1.t1", "prod") + ) + base_sql_alchemy_checkpoint_state_obj.add_checkpoint_urn( + type="view", urn=make_dataset_urn("mysql", "db1.v1", "prod") + ) + return base_sql_alchemy_checkpoint_state_obj + + +def _make_usage_checkpoint_state() -> BaseUsageCheckpointState: + base_usage_checkpoint_state_obj = BaseUsageCheckpointState( + version="2.0", begin_timestamp_millis=1, end_timestamp_millis=100 + ) + return base_usage_checkpoint_state_obj + + +_checkpoint_aspect_test_cases: Dict[str, CheckpointStateBase] = { + # An instance of BaseSQLAlchemyCheckpointState. + "BaseSQLAlchemyCheckpointState": _make_sql_alchemy_checkpoint_state(), + # An instance of BaseUsageCheckpointState. + "BaseUsageCheckpointState": _make_usage_checkpoint_state(), +} + + +# 3. Define the test with the params + + +@pytest.mark.parametrize( + "state_obj", + _checkpoint_aspect_test_cases.values(), + ids=_checkpoint_aspect_test_cases.keys(), +) +def test_checkpoint_serde(state_obj: CheckpointStateBase) -> None: + """ + Tests CheckpointStateBase.to_bytes() and Checkpoint.create_from_checkpoint_aspect(). + """ + + # 1. Construct the raw aspect object with the state + checkpoint_state = IngestionCheckpointStateClass( + formatVersion=state_obj.version, + serde=state_obj.serde, + payload=state_obj.to_bytes(), + ) + + _assert_checkpoint_deserialization(checkpoint_state, state_obj) + @pytest.mark.parametrize( "state_obj", - test_checkpoint_serde_params.values(), - ids=test_checkpoint_serde_params.keys(), + _checkpoint_aspect_test_cases.values(), + ids=_checkpoint_aspect_test_cases.keys(), ) def test_serde_idempotence(state_obj): """ @@ -114,9 +133,7 @@ def test_serde_idempotence(state_obj): # 2. Convert it to the aspect form. checkpoint_aspect = orig_checkpoint_obj.to_checkpoint_aspect( - # fmt: off max_allowed_state_size=2**20 - # fmt: on ) assert checkpoint_aspect is not None @@ -132,7 +149,7 @@ def test_serde_idempotence(state_obj): def test_supported_encodings(): """ - Tests utf-8 and base85 encodings + Tests utf-8 and base85-bz2-json encodings """ test_state = BaseUsageCheckpointState( version="1.0", begin_timestamp_millis=1, end_timestamp_millis=100 @@ -143,5 +160,51 @@ def test_supported_encodings(): test_serde_idempotence(test_state) # 2. Test Base85 encoding - test_state.serde = "base85" + test_state.serde = "base85-bz2-json" test_serde_idempotence(test_state) + + +def test_base85_upgrade_pickle_to_json(): + """Verify that base85 (pickle) encoding is transitioned to base85-bz2-json.""" + + base85_payload = b"LRx4!F+o`-Q&~9zyaE6Km;c~@!8ry1Vd6kI1ULe}@BgM?1daeO0O_j`RP>&v5Eub8X^>>mqalb7C^byc8UsjrKmgDKAR1|q0#p(YC>k_rkk9}C0g>tf5XN6Ukbt0I-PV9G8w@zi7T+Sfbo$@HCtElKF-WJ9s~2<3(ryuxT}MN0DW*v>5|o${#bF{|bU_>|0pOAXZ$h9H+K5Hnfao6Yg3#ELVd4P5nvyX?j>N&ZHgqcR1Zc?#LWa^1m=n None: + class PrevState(CheckpointStateBase): + list_a: List[str] + list_b: List[str] + + class NextState(CheckpointStateBase): + list_stuff: List[str] + + @pydantic.root_validator(pre=True, allow_reuse=True) + def _migrate(cls, values: dict) -> dict: + values.setdefault("list_stuff", []) + values["list_stuff"] += values.pop("list_a", []) + values["list_stuff"] += values.pop("list_b", []) + return values + + prev_state = PrevState(list_a=["a", "b"], list_b=["c", "d"], serde=serde) + expected_next_state = NextState(list_stuff=["a", "b", "c", "d"], serde=serde) + + checkpoint_state = IngestionCheckpointStateClass( + formatVersion=prev_state.version, + serde=prev_state.serde, + payload=prev_state.to_bytes(), + ) + + _assert_checkpoint_deserialization(checkpoint_state, expected_next_state)