Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Include file path option for NDJSON #17681

Merged
merged 2 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/polars-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,8 @@ where
}

/// This takes ownership of the DataFrame so that drop is called earlier.
/// # Panics
/// Panics if `dfs` is empty.
pub fn accumulate_dataframes_vertical<I>(dfs: I) -> PolarsResult<DataFrame>
where
I: IntoIterator<Item = DataFrame>,
Expand Down
9 changes: 8 additions & 1 deletion crates/polars-lazy/src/scan/ndjson.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub struct LazyJsonLineReader {
pub(crate) infer_schema_length: Option<NonZeroUsize>,
pub(crate) n_rows: Option<usize>,
pub(crate) ignore_errors: bool,
pub(crate) include_file_paths: Option<Arc<str>>,
}

impl LazyJsonLineReader {
Expand All @@ -39,6 +40,7 @@ impl LazyJsonLineReader {
infer_schema_length: NonZeroUsize::new(100),
ignore_errors: false,
n_rows: None,
include_file_paths: None,
}
}
/// Add a row index column.
Expand Down Expand Up @@ -89,6 +91,11 @@ impl LazyJsonLineReader {
self.batch_size = batch_size;
self
}

pub fn with_include_file_paths(mut self, include_file_paths: Option<Arc<str>>) -> Self {
self.include_file_paths = include_file_paths;
self
}
}

