From 50d57adbc8a433710a81e15fd14dc298e30a0931 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 14 Jun 2024 17:38:00 -0700 Subject: [PATCH 1/3] fix --- src/daft-plan/src/builder.rs | 42 ++++++++++++++++++----------- tests/dataframe/test_with_column.py | 9 +++++++ 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index 7541559ad5..7088042c85 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -1,4 +1,7 @@ -use std::{collections::HashSet, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; use crate::{ logical_ops, @@ -215,22 +218,31 @@ impl LogicalPlanBuilder { ) -> DaftResult { err_if_agg("with_columns", &columns)?; - let new_col_names = columns.iter().map(|e| e.name()).collect::>(); - - let mut exprs = self - .schema() - .fields + let fields = &self.schema().fields; + let current_col_names = fields .iter() - .filter_map(|(name, _)| { - if new_col_names.contains(name.as_str()) { - None - } else { - Some(col(name.clone())) - } - }) - .collect::>(); + .map(|(name, _)| name.as_str()) + .collect::>(); + let new_col_name_and_exprs = columns + .iter() + .map(|e| (e.name(), e.clone())) + .collect::>(); + + let mut exprs = vec![]; + for (name, _) in fields.iter() { + if let Some(expr) = new_col_name_and_exprs.get(name.as_str()) { + exprs.push(expr.clone()); + } else { + exprs.push(col(name.clone())); + } + } - exprs.extend(columns); + exprs.extend( + columns + .iter() + .filter(|e| !current_col_names.contains(e.name())) + .cloned(), + ); let logical_plan: LogicalPlan = logical_ops::Project::try_new(self.plan.clone(), exprs, resource_request)?.into(); diff --git a/tests/dataframe/test_with_column.py b/tests/dataframe/test_with_column.py index 8e2e410464..0432b10b0d 100644 --- a/tests/dataframe/test_with_column.py +++ b/tests/dataframe/test_with_column.py @@ -9,6 +9,15 @@ def test_with_column(make_df, valid_data: list[dict[str, float]]) -> None: assert data["bar"] == [sw + pl for sw, pl in zip(data["sepal_width"], data["petal_length"])] +def test_with_column_same_name(make_df, valid_data: list[dict[str, float]]) -> None: + df = make_df(valid_data) + expanded_df = df.with_column("sepal_width", df["sepal_width"] + df["petal_length"]) + data = expanded_df.to_pydict() + assert expanded_df.column_names == df.column_names + expected = [valid_data[i]["sepal_width"] + valid_data[i]["petal_length"] for i in range(len(valid_data))] + assert data["sepal_width"] == expected + + def test_stacked_with_columns(make_df, valid_data: list[dict[str, float]]): df = make_df(valid_data) df = df.select(df["sepal_length"]) From b85e540e044d4e8348be271d8936e7a05d7e707e Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 14 Jun 2024 17:53:48 -0700 Subject: [PATCH 2/3] clean --- src/daft-plan/src/builder.rs | 17 +++++++++-------- tests/dataframe/test_with_columns.py | 17 +++++++++++++++++ 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index 7088042c85..d29ec01bfb 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -228,14 +228,15 @@ impl LogicalPlanBuilder { .map(|e| (e.name(), e.clone())) .collect::>(); - let mut exprs = vec![]; - for (name, _) in fields.iter() { - if let Some(expr) = new_col_name_and_exprs.get(name.as_str()) { - exprs.push(expr.clone()); - } else { - exprs.push(col(name.clone())); - } - } + let mut exprs = fields + .iter() + .map(|(name, _)| { + new_col_name_and_exprs + .get(name.as_str()) + .cloned() + .unwrap_or_else(|| col(name.clone())) + }) + .collect::>(); exprs.extend( columns diff --git a/tests/dataframe/test_with_columns.py b/tests/dataframe/test_with_columns.py index e86c675a41..3d4800ff00 100644 --- a/tests/dataframe/test_with_columns.py +++ b/tests/dataframe/test_with_columns.py @@ -10,6 +10,23 @@ def test_with_columns(make_df, valid_data: list[dict[str, float]]) -> None: assert data["bar"] == [sw + pl for sw, pl in zip(data["sepal_width"], data["petal_length"])] +def test_with_columns_same_name(make_df, valid_data: list[dict[str, float]]) -> None: + df = make_df(valid_data) + expanded_df = df.with_columns( + {"sepal_length": df["sepal_length"] + df["sepal_width"], "petal_length": df["petal_length"] + df["petal_width"]} + ) + data = expanded_df.to_pydict() + assert expanded_df.column_names == df.column_names + expected_sepal_length = [ + valid_data[i]["sepal_length"] + valid_data[i]["sepal_width"] for i in range(len(valid_data)) + ] + expected_petal_length = [ + valid_data[i]["petal_length"] + valid_data[i]["petal_width"] for i in range(len(valid_data)) + ] + assert data["sepal_length"] == expected_sepal_length + assert data["petal_length"] == expected_petal_length + + def test_with_columns_empty(make_df, valid_data: list[dict[str, float]]) -> None: df = make_df(valid_data) expanded_df = df.with_columns({}) From e965d44cd9c40b8567c97e67bda3362e6c4190c7 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 17 Jun 2024 15:58:23 -0700 Subject: [PATCH 3/3] select tests --- tests/dataframe/test_select.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/dataframe/test_select.py b/tests/dataframe/test_select.py index 24b562f043..7b317dbb9b 100644 --- a/tests/dataframe/test_select.py +++ b/tests/dataframe/test_select.py @@ -22,3 +22,12 @@ def test_multiple_select_same_col(make_df, valid_data: list[dict[str, float]]): pdf = df.to_pandas() assert len(pdf.columns) == 2 assert pdf.columns.to_list() == ["sepal_length", "sepal_length_2"] + + +def test_select_ordering(make_df, valid_data: list[dict[str, float]]): + df = make_df(valid_data) + df = df.select( + df["variety"], df["petal_length"].alias("foo"), df["sepal_length"], df["sepal_width"], df["petal_width"] + ) + df = df.collect() + assert df.column_names == ["variety", "foo", "sepal_length", "sepal_width", "petal_width"]