Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ddl: Fix vector index for high dimensional vectors (#58717) #58835

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions pkg/ddl/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,11 @@ func getIndexColumnLength(col *model.ColumnInfo, colLen int) (int, error) {
}

switch col.GetType() {
case mysql.TypeTiDBVectorFloat32:
// Vector Index does not actually create KV index, so it has length of 0.
// however 0 may cause some issues in other calculations, so we use 1 here.
// 1 is also minimal enough anyway.
return 1, nil
case mysql.TypeBit:
return (length + 7) >> 3, nil
case mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeBlob, mysql.TypeLongBlob:
Expand Down Expand Up @@ -2922,6 +2927,9 @@ func newCleanUpIndexWorker(id int, t table.PhysicalTable, decodeColMap map[int64
indexes := make([]table.Index, 0, len(t.Indices()))
rowDecoder := decoder.NewRowDecoder(t, t.WritableCols(), decodeColMap)
for _, index := range t.Indices() {
if index.Meta().IsTiFlashLocalIndex() {
continue
}
if index.Meta().Global {
indexes = append(indexes, index)
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/expression/integration_test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ go_test(
"main_test.go",
],
flaky = True,
shard_count = 48,
shard_count = 49,
deps = [
"//pkg/config",
"//pkg/domain",
"//pkg/domain/infosync",
"//pkg/errno",
"//pkg/expression",
"//pkg/kv",
Expand Down
96 changes: 96 additions & 0 deletions pkg/expression/integration_test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/pkg/config"
"github.com/pingcap/tidb/pkg/domain"
"github.com/pingcap/tidb/pkg/domain/infosync"
"github.com/pingcap/tidb/pkg/errno"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/kv"
Expand Down Expand Up @@ -61,6 +62,101 @@ import (
"github.com/tikv/client-go/v2/oracle"
)

func TestVectorLong(t *testing.T) {
store := testkit.CreateMockStoreWithSchemaLease(t, 1*time.Second, mockstore.WithMockTiFlash(2))

tk := testkit.NewTestKit(t, store)

tiflash := infosync.NewMockTiFlash()
infosync.SetMockTiFlash(tiflash)
defer func() {
tiflash.Lock()
tiflash.StatusServer.Close()
tiflash.Unlock()
}()

genVec := func(d int, startValue int) string {
vb := strings.Builder{}
vb.WriteString("[")
value := startValue
for i := 0; i < d; i++ {
if i > 0 {
vb.WriteString(",")
}
vb.WriteString(strconv.FormatInt(int64(value), 10))
value += 100
}
vb.WriteString("]")
return vb.String()
}

failpoint.Enable("github.com/pingcap/tidb/pkg/ddl/MockCheckVectorIndexProcess", `return(1)`)
defer func() {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/ddl/MockCheckVectorIndexProcess"))
}()

runWorkload := func() {
tk.MustExec(fmt.Sprintf(`insert into t1 values (1, '%s')`, genVec(16383, 100)))
tk.MustQuery(`select * from t1 order by id`).Check(testkit.Rows("1 " + genVec(16383, 100)))
tk.MustExec(fmt.Sprintf(`delete from t1 where vec > '%s'`, genVec(16383, 200)))
tk.MustQuery(`select * from t1 order by id`).Check(testkit.Rows("1 " + genVec(16383, 100)))
tk.MustExec(fmt.Sprintf(`delete from t1 where vec > '%s'`, genVec(16383, 50)))
tk.MustQuery(`select * from t1 order by id`).Check(testkit.Rows())
tk.MustExec(fmt.Sprintf(`insert into t1 values (1, '%s')`, genVec(16383, 100)))
tk.MustExec(fmt.Sprintf(`insert into t1 values (2, '%s')`, genVec(16383, 200)))
tk.MustExec(fmt.Sprintf(`insert into t1 values (3, '%s')`, genVec(16383, 300)))
tk.MustQuery(fmt.Sprintf(`select id from t1 order by vec_l2_distance(vec, '%s') limit 2`, genVec(16383, 180))).Check(testkit.Rows(
"2",
"1",
))
tk.MustExec(fmt.Sprintf(`update t1 set vec = '%s' where id = 1`, genVec(16383, 500)))
tk.MustQuery(`select * from t1 order by id`).Check(testkit.Rows(
"1 "+genVec(16383, 500),
"2 "+genVec(16383, 200),
"3 "+genVec(16383, 300),
))
tk.MustQuery(fmt.Sprintf(`select id from t1 order by vec_l2_distance(vec, '%s') limit 2`, genVec(16383, 180))).Check(testkit.Rows(
"2",
"3",
))
}

tk.MustExec("use test")
tk.MustExec(`
create table t1 (
id int primary key,
vec vector(16383)
)
`)
runWorkload()
tk.MustExec("drop table t1")

tk.MustExec(`
create table t1 (
id int primary key,
vec vector(16383),
VECTOR INDEX ((vec_cosine_distance(vec)))
)
`)
runWorkload()
tk.MustExec("drop table if exists t1")
tk.MustExec(`
create table t1 (
id int primary key,
vec vector(16383)
)
`)
tk.MustExec(`alter table t1 set tiflash replica 1`)
tbl, _ := domain.GetDomain(tk.Session()).InfoSchema().TableByName(context.Background(), pmodel.NewCIStr("test"), pmodel.NewCIStr("t1"))
tbl.Meta().TiFlashReplica = &model.TiFlashReplicaInfo{
Count: 1,
Available: true,
}
tk.MustExec(`alter table t1 add VECTOR INDEX ((vec_cosine_distance(vec)))`)
runWorkload()
tk.MustExec("drop table if exists t1")
}

func TestVectorDefaultValue(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
Expand Down
6 changes: 6 additions & 0 deletions pkg/meta/model/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ func (index *IndexInfo) IsPublic() bool {
return index.State == StatePublic
}

// IsTiFlashLocalIndex checks whether the index is a TiFlash local index.
// For a TiFlash local index, no actual index data need to be written to KV layer.
func (index *IndexInfo) IsTiFlashLocalIndex() bool {
return index.VectorInfo != nil
}

// FindIndexByColumns find IndexInfo in indices which is cover the specified columns.
func FindIndexByColumns(tbInfo *TableInfo, indices []*IndexInfo, cols ...model.CIStr) *IndexInfo {
for _, index := range indices {
Expand Down
15 changes: 15 additions & 0 deletions pkg/table/tables/tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,9 @@ func GetWritableIndexByName(idxName string, t table.Table) table.Index {
if !IsIndexWritable(idx) {
continue
}
if idx.Meta().IsTiFlashLocalIndex() {
continue
}
if idxName == idx.Meta().Name.L {
return idx
}
Expand Down Expand Up @@ -561,6 +564,9 @@ func (t *TableCommon) rebuildUpdateRecordIndices(
if t.meta.IsCommonHandle && idx.Meta().Primary {
continue
}
if idx.Meta().IsTiFlashLocalIndex() {
continue
}
for _, ic := range idx.Meta().Columns {
if !touched[ic.Offset] {
continue
Expand All @@ -580,6 +586,9 @@ func (t *TableCommon) rebuildUpdateRecordIndices(
if !IsIndexWritable(idx) {
continue
}
if idx.Meta().IsTiFlashLocalIndex() {
continue
}
if t.meta.IsCommonHandle && idx.Meta().Primary {
continue
}
Expand Down Expand Up @@ -948,6 +957,9 @@ func (t *TableCommon) addIndices(sctx table.MutateContext, recordID kv.Handle, r
if !IsIndexWritable(v) {
continue
}
if v.Meta().IsTiFlashLocalIndex() {
continue
}
if t.meta.IsCommonHandle && v.Meta().Primary {
continue
}
Expand Down Expand Up @@ -1232,6 +1244,9 @@ func (t *TableCommon) removeRowIndices(ctx table.MutateContext, txn kv.Transacti
if v.Meta().Primary && (t.Meta().IsCommonHandle || t.Meta().PKIsHandle) {
continue
}
if v.Meta().IsTiFlashLocalIndex() {
continue
}
var vals []types.Datum
if opt.HasIndexesLayout() {
vals, err = fetchIndexRow(v.Meta(), rec, nil, opt.GetIndexLayout(v.Meta().ID))
Expand Down