-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstreamlit-rag-app.py
115 lines (87 loc) · 4.1 KB
/
streamlit-rag-app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import streamlit as st
import os
import json
from dotenv import load_dotenv
from langchain.chains import RetrievalQA
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import CharacterTextSplitter
from langchain_openai import ChatOpenAI
from langchain.schema import Document
from langchain_huggingface import HuggingFaceEmbeddings
# Load environment variables
load_dotenv()
# Get the OpenAI API key from the environment
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
st.error("OPENAI_API_KEY is not set. Please add it to your .env file.")
# takes a few minutes to load
file = open("bpl_data.json")
bpl = json.load(file)
#file = open("bpl_data.json")
#bpl = json.load(file)
# Initialize session state variables
if 'vector_store' not in st.session_state:
st.session_state.vector_store = None
if 'qa_chain' not in st.session_state:
st.session_state.qa_chain = None
def setup_qa_chain(vector_store):
"""Set up the QA chain with a retriever."""
retriever = vector_store.as_retriever(search_kwargs={"k": 3})
llm = ChatOpenAI(model="gpt-3.5-turbo", openai_api_key=OPENAI_API_KEY)
qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever, return_source_documents=True)
return qa_chain
def setup_custom_chain(vector_store):
retriever = vector_store.as_retriever(search_kwargs={"k": 3})
llm = ChatOpenAI(model="gpt-3.5-turbo", openai_api_key=OPENAI_API_KEY)
docs = retriever.invoke()
def main():
# Set page title and header
st.set_page_config(page_title="LibRAG", page_icon="📚")
st.title("Boston Public Library Database 📚")
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# Sidebar for initialization
# st.sidebar.header("Initialize Knowledge Base")
# if st.sidebar.button("Load Data"):
# try:
# st.session_state.vector_store = FAISS.load_local(
# "vector-store", embeddings, allow_dangerous_deserialization=True
# )
# st.session_state.qa_chain = setup_qa_chain(st.session_state.vector_store)
# st.sidebar.success("Knowledge base loaded successfully!")
# except Exception as e:
# st.sidebar.error(f"Error loading data: {e}")
st.session_state.vector_store = FAISS.load_local("vector-store", embeddings, allow_dangerous_deserialization=True)
st.session_state.qa_chain = setup_qa_chain(st.session_state.vector_store)
# Query input and processing
st.header("Ask a Question")
query = st.text_input("Enter your question about BPL's database")
if query:
# Check if vector store and QA chain are initialized
if st.session_state.qa_chain is None:
st.warning("Please load the knowledge base first using the sidebar.")
else:
# Run the query
try:
response = st.session_state.qa_chain({"query": query})
# Display answer
st.subheader("Answer")
st.write(response["result"])
# Display sources
st.subheader("Sources")
sources = response["source_documents"]
for i, doc in enumerate(sources, 1):
source = doc.metadata["source"]
abstract = None
# find the specific source:
for j in range(len(bpl["Data"])):
ID = bpl['Data'][j]["id"]
if doc.metadata['source'] == ID:
abstract = bpl["Data"][j]['attributes']['abstract_tsi']
break
with st.expander(f"Source {i}"):
st.write(f"**Content:** {abstract}")
st.write(f"**URL:** https://www.digitalcommonwealth.org/search/{doc.metadata['source']}")
except Exception as e:
st.error(f"An error occurred: {e}")
if __name__ == "__main__":
main()