Skip to content

Commit

Permalink
Fix the parsing error in MSSQL for multiple statements that include `…
Browse files Browse the repository at this point in the history
…DECLARE` statements (#1497)
  • Loading branch information
wugeer authored Nov 13, 2024
1 parent 3a8369a commit 632ba4c
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 46 deletions.
94 changes: 50 additions & 44 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5321,55 +5321,61 @@ impl<'a> Parser<'a> {
/// ```
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/language-elements/declare-local-variable-transact-sql?view=sql-server-ver16
pub fn parse_mssql_declare(&mut self) -> Result<Statement, ParserError> {
let mut stmts = vec![];

loop {
let name = {
let ident = self.parse_identifier(false)?;
if !ident.value.starts_with('@') {
Err(ParserError::TokenizerError(
"Invalid MsSql variable declaration.".to_string(),
))
} else {
Ok(ident)
}
}?;
let stmts = self.parse_comma_separated(Parser::parse_mssql_declare_stmt)?;

let (declare_type, data_type) = match self.peek_token().token {
Token::Word(w) => match w.keyword {
Keyword::CURSOR => {
self.next_token();
(Some(DeclareType::Cursor), None)
}
Keyword::AS => {
self.next_token();
(None, Some(self.parse_data_type()?))
}
_ => (None, Some(self.parse_data_type()?)),
},
_ => (None, Some(self.parse_data_type()?)),
};
Ok(Statement::Declare { stmts })
}

let assignment = self.parse_mssql_variable_declaration_expression()?;
/// Parse the body of a [MsSql] `DECLARE`statement.
///
/// Syntax:
/// ```text
// {
// { @local_variable [AS] data_type [ = value ] }
// | { @cursor_variable_name CURSOR }
// } [ ,...n ]
/// ```
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/language-elements/declare-local-variable-transact-sql?view=sql-server-ver16
pub fn parse_mssql_declare_stmt(&mut self) -> Result<Declare, ParserError> {
let name = {
let ident = self.parse_identifier(false)?;
if !ident.value.starts_with('@') {
Err(ParserError::TokenizerError(
"Invalid MsSql variable declaration.".to_string(),
))
} else {
Ok(ident)
}
}?;

stmts.push(Declare {
names: vec![name],
data_type,
assignment,
declare_type,
binary: None,
sensitive: None,
scroll: None,
hold: None,
for_query: None,
});
let (declare_type, data_type) = match self.peek_token().token {
Token::Word(w) => match w.keyword {
Keyword::CURSOR => {
self.next_token();
(Some(DeclareType::Cursor), None)
}
Keyword::AS => {
self.next_token();
(None, Some(self.parse_data_type()?))
}
_ => (None, Some(self.parse_data_type()?)),
},
_ => (None, Some(self.parse_data_type()?)),
};

if self.next_token() != Token::Comma {
break;
}
}
let assignment = self.parse_mssql_variable_declaration_expression()?;

Ok(Statement::Declare { stmts })
Ok(Declare {
names: vec![name],
data_type,
assignment,
declare_type,
binary: None,
sensitive: None,
scroll: None,
hold: None,
for_query: None,
})
}

/// Parses the assigned expression in a variable declaration.
Expand Down
71 changes: 69 additions & 2 deletions tests/sqlparser_mssql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use sqlparser::ast::DeclareAssignment::MsSqlAssignment;
use sqlparser::ast::Value::SingleQuotedString;
use sqlparser::ast::*;
use sqlparser::dialect::{GenericDialect, MsSqlDialect};
use sqlparser::parser::{Parser, ParserError};
use sqlparser::parser::ParserError;

#[test]
fn parse_mssql_identifiers() {
Expand Down Expand Up @@ -910,7 +910,7 @@ fn parse_substring_in_select() {
#[test]
fn parse_mssql_declare() {
let sql = "DECLARE @foo CURSOR, @bar INT, @baz AS TEXT = 'foobar';";
let ast = Parser::parse_sql(&MsSqlDialect {}, sql).unwrap();
let ast = ms().parse_sql_statements(sql).unwrap();

assert_eq!(
vec![Statement::Declare {
Expand Down Expand Up @@ -963,6 +963,73 @@ fn parse_mssql_declare() {
}],
ast
);

let sql = "DECLARE @bar INT;SET @bar = 2;SELECT @bar * 4";
let ast = ms().parse_sql_statements(sql).unwrap();
assert_eq!(
vec![
Statement::Declare {
stmts: vec![Declare {
names: vec![Ident {
value: "@bar".to_string(),
quote_style: None
}],
data_type: Some(Int(None)),
assignment: None,
declare_type: None,
binary: None,
sensitive: None,
scroll: None,
hold: None,
for_query: None
}]
},
Statement::SetVariable {
local: false,
hivevar: false,
variables: OneOrManyWithParens::One(ObjectName(vec![Ident::new("@bar")])),
value: vec![Expr::Value(Value::Number("2".parse().unwrap(), false))],
},
Statement::Query(Box::new(Query {
with: None,
limit: None,
limit_by: vec![],
offset: None,
fetch: None,
locks: vec![],
for_clause: None,
order_by: None,
settings: None,
format_clause: None,
body: Box::new(SetExpr::Select(Box::new(Select {
distinct: None,
top: None,
top_before_distinct: false,
projection: vec![SelectItem::UnnamedExpr(Expr::BinaryOp {
left: Box::new(Expr::Identifier(Ident::new("@bar"))),
op: BinaryOperator::Multiply,
right: Box::new(Expr::Value(Value::Number("4".parse().unwrap(), false))),
})],
into: None,
from: vec![],
lateral_views: vec![],
prewhere: None,
selection: None,
group_by: GroupByExpr::Expressions(vec![], vec![]),
cluster_by: vec![],
distribute_by: vec![],
sort_by: vec![],
having: None,
named_window: vec![],
window_before_qualify: false,
qualify: None,
value_table_mode: None,
connect_by: None,
})))
}))
],
ast
);
}

#[test]
Expand Down

0 comments on commit 632ba4c

Please sign in to comment.