Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix row index disappearing after projection pushdown in NDJSON #17631

Merged
merged 3 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions crates/polars-io/src/ndjson/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,15 +316,15 @@ impl<'a> CoreJsonReader<'a> {
)?;

let prepredicate_height = local_df.height() as IdxSize;
if let Some(projection) = &self.projection {
local_df = local_df.select(projection.as_ref())?;
}

if let Some(row_index) = row_index {
local_df = local_df
.with_row_index(row_index.name.as_ref(), Some(row_index.offset))?;
}

if let Some(projection) = &self.projection {
local_df = local_df.select(projection.as_ref())?;
}

if let Some(predicate) = &self.predicate {
let s = predicate.evaluate_io(&local_df)?;
let mask = s.bool()?;
Expand Down
73 changes: 38 additions & 35 deletions py-polars/tests/unit/io/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,11 @@ def _enable_force_async(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("POLARS_FORCE_ASYNC", "1")


def _assert_force_async(capfd: Any) -> None:
def _assert_force_async(capfd: Any, data_file_extension: str) -> None:
"""Calls `capfd.readouterr`, consuming the captured output so far."""
if data_file_extension == ".ndjson":
return

captured = capfd.readouterr().err
assert captured.count("ASYNC READING FORCED") == 1

Expand All @@ -42,47 +45,47 @@ def _scan(
row_index_name = None if row_index is None else row_index.name
row_index_offset = 0 if row_index is None else row_index.offset

if suffix == ".ipc":
result = pl.scan_ipc(
if (
scan_func := {
".ipc" : pl.scan_ipc,
".parquet" : pl.scan_parquet,
".csv" : pl.scan_csv,
".ndjson" : pl.scan_ndjson,
}.get(suffix)
) is not None: # fmt: skip
result = scan_func(
file_path,
row_index_name=row_index_name,
row_index_offset=row_index_offset,
)
elif suffix == ".parquet":
result = pl.scan_parquet(
file_path,
row_index_name=row_index_name,
row_index_offset=row_index_offset,
)
elif suffix == ".csv":
result = pl.scan_csv(
file_path,
schema=schema,
row_index_name=row_index_name,
row_index_offset=row_index_offset,
)
) # type: ignore[operator]

else:
msg = f"Unknown suffix {suffix}"
raise NotImplementedError(msg)

return result
return result # type: ignore[no-any-return]


def _write(df: pl.DataFrame, file_path: Path) -> None:
suffix = file_path.suffix
if suffix == ".ipc":
return df.write_ipc(file_path)
if suffix == ".parquet":
return df.write_parquet(file_path)
if suffix == ".csv":
return df.write_csv(file_path)

if (
write_func := {
".ipc" : pl.DataFrame.write_ipc,
".parquet" : pl.DataFrame.write_parquet,
".csv" : pl.DataFrame.write_csv,
".ndjson" : pl.DataFrame.write_ndjson,
}.get(suffix)
) is not None: # fmt: skip
return write_func(df, file_path) # type: ignore[operator, no-any-return]

msg = f"Unknown suffix {suffix}"
raise NotImplementedError(msg)


@pytest.fixture(
scope="session",
params=["csv", "ipc", "parquet"],
params=["csv", "ipc", "parquet", "ndjson"],
)
def data_file_extension(request: pytest.FixtureRequest) -> str:
return f".{request.param}"
Expand Down Expand Up @@ -197,7 +200,7 @@ def test_scan(
df = _scan(data_file.path, data_file.df.schema).collect()

if force_async:
_assert_force_async(capfd)
_assert_force_async(capfd, data_file.path.suffix)

assert_frame_equal(df, data_file.df)

Expand All @@ -212,7 +215,7 @@ def test_scan_with_limit(
df = _scan(data_file.path, data_file.df.schema).limit(4483).collect()

if force_async:
_assert_force_async(capfd)
_assert_force_async(capfd, data_file.path.suffix)

assert_frame_equal(
df,
Expand All @@ -238,7 +241,7 @@ def test_scan_with_filter(
)

if force_async:
_assert_force_async(capfd)
_assert_force_async(capfd, data_file.path.suffix)

assert_frame_equal(
df,
Expand All @@ -265,7 +268,7 @@ def test_scan_with_filter_and_limit(
)

if force_async:
_assert_force_async(capfd)
_assert_force_async(capfd, data_file.path.suffix)

assert_frame_equal(
df,
Expand All @@ -292,7 +295,7 @@ def test_scan_with_limit_and_filter(
)

if force_async:
_assert_force_async(capfd)
_assert_force_async(capfd, data_file.path.suffix)

assert_frame_equal(
df,
Expand All @@ -318,7 +321,7 @@ def test_scan_with_row_index_and_limit(
)

if force_async:
_assert_force_async(capfd)
_assert_force_async(capfd, data_file.path.suffix)

assert_frame_equal(
df,
Expand Down Expand Up @@ -346,7 +349,7 @@ def test_scan_with_row_index_and_filter(
)

if force_async:
_assert_force_async(capfd)
_assert_force_async(capfd, data_file.path.suffix)

assert_frame_equal(
df,
Expand Down Expand Up @@ -375,7 +378,7 @@ def test_scan_with_row_index_limit_and_filter(
)

if force_async:
_assert_force_async(capfd)
_assert_force_async(capfd, data_file.path.suffix)

assert_frame_equal(
df,
Expand Down Expand Up @@ -407,7 +410,7 @@ def test_scan_with_row_index_projected_out(
)

if force_async:
_assert_force_async(capfd)
_assert_force_async(capfd, data_file.path.suffix)

assert_frame_equal(df, data_file.df.select(subset))

Expand All @@ -430,7 +433,7 @@ def test_scan_with_row_index_filter_and_limit(
)

if force_async:
_assert_force_async(capfd)
_assert_force_async(capfd, data_file.path.suffix)

assert_frame_equal(
df,
Expand Down