From 4fbed722518b21fbf33a894f2d24f3b89a55e847 Mon Sep 17 00:00:00 2001 From: "gary.y" Date: Tue, 9 Jul 2024 20:19:58 +0800 Subject: [PATCH] fix(kb): issue of chunking --- cmd/main/main.go | 1 + config/config.yaml | 2 +- .../000007_create_text_chunk_table.up.sql | 24 ++-- .../000008_create_embedding_table.up.sql | 3 +- pkg/handler/knowledgebasefiles.go | 31 +++-- pkg/milvus/milvus.go | 13 ++- pkg/minio/knowledgebase.go | 22 +++- pkg/mock/repository_i_mock.gen.go | 106 +++++++++++------- pkg/repository/chunk.go | 50 +++++++-- pkg/repository/embedding.go | 3 + pkg/repository/knowledgebasefile.go | 25 ++++- pkg/worker/worker.go | 54 +++++---- 12 files changed, 227 insertions(+), 107 deletions(-) diff --git a/cmd/main/main.go b/cmd/main/main.go index fdf389a..3dc7a26 100644 --- a/cmd/main/main.go +++ b/cmd/main/main.go @@ -43,6 +43,7 @@ import ( "github.com/instill-ai/artifact-backend/pkg/repository" servicePkg "github.com/instill-ai/artifact-backend/pkg/service" "github.com/instill-ai/artifact-backend/pkg/usage" + // "github.com/instill-ai/artifact-backend/pkg/worker" grpcclient "github.com/instill-ai/artifact-backend/pkg/client/grpc" httpclient "github.com/instill-ai/artifact-backend/pkg/client/http" diff --git a/config/config.yaml b/config/config.yaml index 0dbf600..9567a4a 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -22,7 +22,7 @@ database: host: pg-sql port: 5432 name: artifact - version: 5 + version: 8 timezone: Etc/UTC pool: idleconnections: 5 diff --git a/pkg/db/migration/000007_create_text_chunk_table.up.sql b/pkg/db/migration/000007_create_text_chunk_table.up.sql index fae20cb..c54273a 100644 --- a/pkg/db/migration/000007_create_text_chunk_table.up.sql +++ b/pkg/db/migration/000007_create_text_chunk_table.up.sql @@ -1,28 +1,34 @@ BEGIN; + CREATE TABLE text_chunk ( uid UUID PRIMARY KEY DEFAULT gen_random_uuid(), source_uid UUID NOT NULL, source_table VARCHAR(255) NOT NULL, - start INT NOT NULL, - end INT NOT NULL, + start_pos INT NOT NULL, + end_pos INT NOT NULL, content_dest VARCHAR(255) NOT NULL, tokens INT NOT NULL, retrievable BOOLEAN NOT NULL DEFAULT true, - order INT NOT NULL create_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - update_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + in_order INT NOT NULL, + create_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ); + -- Create indexes -CREATE UNIQUE INDEX idx_unique_source_table_uid_start_end ON text_chunk (source_table, source_uid, start,end -); +CREATE UNIQUE INDEX idx_unique_source_table_uid_start_end ON text_chunk (source_table, source_uid, start_pos, end_pos); + -- Comments for the table and columns COMMENT ON TABLE text_chunk IS 'Table to store text chunks with metadata'; COMMENT ON COLUMN text_chunk.uid IS 'Unique identifier for the text chunk'; COMMENT ON COLUMN text_chunk.source_uid IS 'Source unique identifier, references source table''s uid field'; COMMENT ON COLUMN text_chunk.source_table IS 'Name of the source table'; -COMMENT ON COLUMN text_chunk.start IS 'Start position of the text chunk'; -COMMENT ON COLUMN text_chunk. -end IS 'End position of the text chunk'; +COMMENT ON COLUMN text_chunk.start_pos IS 'Start position of the text chunk'; +COMMENT ON COLUMN text_chunk.end_pos IS 'End position of the text chunk'; COMMENT ON COLUMN text_chunk.content_dest IS 'dest of the text chunk''s content in file store'; +COMMENT ON COLUMN text_chunk.tokens IS 'Number of tokens in the text chunk'; +COMMENT ON COLUMN text_chunk.retrievable IS 'Flag indicating if the chunk is retrievable'; +COMMENT ON COLUMN text_chunk.in_order IS 'Order of the text chunk'; COMMENT ON COLUMN text_chunk.create_time IS 'Timestamp when the record was created'; COMMENT ON COLUMN text_chunk.update_time IS 'Timestamp when the record was last updated'; + COMMIT; diff --git a/pkg/db/migration/000008_create_embedding_table.up.sql b/pkg/db/migration/000008_create_embedding_table.up.sql index 3336dac..002d3c4 100644 --- a/pkg/db/migration/000008_create_embedding_table.up.sql +++ b/pkg/db/migration/000008_create_embedding_table.up.sql @@ -5,6 +5,7 @@ CREATE TABLE embedding ( source_uid UUID NOT NULL, source_table VARCHAR(255) NOT NULL, vector JSONB NOT NULL, + collection VARCHAR(255) NOT NULL, create_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ); @@ -15,7 +16,7 @@ COMMENT ON TABLE embedding IS 'Table to store embeddings with metadata'; COMMENT ON COLUMN embedding.uid IS 'Unique identifier for the embedding'; COMMENT ON COLUMN embedding.source_uid IS 'Source unique identifier, references source table''s uid field'; COMMENT ON COLUMN embedding.source_table IS 'Name of the source table'; -COMMENT ON COLUMN embedding.embedding_dest IS 'Destination of the embedding''s content in vector store'; +COMMENT ON COLUMN embedding.collection IS 'Destination of the embedding''s content in vector store'; COMMENT ON COLUMN embedding.create_time IS 'Timestamp when the record was created'; COMMENT ON COLUMN embedding.update_time IS 'Timestamp when the record was last updated'; COMMIT; diff --git a/pkg/handler/knowledgebasefiles.go b/pkg/handler/knowledgebasefiles.go index 3ed7842..800ea6e 100644 --- a/pkg/handler/knowledgebasefiles.go +++ b/pkg/handler/knowledgebasefiles.go @@ -27,6 +27,10 @@ func (ph *PublicHandler) UploadKnowledgeBaseFile(ctx context.Context, req *artif return nil, err } + if strings.Contains(req.File.Name, "/") { + return nil, fmt.Errorf("file name cannot contain '/'. err: %w", customerror.ErrInvalidArgument) + } + // TODO: ACL - check if the creator can upload file to this knowledge base. ACL. // ..... @@ -43,21 +47,12 @@ func (ph *PublicHandler) UploadKnowledgeBaseFile(ctx context.Context, req *artif // upload file to minio var kb *repository.KnowledgeBase - var filePathName string { kb, err = ph.service.Repository.GetKnowledgeBaseByOwnerAndID(ctx, ownerUID, req.KbId) if err != nil { return nil, fmt.Errorf("failed to get knowledge base by owner and id. err: %w", err) } - // check if the name has "/" which may cause folder creation in minio - if strings.Contains(req.File.Name, "/") { - return nil, fmt.Errorf("file name cannot contain '/'. err: %w", customerror.ErrInvalidArgument) - } - filePathName = kb.UID.String() + "/" + req.File.Name - err = ph.service.MinIO.UploadBase64File(ctx, filePathName, req.File.Content, fileTypeConvertToMime(req.File.Type)) - if err != nil { - return nil, err - } + } // create metadata in db @@ -74,23 +69,25 @@ func (ph *PublicHandler) UploadKnowledgeBaseFile(ctx context.Context, req *artif log.Error("failed to parse owner uid", zap.Error(err)) return nil, err } + destination := ph.service.MinIO.GetUploadedFilePathInKnowledgeBase(kb.UID.String(), req.File.Name) kbFile := repository.KnowledgeBaseFile{ Name: req.File.Name, Type: artifactpb.FileType_name[int32(req.File.Type)], Owner: ownerUIDUuid, CreatorUID: creatorUID, KnowledgeBaseUID: kb.UID, - Destination: filePathName, + Destination: destination, ProcessStatus: artifactpb.FileProcessStatus_name[int32(artifactpb.FileProcessStatus_FILE_PROCESS_STATUS_NOTSTARTED)], } - res, err = ph.service.Repository.CreateKnowledgeBaseFile(ctx, kbFile) - if err != nil { - err := ph.service.MinIO.DeleteFile(ctx, filePathName) + res, err = ph.service.Repository.CreateKnowledgeBaseFile(ctx, kbFile, func(FileUID string) error { + // upload file to minio + err = ph.service.MinIO.UploadBase64File(ctx, destination, req.File.Content, fileTypeConvertToMime(req.File.Type)) if err != nil { - log.Error("failed to delete file in minio", zap.Error(err)) + return err } - return nil, err - } + + return nil + }) } return &artifactpb.UploadKnowledgeBaseFileResponse{ diff --git a/pkg/milvus/milvus.go b/pkg/milvus/milvus.go index 10c4bf8..7b93326 100644 --- a/pkg/milvus/milvus.go +++ b/pkg/milvus/milvus.go @@ -21,6 +21,8 @@ type MilvusClientI interface { ListEmbeddings(ctx context.Context, collectionName string) ([]Embedding, error) SearchEmbeddings(ctx context.Context, collectionName string, vectors [][]float32, topK int) ([][]Embedding, error) DeleteEmbedding(ctx context.Context, collectionName string, embeddingUID []string) error + // GetKnowledgeBaseCollectionName returns the collection name for a knowledge base + GetKnowledgeBaseCollectionName(kbUID string) string Close() } @@ -81,7 +83,7 @@ func (m *MilvusClient) GetHealth(ctx context.Context) (bool, error) { // CreateKnowledgeBaseCollection func (m *MilvusClient) CreateKnowledgeBaseCollection(ctx context.Context, kbUID string) error { logger, _ := logger.GetZapLogger(ctx) - collectionName := getKnowledgeBaseCollectionName(kbUID) + collectionName := m.GetKnowledgeBaseCollectionName(kbUID) // 1. Check if the collection already exists has, err := m.c.HasCollection(ctx, collectionName) @@ -129,7 +131,7 @@ func (m *MilvusClient) CreateKnowledgeBaseCollection(ctx context.Context, kbUID // InsertVectorsToKnowledgeBaseCollection func (m *MilvusClient) InsertVectorsToKnowledgeBaseCollection(ctx context.Context, kbUID string, embeddings []Embedding) error { - collectionName := getKnowledgeBaseCollectionName(kbUID) + collectionName := m.GetKnowledgeBaseCollectionName(kbUID) // Check if the collection exists has, err := m.c.HasCollection(ctx, collectionName) @@ -380,7 +382,10 @@ func (m *MilvusClient) Close() { const kbCollectionPrefix = "kb_" -// getKnowledgeBaseCollectionName -func getKnowledgeBaseCollectionName(kbUID string) string { +// GetKnowledgeBaseCollectionName returns the collection name for a knowledge base +func (m *MilvusClient) GetKnowledgeBaseCollectionName(kbUID string) string { + // collection name can only contain numbers, letters and underscores: invalid parameter + // turn kbUID(uuid) into a valid collection name + kbUID = strings.ReplaceAll(kbUID, "-", "_") return kbCollectionPrefix + kbUID } diff --git a/pkg/minio/knowledgebase.go b/pkg/minio/knowledgebase.go index 1413777..3d8b9d9 100644 --- a/pkg/minio/knowledgebase.go +++ b/pkg/minio/knowledgebase.go @@ -13,11 +13,17 @@ type KnowledgeBaseI interface { SaveConvertedFile(ctx context.Context, kbUID, convertedFileUID, fileExt string, content []byte) error // SaveChunks saves batch of chunks(text files) to MinIO. SaveChunks(ctx context.Context, kbUID string, chunks map[ChunkUIDType]ChunkContentType) error + // GetUploadedFilePathInKnowledgeBase returns the path of the uploaded file in MinIO. + GetUploadedFilePathInKnowledgeBase(kbUID, dest string) string + // GetConvertedFilePathInKnowledgeBase returns the path of the converted file in MinIO. + GetConvertedFilePathInKnowledgeBase(kbUID, ConvertedFileUID, fileExt string) string + // GetChunkPathInKnowledgeBase returns the path of the chunk in MinIO. + GetChunkPathInKnowledgeBase(kbUID, chunkUID string) string } // SaveConvertedFile saves a converted file to MinIO with the appropriate MIME type. func (m *Minio) SaveConvertedFile(ctx context.Context, kbUID, convertedFileUID, fileExt string, content []byte) error { - filePathName := GetConvertedFilePathInKnowledgeBase(kbUID, convertedFileUID, fileExt) + filePathName := m.GetConvertedFilePathInKnowledgeBase(kbUID, convertedFileUID, fileExt) mimeType := "application/octet-stream" if fileExt == "md" { mimeType = "text/markdown" @@ -41,7 +47,7 @@ func (m *Minio) SaveChunks(ctx context.Context, kbUID string, chunks map[ChunkUI wg.Add(1) go func(chunkUID ChunkUIDType, chunkContent ChunkContentType) { defer wg.Done() - filePathName := GetChunkPathInKnowledgeBase(kbUID, string(chunkUID)) + filePathName := m.GetChunkPathInKnowledgeBase(kbUID, string(chunkUID)) err := m.UploadBase64File(ctx, filePathName, base64.StdEncoding.EncodeToString(chunkContent), "text/plain") if err != nil { @@ -61,3 +67,15 @@ func (m *Minio) SaveChunks(ctx context.Context, kbUID string, chunks map[ChunkUI } return nil } + +func (m *Minio) GetUploadedFilePathInKnowledgeBase(kbUID, dest string) string { + return kbUID + "/uploaded-file/" + dest +} + +func (m *Minio) GetConvertedFilePathInKnowledgeBase(kbUID, ConvertedFileUID, fileExt string) string { + return kbUID + "/converted-file/" + ConvertedFileUID + "." + fileExt +} + +func (m *Minio) GetChunkPathInKnowledgeBase(kbUID, chunkUID string) string { + return kbUID + "/chunk/" + chunkUID + ".txt" +} diff --git a/pkg/mock/repository_i_mock.gen.go b/pkg/mock/repository_i_mock.gen.go index 158fc89..7608a4b 100644 --- a/pkg/mock/repository_i_mock.gen.go +++ b/pkg/mock/repository_i_mock.gen.go @@ -38,14 +38,14 @@ type RepositoryIMock struct { beforeCreateKnowledgeBaseCounter uint64 CreateKnowledgeBaseMock mRepositoryIMockCreateKnowledgeBase - funcCreateKnowledgeBaseFile func(ctx context.Context, kb mm_repository.KnowledgeBaseFile) (kp1 *mm_repository.KnowledgeBaseFile, err error) - inspectFuncCreateKnowledgeBaseFile func(ctx context.Context, kb mm_repository.KnowledgeBaseFile) + funcCreateKnowledgeBaseFile func(ctx context.Context, kb mm_repository.KnowledgeBaseFile, externalServiceCall func(FileUID string) error) (kp1 *mm_repository.KnowledgeBaseFile, err error) + inspectFuncCreateKnowledgeBaseFile func(ctx context.Context, kb mm_repository.KnowledgeBaseFile, externalServiceCall func(FileUID string) error) afterCreateKnowledgeBaseFileCounter uint64 beforeCreateKnowledgeBaseFileCounter uint64 CreateKnowledgeBaseFileMock mRepositoryIMockCreateKnowledgeBaseFile - funcDeleteAndCreateChunks func(ctx context.Context, sourceTable string, sourceUID uuid.UUID, chunks []mm_repository.TextChunk, externalServiceCall func(chunkUIDs []string) error) (ta1 []mm_repository.TextChunk, err error) - inspectFuncDeleteAndCreateChunks func(ctx context.Context, sourceTable string, sourceUID uuid.UUID, chunks []mm_repository.TextChunk, externalServiceCall func(chunkUIDs []string) error) + funcDeleteAndCreateChunks func(ctx context.Context, sourceTable string, sourceUID uuid.UUID, chunks []*mm_repository.TextChunk, externalServiceCall func(chunkUIDs []string) (map[string]any, error)) (tpa1 []*mm_repository.TextChunk, err error) + inspectFuncDeleteAndCreateChunks func(ctx context.Context, sourceTable string, sourceUID uuid.UUID, chunks []*mm_repository.TextChunk, externalServiceCall func(chunkUIDs []string) (map[string]any, error)) afterDeleteAndCreateChunksCounter uint64 beforeDeleteAndCreateChunksCounter uint64 DeleteAndCreateChunksMock mRepositoryIMockDeleteAndCreateChunks @@ -1117,14 +1117,16 @@ type RepositoryIMockCreateKnowledgeBaseFileExpectation struct { // RepositoryIMockCreateKnowledgeBaseFileParams contains parameters of the RepositoryI.CreateKnowledgeBaseFile type RepositoryIMockCreateKnowledgeBaseFileParams struct { - ctx context.Context - kb mm_repository.KnowledgeBaseFile + ctx context.Context + kb mm_repository.KnowledgeBaseFile + externalServiceCall func(FileUID string) error } // RepositoryIMockCreateKnowledgeBaseFileParamPtrs contains pointers to parameters of the RepositoryI.CreateKnowledgeBaseFile type RepositoryIMockCreateKnowledgeBaseFileParamPtrs struct { - ctx *context.Context - kb *mm_repository.KnowledgeBaseFile + ctx *context.Context + kb *mm_repository.KnowledgeBaseFile + externalServiceCall *func(FileUID string) error } // RepositoryIMockCreateKnowledgeBaseFileResults contains results of the RepositoryI.CreateKnowledgeBaseFile @@ -1134,7 +1136,7 @@ type RepositoryIMockCreateKnowledgeBaseFileResults struct { } // Expect sets up expected params for RepositoryI.CreateKnowledgeBaseFile -func (mmCreateKnowledgeBaseFile *mRepositoryIMockCreateKnowledgeBaseFile) Expect(ctx context.Context, kb mm_repository.KnowledgeBaseFile) *mRepositoryIMockCreateKnowledgeBaseFile { +func (mmCreateKnowledgeBaseFile *mRepositoryIMockCreateKnowledgeBaseFile) Expect(ctx context.Context, kb mm_repository.KnowledgeBaseFile, externalServiceCall func(FileUID string) error) *mRepositoryIMockCreateKnowledgeBaseFile { if mmCreateKnowledgeBaseFile.mock.funcCreateKnowledgeBaseFile != nil { mmCreateKnowledgeBaseFile.mock.t.Fatalf("RepositoryIMock.CreateKnowledgeBaseFile mock is already set by Set") } @@ -1147,7 +1149,7 @@ func (mmCreateKnowledgeBaseFile *mRepositoryIMockCreateKnowledgeBaseFile) Expect mmCreateKnowledgeBaseFile.mock.t.Fatalf("RepositoryIMock.CreateKnowledgeBaseFile mock is already set by ExpectParams functions") } - mmCreateKnowledgeBaseFile.defaultExpectation.params = &RepositoryIMockCreateKnowledgeBaseFileParams{ctx, kb} + mmCreateKnowledgeBaseFile.defaultExpectation.params = &RepositoryIMockCreateKnowledgeBaseFileParams{ctx, kb, externalServiceCall} for _, e := range mmCreateKnowledgeBaseFile.expectations { if minimock.Equal(e.params, mmCreateKnowledgeBaseFile.defaultExpectation.params) { mmCreateKnowledgeBaseFile.mock.t.Fatalf("Expectation set by When has same params: %#v", *mmCreateKnowledgeBaseFile.defaultExpectation.params) @@ -1201,8 +1203,30 @@ func (mmCreateKnowledgeBaseFile *mRepositoryIMockCreateKnowledgeBaseFile) Expect return mmCreateKnowledgeBaseFile } +// ExpectExternalServiceCallParam3 sets up expected param externalServiceCall for RepositoryI.CreateKnowledgeBaseFile +func (mmCreateKnowledgeBaseFile *mRepositoryIMockCreateKnowledgeBaseFile) ExpectExternalServiceCallParam3(externalServiceCall func(FileUID string) error) *mRepositoryIMockCreateKnowledgeBaseFile { + if mmCreateKnowledgeBaseFile.mock.funcCreateKnowledgeBaseFile != nil { + mmCreateKnowledgeBaseFile.mock.t.Fatalf("RepositoryIMock.CreateKnowledgeBaseFile mock is already set by Set") + } + + if mmCreateKnowledgeBaseFile.defaultExpectation == nil { + mmCreateKnowledgeBaseFile.defaultExpectation = &RepositoryIMockCreateKnowledgeBaseFileExpectation{} + } + + if mmCreateKnowledgeBaseFile.defaultExpectation.params != nil { + mmCreateKnowledgeBaseFile.mock.t.Fatalf("RepositoryIMock.CreateKnowledgeBaseFile mock is already set by Expect") + } + + if mmCreateKnowledgeBaseFile.defaultExpectation.paramPtrs == nil { + mmCreateKnowledgeBaseFile.defaultExpectation.paramPtrs = &RepositoryIMockCreateKnowledgeBaseFileParamPtrs{} + } + mmCreateKnowledgeBaseFile.defaultExpectation.paramPtrs.externalServiceCall = &externalServiceCall + + return mmCreateKnowledgeBaseFile +} + // Inspect accepts an inspector function that has same arguments as the RepositoryI.CreateKnowledgeBaseFile -func (mmCreateKnowledgeBaseFile *mRepositoryIMockCreateKnowledgeBaseFile) Inspect(f func(ctx context.Context, kb mm_repository.KnowledgeBaseFile)) *mRepositoryIMockCreateKnowledgeBaseFile { +func (mmCreateKnowledgeBaseFile *mRepositoryIMockCreateKnowledgeBaseFile) Inspect(f func(ctx context.Context, kb mm_repository.KnowledgeBaseFile, externalServiceCall func(FileUID string) error)) *mRepositoryIMockCreateKnowledgeBaseFile { if mmCreateKnowledgeBaseFile.mock.inspectFuncCreateKnowledgeBaseFile != nil { mmCreateKnowledgeBaseFile.mock.t.Fatalf("Inspect function is already set for RepositoryIMock.CreateKnowledgeBaseFile") } @@ -1226,7 +1250,7 @@ func (mmCreateKnowledgeBaseFile *mRepositoryIMockCreateKnowledgeBaseFile) Return } // Set uses given function f to mock the RepositoryI.CreateKnowledgeBaseFile method -func (mmCreateKnowledgeBaseFile *mRepositoryIMockCreateKnowledgeBaseFile) Set(f func(ctx context.Context, kb mm_repository.KnowledgeBaseFile) (kp1 *mm_repository.KnowledgeBaseFile, err error)) *RepositoryIMock { +func (mmCreateKnowledgeBaseFile *mRepositoryIMockCreateKnowledgeBaseFile) Set(f func(ctx context.Context, kb mm_repository.KnowledgeBaseFile, externalServiceCall func(FileUID string) error) (kp1 *mm_repository.KnowledgeBaseFile, err error)) *RepositoryIMock { if mmCreateKnowledgeBaseFile.defaultExpectation != nil { mmCreateKnowledgeBaseFile.mock.t.Fatalf("Default expectation is already set for the RepositoryI.CreateKnowledgeBaseFile method") } @@ -1241,14 +1265,14 @@ func (mmCreateKnowledgeBaseFile *mRepositoryIMockCreateKnowledgeBaseFile) Set(f // When sets expectation for the RepositoryI.CreateKnowledgeBaseFile which will trigger the result defined by the following // Then helper -func (mmCreateKnowledgeBaseFile *mRepositoryIMockCreateKnowledgeBaseFile) When(ctx context.Context, kb mm_repository.KnowledgeBaseFile) *RepositoryIMockCreateKnowledgeBaseFileExpectation { +func (mmCreateKnowledgeBaseFile *mRepositoryIMockCreateKnowledgeBaseFile) When(ctx context.Context, kb mm_repository.KnowledgeBaseFile, externalServiceCall func(FileUID string) error) *RepositoryIMockCreateKnowledgeBaseFileExpectation { if mmCreateKnowledgeBaseFile.mock.funcCreateKnowledgeBaseFile != nil { mmCreateKnowledgeBaseFile.mock.t.Fatalf("RepositoryIMock.CreateKnowledgeBaseFile mock is already set by Set") } expectation := &RepositoryIMockCreateKnowledgeBaseFileExpectation{ mock: mmCreateKnowledgeBaseFile.mock, - params: &RepositoryIMockCreateKnowledgeBaseFileParams{ctx, kb}, + params: &RepositoryIMockCreateKnowledgeBaseFileParams{ctx, kb, externalServiceCall}, } mmCreateKnowledgeBaseFile.expectations = append(mmCreateKnowledgeBaseFile.expectations, expectation) return expectation @@ -1281,15 +1305,15 @@ func (mmCreateKnowledgeBaseFile *mRepositoryIMockCreateKnowledgeBaseFile) invoca } // CreateKnowledgeBaseFile implements repository.RepositoryI -func (mmCreateKnowledgeBaseFile *RepositoryIMock) CreateKnowledgeBaseFile(ctx context.Context, kb mm_repository.KnowledgeBaseFile) (kp1 *mm_repository.KnowledgeBaseFile, err error) { +func (mmCreateKnowledgeBaseFile *RepositoryIMock) CreateKnowledgeBaseFile(ctx context.Context, kb mm_repository.KnowledgeBaseFile, externalServiceCall func(FileUID string) error) (kp1 *mm_repository.KnowledgeBaseFile, err error) { mm_atomic.AddUint64(&mmCreateKnowledgeBaseFile.beforeCreateKnowledgeBaseFileCounter, 1) defer mm_atomic.AddUint64(&mmCreateKnowledgeBaseFile.afterCreateKnowledgeBaseFileCounter, 1) if mmCreateKnowledgeBaseFile.inspectFuncCreateKnowledgeBaseFile != nil { - mmCreateKnowledgeBaseFile.inspectFuncCreateKnowledgeBaseFile(ctx, kb) + mmCreateKnowledgeBaseFile.inspectFuncCreateKnowledgeBaseFile(ctx, kb, externalServiceCall) } - mm_params := RepositoryIMockCreateKnowledgeBaseFileParams{ctx, kb} + mm_params := RepositoryIMockCreateKnowledgeBaseFileParams{ctx, kb, externalServiceCall} // Record call args mmCreateKnowledgeBaseFile.CreateKnowledgeBaseFileMock.mutex.Lock() @@ -1308,7 +1332,7 @@ func (mmCreateKnowledgeBaseFile *RepositoryIMock) CreateKnowledgeBaseFile(ctx co mm_want := mmCreateKnowledgeBaseFile.CreateKnowledgeBaseFileMock.defaultExpectation.params mm_want_ptrs := mmCreateKnowledgeBaseFile.CreateKnowledgeBaseFileMock.defaultExpectation.paramPtrs - mm_got := RepositoryIMockCreateKnowledgeBaseFileParams{ctx, kb} + mm_got := RepositoryIMockCreateKnowledgeBaseFileParams{ctx, kb, externalServiceCall} if mm_want_ptrs != nil { @@ -1320,6 +1344,10 @@ func (mmCreateKnowledgeBaseFile *RepositoryIMock) CreateKnowledgeBaseFile(ctx co mmCreateKnowledgeBaseFile.t.Errorf("RepositoryIMock.CreateKnowledgeBaseFile got unexpected parameter kb, want: %#v, got: %#v%s\n", *mm_want_ptrs.kb, mm_got.kb, minimock.Diff(*mm_want_ptrs.kb, mm_got.kb)) } + if mm_want_ptrs.externalServiceCall != nil && !minimock.Equal(*mm_want_ptrs.externalServiceCall, mm_got.externalServiceCall) { + mmCreateKnowledgeBaseFile.t.Errorf("RepositoryIMock.CreateKnowledgeBaseFile got unexpected parameter externalServiceCall, want: %#v, got: %#v%s\n", *mm_want_ptrs.externalServiceCall, mm_got.externalServiceCall, minimock.Diff(*mm_want_ptrs.externalServiceCall, mm_got.externalServiceCall)) + } + } else if mm_want != nil && !minimock.Equal(*mm_want, mm_got) { mmCreateKnowledgeBaseFile.t.Errorf("RepositoryIMock.CreateKnowledgeBaseFile got unexpected parameters, want: %#v, got: %#v%s\n", *mm_want, mm_got, minimock.Diff(*mm_want, mm_got)) } @@ -1331,9 +1359,9 @@ func (mmCreateKnowledgeBaseFile *RepositoryIMock) CreateKnowledgeBaseFile(ctx co return (*mm_results).kp1, (*mm_results).err } if mmCreateKnowledgeBaseFile.funcCreateKnowledgeBaseFile != nil { - return mmCreateKnowledgeBaseFile.funcCreateKnowledgeBaseFile(ctx, kb) + return mmCreateKnowledgeBaseFile.funcCreateKnowledgeBaseFile(ctx, kb, externalServiceCall) } - mmCreateKnowledgeBaseFile.t.Fatalf("Unexpected call to RepositoryIMock.CreateKnowledgeBaseFile. %v %v", ctx, kb) + mmCreateKnowledgeBaseFile.t.Fatalf("Unexpected call to RepositoryIMock.CreateKnowledgeBaseFile. %v %v %v", ctx, kb, externalServiceCall) return } @@ -1425,8 +1453,8 @@ type RepositoryIMockDeleteAndCreateChunksParams struct { ctx context.Context sourceTable string sourceUID uuid.UUID - chunks []mm_repository.TextChunk - externalServiceCall func(chunkUIDs []string) error + chunks []*mm_repository.TextChunk + externalServiceCall func(chunkUIDs []string) (map[string]any, error) } // RepositoryIMockDeleteAndCreateChunksParamPtrs contains pointers to parameters of the RepositoryI.DeleteAndCreateChunks @@ -1434,18 +1462,18 @@ type RepositoryIMockDeleteAndCreateChunksParamPtrs struct { ctx *context.Context sourceTable *string sourceUID *uuid.UUID - chunks *[]mm_repository.TextChunk - externalServiceCall *func(chunkUIDs []string) error + chunks *[]*mm_repository.TextChunk + externalServiceCall *func(chunkUIDs []string) (map[string]any, error) } // RepositoryIMockDeleteAndCreateChunksResults contains results of the RepositoryI.DeleteAndCreateChunks type RepositoryIMockDeleteAndCreateChunksResults struct { - ta1 []mm_repository.TextChunk - err error + tpa1 []*mm_repository.TextChunk + err error } // Expect sets up expected params for RepositoryI.DeleteAndCreateChunks -func (mmDeleteAndCreateChunks *mRepositoryIMockDeleteAndCreateChunks) Expect(ctx context.Context, sourceTable string, sourceUID uuid.UUID, chunks []mm_repository.TextChunk, externalServiceCall func(chunkUIDs []string) error) *mRepositoryIMockDeleteAndCreateChunks { +func (mmDeleteAndCreateChunks *mRepositoryIMockDeleteAndCreateChunks) Expect(ctx context.Context, sourceTable string, sourceUID uuid.UUID, chunks []*mm_repository.TextChunk, externalServiceCall func(chunkUIDs []string) (map[string]any, error)) *mRepositoryIMockDeleteAndCreateChunks { if mmDeleteAndCreateChunks.mock.funcDeleteAndCreateChunks != nil { mmDeleteAndCreateChunks.mock.t.Fatalf("RepositoryIMock.DeleteAndCreateChunks mock is already set by Set") } @@ -1535,7 +1563,7 @@ func (mmDeleteAndCreateChunks *mRepositoryIMockDeleteAndCreateChunks) ExpectSour } // ExpectChunksParam4 sets up expected param chunks for RepositoryI.DeleteAndCreateChunks -func (mmDeleteAndCreateChunks *mRepositoryIMockDeleteAndCreateChunks) ExpectChunksParam4(chunks []mm_repository.TextChunk) *mRepositoryIMockDeleteAndCreateChunks { +func (mmDeleteAndCreateChunks *mRepositoryIMockDeleteAndCreateChunks) ExpectChunksParam4(chunks []*mm_repository.TextChunk) *mRepositoryIMockDeleteAndCreateChunks { if mmDeleteAndCreateChunks.mock.funcDeleteAndCreateChunks != nil { mmDeleteAndCreateChunks.mock.t.Fatalf("RepositoryIMock.DeleteAndCreateChunks mock is already set by Set") } @@ -1557,7 +1585,7 @@ func (mmDeleteAndCreateChunks *mRepositoryIMockDeleteAndCreateChunks) ExpectChun } // ExpectExternalServiceCallParam5 sets up expected param externalServiceCall for RepositoryI.DeleteAndCreateChunks -func (mmDeleteAndCreateChunks *mRepositoryIMockDeleteAndCreateChunks) ExpectExternalServiceCallParam5(externalServiceCall func(chunkUIDs []string) error) *mRepositoryIMockDeleteAndCreateChunks { +func (mmDeleteAndCreateChunks *mRepositoryIMockDeleteAndCreateChunks) ExpectExternalServiceCallParam5(externalServiceCall func(chunkUIDs []string) (map[string]any, error)) *mRepositoryIMockDeleteAndCreateChunks { if mmDeleteAndCreateChunks.mock.funcDeleteAndCreateChunks != nil { mmDeleteAndCreateChunks.mock.t.Fatalf("RepositoryIMock.DeleteAndCreateChunks mock is already set by Set") } @@ -1579,7 +1607,7 @@ func (mmDeleteAndCreateChunks *mRepositoryIMockDeleteAndCreateChunks) ExpectExte } // Inspect accepts an inspector function that has same arguments as the RepositoryI.DeleteAndCreateChunks -func (mmDeleteAndCreateChunks *mRepositoryIMockDeleteAndCreateChunks) Inspect(f func(ctx context.Context, sourceTable string, sourceUID uuid.UUID, chunks []mm_repository.TextChunk, externalServiceCall func(chunkUIDs []string) error)) *mRepositoryIMockDeleteAndCreateChunks { +func (mmDeleteAndCreateChunks *mRepositoryIMockDeleteAndCreateChunks) Inspect(f func(ctx context.Context, sourceTable string, sourceUID uuid.UUID, chunks []*mm_repository.TextChunk, externalServiceCall func(chunkUIDs []string) (map[string]any, error))) *mRepositoryIMockDeleteAndCreateChunks { if mmDeleteAndCreateChunks.mock.inspectFuncDeleteAndCreateChunks != nil { mmDeleteAndCreateChunks.mock.t.Fatalf("Inspect function is already set for RepositoryIMock.DeleteAndCreateChunks") } @@ -1590,7 +1618,7 @@ func (mmDeleteAndCreateChunks *mRepositoryIMockDeleteAndCreateChunks) Inspect(f } // Return sets up results that will be returned by RepositoryI.DeleteAndCreateChunks -func (mmDeleteAndCreateChunks *mRepositoryIMockDeleteAndCreateChunks) Return(ta1 []mm_repository.TextChunk, err error) *RepositoryIMock { +func (mmDeleteAndCreateChunks *mRepositoryIMockDeleteAndCreateChunks) Return(tpa1 []*mm_repository.TextChunk, err error) *RepositoryIMock { if mmDeleteAndCreateChunks.mock.funcDeleteAndCreateChunks != nil { mmDeleteAndCreateChunks.mock.t.Fatalf("RepositoryIMock.DeleteAndCreateChunks mock is already set by Set") } @@ -1598,12 +1626,12 @@ func (mmDeleteAndCreateChunks *mRepositoryIMockDeleteAndCreateChunks) Return(ta1 if mmDeleteAndCreateChunks.defaultExpectation == nil { mmDeleteAndCreateChunks.defaultExpectation = &RepositoryIMockDeleteAndCreateChunksExpectation{mock: mmDeleteAndCreateChunks.mock} } - mmDeleteAndCreateChunks.defaultExpectation.results = &RepositoryIMockDeleteAndCreateChunksResults{ta1, err} + mmDeleteAndCreateChunks.defaultExpectation.results = &RepositoryIMockDeleteAndCreateChunksResults{tpa1, err} return mmDeleteAndCreateChunks.mock } // Set uses given function f to mock the RepositoryI.DeleteAndCreateChunks method -func (mmDeleteAndCreateChunks *mRepositoryIMockDeleteAndCreateChunks) Set(f func(ctx context.Context, sourceTable string, sourceUID uuid.UUID, chunks []mm_repository.TextChunk, externalServiceCall func(chunkUIDs []string) error) (ta1 []mm_repository.TextChunk, err error)) *RepositoryIMock { +func (mmDeleteAndCreateChunks *mRepositoryIMockDeleteAndCreateChunks) Set(f func(ctx context.Context, sourceTable string, sourceUID uuid.UUID, chunks []*mm_repository.TextChunk, externalServiceCall func(chunkUIDs []string) (map[string]any, error)) (tpa1 []*mm_repository.TextChunk, err error)) *RepositoryIMock { if mmDeleteAndCreateChunks.defaultExpectation != nil { mmDeleteAndCreateChunks.mock.t.Fatalf("Default expectation is already set for the RepositoryI.DeleteAndCreateChunks method") } @@ -1618,7 +1646,7 @@ func (mmDeleteAndCreateChunks *mRepositoryIMockDeleteAndCreateChunks) Set(f func // When sets expectation for the RepositoryI.DeleteAndCreateChunks which will trigger the result defined by the following // Then helper -func (mmDeleteAndCreateChunks *mRepositoryIMockDeleteAndCreateChunks) When(ctx context.Context, sourceTable string, sourceUID uuid.UUID, chunks []mm_repository.TextChunk, externalServiceCall func(chunkUIDs []string) error) *RepositoryIMockDeleteAndCreateChunksExpectation { +func (mmDeleteAndCreateChunks *mRepositoryIMockDeleteAndCreateChunks) When(ctx context.Context, sourceTable string, sourceUID uuid.UUID, chunks []*mm_repository.TextChunk, externalServiceCall func(chunkUIDs []string) (map[string]any, error)) *RepositoryIMockDeleteAndCreateChunksExpectation { if mmDeleteAndCreateChunks.mock.funcDeleteAndCreateChunks != nil { mmDeleteAndCreateChunks.mock.t.Fatalf("RepositoryIMock.DeleteAndCreateChunks mock is already set by Set") } @@ -1632,8 +1660,8 @@ func (mmDeleteAndCreateChunks *mRepositoryIMockDeleteAndCreateChunks) When(ctx c } // Then sets up RepositoryI.DeleteAndCreateChunks return parameters for the expectation previously defined by the When method -func (e *RepositoryIMockDeleteAndCreateChunksExpectation) Then(ta1 []mm_repository.TextChunk, err error) *RepositoryIMock { - e.results = &RepositoryIMockDeleteAndCreateChunksResults{ta1, err} +func (e *RepositoryIMockDeleteAndCreateChunksExpectation) Then(tpa1 []*mm_repository.TextChunk, err error) *RepositoryIMock { + e.results = &RepositoryIMockDeleteAndCreateChunksResults{tpa1, err} return e.mock } @@ -1658,7 +1686,7 @@ func (mmDeleteAndCreateChunks *mRepositoryIMockDeleteAndCreateChunks) invocation } // DeleteAndCreateChunks implements repository.RepositoryI -func (mmDeleteAndCreateChunks *RepositoryIMock) DeleteAndCreateChunks(ctx context.Context, sourceTable string, sourceUID uuid.UUID, chunks []mm_repository.TextChunk, externalServiceCall func(chunkUIDs []string) error) (ta1 []mm_repository.TextChunk, err error) { +func (mmDeleteAndCreateChunks *RepositoryIMock) DeleteAndCreateChunks(ctx context.Context, sourceTable string, sourceUID uuid.UUID, chunks []*mm_repository.TextChunk, externalServiceCall func(chunkUIDs []string) (map[string]any, error)) (tpa1 []*mm_repository.TextChunk, err error) { mm_atomic.AddUint64(&mmDeleteAndCreateChunks.beforeDeleteAndCreateChunksCounter, 1) defer mm_atomic.AddUint64(&mmDeleteAndCreateChunks.afterDeleteAndCreateChunksCounter, 1) @@ -1676,7 +1704,7 @@ func (mmDeleteAndCreateChunks *RepositoryIMock) DeleteAndCreateChunks(ctx contex for _, e := range mmDeleteAndCreateChunks.DeleteAndCreateChunksMock.expectations { if minimock.Equal(*e.params, mm_params) { mm_atomic.AddUint64(&e.Counter, 1) - return e.results.ta1, e.results.err + return e.results.tpa1, e.results.err } } @@ -1717,7 +1745,7 @@ func (mmDeleteAndCreateChunks *RepositoryIMock) DeleteAndCreateChunks(ctx contex if mm_results == nil { mmDeleteAndCreateChunks.t.Fatal("No results are set for the RepositoryIMock.DeleteAndCreateChunks") } - return (*mm_results).ta1, (*mm_results).err + return (*mm_results).tpa1, (*mm_results).err } if mmDeleteAndCreateChunks.funcDeleteAndCreateChunks != nil { return mmDeleteAndCreateChunks.funcDeleteAndCreateChunks(ctx, sourceTable, sourceUID, chunks, externalServiceCall) diff --git a/pkg/repository/chunk.go b/pkg/repository/chunk.go index 91c21e0..2d9e2f9 100644 --- a/pkg/repository/chunk.go +++ b/pkg/repository/chunk.go @@ -7,10 +7,11 @@ import ( "github.com/google/uuid" "gorm.io/gorm" + "gorm.io/gorm/clause" ) type TextChunkI interface { - DeleteAndCreateChunks(ctx context.Context, sourceTable string, sourceUID uuid.UUID, chunks []TextChunk, externalServiceCall func(chunkUIDs []string) error) ([]TextChunk, error) + DeleteAndCreateChunks(ctx context.Context, sourceTable string, sourceUID uuid.UUID, chunks []*TextChunk, externalServiceCall func(chunkUIDs []string) (map[string]any, error)) ([]*TextChunk, error) DeleteChunksBySource(ctx context.Context, sourceTable string, sourceUID uuid.UUID) error DeleteChunksByUIDs(ctx context.Context, chunkUIDs []uuid.UUID) error GetTextChunksBySource(ctx context.Context, sourceTable string, sourceUID uuid.UUID) ([]TextChunk, error) @@ -22,13 +23,13 @@ type TextChunk struct { UID uuid.UUID `gorm:"column:uid;type:uuid;default:gen_random_uuid();primaryKey" json:"uid"` SourceUID uuid.UUID `gorm:"column:source_uid;type:uuid;not null" json:"source_uid"` SourceTable string `gorm:"column:source_table;size:255;not null" json:"source_table"` - Start int `gorm:"column:start;not null" json:"start"` - End int `gorm:"column:end;not null" json:"end"` + StartPos int `gorm:"column:start_pos;not null" json:"start"` + EndPos int `gorm:"column:end_pos;not null" json:"end"` // ContentDest is the destination path in minio ContentDest string `gorm:"column:content_dest;size:255;not null" json:"content_dest"` Tokens int `gorm:"column:tokens;not null" json:"tokens"` Retrievable bool `gorm:"column:retrievable;not null;default:true" json:"retrievable"` - Order int `gorm:"column:order;not null" json:"order"` + InOrder int `gorm:"column:in_order;not null" json:"order"` CreateTime *time.Time `gorm:"column:create_time;not null;default:CURRENT_TIMESTAMP" json:"create_time"` UpdateTime *time.Time `gorm:"column:update_time;not null;default:CURRENT_TIMESTAMP" json:"update_time"` } @@ -51,12 +52,12 @@ var TextChunkColumn = TextChunkColumns{ UID: "uid", SourceUID: "source_uid", SourceTable: "source_table", - Start: "start", - End: "end", + Start: "start_pos", + End: "end_pos", ContentDest: "content_dest", Tokens: "tokens", Retrievable: "retrievable", - Order: "order", + Order: "in_order", CreateTime: "create_time", UpdateTime: "update_time", } @@ -69,7 +70,7 @@ func (TextChunk) TableName() string { // DeleteAndCreateChunks deletes all the chunks associated with // a certain source table and sourceUID, then batch inserts the new chunks // within a transaction. -func (r *Repository) DeleteAndCreateChunks(ctx context.Context, sourceTable string, sourceUID uuid.UUID, chunks []TextChunk, externalServiceCall func(chunkUIDs []string) error) ([]TextChunk, error) { +func (r *Repository) DeleteAndCreateChunks(ctx context.Context, sourceTable string, sourceUID uuid.UUID, chunks []*TextChunk, externalServiceCall func(chunkUIDs []string) (map[string]any, error)) ([]*TextChunk, error) { // Start a transaction err := r.db.Transaction(func(tx *gorm.DB) error { // Delete existing chunks @@ -89,11 +90,25 @@ func (r *Repository) DeleteAndCreateChunks(ctx context.Context, sourceTable stri chunkUIDs = append(chunkUIDs, chunk.UID.String()) } if externalServiceCall != nil { - if err := externalServiceCall(chunkUIDs); err != nil { + if chunkDestMap, err := externalServiceCall(chunkUIDs); err != nil { return err + } else { + // update the content dest of each chunk + for _, chunk := range chunks { + if dest, ok := chunkDestMap[chunk.UID.String()]; ok { + if data, ok := dest.(string); ok { + chunk.ContentDest = data + } + } + } } } + // Update the content dest of each chunk + if err := BatchUpdateContentDest(ctx, tx, chunks); err != nil { + return err + } + return nil }) @@ -104,6 +119,23 @@ func (r *Repository) DeleteAndCreateChunks(ctx context.Context, sourceTable stri return chunks, nil } +// Batch update function +func BatchUpdateContentDest(ctx context.Context, tx *gorm.DB, chunks []*TextChunk) error { + if len(chunks) == 0 { + return nil + } + + return tx.WithContext(ctx).Model(&TextChunk{}). + Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: TextChunkColumn.UID}}, + DoUpdates: clause.AssignmentColumns([]string{ + TextChunkColumn.SourceUID, TextChunkColumn.SourceTable, TextChunkColumn.Start, TextChunkColumn.End, TextChunkColumn.ContentDest, + TextChunkColumn.Tokens, TextChunkColumn.Retrievable, TextChunkColumn.Order, + }), + }). + Create(chunks).Error +} + // DeleteChunksBySource deletes all the chunks associated with a certain source table and sourceUID. func (r *Repository) DeleteChunksBySource(ctx context.Context, sourceTable string, sourceUID uuid.UUID) error { where := fmt.Sprintf("%s = ? AND %s = ?", TextChunkColumn.SourceTable, TextChunkColumn.SourceUID) diff --git a/pkg/repository/embedding.go b/pkg/repository/embedding.go index 859afd7..5c79da5 100644 --- a/pkg/repository/embedding.go +++ b/pkg/repository/embedding.go @@ -22,6 +22,7 @@ type Embedding struct { SourceUID uuid.UUID `gorm:"column:source_uid;type:uuid;not null" json:"source_uid"` SourceTable string `gorm:"column:source_table;size:255;not null" json:"source_table"` Vector []float32 `gorm:"column:vector;type:jsonb;not null" json:"vector"` + Collection string `gorm:"column:collection;size:255;not null" json:"collection"` CreateTime *time.Time `gorm:"column:create_time;not null;default:CURRENT_TIMESTAMP" json:"create_time"` UpdateTime *time.Time `gorm:"column:update_time;not null;default:CURRENT_TIMESTAMP" json:"update_time"` } @@ -31,6 +32,7 @@ type EmbeddingColumns struct { SourceUID string SourceTable string Vector string + Collection string CreateTime string UpdateTime string } @@ -40,6 +42,7 @@ var EmbeddingColumn = EmbeddingColumns{ SourceUID: "source_uid", SourceTable: "source_table", Vector: "vector", + Collection: "collection", CreateTime: "create_time", UpdateTime: "update_time", } diff --git a/pkg/repository/knowledgebasefile.go b/pkg/repository/knowledgebasefile.go index 896cfa2..02db390 100644 --- a/pkg/repository/knowledgebasefile.go +++ b/pkg/repository/knowledgebasefile.go @@ -12,7 +12,7 @@ import ( type KnowledgeBaseFileI interface { KnowledgeBaseFileTableName() string - CreateKnowledgeBaseFile(ctx context.Context, kb KnowledgeBaseFile) (*KnowledgeBaseFile, error) + CreateKnowledgeBaseFile(ctx context.Context, kb KnowledgeBaseFile, externalServiceCall func(FileUID string) error) (*KnowledgeBaseFile, error) ListKnowledgeBaseFiles(ctx context.Context, uid string, ownerUID string, kbUID string, pageSize int32, nextPageToken string, filesUID []string) ([]KnowledgeBaseFile, int, string, error) DeleteKnowledgeBaseFile(ctx context.Context, fileUID string) error ProcessKnowledgeBaseFiles(ctx context.Context, fileUids []string) ([]KnowledgeBaseFile, error) @@ -74,7 +74,7 @@ func (r *Repository) KnowledgeBaseFileTableName() string { return "knowledge_base_file" } -func (r *Repository) CreateKnowledgeBaseFile(ctx context.Context, kb KnowledgeBaseFile) (*KnowledgeBaseFile, error) { +func (r *Repository) CreateKnowledgeBaseFile(ctx context.Context, kb KnowledgeBaseFile, externalServiceCall func(FileUID string) error) (*KnowledgeBaseFile, error) { // check if the file already exists in the same knowledge base and not delete var existingFile KnowledgeBaseFile whereClause := fmt.Sprintf("%s = ? AND %s = ? AND %v is NULL", KnowledgeBaseFileColumn.KnowledgeBaseUID, KnowledgeBaseFileColumn.Name, KnowledgeBaseFileColumn.DeleteTime) @@ -95,9 +95,28 @@ func (r *Repository) CreateKnowledgeBaseFile(ctx context.Context, kb KnowledgeBa } kb.ExtraMetaData = "{}" - if err := r.db.WithContext(ctx).Create(&kb).Error; err != nil { + + // Use a transaction to create the knowledge base file and call the external service + err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + // Create the knowledge base file + if err := tx.Create(&kb).Error; err != nil { + return err + } + + // Call the external service + if externalServiceCall != nil { + if err := externalServiceCall(kb.UID.String()); err != nil { + return err + } + } + + return nil + }) + + if err != nil { return nil, err } + return &kb, nil } diff --git a/pkg/worker/worker.go b/pkg/worker/worker.go index 2285894..2cd449a 100644 --- a/pkg/worker/worker.go +++ b/pkg/worker/worker.go @@ -94,7 +94,7 @@ func (wp *fileToEmbWorkerPool) startDispatcher(ctx context.Context) { fmt.Println("Dispatcher received termination signal while dispatching") return case wp.channel <- file: - fmt.Printf("Dispatcher dispatched file. fileUID: %s", file.UID.String()) + fmt.Printf("Dispatcher dispatched file. fileUID: %s\n", file.UID.String()) } } } @@ -124,8 +124,11 @@ func (wp *fileToEmbWorkerPool) startWorker(ctx context.Context, workerID int) { } // register file process worker in redis and extend the lifetime - stopRegisterFunc := wp.registerFileWorker(ctx, file.UID.String(), extensionHelperPeriod, workerLifetime) + ok, stopRegisterFunc := wp.registerFileWorker(ctx, file.UID.String(), extensionHelperPeriod, workerLifetime) + if !ok { + continue + } // start file processing tracing fmt.Printf("Worker %d processing file: %s\n", workerID, file.UID.String()) @@ -159,15 +162,15 @@ type stopRegisterWorkerFunc func() // It returns a stopRegisterWorkerFunc that can be used to cancel the worker's lifetime extension and remove the worker key from Redis. // period: second // workerLifetime: second -func (wp *fileToEmbWorkerPool) registerFileWorker(ctx context.Context, fileUID string, period time.Duration, workerLifetime time.Duration) stopRegisterWorkerFunc { +func (wp *fileToEmbWorkerPool) registerFileWorker(ctx context.Context, fileUID string, period time.Duration, workerLifetime time.Duration) (bool, stopRegisterWorkerFunc) { ok, err := wp.svc.RedisClient.SetNX(ctx, getWorkerKey(fileUID), "1", workerLifetime).Result() if err != nil { - fmt.Printf("Error when setting worker key in redis. Error: %v", err) - return nil + fmt.Printf("Error when setting worker key in redis. Error: %v\n", err) + return false, nil } if !ok { - fmt.Printf("Worker already exists in redis. fileUID: %s", fileUID) - return nil + fmt.Printf("File is already being processed in redis. fileUID: %s\n", fileUID) + return false, nil } ctx, lifetimeHelperCancel := context.WithCancel(ctx) @@ -183,9 +186,10 @@ func (wp *fileToEmbWorkerPool) registerFileWorker(ctx context.Context, fileUID s return case <-ticker.C: // extend the lifetime of the worker + fmt.Printf("Extending %v's lifetime: %v \n", getWorkerKey(fileUID), workerLifetime) err := wp.svc.RedisClient.Expire(ctx, getWorkerKey(fileUID), workerLifetime).Err() if err != nil { - fmt.Printf("Error when extending worker lifetime in redis. Error: %v, worker: %v", err, getWorkerKey(fileUID)) + fmt.Printf("Error when extending worker lifetime in redis. Error: %v, worker: %v\n", err, getWorkerKey(fileUID)) return } } @@ -199,7 +203,7 @@ func (wp *fileToEmbWorkerPool) registerFileWorker(ctx context.Context, fileUID s wp.svc.RedisClient.Del(ctx, getWorkerKey(fileUID)) } - return stopRegisterWorker + return true, stopRegisterWorker } // checkFileWorker checks if any of the provided fileUIDs have active workers @@ -317,7 +321,7 @@ func (wp *fileToEmbWorkerPool) processWaitingFile(ctx context.Context, file repo artifactpb.FileType_name[int32(artifactpb.FileType_FILE_TYPE_MARKDOWN)]: updateMap := map[string]interface{}{ - repository.KnowledgeBaseFileColumn.ProcessStatus: artifactpb.FileProcessStatus_name[int32(artifactpb.FileProcessStatus_FILE_PROCESS_STATUS_EMBEDDING)], + repository.KnowledgeBaseFileColumn.ProcessStatus: artifactpb.FileProcessStatus_name[int32(artifactpb.FileProcessStatus_FILE_PROCESS_STATUS_CHUNKING)], } updatedFile, err := wp.svc.Repository.UpdateKnowledgeBaseFile(ctx, file.UID.String(), updateMap) if err != nil { @@ -402,7 +406,7 @@ func (wp *fileToEmbWorkerPool) processChunkingFile(ctx context.Context, file rep logger.Error("Failed to get converted file from minIO.", zap.String("Converted file uid", convertedFile.UID.String())) return nil, artifactpb.FileProcessStatus_FILE_PROCESS_STATUS_UNSPECIFIED, err } - // call the chunking pipeline + // call the markdown chunking pipeline chunks, err := wp.svc.SplitMarkdown(ctx, file.CreatorUID, string(convertedFileData)) if err != nil { logger.Error("Failed to get chunks from converted file.", zap.String("Converted file uid", convertedFile.UID.String())) @@ -435,8 +439,8 @@ func (wp *fileToEmbWorkerPool) processChunkingFile(ctx context.Context, file rep return nil, artifactpb.FileProcessStatus_FILE_PROCESS_STATUS_UNSPECIFIED, err } - // Call the markdown chunking pipeline - chunks, err := wp.svc.SplitMarkdown(ctx, file.CreatorUID, string(originalFile)) + // Call the text chunking pipeline + chunks, err := wp.svc.SplitText(ctx, file.CreatorUID, string(originalFile)) if err != nil { logger.Error("Failed to get chunks from original file.", zap.String("File uid", file.UID.String())) return nil, artifactpb.FileProcessStatus_FILE_PROCESS_STATUS_UNSPECIFIED, err @@ -557,12 +561,14 @@ func (wp *fileToEmbWorkerPool) processEmbeddingFile(ctx context.Context, file re } // save the embeddings into milvus and metadata into database + collection := wp.svc.MilvusClient.GetKnowledgeBaseCollectionName(file.KnowledgeBaseUID.String()) embeddings := make([]repository.Embedding, len(vectors)) for i, v := range vectors { embeddings[i] = repository.Embedding{ SourceUID: sourceUID, SourceTable: sourceTable, Vector: v, + Collection: collection, } } err = wp.saveEmbeddings(ctx, file.KnowledgeBaseUID.String(), embeddings) @@ -622,21 +628,21 @@ type chunk = struct { // save chunk into object storage and metadata into database func (wp *fileToEmbWorkerPool) saveChunks(ctx context.Context, kbUID string, sourceTable string, sourceUID uuid.UUID, chunks []chunk) error { logger, _ := logger.GetZapLogger(ctx) - textChunks := make([]repository.TextChunk, len(chunks)) + textChunks := make([]*repository.TextChunk, len(chunks)) for i, c := range chunks { - textChunks[i] = repository.TextChunk{ + textChunks[i] = &repository.TextChunk{ SourceUID: sourceUID, SourceTable: sourceTable, - Start: c.Start, - End: c.End, - ContentDest: minio.GetChunkPathInKnowledgeBase(kbUID, sourceUID.String()), + StartPos: c.Start, + EndPos: c.End, + ContentDest: "not set yet", Tokens: c.Tokens, Retrievable: true, - Order: i, + InOrder: i, } } _, err := wp.svc.Repository.DeleteAndCreateChunks(ctx, sourceTable, sourceUID, textChunks, - func(chunkUIDs []string) error { + func(chunkUIDs []string) (map[string]any, error) { // save the chunksForMinIO into object storage chunksForMinIO := make(map[minio.ChunkUIDType]minio.ChunkContentType, len(textChunks)) for i, uid := range chunkUIDs { @@ -645,9 +651,13 @@ func (wp *fileToEmbWorkerPool) saveChunks(ctx context.Context, kbUID string, sou err := wp.svc.MinIO.SaveChunks(ctx, kbUID, chunksForMinIO) if err != nil { logger.Error("Failed to save chunks into object storage.", zap.String("SourceUID", sourceUID.String())) - return err + return nil, err } - return nil + chunkDestMap := make(map[string]any, len(chunkUIDs)) + for _, chunkUID := range chunkUIDs { + chunkDestMap[chunkUID] = wp.svc.MinIO.GetChunkPathInKnowledgeBase(kbUID, string(chunkUID)) + } + return chunkDestMap, nil }, ) if err != nil {