Skip to content

Commit

Permalink
feat: Improve schema inference for hive partitions (#17079)
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion authored Jun 22, 2024
1 parent a613e5a commit 4d6eec1
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 125 deletions.
43 changes: 20 additions & 23 deletions crates/polars-io/src/csv/read/schema_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,24 @@ impl CsvReadOptions {
}
}

pub fn finish_infer_field_schema(possibilities: &PlHashSet<DataType>) -> DataType {
// determine data type based on possible types
// if there are incompatible types, use DataType::String
match possibilities.len() {
1 => possibilities.iter().next().unwrap().clone(),
2 if possibilities.contains(&DataType::Int64)
&& possibilities.contains(&DataType::Float64) =>
{
// we have an integer and double, fall down to double
DataType::Float64
},
// default to String for conflicting datatypes (e.g bool and int)
_ => DataType::String,
}
}

/// Infer the data type of a record
fn infer_field_schema(string: &str, try_parse_dates: bool, decimal_comma: bool) -> DataType {
pub fn infer_field_schema(string: &str, try_parse_dates: bool, decimal_comma: bool) -> DataType {
// when quoting is enabled in the reader, these quotes aren't escaped, we default to
// String for them
if string.starts_with('"') {
Expand Down Expand Up @@ -428,7 +444,6 @@ fn infer_file_schema_inner(

// build schema from inference results
for i in 0..header_length {
let possibilities = &column_types[i];
let field_name = &headers[i];

if let Some(schema_overwrite) = schema_overwrite {
Expand All @@ -447,27 +462,9 @@ fn infer_file_schema_inner(
}
}

// determine data type based on possible types
// if there are incompatible types, use DataType::String
match possibilities.len() {
1 => {
for dtype in possibilities.iter() {
fields.push(Field::new(field_name, dtype.clone()));
}
},
2 => {
if possibilities.contains(&DataType::Int64)
&& possibilities.contains(&DataType::Float64)
{
// we have an integer and double, fall down to double
fields.push(Field::new(field_name, DataType::Float64));
} else {
// default to String for conflicting datatypes (e.g bool and int)
fields.push(Field::new(field_name, DataType::String));
}
},
_ => fields.push(Field::new(field_name, DataType::String)),
}
let possibilities = &column_types[i];
let dtype = finish_infer_field_schema(possibilities);
fields.push(Field::new(field_name, dtype));
}
// if there is a single line after the header without an eol
// we copy the bytes add an eol and rerun this function
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ doctest = false
libloading = { version = "0.8.0", optional = true }
polars-core = { workspace = true, features = ["lazy", "zip_with", "random"] }
polars-ffi = { workspace = true, optional = true }
polars-io = { workspace = true, features = ["lazy"] }
polars-io = { workspace = true, features = ["lazy", "csv"] }
polars-json = { workspace = true, optional = true }
polars-ops = { workspace = true, features = [] }
polars-parquet = { workspace = true, optional = true }
Expand Down
182 changes: 82 additions & 100 deletions crates/polars-plan/src/plans/hive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::path::{Path, PathBuf};
use percent_encoding::percent_decode_str;
use polars_core::prelude::*;
use polars_io::predicates::{BatchStats, ColumnStats};
use polars_io::prelude::schema_inference::{finish_infer_field_schema, infer_field_schema};
use polars_io::utils::{BOOLEAN_RE, FLOAT_RE, INTEGER_RE};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
Expand All @@ -16,70 +17,6 @@ pub struct HivePartitions {
}

impl HivePartitions {
/// Constructs a new [`HivePartitions`] from a schema reference.
pub fn from_schema_ref(schema: SchemaRef) -> Self {
let column_stats = schema.iter_fields().map(ColumnStats::from_field).collect();
let stats = BatchStats::new(schema, column_stats, None);
Self { stats }
}

/// Constructs a new [`HivePartitions`] from a path.
///
/// Returns `None` if the path does not contain any Hive partitions.
/// Returns `Err` if the Hive partitions cannot be parsed correctly or do not match the given
/// [`Schema`].
pub fn try_from_path(path: &Path, schema: Option<SchemaRef>) -> PolarsResult<Option<Self>> {
let sep = separator(path);

let path_string = path.display().to_string();
let path_parts = path_string.split(sep);

// Last part is the file, which should be skipped.
let file_index = path_parts.clone().count() - 1;

let partitions = path_parts
.enumerate()
.filter_map(|(index, part)| {
if index == file_index {
return None;
}
parse_hive_string(part)
})
.map(|(name, value)| hive_info_to_series(name, value, schema.clone()))
.collect::<PolarsResult<Vec<_>>>()?;

if partitions.is_empty() {
return Ok(None);
}

let schema = match schema {
Some(schema) => Arc::new(
partitions
.iter()
.map(|s| {
let mut field = s.field().into_owned();
if let Some(dtype) = schema.get(field.name()) {
field.dtype = dtype.clone();
};
field
})
.collect::<Schema>(),
),
None => Arc::new(partitions.as_slice().into()),
};

let stats = BatchStats::new(
schema,
partitions
.into_iter()
.map(ColumnStats::from_column_literal)
.collect(),
None,
);

Ok(Some(HivePartitions { stats }))
}

pub fn get_statistics(&self) -> &BatchStats {
&self.stats
}
Expand Down Expand Up @@ -108,33 +45,96 @@ pub fn hive_partitions_from_paths(
return Ok(None);
};

let Some(hive_parts) = HivePartitions::try_from_path(
&PathBuf::from(&path.to_str().unwrap()[hive_start_idx..]),
schema.clone(),
)?
else {
return Ok(None);
let sep = separator(path);
let path_string = path.to_str().unwrap();

macro_rules! get_hive_parts_iter {
($e:expr) => {{
let path_parts = $e[hive_start_idx..].split(sep);
let file_index = path_parts.clone().count() - 1;

path_parts.enumerate().filter_map(move |(index, part)| {
if index == file_index {
return None;
}
parse_hive_string(part)
})
}};
}

let hive_schema = if let Some(v) = schema {
v
} else {
let mut hive_schema = Schema::with_capacity(16);
let mut schema_inference_map: PlHashMap<&str, PlHashSet<DataType>> =
PlHashMap::with_capacity(16);

for (name, _) in get_hive_parts_iter!(path_string) {
hive_schema.insert_at_index(hive_schema.len(), name.into(), DataType::String)?;
schema_inference_map.insert(name, PlHashSet::with_capacity(4));
}

if hive_schema.is_empty() && schema_inference_map.is_empty() {
return Ok(None);
}

if !schema_inference_map.is_empty() {
for path in paths {
for (name, value) in get_hive_parts_iter!(path.to_str().unwrap()) {
let Some(entry) = schema_inference_map.get_mut(name) else {
continue;
};

entry.insert(infer_field_schema(value, false, false));
}
}

for (name, ref possibilities) in schema_inference_map.drain() {
let dtype = finish_infer_field_schema(possibilities);
*hive_schema.try_get_mut(name).unwrap() = dtype;
}
}
Arc::new(hive_schema)
};

let mut results = Vec::with_capacity(paths.len());
results.push(Arc::new(hive_parts));
let mut hive_partitions = Vec::with_capacity(paths.len());

for path in paths {
let path = path.to_str().unwrap();

let column_stats = get_hive_parts_iter!(path)
.map(|(name, value)| {
let Some(dtype) = hive_schema.as_ref().get(name) else {
polars_bail!(
SchemaFieldNotFound:
"path contains column not present in the given Hive schema: {:?}, path = {:?}",
name,
path
)
};

Ok(ColumnStats::from_column_literal(value_to_series(
name,
value,
Some(dtype),
)?))
})
.collect::<PolarsResult<Vec<_>>>()?;

for path in &paths[1..] {
let Some(hive_parts) = HivePartitions::try_from_path(
&PathBuf::from(&path.to_str().unwrap()[hive_start_idx..]),
schema.clone(),
)?
else {
if column_stats.is_empty() {
polars_bail!(
ComputeError: "expected Hive partitioned path, got {}\n\n\
This error occurs if some paths are Hive partitioned and some paths are not.",
path.display()
path
)
};
results.push(Arc::new(hive_parts));
}

let stats = BatchStats::new(hive_schema.clone(), column_stats, None);

hive_partitions.push(Arc::new(HivePartitions { stats }));
}

Ok(Some(results))
Ok(Some(hive_partitions))
}

/// Determine the path separator for identifying Hive partitions.
Expand Down Expand Up @@ -174,24 +174,6 @@ fn parse_hive_string(part: &'_ str) -> Option<(&'_ str, &'_ str)> {
Some((name, value))
}

/// Convert Hive partition string information to a single-value [`Series`].
fn hive_info_to_series(name: &str, value: &str, schema: Option<SchemaRef>) -> PolarsResult<Series> {
let dtype = match schema {
Some(ref s) => {
let dtype = s.try_get(name).map_err(|_| {
polars_err!(
SchemaFieldNotFound:
"path contains column not present in the given Hive schema: {:?}", name
)
})?;
Some(dtype)
},
None => None,
};

value_to_series(name, value, dtype)
}

/// Parse a string value into a single-value [`Series`].
fn value_to_series(name: &str, value: &str, dtype: Option<&DataType>) -> PolarsResult<Series> {
let fn_err = || polars_err!(ComputeError: "unable to parse Hive partition value: {:?}", value);
Expand Down
31 changes: 30 additions & 1 deletion py-polars/tests/unit/io/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import polars as pl
from polars.exceptions import DuplicateError, SchemaFieldNotFoundError
from polars.testing import assert_frame_equal
from polars.testing import assert_frame_equal, assert_series_equal


@pytest.mark.skip(
Expand Down Expand Up @@ -352,3 +352,32 @@ def test_hive_partition_directory_scan(
out = scan(path, hive_partitioning=False).collect()

assert_frame_equal(out, df)


def test_hive_partition_schema_inference(tmp_path: Path) -> None:
tmp_path.mkdir(exist_ok=True)

dfs = [
pl.DataFrame({"x": 1}),
pl.DataFrame({"x": 2}),
pl.DataFrame({"x": 3}),
]

paths = [
tmp_path / "a=1/data.bin",
tmp_path / "a=1.5/data.bin",
tmp_path / "a=polars/data.bin",
]

expected = [
pl.Series("a", [1], dtype=pl.Int64),
pl.Series("a", [1.0, 1.5], dtype=pl.Float64),
pl.Series("a", ["1", "1.5", "polars"], dtype=pl.String),
]

for i in range(3):
paths[i].parent.mkdir(exist_ok=True, parents=True)
dfs[i].write_parquet(paths[i])
out = pl.scan_parquet(tmp_path / "**/*.bin").collect()

assert_series_equal(out["a"], expected[i])

0 comments on commit 4d6eec1

Please sign in to comment.