From 3bb80507d250f6fd30202c03ed23538a89184fb4 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Fri, 21 Jun 2024 13:58:18 +0400 Subject: [PATCH 1/7] feat: Support `Struct` field selection in the SQL engine --- crates/polars-sql/Cargo.toml | 2 +- crates/polars-sql/src/context.rs | 12 +++- crates/polars-sql/src/sql_expr.rs | 74 +++++++++++++----------- py-polars/tests/unit/sql/test_structs.py | 60 +++++++++++++++++++ 4 files changed, 112 insertions(+), 36 deletions(-) create mode 100644 py-polars/tests/unit/sql/test_structs.py diff --git a/crates/polars-sql/Cargo.toml b/crates/polars-sql/Cargo.toml index 83dcd5b98a3c..0c8f883daf50 100644 --- a/crates/polars-sql/Cargo.toml +++ b/crates/polars-sql/Cargo.toml @@ -12,7 +12,7 @@ description = "SQL transpiler for Polars. Converts SQL to Polars logical plans" arrow = { workspace = true } polars-core = { workspace = true, features = ["rows"] } polars-error = { workspace = true } -polars-lazy = { workspace = true, features = ["abs", "binary_encoding", "concat_str", "cross_join", "cum_agg", "dtype-date", "dtype-decimal", "is_in", "list_eval", "log", "meta", "regex", "round_series", "sign", "string_reverse", "strings", "timezones", "trigonometry"] } +polars-lazy = { workspace = true, features = ["abs", "binary_encoding", "concat_str", "cross_join", "cum_agg", "dtype-date", "dtype-decimal", "dtype-struct", "is_in", "list_eval", "log", "meta", "regex", "round_series", "sign", "string_reverse", "strings", "timezones", "trigonometry"] } polars-ops = { workspace = true } polars-plan = { workspace = true } polars-time = { workspace = true } diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 496b0d0070d0..5e44f1956332 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -1179,9 +1179,17 @@ impl SQLContext { options: &WildcardAdditionalOptions, contains_wildcard_exclude: &mut bool, ) -> PolarsResult { - if options.opt_except.is_some() { - polars_bail!(SQLSyntax: "EXCEPT not supported (use EXCLUDE instead)") + // bail on unsupported wildcard options + if options.opt_ilike.is_some() { + polars_bail!(SQLSyntax: "ILIKE wildcard option is unsupported") + } else if options.opt_rename.is_some() { + polars_bail!(SQLSyntax: "RENAME wildcard option is unsupported") + } else if options.opt_replace.is_some() { + polars_bail!(SQLSyntax: "REPLACE wildcard option is unsupported") + } else if options.opt_except.is_some() { + polars_bail!(SQLSyntax: "EXCEPT wildcard option is unsupported (use EXCLUDE instead)") } + Ok(match &options.opt_exclude { Some(ExcludeSelectItem::Single(ident)) => { *contains_wildcard_exclude = true; diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 447f0486baab..39a9eca5c818 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -375,42 +375,50 @@ impl SQLExprVisitor<'_> { /// Visit a compound SQL identifier /// - /// e.g. df.column or "df"."column" + /// e.g. tbl.column, struct.field, tbl.struct.field (inc. nested struct fields) fn visit_compound_identifier(&mut self, idents: &[Ident]) -> PolarsResult { - match idents { - [tbl_name, column_name] => { - let mut lf = self - .ctx - .get_table_from_current_scope(&tbl_name.value) - .ok_or_else(|| { - polars_err!( - SQLInterface: "no table or alias named '{}' found", - tbl_name - ) - })?; - - let schema = - lf.schema_with_arenas(&mut self.ctx.lp_arena, &mut self.ctx.expr_arena)?; - if let Some((_, name, _)) = schema.get_full(&column_name.value) { - let resolved = &self.ctx.resolve_name(&tbl_name.value, &column_name.value); - Ok(if name != resolved { - col(resolved).alias(name) - } else { - col(name) - }) + // inference priority: table > struct > column + let ident_root = &idents[0]; + let mut remaining_idents = idents.iter().skip(1); + let mut lf = self.ctx.get_table_from_current_scope(&ident_root.value); + + let schema = if let Some(ref mut lf) = lf { + lf.schema_with_arenas(&mut self.ctx.lp_arena, &mut self.ctx.expr_arena) + } else { + Ok(Arc::new(if let Some(active_schema) = self.active_schema { + active_schema.clone() + } else { + Schema::new() + })) + }; + + let mut column: PolarsResult = if lf.is_none() { + Ok(col(&ident_root.value)) + } else { + let col_name = &remaining_idents.next().unwrap().value; + if let Some((_, name, _)) = schema?.get_full(col_name) { + let resolved = &self.ctx.resolve_name(&ident_root.value, col_name); + Ok(if name != resolved { + col(resolved).alias(name) } else { - polars_bail!( - SQLInterface: "no column named '{}' found in table '{}'", - column_name, - tbl_name - ) - } - }, - _ => polars_bail!( - SQLInterface: "invalid identifier {:?}", - idents - ), + col(name) + }) + } else { + polars_bail!( + SQLInterface: "no column named '{}' found in table '{}'", + col_name, + ident_root + ) + } + }; + // additional ident levels index into struct fields + for ident in remaining_idents { + column = Ok(column + .unwrap() + .struct_() + .field_by_name(ident.value.as_str())); } + column } fn visit_interval(&self, interval: &Interval) -> PolarsResult { diff --git a/py-polars/tests/unit/sql/test_structs.py b/py-polars/tests/unit/sql/test_structs.py new file mode 100644 index 000000000000..6f1cad494ac6 --- /dev/null +++ b/py-polars/tests/unit/sql/test_structs.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import pytest + +import polars as pl +from polars.exceptions import StructFieldNotFoundError +from polars.testing import assert_frame_equal + + +@pytest.fixture() +def struct_df() -> pl.DataFrame: + return pl.DataFrame( + { + "id": [100, 200, 300, 400], + "name": ["Alice", "Bob", "David", "Zoe"], + "age": [32, 27, 19, 45], + "other": [{"n": 1.5}, {"n": None}, {"n": -0.5}, {"n": 2.0}], + } + ).select(pl.struct(pl.all()).alias("json_msg")) + + +def test_struct_field_selection(struct_df: pl.DataFrame) -> None: + res = struct_df.sql( + """ + SELECT + -- validate table alias resolution + frame.json_msg.id AS ID, + self.json_msg.name AS NAME, + json_msg.age AS AGE + FROM + self AS frame + WHERE + json_msg.age > 20 AND + json_msg.other.n IS NOT NULL -- note: nested struct field + ORDER BY + json_msg.name DESC + """ + ) + + expected = pl.DataFrame( + { + "ID": [400, 100], + "NAME": ["Zoe", "Alice"], + "AGE": [45, 32], + } + ) + assert_frame_equal(expected, res) + + +@pytest.mark.parametrize( + "invalid_column", + [ + "json_msg.invalid_column", + "json_msg.other.invalid_column", + "self.json_msg.other.invalid_column", + ], +) +def test_struct_indexing_errors(invalid_column: str, struct_df: pl.DataFrame) -> None: + with pytest.raises(StructFieldNotFoundError, match="invalid_column"): + struct_df.sql(f"SELECT {invalid_column} FROM self") From 2cfdfb2dfb7ec0707e54d3b900bf7a2a40f2df44 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sat, 22 Jun 2024 00:01:40 +0400 Subject: [PATCH 2/7] add `struct.*` wildcard selection support --- .../src/plans/conversion/expr_expansion.rs | 2 +- .../polars-plan/src/plans/conversion/mod.rs | 2 +- crates/polars-plan/src/plans/mod.rs | 2 +- crates/polars-plan/src/prelude.rs | 1 + crates/polars-sql/src/context.rs | 56 ++++--- crates/polars-sql/src/sql_expr.rs | 17 +- crates/polars-sql/tests/simple_exprs.rs | 147 ++++++++++++++---- py-polars/tests/unit/sql/test_structs.py | 44 ++++-- 8 files changed, 187 insertions(+), 84 deletions(-) diff --git a/crates/polars-plan/src/plans/conversion/expr_expansion.rs b/crates/polars-plan/src/plans/conversion/expr_expansion.rs index 4f168d1080d1..fee7913dede7 100644 --- a/crates/polars-plan/src/plans/conversion/expr_expansion.rs +++ b/crates/polars-plan/src/plans/conversion/expr_expansion.rs @@ -634,7 +634,7 @@ fn find_flags(expr: &Expr) -> PolarsResult { /// In case of single col(*) -> do nothing, no selection is the same as select all /// In other cases replace the wildcard with an expression with all columns -pub(crate) fn rewrite_projections( +pub fn rewrite_projections( exprs: Vec, schema: &Schema, keys: &[Expr], diff --git a/crates/polars-plan/src/plans/conversion/mod.rs b/crates/polars-plan/src/plans/conversion/mod.rs index d0d6a41e9fb9..afdac2d300fc 100644 --- a/crates/polars-plan/src/plans/conversion/mod.rs +++ b/crates/polars-plan/src/plans/conversion/mod.rs @@ -1,6 +1,6 @@ mod convert_utils; mod dsl_to_ir; -mod expr_expansion; +pub(crate) mod expr_expansion; mod expr_to_ir; mod ir_to_dsl; #[cfg(any(feature = "ipc", feature = "parquet", feature = "csv"))] diff --git a/crates/polars-plan/src/plans/mod.rs b/crates/polars-plan/src/plans/mod.rs index ca9acc44cf53..9255c811e489 100644 --- a/crates/polars-plan/src/plans/mod.rs +++ b/crates/polars-plan/src/plans/mod.rs @@ -16,7 +16,7 @@ pub(crate) mod ir; mod apply; mod builder_dsl; mod builder_ir; -pub(crate) mod conversion; +pub mod conversion; #[cfg(feature = "debugging")] pub(crate) mod debug; pub mod expr_ir; diff --git a/crates/polars-plan/src/prelude.rs b/crates/polars-plan/src/prelude.rs index d90e032cc925..34c38cefbdab 100644 --- a/crates/polars-plan/src/prelude.rs +++ b/crates/polars-plan/src/prelude.rs @@ -11,6 +11,7 @@ pub(crate) use polars_time::prelude::*; pub use polars_utils::arena::{Arena, Node}; pub use crate::dsl::*; +pub use crate::plans::conversion::expr_expansion::rewrite_projections; #[cfg(feature = "debugging")] pub use crate::plans::debug::*; pub use crate::plans::options::*; diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 5e44f1956332..cce84fcc9596 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -6,7 +6,7 @@ use polars_lazy::prelude::*; use polars_ops::frame::JoinCoalesce; use polars_plan::prelude::*; use sqlparser::ast::{ - Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, JoinConstraint, + Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, Ident, JoinConstraint, JoinOperator, ObjectName, ObjectType, Offset, OrderByExpr, Query, Select, SelectItem, SetExpr, SetOperator, SetQuantifier, Statement, TableAlias, TableFactor, TableWithJoins, UnaryOperator, Value as SQLValue, Values, WildcardAdditionalOptions, @@ -600,36 +600,44 @@ impl SQLContext { lf = self.process_where(lf, &select_stmt.selection)?; // Column projections. - let projections: Vec<_> = select_stmt + let projections: Vec = select_stmt .projection .iter() .map(|select_item| { Ok(match select_item { - SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, self, schema.as_deref())?, + SelectItem::UnnamedExpr(expr) => { + vec![parse_sql_expr(expr, self, schema.as_deref())?] + }, SelectItem::ExprWithAlias { expr, alias } => { let expr = parse_sql_expr(expr, self, schema.as_deref())?; - expr.alias(&alias.value) + vec![expr.alias(&alias.value)] }, - SelectItem::QualifiedWildcard(oname, wildcard_options) => self - .process_qualified_wildcard( - oname, + SelectItem::QualifiedWildcard(obj_name, wildcard_options) => { + let expanded = self.process_qualified_wildcard( + obj_name, wildcard_options, &mut contains_wildcard_exclude, - )?, + schema.as_deref(), + )?; + rewrite_projections(vec![expanded], &(schema.clone().unwrap()), &[])? + }, SelectItem::Wildcard(wildcard_options) => { contains_wildcard = true; let e = col("*"); - self.process_wildcard_additional_options( + vec![self.process_wildcard_additional_options( e, wildcard_options, &mut contains_wildcard_exclude, - )? + )?] }, }) }) - .collect::>()?; + .collect::>>>()? + .into_iter() + .flatten() + .collect(); - // Check for "GROUP BY ..." (after projections, as there may be ordinal/position ints). + // Check for "GROUP BY ..." (after determining projections) let mut group_by_keys: Vec = Vec::new(); match &select_stmt.group_by { // Standard "GROUP BY x, y, z" syntax (also recognising ordinal values) @@ -1152,25 +1160,13 @@ impl SQLContext { ObjectName(idents): &ObjectName, options: &WildcardAdditionalOptions, contains_wildcard_exclude: &mut bool, + schema: Option<&Schema>, ) -> PolarsResult { - let idents = idents.as_slice(); - let e = match idents { - [tbl_name] => { - let lf = self.table_map.get_mut(&tbl_name.value).ok_or_else(|| { - polars_err!( - SQLInterface: "no table named '{}' found", - tbl_name - ) - })?; - let schema = lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; - cols(schema.iter_names()) - }, - e => polars_bail!( - SQLSyntax: "invalid wildcard expression ({:?})", - e - ), - }; - self.process_wildcard_additional_options(e, options, contains_wildcard_exclude) + let mut new_idents = idents.clone(); + new_idents.push(Ident::new("*")); + let identifier = SQLExpr::CompoundIdentifier(new_idents); + let expr = parse_sql_expr(&identifier, self, schema)?; + self.process_wildcard_additional_options(expr, options, contains_wildcard_exclude) } fn process_wildcard_additional_options( diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 39a9eca5c818..a4ae10997715 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -390,23 +390,28 @@ impl SQLExprVisitor<'_> { } else { Schema::new() })) - }; + }?; - let mut column: PolarsResult = if lf.is_none() { + let mut column: PolarsResult = if lf.is_none() && schema.is_empty() { Ok(col(&ident_root.value)) } else { - let col_name = &remaining_idents.next().unwrap().value; - if let Some((_, name, _)) = schema?.get_full(col_name) { - let resolved = &self.ctx.resolve_name(&ident_root.value, col_name); + let name = &remaining_idents.next().unwrap().value; + if lf.is_some() && name == "*" { + Ok(cols(schema.iter_names())) + } else if let Some((_, name, _)) = schema.get_full(name) { + let resolved = &self.ctx.resolve_name(&ident_root.value, name); Ok(if name != resolved { col(resolved).alias(name) } else { col(name) }) + } else if lf.is_none() { + remaining_idents = idents.iter().skip(1); + Ok(col(&ident_root.value)) } else { polars_bail!( SQLInterface: "no column named '{}' found in table '{}'", - col_name, + name, ident_root ) } diff --git a/crates/polars-sql/tests/simple_exprs.rs b/crates/polars-sql/tests/simple_exprs.rs index 0a60b9dc7aca..77980bf54f77 100644 --- a/crates/polars-sql/tests/simple_exprs.rs +++ b/crates/polars-sql/tests/simple_exprs.rs @@ -326,21 +326,21 @@ fn test_binary_functions() { SELECT a, b, - a + b as add, - a - b as sub, - a * b as mul, - a / b as div, - a % b as rem, - a <> b as neq, - a = b as eq, - a > b as gt, - a < b as lt, - a >= b as gte, - a <= b as lte, - a and b as and, - a or b as or, - a xor b as xor, - a || b as concat + a + b AS add, + a - b AS sub, + a * b AS mul, + a / b AS div, + a % b AS rem, + a <> b AS neq, + a = b AS eq, + a > b AS gt, + a < b AS lt, + a >= b AS gte, + a <= b AS lte, + a and b AS and, + a or b AS or, + a xor b AS xor, + a || b AS concat FROM df"#; let df_sql = context.execute(sql).unwrap().collect().unwrap(); let df_pl = df.lazy().select(&[ @@ -374,18 +374,18 @@ fn test_agg_functions() { context.register("df", df.clone().lazy()); let sql = r#" SELECT - sum(a) as sum_a, - first(a) as first_a, - last(a) as last_a, - avg(a) as avg_a, - max(a) as max_a, - min(a) as min_a, - atan(a) as atan_a, - stddev(a) as stddev_a, - variance(a) as variance_a, - count(a) as count_a, - count(distinct a) as count_distinct_a, - count(*) as count_all + sum(a) AS sum_a, + first(a) AS first_a, + last(a) AS last_a, + avg(a) AS avg_a, + max(a) AS max_a, + min(a) AS min_a, + atan(a) AS atan_a, + stddev(a) AS stddev_a, + variance(a) AS variance_a, + count(a) AS count_a, + count(distinct a) AS count_distinct_a, + count(*) AS count_all FROM df"#; let df_sql = context.execute(sql).unwrap().collect().unwrap(); let df_pl = df @@ -414,6 +414,7 @@ fn test_create_table() { let df = create_sample_df().unwrap(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); + let sql = r#" CREATE TABLE df2 AS SELECT a @@ -423,14 +424,15 @@ fn test_create_table() { "Response" => ["CREATE TABLE"] } .unwrap(); + assert!(df_sql.equals(&create_tbl_res)); let df_2 = context .execute(r#"SELECT a FROM df2"#) .unwrap() .collect() .unwrap(); - let expected = df.lazy().select(&[col("a")]).collect().unwrap(); + let expected = df.lazy().select(&[col("a")]).collect().unwrap(); assert!(df_2.equals(&expected)); } @@ -450,6 +452,7 @@ fn test_unary_minus_0() { .filter(col("value").lt(lit(-1))) .collect() .unwrap(); + assert!(df_sql.equals(&df_pl)); } @@ -478,7 +481,7 @@ fn test_arr_agg() { vec![col("a").implode().alias("a")], ), ( - "SELECT ARRAY_AGG(a) AS a, ARRAY_AGG(b) as b FROM df", + "SELECT ARRAY_AGG(a) AS a, ARRAY_AGG(b) AS b FROM df", vec![col("a").implode().alias("a"), col("b").implode().alias("b")], ), ( @@ -530,6 +533,23 @@ fn test_ctes() -> PolarsResult<()> { Ok(()) } +#[test] +fn test_cte_values() -> PolarsResult<()> { + let sql = r#" + WITH + x AS (SELECT w.* FROM (VALUES(1,2), (3,4)) AS w(a, b)), + y (m, n) AS ( + WITH z(c, d) AS (SELECT a, b FROM x) + SELECT d*2 AS d2, c*3 AS c3 FROM z + ) + SELECT n, m FROM y + "#; + let mut context = SQLContext::new(); + assert!(context.execute(sql).is_ok()); + + Ok(()) +} + #[test] #[cfg(feature = "ipc")] fn test_group_by_2() -> PolarsResult<()> { @@ -543,7 +563,7 @@ fn test_group_by_2() -> PolarsResult<()> { let sql = r#" SELECT category, - count(category) as count, + count(category) AS count, max(calories), min(fats_g) FROM foods @@ -566,6 +586,7 @@ fn test_group_by_2() -> PolarsResult<()> { SortMultipleOptions::default().with_order_descending_multi([false, true]), ) .limit(2); + let expected = expected.collect()?; assert!(df_sql.equals(&expected)); Ok(()) @@ -591,6 +612,7 @@ fn test_case_expr() { .then(lit("lteq_5")) .otherwise(lit("no match")) .alias("sign"); + let df_pl = df.lazy().select(&[case_expr]).collect().unwrap(); assert!(df_sql.equals(&df_pl)); } @@ -600,6 +622,7 @@ fn test_case_expr_with_expression() { let df = create_sample_df().unwrap(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); + let sql = r#" SELECT CASE b%2 @@ -615,6 +638,7 @@ fn test_case_expr_with_expression() { .then(lit("odd")) .otherwise(lit("No?")) .alias("parity"); + let df_pl = df.lazy().select(&[case_expr]).collect().unwrap(); assert!(df_sql.equals(&df_pl)); } @@ -630,17 +654,72 @@ fn test_sql_expr() { #[test] fn test_iss_9471() { - let sql = r#" - SELECT - ABS(a,a,a,a,1,2,3,XYZRandomLetters,"XYZRandomLetters") as "abs", - FROM df"#; let df = df! { "a" => [-4, -3, -2, -1, 0, 1, 2, 3, 4], } .unwrap() .lazy(); + let mut context = SQLContext::new(); context.register("df", df); + + let sql = r#" + SELECT + ABS(a,a,a,a,1,2,3,XYZRandomLetters,"XYZRandomLetters") AS "abs", + FROM df"#; let res = context.execute(sql); + assert!(res.is_err()) } + +#[test] +fn test_order_by_excluded_column() { + let df = df! { + "x" => [0, 1, 2, 3], + "y" => [4, 2, 0, 8], + } + .unwrap() + .lazy(); + + let mut context = SQLContext::new(); + context.register("df", df); + + for sql in [ + "SELECT * EXCLUDE y FROM df ORDER BY y", + "SELECT df.* EXCLUDE y FROM df ORDER BY y", + ] { + let df_sorted = context.execute(sql).unwrap().collect().unwrap(); + + let expected = df! {"x" => [2, 1, 0, 3],}.unwrap(); + assert!(df_sorted.equals(&expected)); + } +} + +#[test] +fn test_struct_wildcards() { + let struct_cols = vec![col("num"), col("str"), col("val")]; + let df_original = df! { + "num" => [100, 200, 300, 400], + "str" => ["d", "c", "b", "a"], + "val" => [0.0, 5.0, 3.0, 4.0], + } + .unwrap(); + + let df_struct = df_original + .clone() + .lazy() + .select([as_struct(struct_cols).alias("json_msg")]); + + let mut context = SQLContext::new(); + context.register("df", df_struct.clone().lazy()); + + for sql in [ + r#"SELECT json_msg.* FROM df"#, + r#"SELECT df.json_msg.* FROM df"#, + r#"SELECT json_msg.* FROM df ORDER BY json_msg.num"#, + r#"SELECT df.json_msg.* FROM df ORDER BY json_msg.str DESC"#, + ] { + let df_sql = context.execute(sql).unwrap().collect().unwrap(); + assert!(df_sql.equals(&df_original)); + } +} diff --git a/py-polars/tests/unit/sql/test_structs.py b/py-polars/tests/unit/sql/test_structs.py index 6f1cad494ac6..9ed6bd1a2cb0 100644 --- a/py-polars/tests/unit/sql/test_structs.py +++ b/py-polars/tests/unit/sql/test_structs.py @@ -8,7 +8,7 @@ @pytest.fixture() -def struct_df() -> pl.DataFrame: +def df_struct() -> pl.DataFrame: return pl.DataFrame( { "id": [100, 200, 300, 400], @@ -19,8 +19,8 @@ def struct_df() -> pl.DataFrame: ).select(pl.struct(pl.all()).alias("json_msg")) -def test_struct_field_selection(struct_df: pl.DataFrame) -> None: - res = struct_df.sql( +def test_struct_field_selection(df_struct: pl.DataFrame) -> None: + res = df_struct.sql( """ SELECT -- validate table alias resolution @@ -36,17 +36,39 @@ def test_struct_field_selection(struct_df: pl.DataFrame) -> None: json_msg.name DESC """ ) - expected = pl.DataFrame( - { - "ID": [400, 100], - "NAME": ["Zoe", "Alice"], - "AGE": [45, 32], - } + {"ID": [400, 100], "NAME": ["Zoe", "Alice"], "AGE": [45, 32]} ) assert_frame_equal(expected, res) +@pytest.mark.parametrize( + ("fields", "excluding"), + [ + ("json_msg.*", ""), + ("self.json_msg.*", ""), + ("json_msg.other.*", ""), + ("self.json_msg.other.*", ""), + ], +) +def test_struct_field_wildcard_selection( + fields: str, + excluding: str, + df_struct: pl.DataFrame, +) -> None: + query = f"SELECT {fields} {excluding} FROM df_struct ORDER BY json_msg.id" + print(query) + res = pl.sql(query).collect() + + expected = df_struct.unnest("json_msg") + if fields.endswith(".other.*"): + expected = expected["other"].struct.unnest() + if excluding: + expected = expected.drop(excluding.split(",")) + + assert_frame_equal(expected, res) + + @pytest.mark.parametrize( "invalid_column", [ @@ -55,6 +77,6 @@ def test_struct_field_selection(struct_df: pl.DataFrame) -> None: "self.json_msg.other.invalid_column", ], ) -def test_struct_indexing_errors(invalid_column: str, struct_df: pl.DataFrame) -> None: +def test_struct_indexing_errors(invalid_column: str, df_struct: pl.DataFrame) -> None: with pytest.raises(StructFieldNotFoundError, match="invalid_column"): - struct_df.sql(f"SELECT {invalid_column} FROM self") + df_struct.sql(f"SELECT {invalid_column} FROM self") From 1e29be3805866a6770448c1524cb49288392bd8f Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sat, 22 Jun 2024 01:56:33 +0400 Subject: [PATCH 3/7] fix issue with `order by` ordinal-based column selection --- crates/polars-sql/src/context.rs | 38 ++++++++---- crates/polars-sql/tests/simple_exprs.rs | 77 ++++++++++++++----------- 2 files changed, 69 insertions(+), 46 deletions(-) diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index cce84fcc9596..420207053aef 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -200,8 +200,9 @@ impl SQLContext { fn expr_or_ordinal( &mut self, e: &SQLExpr, - schema: Option<&Schema>, exprs: &[Expr], + selected: Option<&[Expr]>, + schema: Option<&Schema>, clause: &str, ) -> PolarsResult { match e { @@ -230,7 +231,12 @@ impl SQLContext { idx ) })?; - Ok(exprs + let cols = if let Some(cols) = selected { + cols + } else { + exprs + }; + Ok(cols .get(idx - 1) .ok_or_else(|| { polars_err!( @@ -645,7 +651,9 @@ impl SQLContext { // translate the group expressions, allowing ordinal values group_by_keys = group_by_exprs .iter() - .map(|e| self.expr_or_ordinal(e, schema.as_deref(), &projections, "GROUP BY")) + .map(|e| { + self.expr_or_ordinal(e, &projections, None, schema.as_deref(), "GROUP BY") + }) .collect::>()? }, // "GROUP BY ALL" syntax; automatically adds expressions that do not contain @@ -712,8 +720,9 @@ impl SQLContext { }); let retained_columns: Vec<_> = retained_names.into_iter().map(|name| col(&name)).collect(); + lf = lf.with_columns(projections); - lf = self.process_order_by(lf, &query.order_by)?; + lf = self.process_order_by(lf, &query.order_by, Some(retained_columns.as_ref()))?; lf.select(&retained_columns) } else if contains_wildcard_exclude { let mut dropped_names = Vec::with_capacity(projections.len()); @@ -731,19 +740,19 @@ impl SQLContext { }); if exclude_expr.is_some() { lf = lf.with_columns(projections); - lf = self.process_order_by(lf, &query.order_by)?; + lf = self.process_order_by(lf, &query.order_by, None)?; lf.drop(dropped_names) } else { lf = lf.select(projections); - self.process_order_by(lf, &query.order_by)? + self.process_order_by(lf, &query.order_by, None)? } } else { lf = lf.select(projections); - self.process_order_by(lf, &query.order_by)? + self.process_order_by(lf, &query.order_by, None)? } } else { lf = self.process_group_by(lf, contains_wildcard, &group_by_keys, &projections)?; - lf = self.process_order_by(lf, &query.order_by)?; + lf = self.process_order_by(lf, &query.order_by, None)?; // Apply optional 'having' clause, post-aggregation. let schema = Some(lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?); @@ -773,7 +782,7 @@ impl SQLContext { // DISTINCT ON applies the ORDER BY before the operation. if !query.order_by.is_empty() { - lf = self.process_order_by(lf, &query.order_by)?; + lf = self.process_order_by(lf, &query.order_by, None)?; } return Ok(lf.unique_stable(Some(cols), UniqueKeepStrategy::First)); }, @@ -1002,13 +1011,14 @@ impl SQLContext { &mut self, mut lf: LazyFrame, order_by: &[OrderByExpr], + selected: Option<&[Expr]>, ) -> PolarsResult { let mut by = Vec::with_capacity(order_by.len()); let mut descending = Vec::with_capacity(order_by.len()); let mut nulls_last = Vec::with_capacity(order_by.len()); let schema = Some(lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?); - let column_names = schema + let columns = schema .clone() .unwrap() .iter_names() @@ -1023,7 +1033,13 @@ impl SQLContext { descending.push(desc_order); // translate order expression, allowing ordinal values - by.push(self.expr_or_ordinal(&ob.expr, schema.as_deref(), &column_names, "ORDER BY")?) + by.push(self.expr_or_ordinal( + &ob.expr, + &columns, + selected, + schema.as_deref(), + "ORDER BY", + )?) } Ok(lf.sort_by_exprs( &by, diff --git a/crates/polars-sql/tests/simple_exprs.rs b/crates/polars-sql/tests/simple_exprs.rs index 77980bf54f77..e32126f0ae14 100644 --- a/crates/polars-sql/tests/simple_exprs.rs +++ b/crates/polars-sql/tests/simple_exprs.rs @@ -3,10 +3,29 @@ use polars_lazy::prelude::*; use polars_sql::*; use polars_time::Duration; -fn create_sample_df() -> PolarsResult { +fn create_sample_df() -> DataFrame { let a = Series::new("a", (1..10000i64).map(|i| i / 100).collect::>()); let b = Series::new("b", 1..10000i64); - DataFrame::new(vec![a, b]) + DataFrame::new(vec![a, b]).unwrap() +} + +fn create_struct_df() -> (DataFrame, DataFrame) { + let struct_cols = vec![col("num"), col("str"), col("val")]; + let df = df! { + "num" => [100, 250, 300, 350], + "str" => ["b", "a", "b", "a"], + "val" => [4.0, 3.5, 2.0, 1.5], + } + .unwrap(); + + ( + df.clone() + .lazy() + .select([as_struct(struct_cols).alias("json_msg")]) + .collect() + .unwrap(), + df, + ) } fn assert_sql_to_polars(df: &DataFrame, sql: &str, f: impl FnOnce(LazyFrame) -> LazyFrame) { @@ -19,7 +38,7 @@ fn assert_sql_to_polars(df: &DataFrame, sql: &str, f: impl FnOnce(LazyFrame) -> #[test] fn test_simple_select() -> PolarsResult<()> { - let df = create_sample_df()?; + let df = create_sample_df(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); let df_sql = context @@ -44,7 +63,7 @@ fn test_simple_select() -> PolarsResult<()> { #[test] fn test_nested_expr() -> PolarsResult<()> { - let df = create_sample_df()?; + let df = create_sample_df(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); let df_sql = context @@ -57,7 +76,7 @@ fn test_nested_expr() -> PolarsResult<()> { #[test] fn test_group_by_simple() -> PolarsResult<()> { - let df = create_sample_df()?; + let df = create_sample_df(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); let df_sql = context @@ -134,7 +153,7 @@ fn test_group_by_expression_key() -> PolarsResult<()> { #[test] fn test_cast_exprs() { - let df = create_sample_df().unwrap(); + let df = create_sample_df(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); let sql = r#" @@ -164,7 +183,7 @@ fn test_cast_exprs() { #[test] fn test_literal_exprs() { - let df = create_sample_df().unwrap(); + let df = create_sample_df(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); let sql = r#" @@ -225,7 +244,7 @@ fn test_implicit_date_string() { #[test] fn test_prefixed_column_names() { - let df = create_sample_df().unwrap(); + let df = create_sample_df(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); let sql = r#" @@ -244,7 +263,7 @@ fn test_prefixed_column_names() { #[test] fn test_prefixed_column_names_2() { - let df = create_sample_df().unwrap(); + let df = create_sample_df(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); let sql = r#" @@ -263,7 +282,7 @@ fn test_prefixed_column_names_2() { #[test] fn test_null_exprs() { - let df = create_sample_df().unwrap(); + let df = create_sample_df(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); let sql = r#" @@ -319,7 +338,7 @@ fn test_null_exprs_in_where() { #[test] fn test_binary_functions() { - let df = create_sample_df().unwrap(); + let df = create_sample_df(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); let sql = r#" @@ -369,7 +388,7 @@ fn test_binary_functions() { #[test] #[ignore = "TODO: non deterministic"] fn test_agg_functions() { - let df = create_sample_df().unwrap(); + let df = create_sample_df(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); let sql = r#" @@ -411,7 +430,7 @@ fn test_agg_functions() { #[test] fn test_create_table() { - let df = create_sample_df().unwrap(); + let df = create_sample_df(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); @@ -474,7 +493,7 @@ fn test_unary_minus_1() { #[test] fn test_arr_agg() { - let df = create_sample_df().unwrap(); + let df = create_sample_df(); let exprs = vec![ ( "SELECT ARRAY_AGG(a) AS a FROM df", @@ -515,7 +534,7 @@ fn test_arr_agg() { #[test] fn test_ctes() -> PolarsResult<()> { - let df = create_sample_df()?; + let df = create_sample_df(); let mut context = SQLContext::new(); context.register("df", df.lazy()); @@ -594,7 +613,7 @@ fn test_group_by_2() -> PolarsResult<()> { #[test] fn test_case_expr() { - let df = create_sample_df().unwrap().head(Some(10)); + let df = create_sample_df().head(Some(10)); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); let sql = r#" @@ -619,7 +638,7 @@ fn test_case_expr() { #[test] fn test_case_expr_with_expression() { - let df = create_sample_df().unwrap(); + let df = create_sample_df(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); @@ -645,7 +664,7 @@ fn test_case_expr_with_expression() { #[test] fn test_sql_expr() { - let df = create_sample_df().unwrap(); + let df = create_sample_df(); let expr = sql_expr("MIN(a)").unwrap(); let actual = df.clone().lazy().select(&[expr]).collect().unwrap(); let expected = df.lazy().select(&[col("a").min()]).collect().unwrap(); @@ -689,35 +708,23 @@ fn test_order_by_excluded_column() { "SELECT df.* EXCLUDE y FROM df ORDER BY y", ] { let df_sorted = context.execute(sql).unwrap().collect().unwrap(); - let expected = df! {"x" => [2, 1, 0, 3],}.unwrap(); assert!(df_sorted.equals(&expected)); } } #[test] -fn test_struct_wildcards() { - let struct_cols = vec![col("num"), col("str"), col("val")]; - let df_original = df! { - "num" => [100, 200, 300, 400], - "str" => ["d", "c", "b", "a"], - "val" => [0.0, 5.0, 3.0, 4.0], - } - .unwrap(); - - let df_struct = df_original - .clone() - .lazy() - .select([as_struct(struct_cols).alias("json_msg")]); +fn test_struct_field_selection() { + let (df_struct, df_original) = create_struct_df(); let mut context = SQLContext::new(); context.register("df", df_struct.clone().lazy()); for sql in [ - r#"SELECT json_msg.* FROM df"#, - r#"SELECT df.json_msg.* FROM df"#, + r#"SELECT json_msg.* FROM df ORDER BY 1"#, + r#"SELECT df.json_msg.* FROM df ORDER BY 3 DESC"#, r#"SELECT json_msg.* FROM df ORDER BY json_msg.num"#, - r#"SELECT df.json_msg.* FROM df ORDER BY json_msg.str DESC"#, + r#"SELECT df.json_msg.* FROM df ORDER BY json_msg.val DESC"#, ] { let df_sql = context.execute(sql).unwrap().collect().unwrap(); assert!(df_sql.equals(&df_original)); From 2d2e9c4c7643e40300af12609c7c712f548b279d Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sat, 22 Jun 2024 22:20:14 +0400 Subject: [PATCH 4/7] fix issue with `group by` and struct fields --- .../polars-plan/src/dsl/function_expr/mod.rs | 2 +- crates/polars-sql/src/context.rs | 22 ++++++++- crates/polars-sql/tests/simple_exprs.rs | 16 ++++++ py-polars/tests/unit/sql/test_structs.py | 49 ++++++++++++++++--- 4 files changed, 79 insertions(+), 10 deletions(-) diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 0dd79141ea02..d596f8f95c15 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -103,7 +103,7 @@ pub(super) use self::rolling_by::RollingFunctionBy; #[cfg(feature = "strings")] pub use self::strings::StringFunction; #[cfg(feature = "dtype-struct")] -pub(crate) use self::struct_::StructFunction; +pub use self::struct_::StructFunction; #[cfg(feature = "trigonometry")] pub(super) use self::trigonometry::TrigonometricFunction; use super::*; diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 420207053aef..638d317456aa 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -1,9 +1,11 @@ use std::cell::RefCell; +use std::ops::Deref; use polars_core::frame::row::Row; use polars_core::prelude::*; use polars_lazy::prelude::*; use polars_ops::frame::JoinCoalesce; +use polars_plan::dsl::function_expr::StructFunction; use polars_plan::prelude::*; use sqlparser::ast::{ Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, Ident, JoinConstraint, @@ -231,6 +233,8 @@ impl SQLContext { idx ) })?; + // note: "selected" cols represent final projection order, so we use those for + // ordinal resolution. "exprs" may include cols that are subsequently dropped. let cols = if let Some(cols) = selected { cols } else { @@ -1067,6 +1071,7 @@ impl SQLContext { // Remove the group_by keys as polars adds those implicitly. let mut aggregation_projection = Vec::with_capacity(projections.len()); + let mut projection_overrides = PlHashMap::with_capacity(projections.len()); let mut projection_aliases = PlHashSet::new(); let mut group_key_aliases = PlHashSet::new(); @@ -1081,6 +1086,12 @@ impl SQLContext { if e.clone().meta().is_simple_projection() { group_key_aliases.insert(alias.as_ref()); e = expr + } else if let Expr::Function { + function: FunctionExpr::StructExpr(StructFunction::FieldByName(name)), + .. + } = expr.deref() + { + projection_overrides.insert(alias.as_ref(), col(name).alias(alias)); } else if !is_agg_or_window && !group_by_keys_schema.contains(alias) { projection_aliases.insert(alias.as_ref()); } @@ -1096,7 +1107,12 @@ impl SQLContext { } } aggregation_projection.push(e); - } else if let Expr::Column(_) = e { + } else if let Expr::Column(_) + | Expr::Function { + function: FunctionExpr::StructExpr(StructFunction::FieldByName(_)), + .. + } = e + { // Non-aggregated columns must be part of the GROUP BY clause if !group_by_keys_schema.contains(&field.name) { polars_bail!(SQLSyntax: "'{}' should participate in the GROUP BY clause or an aggregate function", &field.name); @@ -1112,7 +1128,9 @@ impl SQLContext { .iter_names() .zip(projections) .map(|(name, projection_expr)| { - if group_by_keys_schema.get(name).is_some() + if let Some(expr) = projection_overrides.get(name.as_str()) { + expr.clone() + } else if group_by_keys_schema.get(name).is_some() || projection_aliases.contains(name.as_str()) || group_key_aliases.contains(name.as_str()) { diff --git a/crates/polars-sql/tests/simple_exprs.rs b/crates/polars-sql/tests/simple_exprs.rs index e32126f0ae14..274eb1a7af6f 100644 --- a/crates/polars-sql/tests/simple_exprs.rs +++ b/crates/polars-sql/tests/simple_exprs.rs @@ -729,4 +729,20 @@ fn test_struct_field_selection() { let df_sql = context.execute(sql).unwrap().collect().unwrap(); assert!(df_sql.equals(&df_original)); } + + let sql = r#" + SELECT + json_msg.str AS id, + SUM(json_msg.num) AS sum_n + FROM df + GROUP BY json_msg.str + ORDER BY 1 + "#; + let df_sql = context.execute(sql).unwrap().collect().unwrap(); + let df_expected = df! { + "id" => ["a", "b"], + "sum_n" => [600, 400], + } + .unwrap(); + assert!(df_sql.equals(&df_expected)); } diff --git a/py-polars/tests/unit/sql/test_structs.py b/py-polars/tests/unit/sql/test_structs.py index 9ed6bd1a2cb0..db965efcd86d 100644 --- a/py-polars/tests/unit/sql/test_structs.py +++ b/py-polars/tests/unit/sql/test_structs.py @@ -3,7 +3,7 @@ import pytest import polars as pl -from polars.exceptions import StructFieldNotFoundError +from polars.exceptions import SQLSyntaxError, StructFieldNotFoundError from polars.testing import assert_frame_equal @@ -11,10 +11,10 @@ def df_struct() -> pl.DataFrame: return pl.DataFrame( { - "id": [100, 200, 300, 400], - "name": ["Alice", "Bob", "David", "Zoe"], - "age": [32, 27, 19, 45], - "other": [{"n": 1.5}, {"n": None}, {"n": -0.5}, {"n": 2.0}], + "id": [200, 300, 400], + "name": ["Bob", "David", "Zoe"], + "age": [45, 19, 45], + "other": [{"n": 1.5}, {"n": None}, {"n": -0.5}], } ).select(pl.struct(pl.all()).alias("json_msg")) @@ -36,12 +36,45 @@ def test_struct_field_selection(df_struct: pl.DataFrame) -> None: json_msg.name DESC """ ) + expected = pl.DataFrame({"ID": [400, 200], "NAME": ["Zoe", "Bob"], "AGE": [45, 45]}) + assert_frame_equal(expected, res) + + +def test_struct_field_group_by(df_struct: pl.DataFrame) -> None: + res = pl.sql( + """ + SELECT + COUNT(json_msg.age) AS n, + ARRAY_AGG(json_msg.name) AS names + FROM df_struct + GROUP BY json_msg.age + ORDER BY 1 DESC + """ + ).collect() + expected = pl.DataFrame( - {"ID": [400, 100], "NAME": ["Zoe", "Alice"], "AGE": [45, 32]} + data={"n": [2, 1], "names": [["Bob", "Zoe"], ["David"]]}, + schema_overrides={"n": pl.UInt32}, ) assert_frame_equal(expected, res) +def test_struct_field_group_by_errors(df_struct: pl.DataFrame) -> None: + with pytest.raises( + SQLSyntaxError, + match="'name' should participate in the GROUP BY clause or an aggregate function", + ): + pl.sql( + """ + SELECT + json_msg.name, + SUM(json_msg.age) AS sum_age + FROM df_struct + GROUP BY json_msg.age + """ + ).collect() + + @pytest.mark.parametrize( ("fields", "excluding"), [ @@ -77,6 +110,8 @@ def test_struct_field_wildcard_selection( "self.json_msg.other.invalid_column", ], ) -def test_struct_indexing_errors(invalid_column: str, df_struct: pl.DataFrame) -> None: +def test_struct_field_selection_errors( + invalid_column: str, df_struct: pl.DataFrame +) -> None: with pytest.raises(StructFieldNotFoundError, match="invalid_column"): df_struct.sql(f"SELECT {invalid_column} FROM self") From dd6b2c4c782bd3f2066e42a79f4dc47a27a15066 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sat, 22 Jun 2024 23:13:39 +0400 Subject: [PATCH 5/7] avoid use of `rewrite_projections` --- .../src/plans/conversion/expr_expansion.rs | 2 +- .../polars-plan/src/plans/conversion/mod.rs | 2 +- crates/polars-plan/src/plans/mod.rs | 2 +- crates/polars-plan/src/prelude.rs | 1 - crates/polars-sql/src/context.rs | 178 +++++++----------- crates/polars-sql/src/sql_expr.rs | 128 ++++++++----- py-polars/tests/unit/sql/test_order_by.py | 28 +-- 7 files changed, 170 insertions(+), 171 deletions(-) diff --git a/crates/polars-plan/src/plans/conversion/expr_expansion.rs b/crates/polars-plan/src/plans/conversion/expr_expansion.rs index fee7913dede7..4f168d1080d1 100644 --- a/crates/polars-plan/src/plans/conversion/expr_expansion.rs +++ b/crates/polars-plan/src/plans/conversion/expr_expansion.rs @@ -634,7 +634,7 @@ fn find_flags(expr: &Expr) -> PolarsResult { /// In case of single col(*) -> do nothing, no selection is the same as select all /// In other cases replace the wildcard with an expression with all columns -pub fn rewrite_projections( +pub(crate) fn rewrite_projections( exprs: Vec, schema: &Schema, keys: &[Expr], diff --git a/crates/polars-plan/src/plans/conversion/mod.rs b/crates/polars-plan/src/plans/conversion/mod.rs index afdac2d300fc..d0d6a41e9fb9 100644 --- a/crates/polars-plan/src/plans/conversion/mod.rs +++ b/crates/polars-plan/src/plans/conversion/mod.rs @@ -1,6 +1,6 @@ mod convert_utils; mod dsl_to_ir; -pub(crate) mod expr_expansion; +mod expr_expansion; mod expr_to_ir; mod ir_to_dsl; #[cfg(any(feature = "ipc", feature = "parquet", feature = "csv"))] diff --git a/crates/polars-plan/src/plans/mod.rs b/crates/polars-plan/src/plans/mod.rs index 9255c811e489..ca9acc44cf53 100644 --- a/crates/polars-plan/src/plans/mod.rs +++ b/crates/polars-plan/src/plans/mod.rs @@ -16,7 +16,7 @@ pub(crate) mod ir; mod apply; mod builder_dsl; mod builder_ir; -pub mod conversion; +pub(crate) mod conversion; #[cfg(feature = "debugging")] pub(crate) mod debug; pub mod expr_ir; diff --git a/crates/polars-plan/src/prelude.rs b/crates/polars-plan/src/prelude.rs index 34c38cefbdab..d90e032cc925 100644 --- a/crates/polars-plan/src/prelude.rs +++ b/crates/polars-plan/src/prelude.rs @@ -11,7 +11,6 @@ pub(crate) use polars_time::prelude::*; pub use polars_utils::arena::{Arena, Node}; pub use crate::dsl::*; -pub use crate::plans::conversion::expr_expansion::rewrite_projections; #[cfg(feature = "debugging")] pub use crate::plans::debug::*; pub use crate::plans::options::*; diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 638d317456aa..0fc38925c2bb 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -18,7 +18,8 @@ use sqlparser::parser::{Parser, ParserOptions}; use crate::function_registry::{DefaultFunctionRegistry, FunctionRegistry}; use crate::sql_expr::{ - parse_sql_array, parse_sql_expr, process_join_constraint, to_sql_interface_err, + parse_sql_array, parse_sql_expr, process_join_constraint, resolve_compound_identifier, + to_sql_interface_err, }; use crate::table_functions::PolarsTableFunctions; @@ -602,43 +603,43 @@ impl SQLContext { } self.execute_from_statement(from.first().unwrap())? }; - let mut contains_wildcard = false; - let mut contains_wildcard_exclude = false; // Filter expression. - let schema = Some(lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?); + let schema = lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; lf = self.process_where(lf, &select_stmt.selection)?; // Column projections. + let mut excluded_cols = Vec::new(); let projections: Vec = select_stmt .projection .iter() .map(|select_item| { Ok(match select_item { SelectItem::UnnamedExpr(expr) => { - vec![parse_sql_expr(expr, self, schema.as_deref())?] + vec![parse_sql_expr(expr, self, Some(schema.deref()))?] }, SelectItem::ExprWithAlias { expr, alias } => { - let expr = parse_sql_expr(expr, self, schema.as_deref())?; + let expr = parse_sql_expr(expr, self, Some(schema.deref()))?; vec![expr.alias(&alias.value)] }, - SelectItem::QualifiedWildcard(obj_name, wildcard_options) => { - let expanded = self.process_qualified_wildcard( + SelectItem::QualifiedWildcard(obj_name, wildcard_options) => self + .process_qualified_wildcard( obj_name, wildcard_options, - &mut contains_wildcard_exclude, - schema.as_deref(), - )?; - rewrite_projections(vec![expanded], &(schema.clone().unwrap()), &[])? - }, + &mut excluded_cols, + Some(schema.deref()), + )?, SelectItem::Wildcard(wildcard_options) => { - contains_wildcard = true; - let e = col("*"); - vec![self.process_wildcard_additional_options( - e, + let cols = schema + .iter_names() + .map(|name| col(name)) + .collect::>(); + + self.process_wildcard_additional_options( + cols, wildcard_options, - &mut contains_wildcard_exclude, - )?] + &mut excluded_cols, + )? }, }) }) @@ -656,7 +657,13 @@ impl SQLContext { group_by_keys = group_by_exprs .iter() .map(|e| { - self.expr_or_ordinal(e, &projections, None, schema.as_deref(), "GROUP BY") + self.expr_or_ordinal( + e, + &projections, + None, + Some(schema.deref()), + "GROUP BY", + ) }) .collect::>()? }, @@ -689,73 +696,37 @@ impl SQLContext { }; lf = if group_by_keys.is_empty() { - if query.order_by.is_empty() { + lf = if query.order_by.is_empty() { + // No sort, select cols as given lf.select(projections) - } else if !contains_wildcard { - let mut retained_names = PlIndexSet::with_capacity(projections.len()); - let schema = lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; + } else { + // Add all projections to the base frame as any of + // the original columns may be required for the sort + lf = lf.with_columns(projections.clone()); - projections.iter().for_each(|expr| match expr { - Expr::Alias(_, name) => { - retained_names.insert(name.clone()); - }, - Expr::Column(name) => { - retained_names.insert(name.clone()); - }, - Expr::Columns(names) => names.iter().for_each(|name| { - retained_names.insert(name.clone()); - }), - Expr::Exclude(inner_expr, excludes) => { - if let Expr::Columns(names) = (*inner_expr).as_ref() { - names.iter().for_each(|name| { - retained_names.insert(name.clone()); - }) - } - excludes.iter().for_each(|excluded| { - if let Excluded::Name(name) = excluded { - retained_names.shift_remove(name); - } - }) - }, - _ => { - let field = expr.to_field(&schema, Context::Default).unwrap(); - retained_names.insert(ColumnName::from(field.name.as_str())); - }, - }); - let retained_columns: Vec<_> = - retained_names.into_iter().map(|name| col(&name)).collect(); - - lf = lf.with_columns(projections); - lf = self.process_order_by(lf, &query.order_by, Some(retained_columns.as_ref()))?; - lf.select(&retained_columns) - } else if contains_wildcard_exclude { - let mut dropped_names = Vec::with_capacity(projections.len()); - let exclude_expr = projections.iter().find(|expr| { - if let Expr::Exclude(_, excludes) = expr { - for excluded in excludes.iter() { - if let Excluded::Name(name) = excluded { - dropped_names.push(name.to_string()); - } - } - true - } else { - false - } - }); - if exclude_expr.is_some() { - lf = lf.with_columns(projections); - lf = self.process_order_by(lf, &query.order_by, None)?; - lf.drop(dropped_names) - } else { - lf = lf.select(projections); - self.process_order_by(lf, &query.order_by, None)? - } + // Final/selected cols (also ensures accurate ordinal position refs) + let retained_cols = projections + .iter() + .map(|e| { + col(e + .to_field(schema.deref(), Context::Default) + .unwrap() + .name + .as_str()) + }) + .collect::>(); + + lf = self.process_order_by(lf, &query.order_by, Some(&retained_cols))?; + lf.select(retained_cols) + }; + // Discard any excluded cols + if !excluded_cols.is_empty() { + lf.drop(excluded_cols) } else { - lf = lf.select(projections); - self.process_order_by(lf, &query.order_by, None)? + lf } } else { - lf = self.process_group_by(lf, contains_wildcard, &group_by_keys, &projections)?; + lf = self.process_group_by(lf, &group_by_keys, &projections)?; lf = self.process_order_by(lf, &query.order_by, None)?; // Apply optional 'having' clause, post-aggregation. @@ -784,7 +755,7 @@ impl SQLContext { }) .collect::>>()?; - // DISTINCT ON applies the ORDER BY before the operation. + // DISTINCT ON has to apply the ORDER BY before the operation. if !query.order_by.is_empty() { lf = self.process_order_by(lf, &query.order_by, None)?; } @@ -1057,14 +1028,9 @@ impl SQLContext { fn process_group_by( &mut self, mut lf: LazyFrame, - contains_wildcard: bool, group_by_keys: &[Expr], projections: &[Expr], ) -> PolarsResult { - polars_ensure!( - !contains_wildcard, - SQLSyntax: "GROUP BY error (cannot process wildcard in group_by)" - ); let schema_before = lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; let group_by_keys_schema = expressions_to_schema(group_by_keys, &schema_before, Context::Default)?; @@ -1193,22 +1159,22 @@ impl SQLContext { &mut self, ObjectName(idents): &ObjectName, options: &WildcardAdditionalOptions, - contains_wildcard_exclude: &mut bool, + excluded_cols: &mut Vec, schema: Option<&Schema>, - ) -> PolarsResult { + ) -> PolarsResult> { let mut new_idents = idents.clone(); new_idents.push(Ident::new("*")); - let identifier = SQLExpr::CompoundIdentifier(new_idents); - let expr = parse_sql_expr(&identifier, self, schema)?; - self.process_wildcard_additional_options(expr, options, contains_wildcard_exclude) + + let expr = resolve_compound_identifier(self, new_idents.deref(), schema); + self.process_wildcard_additional_options(expr?, options, excluded_cols) } fn process_wildcard_additional_options( &mut self, - expr: Expr, + exprs: Vec, options: &WildcardAdditionalOptions, - contains_wildcard_exclude: &mut bool, - ) -> PolarsResult { + excluded_cols: &mut Vec, + ) -> PolarsResult> { // bail on unsupported wildcard options if options.opt_ilike.is_some() { polars_bail!(SQLSyntax: "ILIKE wildcard option is unsupported") @@ -1220,17 +1186,15 @@ impl SQLContext { polars_bail!(SQLSyntax: "EXCEPT wildcard option is unsupported (use EXCLUDE instead)") } - Ok(match &options.opt_exclude { - Some(ExcludeSelectItem::Single(ident)) => { - *contains_wildcard_exclude = true; - expr.exclude(vec![&ident.value]) - }, - Some(ExcludeSelectItem::Multiple(idents)) => { - *contains_wildcard_exclude = true; - expr.exclude(idents.iter().map(|i| &i.value)) - }, - _ => expr, - }) + if let Some(exc_items) = &options.opt_exclude { + *excluded_cols = match exc_items { + ExcludeSelectItem::Single(ident) => vec![ident.value.clone()], + ExcludeSelectItem::Multiple(idents) => { + idents.iter().map(|i| i.value.clone()).collect() + }, + }; + } + Ok(exprs) } fn rename_columns_from_table_alias( diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index a4ae10997715..bcf95d75f8d4 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -377,53 +377,7 @@ impl SQLExprVisitor<'_> { /// /// e.g. tbl.column, struct.field, tbl.struct.field (inc. nested struct fields) fn visit_compound_identifier(&mut self, idents: &[Ident]) -> PolarsResult { - // inference priority: table > struct > column - let ident_root = &idents[0]; - let mut remaining_idents = idents.iter().skip(1); - let mut lf = self.ctx.get_table_from_current_scope(&ident_root.value); - - let schema = if let Some(ref mut lf) = lf { - lf.schema_with_arenas(&mut self.ctx.lp_arena, &mut self.ctx.expr_arena) - } else { - Ok(Arc::new(if let Some(active_schema) = self.active_schema { - active_schema.clone() - } else { - Schema::new() - })) - }?; - - let mut column: PolarsResult = if lf.is_none() && schema.is_empty() { - Ok(col(&ident_root.value)) - } else { - let name = &remaining_idents.next().unwrap().value; - if lf.is_some() && name == "*" { - Ok(cols(schema.iter_names())) - } else if let Some((_, name, _)) = schema.get_full(name) { - let resolved = &self.ctx.resolve_name(&ident_root.value, name); - Ok(if name != resolved { - col(resolved).alias(name) - } else { - col(name) - }) - } else if lf.is_none() { - remaining_idents = idents.iter().skip(1); - Ok(col(&ident_root.value)) - } else { - polars_bail!( - SQLInterface: "no column named '{}' found in table '{}'", - name, - ident_root - ) - } - }; - // additional ident levels index into struct fields - for ident in remaining_idents { - column = Ok(column - .unwrap() - .struct_() - .field_by_name(ident.value.as_str())); - } - column + Ok(resolve_compound_identifier(self.ctx, idents, self.active_schema)?[0].clone()) } fn visit_interval(&self, interval: &Interval) -> PolarsResult { @@ -1253,3 +1207,83 @@ fn bitstring_to_bytes_literal(b: &String) -> PolarsResult { _ => u64::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(), })) } + +pub(crate) fn resolve_compound_identifier( + ctx: &mut SQLContext, + idents: &[Ident], + active_schema: Option<&Schema>, +) -> PolarsResult> { + // inference priority: table > struct > column + let ident_root = &idents[0]; + let mut remaining_idents = idents.iter().skip(1); + let mut lf = ctx.get_table_from_current_scope(&ident_root.value); + + let schema = if let Some(ref mut lf) = lf { + lf.schema_with_arenas(&mut ctx.lp_arena, &mut ctx.expr_arena) + } else { + Ok(Arc::new(if let Some(active_schema) = active_schema { + active_schema.clone() + } else { + Schema::new() + })) + }?; + + let col_dtype: PolarsResult<(Expr, Option<&DataType>)> = if lf.is_none() && schema.is_empty() { + Ok((col(&ident_root.value), None)) + } else { + let name = &remaining_idents.next().unwrap().value; + if lf.is_some() && name == "*" { + return Ok(schema + .iter_names() + .map(|name| col(name)) + .collect::>()); + } else if let Some((_, name, dtype)) = schema.get_full(name) { + let resolved = &ctx.resolve_name(&ident_root.value, name); + Ok(( + if name != resolved { + col(resolved).alias(name) + } else { + col(name) + }, + Some(dtype), + )) + } else if lf.is_none() { + remaining_idents = idents.iter().skip(1); + Ok((col(&ident_root.value), schema.get(&ident_root.value))) + } else { + polars_bail!( + SQLInterface: "no column named '{}' found in table '{}'", + name, + ident_root + ) + } + }; + + // additional ident levels index into struct fields + let (mut column, mut dtype) = col_dtype?; + for ident in remaining_idents { + let name = ident.value.as_str(); + match dtype { + Some(DataType::Struct(fields)) if name == "*" => { + return Ok(fields + .iter() + .map(|fld| column.clone().struct_().field_by_name(&fld.name)) + .collect()) + }, + Some(DataType::Struct(fields)) => { + dtype = fields + .iter() + .find(|fld| fld.name == name) + .map(|fld| &fld.dtype); + }, + Some(dtype) if name == "*" => { + polars_bail!(SQLSyntax: "cannot expand '*' on non-Struct dtype; found {:?}", dtype) + }, + _ => { + dtype = None; + }, + } + column = column.struct_().field_by_name(name); + } + Ok(vec![column]) +} diff --git a/py-polars/tests/unit/sql/test_order_by.py b/py-polars/tests/unit/sql/test_order_by.py index 691d6895be7b..8fb470508f31 100644 --- a/py-polars/tests/unit/sql/test_order_by.py +++ b/py-polars/tests/unit/sql/test_order_by.py @@ -27,17 +27,19 @@ def test_order_by_basic(foods_ipc_path: Path) -> None: "category": ["vegetables", "seafood", "meat", "fruit"] } - order_by_group_by_res = foods.sql( - """ - SELECT category - FROM self - GROUP BY category - ORDER BY category DESC - """ - ).collect() - assert order_by_group_by_res.to_dict(as_series=False) == { - "category": ["vegetables", "seafood", "meat", "fruit"] - } + for category in ("category", "category AS cat"): + category_col = category.split(" ")[-1] + order_by_group_by_res = foods.sql( + f""" + SELECT {category} + FROM self + GROUP BY category + ORDER BY {category_col} DESC + """ + ).collect() + assert order_by_group_by_res.to_dict(as_series=False) == { + category_col: ["vegetables", "seafood", "meat", "fruit"] + } order_by_constructed_group_by_res = foods.sql( """ @@ -108,8 +110,8 @@ def test_order_by_misc_selection() -> None: assert res.to_dict(as_series=False) == {"x": [1, None, 3, 2]} # order by expression - res = df.sql("SELECT (x % y) AS xmy FROM self ORDER BY x % y") - assert res.to_dict(as_series=False) == {"xmy": [1, 3, None, None]} + res = df.sql("SELECT (x % y) AS xmy FROM self ORDER BY -(x % y)") + assert res.to_dict(as_series=False) == {"xmy": [3, 1, None, None]} res = df.sql("SELECT (x % y) AS xmy FROM self ORDER BY x % y NULLS FIRST") assert res.to_dict(as_series=False) == {"xmy": [None, None, 1, 3]} From 5c12ba454f560b88f427f4c25775c21168857395 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sun, 23 Jun 2024 23:52:32 +0400 Subject: [PATCH 6/7] support `replace` & `rename` select wildcard options --- .../src/executors/projection_utils.rs | 2 +- .../src/plans/conversion/dsl_to_ir.rs | 2 +- crates/polars-sql/src/context.rs | 95 ++++++++++++++----- crates/polars-sql/tests/statements.rs | 26 +++-- .../tests/unit/sql/test_wildcard_opts.py | 79 +++++++++++++++ 5 files changed, 167 insertions(+), 37 deletions(-) create mode 100644 py-polars/tests/unit/sql/test_wildcard_opts.py diff --git a/crates/polars-mem-engine/src/executors/projection_utils.rs b/crates/polars-mem-engine/src/executors/projection_utils.rs index 125b774cf935..1ca7e085bfa3 100644 --- a/crates/polars-mem-engine/src/executors/projection_utils.rs +++ b/crates/polars-mem-engine/src/executors/projection_utils.rs @@ -263,7 +263,7 @@ pub(super) fn check_expand_literals( if duplicate_check && !names.insert(name) { let msg = format!( - "the name: '{}' is duplicate\n\n\ + "the name '{}' is duplicate\n\n\ It's possible that multiple expressions are returning the same default column \ name. If this is the case, try renaming the columns with \ `.alias(\"new_name\")` to avoid duplicate column names.", diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs index cd273a6c6ddc..a05417dbfd24 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs @@ -707,7 +707,7 @@ fn resolve_with_columns( if !output_names.insert(field.name().clone()) { let msg = format!( - "the name: '{}' passed to `LazyFrame.with_columns` is duplicate\n\n\ + "the name '{}' passed to `LazyFrame.with_columns` is duplicate\n\n\ It's possible that multiple expressions are returning the same default column name. \ If this is the case, try renaming the columns with `.alias(\"new_name\")` to avoid \ duplicate column names.", diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 0fc38925c2bb..ea86de06f4e5 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -9,9 +9,9 @@ use polars_plan::dsl::function_expr::StructFunction; use polars_plan::prelude::*; use sqlparser::ast::{ Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, Ident, JoinConstraint, - JoinOperator, ObjectName, ObjectType, Offset, OrderByExpr, Query, Select, SelectItem, SetExpr, - SetOperator, SetQuantifier, Statement, TableAlias, TableFactor, TableWithJoins, UnaryOperator, - Value as SQLValue, Values, WildcardAdditionalOptions, + JoinOperator, ObjectName, ObjectType, Offset, OrderByExpr, Query, RenameSelectItem, Select, + SelectItem, SetExpr, SetOperator, SetQuantifier, Statement, TableAlias, TableFactor, + TableWithJoins, UnaryOperator, Value as SQLValue, Values, WildcardAdditionalOptions, }; use sqlparser::dialect::GenericDialect; use sqlparser::parser::{Parser, ParserOptions}; @@ -590,13 +590,11 @@ impl SQLContext { /// Execute the 'SELECT' part of the query. fn execute_select(&mut self, select_stmt: &Select, query: &Query) -> PolarsResult { - // Determine involved dataframes. - // Note: implicit joins require more work in query parsing, - // explicit joins are preferred for now (ref: #16662) - let mut lf = if select_stmt.from.is_empty() { DataFrame::empty().lazy() } else { + // Note: implicit joins need more work to support properly, + // explicit joins are preferred for now (ref: #16662) let from = select_stmt.clone().from; if from.len() > 1 { polars_bail!(SQLInterface: "multiple tables in FROM clause are not currently supported (found {}); use explicit JOIN syntax instead", from.len()) @@ -604,12 +602,16 @@ impl SQLContext { self.execute_from_statement(from.first().unwrap())? }; - // Filter expression. + // Filter expression (WHERE clause) let schema = lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; lf = self.process_where(lf, &select_stmt.selection)?; - // Column projections. - let mut excluded_cols = Vec::new(); + // 'SELECT *' modifiers + let mut excluded_cols = vec![]; + let mut replace_exprs = vec![]; + let mut rename_cols = (&mut vec![], &mut vec![]); + + // Column projections (SELECT clause) let projections: Vec = select_stmt .projection .iter() @@ -627,6 +629,8 @@ impl SQLContext { obj_name, wildcard_options, &mut excluded_cols, + &mut rename_cols, + &mut replace_exprs, Some(schema.deref()), )?, SelectItem::Wildcard(wildcard_options) => { @@ -639,6 +643,9 @@ impl SQLContext { cols, wildcard_options, &mut excluded_cols, + &mut rename_cols, + &mut replace_exprs, + Some(schema.deref()), )? }, }) @@ -700,8 +707,8 @@ impl SQLContext { // No sort, select cols as given lf.select(projections) } else { - // Add all projections to the base frame as any of - // the original columns may be required for the sort + // Add projections to the base frame as any of the + // original columns may be required for the sort lf = lf.with_columns(projections.clone()); // Final/selected cols (also ensures accurate ordinal position refs) @@ -737,7 +744,7 @@ impl SQLContext { } }; - // Apply optional 'distinct' clause. + // Apply optional DISTINCT clause. lf = match &select_stmt.distinct { Some(Distinct::Distinct) => lf.unique_stable(None, UniqueKeepStrategy::Any), Some(Distinct::On(exprs)) => { @@ -764,6 +771,13 @@ impl SQLContext { None => lf, }; + // Apply final 'SELECT *' modifiers + if !replace_exprs.is_empty() { + lf = lf.with_columns(replace_exprs); + } + if !rename_cols.0.is_empty() { + lf = lf.rename(rename_cols.0, rename_cols.1); + } Ok(lf) } @@ -1160,13 +1174,22 @@ impl SQLContext { ObjectName(idents): &ObjectName, options: &WildcardAdditionalOptions, excluded_cols: &mut Vec, + rename_cols: &mut (&mut Vec, &mut Vec), + replace_exprs: &mut Vec, schema: Option<&Schema>, ) -> PolarsResult> { let mut new_idents = idents.clone(); new_idents.push(Ident::new("*")); let expr = resolve_compound_identifier(self, new_idents.deref(), schema); - self.process_wildcard_additional_options(expr?, options, excluded_cols) + self.process_wildcard_additional_options( + expr?, + options, + excluded_cols, + rename_cols, + replace_exprs, + schema, + ) } fn process_wildcard_additional_options( @@ -1174,26 +1197,48 @@ impl SQLContext { exprs: Vec, options: &WildcardAdditionalOptions, excluded_cols: &mut Vec, + rename_cols: &mut (&mut Vec, &mut Vec), + replace_exprs: &mut Vec, + schema: Option<&Schema>, ) -> PolarsResult> { - // bail on unsupported wildcard options - if options.opt_ilike.is_some() { - polars_bail!(SQLSyntax: "ILIKE wildcard option is unsupported") - } else if options.opt_rename.is_some() { - polars_bail!(SQLSyntax: "RENAME wildcard option is unsupported") - } else if options.opt_replace.is_some() { - polars_bail!(SQLSyntax: "REPLACE wildcard option is unsupported") - } else if options.opt_except.is_some() { - polars_bail!(SQLSyntax: "EXCEPT wildcard option is unsupported (use EXCLUDE instead)") + // bail on (currently) unsupported wildcard options + if options.opt_except.is_some() { + polars_bail!(SQLInterface: "EXCEPT wildcard option is unsupported (use EXCLUDE instead)") + } else if options.opt_ilike.is_some() { + polars_bail!(SQLInterface: "ILIKE wildcard option is currently unsupported") + } else if options.opt_rename.is_some() && options.opt_replace.is_some() { + // pending an upstream fix: https://github.com/sqlparser-rs/sqlparser-rs/pull/1321 + polars_bail!(SQLInterface: "RENAME and REPLACE wildcard options cannot (yet) be used simultaneously") } - if let Some(exc_items) = &options.opt_exclude { - *excluded_cols = match exc_items { + if let Some(items) = &options.opt_exclude { + *excluded_cols = match items { ExcludeSelectItem::Single(ident) => vec![ident.value.clone()], ExcludeSelectItem::Multiple(idents) => { idents.iter().map(|i| i.value.clone()).collect() }, }; } + if let Some(items) = &options.opt_rename { + match items { + RenameSelectItem::Single(rename) => { + rename_cols.0.push(rename.ident.value.clone()); + rename_cols.1.push(rename.alias.value.clone()); + }, + RenameSelectItem::Multiple(renames) => { + for rn in renames { + rename_cols.0.push(rn.ident.value.clone()); + rename_cols.1.push(rn.alias.value.clone()); + } + }, + } + } + if let Some(replacements) = &options.opt_replace { + for rp in &replacements.items { + let replacement_expr = parse_sql_expr(&rp.expr, self, schema); + replace_exprs.push(replacement_expr?.alias(rp.column_name.value.as_str())); + } + } Ok(exprs) } diff --git a/crates/polars-sql/tests/statements.rs b/crates/polars-sql/tests/statements.rs index 0af4ae64fa86..dd1f89027c46 100644 --- a/crates/polars-sql/tests/statements.rs +++ b/crates/polars-sql/tests/statements.rs @@ -419,7 +419,7 @@ fn test_resolve_join_column_select_13618() { } #[test] -fn test_compound_join_nested_and_with_brackets() { +fn test_compound_join_and_select_exclude_rename_replace() { let df1 = df! { "a" => [1, 2, 3, 4, 5], "b" => [1, 2, 3, 4, 5], @@ -442,10 +442,13 @@ fn test_compound_join_nested_and_with_brackets() { ctx.register("df2", df2.lazy()); let sql = r#" - SELECT df1.* EXCLUDE "e", df2.e - FROM df1 - INNER JOIN df2 ON df1.a = df2.a AND - ((df1.b = df2.b AND df1.c = df2.c) AND df1.d = df2.d) + SELECT * RENAME ("ee" AS "e") + FROM ( + SELECT df1.* EXCLUDE "e", df2.e AS "ee" + FROM df1 + INNER JOIN df2 ON df1.a = df2.a AND + ((df1.b = df2.b AND df1.c = df2.c) AND df1.d = df2.d) + ) tbl "#; let actual = ctx.execute(sql).unwrap().collect().unwrap(); let expected = df! { @@ -465,10 +468,13 @@ fn test_compound_join_nested_and_with_brackets() { ); let sql = r#" - SELECT * EXCLUDE ("e", "e:df2"), df1.e - FROM df1 - INNER JOIN df2 ON df1.a = df2.a AND - ((df1.b = df2.b AND df1.c = df2.c) AND df1.d = df2.d) + SELECT * REPLACE ("ee" || "ee" AS "ee") + FROM ( + SELECT * EXCLUDE ("e", "e:df2"), df1.e AS "ee" + FROM df1 + INNER JOIN df2 ON df1.a = df2.a AND + ((df1.b = df2.b AND df1.c = df2.c) AND df1.d = df2.d) + ) tbl "#; let actual = ctx.execute(sql).unwrap().collect().unwrap(); @@ -481,7 +487,7 @@ fn test_compound_join_nested_and_with_brackets() { "b:df2" => [1, 3], "c:df2" => [0, 4], "d:df2" => [0, 4], - "e" => ["a", "c"], + "ee" => ["aa", "cc"], } .unwrap(); diff --git a/py-polars/tests/unit/sql/test_wildcard_opts.py b/py-polars/tests/unit/sql/test_wildcard_opts.py new file mode 100644 index 000000000000..ad17a215f7da --- /dev/null +++ b/py-polars/tests/unit/sql/test_wildcard_opts.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import polars as pl +from polars.exceptions import DuplicateError + + +@pytest.fixture() +def df() -> pl.DataFrame: + return pl.DataFrame({"num": [999, 666], "str": ["b", "a"], "val": [2.0, 0.5]}) + + +@pytest.mark.parametrize( + ("excluded", "expected"), + [ + ("num", ["str", "val"]), + ("(val, num)", ["str"]), + ("(str, num)", ["val"]), + ("(str, val, num)", []), + ], +) +def test_select_exclude( + excluded: str, + expected: list[str], + df: pl.DataFrame, +) -> None: + assert df.sql(f"SELECT * EXCLUDE {excluded} FROM self").columns == expected + + +def test_select_exclude_error(df: pl.DataFrame) -> None: + with pytest.raises(DuplicateError, match="the name 'num' is duplicate"): + # note: missing "()" around the exclude option results in dupe col + assert df.sql("SELECT * EXCLUDE val, num FROM self") + + +@pytest.mark.parametrize( + ("renames", "expected"), + [ + ("val AS value", ["num", "str", "value"]), + ("(num AS flt)", ["flt", "str", "val"]), + ("(val AS value, num AS flt)", ["flt", "str", "value"]), + ], +) +def test_select_rename( + renames: str, + expected: list[str], + df: pl.DataFrame, +) -> None: + assert df.sql(f"SELECT * RENAME {renames} FROM self").columns == expected + + +@pytest.mark.parametrize( + ("replacements", "check_cols", "expected"), + [ + ( + "(num // 3 AS num)", + ["num"], + [(333,), (222,)], + ), + ( + "((str || str) AS str, num / 3 AS num)", + ["num", "str"], + [(333, "bb"), (222, "aa")], + ), + ], +) +def test_select_replace( + replacements: str, + check_cols: list[str], + expected: list[tuple[Any]], + df: pl.DataFrame, +) -> None: + res = df.sql(f"SELECT * REPLACE {replacements} FROM self") + + assert res.select(check_cols).rows() == expected + assert res.columns == df.columns From 64982d7b1c56b9f22b6556edb0939a8f8e44f081 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Mon, 24 Jun 2024 00:37:02 +0400 Subject: [PATCH 7/7] fix test --- py-polars/tests/unit/test_errors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py-polars/tests/unit/test_errors.py b/py-polars/tests/unit/test_errors.py index ebd3ec4e73ff..dc89f0e43f2e 100644 --- a/py-polars/tests/unit/test_errors.py +++ b/py-polars/tests/unit/test_errors.py @@ -474,7 +474,7 @@ def test_with_column_duplicates() -> None: df = pl.DataFrame({"a": [0, None, 2, 3, None], "b": [None, 1, 2, 3, None]}) with pytest.raises( ComputeError, - match=r"the name: 'same' passed to `LazyFrame.with_columns` is duplicate.*", + match=r"the name 'same' passed to `LazyFrame.with_columns` is duplicate.*", ): assert df.with_columns([pl.all().alias("same")]).columns == ["a", "b", "same"]