Skip to content

Commit

Permalink
feat(mysql): Increased compatibility for MySQL (#1059)
Browse files Browse the repository at this point in the history
Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
zzzdong and alamb authored Dec 19, 2023
1 parent f46f147 commit d0fce12
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 15 deletions.
73 changes: 72 additions & 1 deletion src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2047,7 +2047,8 @@ pub enum Statement {
table_name: ObjectName,
if_exists: bool,
},
///CreateSequence -- define a new sequence
/// Define a new sequence:
///
/// CREATE [ { TEMPORARY | TEMP } ] SEQUENCE [ IF NOT EXISTS ] <sequence_name>
CreateSequence {
temporary: bool,
Expand All @@ -2068,6 +2069,15 @@ pub enum Statement {
value: Option<Value>,
is_eq: bool,
},
/// `LOCK TABLES <table_name> [READ [LOCAL] | [LOW_PRIORITY] WRITE]`
///
/// Note: this is a MySQL-specific statement. See <https://dev.mysql.com/doc/refman/8.0/en/lock-tables.html>
LockTables {
tables: Vec<LockTable>,
},
/// `UNLOCK TABLES`
/// Note: this is a MySQL-specific statement. See <https://dev.mysql.com/doc/refman/8.0/en/lock-tables.html>
UnlockTables,
}

impl fmt::Display for Statement {
Expand Down Expand Up @@ -3477,6 +3487,12 @@ impl fmt::Display for Statement {
}
Ok(())
}
Statement::LockTables { tables } => {
write!(f, "LOCK TABLES {}", display_comma_separated(tables))
}
Statement::UnlockTables => {
write!(f, "UNLOCK TABLES")
}
}
}
}
Expand Down Expand Up @@ -4979,6 +4995,61 @@ impl fmt::Display for SearchModifier {
}
}

#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct LockTable {
pub table: Ident,
pub alias: Option<Ident>,
pub lock_type: LockTableType,
}

impl fmt::Display for LockTable {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self {
table: tbl_name,
alias,
lock_type,
} = self;

write!(f, "{tbl_name} ")?;
if let Some(alias) = alias {
write!(f, "AS {alias} ")?;
}
write!(f, "{lock_type}")?;
Ok(())
}
}

#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum LockTableType {
Read { local: bool },
Write { low_priority: bool },
}

impl fmt::Display for LockTableType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Read { local } => {
write!(f, "READ")?;
if *local {
write!(f, " LOCAL")?;
}
}
Self::Write { low_priority } => {
if *low_priority {
write!(f, "LOW_PRIORITY ")?;
}
write!(f, "WRITE")?;
}
}

Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
61 changes: 59 additions & 2 deletions src/dialect/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
use alloc::boxed::Box;

use crate::{
ast::{BinaryOperator, Expr},
ast::{BinaryOperator, Expr, LockTable, LockTableType, Statement},
dialect::Dialect,
keywords::Keyword,
parser::{Parser, ParserError},
};

