Skip to content

Commit

Permalink
expression: expression.BuildSimpleExpr supports to build `ParamMark…
Browse files Browse the repository at this point in the history
…er` (#55493)

close #55492
  • Loading branch information
lcwangchao authored Aug 19, 2024
1 parent 45b127d commit 509d1bd
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 66 deletions.
5 changes: 5 additions & 0 deletions pkg/expression/contextstatic/exprctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,11 @@ func (ctx *StaticExprContext) GetEvalCtx() exprctx.EvalContext {
return ctx.evalCtx
}

// GetStaticEvalCtx returns the inner `StaticEvalContext`.
func (ctx *StaticExprContext) GetStaticEvalCtx() *StaticEvalContext {
return ctx.evalCtx
}

// GetCharsetInfo implements the `ExprContext.GetCharsetInfo`.
func (ctx *StaticExprContext) GetCharsetInfo() (string, string) {
return ctx.charset, ctx.collation
Expand Down
9 changes: 4 additions & 5 deletions pkg/expression/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import (
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/parser/opcode"
"github.com/pingcap/tidb/pkg/parser/terror"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/types"
driver "github.com/pingcap/tidb/pkg/types/parser_driver"
"github.com/pingcap/tidb/pkg/util/chunk"
Expand Down Expand Up @@ -1196,8 +1195,8 @@ func DatumToConstant(d types.Datum, tp byte, flag uint) *Constant {
}

// ParamMarkerExpression generate a getparam function expression.
func ParamMarkerExpression(ctx variable.SessionVarsProvider, v *driver.ParamMarkerExpr, needParam bool) (*Constant, error) {
useCache := ctx.GetSessionVars().StmtCtx.UseCache()
func ParamMarkerExpression(ctx BuildContext, v *driver.ParamMarkerExpr, needParam bool) (*Constant, error) {
useCache := ctx.IsUseCache()
tp := types.NewFieldType(mysql.TypeUnspecified)
types.InferParamTypeFromDatum(&v.Datum, tp)
value := &Constant{Value: v.Datum, RetType: tp}
Expand Down Expand Up @@ -1251,11 +1250,11 @@ func ConstructPositionExpr(p *driver.ParamMarkerExpr) *ast.PositionExpr {
}

// PosFromPositionExpr generates a position value from PositionExpr.
func PosFromPositionExpr(ctx BuildContext, vars variable.SessionVarsProvider, v *ast.PositionExpr) (int, bool, error) {
func PosFromPositionExpr(ctx BuildContext, v *ast.PositionExpr) (int, bool, error) {
if v.P == nil {
return v.N, false, nil
}
value, err := ParamMarkerExpression(vars, v.P.(*driver.ParamMarkerExpr), false)
value, err := ParamMarkerExpression(ctx, v.P.(*driver.ParamMarkerExpr), false)
if err != nil {
return 0, true, err
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/planner/core/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ go_test(
"//pkg/domain",
"//pkg/expression",
"//pkg/expression/aggregation",
"//pkg/expression/context",
"//pkg/expression/contextstatic",
"//pkg/infoschema",
"//pkg/kv",
"//pkg/metrics",
Expand Down
28 changes: 15 additions & 13 deletions pkg/planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1445,19 +1445,7 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok
}
er.ctxStackAppend(value, types.EmptyName)
case *driver.ParamMarkerExpr:
withPlanCtx(func(planCtx *exprRewriterPlanCtx) {
var value *expression.Constant
value, er.err = expression.ParamMarkerExpression(planCtx.builder.ctx, v, false)
if er.err != nil {
return
}
initConstantRepertoire(er.sctx.GetEvalCtx(), value)
er.adjustUTF8MB4Collation(value.RetType)
if er.err != nil {
return
}
er.ctxStackAppend(value, types.EmptyName)
})
er.toParamMarker(v)
case *ast.VariableExpr:
withPlanCtx(func(planCtx *exprRewriterPlanCtx) {
er.rewriteVariable(planCtx, v)
Expand Down Expand Up @@ -2407,6 +2395,20 @@ func (er *expressionRewriter) toTable(v *ast.TableName) {
er.ctxStackAppend(val, types.EmptyName)
}

func (er *expressionRewriter) toParamMarker(v *driver.ParamMarkerExpr) {
var value *expression.Constant
value, er.err = expression.ParamMarkerExpression(er.sctx, v, false)
if er.err != nil {
return
}
initConstantRepertoire(er.sctx.GetEvalCtx(), value)
er.adjustUTF8MB4Collation(value.RetType)
if er.err != nil {
return
}
er.ctxStackAppend(value, types.EmptyName)
}

func (er *expressionRewriter) clause() clauseCode {
if er.planCtx != nil {
return er.planCtx.builder.curClause
Expand Down
57 changes: 33 additions & 24 deletions pkg/planner/core/expression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@ import (

"github.com/pingcap/tidb/pkg/domain"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/expression/context"
"github.com/pingcap/tidb/pkg/expression/contextstatic"
"github.com/pingcap/tidb/pkg/parser"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/charset"
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/testkit/testutil"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
Expand Down Expand Up @@ -404,39 +407,36 @@ func TestBuildExpression(t *testing.T) {
},
}

ctx := MockContext()
defer func() {
domain.GetDomain(ctx).StatsHandle().Close()
}()

ctx := contextstatic.NewStaticExprContext()
evalCtx := ctx.GetStaticEvalCtx()
cols, names, err := expression.ColumnInfos2ColumnsAndNames(ctx, model.NewCIStr(""), tbl.Name, tbl.Cols(), tbl)
require.NoError(t, err)
schema := expression.NewSchema(cols...)

// normal build
ctx.GetSessionVars().PlanColumnID.Store(0)
ctx = ctx.Apply(contextstatic.WithColumnIDAllocator(context.NewSimplePlanColumnIDAllocator(0)))
expr, err := buildExpr(t, ctx, "(1+a)*(3+b)", expression.WithTableInfo("", tbl))
require.NoError(t, err)
ctx.GetSessionVars().PlanColumnID.Store(0)
ctx = ctx.Apply(contextstatic.WithColumnIDAllocator(context.NewSimplePlanColumnIDAllocator(0)))
expr2, err := expression.ParseSimpleExpr(ctx, "(1+a)*(3+b)", expression.WithTableInfo("", tbl))
require.NoError(t, err)
require.True(t, expr.Equal(ctx, expr2))
val, _, err := expr.EvalInt(ctx, chunk.MutRowFromValues("", 1, 2).ToRow())
require.True(t, expr.Equal(evalCtx, expr2))
val, _, err := expr.EvalInt(evalCtx, chunk.MutRowFromValues("", 1, 2).ToRow())
require.NoError(t, err)
require.Equal(t, int64(10), val)
val, _, err = expr.EvalInt(ctx, chunk.MutRowFromValues("", 3, 4).ToRow())
val, _, err = expr.EvalInt(evalCtx, chunk.MutRowFromValues("", 3, 4).ToRow())
require.NoError(t, err)
require.Equal(t, int64(28), val)
val, _, err = expr2.EvalInt(ctx, chunk.MutRowFromValues("", 1, 2).ToRow())
val, _, err = expr2.EvalInt(evalCtx, chunk.MutRowFromValues("", 1, 2).ToRow())
require.NoError(t, err)
require.Equal(t, int64(10), val)
val, _, err = expr2.EvalInt(ctx, chunk.MutRowFromValues("", 3, 4).ToRow())
val, _, err = expr2.EvalInt(evalCtx, chunk.MutRowFromValues("", 3, 4).ToRow())
require.NoError(t, err)
require.Equal(t, int64(28), val)

expr, err = buildExpr(t, ctx, "(1+a)*(3+b)", expression.WithInputSchemaAndNames(schema, names, nil))
require.NoError(t, err)
val, _, err = expr.EvalInt(ctx, chunk.MutRowFromValues("", 1, 2).ToRow())
val, _, err = expr.EvalInt(evalCtx, chunk.MutRowFromValues("", 1, 2).ToRow())
require.NoError(t, err)
require.Equal(t, int64(10), val)

Expand All @@ -452,52 +452,61 @@ func TestBuildExpression(t *testing.T) {
// use WithAllowCastArray to allow casting to array
expr, err = buildExpr(t, ctx, `cast(json_extract('{"a": [1, 2, 3]}', '$.a') as signed array)`, expression.WithAllowCastArray(true))
require.NoError(t, err)
j, _, err := expr.EvalJSON(ctx, chunk.Row{})
j, _, err := expr.EvalJSON(evalCtx, chunk.Row{})
require.NoError(t, err)
require.Equal(t, types.JSONTypeCodeArray, j.TypeCode)
require.Equal(t, "[1, 2, 3]", j.String())

// default expr
expr, err = buildExpr(t, ctx, "default(id)", expression.WithTableInfo("", tbl))
require.NoError(t, err)
s, _, err := expr.EvalString(ctx, chunk.MutRowFromValues("", 1, 2).ToRow())
s, _, err := expr.EvalString(evalCtx, chunk.MutRowFromValues("", 1, 2).ToRow())
require.NoError(t, err)
require.Equal(t, 36, len(s), s)

expr, err = buildExpr(t, ctx, "default(id)", expression.WithInputSchemaAndNames(schema, names, tbl))
require.NoError(t, err)
s, _, err = expr.EvalString(ctx, chunk.MutRowFromValues("", 1, 2).ToRow())
s, _, err = expr.EvalString(evalCtx, chunk.MutRowFromValues("", 1, 2).ToRow())
require.NoError(t, err)
require.Equal(t, 36, len(s), s)

expr, err = buildExpr(t, ctx, "default(b)", expression.WithTableInfo("", tbl))
require.NoError(t, err)
d, err := expr.Eval(ctx, chunk.MutRowFromValues("", 1, 2).ToRow())
d, err := expr.Eval(evalCtx, chunk.MutRowFromValues("", 1, 2).ToRow())
require.NoError(t, err)
require.Equal(t, types.NewDatum(int64(123)), d)

// WithCastExprTo
expr, err = buildExpr(t, ctx, "1+2+3")
require.NoError(t, err)
require.Equal(t, mysql.TypeLonglong, expr.GetType(ctx).GetType())
require.Equal(t, mysql.TypeLonglong, expr.GetType(evalCtx).GetType())
castTo := types.NewFieldType(mysql.TypeVarchar)
expr, err = buildExpr(t, ctx, "1+2+3", expression.WithCastExprTo(castTo))
require.NoError(t, err)
require.Equal(t, mysql.TypeVarchar, expr.GetType(ctx).GetType())
v, err := expr.Eval(ctx, chunk.Row{})
require.Equal(t, mysql.TypeVarchar, expr.GetType(evalCtx).GetType())
v, err := expr.Eval(evalCtx, chunk.Row{})
require.NoError(t, err)
require.Equal(t, types.KindString, v.Kind())
require.Equal(t, "6", v.GetString())

// param marker
params := variable.NewPlanCacheParamList()
params.Append(types.NewIntDatum(5))
evalCtx = evalCtx.Apply(contextstatic.WithParamList(params))
ctx = ctx.Apply(contextstatic.WithEvalCtx(evalCtx))
expr, err = buildExpr(t, ctx, "a + ?", expression.WithTableInfo("", tbl))
require.NoError(t, err)
require.Equal(t, mysql.TypeLonglong, expr.GetType(evalCtx).GetType())
v, err = expr.Eval(evalCtx, chunk.MutRowFromValues(1, 2, 3).ToRow())
require.NoError(t, err)
require.Equal(t, types.KindInt64, v.Kind())
require.Equal(t, int64(7), v.GetInt64())

// should report error for default expr when source table not provided
_, err = buildExpr(t, ctx, "default(b)", expression.WithInputSchemaAndNames(schema, names, nil))
require.EqualError(t, err, "Unsupported expr *ast.DefaultExpr when source table not provided")

// subquery not supported
_, err = buildExpr(t, ctx, "a + (select b from t)", expression.WithTableInfo("", tbl))
require.EqualError(t, err, "node '*ast.SubqueryExpr' is not allowed when building an expression without planner")

// param marker not supported
_, err = buildExpr(t, ctx, "a + ?", expression.WithTableInfo("", tbl))
require.EqualError(t, err, "node '*driver.ParamMarkerExpr' is not allowed when building an expression without planner")
}
23 changes: 11 additions & 12 deletions pkg/planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func (a *aggOrderByResolver) Enter(inNode ast.Node) (ast.Node, bool) {
a.exprDepth++
if n, ok := inNode.(*driver.ParamMarkerExpr); ok {
if a.exprDepth == 1 {
_, isNull, isExpectedType := getUintFromNode(a.ctx, n, false)
_, isNull, isExpectedType := getUintFromNode(a.ctx.GetExprCtx(), n, false)
// For constant uint expression in top level, it should be treated as position expression.
if !isNull && isExpectedType {
return expression.ConstructPositionExpr(n), true
Expand All @@ -109,7 +109,7 @@ func (a *aggOrderByResolver) Enter(inNode ast.Node) (ast.Node, bool) {

func (a *aggOrderByResolver) Leave(inNode ast.Node) (ast.Node, bool) {
if v, ok := inNode.(*ast.PositionExpr); ok {
pos, isNull, err := expression.PosFromPositionExpr(a.ctx.GetExprCtx(), a.ctx, v)
pos, isNull, err := expression.PosFromPositionExpr(a.ctx.GetExprCtx(), v)
if err != nil {
a.err = err
}
Expand Down Expand Up @@ -2006,7 +2006,7 @@ CheckReferenced:
// getUintFromNode gets uint64 value from ast.Node.
// For ordinary statement, node should be uint64 constant value.
// For prepared statement, node is string. We should convert it to uint64.
func getUintFromNode(ctx base.PlanContext, n ast.Node, mustInt64orUint64 bool) (uVal uint64, isNull bool, isExpectedType bool) {
func getUintFromNode(ctx expression.BuildContext, n ast.Node, mustInt64orUint64 bool) (uVal uint64, isNull bool, isExpectedType bool) {
var val any
switch v := n.(type) {
case *driver.ValueExpr:
Expand All @@ -2024,7 +2024,7 @@ func getUintFromNode(ctx base.PlanContext, n ast.Node, mustInt64orUint64 bool) (
if err != nil {
return 0, false, false
}
str, isNull, err := expression.GetStringFromConstant(ctx.GetExprCtx().GetEvalCtx(), param)
str, isNull, err := expression.GetStringFromConstant(ctx.GetEvalCtx(), param)
if err != nil {
return 0, false, false
}
Expand All @@ -2043,8 +2043,7 @@ func getUintFromNode(ctx base.PlanContext, n ast.Node, mustInt64orUint64 bool) (
return uint64(v), false, true
}
case string:
ctx := ctx.GetSessionVars().StmtCtx.TypeCtx()
uVal, err := types.StrToUint(ctx, v, false)
uVal, err := types.StrToUint(ctx.GetEvalCtx().TypeCtx(), v, false)
if err != nil {
return 0, false, false
}
Expand All @@ -2068,7 +2067,7 @@ func CheckParamTypeInt64orUint64(param *driver.ParamMarkerExpr) (bool, uint64) {
return false, 0
}

func extractLimitCountOffset(ctx base.PlanContext, limit *ast.Limit) (count uint64,
func extractLimitCountOffset(ctx expression.BuildContext, limit *ast.Limit) (count uint64,
offset uint64, err error) {
var isExpectedType bool
if limit.Count != nil {
Expand All @@ -2092,7 +2091,7 @@ func (b *PlanBuilder) buildLimit(src base.LogicalPlan, limit *ast.Limit) (base.L
offset, count uint64
err error
)
if count, offset, err = extractLimitCountOffset(b.ctx, limit); err != nil {
if count, offset, err = extractLimitCountOffset(b.ctx.GetExprCtx(), limit); err != nil {
return nil, err
}

Expand Down Expand Up @@ -2845,7 +2844,7 @@ func (g *gbyResolver) Enter(inNode ast.Node) (ast.Node, bool) {
case *driver.ParamMarkerExpr:
g.isParam = true
if g.exprDepth == 1 && !n.UseAsValueInGbyByClause {
_, isNull, isExpectedType := getUintFromNode(g.ctx, n, false)
_, isNull, isExpectedType := getUintFromNode(g.ctx.GetExprCtx(), n, false)
// For constant uint expression in top level, it should be treated as position expression.
if !isNull && isExpectedType {
return expression.ConstructPositionExpr(n), true
Expand Down Expand Up @@ -2892,7 +2891,7 @@ func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) {
return inNode, false
}
case *ast.PositionExpr:
pos, isNull, err := expression.PosFromPositionExpr(g.ctx.GetExprCtx(), g.ctx, v)
pos, isNull, err := expression.PosFromPositionExpr(g.ctx.GetExprCtx(), v)
if err != nil {
g.err = plannererrors.ErrUnknown.GenWithStackByArgs()
}
Expand Down Expand Up @@ -6069,7 +6068,7 @@ func (b *PlanBuilder) buildWindowFunctionFrameBound(_ context.Context, spec *ast
if bound.Type == ast.CurrentRow {
return bound, nil
}
numRows, _, _ := getUintFromNode(b.ctx, boundClause.Expr, false)
numRows, _, _ := getUintFromNode(b.ctx.GetExprCtx(), boundClause.Expr, false)
bound.Num = numRows
return bound, nil
}
Expand Down Expand Up @@ -6391,7 +6390,7 @@ func (b *PlanBuilder) checkOriginWindowFrameBound(bound *ast.FrameBound, spec *a
if bound.Unit != ast.TimeUnitInvalid {
return plannererrors.ErrWindowRowsIntervalUse.GenWithStackByArgs(getWindowName(spec.Name.O))
}
_, isNull, isExpectedType := getUintFromNode(b.ctx, bound.Expr, false)
_, isNull, isExpectedType := getUintFromNode(b.ctx.GetExprCtx(), bound.Expr, false)
if isNull || !isExpectedType {
return plannererrors.ErrWindowFrameIllegal.GenWithStackByArgs(getWindowName(spec.Name.O))
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/planner/core/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -4011,7 +4011,7 @@ func (b PlanBuilder) getInsertColExpr(ctx context.Context, insertPlan *Insert, m
RetType: &x.Type,
}
case *driver.ParamMarkerExpr:
outExpr, err = expression.ParamMarkerExpression(b.ctx, x, false)
outExpr, err = expression.ParamMarkerExpression(b.ctx.GetExprCtx(), x, false)
default:
b.curClause = fieldList
// subquery in insert values should not reference upper scope
Expand Down
Loading

0 comments on commit 509d1bd

Please sign in to comment.