diff --git a/Cargo.toml b/Cargo.toml index 022be42bca58..384a8d37fa30 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -88,6 +88,7 @@ arrow-data = { version = ">=41", default-features = false } arrow-schema = { version = ">=41", default-features = false } parquet2 = { version = "0.17.2", features = ["async"], default-features = false } avro-schema = { version = "0.3" } +zstd = "0.13" [workspace.dependencies.arrow] package = "polars-arrow" diff --git a/crates/polars-arrow/Cargo.toml b/crates/polars-arrow/Cargo.toml index d2bbce0b5d38..7dd8db37e4d1 100644 --- a/crates/polars-arrow/Cargo.toml +++ b/crates/polars-arrow/Cargo.toml @@ -46,7 +46,7 @@ hex = { workspace = true, optional = true } # for IPC compression lz4 = { version = "1.24", optional = true } -zstd = { version = "0.13", optional = true } +zstd = { workspace = true, optional = true } # to write to parquet as a stream futures = { workspace = true, optional = true } diff --git a/crates/polars-io/Cargo.toml b/crates/polars-io/Cargo.toml index 97999970e27a..6e10215b168a 100644 --- a/crates/polars-io/Cargo.toml +++ b/crates/polars-io/Cargo.toml @@ -46,6 +46,7 @@ smartstring = { workspace = true } tokio = { workspace = true, features = ["net", "rt-multi-thread", "time", "sync"], optional = true } tokio-util = { workspace = true, features = ["io", "io-util"], optional = true } url = { workspace = true, optional = true } +zstd = { workspace = true, optional = true } [target.'cfg(not(target_family = "wasm"))'.dependencies] home = "0.5.4" @@ -73,8 +74,8 @@ ipc_streaming = ["arrow/io_ipc", "arrow/io_ipc_compression"] # support for arrow avro parsing avro = ["arrow/io_avro", "arrow/io_avro_compression"] csv = ["lexical", "polars-core/rows", "itoa", "ryu", "fast-float", "simdutf8"] -decompress = ["flate2/rust_backend"] -decompress-fast = ["flate2/zlib-ng"] +decompress = ["flate2/rust_backend", "zstd"] +decompress-fast = ["flate2/zlib-ng", "zstd"] dtype-categorical = ["polars-core/dtype-categorical"] dtype-date = ["polars-core/dtype-date", "polars-time/dtype-date"] object = [] diff --git a/crates/polars-io/src/csv/utils.rs b/crates/polars-io/src/csv/utils.rs index e7f6ea9ea27e..db42b01a341b 100644 --- a/crates/polars-io/src/csv/utils.rs +++ b/crates/polars-io/src/csv/utils.rs @@ -507,6 +507,7 @@ const GZIP: [u8; 2] = [31, 139]; const ZLIB0: [u8; 2] = [0x78, 0x01]; const ZLIB1: [u8; 2] = [0x78, 0x9C]; const ZLIB2: [u8; 2] = [0x78, 0xDA]; +const ZSTD: [u8; 4] = [0x28, 0xB5, 0x2F, 0xFD]; /// check if csv file is compressed pub fn is_compressed(bytes: &[u8]) -> bool { @@ -514,6 +515,7 @@ pub fn is_compressed(bytes: &[u8]) -> bool { || bytes.starts_with(&ZLIB1) || bytes.starts_with(&ZLIB2) || bytes.starts_with(&GZIP) + || bytes.starts_with(&ZSTD) } #[cfg(any(feature = "decompress", feature = "decompress-fast"))] @@ -603,6 +605,9 @@ pub(crate) fn decompress( } else if bytes.starts_with(&ZLIB0) || bytes.starts_with(&ZLIB1) || bytes.starts_with(&ZLIB2) { let mut decoder = flate2::read::ZlibDecoder::new(bytes); decompress_impl(&mut decoder, n_rows, separator, quote_char, eol_char) + } else if bytes.starts_with(&ZSTD) { + let mut decoder = zstd::Decoder::new(bytes).ok()?; + decompress_impl(&mut decoder, n_rows, separator, quote_char, eol_char) } else { None } diff --git a/py-polars/Cargo.lock b/py-polars/Cargo.lock index 4c91a4ea7f52..f93324027776 100644 --- a/py-polars/Cargo.lock +++ b/py-polars/Cargo.lock @@ -1709,6 +1709,7 @@ dependencies = [ "tokio", "tokio-util", "url", + "zstd 0.13.0", ] [[package]] diff --git a/py-polars/requirements-dev.txt b/py-polars/requirements-dev.txt index f63137249749..bad0ea1630df 100644 --- a/py-polars/requirements-dev.txt +++ b/py-polars/requirements-dev.txt @@ -45,6 +45,8 @@ deltalake == 0.10.1 # Dataframe interchange protocol dataframe-api-compat >= 0.1.6 pyiceberg >= 0.5.0 +# Csv +zstandard # Other matplotlib gevent diff --git a/py-polars/tests/unit/io/files/zstd_compressed.csv.zst b/py-polars/tests/unit/io/files/zstd_compressed.csv.zst new file mode 100644 index 000000000000..ae05d77685f6 Binary files /dev/null and b/py-polars/tests/unit/io/files/zstd_compressed.csv.zst differ diff --git a/py-polars/tests/unit/io/test_csv.py b/py-polars/tests/unit/io/test_csv.py index a47ac805847a..0f44f7754826 100644 --- a/py-polars/tests/unit/io/test_csv.py +++ b/py-polars/tests/unit/io/test_csv.py @@ -11,6 +11,7 @@ import numpy as np import pyarrow as pa import pytest +import zstandard import polars as pl from polars.exceptions import ComputeError, NoDataError @@ -500,6 +501,16 @@ def test_compressed_csv(io_files_path: Path) -> None: ) assert_frame_equal(out, expected) + # zstd compression + csv_bytes = zstandard.compress(csv.encode()) + out = pl.read_csv(csv_bytes) + assert_frame_equal(out, expected) + + # zstd compressed file + csv_file = io_files_path / "zstd_compressed.csv.zst" + out = pl.read_csv(str(csv_file), truncate_ragged_lines=True) + assert_frame_equal(out, expected) + # no compression f2 = io.BytesIO(b"a,b\n1,2\n") out2 = pl.read_csv(f2) @@ -517,6 +528,12 @@ def test_partial_decompression(foods_file_path: Path) -> None: out = pl.read_csv(csv_bytes, n_rows=n_rows) assert out.shape == (n_rows, 4) + # zstd compression + csv_bytes = zstandard.compress(foods_file_path.read_bytes()) + for n_rows in [1, 5, 26]: + out = pl.read_csv(csv_bytes, n_rows=n_rows) + assert out.shape == (n_rows, 4) + def test_empty_bytes() -> None: b = b""