diff --git a/crates/polars-io/src/csv/read/schema_inference.rs b/crates/polars-io/src/csv/read/schema_inference.rs index 57de091247f6..189c54501c12 100644 --- a/crates/polars-io/src/csv/read/schema_inference.rs +++ b/crates/polars-io/src/csv/read/schema_inference.rs @@ -99,8 +99,24 @@ impl CsvReadOptions { } } +pub fn finish_infer_field_schema(possibilities: &PlHashSet) -> DataType { + // determine data type based on possible types + // if there are incompatible types, use DataType::String + match possibilities.len() { + 1 => possibilities.iter().next().unwrap().clone(), + 2 if possibilities.contains(&DataType::Int64) + && possibilities.contains(&DataType::Float64) => + { + // we have an integer and double, fall down to double + DataType::Float64 + }, + // default to String for conflicting datatypes (e.g bool and int) + _ => DataType::String, + } +} + /// Infer the data type of a record -fn infer_field_schema(string: &str, try_parse_dates: bool, decimal_comma: bool) -> DataType { +pub fn infer_field_schema(string: &str, try_parse_dates: bool, decimal_comma: bool) -> DataType { // when quoting is enabled in the reader, these quotes aren't escaped, we default to // String for them if string.starts_with('"') { @@ -428,7 +444,6 @@ fn infer_file_schema_inner( // build schema from inference results for i in 0..header_length { - let possibilities = &column_types[i]; let field_name = &headers[i]; if let Some(schema_overwrite) = schema_overwrite { @@ -447,27 +462,9 @@ fn infer_file_schema_inner( } } - // determine data type based on possible types - // if there are incompatible types, use DataType::String - match possibilities.len() { - 1 => { - for dtype in possibilities.iter() { - fields.push(Field::new(field_name, dtype.clone())); - } - }, - 2 => { - if possibilities.contains(&DataType::Int64) - && possibilities.contains(&DataType::Float64) - { - // we have an integer and double, fall down to double - fields.push(Field::new(field_name, DataType::Float64)); - } else { - // default to String for conflicting datatypes (e.g bool and int) - fields.push(Field::new(field_name, DataType::String)); - } - }, - _ => fields.push(Field::new(field_name, DataType::String)), - } + let possibilities = &column_types[i]; + let dtype = finish_infer_field_schema(possibilities); + fields.push(Field::new(field_name, dtype)); } // if there is a single line after the header without an eol // we copy the bytes add an eol and rerun this function diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index 42a499bb8550..2e26f1a392db 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -15,7 +15,7 @@ doctest = false libloading = { version = "0.8.0", optional = true } polars-core = { workspace = true, features = ["lazy", "zip_with", "random"] } polars-ffi = { workspace = true, optional = true } -polars-io = { workspace = true, features = ["lazy"] } +polars-io = { workspace = true, features = ["lazy", "csv"] } polars-json = { workspace = true, optional = true } polars-ops = { workspace = true, features = [] } polars-parquet = { workspace = true, optional = true } diff --git a/crates/polars-plan/src/plans/hive.rs b/crates/polars-plan/src/plans/hive.rs index bdb9fe3deecf..505e404fea36 100644 --- a/crates/polars-plan/src/plans/hive.rs +++ b/crates/polars-plan/src/plans/hive.rs @@ -3,6 +3,7 @@ use std::path::{Path, PathBuf}; use percent_encoding::percent_decode_str; use polars_core::prelude::*; use polars_io::predicates::{BatchStats, ColumnStats}; +use polars_io::prelude::schema_inference::{finish_infer_field_schema, infer_field_schema}; use polars_io::utils::{BOOLEAN_RE, FLOAT_RE, INTEGER_RE}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -16,70 +17,6 @@ pub struct HivePartitions { } impl HivePartitions { - /// Constructs a new [`HivePartitions`] from a schema reference. - pub fn from_schema_ref(schema: SchemaRef) -> Self { - let column_stats = schema.iter_fields().map(ColumnStats::from_field).collect(); - let stats = BatchStats::new(schema, column_stats, None); - Self { stats } - } - - /// Constructs a new [`HivePartitions`] from a path. - /// - /// Returns `None` if the path does not contain any Hive partitions. - /// Returns `Err` if the Hive partitions cannot be parsed correctly or do not match the given - /// [`Schema`]. - pub fn try_from_path(path: &Path, schema: Option) -> PolarsResult> { - let sep = separator(path); - - let path_string = path.display().to_string(); - let path_parts = path_string.split(sep); - - // Last part is the file, which should be skipped. - let file_index = path_parts.clone().count() - 1; - - let partitions = path_parts - .enumerate() - .filter_map(|(index, part)| { - if index == file_index { - return None; - } - parse_hive_string(part) - }) - .map(|(name, value)| hive_info_to_series(name, value, schema.clone())) - .collect::>>()?; - - if partitions.is_empty() { - return Ok(None); - } - - let schema = match schema { - Some(schema) => Arc::new( - partitions - .iter() - .map(|s| { - let mut field = s.field().into_owned(); - if let Some(dtype) = schema.get(field.name()) { - field.dtype = dtype.clone(); - }; - field - }) - .collect::(), - ), - None => Arc::new(partitions.as_slice().into()), - }; - - let stats = BatchStats::new( - schema, - partitions - .into_iter() - .map(ColumnStats::from_column_literal) - .collect(), - None, - ); - - Ok(Some(HivePartitions { stats })) - } - pub fn get_statistics(&self) -> &BatchStats { &self.stats } @@ -108,33 +45,96 @@ pub fn hive_partitions_from_paths( return Ok(None); }; - let Some(hive_parts) = HivePartitions::try_from_path( - &PathBuf::from(&path.to_str().unwrap()[hive_start_idx..]), - schema.clone(), - )? - else { - return Ok(None); + let sep = separator(path); + let path_string = path.to_str().unwrap(); + + macro_rules! get_hive_parts_iter { + ($e:expr) => {{ + let path_parts = $e[hive_start_idx..].split(sep); + let file_index = path_parts.clone().count() - 1; + + path_parts.enumerate().filter_map(move |(index, part)| { + if index == file_index { + return None; + } + parse_hive_string(part) + }) + }}; + } + + let hive_schema = if let Some(v) = schema { + v + } else { + let mut hive_schema = Schema::with_capacity(16); + let mut schema_inference_map: PlHashMap<&str, PlHashSet> = + PlHashMap::with_capacity(16); + + for (name, _) in get_hive_parts_iter!(path_string) { + hive_schema.insert_at_index(hive_schema.len(), name.into(), DataType::String)?; + schema_inference_map.insert(name, PlHashSet::with_capacity(4)); + } + + if hive_schema.is_empty() && schema_inference_map.is_empty() { + return Ok(None); + } + + if !schema_inference_map.is_empty() { + for path in paths { + for (name, value) in get_hive_parts_iter!(path.to_str().unwrap()) { + let Some(entry) = schema_inference_map.get_mut(name) else { + continue; + }; + + entry.insert(infer_field_schema(value, false, false)); + } + } + + for (name, ref possibilities) in schema_inference_map.drain() { + let dtype = finish_infer_field_schema(possibilities); + *hive_schema.try_get_mut(name).unwrap() = dtype; + } + } + Arc::new(hive_schema) }; - let mut results = Vec::with_capacity(paths.len()); - results.push(Arc::new(hive_parts)); + let mut hive_partitions = Vec::with_capacity(paths.len()); + + for path in paths { + let path = path.to_str().unwrap(); + + let column_stats = get_hive_parts_iter!(path) + .map(|(name, value)| { + let Some(dtype) = hive_schema.as_ref().get(name) else { + polars_bail!( + SchemaFieldNotFound: + "path contains column not present in the given Hive schema: {:?}, path = {:?}", + name, + path + ) + }; + + Ok(ColumnStats::from_column_literal(value_to_series( + name, + value, + Some(dtype), + )?)) + }) + .collect::>>()?; - for path in &paths[1..] { - let Some(hive_parts) = HivePartitions::try_from_path( - &PathBuf::from(&path.to_str().unwrap()[hive_start_idx..]), - schema.clone(), - )? - else { + if column_stats.is_empty() { polars_bail!( ComputeError: "expected Hive partitioned path, got {}\n\n\ This error occurs if some paths are Hive partitioned and some paths are not.", - path.display() + path ) - }; - results.push(Arc::new(hive_parts)); + } + + let stats = BatchStats::new(hive_schema.clone(), column_stats, None); + + hive_partitions.push(Arc::new(HivePartitions { stats })); } - Ok(Some(results)) + Ok(Some(hive_partitions)) } /// Determine the path separator for identifying Hive partitions. @@ -174,24 +174,6 @@ fn parse_hive_string(part: &'_ str) -> Option<(&'_ str, &'_ str)> { Some((name, value)) } -/// Convert Hive partition string information to a single-value [`Series`]. -fn hive_info_to_series(name: &str, value: &str, schema: Option) -> PolarsResult { - let dtype = match schema { - Some(ref s) => { - let dtype = s.try_get(name).map_err(|_| { - polars_err!( - SchemaFieldNotFound: - "path contains column not present in the given Hive schema: {:?}", name - ) - })?; - Some(dtype) - }, - None => None, - }; - - value_to_series(name, value, dtype) -} - /// Parse a string value into a single-value [`Series`]. fn value_to_series(name: &str, value: &str, dtype: Option<&DataType>) -> PolarsResult { let fn_err = || polars_err!(ComputeError: "unable to parse Hive partition value: {:?}", value); diff --git a/py-polars/tests/unit/io/test_hive.py b/py-polars/tests/unit/io/test_hive.py index 23a7902c01ff..e73b6c357757 100644 --- a/py-polars/tests/unit/io/test_hive.py +++ b/py-polars/tests/unit/io/test_hive.py @@ -9,7 +9,7 @@ import polars as pl from polars.exceptions import DuplicateError, SchemaFieldNotFoundError -from polars.testing import assert_frame_equal +from polars.testing import assert_frame_equal, assert_series_equal @pytest.mark.skip( @@ -352,3 +352,32 @@ def test_hive_partition_directory_scan( out = scan(path, hive_partitioning=False).collect() assert_frame_equal(out, df) + + +def test_hive_partition_schema_inference(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + dfs = [ + pl.DataFrame({"x": 1}), + pl.DataFrame({"x": 2}), + pl.DataFrame({"x": 3}), + ] + + paths = [ + tmp_path / "a=1/data.bin", + tmp_path / "a=1.5/data.bin", + tmp_path / "a=polars/data.bin", + ] + + expected = [ + pl.Series("a", [1], dtype=pl.Int64), + pl.Series("a", [1.0, 1.5], dtype=pl.Float64), + pl.Series("a", ["1", "1.5", "polars"], dtype=pl.String), + ] + + for i in range(3): + paths[i].parent.mkdir(exist_ok=True, parents=True) + dfs[i].write_parquet(paths[i]) + out = pl.scan_parquet(tmp_path / "**/*.bin").collect() + + assert_series_equal(out["a"], expected[i])