diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs index 9eec5dff66f5..956d055a52c2 100644 --- a/crates/polars-core/src/datatypes/dtype.rs +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -764,7 +764,6 @@ impl Display for DataType { } pub fn merge_dtypes(left: &DataType, right: &DataType) -> PolarsResult { - // TODO! add struct use DataType::*; Ok(match (left, right) { #[cfg(feature = "dtype-categorical")] @@ -794,6 +793,16 @@ pub fn merge_dtypes(left: &DataType, right: &DataType) -> PolarsResult let merged = merge_dtypes(inner_l, inner_r)?; List(Box::new(merged)) }, + #[cfg(feature = "dtype-struct")] + (Struct(inner_l), Struct(inner_r)) => { + polars_ensure!(inner_l.len() == inner_r.len(), ComputeError: "cannot combine structs with differing amounts of fields ({} != {})", inner_l.len(), inner_r.len()); + let fields = inner_l.iter().zip(inner_r.iter()).map(|(l, r)| { + polars_ensure!(l.name() == r.name(), ComputeError: "cannot combine structs with different fields ({} != {})", l.name(), r.name()); + let merged = merge_dtypes(l.dtype(), r.dtype())?; + Ok(Field::new(l.name().clone(), merged)) + }).collect::>>()?; + Struct(fields) + }, #[cfg(feature = "dtype-array")] (Array(inner_l, width_l), Array(inner_r, width_r)) => { polars_ensure!(width_l == width_r, ComputeError: "widths of FixedSizeWidth Series are not equal"); diff --git a/py-polars/tests/unit/datatypes/test_categorical.py b/py-polars/tests/unit/datatypes/test_categorical.py index b898c7c07999..c5888abee67e 100644 --- a/py-polars/tests/unit/datatypes/test_categorical.py +++ b/py-polars/tests/unit/datatypes/test_categorical.py @@ -845,3 +845,22 @@ def test_get_cat_categories_multiple_chunks() -> None: ) df_cat = df.lazy().select(pl.col("e").cat.get_categories()).collect() assert len(df_cat) == 2 + + +@pytest.mark.parametrize( + "f", + [ + lambda x: (pl.List(pl.Categorical), [x]), + lambda x: (pl.Struct({"a": pl.Categorical}), {"a": x}), + ], +) +def test_nested_categorical_concat( + f: Callable[[str], tuple[pl.DataType, list[str] | dict[str, str]]], +) -> None: + dt, va = f("a") + _, vb = f("b") + a = pl.DataFrame({"x": [va]}, schema={"x": dt}) + b = pl.DataFrame({"x": [vb]}, schema={"x": dt}) + + with pytest.raises(pl.exceptions.StringCacheMismatchError): + pl.concat([a, b])