Skip to content

Commit

Permalink
feat(ingest): implement compression for CheckpointState
Browse files Browse the repository at this point in the history
  • Loading branch information
alexey-kravtsov committed Sep 21, 2022
1 parent 7e5f44a commit 34288d6
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 22 deletions.
73 changes: 51 additions & 22 deletions metadata-ingestion/src/datahub/ingestion/source/state/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import base64
import bz2
import functools
import json
import logging
import pickle
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Callable, Dict, Generic, Optional, Type, TypeVar
from typing import Callable, Generic, Optional, Type, TypeVar

import pydantic

Expand All @@ -26,7 +28,7 @@ class CheckpointStateBase(ConfigModel):
"""

version: str = pydantic.Field(default="1.0")
serde: str = pydantic.Field(default="utf-8")
serde: str = pydantic.Field(default="base85")

def to_bytes(
self,
Expand All @@ -45,8 +47,13 @@ def to_bytes(
binary state payload. Binary content-type needs to be supported for encoding the GenericAspect to do this.
"""

json_str_self = self.json(exclude={"version", "serde"})
encoded_bytes = json_str_self.encode("utf-8")
if self.serde == "utf-8":
encoded_bytes = CheckpointStateBase._to_bytes_utf8(self)
elif self.serde == "base85":
encoded_bytes = CheckpointStateBase._to_bytes_base85(self, compressor)
else:
raise ValueError(f"Unknown serde: {self.serde}")

if len(encoded_bytes) > max_allowed_state_size:
raise ValueError(
f"The state size has exceeded the max_allowed_state_size of {max_allowed_state_size}"
Expand All @@ -55,14 +62,14 @@ def to_bytes(
return encoded_bytes

@staticmethod
def from_bytes_to_dict(
data_bytes: bytes, decompressor: Callable[[bytes], bytes] = bz2.decompress
) -> Dict[str, Any]:
"""Helper method for sub-classes to use."""
# uncompressed_data: bytes = decompressor(data_bytes)
# json_str = uncompressed_data.decode('utf-8')
json_str = data_bytes.decode("utf-8")
return json.loads(json_str)
def _to_bytes_utf8(model: ConfigModel) -> bytes:
return model.json(exclude={"version", "serde"}).encode("utf-8")

@staticmethod
def _to_bytes_base85(
model: ConfigModel, compressor: Callable[[bytes], bytes]
) -> bytes:
return base64.b85encode(compressor(pickle.dumps(model)))

def prepare_for_commit(self) -> None:
"""
Expand Down Expand Up @@ -110,17 +117,16 @@ def create_from_checkpoint_aspect(
)
else:
try:
# Construct the state
state_as_dict = (
CheckpointStateBase.from_bytes_to_dict(
checkpoint_aspect.state.payload
if checkpoint_aspect.state.serde == "utf-8":
state_obj = Checkpoint._from_utf8_bytes(
checkpoint_aspect, state_class
)
if checkpoint_aspect.state.payload is not None
else {}
)
state_as_dict["version"] = checkpoint_aspect.state.formatVersion
state_as_dict["serde"] = checkpoint_aspect.state.serde
state_obj = state_class.parse_obj(state_as_dict)
elif checkpoint_aspect.state.serde == "base85":
state_obj = Checkpoint._from_base85_bytes(
checkpoint_aspect, functools.partial(bz2.decompress)
)
else:
raise ValueError(f"Unknown serde: {checkpoint_aspect.state.serde}")
except Exception as e:
logger.error(
"Failed to construct checkpoint class from checkpoint aspect.", e
Expand All @@ -142,6 +148,29 @@ def create_from_checkpoint_aspect(
return checkpoint
return None

@staticmethod
def _from_utf8_bytes(
checkpoint_aspect: DatahubIngestionCheckpointClass,
state_class: Type[StateType],
) -> StateType:
state_as_dict = (
json.loads(checkpoint_aspect.state.payload.decode("utf-8"))
if checkpoint_aspect.state.payload is not None
else {}
)
state_as_dict["version"] = checkpoint_aspect.state.formatVersion
state_as_dict["serde"] = checkpoint_aspect.state.serde
return state_class.parse_obj(state_as_dict)

@staticmethod
def _from_base85_bytes(
checkpoint_aspect: DatahubIngestionCheckpointClass,
decompressor: Callable[[bytes], bytes],
) -> StateType:
return pickle.loads(
decompressor(base64.b85decode(checkpoint_aspect.state.payload)) # type: ignore
)

def to_checkpoint_aspect(
self, max_allowed_state_size: int
) -> Optional[DatahubIngestionCheckpointClass]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,20 @@ def test_serde_idempotence(state_obj):
config_class=MySQLConfig,
)
assert orig_checkpoint_obj == serde_checkpoint_obj


def test_supported_encodings():
"""
Tests utf-8 and base85 encodings
"""
test_state = BaseUsageCheckpointState(
version="1.0", begin_timestamp_millis=1, end_timestamp_millis=100
)

# 1. Test UTF-8 encoding
test_state.serde = "utf-8"
test_serde_idempotence(test_state)

# 2. Test Base85 encoding
test_state.serde = "base85"
test_serde_idempotence(test_state)

0 comments on commit 34288d6

Please sign in to comment.