Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/feature' into personal/yfei/fast…
Browse files Browse the repository at this point in the history
…_bm25
  • Loading branch information
moria97 committed Jun 18, 2024
2 parents f1a18ba + 6c5eee4 commit 90d4b43
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/pai_rag/app/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def aupdate(new_config: Any = Body(None)):


@router.post("/upload_data")
async def load_data(input: DataInput, background_tasks: BackgroundTasks):
def load_data(input: DataInput, background_tasks: BackgroundTasks):
task_id = uuid.uuid4().hex
background_tasks.add_task(
rag_service.add_knowledge_async,
Expand Down
12 changes: 7 additions & 5 deletions src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import requests
import html
import markdown
import httpx

cache_config = None

Expand Down Expand Up @@ -95,11 +96,12 @@ def add_knowledge(self, file_dir: str, enable_qa_extraction: bool):
response = dotdict(json.loads(r.text))
return response

def get_knowledge_state(self, task_id: str):
r = requests.get(self.get_load_state_url, params={"task_id": task_id})
r.raise_for_status()
response = dotdict(json.loads(r.text))
return response
async def get_knowledge_state(self, task_id: str):
async with httpx.AsyncClient() as client:
r = await client.get(self.get_load_state_url, params={"task_id": task_id})
r.raise_for_status()
response = dotdict(json.loads(r.text))
return response

def reload_config(self, config: Any):
global cache_config
Expand Down
11 changes: 9 additions & 2 deletions src/pai_rag/app/web/tabs/upload_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pai_rag.app.web.view_model import view_model
from pai_rag.utils.file_utils import MyUploadFile
import pandas as pd
import asyncio


def upload_knowledge(upload_files, chunk_size, chunk_overlap, enable_qa_extraction):
Expand All @@ -15,7 +16,13 @@ def upload_knowledge(upload_files, chunk_size, chunk_overlap, enable_qa_extracti
rag_client.reload_config(new_config)

if not upload_files:
return "No file selected. Please choose at least one file."
yield [
gr.update(visible=False),
gr.update(
visible=True,
value="No file selected. Please choose at least one file.",
),
]

my_upload_files = []
for file in upload_files:
Expand All @@ -28,7 +35,7 @@ def upload_knowledge(upload_files, chunk_size, chunk_overlap, enable_qa_extracti
result = {"Info": ["StartTime", "EndTime", "Duration(s)", "Status"]}
while not all(file.finished is True for file in my_upload_files):
for file in my_upload_files:
response = rag_client.get_knowledge_state(str(file.task_id))
response = asyncio.run(rag_client.get_knowledge_state(str(file.task_id)))
file.update_state(response["status"])
file.update_process_duration()
result[file.file_name] = file.__info__()
Expand Down
15 changes: 13 additions & 2 deletions src/pai_rag/core/rag_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,31 @@ def reload(self, config):
self.logger.info("RagApplication reloaded successfully.")

# TODO: 大量文件上传实现异步添加
async def load_knowledge(self, file_dir, enable_qa_extraction=False):
async def aload_knowledge(self, file_dir, enable_qa_extraction=False):
data_loader = module_registry.get_module_with_config(
"DataLoaderModule", self.config
)
await data_loader.aload(file_dir, enable_qa_extraction)

def load_knowledge(self, file_dir, enable_qa_extraction=False):
data_loader = module_registry.get_module_with_config(
"DataLoaderModule", self.config
)
data_loader.load(file_dir, enable_qa_extraction)

async def aquery_retrieval(self, query: RetrievalQuery) -> RetrievalResponse:
if not query.question:
return RetrievalResponse(docs=[])

sessioned_config = self.config
if query.vector_db and query.vector_db.faiss_path:
sessioned_config = self.config.copy()
sessioned_config.index.update({"persist_path": query.vector_db.faiss_path})

query_bundle = QueryBundle(query.question)

query_engine = module_registry.get_module_with_config(
"QueryEngineModule", self.config
"QueryEngineModule", sessioned_config
)
node_results = await query_engine.aretrieve(query_bundle)

Expand Down
4 changes: 2 additions & 2 deletions src/pai_rag/core/rag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def reload(self, new_config: Any):
self.rag.reload(self.rag_configuration.get_value())
self.rag_configuration.persist()

async def add_knowledge_async(
def add_knowledge_async(
self, task_id: str, file_dir: str, enable_qa_extraction: bool = False
):
self.tasks_status[task_id] = "processing"
try:
await self.rag.load_knowledge(file_dir, enable_qa_extraction)
self.rag.load_knowledge(file_dir, enable_qa_extraction)
self.tasks_status[task_id] = "completed"
except Exception as ex:
logger.error(f"Upload failed: {ex}")
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_rag_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ def rag_app():


# Test load knowledge file
async def test_add_knowledge_file(rag_app: RagApplication):
def test_add_knowledge_file(rag_app: RagApplication):
data_dir = os.path.join(BASE_DIR, "tests/testdata/paul_graham")
await rag_app.load_knowledge(data_dir)
rag_app.load_knowledge(data_dir)


# Test rag query
Expand Down

0 comments on commit 90d4b43

Please sign in to comment.