impl LazyFileListReader for LazyJsonLineReader {
Expand All @@ -108,7 +115,7 @@ impl LazyFileListReader for LazyJsonLineReader {
file_counter: 0,
hive_options: Default::default(),
glob: true,
include_file_paths: None,
include_file_paths: self.include_file_paths,
};

let options = NDJsonReadOptions {
Expand Down
23 changes: 22 additions & 1 deletion crates/polars-mem-engine/src/executors/scan/ndjson.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ impl JsonExec {

let mut n_rows = self.file_scan_options.n_rows;

// Avoid panicking
if n_rows == Some(0) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drive-by - fix panic on scan_ndjson().head(0)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test for this one?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a test below

let mut df = DataFrame::empty_with_schema(schema);
if let Some(col) = &self.file_scan_options.include_file_paths {
unsafe { df.with_column_unchecked(StringChunked::full_null(col, 0).into_series()) };
}
if let Some(row_index) = &self.file_scan_options.row_index {
df.with_row_index_mut(row_index.name.as_ref(), Some(row_index.offset));
}
return Ok(df);
}

let dfs = self
.paths
.iter()
Expand Down Expand Up @@ -67,7 +79,7 @@ impl JsonExec {
.with_ignore_errors(self.options.ignore_errors)
.finish();

let df = match df {
let mut df = match df {
Ok(df) => df,
Err(e) => return Some(Err(e)),
};
Expand All @@ -76,6 +88,15 @@ impl JsonExec {
*n_rows -= df.height();
}

if let Some(col) = &self.file_scan_options.include_file_paths {
let path = p.to_str().unwrap();
unsafe {
df.with_column_unchecked(
StringChunked::full(col, path, df.height()).into_series(),
)
};
}

Some(Ok(df))
})
.collect::<PolarsResult<Vec<_>>>()?;
Expand Down
13 changes: 13 additions & 0 deletions crates/polars-plan/src/plans/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,19 @@ pub fn to_alp_impl(
None
};

file_options.include_file_paths =
file_options.include_file_paths.filter(|_| match scan_type {
#[cfg(feature = "parquet")]
FileScan::Parquet { .. } => true,
#[cfg(feature = "ipc")]
FileScan::Ipc { .. } => true,
#[cfg(feature = "csv")]
FileScan::Csv { .. } => true,
#[cfg(feature = "json")]
FileScan::NDJson { .. } => true,
FileScan::Anonymous { .. } => false,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure the lazy schema is correct

});

// Only if we have a writing file handle we must resolve hive partitions
// update schema's etc.
if let Some(lock) = &mut _file_info_write {
Expand Down
10 changes: 8 additions & 2 deletions py-polars/polars/io/ndjson.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def scan_ndjson(
row_index_name: str | None = None,
row_index_offset: int = 0,
ignore_errors: bool = False,
include_file_paths: str | None = None,
) -> LazyFrame:
"""
Lazily read from a newline delimited JSON file or multiple files via glob patterns.
Expand Down Expand Up @@ -138,12 +139,16 @@ def scan_ndjson(
Offset to start the row index column (only use if the name is set)
ignore_errors
Return `Null` if parsing fails because of schema mismatches.
include_file_paths
Include the path of the source file(s) as a column with this name.
"""
if isinstance(source, (str, Path)):
source = normalize_filepath(source)
source = normalize_filepath(source, check_not_directory=False)
ritchie46 marked this conversation as resolved.
Show resolved Hide resolved
sources = []
else:
sources = [normalize_filepath(source) for source in source]
sources = [
normalize_filepath(source, check_not_directory=False) for source in source
]
source = None # type: ignore[assignment]
if infer_schema_length == 0:
msg = "'infer_schema_length' should be positive"
Expand All @@ -160,5 +165,6 @@ def scan_ndjson(
rechunk,
parse_row_index_args(row_index_name, row_index_offset),
ignore_errors,
include_file_paths=include_file_paths,
)
return wrap_ldf(pylf)
4 changes: 3 additions & 1 deletion py-polars/src/lazyframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ impl PyLazyFrame {
#[staticmethod]
#[cfg(feature = "json")]
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (path, paths, infer_schema_length, schema, batch_size, n_rows, low_memory, rechunk, row_index, ignore_errors))]
#[pyo3(signature = (path, paths, infer_schema_length, schema, batch_size, n_rows, low_memory, rechunk, row_index, ignore_errors, include_file_paths))]
fn new_from_ndjson(
path: Option<PathBuf>,
paths: Vec<PathBuf>,
Expand All @@ -56,6 +56,7 @@ impl PyLazyFrame {
rechunk: bool,
row_index: Option<(String, IdxSize)>,
ignore_errors: bool,
include_file_paths: Option<String>,
) -> PyResult<Self> {
let row_index = row_index.map(|(name, offset)| RowIndex {
name: Arc::from(name.as_str()),
Expand All @@ -77,6 +78,7 @@ impl PyLazyFrame {
.with_schema(schema.map(|schema| Arc::new(schema.0)))
.with_row_index(row_index)
.with_ignore_errors(ignore_errors)
.with_include_file_paths(include_file_paths.map(Arc::from))
.finish()
.map_err(PyPolarsErr::from)?;

Expand Down
30 changes: 29 additions & 1 deletion py-polars/tests/unit/io/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,33 @@ def test_scan_with_row_index_filter_and_limit(
)


@pytest.mark.write_disk()
@pytest.mark.parametrize(
("scan_func", "write_func"),
[
(pl.scan_parquet, pl.DataFrame.write_parquet),
(pl.scan_ipc, pl.DataFrame.write_ipc),
(pl.scan_csv, pl.DataFrame.write_csv),
(pl.scan_ndjson, pl.DataFrame.write_ndjson),
],
)
@pytest.mark.parametrize(
"streaming",
[True, False],
)
def test_scan_limit_0_does_not_panic(
tmp_path: Path,
scan_func: Callable[[Any], pl.LazyFrame],
write_func: Callable[[pl.DataFrame, Path], None],
streaming: bool,
) -> None:
tmp_path.mkdir(exist_ok=True)
path = tmp_path / "data.bin"
df = pl.DataFrame({"x": 1})
write_func(df, path)
assert_frame_equal(scan_func(path).head(0).collect(streaming=streaming), df.clear())


@pytest.mark.write_disk()
@pytest.mark.parametrize(
("scan_func", "write_func"),
Expand Down Expand Up @@ -598,6 +625,7 @@ def test_scan_nonexistent_path(format: str) -> None:
(pl.scan_parquet, pl.DataFrame.write_parquet),
(pl.scan_ipc, pl.DataFrame.write_ipc),
(pl.scan_csv, pl.DataFrame.write_csv),
(pl.scan_ndjson, pl.DataFrame.write_ndjson),
],
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -639,7 +667,7 @@ def test_scan_include_file_name(
assert_frame_equal(lf.collect(streaming=streaming), df)

# TODO: Support this with CSV
if scan_func is not pl.scan_csv:
if scan_func not in [pl.scan_csv, pl.scan_ndjson]:
# Test projecting only the path column
assert_frame_equal(
lf.select("path").collect(streaming=streaming),
Expand Down