Skip to content

Commit

Permalink
planner: Display truncate vector in EXPLAIN (#55934)
Browse files Browse the repository at this point in the history
ref #54245
  • Loading branch information
EricZequan authored Sep 29, 2024
1 parent 4df3389 commit 65d740f
Show file tree
Hide file tree
Showing 9 changed files with 153 additions and 29 deletions.
12 changes: 7 additions & 5 deletions pkg/executor/importer/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package importer

import (
"context"
"fmt"
"io"
"math"
"net/url"
Expand Down Expand Up @@ -810,13 +811,14 @@ func (p *Plan) initParameters(plan *plannercore.ImportInto) error {
setClause = sb.String()
}
optionMap := make(map[string]any, len(plan.Options))
var evalCtx expression.EvalContext
if plan.SCtx() != nil {
evalCtx = plan.SCtx().GetExprCtx().GetEvalCtx()
}
for _, opt := range plan.Options {
if opt.Value != nil {
val := opt.Value.StringWithCtx(evalCtx, errors.RedactLogDisable)
// The option attached to the import statement here are all
// parameters entered by the user. TiDB will process the
// parameters entered by the user as constant. so we can
// directly convert it to constant.
cons := opt.Value.(*expression.Constant)
val := fmt.Sprintf("%v", cons.Value.GetValue())
if opt.Name == cloudStorageURIOption {
val = ast.RedactURL(val)
}
Expand Down
24 changes: 12 additions & 12 deletions pkg/executor/testdata/prepare_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@
}
],
"Plan": [
"Projection_3 1.00 root cast(123456789.0123456789012345678901234567890123456789, decimal(len:78)->Column#1",
"Projection_3 1.00 root cast(123456789.0123456789012345678901234567890123456789, decimal(10,0) BINARY)->Column#1",
"└─TableDual_4 1.00 root rows:1"
],
"LastPlanUseCache": "0",
Expand Down Expand Up @@ -290,7 +290,7 @@
}
],
"Plan": [
"Projection_3 1.00 root cast(-123456789.0123456789012345678901234567890123456789, decima(len:79)->Column#1",
"Projection_3 1.00 root cast(-123456789.0123456789012345678901234567890123456789, decimal(10,0) BINARY)->Column#1",
"└─TableDual_4 1.00 root rows:1"
],
"LastPlanUseCache": "0",
Expand Down Expand Up @@ -363,7 +363,7 @@
}
],
"Plan": [
"Projection_3 1.00 root cast(123456789.0123456789012345678901234567890123456789, decimal(len:78)->Column#1",
"Projection_3 1.00 root cast(123456789.0123456789012345678901234567890123456789, decimal(10,0) BINARY)->Column#1",
"└─TableDual_4 1.00 root rows:1"
],
"LastPlanUseCache": "0",
Expand Down Expand Up @@ -431,7 +431,7 @@
}
],
"Plan": [
"Projection_3 1.00 root cast(-123456789.0123456789012345678901234567890123456789, decima(len:79)->Column#1",
"Projection_3 1.00 root cast(-123456789.0123456789012345678901234567890123456789, decimal(10,0) BINARY)->Column#1",
"└─TableDual_4 1.00 root rows:1"
],
"LastPlanUseCache": "0",
Expand Down Expand Up @@ -504,7 +504,7 @@
}
],
"Plan": [
"Projection_3 1.00 root cast(123456789.0123456789012345678901234567890123456789, decimal(len:77)->Column#1",
"Projection_3 1.00 root cast(123456789.0123456789012345678901234567890123456789, decimal(5,4) BINARY)->Column#1",
"└─TableDual_4 1.00 root rows:1"
],
"LastPlanUseCache": "0",
Expand Down Expand Up @@ -572,7 +572,7 @@
}
],
"Plan": [
"Projection_3 1.00 root cast(-123456789.0123456789012345678901234567890123456789, decima(len:78)->Column#1",
"Projection_3 1.00 root cast(-123456789.0123456789012345678901234567890123456789, decimal(5,4) BINARY)->Column#1",
"└─TableDual_4 1.00 root rows:1"
],
"LastPlanUseCache": "0",
Expand Down Expand Up @@ -645,7 +645,7 @@
}
],
"Plan": [
"Projection_3 1.00 root cast(123456789.0123456789012345678901234567890123456789, decimal(len:79)->Column#1",
"Projection_3 1.00 root cast(123456789.0123456789012345678901234567890123456789, decimal(64,30) BINARY)->Column#1",
"└─TableDual_4 1.00 root rows:1"
],
"LastPlanUseCache": "0",
Expand Down Expand Up @@ -713,7 +713,7 @@
}
],
"Plan": [
"Projection_3 1.00 root cast(-123456789.0123456789012345678901234567890123456789, decima(len:80)->Column#1",
"Projection_3 1.00 root cast(-123456789.0123456789012345678901234567890123456789, decimal(64,30) BINARY)->Column#1",
"└─TableDual_4 1.00 root rows:1"
],
"LastPlanUseCache": "0",
Expand Down Expand Up @@ -786,7 +786,7 @@
}
],
"Plan": [
"Projection_3 1.00 root cast(123456789.0123456789012345678901234567890123456789, decimal(len:78)->Column#1",
"Projection_3 1.00 root cast(123456789.0123456789012345678901234567890123456789, decimal(15,5) BINARY)->Column#1",
"└─TableDual_4 1.00 root rows:1"
],
"LastPlanUseCache": "0",
Expand Down Expand Up @@ -854,7 +854,7 @@
}
],
"Plan": [
"Projection_3 1.00 root cast(-123456789.0123456789012345678901234567890123456789, decima(len:79)->Column#1",
"Projection_3 1.00 root cast(-123456789.0123456789012345678901234567890123456789, decimal(15,5) BINARY)->Column#1",
"└─TableDual_4 1.00 root rows:1"
],
"LastPlanUseCache": "0",
Expand Down Expand Up @@ -927,7 +927,7 @@
}
],
"Plan": [
"Projection_3 1.00 root cast(123456789.0123456789012345678901234567890123456789, decimal(len:77)->Column#1",
"Projection_3 1.00 root cast(123456789.0123456789012345678901234567890123456789, decimal(5,5) BINARY)->Column#1",
"└─TableDual_4 1.00 root rows:1"
],
"LastPlanUseCache": "0",
Expand Down Expand Up @@ -995,7 +995,7 @@
}
],
"Plan": [
"Projection_3 1.00 root cast(-123456789.0123456789012345678901234567890123456789, decima(len:78)->Column#1",
"Projection_3 1.00 root cast(-123456789.0123456789012345678901234567890123456789, decimal(5,5) BINARY)->Column#1",
"└─TableDual_4 1.00 root rows:1"
],
"LastPlanUseCache": "0",
Expand Down
4 changes: 2 additions & 2 deletions pkg/expression/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,9 @@ func (c *Constant) StringWithCtx(ctx ParamValues, redact string) string {
return c.DeferredExpr.StringWithCtx(ctx, redact)
}
if redact == perrors.RedactLogDisable {
return fmt.Sprintf("%v", c.Value.GetValue())
return c.Value.TruncatedStringify()
} else if redact == perrors.RedactLogMarker {
return fmt.Sprintf("‹%v›", c.Value.GetValue())
return fmt.Sprintf("‹%s›", c.Value.TruncatedStringify())
}
return "?"
}
Expand Down
12 changes: 3 additions & 9 deletions pkg/expression/explain.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ func (expr *Constant) format(dt types.Datum) string {
return "NULL"
case types.KindString, types.KindBytes, types.KindMysqlEnum, types.KindMysqlSet,
types.KindMysqlJSON, types.KindBinaryLiteral, types.KindMysqlBit:
return fmt.Sprintf("\"%v\"", dt.GetValue())
return fmt.Sprintf("\"%s\"", dt.TruncatedStringify())
}
return fmt.Sprintf("%v", dt.GetValue())
return dt.TruncatedStringify()
}

