Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] with_column with existing column name should not reorder columns #2381

Merged
merged 3 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions src/daft-plan/src/builder.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::{collections::HashSet, sync::Arc};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};

use crate::{
logical_ops,
Expand Down Expand Up @@ -215,22 +218,32 @@ impl LogicalPlanBuilder {
) -> DaftResult<Self> {
err_if_agg("with_columns", &columns)?;

let new_col_names = columns.iter().map(|e| e.name()).collect::<HashSet<&str>>();
let fields = &self.schema().fields;
let current_col_names = fields
.iter()
.map(|(name, _)| name.as_str())
.collect::<HashSet<_>>();
let new_col_name_and_exprs = columns
.iter()
.map(|e| (e.name(), e.clone()))
.collect::<HashMap<_, _>>();

let mut exprs = self
.schema()
.fields
let mut exprs = fields
.iter()
.filter_map(|(name, _)| {
if new_col_names.contains(name.as_str()) {
None
} else {
Some(col(name.clone()))
}
.map(|(name, _)| {
new_col_name_and_exprs
.get(name.as_str())
.cloned()
.unwrap_or_else(|| col(name.clone()))
})
.collect::<Vec<_>>();

exprs.extend(columns);
exprs.extend(
columns
.iter()
.filter(|e| !current_col_names.contains(e.name()))
.cloned(),
);

let logical_plan: LogicalPlan =
logical_ops::Project::try_new(self.plan.clone(), exprs, resource_request)?.into();
Expand Down
9 changes: 9 additions & 0 deletions tests/dataframe/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,12 @@ def test_multiple_select_same_col(make_df, valid_data: list[dict[str, float]]):
pdf = df.to_pandas()
assert len(pdf.columns) == 2
assert pdf.columns.to_list() == ["sepal_length", "sepal_length_2"]


def test_select_ordering(make_df, valid_data: list[dict[str, float]]):
df = make_df(valid_data)
df = df.select(
df["variety"], df["petal_length"].alias("foo"), df["sepal_length"], df["sepal_width"], df["petal_width"]
)
df = df.collect()
assert df.column_names == ["variety", "foo", "sepal_length", "sepal_width", "petal_width"]
9 changes: 9 additions & 0 deletions tests/dataframe/test_with_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@ def test_with_column(make_df, valid_data: list[dict[str, float]]) -> None:
assert data["bar"] == [sw + pl for sw, pl in zip(data["sepal_width"], data["petal_length"])]


def test_with_column_same_name(make_df, valid_data: list[dict[str, float]]) -> None:
df = make_df(valid_data)
expanded_df = df.with_column("sepal_width", df["sepal_width"] + df["petal_length"])
data = expanded_df.to_pydict()
assert expanded_df.column_names == df.column_names
expected = [valid_data[i]["sepal_width"] + valid_data[i]["petal_length"] for i in range(len(valid_data))]
assert data["sepal_width"] == expected


def test_stacked_with_columns(make_df, valid_data: list[dict[str, float]]):
df = make_df(valid_data)
df = df.select(df["sepal_length"])
Expand Down
17 changes: 17 additions & 0 deletions tests/dataframe/test_with_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,23 @@ def test_with_columns(make_df, valid_data: list[dict[str, float]]) -> None:
assert data["bar"] == [sw + pl for sw, pl in zip(data["sepal_width"], data["petal_length"])]


def test_with_columns_same_name(make_df, valid_data: list[dict[str, float]]) -> None:
df = make_df(valid_data)
expanded_df = df.with_columns(
{"sepal_length": df["sepal_length"] + df["sepal_width"], "petal_length": df["petal_length"] + df["petal_width"]}
)
data = expanded_df.to_pydict()
assert expanded_df.column_names == df.column_names
expected_sepal_length = [
valid_data[i]["sepal_length"] + valid_data[i]["sepal_width"] for i in range(len(valid_data))
]
expected_petal_length = [
valid_data[i]["petal_length"] + valid_data[i]["petal_width"] for i in range(len(valid_data))
]
assert data["sepal_length"] == expected_sepal_length
assert data["petal_length"] == expected_petal_length


def test_with_columns_empty(make_df, valid_data: list[dict[str, float]]) -> None:
df = make_df(valid_data)
expanded_df = df.with_columns({})
Expand Down
Loading