Skip to content

Commit

Permalink
fix: Proper dtype casting for struct embedded categoricals in chunked…
Browse files Browse the repository at this point in the history
… categoricals (#18815)
  • Loading branch information
coastalwhite authored Sep 20, 2024
1 parent b6263e2 commit aff3a86
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
11 changes: 10 additions & 1 deletion crates/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,6 @@ impl Display for DataType {
}

pub fn merge_dtypes(left: &DataType, right: &DataType) -> PolarsResult<DataType> {
// TODO! add struct
use DataType::*;
Ok(match (left, right) {
#[cfg(feature = "dtype-categorical")]
Expand Down Expand Up @@ -794,6 +793,16 @@ pub fn merge_dtypes(left: &DataType, right: &DataType) -> PolarsResult<DataType>
let merged = merge_dtypes(inner_l, inner_r)?;
List(Box::new(merged))
},
#[cfg(feature = "dtype-struct")]
(Struct(inner_l), Struct(inner_r)) => {
polars_ensure!(inner_l.len() == inner_r.len(), ComputeError: "cannot combine structs with differing amounts of fields ({} != {})", inner_l.len(), inner_r.len());
let fields = inner_l.iter().zip(inner_r.iter()).map(|(l, r)| {
polars_ensure!(l.name() == r.name(), ComputeError: "cannot combine structs with different fields ({} != {})", l.name(), r.name());
let merged = merge_dtypes(l.dtype(), r.dtype())?;
Ok(Field::new(l.name().clone(), merged))
}).collect::<PolarsResult<Vec<_>>>()?;
Struct(fields)
},
#[cfg(feature = "dtype-array")]
(Array(inner_l, width_l), Array(inner_r, width_r)) => {
polars_ensure!(width_l == width_r, ComputeError: "widths of FixedSizeWidth Series are not equal");
Expand Down
19 changes: 19 additions & 0 deletions py-polars/tests/unit/datatypes/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,3 +845,22 @@ def test_get_cat_categories_multiple_chunks() -> None:
)
df_cat = df.lazy().select(pl.col("e").cat.get_categories()).collect()
assert len(df_cat) == 2


@pytest.mark.parametrize(
"f",
[
lambda x: (pl.List(pl.Categorical), [x]),
lambda x: (pl.Struct({"a": pl.Categorical}), {"a": x}),
],
)
def test_nested_categorical_concat(
f: Callable[[str], tuple[pl.DataType, list[str] | dict[str, str]]],
) -> None:
dt, va = f("a")
_, vb = f("b")
a = pl.DataFrame({"x": [va]}, schema={"x": dt})
b = pl.DataFrame({"x": [vb]}, schema={"x": dt})

with pytest.raises(pl.exceptions.StringCacheMismatchError):
pl.concat([a, b])

0 comments on commit aff3a86

Please sign in to comment.