Skip to content

Commit

Permalink
Merge pull request #78 from tgourdel/main
Browse files Browse the repository at this point in the history
add local-rag-pdf app
  • Loading branch information
RichmondAlake authored Jan 31, 2025
2 parents 9d9429f + 870942d commit ddf02d0
Show file tree
Hide file tree
Showing 5 changed files with 331 additions and 0 deletions.
1 change: 1 addition & 0 deletions apps/local-rag-pdf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# local-rag-deepseek-mongodb
145 changes: 145 additions & 0 deletions apps/local-rag-pdf/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import os
import tempfile
import time
import re
import streamlit as st
from rag_module import ChatPDF

st.set_page_config(page_title="Local RAG with MongoDB and DeepSeek")


def display_messages():
"""Display the chat history using Streamlit's native chat interface."""
st.subheader("Chat History")
for message in st.session_state["messages"]:
with st.chat_message(message["role"]):
if message["role"] == "assistant":
# Process the content to hide <think>...</think> blocks
content = message["content"]
# Use regex to find all <think>...</think> blocks
think_blocks = re.findall(r'<think>(.*?)</think>', content, re.DOTALL)
# Remove all <think>...</think> blocks from the visible content
visible_content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()

# Display the visible content
st.markdown(visible_content)

# For each think block, add an expander to show the hidden content
for think in think_blocks:
with st.expander("Show Hidden Reasoning", expanded=False):
st.markdown(think)
else:
# For user and system messages, display normally
st.markdown(message["content"])


def process_query():
"""Process the user input and generate an assistant response."""
user_input = st.session_state.get("user_input", "").strip()
if user_input:
# Add user message to chat history
st.session_state["messages"].append({"role": "user", "content": user_input})

with st.chat_message("user"):
st.markdown(user_input)

# Prepare conversation history for context (excluding system messages if any)
conversation_history = [
msg["content"] for msg in st.session_state["messages"] if msg["role"] != "system"
]

# Display assistant response
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
try:
# Generate the assistant response with context
agent_text = st.session_state["assistant"].query_with_context(
user_input,
conversation_history=conversation_history,
k=st.session_state["retrieval_k"],
score_threshold=st.session_state["retrieval_threshold"],
)
except ValueError as e:
agent_text = str(e)

st.markdown(agent_text)

# Add assistant response to chat history
st.session_state["messages"].append({"role": "assistant", "content": agent_text})

# Clear the input box
st.session_state["user_input"] = ""


def upload_and_index_file():
"""Handle file upload and ingestion."""
st.session_state["assistant"].reset_retriever()
st.session_state["messages"] = []
st.session_state["user_input"] = ""

for file in st.session_state["file_uploader"]:
with tempfile.NamedTemporaryFile(delete=False) as tf:
tf.write(file.getbuffer())
file_path = tf.name

with st.session_state["ingestion_spinner"], st.spinner(f"Uploading and indexing {file.name}..."):
t0 = time.time()
st.session_state["assistant"].upload_and_index_pdf(file_path)
t1 = time.time()

st.session_state["messages"].append(
{"role": "system", "content": f"Uploaded and indexed {file.name} in {t1 - t0:.2f} seconds"}
)
os.remove(file_path)


def initialize_session_state():
"""Initialize session state variables."""
if "messages" not in st.session_state:
st.session_state["messages"] = []
if "assistant" not in st.session_state:
st.session_state["assistant"] = ChatPDF()
if "ingestion_spinner" not in st.session_state:
st.session_state["ingestion_spinner"] = st.empty()
if "retrieval_k" not in st.session_state:
st.session_state["retrieval_k"] = 5 # Default value
if "retrieval_threshold" not in st.session_state:
st.session_state["retrieval_threshold"] = 0.2 # Default value
if "user_input" not in st.session_state:
st.session_state["user_input"] = ""


def page():
"""Main app page layout."""
initialize_session_state()

st.header("Local RAG with MongoDB and DeepSeek")

st.subheader("Upload a Document")
st.file_uploader(
"Upload a PDF document",
type=["pdf"],
key="file_uploader",
on_change=upload_and_index_file,
label_visibility="collapsed",
accept_multiple_files=True,
)

# Display messages and text input
display_messages()

# Accept user input using the new chat input
prompt = st.chat_input("Type your message here...")
if prompt:
st.session_state["user_input"] = prompt
process_query()

# Clear chat
if st.button("Clear Chat"):
st.session_state["messages"] = []
st.session_state["assistant"].reset_retriever()
st.session_state["user_input"] = ""


if __name__ == "__main__":
page()
5 changes: 5 additions & 0 deletions apps/local-rag-pdf/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
llm_model: "deepseek-r1:8b"
embedding_model: "nomic-embed-text"
mongo_connection_str: "mongodb://localhost:27017/?directConnection=true"
database_name: "knowledge_base"
collection_name: "documents"
173 changes: 173 additions & 0 deletions apps/local-rag-pdf/rag_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from langchain_core.globals import set_verbose, set_debug
from langchain_ollama import ChatOllama, OllamaEmbeddings
from langchain.schema.output_parser import StrOutputParser
from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch
from pymongo import MongoClient
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema.runnable import RunnablePassthrough
from langchain_community.vectorstores.utils import filter_complex_metadata
from langchain_core.prompts import ChatPromptTemplate
import logging
import yaml


