Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: support fixed dimension vector #55002

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
98b15cf
Add Vector data type
EricZequan Jul 15, 2024
68e3546
Add Vector data type
EricZequan Jul 15, 2024
f01373a
Add Vector data type
EricZequan Jul 15, 2024
3420f9b
Add Vector data type
EricZequan Jul 15, 2024
9b22faf
Add Vector data type
EricZequan Jul 16, 2024
0037a86
Add Vector data type
EricZequan Jul 16, 2024
7227ca2
Add Vector data type(2)
EricZequan Jul 16, 2024
acfdf39
Add Vector data type(3)
EricZequan Jul 17, 2024
c57f4b1
Add Vector data type(4)
EricZequan Jul 17, 2024
becbe6a
Add Vector data type(4)
EricZequan Jul 17, 2024
74e8abe
Add Vector data type(5)
EricZequan Jul 17, 2024
eb1571a
Add Vector data type(6)
EricZequan Jul 17, 2024
cb09611
Add Vector data type(7)
EricZequan Jul 17, 2024
066ed32
Add Vector data type(8)
EricZequan Jul 18, 2024
4cce327
Add Vector data type(9)
EricZequan Jul 18, 2024
66ed138
Add Vector data type(10)
EricZequan Jul 18, 2024
0e9229d
vector data type(11)
EricZequan Jul 22, 2024
786ab30
Add Vector Data Type(12)
EricZequan Jul 22, 2024
3feccce
Add Vector Data Type(12)
EricZequan Jul 22, 2024
f4bfc0b
Merge branch 'pingcap:master' into vector-type
EricZequan Jul 22, 2024
e251f72
Add Vector Data Type(13)
EricZequan Jul 25, 2024
08a2db3
Merge remote-tracking branch 'origin' into vector-type
EricZequan Jul 25, 2024
62cd94b
fixed dimension vector
EricZequan Jul 29, 2024
2ac243b
fixed dimension vector
EricZequan Jul 29, 2024
02df1dd
remove some unneed line
EricZequan Jul 29, 2024
b780290
fix a bug when using 'update'
EricZequan Jul 30, 2024
2d6a2b0
fix test example fail
EricZequan Jul 30, 2024
54b6e38
modify some code write style
EricZequan Jul 31, 2024
a53550f
fix a test-function run fail
EricZequan Jul 31, 2024
ab0a49b
change code writing
EricZequan Jul 31, 2024
7589def
modify pkg/parser/parser.y
EricZequan Jul 31, 2024
1a161d8
modify pkg/ddl/index.go
EricZequan Jul 31, 2024
f132098
modify pkg/ddl/index.go
EricZequan Jul 31, 2024
6803a28
Merge branch 'vector-type' into pr-872
EricZequan Jul 31, 2024
3448a11
modify pkg/types/datum.go
EricZequan Aug 1, 2024
f277c83
Merge branch 'feature/vector-search/vector-data-type' into pr-872
EricZequan Aug 2, 2024
157592d
removed the variable EnableVectorType
EricZequan Aug 2, 2024
2bebf25
fix TestVectorColumnInfo fail
EricZequan Aug 5, 2024
91190c6
fix TestVectorColumnInfo fail
EricZequan Aug 5, 2024
5e30460
change the limitation of vector-dimension into 16383
EricZequan Aug 5, 2024
70cdff3
modify vector dimention test example
EricZequan Aug 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pkg/ddl/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,8 @@ func needChangeColumnData(oldCol, newCol *model.ColumnInfo) bool {
if types.IsBinaryStr(&oldCol.FieldType) {
return newCol.GetFlen() != oldCol.GetFlen()
}
case mysql.TypeTiDBVectorFloat32:
return newCol.GetFlen() != types.UnspecifiedLength && oldCol.GetFlen() != newCol.GetFlen()
}

