Skip to content

Commit

Permalink
ddl: Fix vector index for high dimensional vectors (#58717)
Browse files Browse the repository at this point in the history
ref #54245
  • Loading branch information
breezewish authored Jan 9, 2025
1 parent b6141ec commit 448e302
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 0 deletions.
8 changes: 8 additions & 0 deletions pkg/ddl/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,11 @@ func getIndexColumnLength(col *model.ColumnInfo, colLen int) (int, error) {
}

switch col.GetType() {
case mysql.TypeTiDBVectorFloat32:
// Vector Index does not actually create KV index, so it has length of 0.
// however 0 may cause some issues in other calculations, so we use 1 here.
// 1 is also minimal enough anyway.
return 1, nil
case mysql.TypeBit:
return (length + 7) >> 3, nil
case mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeBlob, mysql.TypeLongBlob:
Expand Down Expand Up @@ -2930,6 +2935,9 @@ func newCleanUpIndexWorker(id int, t table.PhysicalTable, decodeColMap map[int64
indexes := make([]table.Index, 0, len(t.Indices()))
rowDecoder := decoder.NewRowDecoder(t, t.WritableCols(), decodeColMap)
for _, index := range t.Indices() {
if index.Meta().IsTiFlashLocalIndex() {
continue
}
if index.Meta().Global {
indexes = append(indexes, index)
}
Expand Down
1 change: 1 addition & 0 deletions pkg/expression/integration_test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ go_test(
deps = [
"//pkg/config",
"//pkg/domain",
"//pkg/domain/infosync",
"//pkg/errno",
"//pkg/expression",
"//pkg/kv",
Expand Down
96 changes: 96 additions & 0 deletions pkg/expression/integration_test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/pkg/config"
"github.com/pingcap/tidb/pkg/domain"
"github.com/pingcap/tidb/pkg/domain/infosync"
"github.com/pingcap/tidb/pkg/errno"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/kv"
Expand Down Expand Up @@ -61,6 +62,101 @@ import (
"github.com/tikv/client-go/v2/oracle"
)

func TestVectorLong(t *testing.T) {
store := testkit.CreateMockStoreWithSchemaLease(t, 1*time.Second, mockstore.WithMockTiFlash(2))

tk := testkit.NewTestKit(t, store)

tiflash := infosync.NewMockTiFlash()
infosync.SetMockTiFlash(tiflash)
defer func() {
tiflash.Lock()
tiflash.StatusServer.Close()
tiflash.Unlock()
}()

genVec := func(d int, startValue int) string {
vb := strings.Builder{}
vb.WriteString("[")
value := startValue
for i := 0; i < d; i++ {
if i > 0 {
vb.WriteString(",")
}
vb.WriteString(strconv.FormatInt(int64(value), 10))
value += 100
}
vb.WriteString("]")
return vb.String()
}

failpoint.Enable("github.com/pingcap/tidb/pkg/ddl/MockCheckVectorIndexProcess", `return(1)`)
defer func() {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/ddl/MockCheckVectorIndexProcess"))
}()

runWorkload := func() {
tk.MustExec(fmt.Sprintf(`insert into t1 values (1, '%s')`, genVec(16383, 100)))
tk.MustQuery(`select * from t1 order by id`).Check(testkit.Rows("1 " + genVec(16383, 100)))
tk.MustExec(fmt.Sprintf(`delete from t1 where vec > '%s'`, genVec(16383, 200)))
tk.MustQuery(`select * from t1 order by id`).Check(testkit.Rows("1 " + genVec(16383, 100)))
tk.MustExec(fmt.Sprintf(`delete from t1 where vec > '%s'`, genVec(16383, 50)))
tk.MustQuery(`select * from t1 order by id`).Check(testkit.Rows())
tk.MustExec(fmt.Sprintf(`insert into t1 values (1, '%s')`, genVec(16383, 100)))
tk.MustExec(fmt.Sprintf(`insert into t1 values (2, '%s')`, genVec(16383, 200)))
tk.MustExec(fmt.Sprintf(`insert into t1 values (3, '%s')`, genVec(16383, 300)))
tk.MustQuery(fmt.Sprintf(`select id from t1 order by vec_l2_distance(vec, '%s') limit 2`, genVec(16383, 180))).Check(testkit.Rows(
"2",
"1",
))
tk.MustExec(fmt.Sprintf(`update t1 set vec = '%s' where id = 1`, genVec(16383, 500)))
tk.MustQuery(`select * from t1 order by id`).Check(testkit.Rows(
"1 "+genVec(16383, 500),
"2 "+genVec(16383, 200),
"3 "+genVec(16383, 300),
))
tk.MustQuery(fmt.Sprintf(`select id from t1 order by vec_l2_distance(vec, '%s') limit 2`, genVec(16383, 180))).Check(testkit.Rows(
"2",
"3",
))
}

tk.MustExec("use test")
tk.MustExec(`
create table t1 (
id int primary key,
vec vector(16383)
)
`)
runWorkload()
tk.MustExec("drop table t1")

tk.MustExec(`
create table t1 (
id int primary key,
vec vector(16383),
VECTOR INDEX ((vec_cosine_distance(vec)))
)
`)
runWorkload()
tk.MustExec("drop table if exists t1")
tk.MustExec(`
create table t1 (
id int primary key,
vec vector(16383)
)
`)
tk.MustExec(`alter table t1 set tiflash replica 1`)
tbl, _ := domain.GetDomain(tk.Session()).InfoSchema().TableByName(context.Background(), ast.NewCIStr("test"), ast.NewCIStr("t1"))
tbl.Meta().TiFlashReplica = &model.TiFlashReplicaInfo{
Count: 1,
Available: true,
}
tk.MustExec(`alter table t1 add VECTOR INDEX ((vec_cosine_distance(vec)))`)
runWorkload()
tk.MustExec("drop table if exists t1")
}

func TestVectorDefaultValue(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
Expand Down
6 changes: 6 additions & 0 deletions pkg/meta/model/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ func (index *IndexInfo) IsPublic() bool {
return index.State == StatePublic
}

// IsTiFlashLocalIndex checks whether the index is a TiFlash local index.
// For a TiFlash local index, no actual index data need to be written to KV layer.
func (index *IndexInfo) IsTiFlashLocalIndex() bool {
return index.VectorInfo != nil
}

// FindIndexByColumns find IndexInfo in indices which is cover the specified columns.
func FindIndexByColumns(tbInfo *TableInfo, indices []*IndexInfo, cols ...ast.CIStr) *IndexInfo {
for _, index := range indices {
Expand Down
15 changes: 15 additions & 0 deletions pkg/table/tables/tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,9 @@ func GetWritableIndexByName(idxName string, t table.Table) table.Index {
if !IsIndexWritable(idx) {
continue
}
if idx.Meta().IsTiFlashLocalIndex() {
continue
}
if idxName == idx.Meta().Name.L {
return idx
}
Expand Down Expand Up @@ -547,6 +550,9 @@ func (t *TableCommon) rebuildUpdateRecordIndices(
if t.meta.IsCommonHandle && idx.Meta().Primary {
continue
}
if idx.Meta().IsTiFlashLocalIndex() {
continue
}
for _, ic := range idx.Meta().Columns {
if !touched[ic.Offset] {
continue
Expand All @@ -566,6 +572,9 @@ func (t *TableCommon) rebuildUpdateRecordIndices(
if !IsIndexWritable(idx) {
continue
}
if idx.Meta().IsTiFlashLocalIndex() {
continue
}
if t.meta.IsCommonHandle && idx.Meta().Primary {
continue
}
Expand Down Expand Up @@ -926,6 +935,9 @@ func (t *TableCommon) addIndices(sctx table.MutateContext, recordID kv.Handle, r
if !IsIndexWritable(v) {
continue
}
if v.Meta().IsTiFlashLocalIndex() {
continue
}
if t.meta.IsCommonHandle && v.Meta().Primary {
continue
}
Expand Down Expand Up @@ -1185,6 +1197,9 @@ func (t *TableCommon) removeRowIndices(ctx table.MutateContext, txn kv.Transacti
if v.Meta().Primary && (t.Meta().IsCommonHandle || t.Meta().PKIsHandle) {
continue
}
if v.Meta().IsTiFlashLocalIndex() {
continue
}
var vals []types.Datum
if opt.HasIndexesLayout() {
vals, err = fetchIndexRow(v.Meta(), rec, nil, opt.GetIndexLayout(v.Meta().ID))
Expand Down

0 comments on commit 448e302

Please sign in to comment.