# Enable verbose debugging
set_debug(True)
set_verbose(True)

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def load_config(config_file: str = "config.yaml"):
"""Load configuration from a YAML file."""
with open(config_file, "r") as file:
return yaml.safe_load(file)

class ChatPDF:
"""A class designed for PDF ingestion and question answering using RAG with detailed debugging logs."""

def __init__(self, config_file: str = "config.yaml"):
"""
Initialize the ChatPDF instance using configuration from a YAML file.
"""
config = load_config(config_file)

# Read values from config
llm_model = config["llm_model"]
embedding_model = config["embedding_model"]
mongo_connection_str = config["mongo_connection_str"]
database_name = config["database_name"]
collection_name = config["collection_name"]

self.model = ChatOllama(model=llm_model)
self.embeddings = OllamaEmbeddings(model=embedding_model)
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=100)
self.prompt = ChatPromptTemplate.from_template(
"""
You are a helpful assistant answering questions based on the uploaded document and the conversation.
Conversation History:
{conversation_history}
Context from Documents:
{context}
Question:
{question}
Provide a concise, accurate answer (preferably within three sentences), ensuring it directly addresses the question.
"""
)

# Setup MongoDB connection
self.client = MongoClient(mongo_connection_str)
self.collection = self.client[database_name][collection_name]

# Verbose connection check
doc_count = self.collection.count_documents({})
logger.info(f"MongoDB Connection Established - Document Count: {doc_count}")

# Initialize the vector store with MongoDB Atlas
self.vector_store = MongoDBAtlasVectorSearch(
collection=self.collection,
embedding=self.embeddings,
index_name="vector_index",
relevance_score_fn="cosine"
)

# Create vector search index on the collection
# Adjust dimensions based on your embedding model
self.vector_store.create_vector_search_index(dimensions=768)

logger.info("Vector Store Initialized")

self.retriever = None

def upload_and_index_pdf(self, pdf_file_path: str):
"""
Upload and index a PDF file, chunk its contents, and store the embeddings in MongoDB Atlas.
"""
logger.info(f"Starting ingestion for file: {pdf_file_path}")
docs = PyPDFLoader(file_path=pdf_file_path).load()

logger.info(f"Loaded {len(docs)} pages from {pdf_file_path}")

chunks = self.text_splitter.split_documents(docs)
logger.info(f"Split into {len(chunks)} document chunks")

# Optional: Log some sample chunks for verification
for i, chunk in enumerate(chunks[:3]):
logger.debug(f"Chunk {i+1} Content: {chunk.page_content[:200]}...")

chunks = filter_complex_metadata(chunks)

# Add documents to vector store and check embeddings
self.vector_store.add_documents(documents=chunks)
logger.info("Document embeddings stored successfully in MongoDB Atlas.")

def query_with_context(self, query: str, conversation_history: list = None, k: int = 5, score_threshold: float = 0.2):
"""
Answer a query using the RAG pipeline with verbose debugging and conversation history.
Parameters:
- query (str): The user's question.
- conversation_history (list): List of previous messages in the conversation.
- k (int): Number of retrieved documents.
- score_threshold (float): Similarity score threshold for retrieval.
Returns:
- str: The assistant's response.
"""
if not self.vector_store:
raise ValueError("No vector store found. Please ingest a document first.")

if not self.retriever:
self.retriever = self.vector_store.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": k, "score_threshold": score_threshold},
)

# Generate and log query embeddings
query_embedding = self.embeddings.embed_query(query)
logger.info(f"User Query: {query}")
logger.debug(f"Query Embedding (sample values): {query_embedding[:10]}... [Total Length: {len(query_embedding)}]")

logger.info(f"Retrieving context for query: {query}")
retrieved_docs = self.retriever.invoke(query)

if not retrieved_docs:
logger.warning("No relevant documents retrieved.")
return "No relevant context found in the document to answer your question."

logger.info(f"Retrieved {len(retrieved_docs)} document(s)")
for i, doc in enumerate(retrieved_docs):
logger.debug(f"Document {i+1}: {doc.page_content[:200]}...")

# Format the input for the LLM, including conversation history
formatted_input = {
"conversation_history": "\n".join(conversation_history) if conversation_history else "",
"context": "\n\n".join(doc.page_content for doc in retrieved_docs),
"question": query,
}

# Build the RAG chain
chain = (
RunnablePassthrough() # Passes the input as-is
| self.prompt # Formats the input for the LLM
| self.model # Queries the LLM
| StrOutputParser() # Parses the LLM's output
)

logger.info("Generating response using the LLM.")
response = chain.invoke(formatted_input)
logger.debug(f"LLM Response: {response}")
return response

def reset_retriever(self):
"""
Reset the retriever and optionally clear the vector store or other states.
"""
logger.info("Resetting retriever and clearing state.")
self.retriever = None
7 changes: 7 additions & 0 deletions apps/local-rag-pdf/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
streamlit
langchain
langchain_ollama
langchain_community
langchain-mongodb
pymongo
pypdf

0 comments on commit ddf02d0

Please sign in to comment.