Skip to content

Commit

Permalink
fix: Raise on non-positive json schema inference (#16770)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jun 7, 2024
1 parent aed60b9 commit a0a577a
Show file tree
Hide file tree
Showing 10 changed files with 42 additions and 26 deletions.
12 changes: 7 additions & 5 deletions crates/polars-io/src/json/infer.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
use std::num::NonZeroUsize;

use polars_core::prelude::DataType;
use polars_core::utils::try_get_supertype;
use polars_error::PolarsResult;
use polars_error::{polars_bail, PolarsResult};
use simd_json::BorrowedValue;

pub(crate) fn json_values_to_supertype(
values: &[BorrowedValue],
infer_schema_len: usize,
infer_schema_len: NonZeroUsize,
) -> PolarsResult<DataType> {
// struct types may have missing fields so find supertype
values
.iter()
.take(infer_schema_len)
.take(infer_schema_len.into())
.map(|value| polars_json::json::infer(value).map(|dt| DataType::from(&dt)))
.reduce(|l, r| {
let l = l?;
let r = r?;
try_get_supertype(&l, &r)
})
.unwrap()
.unwrap_or_else(|| polars_bail!(ComputeError: "could not infer data-type"))
}

pub(crate) fn data_types_to_supertype<I: Iterator<Item = DataType>>(
Expand All @@ -30,5 +32,5 @@ pub(crate) fn data_types_to_supertype<I: Iterator<Item = DataType>>(
let r = r?;
try_get_supertype(&l, &r)
})
.unwrap()
.unwrap_or_else(|| polars_bail!(ComputeError: "could not infer data-type"))
}
11 changes: 6 additions & 5 deletions crates/polars-io/src/json/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
//! let file = Cursor::new(basic_json);
//! let df = JsonReader::new(file)
//! .with_json_format(JsonFormat::JsonLines)
//! .infer_schema_len(Some(3))
//! .infer_schema_len(NonZeroUsize::new(3))
//! .with_batch_size(NonZeroUsize::new(3).unwrap())
//! .finish()
//! .unwrap();
Expand Down Expand Up @@ -206,7 +206,7 @@ where
reader: R,
rechunk: bool,
ignore_errors: bool,
infer_schema_len: Option<usize>,
infer_schema_len: Option<NonZeroUsize>,
batch_size: NonZeroUsize,
projection: Option<Vec<String>>,
schema: Option<SchemaRef>,
Expand All @@ -223,7 +223,7 @@ where
reader,
rechunk: true,
ignore_errors: false,
infer_schema_len: Some(100),
infer_schema_len: Some(NonZeroUsize::new(100).unwrap()),
batch_size: NonZeroUsize::new(8192).unwrap(),
projection: None,
schema: None,
Expand Down Expand Up @@ -265,7 +265,8 @@ where
let inner_dtype = if let BorrowedValue::Array(values) = &json_value {
infer::json_values_to_supertype(
values,
self.infer_schema_len.unwrap_or(usize::MAX),
self.infer_schema_len
.unwrap_or(NonZeroUsize::new(usize::MAX).unwrap()),
)?
.to_arrow(true)
} else {
Expand Down Expand Up @@ -360,7 +361,7 @@ where
///
/// It is an error to pass `max_records = Some(0)`, as a schema cannot be inferred from 0 records when deserializing
/// from JSON (unlike CSVs, there is no header row to inspect for column names).
pub fn infer_schema_len(mut self, max_records: Option<usize>) -> Self {
pub fn infer_schema_len(mut self, max_records: Option<NonZeroUsize>) -> Self {
self.infer_schema_len = max_records;
self
}
Expand Down
8 changes: 4 additions & 4 deletions crates/polars-io/src/ndjson/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ where
rechunk: bool,
n_rows: Option<usize>,
n_threads: Option<usize>,
infer_schema_len: Option<usize>,
infer_schema_len: Option<NonZeroUsize>,
chunk_size: NonZeroUsize,
schema: Option<SchemaRef>,
schema_overwrite: Option<&'a Schema>,
Expand Down Expand Up @@ -58,7 +58,7 @@ where
self
}

pub fn infer_schema_len(mut self, infer_schema_len: Option<usize>) -> Self {
pub fn infer_schema_len(mut self, infer_schema_len: Option<NonZeroUsize>) -> Self {
self.infer_schema_len = infer_schema_len;
self
}
Expand Down Expand Up @@ -112,7 +112,7 @@ where
rechunk: true,
n_rows: None,
n_threads: None,
infer_schema_len: Some(128),
infer_schema_len: Some(NonZeroUsize::new(100).unwrap()),
schema: None,
schema_overwrite: None,
path: None,
Expand Down Expand Up @@ -166,7 +166,7 @@ impl<'a> CoreJsonReader<'a> {
sample_size: usize,
chunk_size: NonZeroUsize,
low_memory: bool,
infer_schema_len: Option<usize>,
infer_schema_len: Option<NonZeroUsize>,
ignore_errors: bool,
) -> PolarsResult<CoreJsonReader<'a>> {
let reader_bytes = reader_bytes;
Expand Down
4 changes: 3 additions & 1 deletion crates/polars-io/src/ndjson/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::num::NonZeroUsize;

use arrow::array::StructArray;
use polars_core::prelude::*;

Expand All @@ -6,7 +8,7 @@ pub mod core;

pub fn infer_schema<R: std::io::BufRead>(
reader: &mut R,
infer_schema_len: Option<usize>,
infer_schema_len: Option<NonZeroUsize>,
) -> PolarsResult<Schema> {
let data_types = polars_json::ndjson::iter_unique_dtypes(reader, infer_schema_len)?;
let data_type =
Expand Down
5 changes: 3 additions & 2 deletions crates/polars-json/src/ndjson/file.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::io::BufRead;
use std::num::NonZeroUsize;

use arrow::datatypes::ArrowDataType;
use fallible_streaming_iterator::FallibleStreamingIterator;
Expand Down Expand Up @@ -100,7 +101,7 @@ fn parse_value<'a>(scratch: &'a mut Vec<u8>, val: &[u8]) -> PolarsResult<Borrowe
/// It performs both `O(N)` IO and CPU-bounded operations where `N` is the number of rows.
pub fn iter_unique_dtypes<R: std::io::BufRead>(
reader: &mut R,
number_of_rows: Option<usize>,
number_of_rows: Option<NonZeroUsize>,
) -> PolarsResult<impl Iterator<Item = ArrowDataType>> {
if reader.fill_buf().map(|b| b.is_empty())? {
return Err(PolarsError::ComputeError(
Expand All @@ -109,7 +110,7 @@ pub fn iter_unique_dtypes<R: std::io::BufRead>(
}

let rows = vec!["".to_string(); 1]; // 1 <=> read row by row
let mut reader = FileReader::new(reader, rows, number_of_rows);
let mut reader = FileReader::new(reader, rows, number_of_rows.map(|v| v.into()));

let mut data_types = PlIndexSet::default();
let mut buf = vec![];
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::num::NonZeroUsize;

use super::*;

impl AnonymousScan for LazyJsonLineReader {
Expand All @@ -17,6 +19,7 @@ impl AnonymousScan for LazyJsonLineReader {
}

fn schema(&self, infer_schema_length: Option<usize>) -> PolarsResult<SchemaRef> {
polars_ensure!(infer_schema_length != Some(0), InvalidOperation: "JSON requires positive 'infer_schema_length'");
// Short-circuit schema inference if the schema has been explicitly provided,
// or already inferred
if let Some(schema) = &(*self.schema.read().unwrap()) {
Expand All @@ -28,7 +31,7 @@ impl AnonymousScan for LazyJsonLineReader {

let schema = Arc::new(polars_io::ndjson::infer_schema(
&mut reader,
infer_schema_length,
infer_schema_length.and_then(NonZeroUsize::new),
)?);
let mut guard = self.schema.write().unwrap();
*guard = Some(schema.clone());
Expand Down
8 changes: 4 additions & 4 deletions crates/polars/tests/it/io/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ fn read_json() {
"#;
let file = Cursor::new(basic_json);
let df = JsonReader::new(file)
.infer_schema_len(Some(3))
.infer_schema_len(NonZeroUsize::new(3))
.with_json_format(JsonFormat::JsonLines)
.with_batch_size(NonZeroUsize::new(3).unwrap())
.finish()
Expand Down Expand Up @@ -48,7 +48,7 @@ fn read_json_with_whitespace() {
{ "a":100000000000000, "b":0.6, "c":false, "d":"text"}"#;
let file = Cursor::new(basic_json);
let df = JsonReader::new(file)
.infer_schema_len(Some(3))
.infer_schema_len(NonZeroUsize::new(3))
.with_json_format(JsonFormat::JsonLines)
.with_batch_size(NonZeroUsize::new(3).unwrap())
.finish()
Expand All @@ -73,7 +73,7 @@ fn read_json_with_escapes() {
"#;
let file = Cursor::new(escaped_json);
let df = JsonLineReader::new(file)
.infer_schema_len(Some(6))
.infer_schema_len(NonZeroUsize::new(6))
.finish()
.unwrap();
assert_eq!("id", df.get_columns()[0].name());
Expand Down Expand Up @@ -102,7 +102,7 @@ fn read_unordered_json() {
"#;
let file = Cursor::new(unordered_json);
let df = JsonReader::new(file)
.infer_schema_len(Some(3))
.infer_schema_len(NonZeroUsize::new(3))
.with_json_format(JsonFormat::JsonLines)
.with_batch_size(NonZeroUsize::new(3).unwrap())
.finish()
Expand Down
3 changes: 3 additions & 0 deletions py-polars/polars/io/ndjson.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ def scan_ndjson(
else:
sources = [normalize_filepath(source) for source in source]
source = None # type: ignore[assignment]
if infer_schema_length == 0:
msg = "'infer_schema_length' should be positive"
raise ValueError(msg)

pylf = PyLazyFrame.new_from_ndjson(
source,
Expand Down
7 changes: 3 additions & 4 deletions py-polars/src/dataframe/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,15 @@ impl PyDataFrame {
schema: Option<Wrap<Schema>>,
schema_overrides: Option<Wrap<Schema>>,
) -> PyResult<Self> {
assert!(infer_schema_length != Some(0));
use crate::file::read_if_bytesio;
py_f = read_if_bytesio(py_f);
let mmap_bytes_r = get_mmap_bytes_reader(&py_f)?;

py.allow_threads(move || {
let mut builder = JsonReader::new(mmap_bytes_r)
.with_json_format(JsonFormat::Json)
.infer_schema_len(infer_schema_length);
.infer_schema_len(infer_schema_length.and_then(NonZeroUsize::new));

if let Some(schema) = schema {
builder = builder.with_schema(Arc::new(schema.0));
Expand All @@ -232,9 +233,7 @@ impl PyDataFrame {
builder = builder.with_schema_overwrite(&schema.0);
}

let out = builder
.finish()
.map_err(|e| PyPolarsErr::Other(format!("{e}")))?;
let out = builder.finish().map_err(PyPolarsErr::from)?;
Ok(out.into())
})
}
Expand Down
5 changes: 5 additions & 0 deletions py-polars/tests/unit/io/test_lazy_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def test_scan_ndjson_with_schema(foods_ndjson_path: Path) -> None:
assert df["sugars_g"].dtype == pl.Float64


def test_scan_ndjson_infer_0(foods_ndjson_path: Path) -> None:
with pytest.raises(ValueError):
pl.scan_ndjson(foods_ndjson_path, infer_schema_length=0)


def test_scan_ndjson_batch_size_zero() -> None:
with pytest.raises(ValueError, match="invalid zero value"):
pl.scan_ndjson("test.ndjson", batch_size=0)
Expand Down

0 comments on commit a0a577a

Please sign in to comment.