-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchat.py
227 lines (177 loc) · 8 KB
/
chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import os
import re
from flask import Flask, render_template, request, jsonify
from werkzeug.utils import secure_filename
import chromadb
from typing import List
import google.generativeai as genai
from chromadb import Documents, EmbeddingFunction, Embeddings
import fitz
import datetime
import uuid
from pymongo import MongoClient
import io
os.environ["GEMINI_API_KEY"] = "YOU_API_KEY"
# Initialize Flask app
app = Flask(__name__)
app.config['ALLOWED_EXTENSIONS'] = {'pdf', 'txt', 'docx'}
# Configure MongoDB connection
mongo_client = MongoClient("mongodb://localhost:27017/") # Update with your MongoDB URI
db = mongo_client['RAG'] # Replace with your MongoDB database name
file_collection = db['files']
chroma_collection_metadata = db['chroma_collection_metadata']
# Global variable to store the name of the latest collection
latest_collection_name = None
# Split text into chunks
def split_text(text: str):
split_text = re.split(r'\n \n', text)
return [chunk for chunk in split_text if chunk]
# Custom embedding function with Gemini API
class GeminiEmbeddingFunction(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
gemini_api_key = os.getenv("GEMINI_API_KEY")
if not gemini_api_key:
raise ValueError("Gemini API Key not provided. Set GEMINI_API_KEY as an environment variable.")
genai.configure(api_key=gemini_api_key)
model = "models/embedding-001"
title = "Custom query"
return genai.embed_content(
model=model,
content=input,
task_type="retrieval_document",
title=title
)["embedding"]
# Function to create and populate ChromaDB with chunks
from bson import ObjectId
# Function to create and populate ChromaDB with chunks
def create_chroma_db(documents: List[str], name: str):
# Store document chunks in MongoDB and get document IDs
doc_ids = []
for doc in documents:
doc_id = file_collection.insert_one({"collection_name": name, "content": doc}).inserted_id
doc_ids.append(doc_id) # Store ObjectId directly
# Initialize ChromaDB client (no `database_uri` argument used)
chroma_client = chromadb.PersistentClient()
db = chroma_client.create_collection(name=name, embedding_function=GeminiEmbeddingFunction())
# Populate ChromaDB with document chunks from MongoDB
for doc_id in doc_ids:
# Convert ObjectId to match MongoDB format
content = file_collection.find_one({"_id": ObjectId(doc_id)})
if content: # Check if document exists
db.add(documents=[content["content"]], ids=[str(doc_id)])
else:
print(f"Warning: Document with ID {doc_id} not found in MongoDB.")
# Save metadata for easy lookup
chroma_collection_metadata.insert_one({"name": name, "doc_ids": [str(doc_id) for doc_id in doc_ids]})
return db, name
# Function to load an existing ChromaDB collection
def load_chroma_collection(name: str):
# Retrieve ChromaDB collection
chroma_client = chromadb.PersistentClient()
return chroma_client.get_collection(name=name, embedding_function=GeminiEmbeddingFunction())
# Retrieve relevant passages from ChromaDB
def get_relevant_passage(query: str, db, n_results: int):
passages = db.query(query_texts=[query], n_results=n_results)['documents']
return passages[0] if passages else []
# Create RAG prompt
def make_rag_prompt(query: str, relevant_passage: str) -> str:
escaped = relevant_passage.replace("'", "").replace('"', "").replace("\n", " ")
prompt = ("""You are a helpful and informative bot that answers questions using text from the reference passage included below. \
Be sure to respond in a complete sentence, being comprehensive, including all relevant background information. \
However, you are talking to a non-technical audience, so be sure to break down complicated concepts and \
strike a friendly and conversational tone. \
If the passage is irrelevant to the answer, you may ignore it.
QUESTION: '{query}'
PASSAGE: '{relevant_passage}'
ANSWER:
""").format(query=query, relevant_passage=escaped)
return prompt
# Generate response using Gemini API
def generate_response(prompt):
gemini_api_key = os.getenv("GEMINI_API_KEY")
if not gemini_api_key:
raise ValueError("Gemini API Key not provided. Please provide GEMINI_API_KEY as an environment variable")
genai.configure(api_key=gemini_api_key)
model = genai.GenerativeModel('gemini-pro')
answer = model.generate_content(prompt)
return answer.text
# Generate answer using RAG pipeline
def generate_answer(db, query):
relevant_text_chunks = get_relevant_passage(query, db, n_results=3)
combined_text = "".join(relevant_text_chunks) if relevant_text_chunks else "No relevant information found."
prompt = make_rag_prompt(query, relevant_passage=combined_text)
response = generate_response(prompt)
return response
# Function to check allowed file extensions
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']
# Function to extract text from PDF (or file-like object)
def extract_text_from_pdf(file_content):
with fitz.open("pdf", file_content) as doc:
pdf_text = ""
for page in doc:
pdf_text += page.get_text()
return pdf_text
#Routes { Upload document, get response from the chat}
@app.route("/")
def index():
return render_template('chat.html')
# Route to handle document upload, splitting, and ChromaDB storage
@app.route("/upload", methods=["POST"])
def upload_document():
global latest_collection_name # Use global to update the latest collection name
if 'document' not in request.files:
return "No file part", 400
file = request.files['document']
if file.filename == '':
return "No selected file", 400
# Check if the file is allowed
if not allowed_file(file.filename):
return "Invalid file type", 400
filename = secure_filename(file.filename)
file_content = file.read() # Read file content into memory
file_metadata = {
"filename": filename,
"upload_date": datetime.datetime.now(),
"file_type": filename.split('.')[-1],
"file_content": file_content # Store binary data
}
file_metadata_id = file_collection.insert_one(file_metadata).inserted_id
# Determine file type and extract text accordingly
if filename.endswith('.pdf'):
pdf_text = extract_text_from_pdf(io.BytesIO(file_content))
elif filename.endswith('.txt'):
pdf_text = file_content.decode('utf-8')
elif filename.endswith('.docx'):
from docx import Document
doc = Document(io.BytesIO(file_content))
pdf_text = "\n".join([para.text for para in doc.paragraphs])
else:
return "Unsupported file type", 400
# Split the document into chunks
chunked_text = split_text(pdf_text)
# Generate a unique collection name for each upload
db_name = f"rag_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
db, name = create_chroma_db(documents=chunked_text, name=db_name)
# Update latest collection name in MongoDB
latest_collection_name = db_name
chroma_collection_metadata.update_one(
{"_id": file_metadata_id},
{"$set": {"chroma_collection_name": db_name}}
)
return f"Document uploaded, processed, and indexed with collection name: {db_name}", 200
# Route to handle chat messages using RAG-based response generation
@app.route("/get", methods=["POST"])
def chat():
global latest_collection_name # Access the latest collection name
user_message = request.form["msg"]
# Check if a collection is available
if not latest_collection_name:
return "No document uploaded yet.", 400
# Load the most recent ChromaDB collection
db = load_chroma_collection(name=latest_collection_name)
# Generate answer using RAG pipeline
response = generate_answer(db=db, query=user_message)
return response
if __name__ == "__main__":
app.run(debug=True)