Skip to content

Commit

Permalink
feat: refactor WriterProperties class (#2030)
Browse files Browse the repository at this point in the history
# Description
- Add better typehint
- Better description of the compression levels
- Gracefully handle wrong compression levels, rust parquet errors were
not clear enough
  • Loading branch information
ion-elgreco authored Jan 4, 2024
1 parent 22e6fea commit 7981b95
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 12 deletions.
99 changes: 88 additions & 11 deletions python/deltalake/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
from functools import reduce
from pathlib import Path
from typing import (
Expand Down Expand Up @@ -51,6 +52,59 @@
MAX_SUPPORTED_WRITER_VERSION = 2


class Compression(Enum):
UNCOMPRESSED = "UNCOMPRESSED"
SNAPPY = "SNAPPY"
GZIP = "GZIP"
BROTLI = "BROTLI"
LZ4 = "LZ4"
ZSTD = "ZSTD"
LZ4_RAW = "LZ4_RAW"

@classmethod
def from_str(cls, value: str) -> "Compression":
try:
return cls(value.upper())
except ValueError:
raise ValueError(
f"{value} is not a valid Compression. Valid values are: {[item.value for item in Compression]}"
)

def get_level_range(self) -> Tuple[int, int]:
if self == Compression.GZIP:
MIN_LEVEL = 0
MAX_LEVEL = 10
elif self == Compression.BROTLI:
MIN_LEVEL = 0
MAX_LEVEL = 11
elif self == Compression.ZSTD:
MIN_LEVEL = 1
MAX_LEVEL = 22
else:
raise KeyError(f"{self.value} does not have a compression level.")
return MIN_LEVEL, MAX_LEVEL

def get_default_level(self) -> int:
if self == Compression.GZIP:
DEFAULT = 6
elif self == Compression.BROTLI:
DEFAULT = 1
elif self == Compression.ZSTD:
DEFAULT = 1
else:
raise KeyError(f"{self.value} does not have a compression level.")
return DEFAULT

def check_valid_level(self, level: int) -> bool:
MIN_LEVEL, MAX_LEVEL = self.get_level_range()
if level < MIN_LEVEL or level > MAX_LEVEL:
raise ValueError(
f"Compression level for {self.value} should fall between {MIN_LEVEL}-{MAX_LEVEL}"
)
else:
return True


@dataclass(init=True)
class WriterProperties:
"""A Writer Properties instance for the Rust parquet writer."""
Expand All @@ -62,39 +116,62 @@ def __init__(
data_page_row_count_limit: Optional[int] = None,
write_batch_size: Optional[int] = None,
max_row_group_size: Optional[int] = None,
compression: Optional[str] = None,
compression: Optional[
Literal[
"UNCOMPRESSED",
"SNAPPY",
"GZIP",
"BROTLI",
"LZ4",
"ZSTD",
"LZ4_RAW",
]
] = None,
compression_level: Optional[int] = None,
):
"""Create a Writer Properties instance for the Rust parquet writer,
see options https://arrow.apache.org/rust/parquet/file/properties/struct.WriterProperties.html:
"""Create a Writer Properties instance for the Rust parquet writer:
Args:
data_page_size_limit: Limit DataPage size to this in bytes.
dictionary_page_size_limit: Limit the size of each DataPage to store dicts to this amount in bytes.
data_page_row_count_limit: Limit the number of rows in each DataPage.
write_batch_size: Splits internally to smaller batch size.
max_row_group_size: Max number of rows in row group.
compression: compression type
compression_level: level of compression, only relevant for subset of compression types
compression: compression type.
compression_level: If none and compression has a level, the default level will be used, only relevant for
GZIP: levels (1-9),
BROTLI: levels (1-11),
ZSTD: levels (1-22),
"""
self.data_page_size_limit = data_page_size_limit
self.dictionary_page_size_limit = dictionary_page_size_limit
self.data_page_row_count_limit = data_page_row_count_limit
self.write_batch_size = write_batch_size
self.max_row_group_size = max_row_group_size
self.compression = None

if compression_level is not None and compression is None:
raise ValueError(
"""Providing a compression level without the compression type is not possible,
please provide the compression as well."""
)

if compression in ["gzip", "brotli", "zstd"]:
if compression_level is not None:
compression = compression = f"{compression}({compression_level})"
if isinstance(compression, str):
compression_enum = Compression.from_str(compression)
if compression_enum in [
Compression.GZIP,
Compression.BROTLI,
Compression.ZSTD,
]:
if compression_level is not None:
if compression_enum.check_valid_level(compression_level):
parquet_compression = (
f"{compression_enum.value}({compression_level})"
)
else:
parquet_compression = f"{compression_enum.value}({compression_enum.get_default_level()})"
else:
raise ValueError("""Gzip, brotli, ztsd require a compression level""")
self.compression = compression
parquet_compression = compression_enum.value
self.compression = parquet_compression

def __str__(self) -> str:
return (
Expand Down
7 changes: 6 additions & 1 deletion python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pyarrow as pa
import pytest

from deltalake import DeltaTable, write_deltalake
from deltalake import DeltaTable, WriterProperties, write_deltalake


def wait_till_host_is_available(host: str, timeout_sec: int = 0.5):
Expand Down Expand Up @@ -248,3 +248,8 @@ def sample_table():
"deleted": pa.array([False] * nrows),
}
)


@pytest.fixture()
def writer_properties():
return WriterProperties(compression="GZIP", compression_level=0)
88 changes: 88 additions & 0 deletions python/tests/test_writerproperties.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import pathlib

import pyarrow as pa
import pyarrow.parquet as pq
import pytest

from deltalake import DeltaTable, WriterProperties, write_deltalake


def test_writer_properties_all_filled():
wp = WriterProperties(
data_page_size_limit=100,
dictionary_page_size_limit=200,
data_page_row_count_limit=300,
write_batch_size=400,
max_row_group_size=500,
compression="SNAPPY",
)

expected = {
"data_page_size_limit": "100",
"dictionary_page_size_limit": "200",
"data_page_row_count_limit": "300",
"write_batch_size": "400",
"max_row_group_size": "500",
"compression": "SNAPPY",
}

assert wp._to_dict() == expected


def test_writer_properties_lower_case_compression():
wp = WriterProperties(compression="snappy") # type: ignore

expected = {
"data_page_size_limit": None,
"dictionary_page_size_limit": None,
"data_page_row_count_limit": None,
"write_batch_size": None,
"max_row_group_size": None,
"compression": "SNAPPY",
}

assert wp._to_dict() == expected


@pytest.mark.parametrize(
"compression,expected",
[("GZIP", "GZIP(6)"), ("BROTLI", "BROTLI(1)"), ("ZSTD", "ZSTD(1)")],
)
def test_writer_properties_missing_compression_level(compression, expected):
wp = WriterProperties(compression=compression)

assert wp.compression == expected


@pytest.mark.parametrize(
"compression,wrong_level",
[
("GZIP", -1),
("GZIP", 11),
("BROTLI", -1),
("BROTLI", 12),
("ZSTD", 0),
("ZSTD", 23),
],
)
def test_writer_properties_incorrect_level_range(compression, wrong_level):
with pytest.raises(ValueError):
WriterProperties(compression=compression, compression_level=wrong_level)


def test_writer_properties_no_compression():
with pytest.raises(ValueError):
WriterProperties(compression_level=10)


def test_write_with_writerproperties(
tmp_path: pathlib.Path, sample_table: pa.Table, writer_properties: WriterProperties
):
write_deltalake(
tmp_path, sample_table, engine="rust", writer_properties=writer_properties
)

parquet_path = DeltaTable(tmp_path).file_uris()[0]
metadata = pq.read_metadata(parquet_path)

assert metadata.to_dict()["row_groups"][0]["columns"][0]["compression"] == "GZIP"

0 comments on commit 7981b95

Please sign in to comment.