-
Notifications
You must be signed in to change notification settings - Fork 154
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #78 from tgourdel/main
add local-rag-pdf app
- Loading branch information
Showing
5 changed files
with
331 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# local-rag-deepseek-mongodb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import os | ||
import tempfile | ||
import time | ||
import re | ||
import streamlit as st | ||
from rag_module import ChatPDF | ||
|
||
st.set_page_config(page_title="Local RAG with MongoDB and DeepSeek") | ||
|
||
|
||
def display_messages(): | ||
"""Display the chat history using Streamlit's native chat interface.""" | ||
st.subheader("Chat History") | ||
for message in st.session_state["messages"]: | ||
with st.chat_message(message["role"]): | ||
if message["role"] == "assistant": | ||
# Process the content to hide <think>...</think> blocks | ||
content = message["content"] | ||
# Use regex to find all <think>...</think> blocks | ||
think_blocks = re.findall(r'<think>(.*?)</think>', content, re.DOTALL) | ||
# Remove all <think>...</think> blocks from the visible content | ||
visible_content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip() | ||
|
||
# Display the visible content | ||
st.markdown(visible_content) | ||
|
||
# For each think block, add an expander to show the hidden content | ||
for think in think_blocks: | ||
with st.expander("Show Hidden Reasoning", expanded=False): | ||
st.markdown(think) | ||
else: | ||
# For user and system messages, display normally | ||
st.markdown(message["content"]) | ||
|
||
|
||
def process_query(): | ||
"""Process the user input and generate an assistant response.""" | ||
user_input = st.session_state.get("user_input", "").strip() | ||
if user_input: | ||
# Add user message to chat history | ||
st.session_state["messages"].append({"role": "user", "content": user_input}) | ||
|
||
with st.chat_message("user"): | ||
st.markdown(user_input) | ||
|
||
# Prepare conversation history for context (excluding system messages if any) | ||
conversation_history = [ | ||
msg["content"] for msg in st.session_state["messages"] if msg["role"] != "system" | ||
] | ||
|
||
# Display assistant response | ||
with st.chat_message("assistant"): | ||
with st.spinner("Thinking..."): | ||
try: | ||
# Generate the assistant response with context | ||
agent_text = st.session_state["assistant"].query_with_context( | ||
user_input, | ||
conversation_history=conversation_history, | ||
k=st.session_state["retrieval_k"], | ||
score_threshold=st.session_state["retrieval_threshold"], | ||
) | ||
except ValueError as e: | ||
agent_text = str(e) | ||
|
||
st.markdown(agent_text) | ||
|
||
# Add assistant response to chat history | ||
st.session_state["messages"].append({"role": "assistant", "content": agent_text}) | ||
|
||
# Clear the input box | ||
st.session_state["user_input"] = "" | ||
|
||
|
||
def upload_and_index_file(): | ||
"""Handle file upload and ingestion.""" | ||
st.session_state["assistant"].reset_retriever() | ||
st.session_state["messages"] = [] | ||
st.session_state["user_input"] = "" | ||
|
||
for file in st.session_state["file_uploader"]: | ||
with tempfile.NamedTemporaryFile(delete=False) as tf: | ||
tf.write(file.getbuffer()) | ||
file_path = tf.name | ||
|
||
with st.session_state["ingestion_spinner"], st.spinner(f"Uploading and indexing {file.name}..."): | ||
t0 = time.time() | ||
st.session_state["assistant"].upload_and_index_pdf(file_path) | ||
t1 = time.time() | ||
|
||
st.session_state["messages"].append( | ||
{"role": "system", "content": f"Uploaded and indexed {file.name} in {t1 - t0:.2f} seconds"} | ||
) | ||
os.remove(file_path) | ||
|
||
|
||
def initialize_session_state(): | ||
"""Initialize session state variables.""" | ||
if "messages" not in st.session_state: | ||
st.session_state["messages"] = [] | ||
if "assistant" not in st.session_state: | ||
st.session_state["assistant"] = ChatPDF() | ||
if "ingestion_spinner" not in st.session_state: | ||
st.session_state["ingestion_spinner"] = st.empty() | ||
if "retrieval_k" not in st.session_state: | ||
st.session_state["retrieval_k"] = 5 # Default value | ||
if "retrieval_threshold" not in st.session_state: | ||
st.session_state["retrieval_threshold"] = 0.2 # Default value | ||
if "user_input" not in st.session_state: | ||
st.session_state["user_input"] = "" | ||
|
||
|
||
def page(): | ||
"""Main app page layout.""" | ||
initialize_session_state() | ||
|
||
st.header("Local RAG with MongoDB and DeepSeek") | ||
|
||
st.subheader("Upload a Document") | ||
st.file_uploader( | ||
"Upload a PDF document", | ||
type=["pdf"], | ||
key="file_uploader", | ||
on_change=upload_and_index_file, | ||
label_visibility="collapsed", | ||
accept_multiple_files=True, | ||
) | ||
|
||
# Display messages and text input | ||
display_messages() | ||
|
||
# Accept user input using the new chat input | ||
prompt = st.chat_input("Type your message here...") | ||
if prompt: | ||
st.session_state["user_input"] = prompt | ||
process_query() | ||
|
||
# Clear chat | ||
if st.button("Clear Chat"): | ||
st.session_state["messages"] = [] | ||
st.session_state["assistant"].reset_retriever() | ||
st.session_state["user_input"] = "" | ||
|
||
|
||
if __name__ == "__main__": | ||
page() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
llm_model: "deepseek-r1:8b" | ||
embedding_model: "nomic-embed-text" | ||
mongo_connection_str: "mongodb://localhost:27017/?directConnection=true" | ||
database_name: "knowledge_base" | ||
collection_name: "documents" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
from langchain_core.globals import set_verbose, set_debug | ||
from langchain_ollama import ChatOllama, OllamaEmbeddings | ||
from langchain.schema.output_parser import StrOutputParser | ||
from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch | ||
from pymongo import MongoClient | ||
from langchain_community.document_loaders import PyPDFLoader | ||
from langchain.text_splitter import RecursiveCharacterTextSplitter | ||
from langchain.schema.runnable import RunnablePassthrough | ||
from langchain_community.vectorstores.utils import filter_complex_metadata | ||
from langchain_core.prompts import ChatPromptTemplate | ||
import logging | ||
import yaml | ||
|
||
|
||
# Enable verbose debugging | ||
set_debug(True) | ||
set_verbose(True) | ||
|
||
# Configure logging | ||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
def load_config(config_file: str = "config.yaml"): | ||
"""Load configuration from a YAML file.""" | ||
with open(config_file, "r") as file: | ||
return yaml.safe_load(file) | ||
|
||
class ChatPDF: | ||
"""A class designed for PDF ingestion and question answering using RAG with detailed debugging logs.""" | ||
|
||
def __init__(self, config_file: str = "config.yaml"): | ||
""" | ||
Initialize the ChatPDF instance using configuration from a YAML file. | ||
""" | ||
config = load_config(config_file) | ||
|
||
# Read values from config | ||
llm_model = config["llm_model"] | ||
embedding_model = config["embedding_model"] | ||
mongo_connection_str = config["mongo_connection_str"] | ||
database_name = config["database_name"] | ||
collection_name = config["collection_name"] | ||
|
||
self.model = ChatOllama(model=llm_model) | ||
self.embeddings = OllamaEmbeddings(model=embedding_model) | ||
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=100) | ||
self.prompt = ChatPromptTemplate.from_template( | ||
""" | ||
You are a helpful assistant answering questions based on the uploaded document and the conversation. | ||
Conversation History: | ||
{conversation_history} | ||
Context from Documents: | ||
{context} | ||
Question: | ||
{question} | ||
Provide a concise, accurate answer (preferably within three sentences), ensuring it directly addresses the question. | ||
""" | ||
) | ||
|
||
# Setup MongoDB connection | ||
self.client = MongoClient(mongo_connection_str) | ||
self.collection = self.client[database_name][collection_name] | ||
|
||
# Verbose connection check | ||
doc_count = self.collection.count_documents({}) | ||
logger.info(f"MongoDB Connection Established - Document Count: {doc_count}") | ||
|
||
# Initialize the vector store with MongoDB Atlas | ||
self.vector_store = MongoDBAtlasVectorSearch( | ||
collection=self.collection, | ||
embedding=self.embeddings, | ||
index_name="vector_index", | ||
relevance_score_fn="cosine" | ||
) | ||
|
||
# Create vector search index on the collection | ||
# Adjust dimensions based on your embedding model | ||
self.vector_store.create_vector_search_index(dimensions=768) | ||
|
||
logger.info("Vector Store Initialized") | ||
|
||
self.retriever = None | ||
|
||
def upload_and_index_pdf(self, pdf_file_path: str): | ||
""" | ||
Upload and index a PDF file, chunk its contents, and store the embeddings in MongoDB Atlas. | ||
""" | ||
logger.info(f"Starting ingestion for file: {pdf_file_path}") | ||
docs = PyPDFLoader(file_path=pdf_file_path).load() | ||
|
||
logger.info(f"Loaded {len(docs)} pages from {pdf_file_path}") | ||
|
||
chunks = self.text_splitter.split_documents(docs) | ||
logger.info(f"Split into {len(chunks)} document chunks") | ||
|
||
# Optional: Log some sample chunks for verification | ||
for i, chunk in enumerate(chunks[:3]): | ||
logger.debug(f"Chunk {i+1} Content: {chunk.page_content[:200]}...") | ||
|
||
chunks = filter_complex_metadata(chunks) | ||
|
||
# Add documents to vector store and check embeddings | ||
self.vector_store.add_documents(documents=chunks) | ||
logger.info("Document embeddings stored successfully in MongoDB Atlas.") | ||
|
||
def query_with_context(self, query: str, conversation_history: list = None, k: int = 5, score_threshold: float = 0.2): | ||
""" | ||
Answer a query using the RAG pipeline with verbose debugging and conversation history. | ||
Parameters: | ||
- query (str): The user's question. | ||
- conversation_history (list): List of previous messages in the conversation. | ||
- k (int): Number of retrieved documents. | ||
- score_threshold (float): Similarity score threshold for retrieval. | ||
Returns: | ||
- str: The assistant's response. | ||
""" | ||
if not self.vector_store: | ||
raise ValueError("No vector store found. Please ingest a document first.") | ||
|
||
if not self.retriever: | ||
self.retriever = self.vector_store.as_retriever( | ||
search_type="similarity_score_threshold", | ||
search_kwargs={"k": k, "score_threshold": score_threshold}, | ||
) | ||
|
||
# Generate and log query embeddings | ||
query_embedding = self.embeddings.embed_query(query) | ||
logger.info(f"User Query: {query}") | ||
logger.debug(f"Query Embedding (sample values): {query_embedding[:10]}... [Total Length: {len(query_embedding)}]") | ||
|
||
logger.info(f"Retrieving context for query: {query}") | ||
retrieved_docs = self.retriever.invoke(query) | ||
|
||
if not retrieved_docs: | ||
logger.warning("No relevant documents retrieved.") | ||
return "No relevant context found in the document to answer your question." | ||
|
||
logger.info(f"Retrieved {len(retrieved_docs)} document(s)") | ||
for i, doc in enumerate(retrieved_docs): | ||
logger.debug(f"Document {i+1}: {doc.page_content[:200]}...") | ||
|
||
# Format the input for the LLM, including conversation history | ||
formatted_input = { | ||
"conversation_history": "\n".join(conversation_history) if conversation_history else "", | ||
"context": "\n\n".join(doc.page_content for doc in retrieved_docs), | ||
"question": query, | ||
} | ||
|
||
# Build the RAG chain | ||
chain = ( | ||
RunnablePassthrough() # Passes the input as-is | ||
| self.prompt # Formats the input for the LLM | ||
| self.model # Queries the LLM | ||
| StrOutputParser() # Parses the LLM's output | ||
) | ||
|
||
logger.info("Generating response using the LLM.") | ||
response = chain.invoke(formatted_input) | ||
logger.debug(f"LLM Response: {response}") | ||
return response | ||
|
||
def reset_retriever(self): | ||
""" | ||
Reset the retriever and optionally clear the vector store or other states. | ||
""" | ||
logger.info("Resetting retriever and clearing state.") | ||
self.retriever = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
streamlit | ||
langchain | ||
langchain_ollama | ||
langchain_community | ||
langchain-mongodb | ||
pymongo | ||
pypdf |