return needTruncationOrToggleSign()
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ func newReturnFieldTypeForBaseBuiltinFunc(funcName string, retType types.EvalTyp
case types.ETJson:
fieldType = types.NewFieldTypeBuilder().SetType(mysql.TypeJSON).SetFlag(mysql.BinaryFlag).SetFlen(mysql.MaxBlobWidth).SetCharset(mysql.DefaultCharset).SetCollate(mysql.DefaultCollationName).BuildP()
case types.ETVectorFloat32:
fieldType = types.NewFieldTypeBuilder().SetType(mysql.TypeTiDBVectorFloat32).SetFlag(mysql.BinaryFlag).SetFlen(mysql.MaxBlobWidth).BuildP()
fieldType = types.NewFieldTypeBuilder().SetType(mysql.TypeTiDBVectorFloat32).SetFlag(mysql.BinaryFlag).SetFlen(types.UnspecifiedLength).BuildP()
}
if mysql.HasBinaryFlag(fieldType.GetFlag()) && fieldType.GetType() != mysql.TypeJSON {
fieldType.SetCharset(charset.CharsetBin)
Expand Down
19 changes: 16 additions & 3 deletions pkg/expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -781,8 +781,14 @@ func (b *builtinCastStringAsVectorFloat32Sig) evalVectorFloat32(ctx EvalContext,
if isNull || err != nil {
return types.ZeroVectorFloat32, isNull, err
}
res, err := types.ParseVectorFloat32(val)
return res, false, err
vec, err := types.ParseVectorFloat32(val)
if err != nil {
return types.ZeroVectorFloat32, false, err
}
if err = vec.CheckDimsFitColumn(b.tp.GetFlen()); err != nil {
return types.ZeroVectorFloat32, isNull, err
}
return vec, false, nil
}

type builtinCastVectorFloat32AsVectorFloat32Sig struct {
Expand All @@ -796,7 +802,14 @@ func (b *builtinCastVectorFloat32AsVectorFloat32Sig) Clone() builtinFunc {
}

func (b *builtinCastVectorFloat32AsVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (types.VectorFloat32, bool, error) {
return b.args[0].EvalVectorFloat32(ctx, row)
val, isNull, err := b.args[0].EvalVectorFloat32(ctx, row)
if isNull || err != nil {
return types.ZeroVectorFloat32, isNull, err
}
if err = val.CheckDimsFitColumn(b.tp.GetFlen()); err != nil {
return types.ZeroVectorFloat32, isNull, err
}
return val, false, nil
}

type builtinCastIntAsIntSig struct {
Expand Down
5 changes: 4 additions & 1 deletion pkg/expression/builtin_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ func (b *builtinVecFromTextSig) evalVectorFloat32(ctx EvalContext, row chunk.Row

vec, err := types.ParseVectorFloat32(v)
if err != nil {
return res, false, err
return types.ZeroVectorFloat32, false, err
}
if err = vec.CheckDimsFitColumn(b.tp.GetFlen()); err != nil {
return types.ZeroVectorFloat32, isNull, err
}

return vec, false, nil
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/integration_test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ go_test(
"main_test.go",
],
flaky = True,
shard_count = 33,
shard_count = 35,
deps = [
"//pkg/config",
"//pkg/domain",
Expand Down
116 changes: 111 additions & 5 deletions pkg/expression/integration_test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,90 @@ import (
"github.com/tikv/client-go/v2/oracle"
)

func TestVectorColumnInfo(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")

// Create vector type column without specified dimension.
tk.MustExec("create table t(embedding VECTOR)")
tk.MustExec("drop table if exists t;")
tk.MustExec("create table t(embedding VECTOR<FLOAT>)")

// SHOW CREATE TABLE
tk.MustQuery("show create table t").Check(testkit.Rows(
"t CREATE TABLE `t` (\n" +
" `embedding` vector<float> DEFAULT NULL\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin",
))

// SHOW COLUMNS
tk.MustQuery("show columns from t").Check(testkit.Rows(
"embedding vector<float> YES <nil> ",
))

// Create vector type column with specified dimension.
tk.MustExec("drop table if exists t;")
tk.MustExec("create table t(embedding VECTOR(3))")
tk.MustExec("drop table if exists t;")
tk.MustExec("create table t(embedding VECTOR<FLOAT>(3))")
tk.MustExec("drop table if exists t;")
tk.MustExec("create table t(embedding VECTOR<FLOAT>(0))")

// SHOW CREATE TABLE
tk.MustExec("drop table if exists t;")
tk.MustExec("create table t(embedding VECTOR(3))")
tk.MustQuery("show create table t").Check(testkit.Rows(
"t CREATE TABLE `t` (\n" +
" `embedding` vector<float>(3) DEFAULT NULL\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin",
))

// SHOW COLUMNS
tk.MustQuery("show columns from t").Check(testkit.Rows(
"embedding vector<float>(3) YES <nil> ",
))

// INFORMATION_SCHEMA.COLUMNS
tk.MustQuery("SELECT data_type, column_type FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = 't'").Check(testkit.Rows(
"vector<float> vector<float>(3)",
))

// Vector dimension MUST be equal or less than 16383.
tk.MustExec("drop table if exists t;")
tk.MustGetErrMsg("create table t(embedding VECTOR<FLOAT>(16384))", "vector cannot have more than 16383 dimensions")
}

func TestFixedVector(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")

tk.MustExec("create table t(embedding VECTOR)")
tk.MustExec("insert into t values ('[1,2,3]')")
tk.MustExec("insert into t values ('[1,2,3,4]')")

// Failed to modify column type cause vectors with different dimension.
tk.MustContainErrMsg("alter table t modify column embedding VECTOR(3)", "vector has 4 dimensions, does not fit VECTOR(3)")

// Mixed dimension to fixed dimension.
tk.MustExec("delete from t where vec_dims(embedding) != 3")
tk.MustExec("alter table t modify column embedding VECTOR(3)")
tk.MustGetErrMsg("insert into t values ('[]')", "vector has 0 dimensions, does not fit VECTOR(3)")
tk.MustGetErrMsg("insert into t values ('[1,2,3,4]')", "vector has 4 dimensions, does not fit VECTOR(3)")
tk.MustGetErrMsg("insert into t values (VEC_FROM_TEXT('[]'))", "vector has 0 dimensions, does not fit VECTOR(3)")
tk.MustGetErrMsg("insert into t values (VEC_FROM_TEXT('[1,2,3,4]'))", "vector has 4 dimensions, does not fit VECTOR(3)")
tk.MustGetErrMsg("update t set embedding = '[1,2,3,4]' where embedding = '[1,2,3]'", "vector has 4 dimensions, does not fit VECTOR(3)")
tk.MustGetErrMsg("update t set embedding = '[]' where embedding = '[1,2,3]'", "vector has 0 dimensions, does not fit VECTOR(3)")

// Fixed dimension to mixed dimension.
tk.MustExec("alter table t modify column embedding VECTOR")
tk.MustExec("insert into t values ('[1,2,3,4]')")

// Vector dimension MUST be equal or less than 16383.
tk.MustGetErrMsg("alter table t modify column embedding VECTOR(16384)", "vector cannot have more than 16383 dimensions")
}

func TestVector(t *testing.T) {
store := testkit.CreateMockStore(t)

Expand Down Expand Up @@ -106,6 +190,7 @@ func TestVectorOperators(t *testing.T) {

tk := testkit.NewTestKit(t, store)
tk.MustExec("USE test;")

tk.MustExec(`CREATE TABLE t(embedding VECTOR);`)
tk.MustExec(`INSERT INTO t VALUES
('[1, 2, 3]'),
Expand All @@ -119,7 +204,7 @@ func TestVectorOperators(t *testing.T) {
tk.MustQuery(`SELECT VEC_FROM_TEXT('[]') IS NOT NULL`).Check(testkit.Rows("1"))
tk.MustQuery(`SELECT VEC_FROM_TEXT('[]') IS NULL`).Check(testkit.Rows("0"))
tk.MustQuery(`SELECT * FROM t WHERE embedding = VEC_FROM_TEXT('[1,2,3]');`).Check(testkit.Rows("[1,2,3]"))
tk.MustQuery(`SELECT * FROM t WHERE embedding BETWEEN '[1, 2, 3]' AND '[4, 5, 6]'`).Check(testkit.Rows("[1,2,3]", "[4,5,6]"))
tk.MustQuery(`SELECT * FROM t WHERE embedding BETWEEN '[1,2,3]' AND '[4,5,6]'`).Check(testkit.Rows("[1,2,3]", "[4,5,6]"))
tk.MustExecToErr(`SELECT * FROM t WHERE embedding IN ('[1, 2, 3]', '[4, 5, 6]')`)
tk.MustExecToErr(`SELECT * FROM t WHERE embedding NOT IN ('[1, 2, 3]', '[4, 5, 6]')`)
}
Expand Down Expand Up @@ -173,10 +258,18 @@ func TestVectorConversion(t *testing.T) {
tk.MustQueryToErr("SELECT CAST(VEC_FROM_TEXT('[1,2,3]') AS DATE);")
tk.MustQueryToErr("SELECT CAST(VEC_FROM_TEXT('[1,2,3]') AS TIME);")

// expect error result
tk.MustExecToErr("SELECT CAST('[1,2,3]' AS VECTOR);")
tk.MustExecToErr("SELECT CAST('[1,2,3]' AS VECTOR<FLOAT>);")
tk.MustExecToErr("SELECT CAST('[1,2,3]' AS VECTOR<DOUBLE>);")
tk.MustQuery("SELECT CAST('[1,2,3]' AS VECTOR);").Check(testkit.Rows("[1,2,3]"))
tk.MustQuery("SELECT CAST('[]' AS VECTOR);").Check(testkit.Rows("[]"))
tk.MustQuery("SELECT CAST('[1,2,3]' AS VECTOR<FLOAT>);").Check(testkit.Rows("[1,2,3]"))
tk.MustContainErrMsg("SELECT CAST('[1,2,3]' AS VECTOR<DOUBLE>);", "Only VECTOR is supported for now")

tk.MustQuery("SELECT CAST('[1,2,3]' AS VECTOR<FLOAT>(3));").Check(testkit.Rows("[1,2,3]"))
err := tk.QueryToErr("SELECT CAST('[1,2,3]' AS VECTOR<FLOAT>(2));")
require.EqualError(t, err, "vector has 3 dimensions, does not fit VECTOR(2)")

tk.MustQuery("SELECT CAST(VEC_FROM_TEXT('[1,2,3]') AS VECTOR<FLOAT>(3));").Check(testkit.Rows("[1,2,3]"))
err = tk.QueryToErr("SELECT CAST(VEC_FROM_TEXT('[1,2,3]') AS VECTOR<FLOAT>(2));")
require.EqualError(t, err, "vector has 3 dimensions, does not fit VECTOR(2)")

// CONVERT
tk.MustQuery("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), BINARY);").Check(testkit.Rows("[1,2,3]"))
Expand All @@ -192,6 +285,19 @@ func TestVectorConversion(t *testing.T) {
tk.MustQueryToErr("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), DATETIME);")
tk.MustQueryToErr("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), DATE);")
tk.MustQueryToErr("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), TIME);")

tk.MustQuery("SELECT CONVERT('[1,2,3]', VECTOR);").Check(testkit.Rows("[1,2,3]"))
tk.MustQuery("SELECT CONVERT('[]', VECTOR);").Check(testkit.Rows("[]"))
tk.MustQuery("SELECT CONVERT('[1,2,3]', VECTOR<FLOAT>);").Check(testkit.Rows("[1,2,3]"))
tk.MustContainErrMsg("SELECT CONVERT('[1,2,3]', VECTOR<DOUBLE>);", "Only VECTOR is supported for now")

tk.MustQuery("SELECT CONVERT('[1,2,3]', VECTOR<FLOAT>(3));").Check(testkit.Rows("[1,2,3]"))
err = tk.QueryToErr("SELECT CONVERT('[1,2,3]', VECTOR<FLOAT>(2));")
require.EqualError(t, err, "vector has 3 dimensions, does not fit VECTOR(2)")

tk.MustQuery("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), VECTOR<FLOAT>(3));").Check(testkit.Rows("[1,2,3]"))
err = tk.QueryToErr("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), VECTOR<FLOAT>(2));")
require.EqualError(t, err, "vector has 3 dimensions, does not fit VECTOR(2)")
}

func TestVectorAggregations(t *testing.T) {
Expand Down
10 changes: 6 additions & 4 deletions pkg/planner/core/preprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -1457,6 +1457,12 @@ func checkColumn(colDef *ast.ColumnDef) error {
if tp.GetFlen() > mysql.MaxBitDisplayWidth {
return types.ErrTooBigDisplayWidth.GenWithStackByArgs(colDef.Name.Name.O, mysql.MaxBitDisplayWidth)
}
case mysql.TypeTiDBVectorFloat32:
if tp.GetFlen() != types.UnspecifiedLength {
if err := types.CheckVectorDimValid(tp.GetFlen()); err != nil {
return err
}
}
default:
// TODO: Add more types.
}
Expand Down Expand Up @@ -1742,10 +1748,6 @@ func (p *preprocessor) checkFuncCastExpr(node *ast.FuncCastExpr) {
return
}
}
if node.Tp.GetType() == mysql.TypeTiDBVectorFloat32 {
p.err = errors.Errorf("vector type is not supported")
return
}
}

func (p *preprocessor) updateStateFromStaleReadProcessor() error {
Expand Down
16 changes: 12 additions & 4 deletions pkg/types/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ func (d *Datum) compareMysqlTime(ctx Context, time Time) (int, error) {
}
}

func (d *Datum) compareVectorFloat32(sc Context, vec VectorFloat32) (int, error) {
func (d *Datum) compareVectorFloat32(ctx Context, vec VectorFloat32) (int, error) {
breezewish marked this conversation as resolved.
Show resolved Hide resolved
switch d.k {
case KindNull, KindMinNotNull:
return -1, nil
Expand Down Expand Up @@ -1796,15 +1796,23 @@ func (d *Datum) convertToMysqlJSON(_ *FieldType) (ret Datum, err error) {
return ret, errors.Trace(err)
}

func (d *Datum) convertToVectorFloat32(_ Context, _ *FieldType) (ret Datum, err error) {
func (d *Datum) convertToVectorFloat32(_ Context, target *FieldType) (ret Datum, err error) {
switch d.k {
case KindVectorFloat32:
v := d.GetVectorFloat32()
if err = v.CheckDimsFitColumn(target.GetFlen()); err != nil {
return ret, errors.Trace(err)
}
ret = *d
case KindString, KindBytes:
var v VectorFloat32
if v, err = ParseVectorFloat32(d.GetString()); err == nil {
ret.SetVectorFloat32(v)
if v, err = ParseVectorFloat32(d.GetString()); err != nil {
return ret, errors.Trace(err)
}
if err = v.CheckDimsFitColumn(target.GetFlen()); err != nil {
return ret, errors.Trace(err)
}
ret.SetVectorFloat32(v)
default:
return invalidConv(d, mysql.TypeTiDBVectorFloat32)
}
Expand Down
30 changes: 29 additions & 1 deletion pkg/types/vector.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

jsoniter "github.com/json-iterator/go"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/parser/types"
)

func init() {
Expand Down Expand Up @@ -55,6 +56,28 @@ func InitVectorFloat32(dims int) VectorFloat32 {
return VectorFloat32{data: data}
}

// CheckVectorDimValid checks if the vector's dimension is valid.
func CheckVectorDimValid(dim int) error {
const (
maxVectorDimension = 16383
)
if dim < 0 {
return errors.Errorf("dimensions for type vector must be at least 0")
}
if dim > maxVectorDimension {
return errors.Errorf("vector cannot have more than %d dimensions", maxVectorDimension)
}
return nil
}

// CheckDimsFitColumn checks if the vector has the expected dimension, which is defined by the column type or cast type.
func (v VectorFloat32) CheckDimsFitColumn(expectedFlen int) error {
if expectedFlen != types.UnspecifiedLength && v.Len() != expectedFlen {
return errors.Errorf("vector has %d dimensions, does not fit VECTOR(%d)", v.Len(), expectedFlen)
}
return nil
}

// Len returns the length (dimension) of the vector.
func (v VectorFloat32) Len() int {
return int(binary.LittleEndian.Uint32(v.data))
Expand Down Expand Up @@ -139,7 +162,12 @@ func ParseVectorFloat32(s string) (VectorFloat32, error) {
return ZeroVectorFloat32, errors.Errorf("Invalid vector text: %s", s)
}

vec := InitVectorFloat32(len(values))
dim := len(values)
if err := CheckVectorDimValid(dim); err != nil {
return ZeroVectorFloat32, err
}

vec := InitVectorFloat32(dim)
copy(vec.Elements(), values)
return vec, nil
}
Expand Down