// ExplainExpressionList generates explain information for a list of expressions.
Expand All @@ -192,13 +192,7 @@ func ExplainExpressionList(ctx EvalContext, exprs []Expression, schema *Schema,
}
case *Constant:
v := expr.StringWithCtx(ctx, errors.RedactLogDisable)
length := 64
if len(v) < length {
redact.WriteRedact(builder, v, redactMode)
} else {
redact.WriteRedact(builder, v[:length], redactMode)
fmt.Fprintf(builder, "(len:%d)", len(v))
}
redact.WriteRedact(builder, v, redactMode)
builder.WriteString("->")
builder.WriteString(schema.Columns[i].StringWithCtx(ctx, redactMode))
default:
Expand Down
3 changes: 2 additions & 1 deletion pkg/expression/integration_test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ go_test(
"main_test.go",
],
flaky = True,
shard_count = 45,
shard_count = 46,
deps = [
"//pkg/config",
"//pkg/domain",
Expand All @@ -34,6 +34,7 @@ go_test(
"//pkg/types",
"//pkg/util/codec",
"//pkg/util/collate",
"//pkg/util/plancodec",
"//pkg/util/sem",
"//pkg/util/timeutil",
"//pkg/util/versioninfo",
Expand Down
62 changes: 62 additions & 0 deletions pkg/expression/integration_test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ import (
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/codec"
"github.com/pingcap/tidb/pkg/util/collate"
"github.com/pingcap/tidb/pkg/util/plancodec"
"github.com/pingcap/tidb/pkg/util/sem"
"github.com/pingcap/tidb/pkg/util/versioninfo"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -321,6 +322,67 @@ func TestVectorColumnInfo(t *testing.T) {
tk.MustGetErrMsg("create table t(embedding VECTOR(16384))", "vector cannot have more than 16383 dimensions")
}

func TestVectorConstantExplain(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("CREATE TABLE t(c VECTOR);")
tk.MustQuery(`EXPLAIN SELECT VEC_COSINE_DISTANCE(c, '[1,2,3,4,5,6,7,8,9,10,11]') FROM t;`).Check(testkit.Rows(
"Projection_3 10000.00 root vec_cosine_distance(test.t.c, [1,2,3,4,5,(6 more)...])->Column#3",
"└─TableReader_5 10000.00 root data:TableFullScan_4",
" └─TableFullScan_4 10000.00 cop[tikv] table:t keep order:false, stats:pseudo",
))
tk.MustQuery(`EXPLAIN SELECT VEC_COSINE_DISTANCE(c, VEC_FROM_TEXT('[1,2,3,4,5,6,7,8,9,10,11]')) FROM t;`).Check(testkit.Rows(
"Projection_3 10000.00 root vec_cosine_distance(test.t.c, [1,2,3,4,5,(6 more)...])->Column#3",
"└─TableReader_5 10000.00 root data:TableFullScan_4",
" └─TableFullScan_4 10000.00 cop[tikv] table:t keep order:false, stats:pseudo",
))
tk.MustQuery(`EXPLAIN SELECT VEC_COSINE_DISTANCE(c, '[1,2,3,4,5,6,7,8,9,10,11]') AS d FROM t ORDER BY d LIMIT 10;`).Check(testkit.Rows(
"Projection_6 10.00 root vec_cosine_distance(test.t.c, [1,2,3,4,5,(6 more)...])->Column#3",
"└─Projection_13 10.00 root test.t.c",
" └─TopN_7 10.00 root Column#4, offset:0, count:10",
" └─Projection_14 10.00 root test.t.c, vec_cosine_distance(test.t.c, [1,2,3,4,5,(6 more)...])->Column#4",
" └─TableReader_12 10.00 root data:TopN_11",
" └─TopN_11 10.00 cop[tikv] vec_cosine_distance(test.t.c, [1,2,3,4,5,(6 more)...]), offset:0, count:10",
" └─TableFullScan_10 10000.00 cop[tikv] table:t keep order:false, stats:pseudo",
))

// Prepare a large Vector string
vb := strings.Builder{}
vb.WriteString("[")
for i := 0; i < 100; i++ {
if i > 0 {
vb.WriteString(",")
}
vb.WriteString("100")
}
vb.WriteString("]")

stmtID, _, _, err := tk.Session().PrepareStmt("SELECT VEC_COSINE_DISTANCE(c, ?) FROM t")
require.Nil(t, err)
rs, err := tk.Session().ExecutePreparedStmt(context.Background(), stmtID, expression.Args2Expressions4Test(vb.String()))
require.NoError(t, err)

p, ok := tk.Session().GetSessionVars().StmtCtx.GetPlan().(base.Plan)
require.True(t, ok)

flat := plannercore.FlattenPhysicalPlan(p, true)
encodedPlanTree := plannercore.EncodeFlatPlan(flat)
planTree, err := plancodec.DecodePlan(encodedPlanTree)
require.NoError(t, err)
fmt.Println(planTree)
fmt.Println("++++")
require.Equal(t, strings.Join([]string{
` id task estRows operator info actRows execution info memory disk`,
` Projection_3 root 10000 vec_cosine_distance(test.t.c, cast([100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100...(len:401), vector))->Column#3 0 time:0s, loops:0 0 Bytes N/A`,
` └─TableReader_5 root 10000 data:TableFullScan_4 0 time:0s, loops:0 0 Bytes N/A`,
` └─TableFullScan_4 cop[tikv] 10000 table:t, keep order:false, stats:pseudo 0 N/A N/A`,
}, "\n"), planTree)

// No need to check result at all.
tk.ResultSetToResult(rs, fmt.Sprintf("%v", rs))
}

func TestFixedVector(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
Expand Down
1 change: 1 addition & 0 deletions pkg/expression/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ func (m *MockExpr) VecEvalJSON(ctx EvalContext, input *chunk.Chunk, result *chun
}

func (m *MockExpr) StringWithCtx(ParamValues, string) string { return "" }

func (m *MockExpr) Eval(ctx EvalContext, row chunk.Row) (types.Datum, error) {
return types.NewDatum(m.i), m.err
}
Expand Down
31 changes: 31 additions & 0 deletions pkg/types/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,37 @@ func (d *Datum) GetValue() any {
}
}

// TruncatedStringify returns the %v representation of the datum
// but truncated (for example, for strings, only first 64 bytes is printed).
// This function is useful in contexts like EXPLAIN.
func (d *Datum) TruncatedStringify() string {
const maxLen = 64

switch d.k {
case KindString, KindBytes:
str := d.GetString()
if len(str) > maxLen {
// This efficiently returns the truncated string without
// less possible allocations.
return fmt.Sprintf("%s...(len:%d)", str[:maxLen], len(str))
}
return str
case KindMysqlJSON:
// For now we can only stringify then truncate.
str := d.GetMysqlJSON().String()
if len(str) > maxLen {
return fmt.Sprintf("%s...(len:%d)", str[:maxLen], len(str))
}
return str
case KindVectorFloat32:
// Vector supports native efficient truncation.
return d.GetVectorFloat32().TruncatedString()
default:
// For other types, no truncation is needed.
return fmt.Sprintf("%v", d.GetValue())
}
}

// SetValueWithDefaultCollation sets any kind of value.
func (d *Datum) SetValueWithDefaultCollation(val any) {
switch x := val.(type) {
Expand Down
33 changes: 33 additions & 0 deletions pkg/types/vector.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package types

import (
"encoding/binary"
"fmt"
"math"
"strconv"
"unsafe"
Expand Down Expand Up @@ -93,6 +94,38 @@ func (v VectorFloat32) Elements() []float32 {
return unsafe.Slice((*float32)(unsafe.Pointer(&v.data[4])), l)
}

// TruncatedString prints the vector in a truncated form, which is useful for
// outputting in logs or EXPLAIN statements.
func (v VectorFloat32) TruncatedString() string {
const (
maxDisplayElements = 5
)

truncatedElements := 0
elements := v.Elements()

if len(elements) > maxDisplayElements {
truncatedElements = len(elements) - maxDisplayElements
elements = elements[:maxDisplayElements]
}

buf := make([]byte, 0, 2+v.Len()*2)
buf = append(buf, '[')
for i, v := range elements {
if i > 0 {
buf = append(buf, ","...)
}
buf = strconv.AppendFloat(buf, float64(v), 'g', 2, 32)
}
if truncatedElements > 0 {
buf = append(buf, fmt.Sprintf(",(%d more)...", truncatedElements)...)
}
buf = append(buf, ']')

// buf is not used elsewhere, so it's safe to just cast to String
return unsafe.String(unsafe.SliceData(buf), len(buf))
}

// String returns a string representation of the vector, which can be parsed later.
func (v VectorFloat32) String() string {
elements := v.Elements()
Expand Down

0 comments on commit 65d740f

Please sign in to comment.