Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

parser: refactor Parse() interface to make it extensible #28975

Merged
merged 12 commits into from
Oct 21, 2021
Merged
2 changes: 1 addition & 1 deletion br/pkg/lightning/restore/tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ func createDatabaseIfNotExistStmt(dbName string) string {
}

func createTableIfNotExistsStmt(p *parser.Parser, createTable, dbName, tblName string) ([]string, error) {
stmts, _, err := p.Parse(createTable, "", "")
stmts, _, err := p.ParseSQL(createTable)
if err != nil {
return []string{}, err
}
Expand Down
2 changes: 1 addition & 1 deletion ddl/partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -1729,7 +1729,7 @@ func findColumnByName(colName string, tblInfo *model.TableInfo) *model.ColumnInf

func extractPartitionColumns(partExpr string, tblInfo *model.TableInfo) ([]*model.ColumnInfo, error) {
partExpr = "select " + partExpr
stmts, _, err := parser.New().Parse(partExpr, "", "")
stmts, _, err := parser.New().ParseSQL(partExpr)
if err != nil {
return nil, errors.Trace(err)
}
Expand Down
2 changes: 1 addition & 1 deletion executor/index_advise.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (e *IndexAdviseInfo) getStmtNodes(data []byte) error {
e.StmtNodes = make([][]ast.StmtNode, len(sqls))
sqlParser := parser.New()
for i, sql := range sqls {
stmtNodes, warns, err := sqlParser.Parse(sql, "", "")
stmtNodes, warns, err := sqlParser.ParseSQL(sql)
if err != nil {
return err
}
Expand Down
5 changes: 2 additions & 3 deletions executor/prepared.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,18 @@ func (e *PrepareExec) Next(ctx context.Context, req *chunk.Chunk) error {
return nil
}
}
charset, collation := vars.GetCharsetInfo()
var (
stmts []ast.StmtNode
err error
)
if sqlParser, ok := e.ctx.(sqlexec.SQLParser); ok {
// FIXME: ok... yet another parse API, may need some api interface clean.
stmts, _, err = sqlParser.ParseSQL(ctx, e.sqlText, charset, collation)
stmts, _, err = sqlParser.ParseSQL(ctx, e.sqlText, vars.GetParseParams()...)
} else {
p := parser.New()
p.SetParserConfig(vars.BuildParserConfig())
var warns []error
stmts, warns, err = p.Parse(e.sqlText, charset, collation)
stmts, warns, err = p.ParseSQL(e.sqlText, vars.GetParseParams()...)
for _, warn := range warns {
e.ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn))
}
Expand Down
12 changes: 6 additions & 6 deletions expression/simple_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ func ParseSimpleExprWithTableInfo(ctx sessionctx.Context, exprStr string, tableI
var err error
var warns []error
if p, ok := ctx.(interface {
ParseSQL(context.Context, string, string, string) ([]ast.StmtNode, []error, error)
ParseSQL(context.Context, string, ...parser.ParseParam) ([]ast.StmtNode, []error, error)
}); ok {
stmts, warns, err = p.ParseSQL(context.Background(), exprStr, "", "")
stmts, warns, err = p.ParseSQL(context.Background(), exprStr)
} else {
stmts, warns, err = parser.New().Parse(exprStr, "", "")
stmts, warns, err = parser.New().ParseSQL(exprStr)
}
for _, warn := range warns {
ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn))
Expand Down Expand Up @@ -84,11 +84,11 @@ func ParseSimpleExprsWithNames(ctx sessionctx.Context, exprStr string, schema *S
var err error
var warns []error
if p, ok := ctx.(interface {
ParseSQL(context.Context, string, string, string) ([]ast.StmtNode, []error, error)
ParseSQL(context.Context, string, ...parser.ParseParam) ([]ast.StmtNode, []error, error)
}); ok {
stmts, warns, err = p.ParseSQL(context.Background(), exprStr, "", "")
stmts, warns, err = p.ParseSQL(context.Background(), exprStr)
} else {
stmts, warns, err = parser.New().Parse(exprStr, "", "")
stmts, warns, err = parser.New().ParseSQL(exprStr)
}
if err != nil {
return nil, util.SyntaxWarn(err)
Expand Down
39 changes: 35 additions & 4 deletions parser/charset/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ package charset
import (
"bytes"
"fmt"
"reflect"
"strings"
"unicode"
"unsafe"

"github.com/cznic/mathutil"
"github.com/pingcap/tidb/parser/mysql"
Expand Down Expand Up @@ -50,8 +52,8 @@ type Encoding struct {
specialCase unicode.SpecialCase
}

// Enabled indicates whether the non-utf8 encoding is used.
func (e *Encoding) Enabled() bool {
// enabled indicates whether the non-utf8 encoding is used.
func (e *Encoding) enabled() bool {
return e.enc != nil && e.charLength != nil
}

Expand Down Expand Up @@ -93,20 +95,38 @@ func (e *Encoding) UpdateEncoding(label EncodingLabel) {

// Encode convert bytes from utf-8 charset to a specific charset.
func (e *Encoding) Encode(dest, src []byte) ([]byte, error) {
if !e.Enabled() {
if !e.enabled() {
return src, nil
}
return e.transform(e.enc.NewEncoder(), dest, src, false)
}

// EncodeString convert a string from utf-8 charset to a specific charset.
func (e *Encoding) EncodeString(src string) (string, error) {
if !e.enabled() {
return src, nil
}
bs, err := e.transform(e.enc.NewEncoder(), nil, Slice(src), false)
return string(bs), err
}

// Decode convert bytes from a specific charset to utf-8 charset.
func (e *Encoding) Decode(dest, src []byte) ([]byte, error) {
if !e.Enabled() {
if !e.enabled() {
return src, nil
}
return e.transform(e.enc.NewDecoder(), dest, src, true)
}

// DecodeString convert a string from a specific charset to utf-8 charset.
func (e *Encoding) DecodeString(src string) (string, error) {
if !e.enabled() {
return src, nil
}
bs, err := e.transform(e.enc.NewDecoder(), nil, Slice(src), true)
return string(bs), err
}

func (e *Encoding) transform(transformer transform.Transformer, dest, src []byte, isDecoding bool) ([]byte, error) {
if len(dest) < len(src) {
dest = make([]byte, len(src)*2)
Expand Down Expand Up @@ -164,3 +184,14 @@ var replacementBytes = []byte{0xEF, 0xBF, 0xBD}
func beginWithReplacementChar(dst []byte) bool {
return bytes.HasPrefix(dst, replacementBytes)
}

// Slice converts string to slice without copy.
// Use at your own risk.
func Slice(s string) (b []byte) {
pBytes := (*reflect.SliceHeader)(unsafe.Pointer(&b))
pString := (*reflect.StringHeader)(unsafe.Pointer(&s))
pBytes.Data = pString.Data
pBytes.Len = pString.Len
pBytes.Cap = pString.Len
return
}
2 changes: 0 additions & 2 deletions parser/charset/encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,10 @@ type testEncodingSuite struct {
func (s *testEncodingSuite) TestEncoding(c *C) {
enc := charset.NewEncoding("gbk")
c.Assert(enc.Name(), Equals, "gbk")
c.Assert(enc.Enabled(), IsTrue)
enc.UpdateEncoding("utf-8")
c.Assert(enc.Name(), Equals, "utf-8")
enc.UpdateEncoding("gbk")
c.Assert(enc.Name(), Equals, "gbk")
c.Assert(enc.Enabled(), IsTrue)

txt := []byte("一二三四")
e, _ := charset.Lookup("gbk")
Expand Down
4 changes: 2 additions & 2 deletions parser/lexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,12 @@ func (s *Scanner) AppendError(err error) {
}

func (s *Scanner) tryDecodeToUTF8String(sql string) string {
utf8Lit, err := s.encoding.Decode(nil, Slice(sql))
utf8Lit, err := s.encoding.DecodeString(sql)
if err != nil {
s.AppendError(err)
s.lastErrorAsWarn()
}
return string(utf8Lit)
return utf8Lit
}

func (s *Scanner) getNextToken() int {
Expand Down
16 changes: 0 additions & 16 deletions parser/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,6 @@

package parser

import (
"reflect"
"unsafe"
)

func isLetter(ch rune) bool {
return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')
}
Expand Down Expand Up @@ -991,14 +986,3 @@ func (s *Scanner) isTokenIdentifier(lit string, offset int) int {
}
return tok
}

// Slice converts string to slice without copy.
// Use at your own risk.
func Slice(s string) (b []byte) {
pBytes := (*reflect.SliceHeader)(unsafe.Pointer(&b))
pString := (*reflect.StringHeader)(unsafe.Pointer(&s))
pBytes.Data = pString.Data
pBytes.Len = pString.Len
pBytes.Cap = pString.Len
return
}
19 changes: 10 additions & 9 deletions parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6402,31 +6402,32 @@ func (s *testParserSuite) TestGBKEncoding(c *C) {
sql, err := encoder.String("create table 测试表 (测试列 varchar(255) default 'GBK测试用例');")
c.Assert(err, IsNil)

stmt, err := p.ParseOneStmt(sql, "", "")
stmt, _, err := p.ParseSQL(sql)
c.Assert(err, IsNil)
checker := &gbkEncodingChecker{}
_, _ = stmt.Accept(checker)
_, _ = stmt[0].Accept(checker)
c.Assert(checker.tblName, Not(Equals), "测试表")
c.Assert(checker.colName, Not(Equals), "测试列")

p.SetParserConfig(parser.ParserConfig{CharsetClient: "gbk"})
stmt, err = p.ParseOneStmt(sql, "", "")
gbkOpt := parser.CharsetClient("gbk")
stmt, _, err = p.ParseSQL(sql, gbkOpt)
c.Assert(err, IsNil)
_, _ = stmt.Accept(checker)
_, _ = stmt[0].Accept(checker)
c.Assert(checker.tblName, Equals, "测试表")
c.Assert(checker.colName, Equals, "测试列")
c.Assert(checker.expr, Equals, "GBK测试用例")

utf8SQL := "select '芢' from `玚`;"
sql, err = encoder.String(utf8SQL)
c.Assert(err, IsNil)
stmt, err = p.ParseOneStmt(sql, "", "")
stmt, _, err = p.ParseSQL(sql, gbkOpt)
c.Assert(err, IsNil)
stmt, err = p.ParseOneStmt("select '\xc6\x5c' from `\xab\x60`;", "", "")
stmt, _, err = p.ParseSQL("select '\xc6\x5c' from `\xab\x60`;", gbkOpt)
c.Assert(err, IsNil)
stmt, _, err = p.ParseSQL(`prepare p1 from "insert into t values ('中文');";`, gbkOpt)
c.Assert(err, IsNil)

p.SetParserConfig(parser.ParserConfig{CharsetClient: ""})
stmt, err = p.ParseOneStmt("select _gbk '\xc6\x5c' from dual;", "", "")
stmt, _, err = p.ParseSQL("select _gbk '\xc6\x5c' from dual;")
c.Assert(err, NotNil)
}

Expand Down
82 changes: 68 additions & 14 deletions parser/yy_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ type ParserConfig struct {
EnableWindowFunction bool
EnableStrictDoubleTypeCheck bool
SkipPositionRecording bool
CharsetClient string // CharsetClient indicates how to decode the original SQL.
}

// Parser represents a parser instance. Some temporary objects are stored in it to reduce object allocation during Parse function.
Expand Down Expand Up @@ -134,21 +133,17 @@ func (parser *Parser) SetParserConfig(config ParserConfig) {
parser.EnableWindowFunc(config.EnableWindowFunction)
parser.SetStrictDoubleTypeCheck(config.EnableStrictDoubleTypeCheck)
parser.lexer.skipPositionRecording = config.SkipPositionRecording
parser.lexer.encoding = *charset.NewEncoding(config.CharsetClient)
}

// Parse parses a query string to raw ast.StmtNode.
// If charset or collation is "", default charset and collation will be used.
func (parser *Parser) Parse(sql, charset, collation string) (stmt []ast.StmtNode, warns []error, err error) {
sql = parser.lexer.tryDecodeToUTF8String(sql)
if charset == "" {
charset = mysql.DefaultCharset
}
if collation == "" {
collation = mysql.DefaultCollationName
// ParseSQL parses a query string to raw ast.StmtNode.
func (parser *Parser) ParseSQL(sql string, params ...ParseParam) (stmt []ast.StmtNode, warns []error, err error) {
resetParams(parser)
for _, p := range params {
if err := p.ApplyOn(parser); err != nil {
return nil, nil, err
}
}
parser.charset = charset
parser.collation = collation
sql = parser.lexer.tryDecodeToUTF8String(sql)
parser.src = sql
parser.result = parser.result[:0]

Expand All @@ -172,14 +167,20 @@ func (parser *Parser) Parse(sql, charset, collation string) (stmt []ast.StmtNode
return parser.result, warns, nil
}

// Parse parses a query string to raw ast.StmtNode.
// If charset or collation is "", default charset and collation will be used.
func (parser *Parser) Parse(sql, charset, collation string) (stmt []ast.StmtNode, warns []error, err error) {
return parser.ParseSQL(sql, CharsetConnection(charset), CollationConnection(collation))
}

func (parser *Parser) lastErrorAsWarn() {
parser.lexer.lastErrorAsWarn()
}

// ParseOneStmt parses a query and returns an ast.StmtNode.
// The query must have one statement, otherwise ErrSyntax is returned.
func (parser *Parser) ParseOneStmt(sql, charset, collation string) (ast.StmtNode, error) {
stmts, _, err := parser.Parse(sql, charset, collation)
stmts, _, err := parser.ParseSQL(sql, CharsetConnection(charset), CollationConnection(collation))
if err != nil {
return nil, errors.Trace(err)
}
Expand Down Expand Up @@ -374,3 +375,56 @@ func convertToPriv(roleOrPrivList []*ast.RoleOrPriv) ([]*ast.PrivElem, error) {
}
return privileges, nil
}

var (
_ ParseParam = CharsetConnection("")
_ ParseParam = CollationConnection("")
_ ParseParam = CharsetClient("")
)

func resetParams(p *Parser) {
p.charset = mysql.DefaultCharset
p.collation = mysql.DefaultCollationName
p.lexer.encoding = charset.Encoding{}
}

// ParseParam represents the parameter of parsing.
type ParseParam interface {
ApplyOn(*Parser) error
}

// CharsetConnection is used for literals specified without a character set.
type CharsetConnection string

// ApplyOn implements ParseParam interface.
func (c CharsetConnection) ApplyOn(p *Parser) error {
if c == "" {
p.charset = mysql.DefaultCharset
} else {
p.charset = string(c)
}
return nil
}

// CollationConnection is used for literals specified without a collation.
type CollationConnection string

// ApplyOn implements ParseParam interface.
func (c CollationConnection) ApplyOn(p *Parser) error {
if c == "" {
p.collation = mysql.DefaultCollationName
} else {
p.collation = string(c)
}
return nil
}

// CharsetClient specifies the charset of a SQL.
// This is used to decode the SQL into a utf-8 string.
type CharsetClient string

// ApplyOn implements ParseParam interface.
func (c CharsetClient) ApplyOn(p *Parser) error {
p.lexer.encoding = *charset.NewEncoding(string(c))
return nil
}
Loading