-
Notifications
You must be signed in to change notification settings - Fork 1
/
app.py
149 lines (127 loc) · 5.34 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import textwrap
import streamlit as st
import google.generativeai as genai
import pandas as pd
import numpy as np
from unstructured.partition.pdf import partition_pdf
from unstructured.chunking.title import chunk_by_title
import os
# Load your Gemini API key
genai.configure(api_key=os.environ["GEMINI_API_KEY"])
# Creating custom template to guide llm model
custom_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
# Extracting text from pdf (without unstructured package)
def extract_data(docs):
# Partition the PDF
elements = partition_pdf(docs[0], strategy="hi_res")
# Chunk the partitioned elements
chunked_elements = chunk_by_title(
elements,
max_characters=500,
new_after_n_chars=2500,
multipage_sections=True
)
return chunked_elements
# Get the embeddings of each text and add to an embeddings column in the dataframe
def embed_fn(text):
return genai.embed_content(model="models/embedding-001",
content=str(text),
task_type="retrieval_document")["embedding"]
def find_best_passage(query, text_embeddings):
"""
Compute the distances between the query and each document in the dataframe
using the dot product.
"""
query_embedding = genai.embed_content(model="models/embedding-001", content=query, task_type="retrieval_query")["embedding"]
similarities = np.dot(text_embeddings['embeddings'].tolist(), query_embedding)
best_passage_index = np.argmax(similarities)
best_passage = text_embeddings.iloc[best_passage_index]['text']
print(f'Best passage: {best_passage}')
return best_passage
def make_prompt(query, relevant_passage):
prompt = textwrap.dedent("""You are a helpful and informative bot that answers questions using text from the reference passage included below. \
Be sure to respond in a complete sentence, being comprehensive, including all relevant background information. \
However, you are talking to a non-technical audience, so be sure to break down complicated concepts and \
strike a friendly and converstional tone. \
If the passage is irrelevant to the answer, you may ignore it.
QUESTION: '{query}'
PASSAGE: '{relevant_passage}'
ANSWER:
""").format(query=query, relevant_passage=relevant_passage)
return prompt
def handle_question(question, dataframe):
relevant_passage = find_best_passage(question, dataframe)
prompt = make_prompt(question, relevant_passage)
model = genai.GenerativeModel('gemini-1.5-pro-latest')
answer = model.generate_content(prompt)
print(answer)
return answer
def main():
st.set_page_config(page_title="Chat with multiple PDFs", page_icon=":books:")
st.markdown(
"""
<style>
.user-message {
background-color: #f0f0f0;
color: black;
padding: 10px;
border-radius: 5px;
margin-bottom: 10px;
}
.bot-message {
background-color: #121212;
color: white;
padding: 10px;
border-radius: 5px;
margin-bottom: 10px;
}
</style>
""",
unsafe_allow_html=True,
)
if "conversation" not in st.session_state:
st.session_state.conversation = []
if "chat_history" not in st.session_state:
st.session_state.chat_history = None
st.header("Chat with multiple PDFs :books:")
question = st.text_input("Ask question from your document:")
if question and "embeddings" in st.session_state:
answer = handle_question(question, st.session_state.embeddings)
answer=answer.text
# Update chat history
st.session_state.conversation.append({"role": "user", "content": question})
st.session_state.conversation.append({"role": "bot", "content": answer})
st.markdown("## Chat History:")
for message in st.session_state.conversation:
if message["role"] == "user":
st.markdown(f'<div class="user-message">{message["content"]}</div>', unsafe_allow_html=True)
elif message["role"] == "bot":
st.markdown(f'<div class="bot-message">{message["content"]}</div>', unsafe_allow_html=True)
with st.sidebar:
st.subheader("Your documents")
docs = st.file_uploader(
"Upload your PDF here and click on 'Process'", accept_multiple_files=True, type=['pdf']
)
if docs:
saved_files = []
for doc in docs:
file_path = "./uploads/" + doc.name
os.makedirs("./uploads/", exist_ok=True)
with open(file_path, "wb") as f:
f.write(doc.getbuffer())
saved_files.append(file_path)
st.session_state.uploaded_files = saved_files
if st.button("Process"):
with st.spinner("Processing"):
chunked_elements = extract_data(st.session_state.uploaded_files)
st.success("Files have been Chunked.")
df = pd.DataFrame({'text': chunked_elements, 'embeddings': [embed_fn(text) for text in chunked_elements]})
st.success("Embeddings have been successfully generated.")
st.session_state.embeddings = df
if __name__ == "__main__":
main()