Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RAG into EmoLLM #191

Merged
merged 5 commits into from
Apr 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions rag/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ JSON 数据格式如下
如果已经有 vector DB 则会直接加载对应数据库


**注意**: 可以直接从 xlab 下载对应 DB(请在rag文件目录下执行对应 code)
```python
# https://openxlab.org.cn/models/detail/Anooyman/EmoLLMRAGTXT/tree/main
git lfs install
git clone https://code.openxlab.org.cn/Anooyman/EmoLLMRAGTXT.git
```


### 配置 config 文件

根据需要改写 config.config 文件:
Expand Down
5 changes: 3 additions & 2 deletions rag/src/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
knowledge_pkl_path = os.path.join(data_dir, 'knowledge.pkl') # pkl
doc_dir = os.path.join(data_dir, 'txt')
qa_dir = os.path.join(data_dir, 'json')
cloud_vector_db_dir = os.path.join(base_dir, 'EmoLLMRAGTXT')

# log
log_dir = os.path.join(base_dir, 'log') # log
Expand All @@ -30,13 +31,13 @@
chunk_overlap=100

# vector DB
vector_db_dir = os.path.join(data_dir, 'vector_db')
vector_db_dir = os.path.join(cloud_vector_db_dir, 'vector_db')

# RAG related
# select num: 代表rerank 之后选取多少个 documents 进入 LLM
# retrieval num: 代表从 vector db 中检索多少 documents。(retrieval num 应该大于等于 select num)
select_num = 3
retrieval_num = 10
retrieval_num = 3

# LLM key
glm_key = ''
Expand Down
7 changes: 3 additions & 4 deletions rag/src/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from loguru import logger
from langchain_community.vectorstores import FAISS
from config.config import (
from rag.src.config.config import (
embedding_path,
embedding_model_name,
doc_dir, qa_dir,
Expand All @@ -19,7 +19,6 @@
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.document_loaders import DirectoryLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.document_loaders import DirectoryLoader
from langchain_core.documents.base import Document
from FlagEmbedding import FlagReranker

Expand Down Expand Up @@ -199,7 +198,7 @@ def create_vector_db(self, emb_model):
创建并保存向量库
'''
logger.info(f'Creating index...')
#split_doc = self.split_document(doc_dir)
split_doc = self.split_document(doc_dir)
split_qa = self.split_conversation(qa_dir)
# logger.info(f'split_doc == {len(split_doc)}')
# logger.info(f'split_qa == {len(split_qa)}')
Expand All @@ -218,7 +217,7 @@ def load_vector_db(self, knowledge_pkl_path=knowledge_pkl_path, doc_dir=doc_dir,
if not os.path.exists(vector_db_dir) or not os.listdir(vector_db_dir):
db = self.create_vector_db(emb_model)
else:
db = FAISS.load_local(vector_db_dir, emb_model, allow_dangerous_deserialization=True)
db = FAISS.load_local(vector_db_dir, emb_model)
return db

if __name__ == "__main__":
Expand Down
12 changes: 6 additions & 6 deletions rag/src/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from langchain_core.prompts import PromptTemplate
from transformers.utils import logging

from data_processing import Data_process
from config.config import prompt_template
from rag.src.data_processing import Data_process
from rag.src.config.config import prompt_template
logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -48,19 +48,19 @@ def get_retrieval_content(self, query) -> str:
ouput: 检索后并且 rerank 的内容
"""

content = ''
content = []
documents = self.vectorstores.similarity_search(query, k=self.retrieval_num)

for doc in documents:
content += doc.page_content
content.append(doc.page_content)

# 如果需要rerank,调用接口对 documents 进行 rerank
if self.rerank_flag:
documents, _ = self.data_processing_obj.rerank(documents, self.select_num)

content = ''
content = []
for doc in documents:
content += doc
content.append(doc)
logger.info(f'Retrieval data: {content}')
return content

Expand Down
10 changes: 8 additions & 2 deletions web_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import os
import warnings
from dataclasses import asdict, dataclass
from rag.src.pipeline import EmoLLMRAG
from typing import Callable, List, Optional

import streamlit as st
Expand Down Expand Up @@ -188,8 +189,9 @@ def prepare_generation_config():
cur_query_prompt = "<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n"


def combine_history(prompt):
def combine_history(prompt, retrieval_content=''):
messages = st.session_state.messages
prompt = f"你需要根据以下从书本中检索到的专业知识:`{retrieval_content}`。从一个心理专家的专业角度来回答后续提问:{prompt}"
meta_instruction = (
"你是一个由aJupyter、Farewell、jujimeizuo、Smiling&Weeping研发(排名按字母顺序排序,不分先后)、散步提供技术支持、上海人工智能实验室提供支持开发的心理健康大模型。现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。"
)
Expand All @@ -211,6 +213,7 @@ def main():
# torch.cuda.empty_cache()
print("load model begin.")
model, tokenizer = load_model()
rag_obj = EmoLLMRAG(model)
print("load model end.")

user_avator = "assets/user.png"
Expand All @@ -232,9 +235,12 @@ def main():
# Accept user input
if prompt := st.chat_input("What is up?"):
# Display user message in chat message container
retrieval_content = rag_obj.get_retrieval_content(prompt)
with st.chat_message("user", avatar=user_avator):
st.markdown(prompt)
real_prompt = combine_history(prompt)
#st.markdown(retrieval_content)

real_prompt = combine_history(prompt, retrieval_content)
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt, "avatar": user_avator})

Expand Down