Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(mysql): Increased compatibility for MySQL #1059

Merged
merged 4 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 72 additions & 1 deletion src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2047,7 +2047,8 @@ pub enum Statement {
table_name: ObjectName,
if_exists: bool,
},
///CreateSequence -- define a new sequence
/// Define a new sequence:
///
/// CREATE [ { TEMPORARY | TEMP } ] SEQUENCE [ IF NOT EXISTS ] <sequence_name>
CreateSequence {
temporary: bool,
Expand All @@ -2068,6 +2069,15 @@ pub enum Statement {
value: Option<Value>,
is_eq: bool,
},
/// `LOCK TABLES <table_name> [READ [LOCAL] | [LOW_PRIORITY] WRITE]`
///
/// Note: this is a MySQL-specific statement. See <https://dev.mysql.com/doc/refman/8.0/en/lock-tables.html>
LockTables {
tables: Vec<LockTable>,
},
/// `UNLOCK TABLES`
/// Note: this is a MySQL-specific statement. See <https://dev.mysql.com/doc/refman/8.0/en/lock-tables.html>
UnlockTables,
}

impl fmt::Display for Statement {
Expand Down Expand Up @@ -3477,6 +3487,12 @@ impl fmt::Display for Statement {
}
Ok(())
}
Statement::LockTables { tables } => {
write!(f, "LOCK TABLES {}", display_comma_separated(tables))
}
Statement::UnlockTables => {
write!(f, "UNLOCK TABLES")
}
}
}
}
Expand Down Expand Up @@ -4979,6 +4995,61 @@ impl fmt::Display for SearchModifier {
}
}

#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct LockTable {
pub table: Ident,
pub alias: Option<Ident>,
pub lock_type: LockTableType,
}

impl fmt::Display for LockTable {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self {
table: tbl_name,
alias,
lock_type,
} = self;

write!(f, "{tbl_name} ")?;
if let Some(alias) = alias {
write!(f, "AS {alias} ")?;
}
write!(f, "{lock_type}")?;
Ok(())
}
}

#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum LockTableType {
Read { local: bool },
Write { low_priority: bool },
}

impl fmt::Display for LockTableType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Read { local } => {
write!(f, "READ")?;
if *local {
write!(f, " LOCAL")?;
}
}
Self::Write { low_priority } => {
if *low_priority {
write!(f, "LOW_PRIORITY ")?;
}
write!(f, "WRITE")?;
}
}

Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
61 changes: 59 additions & 2 deletions src/dialect/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
use alloc::boxed::Box;

use crate::{
ast::{BinaryOperator, Expr},
ast::{BinaryOperator, Expr, LockTable, LockTableType, Statement},
dialect::Dialect,
keywords::Keyword,
parser::{Parser, ParserError},
};

