Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility with latest dask, pyarrow and deltalake #68

Merged
merged 6 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 37 additions & 12 deletions dask_deltatable/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import os
from collections.abc import Sequence
from functools import partial
from typing import Any, cast
from typing import Any, Callable, cast

import dask
import dask.dataframe as dd
import pyarrow as pa
import pyarrow.parquet as pq
from dask.base import tokenize
from dask.dataframe.io.parquet.arrow import ArrowDatasetEngine
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This import is a bit ugly but this class exposes our type mappers and this way we can align behavior

from dask.dataframe.utils import make_meta
from deltalake import DataCatalog, DeltaTable
from fsspec.core import get_fs_token_paths
Expand Down Expand Up @@ -64,17 +66,20 @@ def _read_delta_partition(
filter_expression = filters_to_expression(filter) if filter else None
if pyarrow_to_pandas is None:
pyarrow_to_pandas = {}
return (
pa_ds.dataset(
source=filename,
schema=schema,
filesystem=fs,
format="parquet",
partitioning="hive",
)
.to_table(filter=filter_expression, columns=columns)
.to_pandas(**pyarrow_to_pandas)
pyarrow_to_pandas["types_mapper"] = _get_type_mapper(
pyarrow_to_pandas.get("types_mapper")
)
pyarrow_to_pandas["ignore_metadata"] = pyarrow_to_pandas.get(
"ignore_metadata", False
)
table = pa_ds.dataset(
source=filename,
schema=schema,
filesystem=fs,
format="parquet",
partitioning="hive",
).to_table(filter=filter_expression, columns=columns)
return table.to_pandas(**pyarrow_to_pandas)


def _read_from_filesystem(
Expand All @@ -94,7 +99,7 @@ def _read_from_filesystem(
table_uri=path, version=version, storage_options=delta_storage_options
)
if datetime is not None:
dt.load_with_datetime(datetime)
dt.load_as_version(datetime)

schema = dt.schema().to_pyarrow()

Expand All @@ -104,6 +109,9 @@ def _read_from_filesystem(
raise RuntimeError("No Parquet files are available")

mapper_kwargs = kwargs.get("pyarrow_to_pandas", {})
mapper_kwargs["types_mapper"] = _get_type_mapper(
mapper_kwargs.get("types_mapper", None)
)
meta = make_meta(schema.empty_table().to_pandas(**mapper_kwargs))
if columns:
meta = meta[columns]
Expand All @@ -117,6 +125,22 @@ def _read_from_filesystem(
)


def _get_type_mapper(
user_types_mapper: dict[str, Any] | None
) -> Callable[[Any], Any] | None:
"""
Set the type mapper for the schema
"""
convert_string = dask.config.get("dataframe.convert-string", True)
if convert_string is None:
convert_string = True
return ArrowDatasetEngine._determine_type_mapper(
dtype_backend=None,
convert_string=convert_string,
arrow_to_pandas={"types_mapper": user_types_mapper},
)


def _read_from_catalog(
database_name: str, table_name: str, **kwargs
) -> dd.core.DataFrame:
Expand All @@ -128,6 +152,7 @@ def _read_from_catalog(

session = Session()
credentials = session.get_credentials()
assert credentials is not None
current_credentials = credentials.get_frozen_credentials()
os.environ["AWS_ACCESS_KEY_ID"] = current_credentials.access_key
os.environ["AWS_SECRET_ACCESS_KEY"] = current_credentials.secret_key
Expand Down
10 changes: 7 additions & 3 deletions dask_deltatable/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
DeltaProtocolError,
DeltaStorageHandler,
__enforce_append_only,
_write_new_deltalake,
get_file_stats_from_metadata,
get_partitions_from_path,
try_get_table_and_table_uri,
write_deltalake_pyarrow,
)
from toolz.itertoolz import pluck

Expand Down Expand Up @@ -54,6 +54,7 @@ def to_deltalake(
storage_options: dict[str, str] | None = None,
partition_filters: list[tuple[str, str, Any]] | None = None,
compute: bool = True,
custom_metadata: dict[str, str] | None = None,
):
"""Write a given dask.DataFrame to a delta table. The returned value is a Dask Scalar,
and the writing operation is only triggered when calling ``.compute()``
Expand Down Expand Up @@ -216,9 +217,10 @@ def to_deltalake(
configuration,
storage_options,
partition_filters,
custom_metadata,
)
}
graph = HighLevelGraph.from_collections(final_name, dsk, dependencies=(written,))
graph = HighLevelGraph.from_collections(final_name, dsk, dependencies=(written,)) # type: ignore
result = Scalar(graph, final_name, "")
Comment on lines +223 to 224
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll migrate to dask-expr in a follow up PR once this is green

if compute:
result = result.compute()
Expand All @@ -237,6 +239,7 @@ def _commit(
configuration,
storage_options,
partition_filters,
custom_metadata,
):
schemas = list(flatten(pluck(0, schemas_add_actions_nested)))
add_actions = list(flatten(pluck(1, schemas_add_actions_nested)))
Expand All @@ -250,7 +253,7 @@ def _commit(
schema = validate_compatible(schemas)
assert schema
if table is None:
_write_new_deltalake(
write_deltalake_pyarrow(
table_uri,
schema,
add_actions,
Expand All @@ -260,6 +263,7 @@ def _commit(
description,
configuration,
storage_options,
custom_metadata,
)
else:
table._table.create_write_transaction(
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
dask[dataframe]
deltalake
deltalake>=0.15
fsspec
pyarrow
4 changes: 4 additions & 0 deletions tests/test_acceptance.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def test_reader_all_primitive_types():
expected_ddf = dd.read_parquet(
f"{DATA_DIR}/all_primitive_types/expected/latest/table_content/*parquet"
)
# Dask and delta go through different parquet parsers which read the
# timestamp differently. This is likely a bug in arrow but the delta result
# is "more correct".
expected_ddf["timestamp"] = expected_ddf["timestamp"].astype("datetime64[us]")
assert_eq(actual_ddf, expected_ddf)


Expand Down
11 changes: 4 additions & 7 deletions tests/test_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,11 @@ def test_roundtrip(tmpdir, with_index, freq, partition_freq):
assert len(os.listdir(tmpdir)) > 0

ddf_read = read_deltalake(tmpdir)
# FIXME: The index is not recovered
if with_index:
ddf = ddf.reset_index()
ddf_dask = dd.read_parquet(tmpdir)

assert ddf.npartitions == ddf_read.npartitions
# By default, arrow reads with ns resolution
ddf.timestamp = ddf.timestamp.astype("datetime64[ns]")
assert_eq(ddf, ddf_read)
assert_eq(ddf_read, ddf_dask)


@pytest.mark.parametrize("unit", ["s", "ms", "us", "ns"])
Expand All @@ -74,5 +71,5 @@ def test_datetime(tmpdir, unit):
ddf = dd.from_pandas(df, npartitions=2)
to_deltalake(tmpdir, ddf)
ddf_read = read_deltalake(tmpdir)
# arrow reads back with ns
assert ddf_read.ts.dtype == "datetime64[ns]"
ddf_dask = dd.read_parquet(tmpdir)
assert_eq(ddf_read, ddf_dask, check_index=False)
Loading