/// A [`Dialect`] for [MySQL](https://www.mysql.com/)
Expand Down Expand Up @@ -48,7 +49,7 @@ impl Dialect for MySqlDialect {
parser: &mut crate::parser::Parser,
expr: &crate::ast::Expr,
_precedence: u8,
) -> Option<Result<crate::ast::Expr, crate::parser::ParserError>> {
) -> Option<Result<crate::ast::Expr, ParserError>> {
// Parse DIV as an operator
if parser.parse_keyword(Keyword::DIV) {
Some(Ok(Expr::BinaryOp {
Expand All @@ -60,4 +61,60 @@ impl Dialect for MySqlDialect {
None
}
}

fn parse_statement(&self, parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
if parser.parse_keywords(&[Keyword::LOCK, Keyword::TABLES]) {
Some(parse_lock_tables(parser))
} else if parser.parse_keywords(&[Keyword::UNLOCK, Keyword::TABLES]) {
Some(parse_unlock_tables(parser))
} else {
None
}
}
}

/// `LOCK TABLES`
/// <https://dev.mysql.com/doc/refman/8.0/en/lock-tables.html>
fn parse_lock_tables(parser: &mut Parser) -> Result<Statement, ParserError> {
let tables = parser.parse_comma_separated(parse_lock_table)?;
Ok(Statement::LockTables { tables })
}

// tbl_name [[AS] alias] lock_type
fn parse_lock_table(parser: &mut Parser) -> Result<LockTable, ParserError> {
let table = parser.parse_identifier()?;
let alias =
parser.parse_optional_alias(&[Keyword::READ, Keyword::WRITE, Keyword::LOW_PRIORITY])?;
let lock_type = parse_lock_tables_type(parser)?;

Ok(LockTable {
table,
alias,
lock_type,
})
}

// READ [LOCAL] | [LOW_PRIORITY] WRITE
fn parse_lock_tables_type(parser: &mut Parser) -> Result<LockTableType, ParserError> {
if parser.parse_keyword(Keyword::READ) {
if parser.parse_keyword(Keyword::LOCAL) {
Ok(LockTableType::Read { local: true })
} else {
Ok(LockTableType::Read { local: false })
}
} else if parser.parse_keyword(Keyword::WRITE) {
Ok(LockTableType::Write {
low_priority: false,
})
} else if parser.parse_keywords(&[Keyword::LOW_PRIORITY, Keyword::WRITE]) {
Ok(LockTableType::Write { low_priority: true })
} else {
parser.expected("an lock type in LOCK TABLES", parser.peek_token())
}
}

/// UNLOCK TABLES
/// <https://dev.mysql.com/doc/refman/8.0/en/lock-tables.html>
fn parse_unlock_tables(_parser: &mut Parser) -> Result<Statement, ParserError> {
Ok(Statement::UnlockTables)
}
3 changes: 3 additions & 0 deletions src/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,9 +377,11 @@ define_keywords!(
LOCALTIME,
LOCALTIMESTAMP,
LOCATION,
LOCK,
LOCKED,
LOGIN,
LOWER,
LOW_PRIORITY,
MACRO,
MANAGEDLOCATION,
MATCH,
Expand Down Expand Up @@ -654,6 +656,7 @@ define_keywords!(
UNION,
UNIQUE,
UNKNOWN,
UNLOCK,
UNLOGGED,
UNNEST,
UNPIVOT,
Expand Down
29 changes: 17 additions & 12 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4012,17 +4012,6 @@ impl<'a> Parser<'a> {
None
};

let comment = if self.parse_keyword(Keyword::COMMENT) {
let _ = self.consume_token(&Token::Eq);
let next_token = self.next_token();
match next_token.token {
Token::SingleQuotedString(str) => Some(str),
_ => self.expected("comment", next_token)?,
}
} else {
None
};

let auto_increment_offset = if self.parse_keyword(Keyword::AUTO_INCREMENT) {
let _ = self.consume_token(&Token::Eq);
let next_token = self.next_token();
Expand Down Expand Up @@ -4097,6 +4086,18 @@ impl<'a> Parser<'a> {
};

let strict = self.parse_keyword(Keyword::STRICT);

let comment = if self.parse_keyword(Keyword::COMMENT) {
let _ = self.consume_token(&Token::Eq);
let next_token = self.next_token();
match next_token.token {
Token::SingleQuotedString(str) => Some(str),
_ => self.expected("comment", next_token)?,
}
} else {
None
};

Ok(CreateTableBuilder::new(table_name)
.temporary(temporary)
.columns(columns)
Expand Down Expand Up @@ -4183,7 +4184,7 @@ impl<'a> Parser<'a> {
pub fn parse_column_def(&mut self) -> Result<ColumnDef, ParserError> {
let name = self.parse_identifier()?;
let data_type = self.parse_data_type()?;
let collation = if self.parse_keyword(Keyword::COLLATE) {
let mut collation = if self.parse_keyword(Keyword::COLLATE) {
Some(self.parse_object_name()?)
} else {
None
Expand All @@ -4202,6 +4203,10 @@ impl<'a> Parser<'a> {
}
} else if let Some(option) = self.parse_optional_column_option()? {
options.push(ColumnOptionDef { name: None, option });
} else if dialect_of!(self is MySqlDialect | GenericDialect)
&& self.parse_keyword(Keyword::COLLATE)
{
collation = Some(self.parse_object_name()?);
} else {
break;
};
Expand Down
36 changes: 36 additions & 0 deletions tests/sqlparser_mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1871,6 +1871,42 @@ fn parse_convert_using() {
mysql().verified_only_select("SELECT CONVERT('test', CHAR CHARACTER SET utf8mb4)");
}

#[test]
fn parse_create_table_with_column_collate() {
let sql = "CREATE TABLE tb (id TEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci)";
let canonical = "CREATE TABLE tb (id TEXT COLLATE utf8mb4_0900_ai_ci CHARACTER SET utf8mb4)";
match mysql().one_statement_parses_to(sql, canonical) {
Statement::CreateTable { name, columns, .. } => {
assert_eq!(name.to_string(), "tb");
assert_eq!(
vec![ColumnDef {
name: Ident::new("id"),
data_type: DataType::Text,
collation: Some(ObjectName(vec![Ident::new("utf8mb4_0900_ai_ci")])),
options: vec![ColumnOptionDef {
name: None,
option: ColumnOption::CharacterSet(ObjectName(vec![Ident::new("utf8mb4")]))
}],
},],
columns
);
}
_ => unreachable!(),
}
}

#[test]
fn parse_lock_tables() {
mysql().one_statement_parses_to(
"LOCK TABLES trans t READ, customer WRITE",
"LOCK TABLES trans AS t READ, customer WRITE",
);
mysql().verified_stmt("LOCK TABLES trans AS t READ, customer WRITE");
mysql().verified_stmt("LOCK TABLES trans AS t READ LOCAL, customer WRITE");
mysql().verified_stmt("LOCK TABLES trans AS t READ, customer LOW_PRIORITY WRITE");
mysql().verified_stmt("UNLOCK TABLES");
}

#[test]
fn parse_json_table() {
mysql().verified_only_select("SELECT * FROM JSON_TABLE('[[1, 2], [3, 4]]', '$[*]' COLUMNS(a INT PATH '$[0]', b INT PATH '$[1]')) AS t");
Expand Down

0 comments on commit d0fce12

Please sign in to comment.