/// A [`Dialect`] for [MySQL](https://www.mysql.com/)
Expand Down Expand Up @@ -48,7 +49,7 @@ impl Dialect for MySqlDialect {
parser: &mut crate::parser::Parser,
expr: &crate::ast::Expr,
_precedence: u8,
) -> Option<Result<crate::ast::Expr, crate::parser::ParserError>> {
) -> Option<Result<crate::ast::Expr, ParserError>> {
// Parse DIV as an operator
if parser.parse_keyword(Keyword::DIV) {
Some(Ok(Expr::BinaryOp {
Expand All @@ -60,4 +61,60 @@ impl Dialect for MySqlDialect {
None
}
}

fn parse_statement(&self, parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
if parser.parse_keywords(&[Keyword::LOCK, Keyword::TABLES]) {
Some(parse_lock_tables(parser))
} else if parser.parse_keywords(&[Keyword::UNLOCK, Keyword::TABLES]) {
Some(parse_unlock_tables(parser))
} else {
None
}
}
}

/// `LOCK TABLES`
/// <https://dev.mysql.com/doc/refman/8.0/en/lock-tables.html>
fn parse_lock_tables(parser: &mut Parser) -> Result<Statement, ParserError> {
let tables = parser.parse_comma_separated(parse_lock_table)?;
Ok(Statement::LockTables { tables })
}

// tbl_name [[AS] alias] lock_type
fn parse_lock_table(parser: &mut Parser) -> Result<LockTable, ParserError> {
let table = parser.parse_identifier()?;
let alias =
parser.parse_optional_alias(&[Keyword::READ, Keyword::WRITE, Keyword::LOW_PRIORITY])?;
let lock_type = parse_lock_tables_type(parser)?;

Ok(LockTable {
table,
alias,
lock_type,
})
}

// READ [LOCAL] | [LOW_PRIORITY] WRITE
fn parse_lock_tables_type(parser: &mut Parser) -> Result<LockTableType, ParserError> {
if parser.parse_keyword(Keyword::READ) {
if parser.parse_keyword(Keyword::LOCAL) {
Ok(LockTableType::Read { local: true })
} else {
Ok(LockTableType::Read { local: false })
}
} else if parser.parse_keyword(Keyword::WRITE) {
Ok(LockTableType::Write {
low_priority: false,
})
} else if parser.parse_keywords(&[Keyword::LOW_PRIORITY, Keyword::WRITE]) {
Ok(LockTableType::Write { low_priority: true })
} else {
parser.expected("an lock type in LOCK TABLES", parser.peek_token())
}
}

/// UNLOCK TABLES
/// <https://dev.mysql.com/doc/refman/8.0/en/lock-tables.html>
fn parse_unlock_tables(_parser: &mut Parser) -> Result<Statement, ParserError> {
Ok(Statement::UnlockTables)
}
3 changes: 3 additions & 0 deletions src/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,9 +377,11 @@ define_keywords!(
LOCALTIME,
LOCALTIMESTAMP,
LOCATION,
LOCK,
LOCKED,
LOGIN,
LOWER,
LOW_PRIORITY,
MACRO,
MANAGEDLOCATION,
MATCH,
Expand Down Expand Up @@ -654,6 +656,7 @@ define_keywords!(
UNION,
UNIQUE,
UNKNOWN,
UNLOCK,
UNLOGGED,
UNNEST,
UNPIVOT,
Expand Down
29 changes: 17 additions & 12 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4012,17 +4012,6 @@ impl<'a> Parser<'a> {
None
};

let comment = if self.parse_keyword(Keyword::COMMENT) {
let _ = self.consume_token(&Token::Eq);
let next_token = self.next_token();
match next_token.token {
Token::SingleQuotedString(str) => Some(str),
_ => self.expected("comment", next_token)?,
}
} else {
None
};

let auto_increment_offset = if self.parse_keyword(Keyword::AUTO_INCREMENT) {
let _ = self.consume_token(&Token::Eq);
let next_token = self.next_token();
Expand Down Expand Up @@ -4097,6 +4086,18 @@ impl<'a> Parser<'a> {
};

let strict = self.parse_keyword(Keyword::STRICT);

let comment = if self.parse_keyword(Keyword::COMMENT) {
let _ = self.consume_token(&Token::Eq);
let next_token = self.next_token();
match next_token.token {
Token::SingleQuotedString(str) => Some(str),
_ => self.expected("comment", next_token)?,
}
} else {
None
};

Ok(CreateTableBuilder::new(table_name)
.temporary(temporary)
.columns(columns)
Expand Down Expand Up @@ -4183,7 +4184,7 @@ impl<'a> Parser<'a> {
pub fn parse_column_def(&mut self) -> Result<ColumnDef, ParserError> {
let name = self.parse_identifier()?;
let data_type = self.parse_data_type()?;
let collation = if self.parse_keyword(Keyword::COLLATE) {
let mut collation = if self.parse_keyword(Keyword::COLLATE) {
Some(self.parse_object_name()?)
} else {
None
Expand All @@ -4202,6 +4203,10 @@ impl<'a> Parser<'a> {
}
} else if let Some(option) = self.parse_optional_column_option()? {
options.push(ColumnOptionDef { name: None, option });
} else if dialect_of!(self is MySqlDialect | GenericDialect)
&& self.parse_keyword(Keyword::COLLATE)
{
collation = Some(self.parse_object_name()?);
} else {
break;
};
Expand Down
36 changes: 36 additions & 0 deletions tests/sqlparser_mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1871,6 +1871,42 @@ fn parse_convert_using() {
mysql().verified_only_select("SELECT CONVERT('test', CHAR CHARACTER SET utf8mb4)");
}

#[test]
fn parse_create_table_with_column_collate() {
let sql = "CREATE TABLE tb (id TEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci)";
let canonical = "CREATE TABLE tb (id TEXT COLLATE utf8mb4_0900_ai_ci CHARACTER SET utf8mb4)";
match mysql().one_statement_parses_to(sql, canonical) {
Statement::CreateTable { name, columns, .. } => {
assert_eq!(name.to_string(), "tb");
assert_eq!(
vec![ColumnDef {
name: Ident::new("id"),
data_type: DataType::Text,
collation: Some(ObjectName(vec![Ident::new("utf8mb4_0900_ai_ci")])),
options: vec![ColumnOptionDef {
name: None,
option: ColumnOption::CharacterSet(ObjectName(vec![Ident::new("utf8mb4")]))
}],
},],
columns
);
}
_ => unreachable!(),
}
}

#[test]
fn parse_lock_tables() {
mysql().one_statement_parses_to(
"LOCK TABLES trans t READ, customer WRITE",
"LOCK TABLES trans AS t READ, customer WRITE",
);
mysql().verified_stmt("LOCK TABLES trans AS t READ, customer WRITE");
mysql().verified_stmt("LOCK TABLES trans AS t READ LOCAL, customer WRITE");
mysql().verified_stmt("LOCK TABLES trans AS t READ, customer LOW_PRIORITY WRITE");
mysql().verified_stmt("UNLOCK TABLES");
}

#[test]
fn parse_json_table() {
mysql().verified_only_select("SELECT * FROM JSON_TABLE('[[1, 2], [3, 4]]', '$[*]' COLUMNS(a INT PATH '$[0]', b INT PATH '$[1]')) AS t");
Expand Down
Loading