Skip to content

Commit

Permalink
Updata RAG (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
aJupyter authored Mar 24, 2024
2 parents 5cb2585 + c50b834 commit 7fa3a8b
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 116 deletions.
89 changes: 80 additions & 9 deletions rag/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,86 @@
- 经典案例
- 客户背景知识

## **环境准备**

```python

langchain==0.1.13
langchain_community==0.0.29
langchain_core==0.1.33
langchain_openai==0.0.8
langchain_text_splitters==0.0.1
FlagEmbedding==1.2.8
unstructured==0.12.6
```

```python

cd rag
pip3 install -r requirements.txt

```

## **使用指南**

### 准备数据

- txt数据:放入到 src.data.txt 目录下
- json 数据:放入到 src.data.json 目录下

会根据准备的数据构建vector DB,最终会在 data 文件夹下产生名为 vector_db 的文件夹包含 index.faiss 和 index.pkl

如果已经有 vector DB 则会直接加载对应数据库


### 配置 config 文件

根据需要改写 config.config 文件:

```python

# 存放所有 model
model_dir = os.path.join(base_dir, 'model')

# embedding model 路径以及 model name
embedding_path = os.path.join(model_dir, 'embedding_model')
embedding_model_name = 'BAAI/bge-small-zh-v1.5'


# rerank model 路径以及 model name
rerank_path = os.path.join(model_dir, 'rerank_model')
rerank_model_name = 'BAAI/bge-reranker-large'


# select num: 代表rerank 之后选取多少个 documents 进入 LLM
select_num = 3

# retrieval num: 代表从 vector db 中检索多少 documents。(retrieval num 应该大于等于 select num)
retrieval_num = 10

# 智谱 LLM 的 API key。目前 demo 仅支持智谱 AI api 作为最后生成
glm_key = ''

# Prompt template: 定义
prompt_template = """
你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n
根据下面检索回来的信息,回答问题。
{content}
问题:{query}
"""
```

### 调用

```python
cd rag/src
python main.py
```


## **数据集**

- 经过清洗的QA对: 每一个QA对作为一个样本进行 embedding
Expand Down Expand Up @@ -65,12 +145,3 @@ RAG的经典评估框架,通过以下三个方面进行评估:
- 增加多路检索以增加召回率。即根据用户输入生成多个类似的query进行检索











8 changes: 7 additions & 1 deletion rag/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,11 @@ sentence_transformers
transformers
numpy
loguru
langchain
torch
langchain==0.1.13
langchain_community==0.0.29
langchain_core==0.1.33
langchain_openai==0.0.8
langchain_text_splitters==0.0.1
FlagEmbedding==1.2.8
unstructured==0.12.6
21 changes: 15 additions & 6 deletions rag/src/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
# model
model_dir = os.path.join(base_dir, 'model') # model
embedding_path = os.path.join(model_dir, 'embedding_model') # embedding
rerank_path = os.path.join(model_dir, 'rerank_model') # embedding
embedding_model_name = 'BAAI/bge-small-zh-v1.5'
rerank_path = os.path.join(model_dir, 'rerank_model') # embedding
rerank_model_name = 'BAAI/bge-reranker-large'

llm_path = os.path.join(model_dir, 'pythia-14m') # llm

# data
Expand All @@ -23,15 +26,21 @@
log_path = os.path.join(log_dir, 'log.log') # file

# vector DB
vector_db_dir = os.path.join(data_dir, 'vector_db.pkl')
vector_db_dir = os.path.join(data_dir, 'vector_db')

# RAG related
# select num: 代表rerank 之后选取多少个 documents 进入 LLM
# retrieval num: 代表从 vector db 中检索多少 documents。(retrieval num 应该大于等于 select num)
select_num = 3
retrieval_num = 10
system_prompt = """
你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n
"""

# LLM key
glm_key = ''

# prompt
prompt_template = """
{system_prompt}
你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n
根据下面检索回来的信息,回答问题。
{content}
问题:{query}
Expand Down
32 changes: 12 additions & 20 deletions rag/src/data_processing.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,24 @@
import json
import pickle
import faiss
import pickle
import os

from loguru import logger
from sentence_transformers import SentenceTransformer
from langchain_community.vectorstores import FAISS
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir, vector_db_dir, rerank_path
from config.config import embedding_path, embedding_model_name, doc_dir, qa_dir, knowledge_pkl_path, data_dir, vector_db_dir, rerank_path, rerank_model_name
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter
from BCEmbedding import EmbeddingModel, RerankerModel
# from util.pipeline import EmoLLMRAG
from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain.document_loaders.pdf import PyPDFDirectoryLoader
from langchain.document_loaders import UnstructuredFileLoader,DirectoryLoader
from langchain_community.llms import Cohere
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import FlashrankRerank
from langchain_community.document_loaders import DirectoryLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.document_loaders import DirectoryLoader
from langchain_core.documents.base import Document
from FlagEmbedding import FlagReranker

class Data_process():

def __init__(self):
self.chunk_size: int=1000
self.chunk_overlap: int=100

def load_embedding_model(self, model_name='BAAI/bge-small-zh-v1.5', device='cpu', normalize_embeddings=True):
def load_embedding_model(self, model_name=embedding_model_name, device='cpu', normalize_embeddings=True):
"""
加载嵌入模型。
Expand Down Expand Up @@ -61,7 +52,8 @@ def load_embedding_model(self, model_name='BAAI/bge-small-zh-v1.5', device='cpu'
return None
return embeddings

def load_rerank_model(self, model_name='BAAI/bge-reranker-large'):
def load_rerank_model(self, model_name=rerank_model_name):

"""
加载重排名模型。
Expand Down Expand Up @@ -99,7 +91,6 @@ def load_rerank_model(self, model_name='BAAI/bge-reranker-large'):

return reranker_model


def extract_text_from_json(self, obj, content=None):
"""
抽取json中的文本,用于向量库构建
Expand Down Expand Up @@ -128,7 +119,8 @@ def extract_text_from_json(self, obj, content=None):
return content


def split_document(self, data_path, chunk_size=500, chunk_overlap=100):
def split_document(self, data_path):

"""
切分data_path文件夹下的所有txt文件
Expand All @@ -143,7 +135,7 @@ def split_document(self, data_path, chunk_size=500, chunk_overlap=100):


# text_spliter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
text_spliter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
text_spliter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
split_docs = []
logger.info(f'Loading txt files from {data_path}')
if os.path.isdir(data_path):
Expand Down Expand Up @@ -188,7 +180,7 @@ def split_conversation(self, path):
# split_qa.append(Document(page_content = content))
#按conversation块切分
content = self.extract_text_from_json(conversation['conversation'], '')
logger.info(f'content====={content}')
#logger.info(f'content====={content}')
split_qa.append(Document(page_content = content))
# logger.info(f'split_qa size====={len(split_qa)}')
return split_qa
Expand Down
104 changes: 29 additions & 75 deletions rag/src/main.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,6 @@
import os
import time
import jwt

from config.config import base_dir, data_dir
from data_processing import Data_process
from pipeline import EmoLLMRAG

from langchain_openai import ChatOpenAI
from util.llm import get_glm
from loguru import logger
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import streamlit as st
from openxlab.model import download
'''
1)构建完整的 RAG pipeline。输入为用户 query,输出为 answer
2)调用 embedding 提供的接口对 query 向量化
Expand All @@ -21,69 +10,34 @@
6)拼接 prompt 并调用模型返回结果
'''
def get_glm(temprature):
llm = ChatOpenAI(
model_name="glm-4",
openai_api_base="https://open.bigmodel.cn/api/paas/v4",
openai_api_key=generate_token("api-key"),
streaming=False,
temperature=temprature
)
return llm

def generate_token(apikey: str, exp_seconds: int=100):
try:
id, secret = apikey.split(".")
except Exception as e:
raise Exception("invalid apikey", e)

payload = {
"api_key": id,
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
"timestamp": int(round(time.time() * 1000)),
}

return jwt.encode(
payload,
secret,
algorithm="HS256",
headers={"alg": "HS256", "sign_type": "SIGN"},
)

@st.cache_resource
def load_model():
model_dir = os.path.join(base_dir,'../model')
logger.info(f'Loading model from {model_dir}')
model = (
AutoModelForCausalLM.from_pretrained('model', trust_remote_code=True)
.to(torch.bfloat16)
.cuda()
)
tokenizer = AutoTokenizer.from_pretrained('model', trust_remote_code=True)
return model, tokenizer

def main(query, system_prompt=''):
logger.info(data_dir)
if not os.path.exists(data_dir):
os.mkdir(data_dir)
dp = Data_process()
vector_db = dp.load_vector_db()
docs, retriever = dp.retrieve(query, vector_db, k=10)
logger.info(f'Query: {query}')
logger.info("Retrieve results===============================")
for i, doc in enumerate(docs):
logger.info(doc)
passages,scores = dp.rerank(query, docs)
logger.info("After reranking===============================")
for i in range(len(scores)):
logger.info(passages[i])
logger.info(f'score: {str(scores[i])}')

if __name__ == "__main__":
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
main(query)
#model = get_glm(0.7)
#rag_obj = EmoLLMRAG(model, 3)
#res = rag_obj.main(query)
#logger.info(res)
query = """
我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。
无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想
"""

"""
输入:
model_name='glm-4',
api_base="https://open.bigmodel.cn/api/paas/v4",
temprature=0.7,
streaming=False,
输出:
LLM Model
"""
model = get_glm()

"""
输入:
LLM model
retrieval_num=3
rerank_flag=False
select_num-3
"""
rag_obj = EmoLLMRAG(model)

res = rag_obj.main(query)

logger.info(res)

8 changes: 3 additions & 5 deletions rag/src/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from transformers.utils import logging

from data_processing import Data_process
from config.config import system_prompt, prompt_template
from config.config import prompt_template
logger = logging.get_logger(__name__)


Expand All @@ -16,7 +16,7 @@ class EmoLLMRAG(object):
4. 将 query 和检索回来的 content 传入 LLM 中
"""

def __init__(self, model, retrieval_num, rerank_flag=False, select_num=3) -> None:
def __init__(self, model, retrieval_num=3, rerank_flag=False, select_num=3) -> None:
"""
输入 Model 进行初始化
Expand All @@ -29,7 +29,6 @@ def __init__(self, model, retrieval_num, rerank_flag=False, select_num=3) -> Non
self.model = model
self.data_processing_obj = Data_process()
self.vectorstores = self._load_vector_db()
self.system_prompt = system_prompt
self.prompt_template = prompt_template
self.retrieval_num = retrieval_num
self.rerank_flag = rerank_flag
Expand Down Expand Up @@ -75,7 +74,7 @@ def generate_answer(self, query, content) -> str:
# 第一版不涉及 history 信息,因此将 system prompt 直接纳入到 template 之中
prompt = PromptTemplate(
template=self.prompt_template,
input_variables=["query", "content", "system_prompt"],
input_variables=["query", "content"],
)

# 定义 chain
Expand All @@ -87,7 +86,6 @@ def generate_answer(self, query, content) -> str:
{
"query": query,
"content": content,
"system_prompt": self.system_prompt
}
)
return generation
Expand Down
Loading

0 comments on commit 7fa3a8b

Please sign in to comment.