diff --git a/go.mod b/go.mod index 7e8852c518981..d83f947961a6b 100644 --- a/go.mod +++ b/go.mod @@ -239,7 +239,7 @@ require ( github.com/jfcg/sixb v1.3.8 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/jonboulle/clockwork v0.4.0 // indirect - github.com/json-iterator/go v1.1.12 // indirect + github.com/json-iterator/go v1.1.12 github.com/klauspost/cpuid v1.3.1 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect diff --git a/pkg/ddl/column.go b/pkg/ddl/column.go index 7d81d99142d76..c479c005a8c82 100644 --- a/pkg/ddl/column.go +++ b/pkg/ddl/column.go @@ -399,6 +399,8 @@ func needChangeColumnData(oldCol, newCol *model.ColumnInfo) bool { if types.IsBinaryStr(&oldCol.FieldType) { return newCol.GetFlen() != oldCol.GetFlen() } + case mysql.TypeTiDBVectorFloat32: + return newCol.GetFlen() != types.UnspecifiedLength && oldCol.GetFlen() != newCol.GetFlen() } return needTruncationOrToggleSign() diff --git a/pkg/ddl/executor.go b/pkg/ddl/executor.go index add9846fc9048..0086afbfc86c4 100644 --- a/pkg/ddl/executor.go +++ b/pkg/ddl/executor.go @@ -1636,7 +1636,7 @@ func getDefaultValue(ctx exprctx.BuildContext, col *table.Column, option *ast.Co } if v.Kind() == types.KindBinaryLiteral || v.Kind() == types.KindMysqlBit { - if types.IsTypeBlob(tp) || tp == mysql.TypeJSON { + if types.IsTypeBlob(tp) || tp == mysql.TypeJSON || tp == mysql.TypeTiDBVectorFloat32 { // BLOB/TEXT/JSON column cannot have a default value. // Skip the unnecessary decode procedure. return v.GetString(), false, err @@ -3483,7 +3483,8 @@ func checkPartitionByList(ctx sessionctx.Context, tbInfo *model.TableInfo) error func isValidKeyPartitionColType(fieldType types.FieldType) bool { switch fieldType.GetType() { - case mysql.TypeBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeJSON, mysql.TypeGeometry, mysql.TypeTiDBVectorFloat32: + case mysql.TypeBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeJSON, mysql.TypeGeometry, + mysql.TypeTiDBVectorFloat32: return false default: return true diff --git a/pkg/ddl/index.go b/pkg/ddl/index.go index 6331a53a2cea4..4c3aa25bba417 100644 --- a/pkg/ddl/index.go +++ b/pkg/ddl/index.go @@ -204,6 +204,14 @@ func checkIndexColumn(ctx sessionctx.Context, col *model.ColumnInfo, indexColumn return errors.Trace(dbterror.ErrJSONUsedAsKey.GenWithStackByArgs(col.Name.O)) } + // Vector column cannot index, for now. + if col.FieldType.GetType() == mysql.TypeTiDBVectorFloat32 { + if col.Hidden { + return errors.Errorf("Cannot create an expression index on a function that returns a VECTOR value") + } + return errors.Trace(dbterror.ErrWrongKeyColumn.GenWithStackByArgs(col.Name)) + } + // Length must be specified and non-zero for BLOB and TEXT column indexes. if types.IsTypeBlob(col.FieldType.GetType()) { if indexColumnLen == types.UnspecifiedLength { diff --git a/pkg/executor/aggfuncs/aggfuncs.go b/pkg/executor/aggfuncs/aggfuncs.go index d2a3083581a19..586666554b9e6 100644 --- a/pkg/executor/aggfuncs/aggfuncs.go +++ b/pkg/executor/aggfuncs/aggfuncs.go @@ -33,6 +33,7 @@ var ( _ AggFunc = (*countOriginal4Time)(nil) _ AggFunc = (*countOriginal4Duration)(nil) _ AggFunc = (*countOriginal4JSON)(nil) + _ AggFunc = (*countOriginal4VectorFloat32)(nil) _ AggFunc = (*countOriginal4String)(nil) _ AggFunc = (*countOriginalWithDistinct4Int)(nil) _ AggFunc = (*countOriginalWithDistinct4Real)(nil) @@ -61,6 +62,7 @@ var ( _ AggFunc = (*firstRow4Float32)(nil) _ AggFunc = (*firstRow4Float64)(nil) _ AggFunc = (*firstRow4JSON)(nil) + _ AggFunc = (*firstRow4VectorFloat32)(nil) _ AggFunc = (*firstRow4Enum)(nil) _ AggFunc = (*firstRow4Set)(nil) @@ -73,6 +75,7 @@ var ( _ AggFunc = (*maxMin4String)(nil) _ AggFunc = (*maxMin4Duration)(nil) _ AggFunc = (*maxMin4JSON)(nil) + _ AggFunc = (*maxMin4VectorFloat32)(nil) _ AggFunc = (*maxMin4Enum)(nil) _ AggFunc = (*maxMin4Set)(nil) diff --git a/pkg/executor/aggfuncs/builder.go b/pkg/executor/aggfuncs/builder.go index abbff30bae6fb..901ad0e469e74 100644 --- a/pkg/executor/aggfuncs/builder.go +++ b/pkg/executor/aggfuncs/builder.go @@ -251,6 +251,8 @@ func buildCount(ctx expression.EvalContext, aggFuncDesc *aggregation.AggFuncDesc return &countOriginal4Duration{baseCount{base}} case types.ETJson: return &countOriginal4JSON{baseCount{base}} + case types.ETVectorFloat32: + return &countOriginal4VectorFloat32{baseCount{base}} case types.ETString: return &countOriginal4String{baseCount{base}} } @@ -378,6 +380,8 @@ func buildFirstRow(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { return &firstRow4String{base} case types.ETJson: return &firstRow4JSON{base} + case types.ETVectorFloat32: + return &firstRow4VectorFloat32{base} } } return nil @@ -431,6 +435,8 @@ func buildMaxMin(aggFuncDesc *aggregation.AggFuncDesc, ordinal int, isMax bool) return &maxMin4Duration{base} case types.ETJson: return &maxMin4JSON{base} + case types.ETVectorFloat32: + return &maxMin4VectorFloat32{base} } } return nil diff --git a/pkg/executor/aggfuncs/func_count.go b/pkg/executor/aggfuncs/func_count.go index 3b21798b6c5f6..3f111b9f2b187 100644 --- a/pkg/executor/aggfuncs/func_count.go +++ b/pkg/executor/aggfuncs/func_count.go @@ -360,6 +360,55 @@ func (e *countOriginal4JSON) Slide(sctx AggFuncUpdateContext, getRow func(uint64 return nil } +type countOriginal4VectorFloat32 struct { + baseCount +} + +func (e *countOriginal4VectorFloat32) UpdatePartialResult(sctx AggFuncUpdateContext, rowsInGroup []chunk.Row, pr PartialResult) (memDelta int64, err error) { + p := (*partialResult4Count)(pr) + + for _, row := range rowsInGroup { + _, isNull, err := e.args[0].EvalVectorFloat32(sctx, row) + if err != nil { + return 0, err + } + if isNull { + continue + } + + *p++ + } + + return 0, nil +} + +var _ SlidingWindowAggFunc = &countOriginal4VectorFloat32{} + +func (e *countOriginal4VectorFloat32) Slide(sctx AggFuncUpdateContext, getRow func(uint64) chunk.Row, lastStart, lastEnd uint64, shiftStart, shiftEnd uint64, pr PartialResult) error { + p := (*partialResult4Count)(pr) + for i := uint64(0); i < shiftStart; i++ { + _, isNull, err := e.args[0].EvalVectorFloat32(sctx, getRow(lastStart+i)) + if err != nil { + return err + } + if isNull { + continue + } + *p-- + } + for i := uint64(0); i < shiftEnd; i++ { + _, isNull, err := e.args[0].EvalVectorFloat32(sctx, getRow(lastEnd+i)) + if err != nil { + return err + } + if isNull { + continue + } + *p++ + } + return nil +} + type countOriginal4String struct { baseCount } diff --git a/pkg/executor/aggfuncs/func_count_distinct.go b/pkg/executor/aggfuncs/func_count_distinct.go index 13c6a2f53e808..85d85b74b963b 100644 --- a/pkg/executor/aggfuncs/func_count_distinct.go +++ b/pkg/executor/aggfuncs/func_count_distinct.go @@ -401,6 +401,13 @@ func evalAndEncode( break } encodedBytes = val.HashValue(encodedBytes) + case types.ETVectorFloat32: + var val types.VectorFloat32 + val, isNull, err = arg.EvalVectorFloat32(sctx, row) + if err != nil || isNull { + break + } + encodedBytes = val.SerializeTo(encodedBytes) case types.ETString: var val string val, isNull, err = arg.EvalString(sctx, row) diff --git a/pkg/executor/aggfuncs/func_first_row.go b/pkg/executor/aggfuncs/func_first_row.go index 70a306551179a..aeba22e27f054 100644 --- a/pkg/executor/aggfuncs/func_first_row.go +++ b/pkg/executor/aggfuncs/func_first_row.go @@ -39,6 +39,8 @@ const ( DefPartialResult4FirstRowDurationSize = int64(unsafe.Sizeof(partialResult4FirstRowDuration{})) // DefPartialResult4FirstRowJSONSize is the size of partialResult4FirstRowJSON DefPartialResult4FirstRowJSONSize = int64(unsafe.Sizeof(partialResult4FirstRowJSON{})) + // DefPartialResult4FirstRowVectorFloat32Size is the size of partialResult4FirstRowVectorFloat32 + DefPartialResult4FirstRowVectorFloat32Size = int64(unsafe.Sizeof(partialResult4FirstRowVectorFloat32{})) // DefPartialResult4FirstRowDecimalSize is the size of partialResult4FirstRowDecimal DefPartialResult4FirstRowDecimalSize = int64(unsafe.Sizeof(partialResult4FirstRowDecimal{})) // DefPartialResult4FirstRowEnumSize is the size of partialResult4FirstRowEnum @@ -104,6 +106,12 @@ type partialResult4FirstRowJSON struct { val types.BinaryJSON } +type partialResult4FirstRowVectorFloat32 struct { + basePartialResult4FirstRow + + val types.VectorFloat32 +} + type partialResult4FirstRowEnum struct { basePartialResult4FirstRow @@ -579,6 +587,52 @@ func (e *firstRow4JSON) deserializeForSpill(helper *deserializeHelper) (PartialR return pr, memDelta } +type firstRow4VectorFloat32 struct { + baseAggFunc +} + +func (*firstRow4VectorFloat32) AllocPartialResult() (pr PartialResult, memDelta int64) { + return PartialResult(new(partialResult4FirstRowVectorFloat32)), DefPartialResult4FirstRowVectorFloat32Size +} + +func (*firstRow4VectorFloat32) ResetPartialResult(pr PartialResult) { + p := (*partialResult4FirstRowVectorFloat32)(pr) + p.isNull, p.gotFirstRow = false, false +} + +func (e *firstRow4VectorFloat32) UpdatePartialResult(sctx AggFuncUpdateContext, rowsInGroup []chunk.Row, pr PartialResult) (memDelta int64, err error) { + p := (*partialResult4FirstRowVectorFloat32)(pr) + if p.gotFirstRow { + return memDelta, nil + } + if len(rowsInGroup) > 0 { + input, isNull, err := e.args[0].EvalVectorFloat32(sctx, rowsInGroup[0]) + if err != nil { + return memDelta, err + } + p.gotFirstRow, p.isNull, p.val = true, isNull, input.Clone() + memDelta += int64(input.EstimatedMemUsage()) + } + return memDelta, nil +} +func (*firstRow4VectorFloat32) MergePartialResult(_ AggFuncUpdateContext, src, dst PartialResult) (memDelta int64, err error) { + p1, p2 := (*partialResult4FirstRowVectorFloat32)(src), (*partialResult4FirstRowVectorFloat32)(dst) + if !p2.gotFirstRow { + *p2 = *p1 + } + return memDelta, nil +} + +func (e *firstRow4VectorFloat32) AppendFinalResult2Chunk(_ AggFuncUpdateContext, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4FirstRowVectorFloat32)(pr) + if p.isNull || !p.gotFirstRow { + chk.AppendNull(e.ordinal) + return nil + } + chk.AppendVectorFloat32(e.ordinal, p.val) + return nil +} + type firstRow4Decimal struct { baseAggFunc } diff --git a/pkg/executor/aggfuncs/func_max_min.go b/pkg/executor/aggfuncs/func_max_min.go index 9903a880b2523..2eb975051545e 100644 --- a/pkg/executor/aggfuncs/func_max_min.go +++ b/pkg/executor/aggfuncs/func_max_min.go @@ -156,6 +156,8 @@ const ( DefPartialResult4MaxMinStringSize = int64(unsafe.Sizeof(partialResult4MaxMinString{})) // DefPartialResult4MaxMinJSONSize is the size of partialResult4MaxMinJSON DefPartialResult4MaxMinJSONSize = int64(unsafe.Sizeof(partialResult4MaxMinJSON{})) + // DefPartialResult4MaxMinVectorFloat32Size is the size of partialResult4MaxMinVectorFloat32 + DefPartialResult4MaxMinVectorFloat32Size = int64(unsafe.Sizeof(partialResult4MaxMinVectorFloat32{})) // DefPartialResult4MaxMinEnumSize is the size of partialResult4MaxMinEnum DefPartialResult4MaxMinEnumSize = int64(unsafe.Sizeof(partialResult4MaxMinEnum{})) // DefPartialResult4MaxMinSetSize is the size of partialResult4MaxMinSet @@ -221,6 +223,11 @@ type partialResult4MaxMinJSON struct { isNull bool } +type partialResult4MaxMinVectorFloat32 struct { + val types.VectorFloat32 + isNull bool +} + type partialResult4MaxMinEnum struct { val types.Enum isNull bool @@ -1632,6 +1639,75 @@ func (e *maxMin4JSON) deserializeForSpill(helper *deserializeHelper) (PartialRes return pr, memDelta } +type maxMin4VectorFloat32 struct { + baseMaxMinAggFunc +} + +func (*maxMin4VectorFloat32) AllocPartialResult() (pr PartialResult, memDelta int64) { + p := new(partialResult4MaxMinVectorFloat32) + p.isNull = true + return PartialResult(p), DefPartialResult4MaxMinVectorFloat32Size +} + +func (*maxMin4VectorFloat32) ResetPartialResult(pr PartialResult) { + p := (*partialResult4MaxMinVectorFloat32)(pr) + p.isNull = true +} + +func (e *maxMin4VectorFloat32) AppendFinalResult2Chunk(_ AggFuncUpdateContext, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4MaxMinVectorFloat32)(pr) + if p.isNull { + chk.AppendNull(e.ordinal) + return nil + } + chk.AppendVectorFloat32(e.ordinal, p.val) + return nil +} + +func (e *maxMin4VectorFloat32) UpdatePartialResult(sctx AggFuncUpdateContext, rowsInGroup []chunk.Row, pr PartialResult) (memDelta int64, err error) { + p := (*partialResult4MaxMinVectorFloat32)(pr) + for _, row := range rowsInGroup { + input, isNull, err := e.args[0].EvalVectorFloat32(sctx, row) + if err != nil { + return memDelta, err + } + if isNull { + continue + } + if p.isNull { + p.val = input.Clone() + memDelta += int64(input.EstimatedMemUsage()) + p.isNull = false + continue + } + cmp := input.Compare(p.val) + if e.isMax && cmp > 0 || !e.isMax && cmp < 0 { + oldMem := p.val.EstimatedMemUsage() + newMem := input.EstimatedMemUsage() + memDelta += int64(newMem - oldMem) + p.val = input.Clone() + } + } + return memDelta, nil +} + +func (e *maxMin4VectorFloat32) MergePartialResult(_ AggFuncUpdateContext, src, dst PartialResult) (memDelta int64, err error) { + p1, p2 := (*partialResult4MaxMinVectorFloat32)(src), (*partialResult4MaxMinVectorFloat32)(dst) + if p1.isNull { + return 0, nil + } + if p2.isNull { + *p2 = *p1 + return 0, nil + } + cmp := p1.val.Compare(p2.val) + if e.isMax && cmp > 0 || !e.isMax && cmp < 0 { + p2.val = p1.val + p2.isNull = false + } + return 0, nil +} + type maxMin4Enum struct { baseMaxMinAggFunc } diff --git a/pkg/executor/aggfuncs/func_value.go b/pkg/executor/aggfuncs/func_value.go index 3753cc0178888..45f3044a51c0c 100644 --- a/pkg/executor/aggfuncs/func_value.go +++ b/pkg/executor/aggfuncs/func_value.go @@ -15,6 +15,7 @@ package aggfuncs import ( + "fmt" "unsafe" "github.com/pingcap/tidb/pkg/expression" @@ -47,6 +48,8 @@ const ( DefValue4StringSize = int64(unsafe.Sizeof(value4String{})) // DefValue4JSONSize is the size of value4JSON DefValue4JSONSize = int64(unsafe.Sizeof(value4JSON{})) + // DefValue4VectorFloat32Size is the size of value4VectorFloat32 + DefValue4VectorFloat32Size = int64(unsafe.Sizeof(value4VectorFloat32{})) ) // valueEvaluator is used to evaluate values for `first_value`, `last_value`, `nth_value`, @@ -207,6 +210,26 @@ func (v *value4JSON) appendResult(chk *chunk.Chunk, colIdx int) { } } +type value4VectorFloat32 struct { + val types.VectorFloat32 + isNull bool +} + +func (v *value4VectorFloat32) evaluateRow(ctx expression.EvalContext, expr expression.Expression, row chunk.Row) (memDelta int64, err error) { + originalLength := v.val.EstimatedMemUsage() + v.val, v.isNull, err = expr.EvalVectorFloat32(ctx, row) + v.val = v.val.Clone() // deep copy to avoid content change. + return int64(v.val.EstimatedMemUsage() - originalLength), err +} + +func (v *value4VectorFloat32) appendResult(chk *chunk.Chunk, colIdx int) { + if v.isNull { + chk.AppendNull(colIdx) + } else { + chk.AppendVectorFloat32(colIdx, v.val) + } +} + func buildValueEvaluator(tp *types.FieldType) (ve valueEvaluator, memDelta int64) { evalType := tp.EvalType() if tp.GetType() == mysql.TypeBit { @@ -232,6 +255,10 @@ func buildValueEvaluator(tp *types.FieldType) (ve valueEvaluator, memDelta int64 return &value4String{}, DefValue4StringSize case types.ETJson: return &value4JSON{}, DefValue4JSONSize + case types.ETVectorFloat32: + return &value4VectorFloat32{}, DefValue4VectorFloat32Size + default: + panic(fmt.Sprintf("unsupported eval type %v", evalType)) } return nil, 0 } diff --git a/pkg/executor/internal/vecgroupchecker/BUILD.bazel b/pkg/executor/internal/vecgroupchecker/BUILD.bazel index 3a143a8cf4ace..4cc2db005b45f 100644 --- a/pkg/executor/internal/vecgroupchecker/BUILD.bazel +++ b/pkg/executor/internal/vecgroupchecker/BUILD.bazel @@ -10,6 +10,7 @@ go_library( "//pkg/types", "//pkg/util/chunk", "//pkg/util/codec", + "@com_github_pingcap_errors//:errors", ], ) diff --git a/pkg/executor/internal/vecgroupchecker/vec_group_checker.go b/pkg/executor/internal/vecgroupchecker/vec_group_checker.go index abb070420254e..242eb21009a28 100644 --- a/pkg/executor/internal/vecgroupchecker/vec_group_checker.go +++ b/pkg/executor/internal/vecgroupchecker/vec_group_checker.go @@ -16,8 +16,8 @@ package vecgroupchecker import ( "bytes" - "fmt" + "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" @@ -291,6 +291,27 @@ func (e *VecGroupChecker) getFirstAndLastRowDatum( } else { lastRowDatum.SetNull() } + case types.ETVectorFloat32: + firstRowVal, firstRowIsNull, err := item.EvalVectorFloat32(e.ctx, chk.GetRow(0)) + if err != nil { + return err + } + lastRowVal, lastRowIsNull, err := item.EvalVectorFloat32(e.ctx, chk.GetRow(numRows-1)) + if err != nil { + return err + } + if !firstRowIsNull { + // make a copy to avoid DATA RACE + firstRowDatum.SetVectorFloat32(firstRowVal.Clone()) + } else { + firstRowDatum.SetNull() + } + if !lastRowIsNull { + // make a copy to avoid DATA RACE + lastRowDatum.SetVectorFloat32(lastRowVal.Clone()) + } else { + lastRowDatum.SetNull() + } case types.ETString: firstRowVal, firstRowIsNull, err := item.EvalString(e.ctx, chk.GetRow(0)) if err != nil { @@ -315,7 +336,7 @@ func (e *VecGroupChecker) getFirstAndLastRowDatum( lastRowDatum.SetNull() } default: - err = fmt.Errorf("invalid eval type %v", eType) + err = errors.Errorf("unsupported type %s during evaluation", eType) return err } @@ -452,6 +473,30 @@ func (e *VecGroupChecker) evalGroupItemsAndResolveGroups( } previousIsNull = isNull } + case types.ETVectorFloat32: + var previousKey, key types.VectorFloat32 + if !previousIsNull { + previousKey = col.GetVectorFloat32(0) + } + for i := 1; i < numRows; i++ { + isNull := col.IsNull(i) + if !isNull { + key = col.GetVectorFloat32(i) + } + if e.sameGroup[i] { + if isNull == previousIsNull { + if !isNull && previousKey.Compare(key) != 0 { + e.sameGroup[i] = false + } + } else { + e.sameGroup[i] = false + } + } + if !isNull { + previousKey = key + } + previousIsNull = isNull + } case types.ETString: previousKey := codec.ConvertByCollationStr(col.GetString(0), tp) for i := 1; i < numRows; i++ { @@ -466,7 +511,7 @@ func (e *VecGroupChecker) evalGroupItemsAndResolveGroups( previousIsNull = isNull } default: - err = fmt.Errorf("invalid eval type %v", eType) + err = errors.Errorf("unsupported type %s during evaluation", eType) } if err != nil { return err diff --git a/pkg/executor/reload_expr_pushdown_blacklist.go b/pkg/executor/reload_expr_pushdown_blacklist.go index c751ec3f9623a..65411f9715ec9 100644 --- a/pkg/executor/reload_expr_pushdown_blacklist.go +++ b/pkg/executor/reload_expr_pushdown_blacklist.go @@ -353,4 +353,12 @@ var funcName2Alias = map[string]string{ "json_depth": ast.JSONDepth, "json_keys": ast.JSONKeys, "json_length": ast.JSONLength, + "vec_dims": ast.VecDims, + "vec_l1_distance": ast.VecL1Distance, + "vec_l2_distance": ast.VecL2Distance, + "vec_negative_inner_product": ast.VecNegativeInnerProduct, + "vec_cosine_distance": ast.VecCosineDistance, + "vec_l2_norm": ast.VecL2Norm, + "vec_from_text": ast.VecFromText, + "vec_as_text": ast.VecAsText, } diff --git a/pkg/executor/select_into.go b/pkg/executor/select_into.go index 5161291d1eada..7c8033d4a8f0c 100644 --- a/pkg/executor/select_into.go +++ b/pkg/executor/select_into.go @@ -189,6 +189,8 @@ func (s *SelectIntoExec) dumpToOutfile() error { s.fieldBuf = append(s.fieldBuf, row.GetSet(j).String()...) case mysql.TypeJSON: s.fieldBuf = append(s.fieldBuf, row.GetJSON(j).String()...) + case mysql.TypeTiDBVectorFloat32: + s.fieldBuf = append(s.fieldBuf, row.GetVectorFloat32(j).String()...) } switch col.GetType(s.Ctx().GetExprCtx().GetEvalCtx()).EvalType() { diff --git a/pkg/executor/show.go b/pkg/executor/show.go index 523593d340eb5..e7c41f4f4469e 100644 --- a/pkg/executor/show.go +++ b/pkg/executor/show.go @@ -2007,6 +2007,8 @@ func (e *ShowExec) appendRow(row []any) { e.result.AppendTime(i, x) case types.BinaryJSON: e.result.AppendJSON(i, x) + case types.VectorFloat32: + e.result.AppendVectorFloat32(i, x) case types.Duration: e.result.AppendDuration(i, x) case types.Enum: diff --git a/pkg/expression/BUILD.bazel b/pkg/expression/BUILD.bazel index 10530152578a7..00ec95d759bbe 100644 --- a/pkg/expression/BUILD.bazel +++ b/pkg/expression/BUILD.bazel @@ -43,6 +43,7 @@ go_library( "builtin_time.go", "builtin_time_vec.go", "builtin_time_vec_generated.go", + "builtin_vec.go", "builtin_vectorized.go", "chunk_executor.go", "collation.go", diff --git a/pkg/expression/aggregation/aggregation.go b/pkg/expression/aggregation/aggregation.go index d02625b7865d7..8dae3a85e59e0 100644 --- a/pkg/expression/aggregation/aggregation.go +++ b/pkg/expression/aggregation/aggregation.go @@ -239,6 +239,9 @@ func CheckAggPushDown(ctx expression.EvalContext, aggFunc *AggFuncDesc, storeTyp if aggFunc.Name == ast.AggFuncApproxPercentile { return false } + if checkVectorAggPushDown(ctx, aggFunc) != nil { + return false + } ret := true switch storeType { case kv.TiFlash: @@ -253,6 +256,22 @@ func CheckAggPushDown(ctx expression.EvalContext, aggFunc *AggFuncDesc, storeTyp return ret } +// checkVectorAggPushDown returns error if this aggregate function is not supported to push down. +// - The aggregate function is not calculated over a Vector column (returns nil) +// - The aggregate function is calculated over a Vector column and the function is supported (returns nil) +// - The aggregate function is calculated over a Vector column and the function is not supported (returns error) +func checkVectorAggPushDown(ctx expression.EvalContext, aggFunc *AggFuncDesc) error { + switch aggFunc.Name { + case ast.AggFuncCount, ast.AggFuncMin, ast.AggFuncMax, ast.AggFuncFirstRow: + return nil + default: + if aggFunc.Args[0].GetType(ctx).GetType() == mysql.TypeTiDBVectorFloat32 { + return errors.Errorf("Aggregate function %s is not supported for VectorFloat32", aggFunc.Name) + } + } + return nil +} + // CheckAggPushFlash checks whether an agg function can be pushed to flash storage. func CheckAggPushFlash(ctx expression.EvalContext, aggFunc *AggFuncDesc) bool { for _, arg := range aggFunc.Args { diff --git a/pkg/expression/aggregation/base_func.go b/pkg/expression/aggregation/base_func.go index 1fc9f1ff7cbfd..2bfd2a1209fbb 100644 --- a/pkg/expression/aggregation/base_func.go +++ b/pkg/expression/aggregation/base_func.go @@ -423,8 +423,10 @@ func (a *baseFuncDesc) WrapCastForAggArgs(ctx expression.BuildContext) { castFunc = expression.WrapWithCastAsDuration case types.ETJson: castFunc = expression.WrapWithCastAsJSON + case types.ETVectorFloat32: + castFunc = expression.WrapWithCastAsVectorFloat32 default: - panic("should never happen in baseFuncDesc.WrapCastForAggArgs") + panic(fmt.Sprintf("unsupported type %s during evaluation", retTp.EvalType())) } for i := range a.Args { // Do not cast the second args of these functions, as they are simply non-negative numbers. diff --git a/pkg/expression/builtin.go b/pkg/expression/builtin.go index b38c379677ddf..e01405efda29a 100644 --- a/pkg/expression/builtin.go +++ b/pkg/expression/builtin.go @@ -154,6 +154,8 @@ func newReturnFieldTypeForBaseBuiltinFunc(funcName string, retType types.EvalTyp fieldType = types.NewFieldTypeBuilder().SetType(mysql.TypeDuration).SetFlag(mysql.BinaryFlag).SetFlen(mysql.MaxDurationWidthWithFsp).SetDecimal(types.MaxFsp).BuildP() case types.ETJson: fieldType = types.NewFieldTypeBuilder().SetType(mysql.TypeJSON).SetFlag(mysql.BinaryFlag).SetFlen(mysql.MaxBlobWidth).SetCharset(mysql.DefaultCharset).SetCollate(mysql.DefaultCollationName).BuildP() + case types.ETVectorFloat32: + fieldType = types.NewFieldTypeBuilder().SetType(mysql.TypeTiDBVectorFloat32).SetFlag(mysql.BinaryFlag).SetFlen(types.UnspecifiedLength).BuildP() } if mysql.HasBinaryFlag(fieldType.GetFlag()) && fieldType.GetType() != mysql.TypeJSON { fieldType.SetCharset(charset.CharsetBin) @@ -202,6 +204,10 @@ func newBaseBuiltinFuncWithTp(ctx BuildContext, funcName string, args []Expressi args[i] = WrapWithCastAsDuration(ctx, args[i]) case types.ETJson: args[i] = WrapWithCastAsJSON(ctx, args[i]) + case types.ETVectorFloat32: + args[i] = WrapWithCastAsVectorFloat32(ctx, args[i]) + default: + return baseBuiltinFunc{}, errors.Errorf("%s is not supported", argTps[i]) } } @@ -330,6 +336,10 @@ func (*baseBuiltinFunc) vecEvalJSON(EvalContext, *chunk.Chunk, *chunk.Column) er return errors.Errorf("baseBuiltinFunc.vecEvalJSON() should never be called, please contact the TiDB team for help") } +func (*baseBuiltinFunc) vecEvalVectorFloat32(EvalContext, *chunk.Chunk, *chunk.Column) error { + return errors.Errorf("baseBuiltinFunc.vecEvalVectorFloat32() should never be called, please contact the TiDB team for help") +} + func (*baseBuiltinFunc) evalInt(EvalContext, chunk.Row) (int64, bool, error) { return 0, false, errors.Errorf("baseBuiltinFunc.evalInt() should never be called, please contact the TiDB team for help") } @@ -358,6 +368,10 @@ func (*baseBuiltinFunc) evalJSON(EvalContext, chunk.Row) (types.BinaryJSON, bool return types.BinaryJSON{}, false, errors.Errorf("baseBuiltinFunc.evalJSON() should never be called, please contact the TiDB team for help") } +func (*baseBuiltinFunc) evalVectorFloat32(EvalContext, chunk.Row) (types.VectorFloat32, bool, error) { + return types.ZeroVectorFloat32, false, errors.Errorf("baseBuiltinFunc.evalVectorFloat32() should never be called, please contact the TiDB team for help") +} + func (*baseBuiltinFunc) vectorized() bool { return false } @@ -476,6 +490,9 @@ type vecBuiltinFunc interface { // vecEvalJSON evaluates this builtin function in a vectorized manner. vecEvalJSON(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error + + // vecEvalVectorFloat32 evaluates this builtin function in a vectorized manner. + vecEvalVectorFloat32(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error } // builtinFunc stands for a particular function signature. @@ -497,6 +514,7 @@ type builtinFunc interface { evalDuration(ctx EvalContext, row chunk.Row) (val types.Duration, isNull bool, err error) // evalJSON evaluates JSON representation of builtinFunc by given row. evalJSON(ctx EvalContext, row chunk.Row) (val types.BinaryJSON, isNull bool, err error) + evalVectorFloat32(ctx EvalContext, row chunk.Row) (val types.VectorFloat32, isNull bool, err error) // getArgs returns the arguments expressions. getArgs() []Expression // equal check if this function equals to another function. @@ -910,6 +928,16 @@ var funcs = map[string]functionClass{ ast.JSONKeys: &jsonKeysFunctionClass{baseFunctionClass{ast.JSONKeys, 1, 2}}, ast.JSONLength: &jsonLengthFunctionClass{baseFunctionClass{ast.JSONLength, 1, 2}}, + // vector functions (TiDB extension) + ast.VecDims: &vecDimsFunctionClass{baseFunctionClass{ast.VecDims, 1, 1}}, + ast.VecL1Distance: &vecL1DistanceFunctionClass{baseFunctionClass{ast.VecL1Distance, 2, 2}}, + ast.VecL2Distance: &vecL2DistanceFunctionClass{baseFunctionClass{ast.VecL2Distance, 2, 2}}, + ast.VecNegativeInnerProduct: &vecNegativeInnerProductFunctionClass{baseFunctionClass{ast.VecNegativeInnerProduct, 2, 2}}, + ast.VecCosineDistance: &vecCosineDistanceFunctionClass{baseFunctionClass{ast.VecCosineDistance, 2, 2}}, + ast.VecL2Norm: &vecL2NormFunctionClass{baseFunctionClass{ast.VecL2Norm, 1, 1}}, + ast.VecFromText: &vecFromTextFunctionClass{baseFunctionClass{ast.VecFromText, 1, 1}}, + ast.VecAsText: &vecAsTextFunctionClass{baseFunctionClass{ast.VecAsText, 1, 1}}, + // TiDB internal function. ast.TiDBDecodeKey: &tidbDecodeKeyFunctionClass{baseFunctionClass{ast.TiDBDecodeKey, 1, 1}}, // This function is used to show tidb-server version info. diff --git a/pkg/expression/builtin_arithmetic.go b/pkg/expression/builtin_arithmetic.go index d2cdb64f9b329..3b7034c303315 100644 --- a/pkg/expression/builtin_arithmetic.go +++ b/pkg/expression/builtin_arithmetic.go @@ -57,6 +57,10 @@ var ( _ builtinFunc = &builtinArithmeticModIntSignedSignedSig{} _ builtinFunc = &builtinArithmeticModRealSig{} _ builtinFunc = &builtinArithmeticModDecimalSig{} + + _ builtinFunc = &builtinArithmeticPlusVectorFloat32Sig{} + _ builtinFunc = &builtinArithmeticMinusVectorFloat32Sig{} + _ builtinFunc = &builtinArithmeticMultiplyVectorFloat32Sig{} ) // isConstantBinaryLiteral return true if expr is constant binary literal @@ -167,6 +171,15 @@ func (c *arithmeticPlusFunctionClass) getFunction(ctx BuildContext, args []Expre if err := c.verifyArgs(args); err != nil { return nil, err } + if args[0].GetType(ctx.GetEvalCtx()).EvalType().IsVectorKind() || args[1].GetType(ctx.GetEvalCtx()).EvalType().IsVectorKind() { + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETVectorFloat32, types.ETVectorFloat32, types.ETVectorFloat32) + if err != nil { + return nil, err + } + sig := &builtinArithmeticPlusVectorFloat32Sig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_PlusVectorFloat32) + return sig, nil + } lhsEvalTp, rhsEvalTp := numericContextResultType(ctx.GetEvalCtx(), args[0]), numericContextResultType(ctx.GetEvalCtx(), args[1]) if lhsEvalTp == types.ETReal || rhsEvalTp == types.ETReal { bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETReal, types.ETReal, types.ETReal) @@ -317,6 +330,15 @@ func (c *arithmeticMinusFunctionClass) getFunction(ctx BuildContext, args []Expr if err := c.verifyArgs(args); err != nil { return nil, err } + if args[0].GetType(ctx.GetEvalCtx()).EvalType().IsVectorKind() || args[1].GetType(ctx.GetEvalCtx()).EvalType().IsVectorKind() { + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETVectorFloat32, types.ETVectorFloat32, types.ETVectorFloat32) + if err != nil { + return nil, err + } + sig := &builtinArithmeticMinusVectorFloat32Sig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_PlusVectorFloat32) + return sig, nil + } lhsEvalTp, rhsEvalTp := numericContextResultType(ctx.GetEvalCtx(), args[0]), numericContextResultType(ctx.GetEvalCtx(), args[1]) if lhsEvalTp == types.ETReal || rhsEvalTp == types.ETReal { bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETReal, types.ETReal, types.ETReal) @@ -1157,3 +1179,90 @@ func (s *builtinArithmeticModIntSignedSignedSig) evalInt(ctx EvalContext, row ch return a % b, false, nil } + +type builtinArithmeticPlusVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (s *builtinArithmeticPlusVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinArithmeticPlusVectorFloat32Sig{} + newSig.cloneFrom(&s.baseBuiltinFunc) + return newSig +} + +func (s *builtinArithmeticPlusVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (types.VectorFloat32, bool, error) { + a, isLHSNull, err := s.args[0].EvalVectorFloat32(ctx, row) + if err != nil { + return types.ZeroVectorFloat32, isLHSNull, err + } + b, isRHSNull, err := s.args[1].EvalVectorFloat32(ctx, row) + if err != nil { + return types.ZeroVectorFloat32, isRHSNull, err + } + if isLHSNull || isRHSNull { + return types.ZeroVectorFloat32, true, nil + } + v, err := a.Add(b) + if err != nil { + return types.ZeroVectorFloat32, true, err + } + return v, false, nil +} + +type builtinArithmeticMinusVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (s *builtinArithmeticMinusVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinArithmeticMinusVectorFloat32Sig{} + newSig.cloneFrom(&s.baseBuiltinFunc) + return newSig +} + +func (s *builtinArithmeticMinusVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (types.VectorFloat32, bool, error) { + a, isLHSNull, err := s.args[0].EvalVectorFloat32(ctx, row) + if err != nil { + return types.ZeroVectorFloat32, isLHSNull, err + } + b, isRHSNull, err := s.args[1].EvalVectorFloat32(ctx, row) + if err != nil { + return types.ZeroVectorFloat32, isRHSNull, err + } + if isLHSNull || isRHSNull { + return types.ZeroVectorFloat32, true, nil + } + v, err := a.Sub(b) + if err != nil { + return types.ZeroVectorFloat32, true, err + } + return v, false, nil +} + +type builtinArithmeticMultiplyVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (s *builtinArithmeticMultiplyVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinArithmeticMultiplyVectorFloat32Sig{} + newSig.cloneFrom(&s.baseBuiltinFunc) + return newSig +} + +func (s *builtinArithmeticMultiplyVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (types.VectorFloat32, bool, error) { + a, isLHSNull, err := s.args[0].EvalVectorFloat32(ctx, row) + if err != nil { + return types.ZeroVectorFloat32, isLHSNull, err + } + b, isRHSNull, err := s.args[1].EvalVectorFloat32(ctx, row) + if err != nil { + return types.ZeroVectorFloat32, isRHSNull, err + } + if isLHSNull || isRHSNull { + return types.ZeroVectorFloat32, true, nil + } + v, err := a.Mul(b) + if err != nil { + return types.ZeroVectorFloat32, true, err + } + return v, false, nil +} diff --git a/pkg/expression/builtin_cast.go b/pkg/expression/builtin_cast.go index 314629cefdc3a..16f7a4c2f44eb 100644 --- a/pkg/expression/builtin_cast.go +++ b/pkg/expression/builtin_cast.go @@ -107,6 +107,12 @@ var ( _ builtinFunc = &builtinCastJSONAsTimeSig{} _ builtinFunc = &builtinCastJSONAsDurationSig{} _ builtinFunc = &builtinCastJSONAsJSONSig{} + + _ builtinFunc = &builtinCastStringAsVectorFloat32Sig{} + _ builtinFunc = &builtinCastVectorFloat32AsStringSig{} + _ builtinFunc = &builtinCastVectorFloat32AsVectorFloat32Sig{} + _ builtinFunc = &builtinCastUnsupportedAsVectorFloat32Sig{} + _ builtinFunc = &builtinCastVectorFloat32AsUnsupportedSig{} ) type castAsIntFunctionClass struct { @@ -153,8 +159,11 @@ func (c *castAsIntFunctionClass) getFunction(ctx BuildContext, args []Expression case types.ETString: sig = &builtinCastStringAsIntSig{bf} sig.setPbCode(tipb.ScalarFuncSig_CastStringAsInt) + case types.ETVectorFloat32: + sig = &builtinCastVectorFloat32AsUnsupportedSig{bf.baseBuiltinFunc} + // sig.setPbCode(tipb.ScalarFuncSig_CastVectorFloat32AsInt) default: - panic("unsupported types.EvalType in castAsIntFunctionClass") + return nil, errors.Errorf("cannot cast from %s to %s", argTp, "Int") } return sig, nil } @@ -209,8 +218,11 @@ func (c *castAsRealFunctionClass) getFunction(ctx BuildContext, args []Expressio case types.ETString: sig = &builtinCastStringAsRealSig{bf} sig.setPbCode(tipb.ScalarFuncSig_CastStringAsReal) + case types.ETVectorFloat32: + sig = &builtinCastVectorFloat32AsUnsupportedSig{bf.baseBuiltinFunc} + // sig.setPbCode(tipb.ScalarFuncSig_CastVectorFloat32AsReal) default: - panic("unsupported types.EvalType in castAsRealFunctionClass") + return nil, errors.Errorf("cannot cast from %s to %s", argTp, "Real") } return sig, nil } @@ -264,8 +276,11 @@ func (c *castAsDecimalFunctionClass) getFunction(ctx BuildContext, args []Expres case types.ETString: sig = &builtinCastStringAsDecimalSig{bf} sig.setPbCode(tipb.ScalarFuncSig_CastStringAsDecimal) + case types.ETVectorFloat32: + sig = &builtinCastVectorFloat32AsUnsupportedSig{bf.baseBuiltinFunc} + // sig.setPbCode(tipb.ScalarFuncSig_CastVectorFloat32AsDecimal) default: - panic("unsupported types.EvalType in castAsDecimalFunctionClass") + return nil, errors.Errorf("cannot cast from %s to %s", argTp, "Decimal") } return sig, nil } @@ -326,6 +341,9 @@ func (c *castAsStringFunctionClass) getFunction(ctx BuildContext, args []Express case types.ETJson: sig = &builtinCastJSONAsStringSig{bf} sig.setPbCode(tipb.ScalarFuncSig_CastJsonAsString) + case types.ETVectorFloat32: + sig = &builtinCastVectorFloat32AsStringSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_CastVectorFloat32AsString) case types.ETString: // When cast from binary to some other charsets, we should check if the binary is valid or not. // so we build a from_binary function to do this check. @@ -333,7 +351,7 @@ func (c *castAsStringFunctionClass) getFunction(ctx BuildContext, args []Express sig = &builtinCastStringAsStringSig{bf} sig.setPbCode(tipb.ScalarFuncSig_CastStringAsString) default: - panic("unsupported types.EvalType in castAsStringFunctionClass") + return nil, errors.Errorf("cannot cast from %s to %s", argTp, "String") } return sig, nil } @@ -375,8 +393,11 @@ func (c *castAsTimeFunctionClass) getFunction(ctx BuildContext, args []Expressio case types.ETString: sig = &builtinCastStringAsTimeSig{bf} sig.setPbCode(tipb.ScalarFuncSig_CastStringAsTime) + case types.ETVectorFloat32: + sig = &builtinCastVectorFloat32AsUnsupportedSig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_CastVectorFloat32AsTime) default: - panic("unsupported types.EvalType in castAsTimeFunctionClass") + return nil, errors.Errorf("cannot cast from %s to %s", argTp, "Datetime") } return sig, nil } @@ -418,8 +439,11 @@ func (c *castAsDurationFunctionClass) getFunction(ctx BuildContext, args []Expre case types.ETString: sig = &builtinCastStringAsDurationSig{bf} sig.setPbCode(tipb.ScalarFuncSig_CastStringAsDuration) + case types.ETVectorFloat32: + sig = &builtinCastVectorFloat32AsUnsupportedSig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_CastVectorFloat32AsDuration) default: - panic("unsupported types.EvalType in castAsDurationFunctionClass") + return nil, errors.Errorf("cannot cast from %s to %s", argTp, "Time") } return sig, nil } @@ -619,12 +643,175 @@ func (c *castAsJSONFunctionClass) getFunction(ctx BuildContext, args []Expressio sig = &builtinCastStringAsJSONSig{bf} sig.getRetTp().AddFlag(mysql.ParseToJSONFlag) sig.setPbCode(tipb.ScalarFuncSig_CastStringAsJson) + case types.ETVectorFloat32: + sig = &builtinCastVectorFloat32AsUnsupportedSig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_CastVectorFloat32AsJson) + default: + return nil, errors.Errorf("cannot cast from %s to %s", argTp, "Json") + } + return sig, nil +} + +type castAsVectorFloat32FunctionClass struct { + baseFunctionClass + + tp *types.FieldType +} + +func (c *castAsVectorFloat32FunctionClass) getFunction(ctx BuildContext, args []Expression) (sig builtinFunc, err error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFunc(ctx, c.funcName, args, c.tp) + if err != nil { + return nil, err + } + argTp := args[0].GetType(ctx.GetEvalCtx()).EvalType() + switch argTp { + case types.ETInt: + sig = &builtinCastUnsupportedAsVectorFloat32Sig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_CastIntAsVectorFloat32) + case types.ETReal: + sig = &builtinCastUnsupportedAsVectorFloat32Sig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_CastRealAsVectorFloat32) + case types.ETDecimal: + sig = &builtinCastUnsupportedAsVectorFloat32Sig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_CastDecimalAsVectorFloat32) + case types.ETDatetime, types.ETTimestamp: + sig = &builtinCastUnsupportedAsVectorFloat32Sig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_CastTimeAsVectorFloat32) + case types.ETDuration: + sig = &builtinCastUnsupportedAsVectorFloat32Sig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_CastDurationAsVectorFloat32) + case types.ETJson: + sig = &builtinCastUnsupportedAsVectorFloat32Sig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_CastJsonAsVectorFloat32) + case types.ETVectorFloat32: + sig = &builtinCastVectorFloat32AsVectorFloat32Sig{bf} + sig.setPbCode(tipb.ScalarFuncSig_CastVectorFloat32AsVectorFloat32) + case types.ETString: + sig = &builtinCastStringAsVectorFloat32Sig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_CastStringAsVectorFloat32) default: - panic("unsupported types.EvalType in castAsJSONFunctionClass") + return nil, errors.Errorf("cannot cast from %s to %s", argTp, "VectorFloat32") } return sig, nil } +type builtinCastUnsupportedAsVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (b *builtinCastUnsupportedAsVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinCastUnsupportedAsVectorFloat32Sig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinCastUnsupportedAsVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, _ chunk.Row) (res types.VectorFloat32, isNull bool, err error) { + return types.ZeroVectorFloat32, false, errors.Errorf( + "cannot cast from %s to vector", + types.TypeStr(b.args[0].GetType(ctx).GetType())) +} + +type builtinCastVectorFloat32AsUnsupportedSig struct { + baseBuiltinFunc +} + +func (b *builtinCastVectorFloat32AsUnsupportedSig) Clone() builtinFunc { + newSig := &builtinCastVectorFloat32AsUnsupportedSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinCastVectorFloat32AsUnsupportedSig) evalInt(_ EvalContext, _ chunk.Row) (int64, bool, error) { + return 0, false, errors.Errorf( + "cannot cast from vector to %s", + types.TypeStr(b.tp.GetType())) +} + +func (b *builtinCastVectorFloat32AsUnsupportedSig) evalReal(_ EvalContext, _ chunk.Row) (float64, bool, error) { + return 0, false, errors.Errorf( + "cannot cast from vector to %s", + types.TypeStr(b.tp.GetType())) +} + +func (b *builtinCastVectorFloat32AsUnsupportedSig) evalDecimal(_ EvalContext, _ chunk.Row) (*types.MyDecimal, bool, error) { + return nil, false, errors.Errorf( + "cannot cast from vector to %s", + types.TypeStr(b.tp.GetType())) +} + +func (b *builtinCastVectorFloat32AsUnsupportedSig) evalString(_ EvalContext, _ chunk.Row) (string, bool, error) { + return "", false, errors.Errorf( + "cannot cast from vector to %s", + types.TypeStr(b.tp.GetType())) +} + +func (b *builtinCastVectorFloat32AsUnsupportedSig) evalTime(_ EvalContext, _ chunk.Row) (types.Time, bool, error) { + return types.ZeroTime, false, errors.Errorf( + "cannot cast from vector to %s", + types.TypeStr(b.tp.GetType())) +} + +func (b *builtinCastVectorFloat32AsUnsupportedSig) evalDuration(_ EvalContext, _ chunk.Row) (types.Duration, bool, error) { + return types.ZeroDuration, false, errors.Errorf( + "cannot cast from vector to %s", + types.TypeStr(b.tp.GetType())) +} + +func (b *builtinCastVectorFloat32AsUnsupportedSig) evalJSON(_ EvalContext, _ chunk.Row) (types.BinaryJSON, bool, error) { + return types.BinaryJSON{}, false, errors.Errorf( + "cannot cast from vector to %s", + types.TypeStr(b.tp.GetType())) +} + +type builtinCastStringAsVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (b *builtinCastStringAsVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinCastStringAsVectorFloat32Sig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinCastStringAsVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (types.VectorFloat32, bool, error) { + val, isNull, err := b.args[0].EvalString(ctx, row) + if isNull || err != nil { + return types.ZeroVectorFloat32, isNull, err + } + vec, err := types.ParseVectorFloat32(val) + if err != nil { + return types.ZeroVectorFloat32, false, err + } + if err = vec.CheckDimsFitColumn(b.tp.GetFlen()); err != nil { + return types.ZeroVectorFloat32, isNull, err + } + return vec, false, nil +} + +type builtinCastVectorFloat32AsVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (b *builtinCastVectorFloat32AsVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinCastVectorFloat32AsVectorFloat32Sig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinCastVectorFloat32AsVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (types.VectorFloat32, bool, error) { + val, isNull, err := b.args[0].EvalVectorFloat32(ctx, row) + if isNull || err != nil { + return types.ZeroVectorFloat32, isNull, err + } + if err = val.CheckDimsFitColumn(b.tp.GetFlen()); err != nil { + return types.ZeroVectorFloat32, isNull, err + } + return val, false, nil +} + type builtinCastIntAsIntSig struct { baseBuiltinCastFunc } @@ -1912,6 +2099,28 @@ func (b *builtinCastJSONAsStringSig) evalString(ctx EvalContext, row chunk.Row) return s, false, nil } +type builtinCastVectorFloat32AsStringSig struct { + baseBuiltinFunc +} + +func (b *builtinCastVectorFloat32AsStringSig) Clone() builtinFunc { + newSig := &builtinCastVectorFloat32AsStringSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinCastVectorFloat32AsStringSig) evalString(ctx EvalContext, row chunk.Row) (res string, isNull bool, err error) { + val, isNull, err := b.args[0].EvalVectorFloat32(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + s, err := types.ProduceStrWithSpecifiedTp(val.String(), b.tp, typeCtx(ctx), false) + if err != nil { + return res, false, err + } + return s, false, nil +} + type builtinCastJSONAsTimeSig struct { baseBuiltinFunc } @@ -2124,11 +2333,15 @@ func BuildCastFunctionWithCheck(ctx BuildContext, expr Expression, tp *types.Fie } else { fc = &castAsJSONFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} } + case types.ETVectorFloat32: + fc = &castAsVectorFloat32FunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} case types.ETString: fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} if expr.GetType(ctx.GetEvalCtx()).GetType() == mysql.TypeBit { tp.SetFlen((expr.GetType(ctx.GetEvalCtx()).GetFlen() + 7) / 8) } + default: + return nil, errors.Errorf("cannot cast from %s", tp.EvalType()) } f, err := fc.getFunction(ctx, []Expression{expr}) res = &ScalarFunction{ @@ -2347,6 +2560,16 @@ func WrapWithCastAsJSON(ctx BuildContext, expr Expression) Expression { return BuildCastFunction(ctx, expr, tp) } +// WrapWithCastAsVectorFloat32 wraps `expr` with `cast` if the return type of expr is not +// type VectorFloat32, otherwise, returns `expr` directly. +func WrapWithCastAsVectorFloat32(ctx BuildContext, expr Expression) Expression { + if expr.GetType(ctx.GetEvalCtx()).GetType() == mysql.TypeTiDBVectorFloat32 { + return expr + } + tp := types.NewFieldType(mysql.TypeTiDBVectorFloat32) + return BuildCastFunction(ctx, expr, tp) +} + // TryPushCastIntoControlFunctionForHybridType try to push cast into control function for Hybrid Type. // If necessary, it will rebuild control function using changed args. // When a hybrid type is the output of a control function, the result may be as a numeric type to subsequent calculation diff --git a/pkg/expression/builtin_compare.go b/pkg/expression/builtin_compare.go index 55a44dff281af..449d17841d063 100644 --- a/pkg/expression/builtin_compare.go +++ b/pkg/expression/builtin_compare.go @@ -46,6 +46,7 @@ var ( _ builtinFunc = &builtinCoalesceStringSig{} _ builtinFunc = &builtinCoalesceTimeSig{} _ builtinFunc = &builtinCoalesceDurationSig{} + _ builtinFunc = &builtinCoalesceVectorFloat32Sig{} _ builtinFunc = &builtinGreatestIntSig{} _ builtinFunc = &builtinGreatestRealSig{} @@ -54,6 +55,7 @@ var ( _ builtinFunc = &builtinGreatestDurationSig{} _ builtinFunc = &builtinGreatestTimeSig{} _ builtinFunc = &builtinGreatestCmpStringAsTimeSig{} + _ builtinFunc = &builtinGreatestVectorFloat32Sig{} _ builtinFunc = &builtinLeastIntSig{} _ builtinFunc = &builtinLeastRealSig{} _ builtinFunc = &builtinLeastDecimalSig{} @@ -63,6 +65,7 @@ var ( _ builtinFunc = &builtinLeastCmpStringAsTimeSig{} _ builtinFunc = &builtinIntervalIntSig{} _ builtinFunc = &builtinIntervalRealSig{} + _ builtinFunc = &builtinLeastVectorFloat32Sig{} _ builtinFunc = &builtinLTIntSig{} _ builtinFunc = &builtinLTRealSig{} @@ -167,6 +170,11 @@ func (c *coalesceFunctionClass) getFunction(ctx BuildContext, args []Expression) case types.ETJson: sig = &builtinCoalesceJSONSig{bf} sig.setPbCode(tipb.ScalarFuncSig_CoalesceJson) + case types.ETVectorFloat32: + sig = &builtinCoalesceVectorFloat32Sig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_CoalesceVectorFloat32) + default: + return nil, errors.Errorf("%s is not supported for COALESCE()", retEvalTp) } return sig, nil @@ -329,6 +337,28 @@ func (b *builtinCoalesceJSONSig) evalJSON(ctx EvalContext, row chunk.Row) (res t return res, isNull, err } +// builtinCoalesceVectorFloat32Sig is builtin function coalesce signature which return type vector float32. +// See http://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#function_coalesce +type builtinCoalesceVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (b *builtinCoalesceVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinCoalesceVectorFloat32Sig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinCoalesceVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (res types.VectorFloat32, isNull bool, err error) { + for _, a := range b.getArgs() { + res, isNull, err = a.EvalVectorFloat32(ctx, row) + if err != nil || !isNull { + break + } + } + return res, isNull, err +} + func aggregateType(ctx EvalContext, args []Expression) *types.FieldType { fieldTypes := make([]*types.FieldType, len(args)) for i := range fieldTypes { @@ -497,6 +527,11 @@ func (c *greatestFunctionClass) getFunction(ctx BuildContext, args []Expression) sig = &builtinGreatestTimeSig{bf, false} sig.setPbCode(tipb.ScalarFuncSig_GreatestTime) } + case types.ETVectorFloat32: + sig = &builtinGreatestVectorFloat32Sig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_GreatestVectorFloat32) + default: + return nil, errors.Errorf("unsupported type %s during evaluation", argTp) } flen, decimal := fixFlenAndDecimalForGreatestAndLeast(ctx.GetEvalCtx(), args) @@ -750,6 +785,29 @@ func (b *builtinGreatestDurationSig) evalDuration(ctx EvalContext, row chunk.Row return res, false, nil } +type builtinGreatestVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (b *builtinGreatestVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinGreatestVectorFloat32Sig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinGreatestVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (res types.VectorFloat32, isNull bool, err error) { + for i := 0; i < len(b.args); i++ { + v, isNull, err := b.args[i].EvalVectorFloat32(ctx, row) + if isNull || err != nil { + return types.VectorFloat32{}, true, err + } + if i == 0 || v.Compare(res) > 0 { + res = v + } + } + return res, false, nil +} + type leastFunctionClass struct { baseFunctionClass } @@ -810,6 +868,11 @@ func (c *leastFunctionClass) getFunction(ctx BuildContext, args []Expression) (s sig = &builtinLeastTimeSig{bf, false} sig.setPbCode(tipb.ScalarFuncSig_LeastTime) } + case types.ETVectorFloat32: + sig = &builtinLeastVectorFloat32Sig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_LeastVectorFloat32) + default: + return nil, errors.Errorf("unsupported type %s during evaluation", argTp) } flen, decimal := fixFlenAndDecimalForGreatestAndLeast(ctx.GetEvalCtx(), args) sig.getRetTp().SetFlenUnderLimit(flen) @@ -1033,6 +1096,29 @@ func (b *builtinLeastDurationSig) evalDuration(ctx EvalContext, row chunk.Row) ( return res, false, nil } +type builtinLeastVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (b *builtinLeastVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinLeastVectorFloat32Sig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinLeastVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (res types.VectorFloat32, isNull bool, err error) { + for i := 0; i < len(b.args); i++ { + v, isNull, err := b.args[i].EvalVectorFloat32(ctx, row) + if isNull || err != nil { + return types.VectorFloat32{}, true, err + } + if i == 0 || v.Compare(res) < 0 { + res = v + } + } + return res, false, nil +} + type intervalFunctionClass struct { baseFunctionClass } @@ -1288,7 +1374,9 @@ func GetAccurateCmpType(ctx EvalContext, lhs, rhs Expression) types.EvalType { lhsFieldType, rhsFieldType := lhs.GetType(ctx), rhs.GetType(ctx) lhsEvalType, rhsEvalType := lhsFieldType.EvalType(), rhsFieldType.EvalType() cmpType := getBaseCmpType(lhsEvalType, rhsEvalType, lhsFieldType, rhsFieldType) - if (lhsEvalType.IsStringKind() && lhsFieldType.GetType() == mysql.TypeJSON) || (rhsEvalType.IsStringKind() && rhsFieldType.GetType() == mysql.TypeJSON) { + if lhsEvalType == types.ETVectorFloat32 || rhsEvalType == types.ETVectorFloat32 { + cmpType = types.ETVectorFloat32 + } else if (lhsEvalType.IsStringKind() && lhsFieldType.GetType() == mysql.TypeJSON) || (rhsEvalType.IsStringKind() && rhsFieldType.GetType() == mysql.TypeJSON) { cmpType = types.ETJson } else if cmpType == types.ETString && (types.IsTypeTime(lhsFieldType.GetType()) || types.IsTypeTime(rhsFieldType.GetType())) { // date[time] date[time] @@ -1355,8 +1443,11 @@ func GetCmpFunction(ctx BuildContext, lhs, rhs Expression) CompareFunc { return CompareTime case types.ETJson: return CompareJSON + case types.ETVectorFloat32: + return CompareVectorFloat32 + default: + panic(fmt.Sprintf("cannot compare with %s", GetAccurateCmpType(ctx.GetEvalCtx(), lhs, rhs))) } - return nil } // isTemporalColumn checks if a expression is a temporal column, @@ -1958,6 +2049,32 @@ func (c *compareFunctionClass) generateCmpSigs(ctx BuildContext, args []Expressi sig = &builtinNullEQJSONSig{bf} sig.setPbCode(tipb.ScalarFuncSig_NullEQJson) } + case types.ETVectorFloat32: + switch c.op { + case opcode.LT: + sig = &builtinLTVectorFloat32Sig{bf} + sig.setPbCode(tipb.ScalarFuncSig_LTVectorFloat32) + case opcode.LE: + sig = &builtinLEVectorFloat32Sig{bf} + sig.setPbCode(tipb.ScalarFuncSig_LEVectorFloat32) + case opcode.GT: + sig = &builtinGTVectorFloat32Sig{bf} + sig.setPbCode(tipb.ScalarFuncSig_GTVectorFloat32) + case opcode.GE: + sig = &builtinGEVectorFloat32Sig{bf} + sig.setPbCode(tipb.ScalarFuncSig_GEVectorFloat32) + case opcode.EQ: + sig = &builtinEQVectorFloat32Sig{bf} + sig.setPbCode(tipb.ScalarFuncSig_EQVectorFloat32) + case opcode.NE: + sig = &builtinNEVectorFloat32Sig{bf} + sig.setPbCode(tipb.ScalarFuncSig_NEVectorFloat32) + case opcode.NullEQ: + sig = &builtinNullEQVectorFloat32Sig{bf} + sig.setPbCode(tipb.ScalarFuncSig_NullEQVectorFloat32) + } + default: + return nil, errors.Errorf("operator %s is not supported for %s", c.op, tp) } return } @@ -2060,6 +2177,20 @@ func (b *builtinLTJSONSig) evalInt(ctx EvalContext, row chunk.Row) (val int64, i return resOfLT(CompareJSON(ctx, b.args[0], b.args[1], row, row)) } +type builtinLTVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (b *builtinLTVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinLTVectorFloat32Sig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinLTVectorFloat32Sig) evalInt(ctx EvalContext, row chunk.Row) (val int64, isNull bool, err error) { + return resOfLT(CompareVectorFloat32(ctx, b.args[0], b.args[1], row, row)) +} + type builtinLEIntSig struct { baseBuiltinFunc } @@ -2158,6 +2289,20 @@ func (b *builtinLEJSONSig) evalInt(ctx EvalContext, row chunk.Row) (val int64, i return resOfLE(CompareJSON(ctx, b.args[0], b.args[1], row, row)) } +type builtinLEVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (b *builtinLEVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinLEVectorFloat32Sig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinLEVectorFloat32Sig) evalInt(ctx EvalContext, row chunk.Row) (val int64, isNull bool, err error) { + return resOfLE(CompareVectorFloat32(ctx, b.args[0], b.args[1], row, row)) +} + type builtinGTIntSig struct { baseBuiltinFunc } @@ -2256,6 +2401,20 @@ func (b *builtinGTJSONSig) evalInt(ctx EvalContext, row chunk.Row) (val int64, i return resOfGT(CompareJSON(ctx, b.args[0], b.args[1], row, row)) } +type builtinGTVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (b *builtinGTVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinGTVectorFloat32Sig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinGTVectorFloat32Sig) evalInt(ctx EvalContext, row chunk.Row) (val int64, isNull bool, err error) { + return resOfGT(CompareVectorFloat32(ctx, b.args[0], b.args[1], row, row)) +} + type builtinGEIntSig struct { baseBuiltinFunc } @@ -2354,6 +2513,20 @@ func (b *builtinGEJSONSig) evalInt(ctx EvalContext, row chunk.Row) (val int64, i return resOfGE(CompareJSON(ctx, b.args[0], b.args[1], row, row)) } +type builtinGEVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (b *builtinGEVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinGEVectorFloat32Sig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinGEVectorFloat32Sig) evalInt(ctx EvalContext, row chunk.Row) (val int64, isNull bool, err error) { + return resOfGE(CompareVectorFloat32(ctx, b.args[0], b.args[1], row, row)) +} + type builtinEQIntSig struct { baseBuiltinFunc } @@ -2452,6 +2625,20 @@ func (b *builtinEQJSONSig) evalInt(ctx EvalContext, row chunk.Row) (val int64, i return resOfEQ(CompareJSON(ctx, b.args[0], b.args[1], row, row)) } +type builtinEQVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (b *builtinEQVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinEQVectorFloat32Sig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinEQVectorFloat32Sig) evalInt(ctx EvalContext, row chunk.Row) (val int64, isNull bool, err error) { + return resOfEQ(CompareVectorFloat32(ctx, b.args[0], b.args[1], row, row)) +} + type builtinNEIntSig struct { baseBuiltinFunc } @@ -2550,6 +2737,20 @@ func (b *builtinNEJSONSig) evalInt(ctx EvalContext, row chunk.Row) (val int64, i return resOfNE(CompareJSON(ctx, b.args[0], b.args[1], row, row)) } +type builtinNEVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (b *builtinNEVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinNEVectorFloat32Sig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinNEVectorFloat32Sig) evalInt(ctx EvalContext, row chunk.Row) (val int64, isNull bool, err error) { + return resOfNE(CompareVectorFloat32(ctx, b.args[0], b.args[1], row, row)) +} + type builtinNullEQIntSig struct { baseBuiltinFunc } @@ -2773,6 +2974,40 @@ func (b *builtinNullEQJSONSig) evalInt(ctx EvalContext, row chunk.Row) (val int6 return res, false, nil } +type builtinNullEQVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (b *builtinNullEQVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinNullEQVectorFloat32Sig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinNullEQVectorFloat32Sig) evalInt(ctx EvalContext, row chunk.Row) (val int64, isNull bool, err error) { + arg0, isNull0, err := b.args[0].EvalVectorFloat32(ctx, row) + if err != nil { + return 0, true, err + } + arg1, isNull1, err := b.args[1].EvalVectorFloat32(ctx, row) + if err != nil { + return 0, true, err + } + var res int64 + switch { + case isNull0 && isNull1: + res = 1 + case isNull0 != isNull1: + return res, false, nil + default: + cmpRes := arg0.Compare(arg1) + if cmpRes == 0 { + res = 1 + } + } + return res, false, nil +} + func resOfLT(val int64, isNull bool, err error) (int64, bool, error) { if isNull || err != nil { return 0, isNull, err @@ -2996,3 +3231,21 @@ func CompareJSON(sctx EvalContext, lhsArg, rhsArg Expression, lhsRow, rhsRow chu } return int64(types.CompareBinaryJSON(arg0, arg1)), false, nil } + +// CompareVectorFloat32 compares two float32 vectors. +func CompareVectorFloat32(sctx EvalContext, lhsArg, rhsArg Expression, lhsRow, rhsRow chunk.Row) (int64, bool, error) { + arg0, isNull0, err := lhsArg.EvalVectorFloat32(sctx, lhsRow) + if err != nil { + return 0, true, err + } + + arg1, isNull1, err := rhsArg.EvalVectorFloat32(sctx, rhsRow) + if err != nil { + return 0, true, err + } + + if isNull0 || isNull1 { + return compareNull(isNull0, isNull1), true, nil + } + return int64(arg0.Compare(arg1)), false, nil +} diff --git a/pkg/expression/builtin_control.go b/pkg/expression/builtin_control.go index 4dbafb2ced5fd..099d0b7443249 100644 --- a/pkg/expression/builtin_control.go +++ b/pkg/expression/builtin_control.go @@ -15,6 +15,7 @@ package expression import ( + "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/types" @@ -37,6 +38,7 @@ var ( _ builtinFunc = &builtinCaseWhenTimeSig{} _ builtinFunc = &builtinCaseWhenDurationSig{} _ builtinFunc = &builtinCaseWhenJSONSig{} + _ builtinFunc = &builtinCaseWhenVectorFloat32Sig{} _ builtinFunc = &builtinIfNullIntSig{} _ builtinFunc = &builtinIfNullRealSig{} _ builtinFunc = &builtinIfNullDecimalSig{} @@ -44,6 +46,7 @@ var ( _ builtinFunc = &builtinIfNullTimeSig{} _ builtinFunc = &builtinIfNullDurationSig{} _ builtinFunc = &builtinIfNullJSONSig{} + _ builtinFunc = &builtinIfNullVectorFloat32Sig{} _ builtinFunc = &builtinIfIntSig{} _ builtinFunc = &builtinIfRealSig{} _ builtinFunc = &builtinIfDecimalSig{} @@ -51,6 +54,7 @@ var ( _ builtinFunc = &builtinIfTimeSig{} _ builtinFunc = &builtinIfDurationSig{} _ builtinFunc = &builtinIfJSONSig{} + _ builtinFunc = &builtinIfVectorFloat32Sig{} ) func maxlen(lhsFlen, rhsFlen int) int { @@ -372,6 +376,11 @@ func (c *caseWhenFunctionClass) getFunction(ctx BuildContext, args []Expression) case types.ETJson: sig = &builtinCaseWhenJSONSig{bf} sig.setPbCode(tipb.ScalarFuncSig_CaseWhenJson) + case types.ETVectorFloat32: + sig = &builtinCaseWhenVectorFloat32Sig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_CaseWhenVectorFloat32) + default: + return nil, errors.Errorf("%s is not supported for CASE WHEN", tp) } return sig, nil } @@ -626,6 +635,40 @@ func (b *builtinCaseWhenJSONSig) evalJSON(ctx EvalContext, row chunk.Row) (ret t return ret, true, nil } +type builtinCaseWhenVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (b *builtinCaseWhenVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinCaseWhenVectorFloat32Sig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalVectorFloat32 evals a builtinCaseWhenVectorFloat32Sig. +// See https://dev.mysql.com/doc/refman/5.7/en/control-flow-functions.html#operator_case +func (b *builtinCaseWhenVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (ret types.VectorFloat32, isNull bool, err error) { + var condition int64 + args, l := b.getArgs(), len(b.getArgs()) + for i := 0; i < l-1; i += 2 { + condition, isNull, err = args[i].EvalInt(ctx, row) + if err != nil { + return + } + if isNull || condition == 0 { + continue + } + return args[i+1].EvalVectorFloat32(ctx, row) + } + // when clause(condition, result) -> args[i], args[i+1]; (i >= 0 && i+1 < l-1) + // else clause -> args[l-1] + // If case clause has else clause, l%2 == 1. + if l%2 == 1 { + return args[l-1].EvalVectorFloat32(ctx, row) + } + return ret, true, nil +} + type ifFunctionClass struct { baseFunctionClass } @@ -673,6 +716,11 @@ func (c *ifFunctionClass) getFunction(ctx BuildContext, args []Expression) (sig case types.ETJson: sig = &builtinIfJSONSig{bf} sig.setPbCode(tipb.ScalarFuncSig_IfJson) + case types.ETVectorFloat32: + sig = &builtinIfVectorFloat32Sig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_IfVectorFloat32) + default: + return nil, errors.Errorf("%s is not supported for IF()", evalTps) } return sig, nil } @@ -824,6 +872,27 @@ func (b *builtinIfJSONSig) evalJSON(ctx EvalContext, row chunk.Row) (ret types.B return b.args[2].EvalJSON(ctx, row) } +type builtinIfVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (b *builtinIfVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinIfVectorFloat32Sig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinIfVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (ret types.VectorFloat32, isNull bool, err error) { + arg0, isNull0, err := b.args[0].EvalInt(ctx, row) + if err != nil { + return ret, true, err + } + if !isNull0 && arg0 != 0 { + return b.args[1].EvalVectorFloat32(ctx, row) + } + return b.args[2].EvalVectorFloat32(ctx, row) +} + type ifNullFunctionClass struct { baseFunctionClass } @@ -873,6 +942,11 @@ func (c *ifNullFunctionClass) getFunction(ctx BuildContext, args []Expression) ( case types.ETJson: sig = &builtinIfNullJSONSig{bf} sig.setPbCode(tipb.ScalarFuncSig_IfNullJson) + case types.ETVectorFloat32: + sig = &builtinIfNullVectorFloat32Sig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_IfNullVectorFloat32) + default: + return nil, errors.Errorf("%s is not supported for IFNULL()", evalTps) } return sig, nil } @@ -1009,3 +1083,22 @@ func (b *builtinIfNullJSONSig) evalJSON(ctx EvalContext, row chunk.Row) (types.B arg1, isNull, err := b.args[1].EvalJSON(ctx, row) return arg1, isNull || err != nil, err } + +type builtinIfNullVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (b *builtinIfNullVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinIfNullVectorFloat32Sig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinIfNullVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (types.VectorFloat32, bool, error) { + arg0, isNull, err := b.args[0].EvalVectorFloat32(ctx, row) + if !isNull { + return arg0, err != nil, err + } + arg1, isNull, err := b.args[1].EvalVectorFloat32(ctx, row) + return arg1, isNull || err != nil, err +} diff --git a/pkg/expression/builtin_info.go b/pkg/expression/builtin_info.go index e0dc23eb125dc..98f1780273518 100644 --- a/pkg/expression/builtin_info.go +++ b/pkg/expression/builtin_info.go @@ -735,7 +735,7 @@ func (b *builtinBenchmarkSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bo } } default: // Should never go into here. - return 0, true, errors.Errorf("EvalType %v not implemented for builtin BENCHMARK()", evalType) + return 0, true, errors.Errorf("%s is not supported for BENCHMARK()", evalType) } // Return value of BENCHMARK() is always 0. diff --git a/pkg/expression/builtin_info_vec.go b/pkg/expression/builtin_info_vec.go index b4755d2f4bec7..5887ad6475b6c 100644 --- a/pkg/expression/builtin_info_vec.go +++ b/pkg/expression/builtin_info_vec.go @@ -320,7 +320,7 @@ func (b *builtinBenchmarkSig) vecEvalInt(ctx EvalContext, input *chunk.Chunk, re } } default: // Should never go into here. - return errors.Errorf("EvalType %v not implemented for builtin BENCHMARK()", evalType) + return errors.Errorf("%s is not supported for BENCHMARK()", evalType) } // Return value of BENCHMARK() is always 0. diff --git a/pkg/expression/builtin_miscellaneous.go b/pkg/expression/builtin_miscellaneous.go index 9477b50f0f55d..859932f3ea8b7 100644 --- a/pkg/expression/builtin_miscellaneous.go +++ b/pkg/expression/builtin_miscellaneous.go @@ -360,6 +360,9 @@ func (c *anyValueFunctionClass) getFunction(ctx BuildContext, args []Expression) case types.ETJson: sig = &builtinJSONAnyValueSig{bf} sig.setPbCode(tipb.ScalarFuncSig_JSONAnyValue) + case types.ETVectorFloat32: + sig = &builtinVectorFloat32AnyValueSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_VectorFloat32AnyValue) case types.ETReal: sig = &builtinRealAnyValueSig{bf} sig.setPbCode(tipb.ScalarFuncSig_RealAnyValue) @@ -374,7 +377,7 @@ func (c *anyValueFunctionClass) getFunction(ctx BuildContext, args []Expression) sig = &builtinTimeAnyValueSig{bf} sig.setPbCode(tipb.ScalarFuncSig_TimeAnyValue) default: - return nil, errIncorrectArgs.GenWithStackByArgs("ANY_VALUE") + return nil, errors.Errorf("%s is not supported for ANY_VALUE()", argTp) } return sig, nil } @@ -443,6 +446,20 @@ func (b *builtinJSONAnyValueSig) evalJSON(ctx EvalContext, row chunk.Row) (types return b.args[0].EvalJSON(ctx, row) } +type builtinVectorFloat32AnyValueSig struct { + baseBuiltinFunc +} + +func (b *builtinVectorFloat32AnyValueSig) Clone() builtinFunc { + newSig := &builtinVectorFloat32AnyValueSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinVectorFloat32AnyValueSig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (types.VectorFloat32, bool, error) { + return b.args[0].EvalVectorFloat32(ctx, row) +} + type builtinRealAnyValueSig struct { baseBuiltinFunc } @@ -1162,6 +1179,8 @@ func (c *nameConstFunctionClass) getFunction(ctx BuildContext, args []Expression sig = &builtinNameConstIntSig{bf} case types.ETJson: sig = &builtinNameConstJSONSig{bf} + case types.ETVectorFloat32: + sig = &builtinNameConstVectorFloat32Sig{bf} case types.ETReal: sig = &builtinNameConstRealSig{bf} case types.ETString: @@ -1173,7 +1192,7 @@ func (c *nameConstFunctionClass) getFunction(ctx BuildContext, args []Expression bf.tp.SetFlag(0) sig = &builtinNameConstTimeSig{bf} default: - return nil, errIncorrectArgs.GenWithStackByArgs("NAME_CONST") + return nil, errors.Errorf("%s is not supported for NAME_CONST()", argTp) } return sig, nil } @@ -1248,6 +1267,20 @@ func (b *builtinNameConstJSONSig) evalJSON(ctx EvalContext, row chunk.Row) (type return b.args[1].EvalJSON(ctx, row) } +type builtinNameConstVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (b *builtinNameConstVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinNameConstVectorFloat32Sig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinNameConstVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (types.VectorFloat32, bool, error) { + return b.args[1].EvalVectorFloat32(ctx, row) +} + type builtinNameConstDurationSig struct { baseBuiltinFunc } diff --git a/pkg/expression/builtin_op.go b/pkg/expression/builtin_op.go index 8f09c0f89ac80..74622619861e0 100644 --- a/pkg/expression/builtin_op.go +++ b/pkg/expression/builtin_op.go @@ -486,6 +486,13 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx BuildContext, args []Expres } else { sig.setPbCode(tipb.ScalarFuncSig_IntIsTrue) } + case types.ETVectorFloat32: + sig = &builtinVectorFloat32IsTrueSig{bf, c.keepNull} + // if c.keepNull { + // sig.setPbCode(tipb.ScalarFuncSig_VectorFloat32IsTrueWithNull) + // } else { + // sig.setPbCode(tipb.ScalarFuncSig_VectorFloat32IsTrue) + // } default: return nil, errors.Errorf("unexpected types.EvalType %v", argTp) } @@ -512,6 +519,13 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx BuildContext, args []Expres } else { sig.setPbCode(tipb.ScalarFuncSig_IntIsFalse) } + case types.ETVectorFloat32: + sig = &builtinVectorFloat32IsFalseSig{bf, c.keepNull} + // if c.keepNull { + // sig.setPbCode(tipb.ScalarFuncSig_VectorFloat32IsFalseWithNull) + // } else { + // sig.setPbCode(tipb.ScalarFuncSig_VectorFloat32IsFalse) + // } default: return nil, errors.Errorf("unexpected types.EvalType %v", argTp) } @@ -594,6 +608,31 @@ func (b *builtinIntIsTrueSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bo return 1, false, nil } +type builtinVectorFloat32IsTrueSig struct { + baseBuiltinFunc + keepNull bool +} + +func (b *builtinVectorFloat32IsTrueSig) Clone() builtinFunc { + newSig := &builtinVectorFloat32IsTrueSig{keepNull: b.keepNull} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinVectorFloat32IsTrueSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + input, isNull, err := b.args[0].EvalVectorFloat32(ctx, row) + if err != nil { + return 0, true, err + } + if b.keepNull && isNull { + return 0, true, nil + } + if isNull || input.IsZeroValue() { + return 0, false, nil + } + return 1, false, nil +} + type builtinRealIsFalseSig struct { baseBuiltinFunc keepNull bool @@ -669,6 +708,31 @@ func (b *builtinIntIsFalseSig) evalInt(ctx EvalContext, row chunk.Row) (int64, b return 1, false, nil } +type builtinVectorFloat32IsFalseSig struct { + baseBuiltinFunc + keepNull bool +} + +func (b *builtinVectorFloat32IsFalseSig) Clone() builtinFunc { + newSig := &builtinVectorFloat32IsFalseSig{keepNull: b.keepNull} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinVectorFloat32IsFalseSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + input, isNull, err := b.args[0].EvalVectorFloat32(ctx, row) + if err != nil { + return 0, true, err + } + if b.keepNull && isNull { + return 0, true, nil + } + if isNull || !input.IsZeroValue() { + return 0, false, nil + } + return 1, false, nil +} + type bitNegFunctionClass struct { baseFunctionClass } @@ -743,7 +807,7 @@ func (c *unaryNotFunctionClass) getFunction(ctx BuildContext, args []Expression) sig = &builtinUnaryNotJSONSig{bf} sig.setPbCode(tipb.ScalarFuncSig_UnaryNotJSON) default: - return nil, errors.Errorf("unexpected types.EvalType %v", argTp) + return nil, errors.Errorf("%s is not supported for unary not operator", argTp) } return sig, nil } @@ -1043,8 +1107,11 @@ func (c *isNullFunctionClass) getFunction(ctx BuildContext, args []Expression) ( case types.ETString: sig = &builtinStringIsNullSig{bf} sig.setPbCode(tipb.ScalarFuncSig_StringIsNull) + case types.ETVectorFloat32: + sig = &builtinVectorFloat32IsNullSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_VectorFloat32IsNull) default: - panic("unexpected types.EvalType") + return nil, errors.Errorf("%s is not supported for ISNULL()", argTp) } return sig, nil } @@ -1134,6 +1201,21 @@ func (b *builtinStringIsNullSig) evalInt(ctx EvalContext, row chunk.Row) (int64, return evalIsNull(isNull, err) } +type builtinVectorFloat32IsNullSig struct { + baseBuiltinFunc +} + +func (b *builtinVectorFloat32IsNullSig) Clone() builtinFunc { + newSig := &builtinVectorFloat32IsNullSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinVectorFloat32IsNullSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + _, isNull, err := b.args[0].EvalVectorFloat32(ctx, row) + return evalIsNull(isNull, err) +} + type builtinTimeIsNullSig struct { baseBuiltinFunc } diff --git a/pkg/expression/builtin_other.go b/pkg/expression/builtin_other.go index 14e0b7600ef19..aa6fd599f7e5b 100644 --- a/pkg/expression/builtin_other.go +++ b/pkg/expression/builtin_other.go @@ -57,6 +57,7 @@ var ( _ builtinFunc = &builtinInTimeSig{} _ builtinFunc = &builtinInDurationSig{} _ builtinFunc = &builtinInJSONSig{} + _ builtinFunc = &builtinInVectorFloat32Sig{} _ builtinFunc = &builtinRowSig{} _ builtinFunc = &builtinSetStringVarSig{} _ builtinFunc = &builtinSetIntVarSig{} @@ -153,6 +154,11 @@ func (c *inFunctionClass) getFunction(ctx BuildContext, args []Expression) (sig case types.ETJson: sig = &builtinInJSONSig{baseBuiltinFunc: bf} sig.setPbCode(tipb.ScalarFuncSig_InJson) + case types.ETVectorFloat32: + sig = &builtinInVectorFloat32Sig{baseBuiltinFunc: bf} + // sig.setPbCode(tipb.ScalarFuncSig_InVectorFloat32) + default: + return nil, errors.Errorf("%s is not supported for IN()", args[0].GetType(ctx.GetEvalCtx()).EvalType()) } return sig, nil } @@ -681,6 +687,39 @@ func (b *builtinInJSONSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, return 0, hasNull, nil } +type builtinInVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (b *builtinInVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinInVectorFloat32Sig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinInVectorFloat32Sig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + arg0, isNull0, err := b.args[0].EvalVectorFloat32(ctx, row) + if isNull0 || err != nil { + return 0, isNull0, err + } + var hasNull bool + for _, arg := range b.args[1:] { + evaledArg, isNull, err := arg.EvalVectorFloat32(ctx, row) + if err != nil { + return 0, true, err + } + if isNull { + hasNull = true + continue + } + result := arg0.Compare(evaledArg) + if result == 0 { + return 1, false, nil + } + } + return 0, hasNull, nil +} + type rowFunctionClass struct { baseFunctionClass } @@ -1265,6 +1304,8 @@ func (c *valuesFunctionClass) getFunction(ctx BuildContext, args []Expression) ( sig = &builtinValuesDurationSig{baseBuiltinFunc: bf, offset: c.offset} case types.ETJson: sig = &builtinValuesJSONSig{baseBuiltinFunc: bf, offset: c.offset} + default: + return nil, errors.Errorf("%s is not supported for VALUES()", c.tp.EvalType()) } return sig, nil } diff --git a/pkg/expression/builtin_vec.go b/pkg/expression/builtin_vec.go new file mode 100644 index 0000000000000..e1d8374461828 --- /dev/null +++ b/pkg/expression/builtin_vec.go @@ -0,0 +1,422 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "math" + + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tipb/go-tipb" +) + +var ( + _ functionClass = &vecDimsFunctionClass{} + _ functionClass = &vecL1DistanceFunctionClass{} + _ functionClass = &vecL2DistanceFunctionClass{} + _ functionClass = &vecNegativeInnerProductFunctionClass{} + _ functionClass = &vecCosineDistanceFunctionClass{} + _ functionClass = &vecL2NormFunctionClass{} + _ functionClass = &vecFromTextFunctionClass{} + _ functionClass = &vecAsTextFunctionClass{} +) + +var ( + _ builtinFunc = &builtinVecDimsSig{} + _ builtinFunc = &builtinVecL1DistanceSig{} + _ builtinFunc = &builtinVecL2DistanceSig{} + _ builtinFunc = &builtinVecNegativeInnerProductSig{} + _ builtinFunc = &builtinVecCosineDistanceSig{} + _ builtinFunc = &builtinVecL2NormSig{} + _ builtinFunc = &builtinVecFromTextSig{} + _ builtinFunc = &builtinVecAsTextSig{} +) + +type vecDimsFunctionClass struct { + baseFunctionClass +} + +type builtinVecDimsSig struct { + baseBuiltinFunc +} + +func (b *builtinVecDimsSig) Clone() builtinFunc { + newSig := &builtinVecDimsSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *vecDimsFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + argTps := make([]types.EvalType, 0, len(args)) + argTps = append(argTps, types.ETVectorFloat32) + + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, argTps...) + if err != nil { + return nil, err + } + sig := &builtinVecDimsSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_VecDimsSig) + return sig, nil +} + +func (b *builtinVecDimsSig) evalInt(ctx EvalContext, row chunk.Row) (res int64, isNull bool, err error) { + v, isNull, err := b.args[0].EvalVectorFloat32(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + return int64(v.Len()), false, nil +} + +type vecL1DistanceFunctionClass struct { + baseFunctionClass +} + +type builtinVecL1DistanceSig struct { + baseBuiltinFunc +} + +func (b *builtinVecL1DistanceSig) Clone() builtinFunc { + newSig := &builtinVecL1DistanceSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *vecL1DistanceFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + argTps := make([]types.EvalType, 0, len(args)) + argTps = append(argTps, types.ETVectorFloat32, types.ETVectorFloat32) + + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETReal, argTps...) + if err != nil { + return nil, err + } + sig := &builtinVecL1DistanceSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_VecL1DistanceSig) + return sig, nil +} + +func (b *builtinVecL1DistanceSig) evalReal(ctx EvalContext, row chunk.Row) (res float64, isNull bool, err error) { + v1, isNull, err := b.args[0].EvalVectorFloat32(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + v2, isNull, err := b.args[1].EvalVectorFloat32(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + d, err := v1.L1Distance(v2) + if err != nil { + return res, false, err + } + + if math.IsNaN(d) { + return 0, true, nil + } + return d, false, nil +} + +type vecL2DistanceFunctionClass struct { + baseFunctionClass +} + +type builtinVecL2DistanceSig struct { + baseBuiltinFunc +} + +func (b *builtinVecL2DistanceSig) Clone() builtinFunc { + newSig := &builtinVecL2DistanceSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *vecL2DistanceFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + argTps := make([]types.EvalType, 0, len(args)) + argTps = append(argTps, types.ETVectorFloat32, types.ETVectorFloat32) + + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETReal, argTps...) + if err != nil { + return nil, err + } + sig := &builtinVecL2DistanceSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_VecL2DistanceSig) + return sig, nil +} + +func (b *builtinVecL2DistanceSig) evalReal(ctx EvalContext, row chunk.Row) (res float64, isNull bool, err error) { + v1, isNull, err := b.args[0].EvalVectorFloat32(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + v2, isNull, err := b.args[1].EvalVectorFloat32(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + d, err := v1.L2Distance(v2) + if err != nil { + return res, false, err + } + + if math.IsNaN(d) { + return 0, true, nil + } + return d, false, nil +} + +type vecNegativeInnerProductFunctionClass struct { + baseFunctionClass +} + +type builtinVecNegativeInnerProductSig struct { + baseBuiltinFunc +} + +func (b *builtinVecNegativeInnerProductSig) Clone() builtinFunc { + newSig := &builtinVecNegativeInnerProductSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *vecNegativeInnerProductFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + argTps := make([]types.EvalType, 0, len(args)) + argTps = append(argTps, types.ETVectorFloat32, types.ETVectorFloat32) + + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETReal, argTps...) + if err != nil { + return nil, err + } + sig := &builtinVecNegativeInnerProductSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_VecNegativeInnerProductSig) + return sig, nil +} + +func (b *builtinVecNegativeInnerProductSig) evalReal(ctx EvalContext, row chunk.Row) (res float64, isNull bool, err error) { + v1, isNull, err := b.args[0].EvalVectorFloat32(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + v2, isNull, err := b.args[1].EvalVectorFloat32(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + d, err := v1.NegativeInnerProduct(v2) + if err != nil { + return res, false, err + } + + if math.IsNaN(d) { + return 0, true, nil + } + return d, false, nil +} + +type vecCosineDistanceFunctionClass struct { + baseFunctionClass +} + +type builtinVecCosineDistanceSig struct { + baseBuiltinFunc +} + +func (b *builtinVecCosineDistanceSig) Clone() builtinFunc { + newSig := &builtinVecCosineDistanceSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *vecCosineDistanceFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + argTps := make([]types.EvalType, 0, len(args)) + argTps = append(argTps, types.ETVectorFloat32, types.ETVectorFloat32) + + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETReal, argTps...) + if err != nil { + return nil, err + } + sig := &builtinVecCosineDistanceSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_VecCosineDistanceSig) + return sig, nil +} + +func (b *builtinVecCosineDistanceSig) evalReal(ctx EvalContext, row chunk.Row) (res float64, isNull bool, err error) { + v1, isNull, err := b.args[0].EvalVectorFloat32(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + v2, isNull, err := b.args[1].EvalVectorFloat32(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + d, err := v1.CosineDistance(v2) + if err != nil { + return res, false, err + } + + if math.IsNaN(d) { + return 0, true, nil + } + return d, false, nil +} + +type vecL2NormFunctionClass struct { + baseFunctionClass +} + +type builtinVecL2NormSig struct { + baseBuiltinFunc +} + +func (b *builtinVecL2NormSig) Clone() builtinFunc { + newSig := &builtinVecL2NormSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *vecL2NormFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + + argTps := make([]types.EvalType, 0, len(args)) + argTps = append(argTps, types.ETVectorFloat32) + + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETReal, argTps...) + if err != nil { + return nil, err + } + sig := &builtinVecL2NormSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_VecL2NormSig) + return sig, nil +} + +func (b *builtinVecL2NormSig) evalReal(ctx EvalContext, row chunk.Row) (res float64, isNull bool, err error) { + v, isNull, err := b.args[0].EvalVectorFloat32(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + d := v.L2Norm() + if math.IsNaN(d) { + return 0, true, nil + } + return d, false, nil +} + +type vecFromTextFunctionClass struct { + baseFunctionClass +} + +type builtinVecFromTextSig struct { + baseBuiltinFunc +} + +func (b *builtinVecFromTextSig) Clone() builtinFunc { + newSig := &builtinVecFromTextSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *vecFromTextFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + argTps := make([]types.EvalType, 0, len(args)) + argTps = append(argTps, types.ETString) + + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETVectorFloat32, argTps...) + if err != nil { + return nil, err + } + sig := &builtinVecFromTextSig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_VecFromTextSig) + return sig, nil +} + +func (b *builtinVecFromTextSig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (res types.VectorFloat32, isNull bool, err error) { + v, isNull, err := b.args[0].EvalString(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + vec, err := types.ParseVectorFloat32(v) + if err != nil { + return types.ZeroVectorFloat32, false, err + } + if err = vec.CheckDimsFitColumn(b.tp.GetFlen()); err != nil { + return types.ZeroVectorFloat32, isNull, err + } + + return vec, false, nil +} + +type vecAsTextFunctionClass struct { + baseFunctionClass +} + +type builtinVecAsTextSig struct { + baseBuiltinFunc +} + +func (b *builtinVecAsTextSig) Clone() builtinFunc { + newSig := &builtinVecAsTextSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *vecAsTextFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + argTps := make([]types.EvalType, 0, len(args)) + argTps = append(argTps, types.ETVectorFloat32) + + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, argTps...) + if err != nil { + return nil, err + } + sig := &builtinVecAsTextSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_VecAsTextSig) + return sig, nil +} + +func (b *builtinVecAsTextSig) evalString(ctx EvalContext, row chunk.Row) (res string, isNull bool, err error) { + v, isNull, err := b.args[0].EvalVectorFloat32(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + return v.String(), false, nil +} diff --git a/pkg/expression/chunk_executor.go b/pkg/expression/chunk_executor.go index fd8128fdff3b5..b8f93fcba1d74 100644 --- a/pkg/expression/chunk_executor.go +++ b/pkg/expression/chunk_executor.go @@ -15,6 +15,7 @@ package expression import ( + "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/types" @@ -168,6 +169,8 @@ func evalOneVec(ctx EvalContext, expr Expression, input *chunk.Chunk, output *ch return expr.VecEvalDuration(ctx, input, result) case types.ETJson: return expr.VecEvalJSON(ctx, input, result) + case types.ETVectorFloat32: + return expr.VecEvalVectorFloat32(ctx, input, result) case types.ETString: if err := expr.VecEvalString(ctx, input, result); err != nil { return err @@ -213,6 +216,8 @@ func evalOneVec(ctx EvalContext, expr Expression, input *chunk.Chunk, output *ch } output.SetCol(colIdx, buf) } + default: + return errors.Errorf("unsupported type %s during evaluation", ft.EvalType()) } return nil } @@ -243,10 +248,16 @@ func evalOneColumn(ctx EvalContext, expr Expression, iterator *chunk.Iterator4Ch for row := iterator.Begin(); err == nil && row != iterator.End(); row = iterator.Next() { err = executeToJSON(ctx, expr, fieldType, row, output, colID) } + case types.ETVectorFloat32: + for row := iterator.Begin(); err == nil && row != iterator.End(); row = iterator.Next() { + err = executeToVectorFloat32(ctx, expr, fieldType, row, output, colID) + } case types.ETString: for row := iterator.Begin(); err == nil && row != iterator.End(); row = iterator.Next() { err = executeToString(ctx, expr, fieldType, row, output, colID) } + default: + return errors.Errorf("unsupported type %s during evaluation", evalType) } return err } @@ -265,8 +276,12 @@ func evalOneCell(ctx EvalContext, expr Expression, row chunk.Row, output *chunk. err = executeToDuration(ctx, expr, fieldType, row, output, colID) case types.ETJson: err = executeToJSON(ctx, expr, fieldType, row, output, colID) + case types.ETVectorFloat32: + err = executeToVectorFloat32(ctx, expr, fieldType, row, output, colID) case types.ETString: err = executeToString(ctx, expr, fieldType, row, output, colID) + default: + return errors.Errorf("unsupported type %s during evaluation", evalType) } return err } @@ -370,6 +385,19 @@ func executeToJSON(ctx EvalContext, expr Expression, fieldType *types.FieldType, return nil } +func executeToVectorFloat32(ctx EvalContext, expr Expression, fieldType *types.FieldType, row chunk.Row, output *chunk.Chunk, colID int) error { + res, isNull, err := expr.EvalVectorFloat32(ctx, row) + if err != nil { + return err + } + if isNull { + output.AppendNull(colID) + } else { + output.AppendVectorFloat32(colID, res) + } + return nil +} + func executeToString(ctx EvalContext, expr Expression, fieldType *types.FieldType, row chunk.Row, output *chunk.Chunk, colID int) error { res, isNull, err := expr.EvalString(ctx, row) if err != nil { diff --git a/pkg/expression/column.go b/pkg/expression/column.go index acc0e00202876..dba257abddce4 100644 --- a/pkg/expression/column.go +++ b/pkg/expression/column.go @@ -83,6 +83,11 @@ func (col *CorrelatedColumn) VecEvalJSON(ctx EvalContext, input *chunk.Chunk, re return genVecFromConstExpr(ctx, col, types.ETJson, input, result) } +// VecEvalVectorFloat32 evaluates this expression in a vectorized manner. +func (col *CorrelatedColumn) VecEvalVectorFloat32(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error { + return genVecFromConstExpr(ctx, col, types.ETVectorFloat32, input, result) +} + // Traverse implements the TraverseDown interface. func (col *CorrelatedColumn) Traverse(action TraverseAction) Expression { return action.Transform(col) @@ -154,6 +159,14 @@ func (col *CorrelatedColumn) EvalJSON(ctx EvalContext, row chunk.Row) (types.Bin return col.Data.GetMysqlJSON(), false, nil } +// EvalVectorFloat32 returns VectorFloat32 representation of CorrelatedColumn. +func (col *CorrelatedColumn) EvalVectorFloat32(ctx EvalContext, row chunk.Row) (types.VectorFloat32, bool, error) { + if col.Data.IsNull() { + return types.ZeroVectorFloat32, true, nil + } + return col.Data.GetVectorFloat32(), false, nil +} + // Equal implements Expression interface. func (col *CorrelatedColumn) Equal(_ EvalContext, expr Expression) bool { return col.EqualColumn(expr) @@ -387,6 +400,12 @@ func (col *Column) VecEvalJSON(ctx EvalContext, input *chunk.Chunk, result *chun return nil } +// VecEvalVectorFloat32 evaluates this expression in a vectorized manner. +func (col *Column) VecEvalVectorFloat32(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error { + input.Column(col.Index).CopyReconstruct(input.Sel(), result) + return nil +} + const columnPrefix = "Column#" // StringWithCtx implements Expression interface. @@ -399,6 +418,21 @@ func (col *Column) String() string { return col.string(errors.RedactLogDisable) } +// StringForExplain implements Explainable interface. +func (col *Column) StringForExplain(_ ParamValues, redact string) string { + if col.IsHidden && col.VirtualExpr != nil { + // A hidden column without virtual expression indicates it's a stored type. + // a virtual column should be able to be stringified without context. + return col.VirtualExpr.StringForExplain(exprctx.EmptyParamValues, redact) + } + if col.OrigName != "" { + return col.OrigName + } + var builder strings.Builder + fmt.Fprintf(&builder, "%s%d", columnPrefix, col.UniqueID) + return builder.String() +} + func (col *Column) string(redact string) string { if col.IsHidden && col.VirtualExpr != nil { // A hidden column without virtual expression indicates it's a stored type. @@ -514,6 +548,14 @@ func (col *Column) EvalJSON(ctx EvalContext, row chunk.Row) (types.BinaryJSON, b return row.GetJSON(col.Index), false, nil } +// EvalVectorFloat32 returns VectorFloat32 representation of Column. +func (col *Column) EvalVectorFloat32(ctx EvalContext, row chunk.Row) (types.VectorFloat32, bool, error) { + if row.IsNull(col.Index) { + return types.ZeroVectorFloat32, true, nil + } + return row.GetVectorFloat32(col.Index), false, nil +} + // Clone implements Expression interface. func (col *Column) Clone() Expression { newCol := *col diff --git a/pkg/expression/constant.go b/pkg/expression/constant.go index e660feb629d64..ffe9c4c71401a 100644 --- a/pkg/expression/constant.go +++ b/pkg/expression/constant.go @@ -160,6 +160,26 @@ func (c *Constant) StringWithCtx(ctx ParamValues, redact string) string { return "?" } +// StringForExplain implements Explainable interface. +func (c *Constant) StringForExplain(ctx ParamValues, redact string) string { + if c.ParamMarker != nil { + dt, err := c.ParamMarker.GetUserVar(ctx) + intest.AssertNoError(err, "fail to get param") + if err != nil { + return "?" + } + c.Value.SetValue(dt.GetValue(), c.RetType) + } else if c.DeferredExpr != nil { + return c.DeferredExpr.StringForExplain(ctx, redact) + } + if redact == perrors.RedactLogDisable { + return fmt.Sprintf("%v", c.Value.GetValue()) + } else if redact == perrors.RedactLogMarker { + return fmt.Sprintf("‹%v›", c.Value.GetValue()) + } + return "?" +} + // Clone implements Expression interface. func (c *Constant) Clone() Expression { con := *c @@ -240,6 +260,14 @@ func (c *Constant) VecEvalJSON(ctx EvalContext, input *chunk.Chunk, result *chun return c.DeferredExpr.VecEvalJSON(ctx, input, result) } +// VecEvalVectorFloat32 evaluates this expression in a vectorized manner. +func (c *Constant) VecEvalVectorFloat32(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error { + if c.DeferredExpr == nil { + return genVecFromConstExpr(ctx, c, types.ETVectorFloat32, input, result) + } + return c.DeferredExpr.VecEvalVectorFloat32(ctx, input, result) +} + func (c *Constant) getLazyDatum(ctx EvalContext, row chunk.Row) (dt types.Datum, isLazy bool, err error) { if c.ParamMarker != nil { val, err := c.ParamMarker.GetUserVar(ctx) @@ -423,6 +451,21 @@ func (c *Constant) EvalJSON(ctx EvalContext, row chunk.Row) (types.BinaryJSON, b return dt.GetMysqlJSON(), false, nil } +// EvalVectorFloat32 returns VectorFloat32 representation of Constant. +func (c *Constant) EvalVectorFloat32(ctx EvalContext, row chunk.Row) (types.VectorFloat32, bool, error) { + dt, lazy, err := c.getLazyDatum(ctx, row) + if err != nil { + return types.ZeroVectorFloat32, false, err + } + if !lazy { + dt = c.Value + } + if c.GetType(ctx).GetType() == mysql.TypeNull || dt.IsNull() { + return types.ZeroVectorFloat32, true, nil + } + return dt.GetVectorFloat32(), false, nil +} + // Equal implements Expression interface. func (c *Constant) Equal(ctx EvalContext, b Expression) bool { y, ok := b.(*Constant) diff --git a/pkg/expression/distsql_builtin.go b/pkg/expression/distsql_builtin.go index 0012410a60c80..ed83d0cdd6688 100644 --- a/pkg/expression/distsql_builtin.go +++ b/pkg/expression/distsql_builtin.go @@ -1076,6 +1076,42 @@ func getSignatureByPB(ctx BuildContext, sigCode tipb.ScalarFuncSig, tp *tipb.Fie case tipb.ScalarFuncSig_FromBinary: // TODO: set the `cannotConvertStringAsWarning` accordingly f = &builtinInternalFromBinarySig{base, false} + case tipb.ScalarFuncSig_CastVectorFloat32AsString: + f = &builtinCastVectorFloat32AsStringSig{base} + case tipb.ScalarFuncSig_CastVectorFloat32AsVectorFloat32: + f = &builtinCastVectorFloat32AsVectorFloat32Sig{base} + case tipb.ScalarFuncSig_LTVectorFloat32: + f = &builtinLTVectorFloat32Sig{base} + case tipb.ScalarFuncSig_LEVectorFloat32: + f = &builtinLEVectorFloat32Sig{base} + case tipb.ScalarFuncSig_GTVectorFloat32: + f = &builtinGTVectorFloat32Sig{base} + case tipb.ScalarFuncSig_GEVectorFloat32: + f = &builtinGEVectorFloat32Sig{base} + case tipb.ScalarFuncSig_NEVectorFloat32: + f = &builtinNEVectorFloat32Sig{base} + case tipb.ScalarFuncSig_EQVectorFloat32: + f = &builtinEQVectorFloat32Sig{base} + case tipb.ScalarFuncSig_NullEQVectorFloat32: + f = &builtinNullEQVectorFloat32Sig{base} + case tipb.ScalarFuncSig_VectorFloat32AnyValue: + f = &builtinVectorFloat32AnyValueSig{base} + case tipb.ScalarFuncSig_VectorFloat32IsNull: + f = &builtinVectorFloat32IsNullSig{base} + case tipb.ScalarFuncSig_VecAsTextSig: + f = &builtinVecAsTextSig{base} + case tipb.ScalarFuncSig_VecDimsSig: + f = &builtinVecDimsSig{base} + case tipb.ScalarFuncSig_VecL1DistanceSig: + f = &builtinVecL1DistanceSig{base} + case tipb.ScalarFuncSig_VecL2DistanceSig: + f = &builtinVecL2DistanceSig{base} + case tipb.ScalarFuncSig_VecNegativeInnerProductSig: + f = &builtinVecNegativeInnerProductSig{base} + case tipb.ScalarFuncSig_VecCosineDistanceSig: + f = &builtinVecCosineDistanceSig{base} + case tipb.ScalarFuncSig_VecL2NormSig: + f = &builtinVecL2NormSig{base} default: e = ErrFunctionNotExists.GenWithStackByArgs("FUNCTION", sigCode) @@ -1149,6 +1185,8 @@ func PBToExpr(ctx BuildContext, expr *tipb.Expr, tps []*types.FieldType) (Expres return convertJSON(expr.Val) case tipb.ExprType_MysqlEnum: return convertEnum(expr.Val, expr.FieldType) + case tipb.ExprType_TiDBVectorFloat32: + return convertVectorFloat32(expr.Val) } if expr.Tp != tipb.ExprType_ScalarFunc { panic("should be a tipb.ExprType_ScalarFunc") @@ -1293,6 +1331,16 @@ func convertJSON(val []byte) (*Constant, error) { return &Constant{Value: d, RetType: types.NewFieldType(mysql.TypeJSON)}, nil } +func convertVectorFloat32(val []byte) (*Constant, error) { + v, _, err := types.ZeroCopyDeserializeVectorFloat32(val) + if err != nil { + return nil, errors.Errorf("invalid VectorFloat32 %x", val) + } + var d types.Datum + d.SetVectorFloat32(v) + return &Constant{Value: d, RetType: types.NewFieldType(mysql.TypeTiDBVectorFloat32)}, nil +} + func convertEnum(val []byte, tp *tipb.FieldType) (*Constant, error) { _, uVal, err := codec.DecodeUint(val) if err != nil { diff --git a/pkg/expression/distsql_builtin_test.go b/pkg/expression/distsql_builtin_test.go index 47214f4dd14e8..ca52e767c69a4 100644 --- a/pkg/expression/distsql_builtin_test.go +++ b/pkg/expression/distsql_builtin_test.go @@ -907,6 +907,12 @@ func datumExpr(t *testing.T, d types.Datum) *tipb.Expr { expr.Val = make([]byte, 0, 1024) expr.Val, err = codec.EncodeValue(time.UTC, expr.Val, d) require.NoError(t, err) + case types.KindVectorFloat32: + expr.Tp = tipb.ExprType_TiDBVectorFloat32 + var err error + expr.Val = make([]byte, 0, 1024) + expr.Val, err = codec.EncodeValue(nil, expr.Val, d) + require.NoError(t, err) case types.KindMysqlTime: expr.Tp = tipb.ExprType_MysqlTime var err error diff --git a/pkg/expression/explain.go b/pkg/expression/explain.go index cf43d6291c9fb..c31a1c67ea4ba 100644 --- a/pkg/expression/explain.go +++ b/pkg/expression/explain.go @@ -184,14 +184,14 @@ func ExplainExpressionList(ctx EvalContext, exprs []Expression, schema *Schema, for i, expr := range exprs { switch expr.(type) { case *Column, *CorrelatedColumn: - builder.WriteString(expr.StringWithCtx(ctx, redactMode)) + builder.WriteString(expr.StringForExplain(ctx, redactMode)) if expr.StringWithCtx(ctx, redactMode) != schema.Columns[i].StringWithCtx(ctx, redactMode) { // simple col projected again with another uniqueID without origin name. builder.WriteString("->") builder.WriteString(schema.Columns[i].StringWithCtx(ctx, redactMode)) } case *Constant: - v := expr.StringWithCtx(ctx, errors.RedactLogDisable) + v := expr.StringForExplain(ctx, errors.RedactLogDisable) length := 64 if len(v) < length { redact.WriteRedact(builder, v, redactMode) @@ -200,11 +200,11 @@ func ExplainExpressionList(ctx EvalContext, exprs []Expression, schema *Schema, fmt.Fprintf(builder, "(len:%d)", len(v)) } builder.WriteString("->") - builder.WriteString(schema.Columns[i].StringWithCtx(ctx, redactMode)) + builder.WriteString(schema.Columns[i].StringForExplain(ctx, redactMode)) default: - builder.WriteString(expr.StringWithCtx(ctx, redactMode)) + builder.WriteString(expr.StringForExplain(ctx, redactMode)) builder.WriteString("->") - builder.WriteString(schema.Columns[i].StringWithCtx(ctx, redactMode)) + builder.WriteString(schema.Columns[i].StringForExplain(ctx, redactMode)) } if i+1 < len(exprs) { builder.WriteString(", ") diff --git a/pkg/expression/expr_to_pb.go b/pkg/expression/expr_to_pb.go index 4c74eedb5e441..9ecce6745657c 100644 --- a/pkg/expression/expr_to_pb.go +++ b/pkg/expression/expr_to_pb.go @@ -175,6 +175,9 @@ func (pc *PbConverter) encodeDatum(ft *types.FieldType, d types.Datum) (tipb.Exp case types.KindMysqlEnum: tp = tipb.ExprType_MysqlEnum val = codec.EncodeUint(nil, d.GetUint64()) + case types.KindVectorFloat32: + tp = tipb.ExprType_TiDBVectorFloat32 + val = d.GetVectorFloat32().ZeroCopySerialize() default: return tp, nil, false } diff --git a/pkg/expression/expression.go b/pkg/expression/expression.go index ad8f9bf01b7c1..4c5805cc448a4 100644 --- a/pkg/expression/expression.go +++ b/pkg/expression/expression.go @@ -138,6 +138,9 @@ type VecExpr interface { // VecEvalJSON evaluates this expression in a vectorized manner. VecEvalJSON(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error + + // VecEvalBool evaluates this expression in a vectorized manner. + VecEvalVectorFloat32(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error } // TraverseAction define the interface for action when traversing down an expression. @@ -163,9 +166,16 @@ const ( ConstStrict ) +// Explainable is the interface for expressions to output themselves in EXPLAIN context. +type Explainable interface { + // StringForExplain + StringForExplain(ParamValues, string) string +} + // Expression represents all scalar expression in SQL. type Expression interface { VecExpr + Explainable CollationInfo Traverse(TraverseAction) Expression @@ -194,6 +204,9 @@ type Expression interface { // EvalJSON returns the JSON representation of expression. EvalJSON(ctx EvalContext, row chunk.Row) (val types.BinaryJSON, isNull bool, err error) + // EvalVectorFloat32 returns the VectorFloat32 representation of expression. + EvalVectorFloat32(ctx EvalContext, row chunk.Row) (val types.VectorFloat32, isNull bool, err error) + // GetType gets the type that the expression returns. GetType(ctx EvalContext) *types.FieldType @@ -584,6 +597,20 @@ func toBool(tc types.Context, tp *types.FieldType, eType types.EvalType, buf *ch } } } + case types.ETVectorFloat32: + for i := range sel { + if buf.IsNull(i) { + isZero[i] = -1 + } else { + if buf.GetVectorFloat32(i).IsZeroValue() { + isZero[i] = 0 + } else { + isZero[i] = 1 + } + } + } + default: + return errors.Errorf("unsupported type %s during evaluation", eType) } return nil } @@ -632,10 +659,12 @@ func EvalExpr(ctx EvalContext, vecEnabled bool, expr Expression, evalType types. err = expr.VecEvalString(ctx, input, result) case types.ETJson: err = expr.VecEvalJSON(ctx, input, result) + case types.ETVectorFloat32: + err = expr.VecEvalVectorFloat32(ctx, input, result) case types.ETDecimal: err = expr.VecEvalDecimal(ctx, input, result) default: - err = fmt.Errorf("invalid eval type %v", expr.GetType(ctx).EvalType()) + err = errors.Errorf("unsupported type %s during evaluation", evalType) } } else { ind, n := 0, input.NumRows() @@ -727,6 +756,19 @@ func EvalExpr(ctx EvalContext, vecEnabled bool, expr Expression, evalType types. result.AppendJSON(value) } } + case types.ETVectorFloat32: + result.ReserveVectorFloat32(n) + for it := iter.Begin(); it != iter.End(); it = iter.Next() { + value, isNull, err := expr.EvalVectorFloat32(ctx, it) + if err != nil { + return err + } + if isNull { + result.AppendNull() + } else { + result.AppendVectorFloat32(value) + } + } case types.ETDecimal: result.ResizeDecimal(n, false) d64s := result.Decimals() @@ -743,7 +785,7 @@ func EvalExpr(ctx EvalContext, vecEnabled bool, expr Expression, evalType types. ind++ } default: - err = fmt.Errorf("invalid eval type %v", expr.GetType(ctx).EvalType()) + err = errors.Errorf("unsupported type %s during evaluation", expr.GetType(ctx).EvalType()) } } return diff --git a/pkg/expression/infer_pushdown.go b/pkg/expression/infer_pushdown.go index 4af5c0e912a8d..60648002d4c55 100644 --- a/pkg/expression/infer_pushdown.go +++ b/pkg/expression/infer_pushdown.go @@ -198,6 +198,9 @@ func scalarExprSupportedByTiKV(ctx EvalContext, sf *ScalarFunction) bool { ast.JSONInsert, ast.JSONReplace, ast.JSONRemove, ast.JSONLength, ast.JSONMergePatch, ast.JSONUnquote, ast.JSONContains, ast.JSONValid, ast.JSONMemberOf, ast.JSONArrayAppend, + // vector functions. + ast.VecDims, ast.VecL1Distance, ast.VecL2Distance, ast.VecNegativeInnerProduct, ast.VecCosineDistance, ast.VecL2Norm, ast.VecAsText, + // date functions. ast.Date, ast.Week /* ast.YearWeek, ast.ToSeconds */, ast.DateDiff, /* ast.TimeDiff, ast.AddTime, ast.SubTime, */ @@ -335,6 +338,9 @@ func scalarExprSupportedByFlash(ctx EvalContext, function *ScalarFunction) bool return function.GetArgs()[0].GetType(ctx).GetType() != mysql.TypeYear case tipb.ScalarFuncSig_CastTimeAsDuration: return retType.GetType() == mysql.TypeDuration + case tipb.ScalarFuncSig_CastVectorFloat32AsString, + tipb.ScalarFuncSig_CastVectorFloat32AsVectorFloat32: + return true case tipb.ScalarFuncSig_CastIntAsJson, tipb.ScalarFuncSig_CastRealAsJson, tipb.ScalarFuncSig_CastDecimalAsJson, tipb.ScalarFuncSig_CastStringAsJson, tipb.ScalarFuncSig_CastTimeAsJson, tipb.ScalarFuncSig_CastDurationAsJson, tipb.ScalarFuncSig_CastJsonAsJson: return true @@ -397,6 +403,8 @@ func scalarExprSupportedByFlash(ctx EvalContext, function *ScalarFunction) bool return true case ast.IsIPv4, ast.IsIPv6: return true + case ast.VecDims, ast.VecL1Distance, ast.VecL2Distance, ast.VecNegativeInnerProduct, ast.VecCosineDistance, ast.VecL2Norm, ast.VecAsText: + return true case ast.Grouping: // grouping function for grouping sets identification. return true } diff --git a/pkg/expression/integration_test/BUILD.bazel b/pkg/expression/integration_test/BUILD.bazel index 8d15c35a41453..7fdbe83f3b79a 100644 --- a/pkg/expression/integration_test/BUILD.bazel +++ b/pkg/expression/integration_test/BUILD.bazel @@ -8,7 +8,7 @@ go_test( "main_test.go", ], flaky = True, - shard_count = 27, + shard_count = 42, deps = [ "//pkg/config", "//pkg/domain", diff --git a/pkg/expression/integration_test/integration_test.go b/pkg/expression/integration_test/integration_test.go index 8d289da9ca611..575c8c31d39a9 100644 --- a/pkg/expression/integration_test/integration_test.go +++ b/pkg/expression/integration_test/integration_test.go @@ -57,29 +57,613 @@ import ( "github.com/tikv/client-go/v2/oracle" ) -func TestVector(t *testing.T) { - // Currently we only allow parsing Vector type, but not using it. +func TestVectorColumnInfo(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + + // MUST enable vector type first. + tk.MustExec("drop table if exists t;") + tk.MustExec("SET @@GLOBAL.TIDB_ENABLE_VECTOR_TYPE=0;") + tk.MustExecToErr("create table t(embedding VECTOR)") + + // Create vector type column without specified dimension. + tk.MustExec("SET @@GLOBAL.TIDB_ENABLE_VECTOR_TYPE=1;") + tk.MustExec("create table t(embedding VECTOR)") + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t(embedding VECTOR)") + + // SHOW CREATE TABLE + tk.MustQuery("show create table t").Check(testkit.Rows( + "t CREATE TABLE `t` (\n" + + " `embedding` vector DEFAULT NULL\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin", + )) + + // SHOW COLUMNS + tk.MustQuery("show columns from t").Check(testkit.Rows( + "embedding vector YES ", + )) + + // Create vector type column with specified dimension. + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t(embedding VECTOR(3))") + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t(embedding VECTOR(3))") + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t(embedding VECTOR(0))") + + // SHOW CREATE TABLE + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t(embedding VECTOR(3))") + tk.MustQuery("show create table t").Check(testkit.Rows( + "t CREATE TABLE `t` (\n" + + " `embedding` vector(3) DEFAULT NULL\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin", + )) + + // SHOW COLUMNS + tk.MustQuery("show columns from t").Check(testkit.Rows( + "embedding vector(3) YES ", + )) + + // INFORMATION_SCHEMA.COLUMNS + tk.MustQuery("SELECT data_type, column_type FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = 't'").Check(testkit.Rows( + "vector vector(3)", + )) + + // Vector dimension MUST be equal or less than 16000. + tk.MustExec("drop table if exists t;") + tk.MustGetErrMsg("create table t(embedding VECTOR(16001))", "vector cannot have more than 16000 dimensions") +} +func TestVectorConstantExplain(t *testing.T) { store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("SET @@GLOBAL.TIDB_ENABLE_VECTOR_TYPE=1;") + tk.MustExec("CREATE TABLE t(c VECTOR);") + tk.MustQuery(`EXPLAIN SELECT VEC_COSINE_DISTANCE(c, '[1,2,3,4,5,6,7,8,9,10,11]') FROM t;`).Check(testkit.Rows( + "Projection_3 10000.00 root vec_cosine_distance(test.t.c, [1,2,3,4,5,(6 more)...])->Column#3", + "└─TableReader_5 10000.00 root data:TableFullScan_4", + " └─TableFullScan_4 10000.00 cop[tikv] table:t keep order:false, stats:pseudo", + )) + tk.MustQuery(`EXPLAIN SELECT VEC_COSINE_DISTANCE(c, VEC_FROM_TEXT('[1,2,3,4,5,6,7,8,9,10,11]')) FROM t;`).Check(testkit.Rows( + "Projection_3 10000.00 root vec_cosine_distance(test.t.c, [1,2,3,4,5,(6 more)...])->Column#3", + "└─TableReader_5 10000.00 root data:TableFullScan_4", + " └─TableFullScan_4 10000.00 cop[tikv] table:t keep order:false, stats:pseudo", + )) +} +func TestFixedVector(t *testing.T) { + store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) tk.MustExec("use test") + tk.MustExec("SET @@GLOBAL.TIDB_ENABLE_VECTOR_TYPE=1;") + + tk.MustExec("create table t(embedding VECTOR)") + tk.MustExec("insert into t values ('[1,2,3]')") + tk.MustExec("insert into t values ('[1,2,3,4]')") + + // Failed to modify column type cause vectors with different dimension. + tk.MustContainErrMsg("alter table t modify column embedding VECTOR(3)", "vector has 4 dimensions, does not fit VECTOR(3)") + + // Mixed dimension to fixed dimension. + tk.MustExec("delete from t where vec_dims(embedding) != 3") + tk.MustExec("alter table t modify column embedding VECTOR(3)") + tk.MustGetErrMsg("insert into t values ('[]')", "vector has 0 dimensions, does not fit VECTOR(3)") + tk.MustGetErrMsg("insert into t values ('[1,2,3,4]')", "vector has 4 dimensions, does not fit VECTOR(3)") + tk.MustGetErrMsg("insert into t values (VEC_FROM_TEXT('[]'))", "vector has 0 dimensions, does not fit VECTOR(3)") + tk.MustGetErrMsg("insert into t values (VEC_FROM_TEXT('[1,2,3,4]'))", "vector has 4 dimensions, does not fit VECTOR(3)") + tk.MustGetErrMsg("update t set embedding = '[1,2,3,4]' where embedding = '[1,2,3]'", "vector has 4 dimensions, does not fit VECTOR(3)") + tk.MustGetErrMsg("update t set embedding = '[]' where embedding = '[1,2,3]'", "vector has 0 dimensions, does not fit VECTOR(3)") + + // Fixed dimension to mixed dimension. + tk.MustExec("alter table t modify column embedding VECTOR") + tk.MustExec("insert into t values ('[1,2,3,4]')") + + // Vector dimension MUST be equal or less than 16000. + tk.MustGetErrMsg("alter table t modify column embedding VECTOR(16001)", "vector cannot have more than 16000 dimensions") +} + +func TestVectorVariable(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("USE test;") + + tk.MustExec("CREATE USER semroot") + tk.MustExec(`GRANT ALL PRIVILEGES ON *.* TO semroot;`) + tk.MustExec("GRANT RESTRICTED_VARIABLES_ADMIN ON *.* to semroot;") + + sem.Disable() + + tk.MustExecToErr(`CREATE TABLE t1 (v VECTOR);`) + + sem.Enable() + + // root cannot set global variable in SEMLevelBasic + require.NoError(t, tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil)) + tk.MustExecToErr(`CREATE TABLE t1 (v VECTOR);`) + tk.MustExecToErr("SET @@GLOBAL.TIDB_ENABLE_VECTOR_TYPE=1;") + tk.MustExecToErr(`CREATE TABLE t1 (v VECTOR);`) + + // semroot can set global variable in SEMLevelBasic + require.NoError(t, tk.Session().Auth(&auth.UserIdentity{Username: "semroot", Hostname: "localhost"}, nil, nil, nil)) + tk.MustExecToErr(`CREATE TABLE t1 (v VECTOR);`) + tk.MustExec("SET @@GLOBAL.TIDB_ENABLE_VECTOR_TYPE=1;") + tk.MustExec(`CREATE TABLE t1 (v VECTOR);`) + tk.MustExec("SET @@GLOBAL.TIDB_ENABLE_VECTOR_TYPE=0;") + tk.MustExec(`DROP TABLE t1;`) + tk.MustExecToErr(`CREATE TABLE t1 (v VECTOR);`) + + sem.Enable() + + // root cannot set global variable in SEMLevelStrict + require.NoError(t, tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil)) + tk.MustExecToErr("SET @@GLOBAL.TIDB_ENABLE_VECTOR_TYPE=1;") + tk.MustExecToErr(`CREATE TABLE t1 (v VECTOR);`) + + // semroot can set global variable in SEMLevelStrict + require.NoError(t, tk.Session().Auth(&auth.UserIdentity{Username: "semroot", Hostname: "localhost"}, nil, nil, nil)) + tk.MustExecToErr(`CREATE TABLE t1 (v VECTOR);`) + tk.MustExec("SET @@GLOBAL.TIDB_ENABLE_VECTOR_TYPE=1;") + tk.MustExec(`CREATE TABLE t1 (v VECTOR);`) + tk.MustExec("SET @@GLOBAL.TIDB_ENABLE_VECTOR_TYPE=0;") + tk.MustExec(`DROP TABLE t1;`) + tk.MustExecToErr(`CREATE TABLE t1 (v VECTOR);`) + + sem.Disable() + + // root can set global variable when SEM is disabled + require.NoError(t, tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil)) + tk.MustExec("SET @@GLOBAL.TIDB_ENABLE_VECTOR_TYPE=1;") + tk.MustExec(`CREATE TABLE t1 (v VECTOR);`) + tk.MustExec("SET @@GLOBAL.TIDB_ENABLE_VECTOR_TYPE=0;") + tk.MustExec(`DROP TABLE t1;`) + tk.MustExecToErr(`CREATE TABLE t1 (v VECTOR);`) + + // semroot can set global variable when SEM is disabled + require.NoError(t, tk.Session().Auth(&auth.UserIdentity{Username: "semroot", Hostname: "localhost"}, nil, nil, nil)) + tk.MustExec("SET @@GLOBAL.TIDB_ENABLE_VECTOR_TYPE=1;") + tk.MustExec(`CREATE TABLE t1 (v VECTOR);`) + tk.MustExec("SET @@GLOBAL.TIDB_ENABLE_VECTOR_TYPE=0;") + tk.MustExec(`DROP TABLE t1;`) + tk.MustExecToErr(`CREATE TABLE t1 (v VECTOR);`) +} + +func TestVector(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("USE test;") + + tk.MustExecToErr(`CREATE TABLE t1 (v VECTOR);`) + + tk.MustExec("SET @@GLOBAL.TIDB_ENABLE_VECTOR_TYPE=1;") + tk.MustExec(`CREATE TABLE t1 (v VECTOR);`) + tk.MustExecToErr(`INSERT INTO t1 VALUES ('abc');`) + tk.MustExec(`INSERT INTO t1 VALUES ('[1,2.1,3.3]');`) + tk.MustExecToErr(`INSERT INTO t1 VALUES ('[1,2.1,null]');`) + tk.MustExecToErr(`INSERT INTO t1 VALUES ('[1,2.1,inf]');`) + tk.MustExecToErr(`INSERT INTO t1 VALUES ('[1,2.1,nan]');`) + tk.MustExec(`INSERT INTO t1 VALUES ('[]');`) + tk.MustExec(`INSERT INTO t1 VALUES (NULL);`) + tk.MustQuery("SELECT * FROM t1;").Check(testkit.Rows("[1,2.1,3.3]", "[]", "")) + tk.MustQuery("SELECT VEC_DIMS(v) FROM t1;").Check(testkit.Rows("3", "0", "")) + + tk.MustQuery("SELECT VEC_DIMS(NULL);").Check(testkit.Rows("")) + tk.MustQuery("SELECT VEC_DIMS('[]');").Check(testkit.Rows("0")) + tk.MustQuery("SELECT VEC_DIMS('[5, 3, 2]');").Check(testkit.Rows("3")) + tk.MustQuery("SELECT VEC_FROM_TEXT('[]');").Check(testkit.Rows("[]")) + + // Basic sort + tk.MustExec(`CREATE TABLE t(val VECTOR);`) + tk.MustExec(`INSERT INTO t VALUES + ('[8.7, 5.7, 7.7, 9.8, 1.5]'), + ('[3.6, 9.7, 2.4, 6.6, 4.9]'), + ('[4.7, 4.9, 2.6, 5.2, 7.4]'), + ('[7.7, 6.7, 8.3, 7.8, 5.7]'), + ('[1.4, 4.5, 8.5, 7.7, 6.2]'); + `) + tk.MustQuery(`SELECT * FROM t ORDER BY val DESC;`).Check(testkit.Rows( + "[8.7,5.7,7.7,9.8,1.5]", + "[7.7,6.7,8.3,7.8,5.7]", + "[4.7,4.9,2.6,5.2,7.4]", + "[3.6,9.7,2.4,6.6,4.9]", + "[1.4,4.5,8.5,7.7,6.2]", + )) + + // Golang produce different results in different Arch for float points. + // Adding a ROUND to make this test stable. + // See https://go.dev/ref/spec#Arithmetic_operators + tk.MustQuery(`SELECT val, + ROUND(VEC_Cosine_Distance(val, '[1,2,3,4,5]'), 5) AS d + FROM t ORDER BY d DESC; + `).Check(testkit.Rows( + "[8.7,5.7,7.7,9.8,1.5] 0.25641", + "[3.6,9.7,2.4,6.6,4.9] 0.18577", + "[7.7,6.7,8.3,7.8,5.7] 0.12677", + "[4.7,4.9,2.6,5.2,7.4] 0.06925", + "[1.4,4.5,8.5,7.7,6.2] 0.04973", + )) +} + +func TestVectorOperators(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("USE test;") + tk.MustExec("SET @@GLOBAL.TIDB_ENABLE_VECTOR_TYPE=1;") + tk.MustExec(`CREATE TABLE t(embedding VECTOR);`) + tk.MustExec(`INSERT INTO t VALUES + ('[1, 2, 3]'), + ('[4, 5, 6]'), + ('[7, 8, 9]'); + `) + + tk.MustQuery(`SELECT VEC_FROM_TEXT('[]') IS TRUE`).Check(testkit.Rows("0")) + tk.MustQuery(`SELECT VEC_FROM_TEXT('[]') IS FALSE`).Check(testkit.Rows("1")) + tk.MustQuery(`SELECT VEC_FROM_TEXT('[]') IS UNKNOWN`).Check(testkit.Rows("0")) + tk.MustQuery(`SELECT VEC_FROM_TEXT('[]') IS NOT NULL`).Check(testkit.Rows("1")) + tk.MustQuery(`SELECT VEC_FROM_TEXT('[]') IS NULL`).Check(testkit.Rows("0")) + tk.MustQuery(`SELECT * FROM t WHERE embedding = VEC_FROM_TEXT('[1,2,3]');`).Check(testkit.Rows("[1,2,3]")) + tk.MustQuery(`SELECT * FROM t WHERE embedding BETWEEN '[1, 2, 3]' AND '[4, 5, 6]'`).Check(testkit.Rows("[1,2,3]", "[4,5,6]")) + tk.MustQuery(`SELECT * FROM t WHERE embedding IN ('[1, 2, 3]', '[4, 5, 6]')`).Check(testkit.Rows("[1,2,3]", "[4,5,6]")) + tk.MustQuery(`SELECT * FROM t WHERE embedding NOT IN ('[1, 2, 3]', '[4, 5, 6]')`).Check(testkit.Rows("[7,8,9]")) +} + +func TestVectorCompare(t *testing.T) { + store := testkit.CreateMockStore(t) - err := tk.ExecToErr("CREATE TABLE c(a VECTOR)") - require.ErrorContains(t, err, "vector type is not supported") - err = tk.ExecToErr("CREATE TABLE c(a VECTOR(3))") - require.ErrorContains(t, err, "vector type is not supported") - err = tk.ExecToErr("SELECT CAST('123' AS VECTOR)") - require.ErrorContains(t, err, "vector type is not supported") + tk := testkit.NewTestKit(t, store) + tk.MustExec("USE test;") + + tk.MustQuery("SELECT VEC_FROM_TEXT('[]') = VEC_FROM_TEXT('[]');").Check(testkit.Rows("1")) + tk.MustQuery("SELECT VEC_FROM_TEXT('[]') != VEC_FROM_TEXT('[]');").Check(testkit.Rows("0")) + tk.MustQuery("SELECT VEC_FROM_TEXT('[]') > VEC_FROM_TEXT('[]');").Check(testkit.Rows("0")) + tk.MustQuery("SELECT VEC_FROM_TEXT('[]') >= VEC_FROM_TEXT('[]');").Check(testkit.Rows("1")) + tk.MustQuery("SELECT VEC_FROM_TEXT('[]') < VEC_FROM_TEXT('[]');").Check(testkit.Rows("0")) + tk.MustQuery("SELECT VEC_FROM_TEXT('[]') <= VEC_FROM_TEXT('[]');").Check(testkit.Rows("1")) + + tk.MustQuery("SELECT VEC_FROM_TEXT('[1, 2, 3]') = VEC_FROM_TEXT('[1, 2, 3]');").Check(testkit.Rows("1")) + tk.MustQuery("SELECT VEC_FROM_TEXT('[1, 2, 3]') != VEC_FROM_TEXT('[1, 2, 3]');").Check(testkit.Rows("0")) + + tk.MustQuery("SELECT VEC_FROM_TEXT('[1, 2, 3]') > VEC_FROM_TEXT('[1]');").Check(testkit.Rows("1")) + tk.MustQuery("SELECT VEC_FROM_TEXT('[1, 2, 3]') >= VEC_FROM_TEXT('[1]');").Check(testkit.Rows("1")) + tk.MustQuery("SELECT VEC_FROM_TEXT('[1, 2, 3]') < VEC_FROM_TEXT('[1]');").Check(testkit.Rows("0")) + tk.MustQuery("SELECT VEC_FROM_TEXT('[1, 2, 3]') <= VEC_FROM_TEXT('[1]');").Check(testkit.Rows("0")) + + tk.MustQuery("SELECT VEC_FROM_TEXT('[1, 2, 3]') > '[1]';").Check(testkit.Rows("1")) + tk.MustQuery("SELECT VEC_FROM_TEXT('[1, 2, 3]') >= '[1]';").Check(testkit.Rows("1")) + tk.MustQuery("SELECT VEC_FROM_TEXT('[1, 2, 3]') < '[1]';").Check(testkit.Rows("0")) + tk.MustQuery("SELECT VEC_FROM_TEXT('[1, 2, 3]') <= '[1]';").Check(testkit.Rows("0")) + + tk.MustQuery(`SELECT GREATEST(VEC_FROM_TEXT('[1, 2, 3]'), VEC_FROM_TEXT('[4, 5, 6]'), VEC_FROM_TEXT('[7, 8, 9]')) AS result;`).Check(testkit.Rows("[7,8,9]")) + tk.MustQuery(`SELECT LEAST(VEC_FROM_TEXT('[1, 2, 3]'), VEC_FROM_TEXT('[4, 5, 6]'), VEC_FROM_TEXT('[7, 8, 9]')) AS result;`).Check(testkit.Rows("[1,2,3]")) + tk.MustQuery(`SELECT COALESCE(VEC_FROM_TEXT('[1, 2, 3]'), VEC_FROM_TEXT('[4, 5, 6]')) AS result;`).Check(testkit.Rows("[1,2,3]")) + tk.MustQuery(`SELECT COALESCE(NULL, VEC_FROM_TEXT('[1, 2, 3]')) AS result;`).Check(testkit.Rows("[1,2,3]")) + tk.MustQuery(`SELECT COALESCE(VEC_FROM_TEXT('[1, 2, 3]'), 1) AS result;`).Check(testkit.Rows("[1,2,3]")) + tk.MustQuery(`SELECT COALESCE(VEC_FROM_TEXT('[1, 2, 3]'), '1') AS result;`).Check(testkit.Rows("[1,2,3]")) + tk.MustQuery(`SELECT COALESCE(1, VEC_FROM_TEXT('[1, 2, 3]'), 1) AS result;`).Check(testkit.Rows("1")) + tk.MustQuery(`SELECT COALESCE('1', VEC_FROM_TEXT('[1, 2, 3]'), '1') AS result;`).Check(testkit.Rows("1")) +} + +func TestVectorConversion(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("USE test;") + tk.MustExec("SET @@GLOBAL.TIDB_ENABLE_VECTOR_TYPE=1;") + tk.MustExec(`CREATE TABLE t1 (val vector);`) + + // CAST + tk.MustQuery("SELECT CAST(VEC_FROM_TEXT('[1,2,3]') AS BINARY);").Check(testkit.Rows("[1,2,3]")) + tk.MustQuery("SELECT CAST(VEC_FROM_TEXT('[1,2,3]') AS CHAR);").Check(testkit.Rows("[1,2,3]")) + tk.MustQueryToErr("SELECT CAST(VEC_FROM_TEXT('[1,2,3]') AS JSON);") + tk.MustQueryToErr("SELECT CAST(VEC_FROM_TEXT('[1,2,3]') AS DECIMAL(2));") + tk.MustQueryToErr("SELECT CAST(VEC_FROM_TEXT('[1,2,3]') AS DOUBLE);") + tk.MustQueryToErr("SELECT CAST(VEC_FROM_TEXT('[1,2,3]') AS FLOAT);") + tk.MustQueryToErr("SELECT CAST(VEC_FROM_TEXT('[1,2,3]') AS REAL);") + tk.MustQueryToErr("SELECT CAST(VEC_FROM_TEXT('[1,2,3]') AS SIGNED);") + tk.MustQueryToErr("SELECT CAST(VEC_FROM_TEXT('[1,2,3]') AS UNSIGNED);") + tk.MustQueryToErr("SELECT CAST(VEC_FROM_TEXT('[1,2,3]') AS YEAR);") + tk.MustQueryToErr("SELECT CAST(VEC_FROM_TEXT('[1,2,3]') AS DATETIME);") + tk.MustQueryToErr("SELECT CAST(VEC_FROM_TEXT('[1,2,3]') AS DATE);") + tk.MustQueryToErr("SELECT CAST(VEC_FROM_TEXT('[1,2,3]') AS TIME);") + + tk.MustQuery("SELECT CAST('[1,2,3]' AS VECTOR);").Check(testkit.Rows("[1,2,3]")) + tk.MustQuery("SELECT CAST('[]' AS VECTOR);").Check(testkit.Rows("[]")) + tk.MustQuery("SELECT CAST('[1,2,3]' AS VECTOR);").Check(testkit.Rows("[1,2,3]")) + tk.MustContainErrMsg("SELECT CAST('[1,2,3]' AS VECTOR);", "Only VECTOR is supported for now") + + tk.MustQuery("SELECT CAST('[1,2,3]' AS VECTOR(3));").Check(testkit.Rows("[1,2,3]")) + err := tk.QueryToErr("SELECT CAST('[1,2,3]' AS VECTOR(2));") + require.EqualError(t, err, "vector has 3 dimensions, does not fit VECTOR(2)") + + tk.MustQuery("SELECT CAST(VEC_FROM_TEXT('[1,2,3]') AS VECTOR(3));").Check(testkit.Rows("[1,2,3]")) + err = tk.QueryToErr("SELECT CAST(VEC_FROM_TEXT('[1,2,3]') AS VECTOR(2));") + require.EqualError(t, err, "vector has 3 dimensions, does not fit VECTOR(2)") + + // CONVERT + tk.MustQuery("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), BINARY);").Check(testkit.Rows("[1,2,3]")) + tk.MustQuery("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), CHAR);").Check(testkit.Rows("[1,2,3]")) + tk.MustQueryToErr("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), JSON);") + tk.MustQueryToErr("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), DECIMAL);") + tk.MustQueryToErr("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), DOUBLE);") + tk.MustQueryToErr("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), FLOAT);") + tk.MustQueryToErr("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), REAL);") + tk.MustQueryToErr("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), SIGNED);") + tk.MustQueryToErr("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), UNSIGNED);") + tk.MustQueryToErr("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), YEAR);") + tk.MustQueryToErr("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), DATETIME);") + tk.MustQueryToErr("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), DATE);") + tk.MustQueryToErr("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), TIME);") + + tk.MustQuery("SELECT CONVERT('[1,2,3]', VECTOR);").Check(testkit.Rows("[1,2,3]")) + tk.MustQuery("SELECT CONVERT('[]', VECTOR);").Check(testkit.Rows("[]")) + tk.MustQuery("SELECT CONVERT('[1,2,3]', VECTOR);").Check(testkit.Rows("[1,2,3]")) + tk.MustContainErrMsg("SELECT CONVERT('[1,2,3]', VECTOR);", "Only VECTOR is supported for now") + + tk.MustQuery("SELECT CONVERT('[1,2,3]', VECTOR(3));").Check(testkit.Rows("[1,2,3]")) + err = tk.QueryToErr("SELECT CONVERT('[1,2,3]', VECTOR(2));") + require.EqualError(t, err, "vector has 3 dimensions, does not fit VECTOR(2)") + + tk.MustQuery("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), VECTOR(3));").Check(testkit.Rows("[1,2,3]")) + err = tk.QueryToErr("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), VECTOR(2));") + require.EqualError(t, err, "vector has 3 dimensions, does not fit VECTOR(2)") +} + +func TestVectorAssignVariable(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("USE test;") + tk.MustExec("SET @@GLOBAL.TIDB_ENABLE_VECTOR_TYPE=1;") + tk.MustExec(`SET @a = VEC_FROM_TEXT('[1,2,3]');`) + tk.MustQuery(`SELECT @a;`).Check(testkit.Rows("[1,2,3]")) +} + +func TestVectorControlFlow(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("USE test;") + tk.MustExec("SET @@GLOBAL.TIDB_ENABLE_VECTOR_TYPE=1;") + + // IF + tk.MustQuery("SELECT IF(VEC_FROM_TEXT('[1, 2, 3]'), 1, 0);").Check(testkit.Rows("1")) + tk.MustQuery("SELECT IF(TRUE, VEC_FROM_TEXT('[1, 2, 3]'), VEC_FROM_TEXT('[4, 5, 6]'));").Check(testkit.Rows("[1,2,3]")) - tk.MustExec("CREATE TABLE c(pk INT)") + // IFNULL + tk.MustQuery("SELECT IFNULL(VEC_FROM_TEXT('[1, 2, 3]'), 1);").Check(testkit.Rows("[1,2,3]")) + tk.MustQuery("SELECT IFNULL(NULL, VEC_FROM_TEXT('[1, 2, 3]'));").Check(testkit.Rows("[1,2,3]")) - err = tk.ExecToErr("ALTER TABLE c ADD COLUMN a VECTOR") - require.ErrorContains(t, err, "vector type is not supported") - err = tk.ExecToErr("ALTER TABLE c MODIFY pk VECTOR") - require.ErrorContains(t, err, "vector type is not supported") + // NULLIF + tk.MustQuery("SELECT NULLIF(VEC_FROM_TEXT('[1, 2, 3]'), VEC_FROM_TEXT('[1, 2, 3]'));").Check(testkit.Rows("")) + tk.MustQuery("SELECT NULLIF(VEC_FROM_TEXT('[1, 2, 3]'), VEC_FROM_TEXT('[4, 5, 6]'));").Check(testkit.Rows("[1,2,3]")) - tk.MustExec("DROP TABLE c") + // CASE WHEN + tk.MustQuery("SELECT CASE WHEN TRUE THEN VEC_FROM_TEXT('[1, 2, 3]') ELSE VEC_FROM_TEXT('[4, 5, 6]') END;").Check(testkit.Rows("[1,2,3]")) +} + +func TestVectorStringCompare(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("USE test;") + tk.MustExec("SET @@GLOBAL.TIDB_ENABLE_VECTOR_TYPE=1;") + + tk.MustExec("DROP TABLE IF EXISTS t1;") + tk.MustExec("CREATE TABLE t1 (val vector);") + tk.MustExec("INSERT INTO t1 VALUES ('[1,2,3]'), ('[4,5,6]');") + + // LIKE + tk.MustQuery("SELECT * FROM t1 WHERE val LIKE '%2%';").Check(testkit.Rows("[1,2,3]")) + + // ILIKE + tk.MustQuery("SELECT * FROM t1 WHERE val ILIKE '%2%';").Check(testkit.Rows("[1,2,3]")) + + // STRCMP + tk.MustQuery("SELECT STRCMP('[1,2,3]', VEC_FROM_TEXT('[1,2,3]'));").Check(testkit.Rows("0")) + tk.MustQuery("SELECT STRCMP('[4,5,6]', VEC_FROM_TEXT('[1,2,3]'));").Check(testkit.Rows("1")) +} + +func TestVectorAggregations(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("USE test;") + tk.MustExec("SET @@GLOBAL.TIDB_ENABLE_VECTOR_TYPE=1;") + tk.MustExec(`CREATE TABLE t(val VECTOR);`) + tk.MustExec(`INSERT INTO t VALUES + ('[8.7, 5.7, 7.7, 9.8, 1.5]'), + ('[3.6, 9.7, 2.4, 6.6, 4.9]'), + ('[4.7, 4.9, 2.6, 5.2, 7.4]'), + ('[4.7, 4.9, 2.6, 5.2, 7.4]'), + ('[7.7, 6.7, 8.3, 7.8, 5.7]'), + ('[1.4, 4.5, 8.5, 7.7, 6.2]'); + `) + tk.MustExec(`ANALYZE TABLE t;`) + + tk.MustQuery(`SELECT COUNT(*), val FROM t GROUP BY val ORDER BY val`).Check(testkit.Rows( + "1 [1.4,4.5,8.5,7.7,6.2]", + "1 [3.6,9.7,2.4,6.6,4.9]", + "2 [4.7,4.9,2.6,5.2,7.4]", + "1 [7.7,6.7,8.3,7.8,5.7]", + "1 [8.7,5.7,7.7,9.8,1.5]", + )) + tk.MustQuery(`SELECT COUNT(val) FROM t`).Check(testkit.Rows("6")) + tk.MustQuery(`SELECT COUNT(DISTINCT val) FROM t`).Check(testkit.Rows("5")) + tk.MustQuery(`SELECT MIN(val) FROM t`).Check(testkit.Rows("[1.4,4.5,8.5,7.7,6.2]")) + tk.MustQuery(`SELECT MAX(val) FROM t`).Check(testkit.Rows("[8.7,5.7,7.7,9.8,1.5]")) + tk.MustQueryToErr(`SELECT SUM(val) FROM t`) + tk.MustQueryToErr(`SELECT AVG(val) FROM t`) + tk.MustQuery(`SELECT val FROM t GROUP BY val HAVING val > VEC_FROM_TEXT('[4.7,4.9,2.6,5.2,7.4]') ORDER BY val`).Check(testkit.Rows( + "[7.7,6.7,8.3,7.8,5.7]", + "[8.7,5.7,7.7,9.8,1.5]", + )) +} + +func TestVectorWindow(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("USE test;") + tk.MustExec("SET @@GLOBAL.TIDB_ENABLE_VECTOR_TYPE=1;") + tk.MustExec(`DROP TABLE IF EXISTS t;`) + tk.MustExec(`CREATE TABLE t (embedding VECTOR);`) + tk.MustExec(`INSERT INTO t VALUES + ('[1, 2, 3]'), + ('[4, 5, 601]'), + ('[4, 5, 61]'); + `) + + tk.MustQuery(`SELECT embedding, FIRST_VALUE(embedding) OVER w AS first, NTH_VALUE(embedding, 2) OVER w AS second, LAST_VALUE(embedding) OVER w AS last + FROM t WINDOW w AS (ORDER BY embedding) ORDER BY embedding;`).Check(testkit.Rows( + "[1,2,3] [1,2,3] [1,2,3]", + "[4,5,61] [1,2,3] [4,5,61] [4,5,61]", + "[4,5,601] [1,2,3] [4,5,61] [4,5,601]", + )) + + tk.MustExec(`DELETE FROM t WHERE 1 = 1`) + tk.MustExec(`INSERT INTO t VALUES + ('[1, 2, 3]'), + ('[4, 5, 6]'), + ('[4, 5, 6]'), + ('[7, 8, 9]'); + `) + + tk.MustQuery(`SELECT embedding, ROW_NUMBER() OVER w AS 'row_num', RANK() OVER w AS 'rank', DENSE_RANK() OVER w AS 'dense_rank' + FROM t WINDOW w AS (ORDER BY embedding) ORDER BY embedding;`).Check(testkit.Rows( + "[1,2,3] 1 1 1", + "[4,5,6] 2 2 2", + "[4,5,6] 3 2 2", + "[7,8,9] 4 4 3", + )) + + tk.MustQuery(`SELECT embedding, LAG(embedding) OVER w AS 'lag', LEAD(embedding) OVER w AS 'lead' + FROM t WINDOW w AS (ORDER BY embedding) ORDER BY embedding;`).Check(testkit.Rows( + "[1,2,3] [4,5,6]", + "[4,5,6] [1,2,3] [4,5,6]", + "[4,5,6] [4,5,6] [7,8,9]", + "[7,8,9] [4,5,6] ", + )) + + tk.MustQuery(`SELECT embedding, ROW_NUMBER() OVER (PARTITION BY embedding ORDER BY embedding) AS 'row_num' + FROM t ORDER BY embedding;`).Check(testkit.Rows( + "[1,2,3] 1", + "[4,5,6] 1", + "[4,5,6] 2", + "[7,8,9] 1", + )) +} + +func TestVectorSetOperation(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("USE test;") + tk.MustExec("SET @@GLOBAL.TIDB_ENABLE_VECTOR_TYPE=1;") + tk.MustExec(`DROP TABLE IF EXISTS t1;`) + tk.MustExec(`CREATE TABLE t1 (embedding VECTOR);`) + tk.MustExec(`INSERT INTO t1 VALUES + ('[1, 2, 3]'), + ('[4, 5, 6]'); + `) + + tk.MustExec(`DROP TABLE IF EXISTS t2;`) + tk.MustExec(`CREATE TABLE t2 (embedding VECTOR);`) + tk.MustExec(`INSERT INTO t2 VALUES + ('[4, 5, 6]'), + ('[7, 8, 9]'); + `) + + tk.MustQuery(`(SELECT embedding FROM t1 UNION SELECT embedding FROM t2) ORDER BY embedding;`).Check(testkit.Rows( + "[1,2,3]", + "[4,5,6]", + "[7,8,9]", + )) + + tk.MustQuery(`(SELECT embedding FROM t1 UNION ALL SELECT embedding FROM t2) ORDER BY embedding;`).Check(testkit.Rows( + "[1,2,3]", + "[4,5,6]", + "[4,5,6]", + "[7,8,9]", + )) + + tk.MustQuery(`SELECT embedding FROM t1 INTERSECT SELECT embedding FROM t2;`).Check(testkit.Rows( + "[4,5,6]", + )) + + tk.MustQuery(`SELECT embedding FROM t1 EXCEPT SELECT embedding FROM t2;`).Check(testkit.Rows( + "[1,2,3]", + )) +} + +func TestVectorArithmatic(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("USE test;") + tk.MustExec("SET @@GLOBAL.TIDB_ENABLE_VECTOR_TYPE=1;") + tk.MustExec(`CREATE TABLE t(embedding VECTOR);`) + tk.MustExec(`INSERT INTO t VALUES + ('[1, 2, 3]'), + ('[4, 5, 6]'), + ('[7, 8, 9]'); + `) + tk.MustQuery(`SELECT embedding + '[1, 2, 3]' FROM t;`).Check(testkit.Rows("[2,4,6]", "[5,7,9]", "[8,10,12]")) + tk.MustQuery(`SELECT embedding + embedding FROM t;`).Check(testkit.Rows("[2,4,6]", "[8,10,12]", "[14,16,18]")) + tk.MustQueryToErr(`SELECT embedding + 1 FROM t;`) + tk.MustQueryToErr(`SELECT embedding + '[]' FROM t;`) + tk.MustQuery(`SELECT embedding - '[1, 2, 3]' FROM t;`).Check(testkit.Rows("[0,0,0]", "[3,3,3]", "[6,6,6]")) + tk.MustQuery(`SELECT embedding - embedding FROM t;`).Check(testkit.Rows("[0,0,0]", "[0,0,0]", "[0,0,0]")) + tk.MustQueryToErr(`SELECT embedding - '[1]' FROM t;`) + + tk.MustQuery(`SELECT VEC_FROM_TEXT('[1,2]') + VEC_FROM_TEXT('[2,3]');`).Check(testkit.Rows("[3,5]")) + tk.MustQuery(`SELECT VEC_FROM_TEXT('[1,2]') + '[2,3]';`).Check(testkit.Rows("[3,5]")) + tk.MustQueryToErr(`SELECT VEC_FROM_TEXT('[1,2]') + '[2,3,4]';`) + tk.MustQueryToErr(`SELECT VEC_FROM_TEXT('[1]') + 2;`) + tk.MustQueryToErr(`SELECT VEC_FROM_TEXT('[1]') + '2';`) + + tk.MustQueryToErr(`SELECT VEC_FROM_TEXT('[3e38]') + '[3e38]';`) + tk.MustQuery(`SELECT VEC_FROM_TEXT('[1,2,3]') * '[4,5,6]';`).Check(testkit.Rows("[4,10,18]")) + tk.MustQueryToErr(`SELECT VEC_FROM_TEXT('[1e37]') * '[1e37]';`) +} + +func TestVectorFunctions(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("USE test;") + + tk.MustQuery(`SELECT VEC_L1_DISTANCE('[0,0]', '[3,4]');`).Check(testkit.Rows("7")) + tk.MustQuery(`SELECT VEC_L1_DISTANCE('[0,0]', '[0,1]');`).Check(testkit.Rows("1")) + tk.MustQueryToErr("SELECT VEC_L1_DISTANCE('[1,2]', '[3]');") + tk.MustQuery(`SELECT VEC_L1_DISTANCE('[3e38]', '[-3e38]');`).Check(testkit.Rows("+Inf")) + + tk.MustQuery(`SELECT VEC_L2_DISTANCE('[0,0]', '[3,4]');`).Check(testkit.Rows("5")) + tk.MustQuery(`SELECT VEC_L2_DISTANCE('[0,0]', '[0,1]');`).Check(testkit.Rows("1")) + tk.MustQueryToErr(`SELECT VEC_L2_DISTANCE('[1,2]', '[3]');`) + tk.MustQuery(`SELECT VEC_L2_DISTANCE('[3e38]', '[-3e38]');`).Check(testkit.Rows("+Inf")) + + tk.MustQuery(`SELECT VEC_NEGATIVE_INNER_PRODUCT('[1,2]', '[3,4]');`).Check(testkit.Rows("-11")) + tk.MustQueryToErr(`SELECT VEC_NEGATIVE_INNER_PRODUCT('[1,2]', '[3]');`) + tk.MustQuery(`SELECT VEC_NEGATIVE_INNER_PRODUCT('[3e38]', '[3e38]');`).Check(testkit.Rows("-Inf")) + + tk.MustQuery(`SELECT VEC_COSINE_DISTANCE('[1,2]', '[2,4]');`).Check(testkit.Rows("0")) + tk.MustQuery(`SELECT VEC_COSINE_DISTANCE('[1,2]', '[0,0]');`).Check(testkit.Rows("")) + tk.MustQuery(`SELECT VEC_COSINE_DISTANCE('[1,1]', '[1,1]');`).Check(testkit.Rows("0")) + tk.MustQuery(`SELECT VEC_COSINE_DISTANCE('[1,0]', '[0,2]');`).Check(testkit.Rows("1")) + tk.MustQuery(`SELECT VEC_COSINE_DISTANCE('[1,1]', '[-1,-1]');`).Check(testkit.Rows("2")) + tk.MustQueryToErr(`SELECT VEC_COSINE_DISTANCE('[1,2]', '[3]');`) + tk.MustQuery(`SELECT VEC_COSINE_DISTANCE('[1,1]', '[1.1,1.1]');`).Check(testkit.Rows("0")) + tk.MustQuery(`SELECT VEC_COSINE_DISTANCE('[1,1]', '[-1.1,-1.1]');`).Check(testkit.Rows("2")) + tk.MustQuery(`SELECT VEC_COSINE_DISTANCE('[3e38]', '[3e38]');`).Check(testkit.Rows("")) + + tk.MustQuery(`SELECT VEC_L2_NORM('[3,4]');`).Check(testkit.Rows("5")) + tk.MustQuery(`SELECT VEC_L2_NORM('[0,1]');`).Check(testkit.Rows("1")) } func TestGetLock(t *testing.T) { diff --git a/pkg/expression/scalar_function.go b/pkg/expression/scalar_function.go index 02953af79f936..53cd4366a73be 100644 --- a/pkg/expression/scalar_function.go +++ b/pkg/expression/scalar_function.go @@ -109,6 +109,11 @@ func (sf *ScalarFunction) VecEvalJSON(ctx EvalContext, input *chunk.Chunk, resul return sf.Function.vecEvalJSON(ctx, input, result) } +// VecEvalVectorFloat32 evaluates this expression in a vectorized manner. +func (sf *ScalarFunction) VecEvalVectorFloat32(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error { + return sf.Function.vecEvalVectorFloat32(ctx, input, result) +} + // GetArgs gets arguments of function. func (sf *ScalarFunction) GetArgs() []Expression { return sf.Function.getArgs() @@ -147,6 +152,29 @@ func (sf *ScalarFunction) String() string { return sf.StringWithCtx(exprctx.EmptyParamValues, errors.RedactLogDisable) } +// StringForExplain implements Explainable interface. +func (sf *ScalarFunction) StringForExplain(ctx ParamValues, redact string) string { + var buffer bytes.Buffer + fmt.Fprintf(&buffer, "%s(", sf.FuncName.L) + switch sf.FuncName.L { + case ast.Cast: + for _, arg := range sf.GetArgs() { + buffer.WriteString(arg.StringForExplain(ctx, redact)) + buffer.WriteString(", ") + buffer.WriteString(sf.RetType.String()) + } + default: + for i, arg := range sf.GetArgs() { + buffer.WriteString(arg.StringForExplain(ctx, redact)) + if i+1 != len(sf.GetArgs()) { + buffer.WriteString(", ") + } + } + } + buffer.WriteString(")") + return buffer.String() +} + // typeInferForNull infers the NULL constants field type and set the field type // of NULL constant same as other non-null operands. func typeInferForNull(ctx EvalContext, args []Expression) { @@ -448,6 +476,8 @@ func (sf *ScalarFunction) Eval(ctx EvalContext, row chunk.Row) (d types.Datum, e res, isNull, err = sf.EvalDuration(ctx, row) case types.ETJson: res, isNull, err = sf.EvalJSON(ctx, row) + case types.ETVectorFloat32: + res, isNull, err = sf.EvalVectorFloat32(ctx, row) case types.ETString: var str string str, isNull, err = sf.EvalString(ctx, row) @@ -531,6 +561,11 @@ func (sf *ScalarFunction) EvalJSON(ctx EvalContext, row chunk.Row) (types.Binary return sf.Function.evalJSON(ctx, row) } +// EvalVectorFloat32 implements Expression interface. +func (sf *ScalarFunction) EvalVectorFloat32(ctx EvalContext, row chunk.Row) (types.VectorFloat32, bool, error) { + return sf.Function.evalVectorFloat32(ctx, row) +} + // HashCode implements Expression interface. func (sf *ScalarFunction) HashCode() []byte { if len(sf.hashcode) > 0 { diff --git a/pkg/expression/util_test.go b/pkg/expression/util_test.go index 7f8c305ea000e..50c8074bb4952 100644 --- a/pkg/expression/util_test.go +++ b/pkg/expression/util_test.go @@ -510,6 +510,9 @@ func BenchmarkExtractColumns(b *testing.B) { } b.ReportAllocs() } +func (m *MockExpr) VecEvalVectorFloat32(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error { + return nil +} func BenchmarkExprFromSchema(b *testing.B) { conditions := []Expression{ @@ -558,7 +561,8 @@ func (m *MockExpr) VecEvalJSON(ctx EvalContext, input *chunk.Chunk, result *chun return nil } -func (m *MockExpr) StringWithCtx(ParamValues, string) string { return "" } +func (m *MockExpr) StringWithCtx(ParamValues, string) string { return "" } +func (m *MockExpr) StringForExplain(ParamValues, string) string { return "" } func (m *MockExpr) Eval(ctx EvalContext, row chunk.Row) (types.Datum, error) { return types.NewDatum(m.i), m.err } @@ -604,6 +608,13 @@ func (m *MockExpr) EvalJSON(ctx EvalContext, row chunk.Row) (val types.BinaryJSO } return types.BinaryJSON{}, m.i == nil, m.err } +func (m *MockExpr) EvalVectorFloat32(ctx EvalContext, row chunk.Row) (val types.VectorFloat32, isNull bool, err error) { + if x, ok := m.i.(types.VectorFloat32); ok { + return x, false, m.err + } + return types.ZeroVectorFloat32, m.i == nil, m.err +} + func (m *MockExpr) GetType(_ EvalContext) *types.FieldType { return m.t } func (m *MockExpr) Clone() Expression { return nil } func (m *MockExpr) Equal(ctx EvalContext, e Expression) bool { return false } diff --git a/pkg/expression/vectorized.go b/pkg/expression/vectorized.go index 8232b64873cc4..eeaa58a6c2d9e 100644 --- a/pkg/expression/vectorized.go +++ b/pkg/expression/vectorized.go @@ -115,6 +115,21 @@ func genVecFromConstExpr(ctx EvalContext, expr Expression, targetType types.Eval result.AppendJSON(v) } } + case types.ETVectorFloat32: + result.ReserveVectorFloat32(n) + v, isNull, err := expr.EvalVectorFloat32(ctx, chunk.Row{}) + if err != nil { + return err + } + if isNull { + for i := 0; i < n; i++ { + result.AppendNull() + } + } else { + for i := 0; i < n; i++ { + result.AppendVectorFloat32(v) + } + } case types.ETString: result.ReserveString(n) v, isNull, err := expr.EvalString(ctx, chunk.Row{}) @@ -131,7 +146,7 @@ func genVecFromConstExpr(ctx EvalContext, expr Expression, targetType types.Eval } } default: - return errors.Errorf("unsupported Constant type for vectorized evaluation") + return errors.Errorf("unsupported type %s during evaluation", targetType) } return nil } diff --git a/pkg/kv/checker.go b/pkg/kv/checker.go index 80115927fa41f..007b967f46c5d 100644 --- a/pkg/kv/checker.go +++ b/pkg/kv/checker.go @@ -47,7 +47,8 @@ func (RequestTypeSupportedChecker) supportExpr(exprType tipb.ExprType) bool { switch exprType { case tipb.ExprType_Null, tipb.ExprType_Int64, tipb.ExprType_Uint64, tipb.ExprType_String, tipb.ExprType_Bytes, tipb.ExprType_MysqlDuration, tipb.ExprType_MysqlTime, tipb.ExprType_MysqlDecimal, - tipb.ExprType_Float32, tipb.ExprType_Float64, tipb.ExprType_ColumnRef, tipb.ExprType_MysqlEnum, tipb.ExprType_MysqlBit: + tipb.ExprType_Float32, tipb.ExprType_Float64, tipb.ExprType_ColumnRef, tipb.ExprType_MysqlEnum, tipb.ExprType_MysqlBit, + tipb.ExprType_TiDBVectorFloat32: return true // aggregate functions. // NOTE: tipb.ExprType_GroupConcat is only supported by TiFlash, So checking it for TiKV case outside. diff --git a/pkg/parser/ast/functions.go b/pkg/parser/ast/functions.go index b1ba7be0fc81e..173018e6babea 100644 --- a/pkg/parser/ast/functions.go +++ b/pkg/parser/ast/functions.go @@ -356,6 +356,16 @@ const ( JSONKeys = "json_keys" JSONLength = "json_length" + // vector functions (tidb extension) + VecDims = "vec_dims" + VecL1Distance = "vec_l1_distance" + VecL2Distance = "vec_l2_distance" + VecNegativeInnerProduct = "vec_negative_inner_product" + VecCosineDistance = "vec_cosine_distance" + VecL2Norm = "vec_l2_norm" + VecFromText = "vec_from_text" + VecAsText = "vec_as_text" + // TiDB internal function. TiDBDecodeKey = "tidb_decode_key" TiDBDecodeBase64Key = "tidb_decode_base64_key" diff --git a/pkg/parser/parser_test.go b/pkg/parser/parser_test.go index 0a6c475899699..f530318e3bec5 100644 --- a/pkg/parser/parser_test.go +++ b/pkg/parser/parser_test.go @@ -635,6 +635,11 @@ func TestDMLStmt(t *testing.T) { {"CREATE VIEW v AS (TABLE t)", true, "CREATE ALGORITHM = UNDEFINED DEFINER = CURRENT_USER SQL SECURITY DEFINER VIEW `v` AS (TABLE `t`)"}, {"SELECT * FROM t1 WHERE a IN (TABLE t2)", true, "SELECT * FROM `t1` WHERE `a` IN (TABLE `t2`)"}, + // vector type + {"CREATE TABLE foo (v VECTOR)", true, "CREATE TABLE `foo` (`v` VECTOR)"}, + {"CREATE TABLE foo (v VECTOR)", true, "CREATE TABLE `foo` (`v` VECTOR)"}, + {"CREATE TABLE foo (v VECTOR)", false, ""}, + // values statement {"VALUES ROW(1)", true, "VALUES ROW(1)"}, {"VALUES ROW()", true, "VALUES ROW()"}, @@ -7556,10 +7561,8 @@ func TestCompatTypes(t *testing.T) { func TestVector(t *testing.T) { table := []testCase{ - {"CREATE TABLE t (a VECTOR)", true, "CREATE TABLE `t` (`a` VECTOR)"}, - {"CREATE TABLE t (a VECTOR)", true, "CREATE TABLE `t` (`a` VECTOR)"}, - {"CREATE TABLE t (a VECTOR(3))", true, "CREATE TABLE `t` (`a` VECTOR(3))"}, - {"CREATE TABLE t (a VECTOR(3))", true, "CREATE TABLE `t` (`a` VECTOR(3))"}, + {"CREATE TABLE t (a VECTOR)", true, "CREATE TABLE `t` (`a` VECTOR)"}, + {"CREATE TABLE t (a VECTOR)", true, "CREATE TABLE `t` (`a` VECTOR)"}, {"CREATE TABLE t (a VECTOR)", false, ""}, {"CREATE TABLE t (a VECTOR)", false, ""}, {"CREATE TABLE t (a VECTOR)", false, ""}, diff --git a/pkg/parser/types/etc.go b/pkg/parser/types/etc.go index 74c89ac0d1185..493833823ab1e 100644 --- a/pkg/parser/types/etc.go +++ b/pkg/parser/types/etc.go @@ -56,7 +56,7 @@ var type2Str = map[byte]string{ mysql.TypeEnum: "enum", mysql.TypeFloat: "float", mysql.TypeGeometry: "geometry", - mysql.TypeTiDBVectorFloat32: "vector", + mysql.TypeTiDBVectorFloat32: "vector", mysql.TypeInt24: "mediumint", mysql.TypeJSON: "json", mysql.TypeLong: "int", @@ -77,34 +77,34 @@ var type2Str = map[byte]string{ } var str2Type = map[string]byte{ - "bit": mysql.TypeBit, - "text": mysql.TypeBlob, - "date": mysql.TypeDate, - "datetime": mysql.TypeDatetime, - "unspecified": mysql.TypeUnspecified, - "decimal": mysql.TypeNewDecimal, - "double": mysql.TypeDouble, - "enum": mysql.TypeEnum, - "float": mysql.TypeFloat, - "geometry": mysql.TypeGeometry, - "vector": mysql.TypeTiDBVectorFloat32, - "mediumint": mysql.TypeInt24, - "json": mysql.TypeJSON, - "int": mysql.TypeLong, - "bigint": mysql.TypeLonglong, - "longtext": mysql.TypeLongBlob, - "mediumtext": mysql.TypeMediumBlob, - "null": mysql.TypeNull, - "set": mysql.TypeSet, - "smallint": mysql.TypeShort, - "char": mysql.TypeString, - "time": mysql.TypeDuration, - "timestamp": mysql.TypeTimestamp, - "tinyint": mysql.TypeTiny, - "tinytext": mysql.TypeTinyBlob, - "varchar": mysql.TypeVarchar, - "var_string": mysql.TypeVarString, - "year": mysql.TypeYear, + "bit": mysql.TypeBit, + "text": mysql.TypeBlob, + "date": mysql.TypeDate, + "datetime": mysql.TypeDatetime, + "unspecified": mysql.TypeUnspecified, + "decimal": mysql.TypeNewDecimal, + "double": mysql.TypeDouble, + "enum": mysql.TypeEnum, + "float": mysql.TypeFloat, + "geometry": mysql.TypeGeometry, + "vector": mysql.TypeTiDBVectorFloat32, + "mediumint": mysql.TypeInt24, + "json": mysql.TypeJSON, + "int": mysql.TypeLong, + "bigint": mysql.TypeLonglong, + "longtext": mysql.TypeLongBlob, + "mediumtext": mysql.TypeMediumBlob, + "null": mysql.TypeNull, + "set": mysql.TypeSet, + "smallint": mysql.TypeShort, + "char": mysql.TypeString, + "time": mysql.TypeDuration, + "timestamp": mysql.TypeTimestamp, + "tinyint": mysql.TypeTiny, + "tinytext": mysql.TypeTinyBlob, + "varchar": mysql.TypeVarchar, + "var_string": mysql.TypeVarString, + "year": mysql.TypeYear, } // TypeStr converts tp to a string. diff --git a/pkg/parser/types/eval_type.go b/pkg/parser/types/eval_type.go index 47775953d97c5..40694c4551e06 100644 --- a/pkg/parser/types/eval_type.go +++ b/pkg/parser/types/eval_type.go @@ -13,6 +13,8 @@ package types +import "fmt" + // EvalType indicates the specified types that arguments and result of a built-in function should be. type EvalType byte @@ -33,10 +35,43 @@ const ( ETDuration // ETJson represents type JSON in evaluation. ETJson + // ETVectorFloat32 represents type VectorFloat32 in evaluation. + ETVectorFloat32 ) // IsStringKind returns true for ETString, ETDatetime, ETTimestamp, ETDuration, ETJson EvalTypes. func (et EvalType) IsStringKind() bool { return et == ETString || et == ETDatetime || - et == ETTimestamp || et == ETDuration || et == ETJson + et == ETTimestamp || et == ETDuration || et == ETJson || et == ETVectorFloat32 +} + +// IsVectorKind returns true for ETVectorXxx EvalTypes. +func (et EvalType) IsVectorKind() bool { + return et == ETVectorFloat32 +} + +// String implements fmt.Stringer interface. +func (et EvalType) String() string { + switch et { + case ETInt: + return "Int" + case ETReal: + return "Real" + case ETDecimal: + return "Decimal" + case ETString: + return "String" + case ETDatetime: + return "Datetime" + case ETTimestamp: + return "Timestamp" + case ETDuration: + return "Time" + case ETJson: + return "Json" + case ETVectorFloat32: + return "VectorFloat32" + default: + panic(fmt.Sprintf("invalid EvalType %d", et)) + } } diff --git a/pkg/parser/types/field_type.go b/pkg/parser/types/field_type.go index 827fa0e12b3a4..839e0ae0871ac 100644 --- a/pkg/parser/types/field_type.go +++ b/pkg/parser/types/field_type.go @@ -342,6 +342,8 @@ func (ft *FieldType) EvalType() EvalType { return ETDuration case mysql.TypeJSON: return ETJson + case mysql.TypeTiDBVectorFloat32: + return ETVectorFloat32 case mysql.TypeEnum, mysql.TypeSet: if ft.flag&mysql.EnumSetAsIntFlag > 0 { return ETInt @@ -586,7 +588,7 @@ func (ft *FieldType) RestoreAsCastType(ctx *format.RestoreCtx, explicitCharset b case mysql.TypeYear: ctx.WriteKeyWord("YEAR") case mysql.TypeTiDBVectorFloat32: - ctx.WriteKeyWord("VECTOR") + ctx.WriteKeyWord("VECTOR") } if ft.array { ctx.WritePlain(" ") diff --git a/pkg/planner/core/optimizer.go b/pkg/planner/core/optimizer.go index bea428e35c6eb..75c63c57ac3ee 100644 --- a/pkg/planner/core/optimizer.go +++ b/pkg/planner/core/optimizer.go @@ -1218,7 +1218,7 @@ func existsOverlongType(schema *expression.Schema) bool { for _, column := range schema.Columns { switch column.RetType.GetType() { case mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, - mysql.TypeBlob, mysql.TypeJSON: + mysql.TypeBlob, mysql.TypeJSON, mysql.TypeTiDBVectorFloat32: return true case mysql.TypeVarString, mysql.TypeVarchar: // if the column is varchar and the length of diff --git a/pkg/planner/core/preprocess.go b/pkg/planner/core/preprocess.go index 6cfaa397c8175..6341691bbf97c 100644 --- a/pkg/planner/core/preprocess.go +++ b/pkg/planner/core/preprocess.go @@ -1458,7 +1458,14 @@ func checkColumn(colDef *ast.ColumnDef) error { return types.ErrTooBigDisplayWidth.GenWithStackByArgs(colDef.Name.Name.O, mysql.MaxBitDisplayWidth) } case mysql.TypeTiDBVectorFloat32: - return errors.Errorf("vector type is not supported") + if !variable.EnableVectorType.Load() { + return errors.Errorf("vector type is not supported") + } + if tp.GetFlen() != types.UnspecifiedLength { + if err := types.CheckVectorDimValid(tp.GetFlen()); err != nil { + return err + } + } default: // TODO: Add more types. } @@ -1744,10 +1751,6 @@ func (p *preprocessor) checkFuncCastExpr(node *ast.FuncCastExpr) { return } } - if node.Tp.GetType() == mysql.TypeTiDBVectorFloat32 { - p.err = errors.Errorf("vector type is not supported") - return - } } func (p *preprocessor) updateStateFromStaleReadProcessor() error { diff --git a/pkg/server/internal/column/column.go b/pkg/server/internal/column/column.go index c69d41ff4a28b..b227ef7fe2197 100644 --- a/pkg/server/internal/column/column.go +++ b/pkg/server/internal/column/column.go @@ -183,6 +183,9 @@ func DumpTextRow(buffer []byte, columns []*Info, row chunk.Row, d *ResultEncoder // To compatible with MySQL, here we treat it as utf-8. d.UpdateDataEncoding(mysql.DefaultCollationID) buffer = dump.LengthEncodedString(buffer, d.EncodeData(hack.Slice(row.GetJSON(i).String()))) + case mysql.TypeTiDBVectorFloat32: + d.UpdateDataEncoding(mysql.DefaultCollationID) + buffer = dump.LengthEncodedString(buffer, d.EncodeData(hack.Slice(row.GetVectorFloat32(i).String()))) default: return nil, err.ErrInvalidType.GenWithStack("invalid type %v", columns[i].Type) } @@ -242,6 +245,9 @@ func DumpBinaryRow(buffer []byte, columns []*Info, row chunk.Row, d *ResultEncod // To compatible with MySQL, here we treat it as utf-8. d.UpdateDataEncoding(mysql.DefaultCollationID) buffer = dump.LengthEncodedString(buffer, d.EncodeData(hack.Slice(row.GetJSON(i).String()))) + case mysql.TypeTiDBVectorFloat32: + d.UpdateDataEncoding(mysql.DefaultCollationID) + buffer = dump.LengthEncodedString(buffer, d.EncodeData(hack.Slice(row.GetVectorFloat32(i).String()))) default: return nil, err.ErrInvalidType.GenWithStack("invalid type %v", columns[i].Type) } diff --git a/pkg/sessionctx/variable/sysvar.go b/pkg/sessionctx/variable/sysvar.go index 4e79ab3ab2325..d4b29006bdf48 100644 --- a/pkg/sessionctx/variable/sysvar.go +++ b/pkg/sessionctx/variable/sysvar.go @@ -3256,6 +3256,12 @@ var defaultSysVars = []*SysVar{ }, IsHintUpdatableVerified: true, }, + {Scope: ScopeGlobal, Name: TiDBEnableVectorType, Value: BoolToOnOff(DefTiDBEnableVectorType), Type: TypeBool, SetGlobal: func(ctx context.Context, vars *SessionVars, s string) error { + EnableVectorType.Store(TiDBOptOn(s)) + return nil + }, GetGlobal: func(ctx context.Context, vars *SessionVars) (string, error) { + return BoolToOnOff(EnableVectorType.Load()), nil + }}, {Scope: ScopeGlobal | ScopeSession, Name: TiDBEnableLazyCursorFetch, Value: BoolToOnOff(DefTiDBEnableLazyCursorFetch), Type: TypeBool, SetSession: func(s *SessionVars, val string) error { s.EnableLazyCursorFetch = TiDBOptOn(val) return nil diff --git a/pkg/sessionctx/variable/tidb_vars.go b/pkg/sessionctx/variable/tidb_vars.go index 2dc5792824058..a2b5aedaed0b1 100644 --- a/pkg/sessionctx/variable/tidb_vars.go +++ b/pkg/sessionctx/variable/tidb_vars.go @@ -1096,6 +1096,8 @@ const ( // TiDBTTLRunningTasks limits the count of running ttl tasks. Default to 0, means 3 times the count of TiKV (or no // limitation, if the storage is not TiKV). TiDBTTLRunningTasks = "tidb_ttl_running_tasks" + // TiDBEnableVectorType indicates whether to enable VECTOR data type. + TiDBEnableVectorType = "tidb_enable_vector_type" // AuthenticationLDAPSASLAuthMethodName defines the authentication method used by LDAP SASL authentication plugin AuthenticationLDAPSASLAuthMethodName = "authentication_ldap_sasl_auth_method_name" // AuthenticationLDAPSASLCAPath defines the ca certificate to verify LDAP connection in LDAP SASL authentication plugin @@ -1493,6 +1495,7 @@ const ( DefRuntimeFilterType = "IN" DefRuntimeFilterMode = "OFF" DefTiDBLockUnchangedKeys = true + DefTiDBEnableVectorType = false DefTiDBEnableCheckConstraint = false DefTiDBSkipMissingPartitionStats = true DefTiDBOptEnableHashJoin = true @@ -1613,6 +1616,7 @@ var ( // always set the default value to false because the resource control in kv-client is not inited // It will be initialized to the right value after the first call of `rebuildSysVarCache` EnableResourceControl = atomic.NewBool(false) + EnableVectorType = atomic.NewBool(false) EnableResourceControlStrictMode = atomic.NewBool(true) EnableCheckConstraint = atomic.NewBool(DefTiDBEnableCheckConstraint) SkipMissingPartitionStats = atomic.NewBool(DefTiDBSkipMissingPartitionStats) diff --git a/pkg/table/column.go b/pkg/table/column.go index a29be800c84a0..46ebdf71d5435 100644 --- a/pkg/table/column.go +++ b/pkg/table/column.go @@ -732,6 +732,8 @@ func GetZeroValue(col *model.ColumnInfo) types.Datum { d.SetMysqlEnum(types.Enum{}, col.GetCollate()) case mysql.TypeJSON: d.SetMysqlJSON(types.CreateBinaryJSON(nil)) + case mysql.TypeTiDBVectorFloat32: + d.SetVectorFloat32(types.ZeroVectorFloat32) } return d } diff --git a/pkg/testkit/testkit.go b/pkg/testkit/testkit.go index a7dfd85443fc0..1517d7cb7f710 100644 --- a/pkg/testkit/testkit.go +++ b/pkg/testkit/testkit.go @@ -190,6 +190,12 @@ func (tk *TestKit) EventuallyMustQueryAndCheck(sql string, args []any, }, waitFor, tick) } +// MustQueryToErr query the sql statement and must return Error. +func (tk *TestKit) MustQueryToErr(sql string, args ...any) { + err := tk.QueryToErr(sql, args...) + tk.require.Error(err) +} + // MustQueryWithContext query the statements and returns result rows. func (tk *TestKit) MustQueryWithContext(ctx context.Context, sql string, args ...any) *Result { comment := fmt.Sprintf("sql:%s, args:%v", sql, args) diff --git a/pkg/types/BUILD.bazel b/pkg/types/BUILD.bazel index bd29720a792d6..f7bad15da90c9 100644 --- a/pkg/types/BUILD.bazel +++ b/pkg/types/BUILD.bazel @@ -37,6 +37,8 @@ go_library( "set.go", "time.go", "truncate.go", + "vector.go", + "vector_functions.go", ], importpath = "github.com/pingcap/tidb/pkg/types", visibility = [ @@ -62,6 +64,7 @@ go_library( "//pkg/util/parser", "//pkg/util/size", "//pkg/util/stringutil", + "@com_github_json_iterator_go//:go", "@com_github_pingcap_errors//:errors", "@com_github_pingcap_log//:log", "@org_uber_go_zap//:zap", @@ -97,6 +100,7 @@ go_test( "overflow_test.go", "set_test.go", "time_test.go", + "vector_test.go", ], embed = [":types"], flaky = True, diff --git a/pkg/types/datum.go b/pkg/types/datum.go index 0da830536c34b..9fec3b2be2d41 100644 --- a/pkg/types/datum.go +++ b/pkg/types/datum.go @@ -59,6 +59,7 @@ const ( KindMaxValue byte = 16 KindRaw byte = 17 KindMysqlJSON byte = 18 + KindVectorFloat32 byte = 19 ) // Datum is a data box holds different kind of data. @@ -405,6 +406,21 @@ func (d *Datum) SetMysqlJSON(b BinaryJSON) { d.b = b.Value } +// SetVectorFloat32 sets VectorFloat32 value +func (d *Datum) SetVectorFloat32(vec VectorFloat32) { + d.k = KindVectorFloat32 + d.b = vec.ZeroCopySerialize() +} + +// GetVectorFloat32 gets VectorFloat32 value +func (d *Datum) GetVectorFloat32() VectorFloat32 { + v, _, err := ZeroCopyDeserializeVectorFloat32(d.b) + if err != nil { + panic(err) + } + return v +} + // GetMysqlTime gets types.Time value func (d *Datum) GetMysqlTime() Time { return d.x.(Time) @@ -479,6 +495,8 @@ func (d Datum) String() string { t = "KindRaw" case KindMysqlJSON: t = "KindMysqlJSON" + case KindVectorFloat32: + t = "KindVectorFloat32" default: t = "Unknown" } @@ -526,6 +544,8 @@ func (d *Datum) GetValue() any { return d.GetMysqlJSON() case KindMysqlTime: return d.GetMysqlTime() + case KindVectorFloat32: + return d.GetVectorFloat32() default: return d.GetInterface() } @@ -574,6 +594,8 @@ func (d *Datum) SetValueWithDefaultCollation(val any) { d.SetMysqlJSON(x) case Time: d.SetMysqlTime(x) + case VectorFloat32: + d.SetVectorFloat32(x) default: d.SetInterface(x) } @@ -622,6 +644,8 @@ func (d *Datum) SetValue(val any, tp *types.FieldType) { d.SetMysqlJSON(x) case Time: d.SetMysqlTime(x) + case VectorFloat32: + d.SetVectorFloat32(x) default: d.SetInterface(x) } @@ -676,6 +700,8 @@ func (d *Datum) Compare(ctx Context, ad *Datum, comparer collate.Collator) (int, return d.compareMysqlJSON(ad.GetMysqlJSON()) case KindMysqlTime: return d.compareMysqlTime(ctx, ad.GetMysqlTime()) + case KindVectorFloat32: + return d.compareVectorFloat32(ctx, ad.GetVectorFloat32()) default: return 0, nil } @@ -901,6 +927,20 @@ func (d *Datum) compareMysqlTime(ctx Context, time Time) (int, error) { } } +func (d *Datum) compareVectorFloat32(ctx Context, vec VectorFloat32) (int, error) { + switch d.k { + case KindNull, KindMinNotNull: + return -1, nil + case KindMaxValue: + return 1, nil + case KindVectorFloat32: + return d.GetVectorFloat32().Compare(vec), nil + // Note: We expect cast is applied before compare, when comparing with String and other vector types. + default: + return 0, errors.New("cannot compare vector and non-vector, cast is required") + } +} + // ConvertTo converts a datum to the target field type. // change this method need sync modification to type2Kind in rowcodec/types.go func (d *Datum) ConvertTo(ctx Context, target *FieldType) (Datum, error) { @@ -937,6 +977,8 @@ func (d *Datum) ConvertTo(ctx Context, target *FieldType) (Datum, error) { return d.convertToMysqlSet(ctx, target) case mysql.TypeJSON: return d.convertToMysqlJSON(target) + case mysql.TypeTiDBVectorFloat32: + return d.convertToVectorFloat32(ctx, target) case mysql.TypeNull: return Datum{}, nil default: @@ -1073,6 +1115,8 @@ func (d *Datum) convertToString(ctx Context, target *FieldType) (Datum, error) { } case KindMysqlJSON: s = d.GetMysqlJSON().String() + case KindVectorFloat32: + s = d.GetVectorFloat32().String() default: return invalidConv(d, target.GetType()) } @@ -1685,6 +1729,8 @@ func (d *Datum) convertToMysqlSet(ctx Context, target *FieldType) (Datum, error) s, err = ParseSet(target.GetElems(), d.GetMysqlEnum().Name, target.GetCollate()) case KindMysqlSet: s, err = ParseSet(target.GetElems(), d.GetMysqlSet().Name, target.GetCollate()) + case KindVectorFloat32: + return invalidConv(d, mysql.TypeSet) default: var uintDatum Datum uintDatum, err = d.convertToUint(ctx, target) @@ -1750,6 +1796,29 @@ func (d *Datum) convertToMysqlJSON(_ *FieldType) (ret Datum, err error) { return ret, errors.Trace(err) } +func (d *Datum) convertToVectorFloat32(_ Context, target *FieldType) (ret Datum, err error) { + switch d.k { + case KindVectorFloat32: + v := d.GetVectorFloat32() + if err = v.CheckDimsFitColumn(target.GetFlen()); err != nil { + return ret, errors.Trace(err) + } + ret = *d + case KindString, KindBytes: + var v VectorFloat32 + if v, err = ParseVectorFloat32(d.GetString()); err != nil { + return ret, errors.Trace(err) + } + if err = v.CheckDimsFitColumn(target.GetFlen()); err != nil { + return ret, errors.Trace(err) + } + ret.SetVectorFloat32(v) + default: + return invalidConv(d, mysql.TypeTiDBVectorFloat32) + } + return ret, errors.Trace(err) +} + // ToBool converts to a bool. // We will use 1 for true, and 0 for false. func (d *Datum) ToBool(ctx Context) (int64, error) { @@ -1784,6 +1853,8 @@ func (d *Datum) ToBool(ctx Context) (int64, error) { case KindMysqlJSON: val := d.GetMysqlJSON() isZero = val.IsZero() + case KindVectorFloat32: + isZero = d.GetVectorFloat32().IsZeroValue() default: return 0, errors.Errorf("cannot convert %v(type %T) to bool", d.GetValue(), d.GetValue()) } @@ -1831,7 +1902,7 @@ func ConvertDatumToDecimal(ctx Context, d Datum) (*MyDecimal, error) { } dec = f default: - err = fmt.Errorf("can't convert %v to decimal", d.GetValue()) + err = errors.Errorf("can't convert %v to decimal", d.GetValue()) } return dec, errors.Trace(err) } @@ -2002,6 +2073,8 @@ func (d *Datum) ToString() (string, error) { return d.GetMysqlJSON().String(), nil case KindBinaryLiteral, KindMysqlBit: return d.GetBinaryLiteral().ToString(), nil + case KindVectorFloat32: + return d.GetVectorFloat32().String(), nil case KindNull: return "", nil default: @@ -2009,6 +2082,16 @@ func (d *Datum) ToString() (string, error) { } } +// StringForExplain implements Explainable interface. +func (d *Datum) StringForExplain(ctx ParamValues, redact string) string { + switch d.Kind() { + case KindVectorFloat32: + return d.GetVectorFloat32().StringForExplain(ctx, redact) + default: + return fmt.Sprintf("%v", d.GetValue()) + } +} + // ToBytes gets the bytes representation of the datum. func (d *Datum) ToBytes() ([]byte, error) { switch d.k { @@ -2217,6 +2300,12 @@ func NewJSONDatum(j BinaryJSON) (d Datum) { return d } +// NewVectorFloat32Datum creates a new Datum from a VectorFloat32 value +func NewVectorFloat32Datum(v VectorFloat32) (d Datum) { + d.SetVectorFloat32(v) + return d +} + // NewBinaryLiteralDatum creates a new BinaryLiteral Datum for a BinaryLiteral value. func NewBinaryLiteralDatum(b BinaryLiteral) (d Datum) { d.SetBinaryLiteral(b) @@ -2581,6 +2670,8 @@ func (d Datum) EstimatedMemUsage() int64 { bytesConsumed += sizeOfMyDecimal case KindMysqlTime: bytesConsumed += sizeOfMysqlTime + case KindVectorFloat32: + bytesConsumed += d.GetVectorFloat32().EstimatedMemUsage() default: bytesConsumed += len(d.b) } diff --git a/pkg/types/etc.go b/pkg/types/etc.go index b87345c38b659..25e30e70bcc3f 100644 --- a/pkg/types/etc.go +++ b/pkg/types/etc.go @@ -37,6 +37,9 @@ var IsTypeBlob = ast.IsTypeBlob // whether the tp is the char type like a string type or a varchar type. var IsTypeChar = ast.IsTypeChar +// IsTypeVector returns whether tp is a vector type. +var IsTypeVector = ast.IsTypeVector + // IsTypeVarchar returns a boolean indicating // whether the tp is the varchar type like a varstring type or a varchar type. func IsTypeVarchar(tp byte) bool { @@ -159,6 +162,7 @@ var kind2Str = map[byte]string{ KindMaxValue: "max_value", KindRaw: "raw", KindMysqlJSON: "json", + KindVectorFloat32: "vector", } // TypeStr converts tp to a string. diff --git a/pkg/types/eval_type.go b/pkg/types/eval_type.go index fd4cd93318723..24ad4a33ba2ef 100644 --- a/pkg/types/eval_type.go +++ b/pkg/types/eval_type.go @@ -36,4 +36,6 @@ const ( ETDuration = ast.ETDuration // ETJson represents type JSON in evaluation. ETJson = ast.ETJson + // ETVectorFloat32 represents type VectorFloat32 in evaluation. + ETVectorFloat32 = ast.ETVectorFloat32 ) diff --git a/pkg/types/field_type.go b/pkg/types/field_type.go index 60c00db8994c8..bfa84118d8c93 100644 --- a/pkg/types/field_type.go +++ b/pkg/types/field_type.go @@ -337,6 +337,11 @@ func DefaultTypeForValue(value any, tp *FieldType, char string, collate string) tp.SetDecimal(0) tp.SetCharset(charset.CharsetUTF8MB4) tp.SetCollate(charset.CollationUTF8MB4) + case VectorFloat32: + tp.SetType(mysql.TypeTiDBVectorFloat32) + tp.SetFlen(UnspecifiedLength) + tp.SetDecimal(0) + SetBinChsClnFlag(tp) default: tp.SetType(mysql.TypeUnspecified) tp.SetFlen(UnspecifiedLength) @@ -389,22 +394,46 @@ func mergeTypeFlag(a, b uint) uint { return a & (b&mysql.NotNullFlag | ^mysql.NotNullFlag) & (b&mysql.UnsignedFlag | ^mysql.UnsignedFlag) } -func getFieldTypeIndex(tp byte) int { - itp := int(tp) - if itp < fieldTypeTearFrom { - return itp +var ( + fieldTypeIndexes = map[byte]int{ + mysql.TypeUnspecified: 0, + mysql.TypeTiny: 1, + mysql.TypeShort: 2, + mysql.TypeLong: 3, + mysql.TypeFloat: 4, + mysql.TypeDouble: 5, + mysql.TypeNull: 6, + mysql.TypeTimestamp: 7, + mysql.TypeLonglong: 8, + mysql.TypeInt24: 9, + mysql.TypeDate: 10, + mysql.TypeDuration: 11, + mysql.TypeDatetime: 12, + mysql.TypeYear: 13, + mysql.TypeNewDate: 14, + mysql.TypeVarchar: 15, + mysql.TypeBit: 16, + mysql.TypeJSON: 17, + mysql.TypeNewDecimal: 18, + mysql.TypeEnum: 19, + mysql.TypeSet: 20, + mysql.TypeTinyBlob: 21, + mysql.TypeMediumBlob: 22, + mysql.TypeLongBlob: 23, + mysql.TypeBlob: 24, + mysql.TypeVarString: 25, + mysql.TypeString: 26, + mysql.TypeGeometry: 27, + mysql.TypeTiDBVectorFloat32: 28, } - return fieldTypeTearFrom + itp - fieldTypeTearTo - 1 -} - -const ( - fieldTypeTearFrom = int(mysql.TypeBit) + 1 - fieldTypeTearTo = int(mysql.TypeJSON) - 1 - fieldTypeNum = fieldTypeTearFrom + (255 - fieldTypeTearTo) ) +func getFieldTypeIndex(tp byte) int { + return fieldTypeIndexes[tp] +} + // https://github.com/mysql/mysql-server/blob/8.0/sql/field.cc#L248 -var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ +var fieldTypeMergeRules = [29][29]byte{ /* mysql.TypeUnspecified -> */ { // mysql.TypeUnspecified mysql.TypeTiny @@ -437,6 +466,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeBlob, mysql.TypeVarchar, // mysql.TypeString mysql.TypeGeometry mysql.TypeString, mysql.TypeVarchar, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeVarchar, }, /* mysql.TypeTiny -> */ { @@ -470,6 +501,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeBlob, mysql.TypeVarchar, // mysql.TypeString mysql.TypeGeometry mysql.TypeString, mysql.TypeVarchar, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeVarchar, }, /* mysql.TypeShort -> */ { @@ -503,6 +536,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeBlob, mysql.TypeVarchar, // mysql.TypeString mysql.TypeGeometry mysql.TypeString, mysql.TypeVarchar, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeVarchar, }, /* mysql.TypeLong -> */ { @@ -536,6 +571,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeBlob, mysql.TypeVarchar, // mysql.TypeString mysql.TypeGeometry mysql.TypeString, mysql.TypeVarchar, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeVarchar, }, /* mysql.TypeFloat -> */ { @@ -569,6 +606,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeBlob, mysql.TypeVarchar, // mysql.TypeString mysql.TypeGeometry mysql.TypeString, mysql.TypeVarchar, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeVarchar, }, /* mysql.TypeDouble -> */ { @@ -602,6 +641,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeBlob, mysql.TypeVarchar, // mysql.TypeString mysql.TypeGeometry mysql.TypeString, mysql.TypeVarchar, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeVarchar, }, /* mysql.TypeNull -> */ { @@ -635,6 +676,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeBlob, mysql.TypeVarchar, // mysql.TypeString mysql.TypeGeometry mysql.TypeString, mysql.TypeGeometry, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeTiDBVectorFloat32, }, /* mysql.TypeTimestamp -> */ { @@ -668,6 +711,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeBlob, mysql.TypeVarchar, // mysql.TypeString mysql.TypeGeometry mysql.TypeString, mysql.TypeVarchar, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeVarchar, }, /* mysql.TypeLonglong -> */ { @@ -701,6 +746,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeBlob, mysql.TypeVarchar, // mysql.TypeString mysql.TypeGeometry mysql.TypeString, mysql.TypeVarchar, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeVarchar, }, /* mysql.TypeInt24 -> */ { @@ -734,6 +781,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeBlob, mysql.TypeVarchar, // mysql.TypeString mysql.TypeGeometry mysql.TypeString, mysql.TypeVarchar, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeVarchar, }, /* mysql.TypeDate -> */ { @@ -767,6 +816,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeBlob, mysql.TypeVarchar, // mysql.TypeString mysql.TypeGeometry mysql.TypeString, mysql.TypeVarchar, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeVarchar, }, /* mysql.TypeTime -> */ { @@ -800,6 +851,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeBlob, mysql.TypeVarchar, // mysql.TypeString mysql.TypeGeometry mysql.TypeString, mysql.TypeVarchar, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeVarchar, }, /* mysql.TypeDatetime -> */ { @@ -833,6 +886,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeBlob, mysql.TypeVarchar, // mysql.TypeString mysql.TypeGeometry mysql.TypeString, mysql.TypeVarchar, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeVarchar, }, /* mysql.TypeYear -> */ { @@ -866,6 +921,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeBlob, mysql.TypeVarchar, // mysql.TypeString mysql.TypeGeometry mysql.TypeString, mysql.TypeVarchar, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeVarchar, }, /* mysql.TypeNewDate -> */ { @@ -899,6 +956,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeBlob, mysql.TypeVarchar, // mysql.TypeString mysql.TypeGeometry mysql.TypeString, mysql.TypeVarchar, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeVarchar, }, /* mysql.TypeVarchar -> */ { @@ -932,6 +991,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeBlob, mysql.TypeVarchar, // mysql.TypeString mysql.TypeGeometry mysql.TypeVarchar, mysql.TypeVarchar, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeVarchar, }, /* mysql.TypeBit -> */ { @@ -965,6 +1026,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeBlob, mysql.TypeVarchar, // mysql.TypeString mysql.TypeGeometry mysql.TypeString, mysql.TypeVarchar, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeVarchar, }, /* mysql.TypeJSON -> */ { @@ -998,6 +1061,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeLongBlob, mysql.TypeVarchar, // mysql.TypeString MYSQL_TYPE_GEOMETRY mysql.TypeString, mysql.TypeVarchar, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeVarchar, }, /* mysql.TypeNewDecimal -> */ { @@ -1031,6 +1096,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeBlob, mysql.TypeVarchar, // mysql.TypeString mysql.TypeGeometry mysql.TypeString, mysql.TypeVarchar, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeVarchar, }, /* mysql.TypeEnum -> */ { @@ -1064,6 +1131,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeBlob, mysql.TypeVarchar, // mysql.TypeString mysql.TypeGeometry mysql.TypeString, mysql.TypeVarchar, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeVarchar, }, /* mysql.TypeSet -> */ { @@ -1097,6 +1166,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeBlob, mysql.TypeVarchar, // mysql.TypeString mysql.TypeGeometry mysql.TypeString, mysql.TypeVarchar, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeVarchar, }, /* mysql.TypeTinyBlob -> */ { @@ -1130,6 +1201,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeBlob, mysql.TypeTinyBlob, // mysql.TypeString mysql.TypeGeometry mysql.TypeTinyBlob, mysql.TypeTinyBlob, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeLongBlob, }, /* mysql.TypeMediumBlob -> */ { @@ -1163,6 +1236,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeMediumBlob, mysql.TypeMediumBlob, // mysql.TypeString mysql.TypeGeometry mysql.TypeMediumBlob, mysql.TypeMediumBlob, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeLongBlob, }, /* mysql.TypeLongBlob -> */ { @@ -1196,6 +1271,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeLongBlob, mysql.TypeLongBlob, // mysql.TypeString mysql.TypeGeometry mysql.TypeLongBlob, mysql.TypeLongBlob, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeLongBlob, }, /* mysql.TypeBlob -> */ { @@ -1229,6 +1306,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeBlob, mysql.TypeBlob, // mysql.TypeString mysql.TypeGeometry mysql.TypeBlob, mysql.TypeBlob, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeLongBlob, }, /* mysql.TypeVarString -> */ { @@ -1262,6 +1341,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeBlob, mysql.TypeVarchar, // mysql.TypeString mysql.TypeGeometry mysql.TypeVarchar, mysql.TypeVarchar, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeVarchar, }, /* mysql.TypeString -> */ { @@ -1295,6 +1376,8 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeBlob, mysql.TypeVarchar, // mysql.TypeString mysql.TypeGeometry mysql.TypeString, mysql.TypeString, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeString, }, /* mysql.TypeGeometry -> */ { @@ -1328,6 +1411,43 @@ var fieldTypeMergeRules = [fieldTypeNum][fieldTypeNum]byte{ mysql.TypeBlob, mysql.TypeVarchar, // mysql.TypeString mysql.TypeGeometry mysql.TypeString, mysql.TypeGeometry, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeVarchar, + }, + /* mysql.TypeTiDBVectorFloat32 -> */ + { + // mysql.TypeUnspecified mysql.TypeTiny + mysql.TypeVarchar, mysql.TypeVarchar, + // mysql.TypeShort mysql.TypeLong + mysql.TypeVarchar, mysql.TypeVarchar, + // mysql.TypeNewFloat mysql.TypeDouble + mysql.TypeVarchar, mysql.TypeVarchar, + // mysql.TypeNull mysql.TypeTimestamp + mysql.TypeTiDBVectorFloat32, mysql.TypeVarchar, + // mysql.TypeLongLONG mysql.TypeInt24 + mysql.TypeVarchar, mysql.TypeVarchar, + // mysql.TypeDate MYSQL_TYPE_TIME + mysql.TypeVarchar, mysql.TypeVarchar, + // mysql.TypeDatetime MYSQL_TYPE_YEAR + mysql.TypeVarchar, mysql.TypeVarchar, + // mysql.TypeNewDate mysql.TypeVarchar + mysql.TypeVarchar, mysql.TypeVarchar, + // mysql.TypeBit <16>-<244> + mysql.TypeVarchar, + // mysql.TypeJSON + mysql.TypeVarchar, + // mysql.TypeNewDecimal MYSQL_TYPE_ENUM + mysql.TypeVarchar, mysql.TypeVarchar, + // mysql.TypeSet mysql.TypeTinyBlob + mysql.TypeVarchar, mysql.TypeLongBlob, + // mysql.TypeMediumBlob mysql.TypeLongBlob + mysql.TypeLongBlob, mysql.TypeLongBlob, + // mysql.TypeBlob mysql.TypeVarString + mysql.TypeLongBlob, mysql.TypeVarchar, + // mysql.TypeString MYSQL_TYPE_GEOMETRY + mysql.TypeString, mysql.TypeVarchar, + // mysql.TypeTiDBVectorFloat32 + mysql.TypeTiDBVectorFloat32, }, } @@ -1461,6 +1581,11 @@ func checkTypeChangeSupported(origin *FieldType, to *FieldType) bool { return false } + if origin.GetType() == mysql.TypeTiDBVectorFloat32 || to.GetType() == mysql.TypeTiDBVectorFloat32 { + // TODO: Vector type not supported. + return false + } + if (origin.GetType() == mysql.TypeEnum || origin.GetType() == mysql.TypeSet || origin.GetType() == mysql.TypeBit) && to.GetType() == mysql.TypeDuration { // TODO: Currently enum/set/bit cast to time are not support yet, should fix here after supported. diff --git a/pkg/types/parser_driver/value_expr.go b/pkg/types/parser_driver/value_expr.go index c9aa039c5d792..1ccb760a1cd1f 100644 --- a/pkg/types/parser_driver/value_expr.go +++ b/pkg/types/parser_driver/value_expr.go @@ -128,7 +128,8 @@ func (n *ValueExpr) Restore(ctx *format.RestoreCtx) error { case types.KindMysqlEnum, types.KindMysqlBit, types.KindMysqlSet, types.KindInterface, types.KindMinNotNull, types.KindMaxValue, - types.KindRaw, types.KindMysqlJSON: + types.KindRaw, types.KindMysqlJSON, + types.KindVectorFloat32: // TODO implement Restore function return errors.New("Not implemented") default: diff --git a/pkg/types/vector.go b/pkg/types/vector.go new file mode 100644 index 0000000000000..a5687e52e41ca --- /dev/null +++ b/pkg/types/vector.go @@ -0,0 +1,238 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +import ( + "encoding/binary" + "fmt" + "math" + "strconv" + "unsafe" + + jsoniter "github.com/json-iterator/go" + "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/parser/types" +) + +func init() { + var buf [4]byte + binary.NativeEndian.PutUint32(buf[:], 0x2) + if buf[0] == 0x02 && buf[1] == 0x00 && buf[2] == 0x00 && buf[3] == 0x00 { + return + } + panic("VectorFloat32 only supports little endian") +} + +// ParamValuesVectorFloat32 is a readonly interface to return param for VectorFloat32 +type ParamValues interface { + // GetParamValue returns the value of the parameter by index. + GetParamValue(idx int) (Datum, error) +} + +// VectorFloat32 represents a vector of float32. +// +// Memory Format: +// 4 byte - Length +// 4 byte * N - Data in Float32 +// +// Normally, the data layout in storage (i.e. after serialization) is identical +// to the memory layout. However, in BigEndian machines, we have BigEndian in +// memory and always have LittleEndian in storage or during data exchange. +type VectorFloat32 struct { + data []byte // Note: data must be a well-formatted byte slice (len >= 4) +} + +// ZeroVectorFloat32 is a zero value of VectorFloat32. +var ZeroVectorFloat32 = InitVectorFloat32( /* dims= */ 0) + +// InitVectorFloat32 initializes a vector with the given dimension. The values are initialized to zero. +func InitVectorFloat32(dims int) VectorFloat32 { + data := make([]byte, 4+dims*4) + binary.LittleEndian.PutUint32(data, uint32(dims)) + return VectorFloat32{data: data} +} + +// CheckVectorDimValid checks if the vector's dimension is valid. +func CheckVectorDimValid(dim int) error { + const ( + maxVectorDimension = 16000 + ) + if dim < 0 { + return errors.Errorf("dimensions for type vector must be at least 0") + } + if dim > maxVectorDimension { + return errors.Errorf("vector cannot have more than %d dimensions", maxVectorDimension) + } + return nil +} + +// CheckDimsFitColumn checks if the vector has the expected dimension, which is defined by the column type or cast type. +func (v VectorFloat32) CheckDimsFitColumn(expectedFlen int) error { + if expectedFlen != types.UnspecifiedLength && v.Len() != expectedFlen { + return errors.Errorf("vector has %d dimensions, does not fit VECTOR(%d)", v.Len(), expectedFlen) + } + return nil +} + +// Len returns the length (dimension) of the vector. +func (v VectorFloat32) Len() int { + return int(binary.LittleEndian.Uint32(v.data)) +} + +// Elements returns a mutable typed slice of the elements. +func (v VectorFloat32) Elements() []float32 { + l := v.Len() + if l == 0 { + return nil + } + return unsafe.Slice((*float32)(unsafe.Pointer(&v.data[4])), l) +} + +// StringForExplain implements Explainable interface. +// In EXPLAIN context, we truncate the elements to avoid too long output. +func (v VectorFloat32) StringForExplain(ctx ParamValues, redact string) string { + const ( + maxDisplayElements = 5 + ) + + truncatedElements := 0 + elements := v.Elements() + + if len(elements) > maxDisplayElements { + truncatedElements = len(elements) - maxDisplayElements + elements = elements[:maxDisplayElements] + } + + buf := make([]byte, 0, 2+v.Len()*2) + buf = append(buf, '[') + for i, v := range elements { + if i > 0 { + buf = append(buf, ","...) + } + buf = strconv.AppendFloat(buf, float64(v), 'g', 2, 32) + } + if truncatedElements > 0 { + buf = append(buf, fmt.Sprintf(",(%d more)...", truncatedElements)...) + } + buf = append(buf, ']') + + // buf is not used elsewhere, so it's safe to just cast to String + return unsafe.String(unsafe.SliceData(buf), len(buf)) +} + +// String implements the fmt.Stringer interface. +// It returns a string representation of the vector which can be parsed later. +func (v VectorFloat32) String() string { + elements := v.Elements() + + buf := make([]byte, 0, 2+v.Len()*2) + buf = append(buf, '[') + for i, v := range elements { + if i > 0 { + buf = append(buf, ',') + } + buf = strconv.AppendFloat(buf, float64(v), 'f', -1, 32) + } + buf = append(buf, ']') + + // buf is not used elsewhere, so it's safe to just cast to String + return unsafe.String(unsafe.SliceData(buf), len(buf)) +} + +// ZeroCopySerialize serializes the vector into a new byte slice, without memory copy. +func (v VectorFloat32) ZeroCopySerialize() []byte { + return v.data +} + +// SerializeTo serializes the vector into the byte slice. +func (v VectorFloat32) SerializeTo(b []byte) []byte { + return append(b, v.data...) +} + +// SerializedSize returns the size of the serialized data. +func (v VectorFloat32) SerializedSize() int { + return len(v.data) +} + +// EstimatedMemUsage returns the estimated memory usage. +func (v VectorFloat32) EstimatedMemUsage() int { + return int(unsafe.Sizeof(v)) + len(v.data) +} + +// ZeroCopyDeserializeVectorFloat32 deserializes the byte slice into a vector, without memory copy. +// Note: b must not be mutated, because this function does zero copy. +func ZeroCopyDeserializeVectorFloat32(b []byte) (VectorFloat32, []byte, error) { + if len(b) < 4 { + return ZeroVectorFloat32, b, errors.Errorf("bad VectorFloat32 value header (len=%d)", len(b)) + } + + elements := binary.LittleEndian.Uint32(b) + totalDataSize := elements*4 + 4 + if len(b) < int(totalDataSize) { + return ZeroVectorFloat32, b, errors.Errorf("bad VectorFloat32 value (len=%d, expected=%d)", len(b), totalDataSize) + } + + data := b[:totalDataSize] + remaining := b[totalDataSize:] + return VectorFloat32{data: data}, remaining, nil +} + +// ParseVectorFloat32 parses a string into a vector. +func ParseVectorFloat32(s string) (VectorFloat32, error) { + var values []float32 + var valueError error + // We explicitly use a JSON float parser to reject other JSON types. + parser := jsoniter.ParseString(jsoniter.ConfigDefault, s) + parser.ReadArrayCB(func(parser *jsoniter.Iterator) bool { + v := parser.ReadFloat64() + if math.IsNaN(v) { + valueError = errors.Errorf("NaN not allowed in vector") + return false + } + if math.IsInf(v, 0) { + valueError = errors.Errorf("infinite value not allowed in vector") + return false + } + values = append(values, float32(v)) + return true + }) + if parser.Error != nil { + return ZeroVectorFloat32, errors.Errorf("Invalid vector text: %s", s) + } + if valueError != nil { + return ZeroVectorFloat32, valueError + } + + dim := len(values) + if err := CheckVectorDimValid(dim); err != nil { + return ZeroVectorFloat32, err + } + + vec := InitVectorFloat32(dim) + copy(vec.Elements(), values) + return vec, nil +} + +// Clone returns a deep copy of the vector. +func (v VectorFloat32) Clone() VectorFloat32 { + data := make([]byte, len(v.data)) + copy(data, v.data) + return VectorFloat32{data: data} +} + +// IsZeroValue returns true if the vector is a zero value (which length is zero). +func (v VectorFloat32) IsZeroValue() bool { + return v.Len() == 0 +} diff --git a/pkg/types/vector_functions.go b/pkg/types/vector_functions.go new file mode 100644 index 0000000000000..c75ade920b21c --- /dev/null +++ b/pkg/types/vector_functions.go @@ -0,0 +1,271 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +import ( + "math" + + "github.com/pingcap/errors" +) + +func (a VectorFloat32) checkIdenticalDims(b VectorFloat32) error { + if a.Len() != b.Len() { + return errors.Errorf("vectors have different dimensions: %d and %d", a.Len(), b.Len()) + } + return nil +} + +// L2SquaredDistance returns the squared L2 distance between two vectors. +// This saves a sqrt calculation. +func (a VectorFloat32) L2SquaredDistance(b VectorFloat32) (float64, error) { + if err := a.checkIdenticalDims(b); err != nil { + return 0, errors.Trace(err) + } + + var distance float32 = 0.0 + va := a.Elements() + vb := b.Elements() + + for i, iMax := 0, a.Len(); i < iMax; i++ { + // Hope this can be vectorized. + diff := va[i] - vb[i] + distance += diff * diff + } + + return float64(distance), nil +} + +// L2Distance returns the L2 distance between two vectors. +func (a VectorFloat32) L2Distance(b VectorFloat32) (float64, error) { + d, err := a.L2SquaredDistance(b) + if err != nil { + return 0, errors.Trace(err) + } + return math.Sqrt(d), nil +} + +// InnerProduct returns the inner product of two vectors. +func (a VectorFloat32) InnerProduct(b VectorFloat32) (float64, error) { + if err := a.checkIdenticalDims(b); err != nil { + return 0, errors.Trace(err) + } + + var distance float32 = 0.0 + va := a.Elements() + vb := b.Elements() + + for i, iMax := 0, a.Len(); i < iMax; i++ { + // Hope this can be vectorized. + distance += va[i] * vb[i] + } + + return float64(distance), nil +} + +// NegativeInnerProduct returns the negative inner product of two vectors. +func (a VectorFloat32) NegativeInnerProduct(b VectorFloat32) (float64, error) { + d, err := a.InnerProduct(b) + if err != nil { + return 0, errors.Trace(err) + } + return d * -1, nil +} + +// CosineDistance returns the cosine distance between two vectors. +func (a VectorFloat32) CosineDistance(b VectorFloat32) (float64, error) { + if err := a.checkIdenticalDims(b); err != nil { + return 0, errors.Trace(err) + } + + var distance float32 = 0.0 + var norma float32 = 0.0 + var normb float32 = 0.0 + va := a.Elements() + vb := b.Elements() + + for i, iMax := 0, a.Len(); i < iMax; i++ { + // Hope this can be vectorized. + distance += va[i] * vb[i] + norma += va[i] * va[i] + normb += vb[i] * vb[i] + } + + similarity := float64(distance) / math.Sqrt(float64(norma)*float64(normb)) + + if math.IsNaN(similarity) { + // Divide by zero + return math.NaN(), nil + } + + if similarity > 1.0 { + similarity = 1.0 + } else if similarity < -1.0 { + similarity = -1.0 + } + + return 1.0 - similarity, nil +} + +// L1Distance returns the L1 distance between two vectors. +func (a VectorFloat32) L1Distance(b VectorFloat32) (float64, error) { + if err := a.checkIdenticalDims(b); err != nil { + return 0, errors.Trace(err) + } + + var distance float32 = 0.0 + va := a.Elements() + vb := b.Elements() + + for i, iMax := 0, a.Len(); i < iMax; i++ { + // Hope this can be vectorized. + diff := va[i] - vb[i] + if diff < 0 { + diff = -diff + } + distance += diff + } + + return float64(distance), nil +} + +// L2Norm returns the L2 norm of the vector. +func (a VectorFloat32) L2Norm() float64 { + // Note: We align the impl with pgvector: Only l2_norm use double + // precision during calculation. + var norm float64 = 0.0 + + va := a.Elements() + for i, iMax := 0, a.Len(); i < iMax; i++ { + // Hope this can be vectorized. + norm += float64(va[i]) * float64(va[i]) + } + return math.Sqrt(norm) +} + +// Add adds two vectors. The vectors must have the same dimension. +func (a VectorFloat32) Add(b VectorFloat32) (VectorFloat32, error) { + if err := a.checkIdenticalDims(b); err != nil { + return ZeroVectorFloat32, errors.Trace(err) + } + + result := InitVectorFloat32(a.Len()) + + va := a.Elements() + vb := b.Elements() + vr := result.Elements() + + for i, iMax := 0, a.Len(); i < iMax; i++ { + // Hope this can be vectorized. + vr[i] = va[i] + vb[i] + } + for i, iMax := 0, a.Len(); i < iMax; i++ { + if math.IsInf(float64(vr[i]), 0) { + return ZeroVectorFloat32, errors.Errorf("value out of range: overflow") + } + if math.IsNaN(float64(vr[i])) { + return ZeroVectorFloat32, errors.Errorf("value out of range: NaN") + } + } + + return result, nil +} + +// Sub subtracts two vectors. The vectors must have the same dimension. +func (a VectorFloat32) Sub(b VectorFloat32) (VectorFloat32, error) { + if err := a.checkIdenticalDims(b); err != nil { + return ZeroVectorFloat32, errors.Trace(err) + } + + result := InitVectorFloat32(a.Len()) + + va := a.Elements() + vb := b.Elements() + vr := result.Elements() + + for i, iMax := 0, a.Len(); i < iMax; i++ { + // Hope this can be vectorized. + vr[i] = va[i] - vb[i] + } + + for i, iMax := 0, a.Len(); i < iMax; i++ { + if math.IsInf(float64(vr[i]), 0) { + return ZeroVectorFloat32, errors.Errorf("value out of range: overflow") + } + if math.IsNaN(float64(vr[i])) { + return ZeroVectorFloat32, errors.Errorf("value out of range: NaN") + } + } + + return result, nil +} + +// Mul multiplies two vectors. The vectors must have the same dimension. +func (a VectorFloat32) Mul(b VectorFloat32) (VectorFloat32, error) { + if err := a.checkIdenticalDims(b); err != nil { + return ZeroVectorFloat32, errors.Trace(err) + } + + result := InitVectorFloat32(a.Len()) + + va := a.Elements() + vb := b.Elements() + vr := result.Elements() + + for i, iMax := 0, a.Len(); i < iMax; i++ { + // Hope this can be vectorized. + vr[i] = va[i] * vb[i] + } + + for i, iMax := 0, a.Len(); i < iMax; i++ { + if math.IsInf(float64(vr[i]), 0) { + return ZeroVectorFloat32, errors.Errorf("value out of range: overflow") + } + if math.IsNaN(float64(vr[i])) { + return ZeroVectorFloat32, errors.Errorf("value out of range: NaN") + } + + // TODO: Check for underflow. + // See https://github.com/pgvector/pgvector/blob/81d13bd40f03890bb5b6360259628cd473c2e467/src/vector.c#L873 + } + + return result, nil +} + +// Compare returns an integer comparing two vectors. The result will be 0 if a==b, -1 if a < b, and +1 if a > b. +func (a VectorFloat32) Compare(b VectorFloat32) int { + la := a.Len() + lb := b.Len() + commonLen := la + if lb < commonLen { + commonLen = lb + } + + va := a.Elements() + vb := b.Elements() + + for i := 0; i < commonLen; i++ { + if va[i] < vb[i] { + return -1 + } else if va[i] > vb[i] { + return 1 + } + } + if la < lb { + return -1 + } else if la > lb { + return 1 + } + return 0 +} diff --git a/pkg/types/vector_test.go b/pkg/types/vector_test.go new file mode 100644 index 0000000000000..dddf3686dcbc4 --- /dev/null +++ b/pkg/types/vector_test.go @@ -0,0 +1,133 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package types_test + +import ( + "testing" + + "github.com/pingcap/tidb/pkg/types" + "github.com/stretchr/testify/require" +) + +func TestVectorEndianess(t *testing.T) { + // Note: This test fails in BigEndian machines. + v := types.InitVectorFloat32(2) + vv := v.Elements() + vv[0] = 1.1 + vv[1] = 2.2 + require.Equal(t, []byte{ + /* Length = 0x02 */ + 0x02, 0x00, 0x00, 0x00, + /* Element 1 = 0x3f8ccccd */ + 0xcd, 0xcc, 0x8c, 0x3f, + /* Element 2 = 0x400ccccd */ + 0xcd, 0xcc, 0x0c, 0x40, + }, v.SerializeTo(nil)) +} + +func TestZeroVector(t *testing.T) { + require.True(t, types.ZeroVectorFloat32.IsZeroValue()) + require.Equal(t, 0, types.ZeroVectorFloat32.Compare(types.ZeroVectorFloat32)) + + require.Equal(t, []byte{0, 0, 0, 0}, types.ZeroVectorFloat32.ZeroCopySerialize()) + require.Equal(t, 4, types.ZeroVectorFloat32.SerializedSize()) + require.Equal(t, []byte{0, 0, 0, 0}, types.ZeroVectorFloat32.SerializeTo(nil)) + require.Equal(t, []byte{1, 2, 3, 0, 0, 0, 0}, types.ZeroVectorFloat32.SerializeTo([]byte{1, 2, 3})) + + v, remaining, err := types.ZeroCopyDeserializeVectorFloat32([]byte{0, 0, 0, 0}) + require.Nil(t, err) + require.Len(t, remaining, 0) + require.Equal(t, 0, v.Len()) + require.Equal(t, "[]", v.String()) + require.True(t, v.IsZeroValue()) + require.Equal(t, 0, v.Compare(types.ZeroVectorFloat32)) + require.Equal(t, 0, types.ZeroVectorFloat32.Compare(v)) +} + +func TestVectorParse(t *testing.T) { + v, err := types.ParseVectorFloat32(`abc`) + require.NotNil(t, err) + require.True(t, v.IsZeroValue()) + + // Note: Currently we will parse "null" into []. + v, err = types.ParseVectorFloat32(`null`) + require.Nil(t, err) + require.True(t, v.IsZeroValue()) + + v, err = types.ParseVectorFloat32(`"json_str"`) + require.NotNil(t, err) + require.True(t, v.IsZeroValue()) + + v, err = types.ParseVectorFloat32(`123`) + require.NotNil(t, err) + require.True(t, v.IsZeroValue()) + + v, err = types.ParseVectorFloat32(`[123`) + require.NotNil(t, err) + require.True(t, v.IsZeroValue()) + + v, err = types.ParseVectorFloat32(`123]`) + require.NotNil(t, err) + require.True(t, v.IsZeroValue()) + + v, err = types.ParseVectorFloat32(`[123,]`) + require.NotNil(t, err) + require.True(t, v.IsZeroValue()) + + v, err = types.ParseVectorFloat32(`[]`) + require.Nil(t, err) + require.Equal(t, 0, v.Len()) + require.Equal(t, "[]", v.String()) + require.True(t, v.IsZeroValue()) + require.Equal(t, 0, v.Compare(types.ZeroVectorFloat32)) + require.Equal(t, 0, types.ZeroVectorFloat32.Compare(v)) + + v, err = types.ParseVectorFloat32(`[1.1, 2.2, 3.3]`) + require.Nil(t, err) + require.Equal(t, 3, v.Len()) + require.Equal(t, "[1.1,2.2,3.3]", v.String()) + require.False(t, v.IsZeroValue()) + require.Equal(t, 1, v.Compare(types.ZeroVectorFloat32)) + require.Equal(t, -1, types.ZeroVectorFloat32.Compare(v)) +} + +func TestVectorDatum(t *testing.T) { + d := types.NewDatum(nil) + d.SetVectorFloat32(types.ZeroVectorFloat32) + v := d.GetVectorFloat32() + require.Equal(t, 0, v.Len()) + require.Equal(t, "[]", v.String()) + require.True(t, v.IsZeroValue()) + require.Equal(t, 0, v.Compare(types.ZeroVectorFloat32)) + require.Equal(t, 0, types.ZeroVectorFloat32.Compare(v)) +} + +func TestVectorCompare(t *testing.T) { + v1, err := types.ParseVectorFloat32(`[1.1, 2.2, 3.3]`) + require.NoError(t, err) + v2, err := types.ParseVectorFloat32(`[-1.1, 4.2]`) + require.NoError(t, err) + + require.Equal(t, 1, v1.Compare(v2)) + require.Equal(t, -1, v2.Compare(v1)) + + v1, err = types.ParseVectorFloat32(`[1.1, 2.2, 3.3]`) + require.NoError(t, err) + v2, err = types.ParseVectorFloat32(`[1.1, 4.2]`) + require.NoError(t, err) + + require.Equal(t, -1, v1.Compare(v2)) + require.Equal(t, 1, v2.Compare(v1)) +} diff --git a/pkg/util/chunk/chunk.go b/pkg/util/chunk/chunk.go index cf17cdb388143..76a552852da97 100644 --- a/pkg/util/chunk/chunk.go +++ b/pkg/util/chunk/chunk.go @@ -612,6 +612,12 @@ func (c *Chunk) AppendJSON(colIdx int, j types.BinaryJSON) { c.columns[colIdx].AppendJSON(j) } +// AppendVectorFloat32 appends a VectorFloat32 value to the chunk. +func (c *Chunk) AppendVectorFloat32(colIdx int, v types.VectorFloat32) { + c.appendSel(colIdx) + c.columns[colIdx].AppendVectorFloat32(v) +} + func (c *Chunk) appendSel(colIdx int) { if colIdx == 0 && c.sel != nil { // use column 0 as standard c.sel = append(c.sel, c.columns[0].length) @@ -645,6 +651,8 @@ func (c *Chunk) AppendDatum(colIdx int, d *types.Datum) { c.AppendTime(colIdx, d.GetMysqlTime()) case types.KindMysqlJSON: c.AppendJSON(colIdx, d.GetMysqlJSON()) + case types.KindVectorFloat32: + c.AppendVectorFloat32(colIdx, d.GetVectorFloat32()) } } diff --git a/pkg/util/chunk/column.go b/pkg/util/chunk/column.go index 13a1877a06da7..5cdbcb2b986c4 100644 --- a/pkg/util/chunk/column.go +++ b/pkg/util/chunk/column.go @@ -53,6 +53,12 @@ func (c *Column) AppendJSON(j types.BinaryJSON) { c.finishAppendVar() } +// AppendVectorFloat32 appends a VectorFloat32 value into this Column. +func (c *Column) AppendVectorFloat32(v types.VectorFloat32) { + c.data = v.SerializeTo(c.data) + c.finishAppendVar() +} + // AppendSet appends a Set value into this Column. func (c *Column) AppendSet(set types.Set) { c.appendNameValue(set.Name, set.Value) @@ -161,6 +167,8 @@ func (c *Column) Reset(eType types.EvalType) { c.ResizeGoDuration(0, false) case types.ETJson: c.ReserveJSON(0) + case types.ETVectorFloat32: + c.ReserveVectorFloat32(0) default: panic(fmt.Sprintf("invalid EvalType %v", eType)) } @@ -546,6 +554,11 @@ func (c *Column) ReserveJSON(n int) { c.reserve(n, 8) } +// ReserveVectorFloat32 changes the column capacity to store n vectorFloat32 elements and set the length to zero. +func (c *Column) ReserveVectorFloat32(n int) { + c.reserve(n, 8) +} + // ReserveSet changes the column capacity to store n set elements and set the length to zero. func (c *Column) ReserveSet(n int) { c.reserve(n, 8) @@ -648,6 +661,16 @@ func (c *Column) GetJSON(rowID int) types.BinaryJSON { return types.BinaryJSON{TypeCode: c.data[start], Value: c.data[start+1 : c.offsets[rowID+1]]} } +// GetVectorFloat32 returns the VectorFloat32 in the specific row. +func (c *Column) GetVectorFloat32(rowID int) types.VectorFloat32 { + data := c.data[c.offsets[rowID]:c.offsets[rowID+1]] + v, _, err := types.ZeroCopyDeserializeVectorFloat32(data) + if err != nil { + panic(err) + } + return v +} + // GetBytes returns the byte slice in the specific row. func (c *Column) GetBytes(rowID int) []byte { return c.data[c.offsets[rowID]:c.offsets[rowID+1]] diff --git a/pkg/util/chunk/compare.go b/pkg/util/chunk/compare.go index 86ab092f66f84..86f93ac640337 100644 --- a/pkg/util/chunk/compare.go +++ b/pkg/util/chunk/compare.go @@ -53,6 +53,8 @@ func GetCompareFunc(tp *types.FieldType) CompareFunc { return cmpBit case mysql.TypeJSON: return cmpJSON + case mysql.TypeTiDBVectorFloat32: + return cmpVectorFloat32 case mysql.TypeNull: return cmpNullConst } @@ -175,6 +177,15 @@ func cmpNullConst(_ Row, _ int, _ Row, _ int) int { return 0 } +func cmpVectorFloat32(l Row, lCol int, r Row, rCol int) int { + lNull, rNull := l.IsNull(lCol), r.IsNull(rCol) + if lNull || rNull { + return cmpNull(lNull, rNull) + } + lv, rv := l.GetVectorFloat32(lCol), r.GetVectorFloat32(rCol) + return lv.Compare(rv) +} + // Compare compares the value with ad. // We assume that the collation information of the column is the same with the datum. func Compare(row Row, colIdx int, ad *types.Datum) int { @@ -218,6 +229,9 @@ func Compare(row Row, colIdx int, ad *types.Datum) int { case types.KindMysqlJSON: l, r := row.GetJSON(colIdx), ad.GetMysqlJSON() return types.CompareBinaryJSON(l, r) + case types.KindVectorFloat32: + l, r := row.GetVectorFloat32(colIdx), ad.GetVectorFloat32() + return l.Compare(r) case types.KindMysqlTime: l, r := row.GetTime(colIdx), ad.GetMysqlTime() return l.Compare(r) diff --git a/pkg/util/chunk/mutrow.go b/pkg/util/chunk/mutrow.go index 5408631b19882..fa19ca76bf056 100644 --- a/pkg/util/chunk/mutrow.go +++ b/pkg/util/chunk/mutrow.go @@ -113,6 +113,8 @@ func zeroValForType(tp *types.FieldType) any { return types.Enum{} case mysql.TypeJSON: return types.CreateBinaryJSON(nil) + case mysql.TypeTiDBVectorFloat32: + return types.ZeroVectorFloat32 default: return nil } @@ -155,6 +157,8 @@ func makeMutRowColumn(in any) *Column { col.data[0] = x.TypeCode copy(col.data[1:], x.Value) return col + case types.VectorFloat32: + return makeMutRowBytesColumn(x.ZeroCopySerialize()) case types.Duration: col := newMutRowFixedLenColumn(8) *(*int64)(unsafe.Pointer(&col.data[0])) = int64(x.Duration) @@ -278,6 +282,8 @@ func (mr MutRow) SetValue(colIdx int, val any) { setMutRowNameValue(col, x.Name, x.Value) case types.BinaryJSON: setMutRowJSON(col, x) + case types.VectorFloat32: + setMutRowBytes(col, x.ZeroCopySerialize()) } col.nullBitmap[0] = 1 } @@ -311,6 +317,8 @@ func (mr MutRow) SetDatum(colIdx int, d types.Datum) { *(*types.MyDecimal)(unsafe.Pointer(&col.data[0])) = *d.GetMysqlDecimal() case types.KindMysqlJSON: setMutRowJSON(col, d.GetMysqlJSON()) + case types.KindVectorFloat32: + setMutRowBytes(col, d.GetVectorFloat32().ZeroCopySerialize()) case types.KindMysqlEnum: e := d.GetMysqlEnum() setMutRowNameValue(col, e.Name, e.Value) diff --git a/pkg/util/chunk/row.go b/pkg/util/chunk/row.go index d08dabbb7916f..d3fd285f1936f 100644 --- a/pkg/util/chunk/row.go +++ b/pkg/util/chunk/row.go @@ -115,6 +115,11 @@ func (r Row) GetJSON(colIdx int) types.BinaryJSON { return r.c.columns[colIdx].GetJSON(r.idx) } +// GetVectorFloat32 returns the VectorFloat32 value with the colIdx. +func (r Row) GetVectorFloat32(colIdx int) types.VectorFloat32 { + return r.c.columns[colIdx].GetVectorFloat32(r.idx) +} + // GetDatumRow converts chunk.Row to types.DatumRow. // Keep in mind that GetDatumRow has a reference to r.c, which is a chunk, // this function works only if the underlying chunk is valid or unchanged. @@ -206,6 +211,10 @@ func (r Row) DatumWithBuffer(colIdx int, tp *types.FieldType, d *types.Datum) { if !r.IsNull(colIdx) { d.SetMysqlJSON(r.GetJSON(colIdx)) } + case mysql.TypeTiDBVectorFloat32: + if !r.IsNull(colIdx) { + d.SetVectorFloat32(r.GetVectorFloat32(colIdx)) + } } if r.IsNull(colIdx) { d.SetNull() @@ -263,6 +272,8 @@ func (r Row) ToString(ft []*types.FieldType) string { case mysql.TypeDouble: buf = strconv.AppendFloat(buf, r.GetFloat64(colIdx), 'f', -1, 64) } + case types.ETVectorFloat32: + buf = append(buf, r.GetVectorFloat32(colIdx).String()...) } } if colIdx != r.Chunk().NumCols()-1 { diff --git a/pkg/util/codec/codec.go b/pkg/util/codec/codec.go index 3c693254387ca..2b8764e0ed10d 100644 --- a/pkg/util/codec/codec.go +++ b/pkg/util/codec/codec.go @@ -17,7 +17,6 @@ package codec import ( "bytes" "encoding/binary" - "fmt" "hash" "io" "time" @@ -36,18 +35,19 @@ import ( // First byte in the encoded value which specifies the encoding type. const ( - NilFlag byte = 0 - bytesFlag byte = 1 - compactBytesFlag byte = 2 - intFlag byte = 3 - uintFlag byte = 4 - floatFlag byte = 5 - decimalFlag byte = 6 - durationFlag byte = 7 - varintFlag byte = 8 - uvarintFlag byte = 9 - jsonFlag byte = 10 - maxFlag byte = 250 + NilFlag byte = 0 + bytesFlag byte = 1 + compactBytesFlag byte = 2 + intFlag byte = 3 + uintFlag byte = 4 + floatFlag byte = 5 + decimalFlag byte = 6 + durationFlag byte = 7 + varintFlag byte = 8 + uvarintFlag byte = 9 + jsonFlag byte = 10 + vectorFloat32Flag byte = 20 + maxFlag byte = 250 ) // IntHandleFlag is only used to encode int handle key. @@ -72,6 +72,8 @@ func preRealloc(b []byte, vals []types.Datum, comparable1 bool) []byte { size++ case types.KindMysqlJSON: size += 2 + len(vals[i].GetBytes()) + case types.KindVectorFloat32: + size += 1 + vals[i].GetVectorFloat32().SerializedSize() case types.KindMysqlDecimal: size += 1 + types.MyDecimalStructSize default: @@ -126,6 +128,11 @@ func encode(loc *time.Location, b []byte, vals []types.Datum, comparable1 bool) j := vals[i].GetMysqlJSON() b = append(b, j.TypeCode) b = append(b, j.Value...) + case types.KindVectorFloat32: + // Always do a small deser + ser for sanity check + b = append(b, vectorFloat32Flag) + v := vals[i].GetVectorFloat32() + b = v.SerializeTo(b) case types.KindNull: b = append(b, NilFlag) case types.KindMinNotNull: @@ -169,6 +176,9 @@ func EstimateValueSize(typeCtx types.Context, val types.Datum) (int, error) { l = valueSizeOfUnsignedInt(val) case types.KindMysqlJSON: l = 2 + len(val.GetMysqlJSON().Value) + case types.KindVectorFloat32: + v := val.GetVectorFloat32() + l = 1 + v.SerializedSize() case types.KindNull, types.KindMinNotNull, types.KindMaxValue: l = 1 default: @@ -385,6 +395,10 @@ func encodeHashChunkRowIdx(typeCtx types.Context, row chunk.Row, tp *types.Field flag = jsonFlag json := row.GetJSON(idx) b = json.HashValue(b) + case mysql.TypeTiDBVectorFloat32: + flag = vectorFloat32Flag + v := row.GetVectorFloat32(idx) + b = v.SerializeTo(nil) default: return 0, nil, errors.Errorf("unsupport column type for encode %d", tp.GetType()) } @@ -826,6 +840,25 @@ func HashChunkSelected(typeCtx types.Context, h []hash.Hash64, chk *chunk.Chunk, b = json.HashValue(b) } + // As the golang doc described, `Hash.Write` never returns an error.. + // See https://golang.org/pkg/hash/#Hash + _, _ = h[i].Write(buf) + _, _ = h[i].Write(b) + } + case mysql.TypeTiDBVectorFloat32: + for i := 0; i < rows; i++ { + if sel != nil && !sel[i] { + continue + } + if column.IsNull(i) { + buf[0], b = NilFlag, nil + isNull[i] = !ignoreNull + } else { + buf[0] = vectorFloat32Flag + v := column.GetVectorFloat32(i) + b = v.SerializeTo(nil) + } + // As the golang doc described, `Hash.Write` never returns an error.. // See https://golang.org/pkg/hash/#Hash _, _ = h[i].Write(buf) @@ -1037,6 +1070,13 @@ func DecodeOne(b []byte) (remain []byte, d types.Datum, err error) { j := types.BinaryJSON{TypeCode: b[0], Value: b[1:size]} d.SetMysqlJSON(j) b = b[size:] + case vectorFloat32Flag: + v, remaining, err := types.ZeroCopyDeserializeVectorFloat32(b) + if err != nil { + return b, d, errors.Trace(err) + } + d.SetVectorFloat32(v) + b = remaining case NilFlag: default: return b, d, errors.Errorf("invalid encoded key flag %v", flag) @@ -1338,6 +1378,13 @@ func (decoder *Decoder) DecodeOne(b []byte, colIdx int, ft *types.FieldType) (re } chk.AppendJSON(colIdx, types.BinaryJSON{TypeCode: b[0], Value: b[1:size]}) b = b[size:] + case vectorFloat32Flag: + v, remaining, err := types.ZeroCopyDeserializeVectorFloat32(b) + if err != nil { + return nil, errors.Trace(err) + } + chk.AppendVectorFloat32(colIdx, v) + b = remaining case NilFlag: chk.AppendNull(colIdx) default: @@ -1482,8 +1529,16 @@ func HashGroupKey(loc *time.Location, n int, col *chunk.Column, buf [][]byte, ft buf[i] = encodeBytes(buf[i], ConvertByCollation(col.GetBytes(i), ft), false) } } + case types.ETVectorFloat32: + for i := 0; i < n; i++ { + if col.IsNull(i) { + buf[i] = append(buf[i], NilFlag) + } else { + buf[i] = col.GetVectorFloat32(i).SerializeTo(buf[i]) + } + } default: - return nil, fmt.Errorf("invalid eval type %v", ft.EvalType()) + return nil, errors.Errorf("unsupported type %s during evaluation", ft.EvalType()) } return buf, nil } @@ -1539,6 +1594,10 @@ func HashCode(b []byte, d types.Datum) []byte { j := d.GetMysqlJSON() b = append(b, j.TypeCode) b = append(b, j.Value...) + case types.KindVectorFloat32: + b = append(b, vectorFloat32Flag) + v := d.GetVectorFloat32() + b = v.SerializeTo(b) case types.KindNull: b = append(b, NilFlag) case types.KindMinNotNull: diff --git a/pkg/util/rowcodec/common.go b/pkg/util/rowcodec/common.go index a3afecb6e6f52..518dd16fa9681 100644 --- a/pkg/util/rowcodec/common.go +++ b/pkg/util/rowcodec/common.go @@ -40,16 +40,17 @@ var ( // First byte in the encoded value which specifies the encoding type. const ( - NilFlag byte = 0 - BytesFlag byte = 1 - CompactBytesFlag byte = 2 - IntFlag byte = 3 - UintFlag byte = 4 - FloatFlag byte = 5 - DecimalFlag byte = 6 - VarintFlag byte = 8 - VaruintFlag byte = 9 - JSONFlag byte = 10 + NilFlag byte = 0 + BytesFlag byte = 1 + CompactBytesFlag byte = 2 + IntFlag byte = 3 + UintFlag byte = 4 + FloatFlag byte = 5 + DecimalFlag byte = 6 + VarintFlag byte = 8 + VaruintFlag byte = 9 + JSONFlag byte = 10 + VectorFloat32Flag byte = 20 ) func bytesToU32Slice(b []byte) []uint32 { @@ -350,6 +351,8 @@ func appendDatumForChecksum(loc *time.Location, buf []byte, dat *data.Datum, typ out = binary.LittleEndian.AppendUint64(buf, v) case mysql.TypeJSON: out = appendLengthValue(buf, []byte(dat.GetMysqlJSON().String())) + case mysql.TypeTiDBVectorFloat32: + out = dat.GetVectorFloat32().SerializeTo(buf) case mysql.TypeNull, mysql.TypeGeometry: out = buf default: diff --git a/pkg/util/rowcodec/decoder.go b/pkg/util/rowcodec/decoder.go index 3fe751d0eade4..2d633853cc422 100644 --- a/pkg/util/rowcodec/decoder.go +++ b/pkg/util/rowcodec/decoder.go @@ -172,6 +172,12 @@ func (decoder *DatumMapDecoder) decodeColDatum(col *ColInfo, colData []byte) (ty j.TypeCode = colData[0] j.Value = colData[1:] d.SetMysqlJSON(j) + case mysql.TypeTiDBVectorFloat32: + v, _, err := types.ZeroCopyDeserializeVectorFloat32(colData) + if err != nil { + return d, err + } + d.SetVectorFloat32(v) default: return d, errors.Errorf("unknown type %d", col.Ft.GetType()) } @@ -351,6 +357,12 @@ func (decoder *ChunkDecoder) decodeColToChunk(colIdx int, col *ColInfo, colData j.TypeCode = colData[0] j.Value = colData[1:] chk.AppendJSON(colIdx, j) + case mysql.TypeTiDBVectorFloat32: + v, _, err := types.ZeroCopyDeserializeVectorFloat32(colData) + if err != nil { + return err + } + chk.AppendVectorFloat32(colIdx, v) default: return errors.Errorf("unknown type %d", col.Ft.GetType()) } @@ -511,6 +523,8 @@ func fieldType2Flag(tp byte, signed bool) (flag byte) { flag = UintFlag case mysql.TypeJSON: flag = JSONFlag + case mysql.TypeTiDBVectorFloat32: + flag = VectorFloat32Flag case mysql.TypeNull: flag = NilFlag default: diff --git a/pkg/util/rowcodec/encoder.go b/pkg/util/rowcodec/encoder.go index 0651c481e9ceb..6ba51fb625907 100644 --- a/pkg/util/rowcodec/encoder.go +++ b/pkg/util/rowcodec/encoder.go @@ -215,6 +215,9 @@ func encodeValueDatum(loc *time.Location, d *types.Datum, buffer []byte) (nBuffe j := d.GetMysqlJSON() buffer = append(buffer, j.TypeCode) buffer = append(buffer, j.Value...) + case types.KindVectorFloat32: + v := d.GetVectorFloat32() + buffer = v.SerializeTo(buffer) default: err = errors.Errorf("unsupport encode type %d", d.Kind()) } diff --git a/pkg/util/schemacmp/type.go b/pkg/util/schemacmp/type.go index 5e0e52e9bbb59..8fd36adadad46 100644 --- a/pkg/util/schemacmp/type.go +++ b/pkg/util/schemacmp/type.go @@ -166,6 +166,8 @@ func (a typ) getStandardDefaultValue() any { return "0000" case mysql.TypeJSON: return "null" + case mysql.TypeTiDBVectorFloat32: + return "[]" case mysql.TypeEnum: return a.Tuple[fieldTypeTupleIndexElems].(StringList)[0] case mysql.TypeString: diff --git a/pkg/util/sem/sem.go b/pkg/util/sem/sem.go index 25ae2a9413450..f594d789dba29 100644 --- a/pkg/util/sem/sem.go +++ b/pkg/util/sem/sem.go @@ -158,6 +158,7 @@ func IsInvisibleSysVar(varNameInLower string) bool { variable.TiDBRestrictedReadOnly, variable.TiDBTopSQLMaxTimeSeriesCount, variable.TiDBTopSQLMaxMetaCount, + variable.TiDBEnableVectorType, tidbAuditRetractLog: return true } diff --git a/tests/integrationtest/r/executor/show.result b/tests/integrationtest/r/executor/show.result index 481704dd5a0fb..be1fe6fc7e470 100644 --- a/tests/integrationtest/r/executor/show.result +++ b/tests/integrationtest/r/executor/show.result @@ -875,6 +875,14 @@ uuid uuid_short uuid_to_bin validate_password_strength +vec_as_text +vec_cosine_distance +vec_dims +vec_from_text +vec_l1_distance +vec_l2_distance +vec_l2_norm +vec_negative_inner_product version vitess_hash week