Skip to content

Commit

Permalink
Datasets schema should match the columns selection for Parquet (#18361)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjyao authored Sep 7, 2021
1 parent f24ccf4 commit 64040a9
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def filter(self,
better performance (you can implement filter by dropping records).
Examples:
>>> ds.flat_map(lambda x: x % 2 == 0)
>>> ds.filter(lambda x: x % 2 == 0)
Time complexity: O(dataset size / parallelism)
Expand Down
5 changes: 4 additions & 1 deletion python/ray/data/datasource/parquet_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def prepare_read(
# FileBasedDatasource's write side (do_write), however.
_check_pyarrow_version()
from ray import cloudpickle
import pyarrow as pa
import pyarrow.parquet as pq
import numpy as np

Expand All @@ -54,6 +55,9 @@ def prepare_read(
use_legacy_dataset=False)
if schema is None:
schema = pq_ds.schema
if columns:
schema = pa.schema([schema.field(column) for column in columns],
schema.metadata)
pieces = pq_ds.pieces

def read_pieces(serialized_pieces: List[str]):
Expand All @@ -69,7 +73,6 @@ def read_pieces(serialized_pieces: List[str]):
# Ensure that we're reading at least one dataset fragment.
assert len(pieces) > 0

import pyarrow as pa
from pyarrow.dataset import _get_partition_keys

logger.debug(f"Reading {len(pieces)} parquet pieces")
Expand Down
1 change: 1 addition & 0 deletions python/ray/data/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,7 @@ def test_parquet_read(ray_start_regular_shared, fs, data_path):
ds = ray.data.read_parquet(data_path, columns=["one"], filesystem=fs)
values = [s["one"] for s in ds.take()]
assert sorted(values) == [1, 2, 3, 4, 5, 6]
assert ds.schema().names == ["one"]


@pytest.mark.parametrize(
Expand Down

0 comments on commit 64040a9

Please sign in to comment.