Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingest): replace base85's pickle with json #6178

Merged
merged 3 commits into from
Oct 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

logger: logging.Logger = logging.getLogger(__name__)

DEFAULT_MAX_STATE_SIZE = 2**22 # 4MB


class CheckpointStateBase(ConfigModel):
"""
Expand All @@ -28,17 +30,14 @@ class CheckpointStateBase(ConfigModel):
"""

version: str = pydantic.Field(default="1.0")
serde: str = pydantic.Field(default="base85")
serde: str = pydantic.Field(default="base85-bz2-json")

def to_bytes(
self,
compressor: Callable[[bytes], bytes] = functools.partial(
bz2.compress, compresslevel=9
),
# fmt: off
# 4 MB
max_allowed_state_size: int = 2**22,
# fmt: on
max_allowed_state_size: int = DEFAULT_MAX_STATE_SIZE,
) -> bytes:
"""
NOTE: Binary compression cannot be turned on yet as the current MCPs encode the GeneralizedAspect
Expand All @@ -50,7 +49,13 @@ def to_bytes(
if self.serde == "utf-8":
encoded_bytes = CheckpointStateBase._to_bytes_utf8(self)
elif self.serde == "base85":
encoded_bytes = CheckpointStateBase._to_bytes_base85(self, compressor)
# The original base85 implementation used pickle, which would cause
# issues with deserialization if we ever changed the state class definition.
raise ValueError(
hsheth2 marked this conversation as resolved.
Show resolved Hide resolved
"Cannot write base85 encoded bytes. Use base85-bz2-json instead."
)
elif self.serde == "base85-bz2-json":
encoded_bytes = CheckpointStateBase._to_bytes_base85_json(self, compressor)
else:
raise ValueError(f"Unknown serde: {self.serde}")

Expand All @@ -66,10 +71,10 @@ def _to_bytes_utf8(model: ConfigModel) -> bytes:
return model.json(exclude={"version", "serde"}).encode("utf-8")

@staticmethod
def _to_bytes_base85(
def _to_bytes_base85_json(
model: ConfigModel, compressor: Callable[[bytes], bytes]
) -> bytes:
return base64.b85encode(compressor(pickle.dumps(model)))
return base64.b85encode(compressor(CheckpointStateBase._to_bytes_utf8(model)))

def prepare_for_commit(self) -> None:
"""
Expand Down Expand Up @@ -125,6 +130,12 @@ def create_from_checkpoint_aspect(
state_obj = Checkpoint._from_base85_bytes(
checkpoint_aspect, functools.partial(bz2.decompress)
)
elif checkpoint_aspect.state.serde == "base85-bz2-json":
state_obj = Checkpoint._from_base85_json_bytes(
checkpoint_aspect,
functools.partial(bz2.decompress),
state_class,
)
else:
raise ValueError(f"Unknown serde: {checkpoint_aspect.state.serde}")
except Exception as e:
Expand Down Expand Up @@ -167,10 +178,32 @@ def _from_base85_bytes(
checkpoint_aspect: DatahubIngestionCheckpointClass,
decompressor: Callable[[bytes], bytes],
) -> StateType:
return pickle.loads(
state: StateType = pickle.loads(
decompressor(base64.b85decode(checkpoint_aspect.state.payload)) # type: ignore
)

# Because the base85 method is deprecated in favor of base85-bz2-json,
# we will automatically switch the serde.
state.serde = "base85-bz2-json"

return state

@staticmethod
def _from_base85_json_bytes(
checkpoint_aspect: DatahubIngestionCheckpointClass,
decompressor: Callable[[bytes], bytes],
state_class: Type[StateType],
) -> StateType:
state_uncompressed = decompressor(
base64.b85decode(checkpoint_aspect.state.payload)
if checkpoint_aspect.state.payload is not None
else b"{}"
)
state_as_dict = json.loads(state_uncompressed.decode("utf-8"))
state_as_dict["version"] = checkpoint_aspect.state.formatVersion
state_as_dict["serde"] = checkpoint_aspect.state.serde
return state_class.parse_obj(state_as_dict)

def to_checkpoint_aspect(
self, max_allowed_state_size: int
) -> Optional[DatahubIngestionCheckpointClass]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import datetime
from typing import Dict
from typing import Dict, List

import pydantic
import pytest

from datahub.emitter.mce_builder import make_dataset_urn
Expand All @@ -23,62 +24,26 @@
test_run_id: str = "test_run_1"
test_source_config: BasicSQLAlchemyConfig = PostgresConfig(host_port="test_host:1234")

# 2. Create the params for parametrized tests.

# 2.1 Create and add an instance of BaseSQLAlchemyCheckpointState.
test_checkpoint_serde_params: Dict[str, CheckpointStateBase] = {}
base_sql_alchemy_checkpoint_state_obj = BaseSQLAlchemyCheckpointState()
base_sql_alchemy_checkpoint_state_obj.add_checkpoint_urn(
type="table", urn=make_dataset_urn("mysql", "db1.t1", "prod")
)
base_sql_alchemy_checkpoint_state_obj.add_checkpoint_urn(
type="view", urn=make_dataset_urn("mysql", "db1.v1", "prod")
)
test_checkpoint_serde_params[
"BaseSQLAlchemyCheckpointState"
] = base_sql_alchemy_checkpoint_state_obj

# 2.2 Create and add an instance of BaseUsageCheckpointState.
base_usage_checkpoint_state_obj = BaseUsageCheckpointState(
version="2.0", begin_timestamp_millis=1, end_timestamp_millis=100
)
test_checkpoint_serde_params[
"BaseUsageCheckpointState"
] = base_usage_checkpoint_state_obj


# 3. Define the test with the params


@pytest.mark.parametrize(
"state_obj",
test_checkpoint_serde_params.values(),
ids=test_checkpoint_serde_params.keys(),
)
def test_create_from_checkpoint_aspect(state_obj):
"""
Tests the Checkpoint class API 'create_from_checkpoint_aspect' with the state_obj parameter as the state.
"""
# 1. Construct the raw aspect object with the state
checkpoint_state = IngestionCheckpointStateClass(
formatVersion=state_obj.version,
serde=state_obj.serde,
payload=state_obj.to_bytes(),
)
def _assert_checkpoint_deserialization(
serialized_checkpoint_state: IngestionCheckpointStateClass,
expected_checkpoint_state: CheckpointStateBase,
) -> Checkpoint:
# Serialize a checkpoint aspect with the previous state.
checkpoint_aspect = DatahubIngestionCheckpointClass(
timestampMillis=int(datetime.utcnow().timestamp() * 1000),
pipelineName=test_pipeline_name,
platformInstanceId=test_platform_instance_id,
config=test_source_config.json(),
state=checkpoint_state,
state=serialized_checkpoint_state,
runId=test_run_id,
)

# 2. Create the checkpoint from the raw checkpoint aspect and validate.
checkpoint_obj = Checkpoint.create_from_checkpoint_aspect(
job_name=test_job_name,
checkpoint_aspect=checkpoint_aspect,
state_class=type(state_obj),
state_class=type(expected_checkpoint_state),
config_class=PostgresConfig,
)

Expand All @@ -88,15 +53,69 @@ def test_create_from_checkpoint_aspect(state_obj):
platform_instance_id=test_platform_instance_id,
run_id=test_run_id,
config=test_source_config,
state=state_obj,
state=expected_checkpoint_state,
)
assert checkpoint_obj == expected_checkpoint_obj

return checkpoint_obj


# 2. Create the params for parametrized tests.


def _make_sql_alchemy_checkpoint_state() -> BaseSQLAlchemyCheckpointState:
base_sql_alchemy_checkpoint_state_obj = BaseSQLAlchemyCheckpointState()
base_sql_alchemy_checkpoint_state_obj.add_checkpoint_urn(
type="table", urn=make_dataset_urn("mysql", "db1.t1", "prod")
)
base_sql_alchemy_checkpoint_state_obj.add_checkpoint_urn(
type="view", urn=make_dataset_urn("mysql", "db1.v1", "prod")
)
return base_sql_alchemy_checkpoint_state_obj


def _make_usage_checkpoint_state() -> BaseUsageCheckpointState:
base_usage_checkpoint_state_obj = BaseUsageCheckpointState(
version="2.0", begin_timestamp_millis=1, end_timestamp_millis=100
)
return base_usage_checkpoint_state_obj


_checkpoint_aspect_test_cases: Dict[str, CheckpointStateBase] = {
# An instance of BaseSQLAlchemyCheckpointState.
"BaseSQLAlchemyCheckpointState": _make_sql_alchemy_checkpoint_state(),
# An instance of BaseUsageCheckpointState.
"BaseUsageCheckpointState": _make_usage_checkpoint_state(),
}


# 3. Define the test with the params


@pytest.mark.parametrize(
"state_obj",
_checkpoint_aspect_test_cases.values(),
ids=_checkpoint_aspect_test_cases.keys(),
)
def test_checkpoint_serde(state_obj: CheckpointStateBase) -> None:
"""
Tests CheckpointStateBase.to_bytes() and Checkpoint.create_from_checkpoint_aspect().
"""

# 1. Construct the raw aspect object with the state
checkpoint_state = IngestionCheckpointStateClass(
formatVersion=state_obj.version,
serde=state_obj.serde,
payload=state_obj.to_bytes(),
)

_assert_checkpoint_deserialization(checkpoint_state, state_obj)


@pytest.mark.parametrize(
"state_obj",
test_checkpoint_serde_params.values(),
ids=test_checkpoint_serde_params.keys(),
_checkpoint_aspect_test_cases.values(),
ids=_checkpoint_aspect_test_cases.keys(),
)
def test_serde_idempotence(state_obj):
"""
Expand All @@ -114,9 +133,7 @@ def test_serde_idempotence(state_obj):

# 2. Convert it to the aspect form.
checkpoint_aspect = orig_checkpoint_obj.to_checkpoint_aspect(
# fmt: off
max_allowed_state_size=2**20
# fmt: on
)
assert checkpoint_aspect is not None

Expand All @@ -132,7 +149,7 @@ def test_serde_idempotence(state_obj):

def test_supported_encodings():
"""
Tests utf-8 and base85 encodings
Tests utf-8 and base85-bz2-json encodings
"""
test_state = BaseUsageCheckpointState(
version="1.0", begin_timestamp_millis=1, end_timestamp_millis=100
Expand All @@ -143,5 +160,51 @@ def test_supported_encodings():
test_serde_idempotence(test_state)

# 2. Test Base85 encoding
test_state.serde = "base85"
test_state.serde = "base85-bz2-json"
test_serde_idempotence(test_state)


def test_base85_upgrade_pickle_to_json():
"""Verify that base85 (pickle) encoding is transitioned to base85-bz2-json."""

base85_payload = b"LRx4!F+o`-Q&~9zyaE6Km;c~@!8ry1Vd6kI1ULe}@BgM?1daeO0O_j`RP>&v5Eub8X^>>mqalb7C^byc8UsjrKmgDKAR1|q0#p(YC>k_rkk9}C0g>tf5XN6Ukbt0I-PV9G8w@zi7T+Sfbo$@HCtElKF-WJ9s~2<3(ryuxT}MN0DW*v>5|o${#bF{|bU_>|0pOAXZ$h9H+K5Hnfao<V0t4|A&l|ECl%3a~3snn}%ap>6Y<yIr$4eZIcxS2Ig`q(J&`QRF$0_OwQfa!>g3#ELVd4P5nvyX?j>N&ZHgqcR1Zc?#LWa^1m=n<!NpoAI5xrS(_*3yB*fiuZ44Funf%Sq?N|V|85WFwtbQE8kLB%FHC-}RPDZ+$-$Q9ra"
checkpoint_state = IngestionCheckpointStateClass(
formatVersion="1.0", serde="base85", payload=base85_payload
)

checkpoint = _assert_checkpoint_deserialization(
checkpoint_state, _checkpoint_aspect_test_cases["BaseSQLAlchemyCheckpointState"]
)
assert checkpoint.state.serde == "base85-bz2-json"
assert len(checkpoint.state.to_bytes()) < len(base85_payload)


@pytest.mark.parametrize(
"serde",
["utf-8", "base85-bz2-json"],
)
def test_state_forward_compatibility(serde: str) -> None:
class PrevState(CheckpointStateBase):
list_a: List[str]
list_b: List[str]

class NextState(CheckpointStateBase):
list_stuff: List[str]

@pydantic.root_validator(pre=True, allow_reuse=True)
def _migrate(cls, values: dict) -> dict:
values.setdefault("list_stuff", [])
values["list_stuff"] += values.pop("list_a", [])
values["list_stuff"] += values.pop("list_b", [])
return values

prev_state = PrevState(list_a=["a", "b"], list_b=["c", "d"], serde=serde)
expected_next_state = NextState(list_stuff=["a", "b", "c", "d"], serde=serde)

checkpoint_state = IngestionCheckpointStateClass(
formatVersion=prev_state.version,
serde=prev_state.serde,
payload=prev_state.to_bytes(),
)

_assert_checkpoint_deserialization(checkpoint_state, expected_next_state)