diff --git a/crates/polars-mem-engine/src/executors/projection_utils.rs b/crates/polars-mem-engine/src/executors/projection_utils.rs index 8287c923969a..6715796c5d30 100644 --- a/crates/polars-mem-engine/src/executors/projection_utils.rs +++ b/crates/polars-mem-engine/src/executors/projection_utils.rs @@ -291,6 +291,7 @@ pub(super) fn check_expand_literals( } } } + // If all series are the same length it is ok. If not we can broadcast Series of length one. if !all_equal_len && should_broadcast { selected_columns = selected_columns @@ -300,24 +301,26 @@ pub(super) fn check_expand_literals( Ok(match series.len() { 0 if df_height == 1 => series, 1 => { - if has_empty { - polars_ensure!(df_height == 1, - ComputeError: "Series length {} doesn't match the DataFrame height of {}", - series.len(), df_height - ); - series.slice(0, 0) - } else if df_height == 1 { + if !has_empty && df_height == 1 { series } else { + if has_empty { + polars_ensure!(df_height == 1, + ShapeMismatch: "Series length {} doesn't match the DataFrame height of {}", + series.len(), df_height + ); + + } + if verify_scalar { polars_ensure!(phys.is_scalar(), - InvalidOperation: "Series: {}, length {} doesn't match the DataFrame height of {}\n\n\ + ShapeMismatch: "Series: {}, length {} doesn't match the DataFrame height of {}\n\n\ If you want this Series to be broadcasted, ensure it is a scalar (for instance by adding '.first()').", - series.name(), series.len(), df_height + series.name(), series.len(), df_height *(!has_empty as usize) ); } - series.new_from_index(0, df_height) + series.new_from_index(0, df_height * (!has_empty as usize) ) } }, len if len == df_height => { @@ -325,7 +328,7 @@ pub(super) fn check_expand_literals( }, _ => { polars_bail!( - ComputeError: "Series length {} doesn't match the DataFrame height of {}", + ShapeMismatch: "Series length {} doesn't match the DataFrame height of {}", series.len(), df_height ) } diff --git a/py-polars/tests/unit/constructors/test_dataframe.py b/py-polars/tests/unit/constructors/test_dataframe.py index 5e56630e7552..251ec5e7bce2 100644 --- a/py-polars/tests/unit/constructors/test_dataframe.py +++ b/py-polars/tests/unit/constructors/test_dataframe.py @@ -151,18 +151,6 @@ def test_df_init_nested_mixed_types() -> None: assert df.to_dicts() == [{"key": [{"value": 1.0}, {"value": 1.0}]}] -def test_unit_and_empty_construction_15896() -> None: - # This is still incorrect. - # We should raise, but currently for len 1 dfs, - # we cannot tell if they come from a literal or expression. - assert "shape: (0, 2)" in str( - pl.DataFrame({"A": [0]}).select( - C="A", - A=pl.int_range("A"), # creates empty series - ) - ) - - class CustomSchema(Mapping[str, Any]): """Dummy schema object for testing compatibility with Mapping.""" diff --git a/py-polars/tests/unit/dataframe/test_shape.py b/py-polars/tests/unit/dataframe/test_shape.py new file mode 100644 index 000000000000..2409ee0c2f3f --- /dev/null +++ b/py-polars/tests/unit/dataframe/test_shape.py @@ -0,0 +1,11 @@ +import pytest + +import polars as pl + + +# TODO: remove this skip when streaming raises +@pytest.mark.may_fail_auto_streaming +def test_raise_invalid_shape_19108() -> None: + df = pl.DataFrame({"foo": [1, 2], "bar": [3, 4]}) + with pytest.raises(pl.exceptions.ShapeError): + df.select(pl.col.foo.head(0), pl.col.bar.head(1)) diff --git a/py-polars/tests/unit/lazyframe/test_with_context.py b/py-polars/tests/unit/lazyframe/test_with_context.py index 1c084c7f7c34..5a224a16c0a9 100644 --- a/py-polars/tests/unit/lazyframe/test_with_context.py +++ b/py-polars/tests/unit/lazyframe/test_with_context.py @@ -3,7 +3,6 @@ import pytest import polars as pl -from polars.exceptions import ComputeError from polars.testing import assert_frame_equal @@ -19,7 +18,7 @@ def test_with_context() -> None: with pytest.deprecated_call(): context = df_a.with_context(df_b.lazy()) - with pytest.raises(ComputeError): + with pytest.raises(pl.exceptions.ShapeError): context.select("a", "c").collect() diff --git a/py-polars/tests/unit/test_scalar.py b/py-polars/tests/unit/test_scalar.py index d1f354d8e48e..6fa59a6c323d 100644 --- a/py-polars/tests/unit/test_scalar.py +++ b/py-polars/tests/unit/test_scalar.py @@ -11,7 +11,7 @@ def test_invalid_broadcast() -> None: "group": [0, 1], } ) - with pytest.raises(pl.exceptions.InvalidOperationError): + with pytest.raises(pl.exceptions.ShapeError): df.select(pl.col("group").filter(pl.col("group") == 0), "a")