From 024a878ee7f027ca2f9c635c9398ba59653f1a4e Mon Sep 17 00:00:00 2001 From: Yuval Shkolar <85674443+yuval-illumex@users.noreply.github.com> Date: Tue, 24 Dec 2024 17:00:59 +0200 Subject: [PATCH] Support Snowflake Update-From-Select (#1604) Co-authored-by: Ifeanyi Ubah --- src/ast/mod.rs | 11 +++++++---- src/ast/query.rs | 13 +++++++++++++ src/ast/spans.rs | 13 +++++++++++-- src/keywords.rs | 1 + src/parser/mod.rs | 15 ++++++++++----- tests/sqlparser_common.rs | 20 ++++++++++++++++---- 6 files changed, 58 insertions(+), 15 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 9fb2bb9c9..5bdce21ef 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -72,8 +72,8 @@ pub use self::query::{ TableAlias, TableAliasColumnDef, TableFactor, TableFunctionArgs, TableSample, TableSampleBucket, TableSampleKind, TableSampleMethod, TableSampleModifier, TableSampleQuantity, TableSampleSeed, TableSampleSeedModifier, TableSampleUnit, TableVersion, - TableWithJoins, Top, TopQuantity, ValueTableMode, Values, WildcardAdditionalOptions, With, - WithFill, + TableWithJoins, Top, TopQuantity, UpdateTableFromKind, ValueTableMode, Values, + WildcardAdditionalOptions, With, WithFill, }; pub use self::trigger::{ @@ -2473,7 +2473,7 @@ pub enum Statement { /// Column assignments assignments: Vec, /// Table which provide value to be set - from: Option, + from: Option, /// WHERE selection: Option, /// RETURNING @@ -3745,10 +3745,13 @@ impl fmt::Display for Statement { write!(f, "{or} ")?; } write!(f, "{table}")?; + if let Some(UpdateTableFromKind::BeforeSet(from)) = from { + write!(f, " FROM {from}")?; + } if !assignments.is_empty() { write!(f, " SET {}", display_comma_separated(assignments))?; } - if let Some(from) = from { + if let Some(UpdateTableFromKind::AfterSet(from)) = from { write!(f, " FROM {from}")?; } if let Some(selection) = selection { diff --git a/src/ast/query.rs b/src/ast/query.rs index 69b7ea1c1..9e4e9e2ef 100644 --- a/src/ast/query.rs +++ b/src/ast/query.rs @@ -2790,3 +2790,16 @@ impl fmt::Display for ValueTableMode { } } } + +/// The `FROM` clause of an `UPDATE TABLE` statement +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum UpdateTableFromKind { + /// Update Statment where the 'FROM' clause is before the 'SET' keyword (Supported by Snowflake) + /// For Example: `UPDATE FROM t1 SET t1.name='aaa'` + BeforeSet(TableWithJoins), + /// Update Statment where the 'FROM' clause is after the 'SET' keyword (Which is the standard way) + /// For Example: `UPDATE SET t1.name='aaa' FROM t1` + AfterSet(TableWithJoins), +} diff --git a/src/ast/spans.rs b/src/ast/spans.rs index 9ba3bdd9b..521b5399a 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -32,8 +32,8 @@ use super::{ OrderBy, OrderByExpr, Partition, PivotValueSource, ProjectionSelect, Query, ReferentialAction, RenameSelectItem, ReplaceSelectElement, ReplaceSelectItem, Select, SelectInto, SelectItem, SetExpr, SqlOption, Statement, Subscript, SymbolDefinition, TableAlias, TableAliasColumnDef, - TableConstraint, TableFactor, TableOptionsClustered, TableWithJoins, Use, Value, Values, - ViewColumnDef, WildcardAdditionalOptions, With, WithFill, + TableConstraint, TableFactor, TableOptionsClustered, TableWithJoins, UpdateTableFromKind, Use, + Value, Values, ViewColumnDef, WildcardAdditionalOptions, With, WithFill, }; /// Given an iterator of spans, return the [Span::union] of all spans. @@ -2106,6 +2106,15 @@ impl Spanned for SelectInto { } } +impl Spanned for UpdateTableFromKind { + fn span(&self) -> Span { + match self { + UpdateTableFromKind::BeforeSet(from) => from.span(), + UpdateTableFromKind::AfterSet(from) => from.span(), + } + } +} + #[cfg(test)] pub mod tests { use crate::dialect::{Dialect, GenericDialect, SnowflakeDialect}; diff --git a/src/keywords.rs b/src/keywords.rs index bbfd00ca0..43abc2b03 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -941,6 +941,7 @@ pub const RESERVED_FOR_TABLE_ALIAS: &[Keyword] = &[ // Reserved for Snowflake table sample Keyword::SAMPLE, Keyword::TABLESAMPLE, + Keyword::FROM, ]; /// Can't be used as a column alias, so that `SELECT alias` diff --git a/src/parser/mod.rs b/src/parser/mod.rs index cc0a57e4d..57c4dc6e7 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -11791,14 +11791,19 @@ impl<'a> Parser<'a> { pub fn parse_update(&mut self) -> Result { let or = self.parse_conflict_clause(); let table = self.parse_table_and_joins()?; + let from_before_set = if self.parse_keyword(Keyword::FROM) { + Some(UpdateTableFromKind::BeforeSet( + self.parse_table_and_joins()?, + )) + } else { + None + }; self.expect_keyword(Keyword::SET)?; let assignments = self.parse_comma_separated(Parser::parse_assignment)?; - let from = if self.parse_keyword(Keyword::FROM) - && dialect_of!(self is GenericDialect | PostgreSqlDialect | DuckDbDialect | BigQueryDialect | SnowflakeDialect | RedshiftSqlDialect | MsSqlDialect | SQLiteDialect ) - { - Some(self.parse_table_and_joins()?) + let from = if from_before_set.is_none() && self.parse_keyword(Keyword::FROM) { + Some(UpdateTableFromKind::AfterSet(self.parse_table_and_joins()?)) } else { - None + from_before_set }; let selection = if self.parse_keyword(Keyword::WHERE) { Some(self.parse_expr()?) diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 79f5c8d32..cbbbb45f9 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -366,7 +366,7 @@ fn parse_update_set_from() { target: AssignmentTarget::ColumnName(ObjectName(vec![Ident::new("name")])), value: Expr::CompoundIdentifier(vec![Ident::new("t2"), Ident::new("name")]) }], - from: Some(TableWithJoins { + from: Some(UpdateTableFromKind::AfterSet(TableWithJoins { relation: TableFactor::Derived { lateral: false, subquery: Box::new(Query { @@ -417,8 +417,8 @@ fn parse_update_set_from() { columns: vec![], }) }, - joins: vec![], - }), + joins: vec![] + })), selection: Some(Expr::BinaryOp { left: Box::new(Expr::CompoundIdentifier(vec![ Ident::new("t1"), @@ -12577,9 +12577,21 @@ fn overflow() { let statement = statements.pop().unwrap(); assert_eq!(statement.to_string(), sql); } - #[test] fn parse_select_without_projection() { let dialects = all_dialects_where(|d| d.supports_empty_projections()); dialects.verified_stmt("SELECT FROM users"); } + +#[test] +fn parse_update_from_before_select() { + all_dialects() + .verified_stmt("UPDATE t1 FROM (SELECT name, id FROM t1 GROUP BY id) AS t2 SET name = t2.name WHERE t1.id = t2.id"); + + let query = + "UPDATE t1 FROM (SELECT name, id FROM t1 GROUP BY id) AS t2 SET name = t2.name FROM (SELECT name from t2) AS t2"; + assert_eq!( + ParserError::ParserError("Expected: end of statement, found: FROM".to_string()), + parse_sql_statements(query).unwrap_err() + ); +}