-
Notifications
You must be signed in to change notification settings - Fork 6
/
vectodb.go
122 lines (106 loc) · 3.11 KB
/
vectodb.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
package vectodb
// https://golang.org/cmd/cgo/
// When the Go tool sees that one or more Go files use the special import "C", it will look for other non-Go files in the directory and compile them as part of the Go package.
// #cgo CXXFLAGS: -std=c++17 -I${SRCDIR}
// #cgo LDFLAGS: -L${SRCDIR}/faiss -lglog -lgflags -lfaiss -lopenblas -lgomp -lstdc++ -lstdc++fs -ljemalloc
// #include "vectodb.h"
// #include <stdlib.h>
import "C"
import (
"unsafe"
log "github.com/sirupsen/logrus"
)
type VectoDB struct {
vdbC unsafe.Pointer
dim int
workDir string
flatThreshold int
}
func NewVectoDB(workDir string, dimIn int) (vdb *VectoDB, err error) {
log.Infof("creating VectoDB %v", workDir)
wordDirC := C.CString(workDir)
vdbC := C.VectodbNew(wordDirC, C.long(dimIn))
vdb = &VectoDB{
vdbC: vdbC,
dim: dimIn,
workDir: workDir,
}
C.free(unsafe.Pointer(wordDirC))
return
}
func (vdb *VectoDB) Destroy() (err error) {
log.Infof("destroying VectoDB %+v", vdb)
C.VectodbDelete(vdb.vdbC)
vdb.vdbC = nil
return
}
/*
input parameters:
@param xb: nb个向量
@param xids: nb个向量编号。xid 64 bit结构:高32 bit为uid(用户ID),低32 bit为pid(图片ID)
*/
func (vdb *VectoDB) AddWithIds(xb []float32, xids []int64) (err error) {
nb := len(xids)
if len(xb) != nb*vdb.dim {
log.Fatalf("invalid length of xb, want %v, have %v", nb*vdb.dim, len(xb))
}
C.VectodbAddWithIds(vdb.vdbC, C.long(nb), (*C.float)(&xb[0]), (*C.long)(&xids[0]))
return
}
func (vdb *VectoDB) SyncIndex() (err error) {
C.VectodbSyncIndex(vdb.vdbC)
return
}
func (vdb *VectoDB) GetTotal() (total int, err error) {
totalC := C.VectodbGetTotal(vdb.vdbC)
total = int(totalC)
return
}
type XidScore struct {
Xid int64
Score float32
}
/**
input parameters:
@param ks: kNN参数k
@param xq: nq个查询向量
@param uids: nq个序列化的roaring bitmap
output parameters:
@param scores: 所有结果的得分(查询1的k个得分,查询2的k个得分,...)
@param xids: 所有结果的向量编号(查询1的k个向量编号,查询2的k个向量编号,...)
return parameters:
@return err 错误
*/
func (vdb *VectoDB) Search(k int, xq []float32, uids []string) (res [][]XidScore, err error) {
nq := len(xq) / vdb.dim
if len(xq) != nq*vdb.dim {
log.Fatalf("invalid length of xq, want %v, have %v", nq*vdb.dim, len(xq))
}
if len(uids) != nq {
log.Fatalf("invalid length of uids, want %v, have %v", nq, len(uids))
}
res = make([][]XidScore, nq)
scores := make([]float32, nq*k)
xids := make([]int64, nq*k)
var uidsFilter int64
C.VectodbSearch(vdb.vdbC, C.long(nq), C.long(k), (*C.float)(&xq[0]), (*C.long)(&uidsFilter), (*C.float)(&scores[0]), (*C.long)(&xids[0]))
for i := 0; i < nq; i++ {
for j := 0; j < k; j++ {
if xids[i*k+j] == int64(-1) {
break
}
res[i] = append(res[i], XidScore{Xid: xids[i*k+j], Score: scores[i*k+j]})
}
}
return
}
/**
* Static methods.
*/
func VectodbClearWorkDir(workDir string) (err error) {
log.Infof("clearing VectoDB %v", workDir)
wordDirC := C.CString(workDir)
C.VectodbClearDir(wordDirC)
C.free(unsafe.Pointer(wordDirC))
return
}