-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfrontend.py
127 lines (96 loc) · 4.44 KB
/
frontend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import streamlit as st
import os
# from dotenv import load_dotenv
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
# load_dotenv()
#
# openai_api_key = os.getenv('OPENAI_API_KEY')
# load private local file
# loading PDF, DOCX and TXT files as LangChain Documents
def load_document(file):
import os
name, extension = os.path.splitext(file)
if extension == '.pdf':
from langchain_community.document_loaders import PyPDFLoader
print(f'Loading {file}')
loader = PyPDFLoader(file)
elif extension == '.docx':
from langchain.document_loaders import Docx2txtLoader
print(f'Loading {file}')
loader = Docx2txtLoader(file)
elif extension == '.txt':
from langchain.document_loaders import TextLoader
loader = TextLoader(file)
else:
print('Document format is not supported!')
return None
data = loader.load()
return data
def chunk_data(data, chunk_size=256, chunk_overlap=20):
from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
chunks = text_splitter.split_documents(data)
return chunks
def create_embeddings(chunks):
embeddings = OpenAIEmbeddings()
vector_store = Chroma.from_documents(chunks, embeddings)
return vector_store
def ask_and_get_answer(vector_store, q, k=3):
from langchain.chains import RetrievalQA
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model='gpt-3.5-turbo', temperature=1)
retriever = vector_store.as_retriever(search_type='similarity', search_kwargs={'k': k})
chain = RetrievalQA.from_chain_type(llm=llm, chain_type='stuff', retriever=retriever)
answer = chain.invoke(q)
return answer
# embedding cost
def calculate_embedding_cost(texts):
import tiktoken
enc = tiktoken.encoding_for_model('text-embedding-ada-002')
total_tokens = sum([len(enc.encode(page.page_content)) for page in texts])
return total_tokens, total_tokens / 1000 * .0004
def clear_history():
if 'history' in st.session_state:
del st.session_state['history']
if __name__ == '__main__':
import os
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv(), override=True)
with st.sidebar:
api_key = st.text_input('OpenAI API key:', type='password')
if api_key:
os.environ['OPENAI_API_KEY'] = api_key
# widgets
uploaded_file = st.file_uploader('Upload a file of type pdf, txt, or docx:', type=['pdf', 'docx', 'txt'])
chunk_size = st.number_input('Chunk size:', min_value=100, max_value=2048, value=512, on_change=clear_history)
k = st.number_input('k', min_value=1, max_value=20, value=3, on_change=clear_history)
add_data = st.button('Add Data', on_click=clear_history)
if uploaded_file and add_data:
with st.spinner('Reading, chunking and embedding file'):
bytes_data = uploaded_file.read()
file_name = os.path.join('./', uploaded_file.name)
with open(file_name, 'wb') as f:
f.write(bytes_data)
data = load_document(file_name)
chunks = chunk_data(data, chunk_size=chunk_size)
st.write(f'Chunk size: {chunk_size}, Chunks: {len(chunks)}')
tokens, embedding_cost = calculate_embedding_cost(chunks)
st.write(f'Embedding cost: {embedding_cost:.4f}')
vector_store = create_embeddings(chunks)
st.session_state.vs = vector_store
st.success('File uploaded, chunked, and embedded successfully')
question = st.text_input('Ask a question regarding the content of your file:')
if question:
if 'vs' in st.session_state:
vector_store = st.session_state.vs
answer = ask_and_get_answer(vector_store, question, k)
st.text_area('Question:', value=question)
st.text_area('LLM Answer:', value=answer['result'])
st.divider()
if 'history' not in st.session_state:
st.session_state.history = ''
value = f'Question: {question} \nAnswer: {answer['result']}'
st.session_state.history = f'{value} \n {"-" * 100} \n {st.session_state.history}'
h = st.session_state.history
st.text_area(label='Chat history', value=h, key='history', height=400)