Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: Add vector functions #55021

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
98b15cf
Add Vector data type
EricZequan Jul 15, 2024
68e3546
Add Vector data type
EricZequan Jul 15, 2024
f01373a
Add Vector data type
EricZequan Jul 15, 2024
3420f9b
Add Vector data type
EricZequan Jul 15, 2024
9b22faf
Add Vector data type
EricZequan Jul 16, 2024
0037a86
Add Vector data type
EricZequan Jul 16, 2024
7227ca2
Add Vector data type(2)
EricZequan Jul 16, 2024
acfdf39
Add Vector data type(3)
EricZequan Jul 17, 2024
c57f4b1
Add Vector data type(4)
EricZequan Jul 17, 2024
becbe6a
Add Vector data type(4)
EricZequan Jul 17, 2024
74e8abe
Add Vector data type(5)
EricZequan Jul 17, 2024
eb1571a
Add Vector data type(6)
EricZequan Jul 17, 2024
cb09611
Add Vector data type(7)
EricZequan Jul 17, 2024
066ed32
Add Vector data type(8)
EricZequan Jul 18, 2024
4cce327
Add Vector data type(9)
EricZequan Jul 18, 2024
66ed138
Add Vector data type(10)
EricZequan Jul 18, 2024
0e9229d
vector data type(11)
EricZequan Jul 22, 2024
786ab30
Add Vector Data Type(12)
EricZequan Jul 22, 2024
3feccce
Add Vector Data Type(12)
EricZequan Jul 22, 2024
f4bfc0b
Merge branch 'pingcap:master' into vector-type
EricZequan Jul 22, 2024
e251f72
Add Vector Data Type(13)
EricZequan Jul 25, 2024
08a2db3
Merge remote-tracking branch 'origin' into vector-type
EricZequan Jul 25, 2024
62cd94b
fixed dimension vector
EricZequan Jul 29, 2024
2ac243b
fixed dimension vector
EricZequan Jul 29, 2024
02df1dd
remove some unneed line
EricZequan Jul 29, 2024
b780290
fix a bug when using 'update'
EricZequan Jul 30, 2024
2d6a2b0
fix test example fail
EricZequan Jul 30, 2024
9c16ebe
fix test example fail
EricZequan Jul 30, 2024
54b6e38
modify some code write style
EricZequan Jul 31, 2024
a53550f
fix a test-function run fail
EricZequan Jul 31, 2024
ed8ff17
fix a test-function run fail
EricZequan Jul 31, 2024
ab0a49b
change code writing
EricZequan Jul 31, 2024
7589def
modify pkg/parser/parser.y
EricZequan Jul 31, 2024
1a161d8
modify pkg/ddl/index.go
EricZequan Jul 31, 2024
f132098
modify pkg/ddl/index.go
EricZequan Jul 31, 2024
6803a28
Merge branch 'vector-type' into pr-872
EricZequan Jul 31, 2024
e139397
Merge branch 'pr-872' into pr-878
EricZequan Jul 31, 2024
3448a11
modify pkg/types/datum.go
EricZequan Aug 1, 2024
737b81f
Merge branch 'pr-872' into pr-878
EricZequan Aug 1, 2024
a85553c
Add vector function
EricZequan Aug 1, 2024
2b853cb
fix multiply function for vector
EricZequan Aug 1, 2024
f277c83
Merge branch 'feature/vector-search/vector-data-type' into pr-872
EricZequan Aug 2, 2024
157592d
removed the variable EnableVectorType
EricZequan Aug 2, 2024
2bebf25
fix TestVectorColumnInfo fail
EricZequan Aug 5, 2024
91190c6
fix TestVectorColumnInfo fail
EricZequan Aug 5, 2024
093d2a7
Merge branch 'pr-872' into pr-878
EricZequan Aug 5, 2024
c453c02
remove the 'vector-type-enable' in test
EricZequan Aug 5, 2024
5e30460
change the limitation of vector-dimension into 16383
EricZequan Aug 5, 2024
70cdff3
modify vector dimention test example
EricZequan Aug 5, 2024
03c24e6
Merge branch 'pr-872' into pr-878
EricZequan Aug 6, 2024
6598227
Merge branch 'feature/vector-search/vector-data-type' into pr-878
EricZequan Aug 6, 2024
b32e75e
change 'builtinLeastVectorFloat32Sig' location
EricZequan Aug 6, 2024
281bfc7
Merge branch 'pr-878' of https://github.com/EricZequan/tidb into pr-878
EricZequan Aug 6, 2024
026043f
change function 'checkVectorAggPushDown' to return bool
EricZequan Aug 9, 2024
4107e95
fix
EricZequan Aug 9, 2024
1293668
fix
EricZequan Aug 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions pkg/expression/aggregation/aggregation.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ func CheckAggPushDown(ctx expression.EvalContext, aggFunc *AggFuncDesc, storeTyp
if aggFunc.Name == ast.AggFuncApproxPercentile {
return false
}
if !checkVectorAggPushDown(ctx, aggFunc) {
return false
}
ret := true
switch storeType {
case kv.TiFlash:
Expand All @@ -253,6 +256,22 @@ func CheckAggPushDown(ctx expression.EvalContext, aggFunc *AggFuncDesc, storeTyp
return ret
}

// checkVectorAggPushDown returns false if this aggregate function is not supported to push down.
// - The aggregate function is not calculated over a Vector column (returns true)
// - The aggregate function is calculated over a Vector column and the function is supported (returns true)
// - The aggregate function is calculated over a Vector column and the function is not supported (returns false)
func checkVectorAggPushDown(ctx expression.EvalContext, aggFunc *AggFuncDesc) bool {
switch aggFunc.Name {
case ast.AggFuncCount, ast.AggFuncMin, ast.AggFuncMax, ast.AggFuncFirstRow:
return true
default:
if aggFunc.Args[0].GetType(ctx).GetType() == mysql.TypeTiDBVectorFloat32 {
return false
}
}
return true
}

// CheckAggPushFlash checks whether an agg function can be pushed to flash storage.
func CheckAggPushFlash(ctx expression.EvalContext, aggFunc *AggFuncDesc) bool {
for _, arg := range aggFunc.Args {
Expand Down
11 changes: 8 additions & 3 deletions pkg/expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -929,9 +929,14 @@ var funcs = map[string]functionClass{
ast.JSONLength: &jsonLengthFunctionClass{baseFunctionClass{ast.JSONLength, 1, 2}},

// vector functions (TiDB extension)
ast.VecDims: &vecDimsFunctionClass{baseFunctionClass{ast.VecDims, 1, 1}},
ast.VecFromText: &vecFromTextFunctionClass{baseFunctionClass{ast.VecFromText, 1, 1}},
ast.VecAsText: &vecAsTextFunctionClass{baseFunctionClass{ast.VecAsText, 1, 1}},
ast.VecDims: &vecDimsFunctionClass{baseFunctionClass{ast.VecDims, 1, 1}},
ast.VecL1Distance: &vecL1DistanceFunctionClass{baseFunctionClass{ast.VecL1Distance, 2, 2}},
ast.VecL2Distance: &vecL2DistanceFunctionClass{baseFunctionClass{ast.VecL2Distance, 2, 2}},
ast.VecNegativeInnerProduct: &vecNegativeInnerProductFunctionClass{baseFunctionClass{ast.VecNegativeInnerProduct, 2, 2}},
ast.VecCosineDistance: &vecCosineDistanceFunctionClass{baseFunctionClass{ast.VecCosineDistance, 2, 2}},
ast.VecL2Norm: &vecL2NormFunctionClass{baseFunctionClass{ast.VecL2Norm, 1, 1}},
ast.VecFromText: &vecFromTextFunctionClass{baseFunctionClass{ast.VecFromText, 1, 1}},
ast.VecAsText: &vecAsTextFunctionClass{baseFunctionClass{ast.VecAsText, 1, 1}},

// TiDB internal function.
ast.TiDBDecodeKey: &tidbDecodeKeyFunctionClass{baseFunctionClass{ast.TiDBDecodeKey, 1, 1}},
Expand Down
118 changes: 118 additions & 0 deletions pkg/expression/builtin_arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ var (
_ builtinFunc = &builtinArithmeticModIntSignedSignedSig{}
_ builtinFunc = &builtinArithmeticModRealSig{}
_ builtinFunc = &builtinArithmeticModDecimalSig{}

_ builtinFunc = &builtinArithmeticPlusVectorFloat32Sig{}
_ builtinFunc = &builtinArithmeticMinusVectorFloat32Sig{}
_ builtinFunc = &builtinArithmeticMultiplyVectorFloat32Sig{}
)

// isConstantBinaryLiteral return true if expr is constant binary literal
Expand Down Expand Up @@ -167,6 +171,15 @@ func (c *arithmeticPlusFunctionClass) getFunction(ctx BuildContext, args []Expre
if err := c.verifyArgs(args); err != nil {
return nil, err
}
if args[0].GetType(ctx.GetEvalCtx()).EvalType().IsVectorKind() || args[1].GetType(ctx.GetEvalCtx()).EvalType().IsVectorKind() {
bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETVectorFloat32, types.ETVectorFloat32, types.ETVectorFloat32)
if err != nil {
return nil, err
}
sig := &builtinArithmeticPlusVectorFloat32Sig{bf}
// sig.setPbCode(tipb.ScalarFuncSig_PlusVectorFloat32)
return sig, nil
}
lhsEvalTp, rhsEvalTp := numericContextResultType(ctx.GetEvalCtx(), args[0]), numericContextResultType(ctx.GetEvalCtx(), args[1])
if lhsEvalTp == types.ETReal || rhsEvalTp == types.ETReal {
bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETReal, types.ETReal, types.ETReal)
Expand Down Expand Up @@ -317,6 +330,15 @@ func (c *arithmeticMinusFunctionClass) getFunction(ctx BuildContext, args []Expr
if err := c.verifyArgs(args); err != nil {
return nil, err
}
if args[0].GetType(ctx.GetEvalCtx()).EvalType().IsVectorKind() || args[1].GetType(ctx.GetEvalCtx()).EvalType().IsVectorKind() {
bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETVectorFloat32, types.ETVectorFloat32, types.ETVectorFloat32)
if err != nil {
return nil, err
}
sig := &builtinArithmeticMinusVectorFloat32Sig{bf}
// sig.setPbCode(tipb.ScalarFuncSig_PlusVectorFloat32)
return sig, nil
}
lhsEvalTp, rhsEvalTp := numericContextResultType(ctx.GetEvalCtx(), args[0]), numericContextResultType(ctx.GetEvalCtx(), args[1])
if lhsEvalTp == types.ETReal || rhsEvalTp == types.ETReal {
bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETReal, types.ETReal, types.ETReal)
Expand Down Expand Up @@ -500,6 +522,15 @@ func (c *arithmeticMultiplyFunctionClass) getFunction(ctx BuildContext, args []E
if err := c.verifyArgs(args); err != nil {
return nil, err
}
if args[0].GetType(ctx.GetEvalCtx()).EvalType().IsVectorKind() || args[1].GetType(ctx.GetEvalCtx()).EvalType().IsVectorKind() {
bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETVectorFloat32, types.ETVectorFloat32, types.ETVectorFloat32)
if err != nil {
return nil, err
}
sig := &builtinArithmeticMultiplyVectorFloat32Sig{bf}
// sig.setPbCode(tipb.ScalarFuncSig_PlusVectorFloat32)
return sig, nil
}
lhsTp, rhsTp := args[0].GetType(ctx.GetEvalCtx()), args[1].GetType(ctx.GetEvalCtx())
lhsEvalTp, rhsEvalTp := numericContextResultType(ctx.GetEvalCtx(), args[0]), numericContextResultType(ctx.GetEvalCtx(), args[1])
if lhsEvalTp == types.ETReal || rhsEvalTp == types.ETReal {
Expand Down Expand Up @@ -1157,3 +1188,90 @@ func (s *builtinArithmeticModIntSignedSignedSig) evalInt(ctx EvalContext, row ch

return a % b, false, nil
}

type builtinArithmeticPlusVectorFloat32Sig struct {
baseBuiltinFunc
}

func (s *builtinArithmeticPlusVectorFloat32Sig) Clone() builtinFunc {
newSig := &builtinArithmeticPlusVectorFloat32Sig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}

func (s *builtinArithmeticPlusVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (types.VectorFloat32, bool, error) {
a, isLHSNull, err := s.args[0].EvalVectorFloat32(ctx, row)
if err != nil {
return types.ZeroVectorFloat32, isLHSNull, err
}
b, isRHSNull, err := s.args[1].EvalVectorFloat32(ctx, row)
if err != nil {
return types.ZeroVectorFloat32, isRHSNull, err
}
if isLHSNull || isRHSNull {
return types.ZeroVectorFloat32, true, nil
}
v, err := a.Add(b)
if err != nil {
return types.ZeroVectorFloat32, true, err
}
return v, false, nil
}

type builtinArithmeticMinusVectorFloat32Sig struct {
baseBuiltinFunc
}

func (s *builtinArithmeticMinusVectorFloat32Sig) Clone() builtinFunc {
newSig := &builtinArithmeticMinusVectorFloat32Sig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}

func (s *builtinArithmeticMinusVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (types.VectorFloat32, bool, error) {
a, isLHSNull, err := s.args[0].EvalVectorFloat32(ctx, row)
if err != nil {
return types.ZeroVectorFloat32, isLHSNull, err
}
b, isRHSNull, err := s.args[1].EvalVectorFloat32(ctx, row)
if err != nil {
return types.ZeroVectorFloat32, isRHSNull, err
}
if isLHSNull || isRHSNull {
return types.ZeroVectorFloat32, true, nil
}
v, err := a.Sub(b)
if err != nil {
return types.ZeroVectorFloat32, true, err
}
return v, false, nil
}

type builtinArithmeticMultiplyVectorFloat32Sig struct {
baseBuiltinFunc
}

func (s *builtinArithmeticMultiplyVectorFloat32Sig) Clone() builtinFunc {
newSig := &builtinArithmeticMultiplyVectorFloat32Sig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}

func (s *builtinArithmeticMultiplyVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (types.VectorFloat32, bool, error) {
a, isLHSNull, err := s.args[0].EvalVectorFloat32(ctx, row)
if err != nil {
return types.ZeroVectorFloat32, isLHSNull, err
}
b, isRHSNull, err := s.args[1].EvalVectorFloat32(ctx, row)
if err != nil {
return types.ZeroVectorFloat32, isRHSNull, err
}
if isLHSNull || isRHSNull {
return types.ZeroVectorFloat32, true, nil
}
v, err := a.Mul(b)
if err != nil {
return types.ZeroVectorFloat32, true, err
}
return v, false, nil
}
2 changes: 1 addition & 1 deletion pkg/expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ func (c *castAsVectorFloat32FunctionClass) getFunction(ctx BuildContext, args []
sig.setPbCode(tipb.ScalarFuncSig_CastVectorFloat32AsVectorFloat32)
case types.ETString:
sig = &builtinCastStringAsVectorFloat32Sig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CastStringAsVectorFloat32)
// sig.setPbCode(tipb.ScalarFuncSig_CastStringAsVectorFloat32)
default:
return nil, errors.Errorf("cannot cast from %s to %s", argTp, "VectorFloat32")
}
Expand Down
80 changes: 80 additions & 0 deletions pkg/expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ var (
_ builtinFunc = &builtinCoalesceStringSig{}
_ builtinFunc = &builtinCoalesceTimeSig{}
_ builtinFunc = &builtinCoalesceDurationSig{}
_ builtinFunc = &builtinCoalesceVectorFloat32Sig{}

_ builtinFunc = &builtinGreatestIntSig{}
_ builtinFunc = &builtinGreatestRealSig{}
Expand All @@ -54,13 +55,15 @@ var (
_ builtinFunc = &builtinGreatestDurationSig{}
_ builtinFunc = &builtinGreatestTimeSig{}
_ builtinFunc = &builtinGreatestCmpStringAsTimeSig{}
_ builtinFunc = &builtinGreatestVectorFloat32Sig{}
_ builtinFunc = &builtinLeastIntSig{}
_ builtinFunc = &builtinLeastRealSig{}
_ builtinFunc = &builtinLeastDecimalSig{}
_ builtinFunc = &builtinLeastStringSig{}
_ builtinFunc = &builtinLeastTimeSig{}
_ builtinFunc = &builtinLeastDurationSig{}
_ builtinFunc = &builtinLeastCmpStringAsTimeSig{}
_ builtinFunc = &builtinLeastVectorFloat32Sig{}
_ builtinFunc = &builtinIntervalIntSig{}
_ builtinFunc = &builtinIntervalRealSig{}

Expand Down Expand Up @@ -167,6 +170,9 @@ func (c *coalesceFunctionClass) getFunction(ctx BuildContext, args []Expression)
case types.ETJson:
sig = &builtinCoalesceJSONSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CoalesceJson)
case types.ETVectorFloat32:
sig = &builtinCoalesceVectorFloat32Sig{bf}
// sig.setPbCode(tipb.ScalarFuncSig_CoalesceVectorFloat32)
default:
return nil, errors.Errorf("%s is not supported for COALESCE()", retEvalTp)
}
Expand Down Expand Up @@ -331,6 +337,28 @@ func (b *builtinCoalesceJSONSig) evalJSON(ctx EvalContext, row chunk.Row) (res t
return res, isNull, err
}

// builtinCoalesceVectorFloat32Sig is builtin function coalesce signature which return type vector float32.
// See http://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#function_coalesce
type builtinCoalesceVectorFloat32Sig struct {
baseBuiltinFunc
}

func (b *builtinCoalesceVectorFloat32Sig) Clone() builtinFunc {
newSig := &builtinCoalesceVectorFloat32Sig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinCoalesceVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (res types.VectorFloat32, isNull bool, err error) {
for _, a := range b.getArgs() {
res, isNull, err = a.EvalVectorFloat32(ctx, row)
if err != nil || !isNull {
break
}
}
return res, isNull, err
}

func aggregateType(ctx EvalContext, args []Expression) *types.FieldType {
fieldTypes := make([]*types.FieldType, len(args))
for i := range fieldTypes {
Expand Down Expand Up @@ -499,6 +527,9 @@ func (c *greatestFunctionClass) getFunction(ctx BuildContext, args []Expression)
sig = &builtinGreatestTimeSig{bf, false}
sig.setPbCode(tipb.ScalarFuncSig_GreatestTime)
}
case types.ETVectorFloat32:
sig = &builtinGreatestVectorFloat32Sig{bf}
// sig.setPbCode(tipb.ScalarFuncSig_GreatestVectorFloat32)
default:
return nil, errors.Errorf("unsupported type %s during evaluation", argTp)
}
Expand Down Expand Up @@ -754,6 +785,29 @@ func (b *builtinGreatestDurationSig) evalDuration(ctx EvalContext, row chunk.Row
return res, false, nil
}

type builtinGreatestVectorFloat32Sig struct {
baseBuiltinFunc
}

func (b *builtinGreatestVectorFloat32Sig) Clone() builtinFunc {
newSig := &builtinGreatestVectorFloat32Sig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinGreatestVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (res types.VectorFloat32, isNull bool, err error) {
for i := 0; i < len(b.args); i++ {
v, isNull, err := b.args[i].EvalVectorFloat32(ctx, row)
if isNull || err != nil {
return types.VectorFloat32{}, true, err
}
if i == 0 || v.Compare(res) > 0 {
res = v
}
}
return res, false, nil
}

type leastFunctionClass struct {
baseFunctionClass
}
Expand Down Expand Up @@ -814,6 +868,9 @@ func (c *leastFunctionClass) getFunction(ctx BuildContext, args []Expression) (s
sig = &builtinLeastTimeSig{bf, false}
sig.setPbCode(tipb.ScalarFuncSig_LeastTime)
}
case types.ETVectorFloat32:
sig = &builtinLeastVectorFloat32Sig{bf}
// sig.setPbCode(tipb.ScalarFuncSig_LeastVectorFloat32)
default:
return nil, errors.Errorf("unsupported type %s during evaluation", argTp)
}
Expand Down Expand Up @@ -1039,6 +1096,29 @@ func (b *builtinLeastDurationSig) evalDuration(ctx EvalContext, row chunk.Row) (
return res, false, nil
}

type builtinLeastVectorFloat32Sig struct {
baseBuiltinFunc
}

func (b *builtinLeastVectorFloat32Sig) Clone() builtinFunc {
newSig := &builtinLeastVectorFloat32Sig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinLeastVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (res types.VectorFloat32, isNull bool, err error) {
for i := 0; i < len(b.args); i++ {
v, isNull, err := b.args[i].EvalVectorFloat32(ctx, row)
if isNull || err != nil {
return types.VectorFloat32{}, true, err
}
if i == 0 || v.Compare(res) < 0 {
res = v
}
}
return res, false, nil
}

type intervalFunctionClass struct {
baseFunctionClass
}
Expand Down
Loading