Skip to content

Commit

Permalink
feat(ingest): replace base85's pickle with json (datahub-project#6178)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored and cccs-tom committed Nov 18, 2022
1 parent ce6c3a7 commit 00f89bf
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

logger: logging.Logger = logging.getLogger(__name__)

DEFAULT_MAX_STATE_SIZE = 2**22 # 4MB


class CheckpointStateBase(ConfigModel):
"""
Expand All @@ -28,17 +30,14 @@ class CheckpointStateBase(ConfigModel):
"""

version: str = pydantic.Field(default="1.0")
serde: str = pydantic.Field(default="base85")
serde: str = pydantic.Field(default="base85-bz2-json")

def to_bytes(
self,
compressor: Callable[[bytes], bytes] = functools.partial(
bz2.compress, compresslevel=9
),
# fmt: off
# 4 MB
max_allowed_state_size: int = 2**22,
# fmt: on
max_allowed_state_size: int = DEFAULT_MAX_STATE_SIZE,
) -> bytes:
"""
NOTE: Binary compression cannot be turned on yet as the current MCPs encode the GeneralizedAspect
Expand All @@ -50,7 +49,13 @@ def to_bytes(
if self.serde == "utf-8":
encoded_bytes = CheckpointStateBase._to_bytes_utf8(self)
elif self.serde == "base85":
encoded_bytes = CheckpointStateBase._to_bytes_base85(self, compressor)
# The original base85 implementation used pickle, which would cause
# issues with deserialization if we ever changed the state class definition.
raise ValueError(
"Cannot write base85 encoded bytes. Use base85-bz2-json instead."
)
elif self.serde == "base85-bz2-json":
encoded_bytes = CheckpointStateBase._to_bytes_base85_json(self, compressor)
else:
raise ValueError(f"Unknown serde: {self.serde}")

Expand All @@ -66,10 +71,10 @@ def _to_bytes_utf8(model: ConfigModel) -> bytes:
return model.json(exclude={"version", "serde"}).encode("utf-8")

@staticmethod
def _to_bytes_base85(
def _to_bytes_base85_json(
model: ConfigModel, compressor: Callable[[bytes], bytes]
) -> bytes:
return base64.b85encode(compressor(pickle.dumps(model)))
return base64.b85encode(compressor(CheckpointStateBase._to_bytes_utf8(model)))

def prepare_for_commit(self) -> None:
"""
Expand Down Expand Up @@ -125,6 +130,12 @@ def create_from_checkpoint_aspect(
state_obj = Checkpoint._from_base85_bytes(
checkpoint_aspect, functools.partial(bz2.decompress)
)
elif checkpoint_aspect.state.serde == "base85-bz2-json":
state_obj = Checkpoint._from_base85_json_bytes(
checkpoint_aspect,
functools.partial(bz2.decompress),
state_class,
)
else:
raise ValueError(f"Unknown serde: {checkpoint_aspect.state.serde}")
except Exception as e:
Expand Down Expand Up @@ -167,10 +178,32 @@ def _from_base85_bytes(
checkpoint_aspect: DatahubIngestionCheckpointClass,
decompressor: Callable[[bytes], bytes],
) -> StateType:
return pickle.loads(
state: StateType = pickle.loads(
decompressor(base64.b85decode(checkpoint_aspect.state.payload)) # type: ignore
)

# Because the base85 method is deprecated in favor of base85-bz2-json,
# we will automatically switch the serde.
state.serde = "base85-bz2-json"

return state

@staticmethod
def _from_base85_json_bytes(
checkpoint_aspect: DatahubIngestionCheckpointClass,
decompressor: Callable[[bytes], bytes],
state_class: Type[StateType],
) -> StateType:
state_uncompressed = decompressor(
base64.b85decode(checkpoint_aspect.state.payload)
if checkpoint_aspect.state.payload is not None
else b"{}"
)
state_as_dict = json.loads(state_uncompressed.decode("utf-8"))
state_as_dict["version"] = checkpoint_aspect.state.formatVersion
state_as_dict["serde"] = checkpoint_aspect.state.serde
return state_class.parse_obj(state_as_dict)

def to_checkpoint_aspect(
self, max_allowed_state_size: int
) -> Optional[DatahubIngestionCheckpointClass]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import datetime
from typing import Dict
from typing import Dict, List

import pydantic
import pytest

from datahub.emitter.mce_builder import make_dataset_urn
Expand All @@ -23,62 +24,26 @@
test_run_id: str = "test_run_1"
test_source_config: BasicSQLAlchemyConfig = PostgresConfig(host_port="test_host:1234")

# 2. Create the params for parametrized tests.

# 2.1 Create and add an instance of BaseSQLAlchemyCheckpointState.
test_checkpoint_serde_params: Dict[str, CheckpointStateBase] = {}
base_sql_alchemy_checkpoint_state_obj = BaseSQLAlchemyCheckpointState()
base_sql_alchemy_checkpoint_state_obj.add_checkpoint_urn(
type="table", urn=make_dataset_urn("mysql", "db1.t1", "prod")
)
base_sql_alchemy_checkpoint_state_obj.add_checkpoint_urn(
type="view", urn=make_dataset_urn("mysql", "db1.v1", "prod")
)
test_checkpoint_serde_params[
"BaseSQLAlchemyCheckpointState"
] = base_sql_alchemy_checkpoint_state_obj

# 2.2 Create and add an instance of BaseUsageCheckpointState.
base_usage_checkpoint_state_obj = BaseUsageCheckpointState(
version="2.0", begin_timestamp_millis=1, end_timestamp_millis=100
)
test_checkpoint_serde_params[
"BaseUsageCheckpointState"
] = base_usage_checkpoint_state_obj


# 3. Define the test with the params


@pytest.mark.parametrize(
"state_obj",
test_checkpoint_serde_params.values(),
ids=test_checkpoint_serde_params.keys(),
)
def test_create_from_checkpoint_aspect(state_obj):
"""
Tests the Checkpoint class API 'create_from_checkpoint_aspect' with the state_obj parameter as the state.
"""
# 1. Construct the raw aspect object with the state
checkpoint_state = IngestionCheckpointStateClass(
formatVersion=state_obj.version,
serde=state_obj.serde,
payload=state_obj.to_bytes(),
)
def _assert_checkpoint_deserialization(
serialized_checkpoint_state: IngestionCheckpointStateClass,
expected_checkpoint_state: CheckpointStateBase,
) -> Checkpoint:
# Serialize a checkpoint aspect with the previous state.
checkpoint_aspect = DatahubIngestionCheckpointClass(
timestampMillis=int(datetime.utcnow().timestamp() * 1000),
pipelineName=test_pipeline_name,
platformInstanceId=test_platform_instance_id,
config=test_source_config.json(),
state=checkpoint_state,
state=serialized_checkpoint_state,
runId=test_run_id,
)

# 2. Create the checkpoint from the raw checkpoint aspect and validate.
checkpoint_obj = Checkpoint.create_from_checkpoint_aspect(
job_name=test_job_name,
checkpoint_aspect=checkpoint_aspect,
state_class=type(state_obj),
state_class=type(expected_checkpoint_state),
config_class=PostgresConfig,
)

Expand All @@ -88,15 +53,69 @@ def test_create_from_checkpoint_aspect(state_obj):
platform_instance_id=test_platform_instance_id,
run_id=test_run_id,
config=test_source_config,
state=state_obj,
state=expected_checkpoint_state,
)
assert checkpoint_obj == expected_checkpoint_obj

return checkpoint_obj


# 2. Create the params for parametrized tests.


def _make_sql_alchemy_checkpoint_state() -> BaseSQLAlchemyCheckpointState:
base_sql_alchemy_checkpoint_state_obj = BaseSQLAlchemyCheckpointState()
base_sql_alchemy_checkpoint_state_obj.add_checkpoint_urn(
type="table", urn=make_dataset_urn("mysql", "db1.t1", "prod")
)
base_sql_alchemy_checkpoint_state_obj.add_checkpoint_urn(
type="view", urn=make_dataset_urn("mysql", "db1.v1", "prod")
)
return base_sql_alchemy_checkpoint_state_obj


def _make_usage_checkpoint_state() -> BaseUsageCheckpointState:
base_usage_checkpoint_state_obj = BaseUsageCheckpointState(
version="2.0", begin_timestamp_millis=1, end_timestamp_millis=100
)
return base_usage_checkpoint_state_obj


_checkpoint_aspect_test_cases: Dict[str, CheckpointStateBase] = {
# An instance of BaseSQLAlchemyCheckpointState.
"BaseSQLAlchemyCheckpointState": _make_sql_alchemy_checkpoint_state(),
# An instance of BaseUsageCheckpointState.
"BaseUsageCheckpointState": _make_usage_checkpoint_state(),
}


# 3. Define the test with the params


@pytest.mark.parametrize(
"state_obj",
_checkpoint_aspect_test_cases.values(),
ids=_checkpoint_aspect_test_cases.keys(),
)
def test_checkpoint_serde(state_obj: CheckpointStateBase) -> None:
"""
Tests CheckpointStateBase.to_bytes() and Checkpoint.create_from_checkpoint_aspect().
"""

# 1. Construct the raw aspect object with the state
checkpoint_state = IngestionCheckpointStateClass(
formatVersion=state_obj.version,
serde=state_obj.serde,
payload=state_obj.to_bytes(),
)

_assert_checkpoint_deserialization(checkpoint_state, state_obj)


@pytest.mark.parametrize(
"state_obj",
test_checkpoint_serde_params.values(),
ids=test_checkpoint_serde_params.keys(),
_checkpoint_aspect_test_cases.values(),
ids=_checkpoint_aspect_test_cases.keys(),
)
def test_serde_idempotence(state_obj):
"""
Expand All @@ -114,9 +133,7 @@ def test_serde_idempotence(state_obj):

# 2. Convert it to the aspect form.
checkpoint_aspect = orig_checkpoint_obj.to_checkpoint_aspect(
# fmt: off
max_allowed_state_size=2**20
# fmt: on
)
assert checkpoint_aspect is not None

Expand All @@ -132,7 +149,7 @@ def test_serde_idempotence(state_obj):

def test_supported_encodings():
"""
Tests utf-8 and base85 encodings
Tests utf-8 and base85-bz2-json encodings
"""
test_state = BaseUsageCheckpointState(
version="1.0", begin_timestamp_millis=1, end_timestamp_millis=100
Expand All @@ -143,5 +160,51 @@ def test_supported_encodings():
test_serde_idempotence(test_state)

# 2. Test Base85 encoding
test_state.serde = "base85"
test_state.serde = "base85-bz2-json"
test_serde_idempotence(test_state)


def test_base85_upgrade_pickle_to_json():
"""Verify that base85 (pickle) encoding is transitioned to base85-bz2-json."""

base85_payload = b"LRx4!F+o`-Q&~9zyaE6Km;c~@!8ry1Vd6kI1ULe}@BgM?1daeO0O_j`RP>&v5Eub8X^>>mqalb7C^byc8UsjrKmgDKAR1|q0#p(YC>k_rkk9}C0g>tf5XN6Ukbt0I-PV9G8w@zi7T+Sfbo$@HCtElKF-WJ9s~2<3(ryuxT}MN0DW*v>5|o${#bF{|bU_>|0pOAXZ$h9H+K5Hnfao<V0t4|A&l|ECl%3a~3snn}%ap>6Y<yIr$4eZIcxS2Ig`q(J&`QRF$0_OwQfa!>g3#ELVd4P5nvyX?j>N&ZHgqcR1Zc?#LWa^1m=n<!NpoAI5xrS(_*3yB*fiuZ44Funf%Sq?N|V|85WFwtbQE8kLB%FHC-}RPDZ+$-$Q9ra"
checkpoint_state = IngestionCheckpointStateClass(
formatVersion="1.0", serde="base85", payload=base85_payload
)

checkpoint = _assert_checkpoint_deserialization(
checkpoint_state, _checkpoint_aspect_test_cases["BaseSQLAlchemyCheckpointState"]
)
assert checkpoint.state.serde == "base85-bz2-json"
assert len(checkpoint.state.to_bytes()) < len(base85_payload)


@pytest.mark.parametrize(
"serde",
["utf-8", "base85-bz2-json"],
)
def test_state_forward_compatibility(serde: str) -> None:
class PrevState(CheckpointStateBase):
list_a: List[str]
list_b: List[str]

class NextState(CheckpointStateBase):
list_stuff: List[str]

@pydantic.root_validator(pre=True, allow_reuse=True)
def _migrate(cls, values: dict) -> dict:
values.setdefault("list_stuff", [])
values["list_stuff"] += values.pop("list_a", [])
values["list_stuff"] += values.pop("list_b", [])
return values

prev_state = PrevState(list_a=["a", "b"], list_b=["c", "d"], serde=serde)
expected_next_state = NextState(list_stuff=["a", "b", "c", "d"], serde=serde)

checkpoint_state = IngestionCheckpointStateClass(
formatVersion=prev_state.version,
serde=prev_state.serde,
payload=prev_state.to_bytes(),
)

_assert_checkpoint_deserialization(checkpoint_state, expected_next_state)

0 comments on commit 00f89bf

Please sign in to comment.