Skip to content

Commit

Permalink
Merge branch 'master' into keep_order_hint
Browse files Browse the repository at this point in the history
  • Loading branch information
ti-chi-bot authored Dec 20, 2022
2 parents 943853b + ae58fa1 commit ff8f21a
Show file tree
Hide file tree
Showing 17 changed files with 232 additions and 54 deletions.
2 changes: 1 addition & 1 deletion br/pkg/lightning/backend/kv/sql2kv.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func collectGeneratedColumns(se *session, meta *model.TableInfo, cols []*table.C
var genCols []genCol
for i, col := range cols {
if col.GeneratedExpr != nil {
expr, err := expression.RewriteAstExpr(se, col.GeneratedExpr, schema, names)
expr, err := expression.RewriteAstExpr(se, col.GeneratedExpr, schema, names, false)
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -6164,7 +6164,7 @@ func (d *ddl) CreatePrimaryKey(ctx sessionctx.Context, ti ast.Ident, indexName m
// After DDL job is put to the queue, and if the check fail, TiDB will run the DDL cancel logic.
// The recover step causes DDL wait a few seconds, makes the unit test painfully slow.
// For same reason, decide whether index is global here.
indexColumns, err := buildIndexColumns(ctx, tblInfo.Columns, indexPartSpecifications)
indexColumns, _, err := buildIndexColumns(ctx, tblInfo.Columns, indexPartSpecifications)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -6274,7 +6274,7 @@ func BuildHiddenColumnInfo(ctx sessionctx.Context, indexPartSpecifications []*as
if err != nil {
return nil, errors.Trace(err)
}
expr, err := expression.RewriteSimpleExprWithTableInfo(ctx, tblInfo, idxPart.Expr)
expr, err := expression.RewriteSimpleExprWithTableInfo(ctx, tblInfo, idxPart.Expr, true)
if err != nil {
// TODO: refine the error message.
return nil, err
Expand Down Expand Up @@ -6389,7 +6389,7 @@ func (d *ddl) createIndex(ctx sessionctx.Context, ti ast.Ident, keyType ast.Inde
// After DDL job is put to the queue, and if the check fail, TiDB will run the DDL cancel logic.
// The recover step causes DDL wait a few seconds, makes the unit test painfully slow.
// For same reason, decide whether index is global here.
indexColumns, err := buildIndexColumns(ctx, finalColumns, indexPartSpecifications)
indexColumns, _, err := buildIndexColumns(ctx, finalColumns, indexPartSpecifications)
if err != nil {
return errors.Trace(err)
}
Expand Down
24 changes: 18 additions & 6 deletions ddl/generated_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,14 @@ func checkModifyGeneratedColumn(sctx sessionctx.Context, tbl table.Table, oldCol
}

type illegalFunctionChecker struct {
hasIllegalFunc bool
hasAggFunc bool
hasRowVal bool // hasRowVal checks whether the functional index refers to a row value
hasWindowFunc bool
hasNotGAFunc4ExprIdx bool
otherErr error
hasIllegalFunc bool
hasAggFunc bool
hasRowVal bool // hasRowVal checks whether the functional index refers to a row value
hasWindowFunc bool
hasNotGAFunc4ExprIdx bool
hasCastArrayFunc bool
disallowCastArrayFunc bool
otherErr error
}

func (c *illegalFunctionChecker) Enter(inNode ast.Node) (outNode ast.Node, skipChildren bool) {
Expand Down Expand Up @@ -308,7 +310,14 @@ func (c *illegalFunctionChecker) Enter(inNode ast.Node) (outNode ast.Node, skipC
case *ast.WindowFuncExpr:
c.hasWindowFunc = true
return inNode, true
case *ast.FuncCastExpr:
c.hasCastArrayFunc = c.hasCastArrayFunc || node.Tp.IsArray()
if c.disallowCastArrayFunc && node.Tp.IsArray() {
c.otherErr = expression.ErrNotSupportedYet.GenWithStackByArgs("Use of CAST( .. AS .. ARRAY) outside of functional index in CREATE(non-SELECT)/ALTER TABLE or in general expressions")
return inNode, true
}
}
c.disallowCastArrayFunc = true
return inNode, false
}

Expand Down Expand Up @@ -355,6 +364,9 @@ func checkIllegalFn4Generated(name string, genType int, expr ast.ExprNode) error
if genType == typeIndex && c.hasNotGAFunc4ExprIdx && !config.GetGlobalConfig().Experimental.AllowsExpressionIndex {
return dbterror.ErrUnsupportedExpressionIndex
}
if genType == typeColumn && c.hasCastArrayFunc {
return expression.ErrNotSupportedYet.GenWithStackByArgs("Use of CAST( .. AS .. ARRAY) outside of functional index in CREATE(non-SELECT)/ALTER TABLE or in general expressions")
}
return nil
}

Expand Down
21 changes: 12 additions & 9 deletions ddl/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,26 +64,28 @@ var (
telemetryAddIndexIngestUsage = metrics.TelemetryAddIndexIngestCnt
)

func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, indexPartSpecifications []*ast.IndexPartSpecification) ([]*model.IndexColumn, error) {
func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, indexPartSpecifications []*ast.IndexPartSpecification) ([]*model.IndexColumn, bool, error) {
// Build offsets.
idxParts := make([]*model.IndexColumn, 0, len(indexPartSpecifications))
var col *model.ColumnInfo
var mvIndex bool
maxIndexLength := config.GetGlobalConfig().MaxIndexLength
// The sum of length of all index columns.
sumLength := 0
for _, ip := range indexPartSpecifications {
col = model.FindColumnInfo(columns, ip.Column.Name.L)
if col == nil {
return nil, dbterror.ErrKeyColumnDoesNotExits.GenWithStack("column does not exist: %s", ip.Column.Name)
return nil, false, dbterror.ErrKeyColumnDoesNotExits.GenWithStack("column does not exist: %s", ip.Column.Name)
}

if err := checkIndexColumn(ctx, col, ip.Length); err != nil {
return nil, err
return nil, false, err
}
mvIndex = mvIndex || col.FieldType.IsArray()
indexColLen := ip.Length
indexColumnLength, err := getIndexColumnLength(col, ip.Length)
if err != nil {
return nil, err
return nil, false, err
}
sumLength += indexColumnLength

Expand All @@ -92,12 +94,12 @@ func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, inde
// The multiple column index and the unique index in which the length sum exceeds the maximum size
// will return an error instead produce a warning.
if ctx == nil || ctx.GetSessionVars().StrictSQLMode || mysql.HasUniKeyFlag(col.GetFlag()) || len(indexPartSpecifications) > 1 {
return nil, dbterror.ErrTooLongKey.GenWithStackByArgs(maxIndexLength)
return nil, false, dbterror.ErrTooLongKey.GenWithStackByArgs(maxIndexLength)
}
// truncate index length and produce warning message in non-restrict sql mode.
colLenPerUint, err := getIndexColumnLength(col, 1)
if err != nil {
return nil, err
return nil, false, err
}
indexColLen = maxIndexLength / colLenPerUint
// produce warning message
Expand All @@ -111,7 +113,7 @@ func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, inde
})
}

return idxParts, nil
return idxParts, mvIndex, nil
}

// CheckPKOnGeneratedColumn checks the specification of PK is valid.
Expand Down Expand Up @@ -154,7 +156,7 @@ func checkIndexColumn(ctx sessionctx.Context, col *model.ColumnInfo, indexColumn
}

// JSON column cannot index.
if col.FieldType.GetType() == mysql.TypeJSON {
if col.FieldType.GetType() == mysql.TypeJSON && !col.FieldType.IsArray() {
if col.Hidden {
return dbterror.ErrFunctionalIndexOnJSONOrGeometryFunction
}
Expand Down Expand Up @@ -263,7 +265,7 @@ func BuildIndexInfo(
return nil, errors.Trace(err)
}

idxColumns, err := buildIndexColumns(ctx, allTableColumns, indexPartSpecifications)
idxColumns, mvIndex, err := buildIndexColumns(ctx, allTableColumns, indexPartSpecifications)
if err != nil {
return nil, errors.Trace(err)
}
Expand All @@ -276,6 +278,7 @@ func BuildIndexInfo(
Primary: isPrimary,
Unique: isUnique,
Global: isGlobal,
MVIndex: mvIndex,
}

if indexOption != nil {
Expand Down
2 changes: 1 addition & 1 deletion ddl/partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -1375,7 +1375,7 @@ func checkPartitionFuncType(ctx sessionctx.Context, expr ast.ExprNode, tblInfo *
return nil
}

e, err := expression.RewriteSimpleExprWithTableInfo(ctx, tblInfo, expr)
e, err := expression.RewriteSimpleExprWithTableInfo(ctx, tblInfo, expr, false)
if err != nil {
return errors.Trace(err)
}
Expand Down
1 change: 1 addition & 0 deletions expression/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ go_test(
"integration_serial_test.go",
"integration_test.go",
"main_test.go",
"multi_valued_index_test.go",
"scalar_function_test.go",
"schema_test.go",
"typeinfer_test.go",
Expand Down
83 changes: 79 additions & 4 deletions expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
package expression

import (
"fmt"
"math"
"strconv"
"strings"
Expand Down Expand Up @@ -407,6 +408,70 @@ func (c *castAsDurationFunctionClass) getFunction(ctx sessionctx.Context, args [
return sig, nil
}

type castAsArrayFunctionClass struct {
baseFunctionClass

tp *types.FieldType
}

func (c *castAsArrayFunctionClass) verifyArgs(args []Expression) error {
if err := c.baseFunctionClass.verifyArgs(args); err != nil {
return err
}

if args[0].GetType().EvalType() != types.ETJson {
return types.ErrInvalidJSONData.GenWithStackByArgs("1", "cast_as_array")
}

return nil
}

func (c *castAsArrayFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (sig builtinFunc, err error) {
if err := c.verifyArgs(args); err != nil {
return nil, err
}
arrayType := c.tp.ArrayType()
switch arrayType.GetType() {
case mysql.TypeYear, mysql.TypeJSON:
return nil, ErrNotSupportedYet.GenWithStackByArgs(fmt.Sprintf("CAST-ing data to array of %s", arrayType.String()))
}
if arrayType.EvalType() == types.ETString && arrayType.GetCharset() != charset.CharsetUTF8MB4 && arrayType.GetCharset() != charset.CharsetBin {
return nil, ErrNotSupportedYet.GenWithStackByArgs("specifying charset for multi-valued index", arrayType.String())
}

bf, err := newBaseBuiltinFunc(ctx, c.funcName, args, c.tp)
if err != nil {
return nil, err
}
sig = &castJSONAsArrayFunctionSig{bf}
return sig, nil
}

type castJSONAsArrayFunctionSig struct {
baseBuiltinFunc
}

func (b *castJSONAsArrayFunctionSig) Clone() builtinFunc {
newSig := &castJSONAsArrayFunctionSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *castJSONAsArrayFunctionSig) evalJSON(row chunk.Row) (res types.BinaryJSON, isNull bool, err error) {
val, isNull, err := b.args[0].EvalJSON(b.ctx, row)
if isNull || err != nil {
return res, isNull, err
}

if val.TypeCode != types.JSONTypeCodeArray {
return types.BinaryJSON{}, false, ErrNotSupportedYet.GenWithStackByArgs("CAST-ing Non-JSON Array type to array")
}

// TODO: impl the cast(... as ... array) function

return types.BinaryJSON{}, false, nil
}

type castAsJSONFunctionClass struct {
baseFunctionClass

Expand Down Expand Up @@ -1914,6 +1979,13 @@ func BuildCastCollationFunction(ctx sessionctx.Context, expr Expression, ec *Exp

// BuildCastFunction builds a CAST ScalarFunction from the Expression.
func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldType) (res Expression) {
res, err := BuildCastFunctionWithCheck(ctx, expr, tp)
terror.Log(err)
return
}

// BuildCastFunctionWithCheck builds a CAST ScalarFunction from the Expression and return error if any.
func BuildCastFunctionWithCheck(ctx sessionctx.Context, expr Expression, tp *types.FieldType) (res Expression, err error) {
argType := expr.GetType()
// If source argument's nullable, then target type should be nullable
if !mysql.HasNotNullFlag(argType.GetFlag()) {
Expand All @@ -1933,15 +2005,18 @@ func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldT
case types.ETDuration:
fc = &castAsDurationFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
case types.ETJson:
fc = &castAsJSONFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
if tp.IsArray() {
fc = &castAsArrayFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
} else {
fc = &castAsJSONFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
}
case types.ETString:
fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
if expr.GetType().GetType() == mysql.TypeBit {
tp.SetFlen((expr.GetType().GetFlen() + 7) / 8)
}
}
f, err := fc.getFunction(ctx, []Expression{expr})
terror.Log(err)
res = &ScalarFunction{
FuncName: model.NewCIStr(ast.Cast),
RetType: tp,
Expand All @@ -1950,10 +2025,10 @@ func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldT
// We do not fold CAST if the eval type of this scalar function is ETJson
// since we may reset the flag of the field type of CastAsJson later which
// would affect the evaluation of it.
if tp.EvalType() != types.ETJson {
if tp.EvalType() != types.ETJson && err == nil {
res = FoldConstant(res)
}
return res
return res, err
}

// WrapWithCastAsInt wraps `expr` with `cast` if the return type of expr is not
Expand Down
1 change: 1 addition & 0 deletions expression/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ var (
ErrInvalidTableSample = dbterror.ClassExpression.NewStd(mysql.ErrInvalidTableSample)
ErrInternal = dbterror.ClassOptimizer.NewStd(mysql.ErrInternal)
ErrNoDB = dbterror.ClassOptimizer.NewStd(mysql.ErrNoDB)
ErrNotSupportedYet = dbterror.ClassExpression.NewStd(mysql.ErrNotSupportedYet)

// All the un-exported errors are defined here:
errFunctionNotExists = dbterror.ClassExpression.NewStd(mysql.ErrSpDoesNotExist)
Expand Down
4 changes: 2 additions & 2 deletions expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ var EvalAstExpr func(sctx sessionctx.Context, expr ast.ExprNode) (types.Datum, e
// RewriteAstExpr rewrites ast expression directly.
// Note: initialized in planner/core
// import expression and planner/core together to use EvalAstExpr
var RewriteAstExpr func(sctx sessionctx.Context, expr ast.ExprNode, schema *Schema, names types.NameSlice) (Expression, error)
var RewriteAstExpr func(sctx sessionctx.Context, expr ast.ExprNode, schema *Schema, names types.NameSlice, allowCastArray bool) (Expression, error)

// VecExpr contains all vectorized evaluation methods.
type VecExpr interface {
Expand Down Expand Up @@ -998,7 +998,7 @@ func ColumnInfos2ColumnsAndNames(ctx sessionctx.Context, dbName, tblName model.C
if err != nil {
return nil, nil, errors.Trace(err)
}
e, err := RewriteAstExpr(ctx, expr, mockSchema, names)
e, err := RewriteAstExpr(ctx, expr, mockSchema, names, false)
if err != nil {
return nil, nil, errors.Trace(err)
}
Expand Down
47 changes: 47 additions & 0 deletions expression/multi_valued_index_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright 2022 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package expression_test

import (
"testing"

"github.com/pingcap/tidb/errno"
"github.com/pingcap/tidb/testkit"
)

func TestMultiValuedIndexDDL(t *testing.T) {
store := testkit.CreateMockStore(t)

tk := testkit.NewTestKit(t, store)
tk.MustExec("USE test;")

tk.MustExec("create table t(a json);")
tk.MustGetErrCode("select cast(a as signed array) from t", errno.ErrNotSupportedYet)
tk.MustGetErrCode("select json_extract(cast(a as signed array), '$[0]') from t", errno.ErrNotSupportedYet)
tk.MustGetErrCode("select * from t where cast(a as signed array)", errno.ErrNotSupportedYet)
tk.MustGetErrCode("select cast('[1,2,3]' as unsigned array);", errno.ErrNotSupportedYet)

tk.MustExec("drop table t")
tk.MustGetErrCode("CREATE TABLE t(x INT, KEY k ((1 AND CAST(JSON_ARRAY(x) AS UNSIGNED ARRAY))));", errno.ErrNotSupportedYet)
tk.MustGetErrCode("CREATE TABLE t1 (f1 json, key mvi((cast(cast(f1 as unsigned array) as unsigned array))));", errno.ErrNotSupportedYet)
tk.MustGetErrCode("CREATE TABLE t1 (f1 json, key mvi((cast(f1->>'$[*]' as unsigned array))));", errno.ErrInvalidJSONData)
tk.MustGetErrCode("CREATE TABLE t1 (f1 json, key mvi((cast(f1->'$[*]' as year array))));", errno.ErrNotSupportedYet)
tk.MustGetErrCode("CREATE TABLE t1 (f1 json, key mvi((cast(f1->'$[*]' as json array))));", errno.ErrNotSupportedYet)
tk.MustGetErrCode("CREATE TABLE t1 (f1 json, key mvi((cast(f1->'$[*]' as char(10) charset gbk array))));", errno.ErrNotSupportedYet)
tk.MustGetErrCode("create table t(j json, gc json as ((concat(cast(j->'$[*]' as unsigned array),\"x\"))));", errno.ErrNotSupportedYet)
tk.MustGetErrCode("create table t(j json, gc json as (cast(j->'$[*]' as unsigned array)));", errno.ErrNotSupportedYet)
tk.MustGetErrCode("create view v as select cast('[1,2,3]' as unsigned array);", errno.ErrNotSupportedYet)
tk.MustExec("create table t(a json, index idx((cast(a as signed array))));")
}
Loading

0 comments on commit ff8f21a

Please sign in to comment.