From 535172d96b702234d26ee7ace4b1634723f3c9b6 Mon Sep 17 00:00:00 2001 From: Manolo Santos Date: Mon, 25 Nov 2024 16:51:30 +0100 Subject: [PATCH 01/12] feat: Add document id and chunk ids to segments fixing order. --- src/raglite/_database.py | 159 +++++++++++++++++++++++++++++++-------- src/raglite/_eval.py | 13 ++-- src/raglite/_rag.py | 81 +++++++++++++------- src/raglite/_search.py | 78 +++++++++---------- tests/test_search.py | 4 +- 5 files changed, 225 insertions(+), 110 deletions(-) diff --git a/src/raglite/_database.py b/src/raglite/_database.py index 36a7fe4..a9b779f 100644 --- a/src/raglite/_database.py +++ b/src/raglite/_database.py @@ -2,25 +2,19 @@ import datetime import json +from collections.abc import Callable +from dataclasses import dataclass from functools import lru_cache from hashlib import sha256 from pathlib import Path from typing import Any +from xml.sax.saxutils import escape import numpy as np from markdown_it import MarkdownIt from pydantic import ConfigDict from sqlalchemy.engine import Engine, make_url -from sqlmodel import ( - JSON, - Column, - Field, - Relationship, - Session, - SQLModel, - create_engine, - text, -) +from sqlmodel import JSON, Column, Field, Relationship, Session, SQLModel, create_engine, text from raglite._config import RAGLiteConfig from raglite._litellm import get_embedding_dim @@ -83,11 +77,7 @@ class Chunk(SQLModel, table=True): @staticmethod def from_body( - document_id: str, - index: int, - body: str, - headings: str = "", - **kwargs: Any, + document_id: str, index: int, body: str, headings: str = "", **kwargs: Any ) -> "Chunk": """Create a chunk from Markdown.""" return Chunk( @@ -221,10 +211,7 @@ class Eval(SQLModel, table=True): @staticmethod def from_chunks( - question: str, - contexts: list[Chunk], - ground_truth: str, - **kwargs: Any, + question: str, contexts: list[Chunk], ground_truth: str, **kwargs: Any ) -> "Eval": """Create a chunk from Markdown.""" document_id = contexts[0].document_id @@ -284,18 +271,22 @@ def create_database_engine(config: RAGLiteConfig | None = None) -> Engine: with Session(engine) as session: metrics = {"cosine": "cosine", "dot": "ip", "euclidean": "l2", "l1": "l1", "l2": "l2"} session.execute( - text(""" + text( + """ CREATE INDEX IF NOT EXISTS keyword_search_chunk_index ON chunk USING GIN (to_tsvector('simple', body)); - """) + """ + ) ) session.execute( - text(f""" + text( + f""" CREATE INDEX IF NOT EXISTS vector_search_chunk_index ON chunk_embedding USING hnsw ( (embedding::halfvec({embedding_dim})) halfvec_{metrics[config.vector_search_index_metric]}_ops ); - """) + """ + ) ) session.commit() elif db_backend == "sqlite": @@ -304,31 +295,137 @@ def create_database_engine(config: RAGLiteConfig | None = None) -> Engine: # [1] https://www.sqlite.org/fts5.html#external_content_tables with Session(engine) as session: session.execute( - text(""" + text( + """ CREATE VIRTUAL TABLE IF NOT EXISTS keyword_search_chunk_index USING fts5(body, content='chunk', content_rowid='rowid'); - """) + """ + ) ) session.execute( - text(""" + text( + """ CREATE TRIGGER IF NOT EXISTS keyword_search_chunk_index_auto_insert AFTER INSERT ON chunk BEGIN INSERT INTO keyword_search_chunk_index(rowid, body) VALUES (new.rowid, new.body); END; - """) + """ + ) ) session.execute( - text(""" + text( + """ CREATE TRIGGER IF NOT EXISTS keyword_search_chunk_index_auto_delete AFTER DELETE ON chunk BEGIN INSERT INTO keyword_search_chunk_index(keyword_search_chunk_index, rowid, body) VALUES('delete', old.rowid, old.body); END; - """) + """ + ) ) session.execute( - text(""" + text( + """ CREATE TRIGGER IF NOT EXISTS keyword_search_chunk_index_auto_update AFTER UPDATE ON chunk BEGIN INSERT INTO keyword_search_chunk_index(keyword_search_chunk_index, rowid, body) VALUES('delete', old.rowid, old.body); INSERT INTO keyword_search_chunk_index(rowid, body) VALUES (new.rowid, new.body); END; - """) + """ + ) ) session.commit() return engine + + +@dataclass +class ContextSegment: + """A class representing a segment of context from a document. + + This class holds information about a specific segment of a document, + including its document ID and associated chunks of text with their IDs and scores. + + Attributes + ---------- + document_id (str): The unique identifier for the document. + chunks (list[Chunk]): List of chunks for this segment. + chunk_scores (list[float]): List of scores for each chunk. + + Raises + ------ + ValueError: If document_id is empty or if chunks is empty. + """ + + document_id: str + chunks: list[Chunk] + chunk_scores: list[float] + + def __post_init__(self) -> None: + """Validate the segment data after initialization.""" + if not isinstance(self.document_id, str) or not self.document_id.strip(): + msg = "document_id must be a non-empty string" + raise ValueError(msg) + if not self.chunks: + msg = "chunks cannot be empty" + raise ValueError(msg) + if not all(isinstance(chunk, Chunk) for chunk in self.chunks): + msg = "all elements in chunks must be Chunk instances" + raise ValueError(msg) + + def to_xml(self, indent: int = 4) -> str: + """Convert the segment to an XML string representation. + + Args: + indent (int): Number of spaces to use for indentation. + + Returns + ------- + str: XML representation of the segment. + """ + chunks_content = "\n".join(str(chunk) for chunk in self.chunks) + + # Create the final XML + chunk_ids = ",".join(self.chunk_ids) + xml = f"""\n{escape(str(chunks_content))}\n""" + + return xml + + def score(self, scoring_function: Callable[[list[float]], float] = sum) -> float: + """Return an aggregated score of the segment, given a scoring function.""" + return scoring_function(self.chunk_scores) + + @property + def chunk_ids(self) -> list[str]: + """Return a list of chunk IDs.""" + return [chunk.id for chunk in self.chunks] + + def __str__(self) -> str: + """Return a string representation reconstructing the document with headings. + + Shows each unique header exactly once, when it first appears. + For example: + - First chunk with "# A ## B" shows both headers + - Next chunk with "# A ## B" shows no headers as they're the same + - Next chunk with "# A ## C" only shows "## C" as it's the only new header + + Returns + ------- + str: A string containing content with each heading shown once. + """ + if not self.chunks: + return "" + + result = [] + seen_headers = set() # Track headers we've already shown + + for chunk in self.chunks: + # Get all headers in this chunk + headers = [h.strip() for h in chunk.headings.split("\n") if h.strip()] + + # Add any headers we haven't seen before + new_headers = [h for h in headers if h not in seen_headers] + if new_headers: + result.extend(new_headers) + result.append("") # Empty line after headers + seen_headers.update(new_headers) # Mark these headers as seen + + # Add the chunk body if it's not empty + if chunk.body.strip(): + result.append(chunk.body.strip()) + + return "\n".join(result).strip() diff --git a/src/raglite/_eval.py b/src/raglite/_eval.py index bfce7c4..e8b5ffc 100644 --- a/src/raglite/_eval.py +++ b/src/raglite/_eval.py @@ -78,7 +78,9 @@ def validate_question(cls, value: str) -> str: num_results=randint(2, max_contexts_per_eval // 2), # noqa: S311 config=config, ) - related_chunks = retrieve_segments(related_chunk_ids, config=config) + related_chunks = [ + str(segment) for segment in retrieve_segments(related_chunk_ids, config=config) + ] # Extract a question from the seed chunk's related chunks. try: question_response = extract_with_llm( @@ -157,9 +159,7 @@ class AnswerResponse(BaseModel): answer = answer_response.answer # Store the eval in the database. eval_ = Eval.from_chunks( - question=question, - contexts=relevant_chunks, - ground_truth=answer, + question=question, contexts=relevant_chunks, ground_truth=answer ) session.add(eval_) session.commit() @@ -185,7 +185,7 @@ def answer_evals( answer = "".join(response) answers.append(answer) chunk_ids, _ = search(eval_.question, config=config) - contexts.append(retrieve_segments(chunk_ids)) + contexts.append([str(segment) for segment in retrieve_segments(chunk_ids)]) # Collect the answered evals. answered_evals: dict[str, list[str] | list[list[str]]] = { "question": [eval_.question for eval_ in evals], @@ -199,8 +199,7 @@ def answer_evals( def evaluate( - answered_evals: pd.DataFrame | int = 100, - config: RAGLiteConfig | None = None, + answered_evals: pd.DataFrame | int = 100, config: RAGLiteConfig | None = None ) -> pd.DataFrame: """Evaluate the performance of a set of answered evals with Ragas.""" try: diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index 81ffda2..89a7cc0 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -1,11 +1,12 @@ """Retrieval-augmented generation.""" from collections.abc import AsyncIterator, Iterator +from typing import Literal from litellm import acompletion, completion from raglite._config import RAGLiteConfig -from raglite._database import Chunk +from raglite._database import Chunk, ContextSegment from raglite._litellm import get_context_size from raglite._search import hybrid_search, rerank_chunks, retrieve_segments from raglite._typing import SearchMethod @@ -46,7 +47,7 @@ def _max_contexts( return max_contexts -def _contexts( # noqa: PLR0913 +def context_segments( # noqa: PLR0913 prompt: str, *, max_contexts: int = 5, @@ -54,7 +55,7 @@ def _contexts( # noqa: PLR0913 search: SearchMethod | list[str] | list[Chunk] = hybrid_search, messages: list[dict[str, str]] | None = None, config: RAGLiteConfig | None = None, -) -> list[str]: +) -> list[ContextSegment]: """Retrieve contexts for RAG.""" # Determine the maximum number of contexts. max_contexts = _max_contexts( @@ -95,25 +96,19 @@ def rag( # noqa: PLR0913 """Retrieval-augmented generation.""" # Get the contexts for RAG as contiguous segments of chunks. config = config or RAGLiteConfig() - segments = _contexts( + segments = context_segments( prompt, max_contexts=max_contexts, context_neighbors=context_neighbors, search=search, config=config, ) - system_prompt = f"{system_prompt}\n\n" + "\n\n".join( - f'\n{segment.strip()}\n' - for i, segment in enumerate(segments) - ) # Stream the LLM response. stream = completion( model=config.llm, - messages=[ - *(messages or []), - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt}, - ], + messages=_compose_messages( + prompt=prompt, system_prompt=system_prompt, messages=messages, segments=segments + ), stream=True, ) for output in stream: @@ -134,27 +129,61 @@ async def async_rag( # noqa: PLR0913 """Retrieval-augmented generation.""" # Get the contexts for RAG as contiguous segments of chunks. config = config or RAGLiteConfig() - segments = _contexts( + segments = context_segments( prompt, max_contexts=max_contexts, context_neighbors=context_neighbors, search=search, config=config, ) - system_prompt = f"{system_prompt}\n\n" + "\n\n".join( - f'\n{segment.strip()}\n' - for i, segment in enumerate(segments) + messages = _compose_messages( + prompt=prompt, system_prompt=system_prompt, messages=messages, segments=segments ) # Stream the LLM response. - async_stream = await acompletion( - model=config.llm, - messages=[ - *(messages or []), - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt}, - ], - stream=True, - ) + async_stream = await acompletion(model=config.llm, messages=messages, stream=True) async for output in async_stream: token: str = output["choices"][0]["delta"].get("content") or "" yield token + + +def _compose_messages( + prompt: str, + system_prompt: str, + messages: list[dict[str, str]] | None, + segments: list[ContextSegment] | None, + context_placement: Literal[ + "system_prompt", "user_prompt", "separate_system_prompt" + ] = "user_prompt", +) -> list[dict[str, str]]: + """Compose the messages for the LLM, placing the context in the desired position.""" + # Using the format recommended by Anthropic for documents in RAG + # (https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips#essential-tips-for-long-context-prompts + if not segments: + return [ + {"role": "system", "content": system_prompt}, + *(messages or []), + {"role": "user", "content": prompt}, + ] + context_content = ( + "\n\n\n" + "\n\n".join(seg.to_xml() for seg in segments) + "\n" + ) + if context_placement == "system_prompt": + return [ + {"role": "system", "content": system_prompt + "\n\n" + context_content}, + *(messages or []), + {"role": "user", "content": prompt}, + ] + if context_placement == "user_prompt": + return [ + {"role": "system", "content": system_prompt}, + *(messages or []), + {"role": "user", "content": prompt + "\n\n" + context_content}, + ] + + # Separate system prompt from context + return [ + {"role": "system", "content": system_prompt}, + *(messages or []), + {"role": "system", "content": context_content}, + {"role": "user", "content": prompt}, + ] diff --git a/src/raglite/_search.py b/src/raglite/_search.py index 30c7982..1cb9f03 100644 --- a/src/raglite/_search.py +++ b/src/raglite/_search.py @@ -5,6 +5,7 @@ from collections import defaultdict from collections.abc import Sequence from itertools import groupby +from operator import attrgetter, methodcaller from typing import cast import numpy as np @@ -13,16 +14,19 @@ from sqlmodel import Session, and_, col, or_, select, text from raglite._config import RAGLiteConfig -from raglite._database import Chunk, ChunkEmbedding, IndexMetadata, create_database_engine +from raglite._database import ( + Chunk, + ChunkEmbedding, + ContextSegment, + IndexMetadata, + create_database_engine, +) from raglite._embed import embed_sentences from raglite._typing import FloatMatrix def vector_search( - query: str | FloatMatrix, - *, - num_results: int = 3, - config: RAGLiteConfig | None = None, + query: str | FloatMatrix, *, num_results: int = 3, config: RAGLiteConfig | None = None ) -> tuple[list[str], list[float]]: """Search chunks using ANN vector search.""" # Read the config. @@ -104,13 +108,15 @@ def keyword_search( query_escaped = re.sub(r"[&|!():<>\"]", " ", query) tsv_query = " | ".join(query_escaped.split()) # Perform keyword search with tsvector. - statement = text(""" + statement = text( + """ SELECT id as chunk_id, ts_rank(to_tsvector('simple', body), to_tsquery('simple', :query)) AS score FROM chunk WHERE to_tsvector('simple', body) @@ to_tsquery('simple', :query) ORDER BY score DESC LIMIT :limit; - """) + """ + ) results = session.execute(statement, params={"query": tsv_query, "limit": num_results}) elif db_backend == "sqlite": # Convert the query to an FTS5 query [1]. @@ -120,13 +126,15 @@ def keyword_search( # Perform keyword search with FTS5. In FTS5, BM25 scores are negative [1], so we # negate them to make them positive. # [1] https://www.sqlite.org/fts5.html#the_bm25_function - statement = text(""" + statement = text( + """ SELECT chunk.id as chunk_id, -bm25(keyword_search_chunk_index) as score FROM chunk JOIN keyword_search_chunk_index ON chunk.rowid = keyword_search_chunk_index.rowid WHERE keyword_search_chunk_index MATCH :match ORDER BY score DESC LIMIT :limit; - """) + """ + ) results = session.execute(statement, params={"match": fts5_query, "limit": num_results}) # Unpack the results. results = list(results) # type: ignore[assignment] @@ -166,11 +174,7 @@ def hybrid_search( return chunk_ids, hybrid_score -def retrieve_chunks( - chunk_ids: list[str], - *, - config: RAGLiteConfig | None = None, -) -> list[Chunk]: +def retrieve_chunks(chunk_ids: list[str], *, config: RAGLiteConfig | None = None) -> list[Chunk]: """Retrieve chunks by their ids.""" config = config or RAGLiteConfig() engine = create_database_engine(config) @@ -185,7 +189,7 @@ def retrieve_segments( *, neighbors: tuple[int, ...] | None = (-1, 1), config: RAGLiteConfig | None = None, -) -> list[str]: +) -> list[ContextSegment]: """Group chunks into contiguous segments and retrieve them.""" # Retrieve the chunks. config = config or RAGLiteConfig() @@ -204,40 +208,26 @@ def retrieve_segments( for offset in neighbors ] chunks += list(session.exec(select(Chunk).where(or_(*neighbor_conditions))).all()) - # Keep only the unique chunks. - chunks = list(set(chunks)) - # Sort the chunks by document_id and index (needed for groupby). - chunks = sorted(chunks, key=lambda chunk: (chunk.document_id, chunk.index)) - # Group the chunks into contiguous segments. - segments: list[list[Chunk]] = [] - for _, group in groupby(chunks, key=lambda chunk: chunk.document_id): - segment: list[Chunk] = [] - for chunk in group: - if not segment or chunk.index == segment[-1].index + 1: - segment.append(chunk) - else: - segments.append(segment) - segment = [chunk] - segments.append(segment) - # Rank segments according to the aggregate relevance of their chunks. + # Assign a reciprocal ranking score to each chunk based on its position in the original list. chunk_id_to_score = {chunk.id: 1 / (i + 1) for i, chunk in enumerate(chunks)} - segments.sort( - key=lambda segment: sum(chunk_id_to_score.get(chunk.id, 0.0) for chunk in segment), - reverse=True, - ) - # Convert the segments into strings. - segments = [ - segment[0].headings.strip() + "\n\n" + "".join(chunk.body for chunk in segment).strip() # type: ignore[misc] - for segment in segments + # Deduplicate and sort the chunks by document_id and index (needed for groupby). + unique_chunks = sorted(set(chunks), key=lambda chunk: (chunk.document_id, chunk.index)) + # Group the chunks into contiguous segments. + context_segments: list[ContextSegment] = [ + ContextSegment( + document_id=doc_id, + chunks=(doc_chunks := list(group)), + chunk_scores=[chunk_id_to_score.get(chunk.id, 0.0) for chunk in doc_chunks], + ) + for doc_id, group in groupby(unique_chunks, key=attrgetter("document_id")) ] - return segments # type: ignore[return-value] + # Rank segments according to the aggregate relevance of their chunks. + context_segments.sort(key=methodcaller("score", scoring_function=sum), reverse=True) + return context_segments def rerank_chunks( - query: str, - chunk_ids: list[str] | list[Chunk], - *, - config: RAGLiteConfig | None = None, + query: str, chunk_ids: list[str] | list[Chunk], *, config: RAGLiteConfig | None = None ) -> list[Chunk]: """Rerank chunks according to their relevance to a given query.""" # Retrieve the chunks. diff --git a/tests/test_search.py b/tests/test_search.py index 9cea9d1..8e74fd3 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -10,7 +10,7 @@ retrieve_segments, vector_search, ) -from raglite._database import Chunk +from raglite._database import Chunk, ContextSegment from raglite._typing import SearchMethod @@ -45,7 +45,7 @@ def test_search(raglite_test_config: RAGLiteConfig, search_method: SearchMethod) assert any("Definition of Simultaneity" in str(chunk) for chunk in chunks) # Extend the chunks with their neighbours and group them into contiguous segments. segments = retrieve_segments(chunk_ids, neighbors=(-1, 1), config=raglite_test_config) - assert all(isinstance(segment, str) for segment in segments) + assert all(isinstance(segment, ContextSegment) for segment in segments) def test_search_no_results(raglite_test_config: RAGLiteConfig, search_method: SearchMethod) -> None: From 04d9eb501777fd9129296cd2f5413b8fa07b4535 Mon Sep 17 00:00:00 2001 From: Manolo Santos Date: Mon, 25 Nov 2024 23:18:45 +0100 Subject: [PATCH 02/12] feat: Improve segment reconstruction from chunks. --- src/raglite/_database.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/src/raglite/_database.py b/src/raglite/_database.py index a9b779f..9ed394d 100644 --- a/src/raglite/_database.py +++ b/src/raglite/_database.py @@ -397,34 +397,35 @@ def chunk_ids(self) -> list[str]: def __str__(self) -> str: """Return a string representation reconstructing the document with headings. - Shows each unique header exactly once, when it first appears. - For example: - - First chunk with "# A ## B" shows both headers - - Next chunk with "# A ## B" shows no headers as they're the same - - Next chunk with "# A ## C" only shows "## C" as it's the only new header + Treats headings as a stack, showing headers only when they differ from + the current stack path. - Returns - ------- - str: A string containing content with each heading shown once. + For example: + - "# A ## B" shows both headers + - "# A ## B" shows nothing (already seen) + - "# A ## C" shows only "## C" (new branch) + - "# D ## B" shows both (new path) """ if not self.chunks: return "" - result = [] - seen_headers = set() # Track headers we've already shown + result: list[str] = [] + stack: list[str] = [] for chunk in self.chunks: - # Get all headers in this chunk headers = [h.strip() for h in chunk.headings.split("\n") if h.strip()] - # Add any headers we haven't seen before - new_headers = [h for h in headers if h not in seen_headers] - if new_headers: - result.extend(new_headers) - result.append("") # Empty line after headers - seen_headers.update(new_headers) # Mark these headers as seen + # Find first differing header + i = 0 + while i < len(headers) and i < len(stack) and headers[i] == stack[i]: + i += 1 + + # Update stack and show new headers + stack[i:] = headers[i:] + if headers[i:]: + result.extend(headers[i:]) + result.append("") - # Add the chunk body if it's not empty if chunk.body.strip(): result.append(chunk.body.strip()) From 4e3291409fc9b683263475fb0425efffcc567ebc Mon Sep 17 00:00:00 2001 From: Manolo Santos Date: Tue, 26 Nov 2024 19:47:03 +0100 Subject: [PATCH 03/12] feat: Simplify the api and minor improvements. --- src/raglite/_database.py | 36 +++++++++------------ src/raglite/_rag.py | 68 +++++++++++++++++----------------------- 2 files changed, 44 insertions(+), 60 deletions(-) diff --git a/src/raglite/_database.py b/src/raglite/_database.py index 9ed394d..e1a5f6c 100644 --- a/src/raglite/_database.py +++ b/src/raglite/_database.py @@ -355,33 +355,26 @@ class ContextSegment: chunks: list[Chunk] chunk_scores: list[float] - def __post_init__(self) -> None: - """Validate the segment data after initialization.""" - if not isinstance(self.document_id, str) or not self.document_id.strip(): - msg = "document_id must be a non-empty string" - raise ValueError(msg) - if not self.chunks: - msg = "chunks cannot be empty" - raise ValueError(msg) - if not all(isinstance(chunk, Chunk) for chunk in self.chunks): - msg = "all elements in chunks must be Chunk instances" - raise ValueError(msg) - - def to_xml(self, indent: int = 4) -> str: - """Convert the segment to an XML string representation. + def __str__(self) -> str: + """Return a string representation of the segment.""" + return self.as_xml - Args: - indent (int): Number of spaces to use for indentation. + @property + def as_xml(self) -> str: + """Returns the segment as an XML string representation. Returns ------- str: XML representation of the segment. """ - chunks_content = "\n".join(str(chunk) for chunk in self.chunks) - - # Create the final XML chunk_ids = ",".join(self.chunk_ids) - xml = f"""\n{escape(str(chunks_content))}\n""" + xml = "\n".join( + [ + f'', + escape(self.as_str), + "", + ] + ) return xml @@ -394,7 +387,8 @@ def chunk_ids(self) -> list[str]: """Return a list of chunk IDs.""" return [chunk.id for chunk in self.chunks] - def __str__(self) -> str: + @property + def as_str(self) -> str: """Return a string representation reconstructing the document with headings. Treats headings as a stack, showing headers only when they differ from diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index 89a7cc0..90e7262 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -1,7 +1,7 @@ """Retrieval-augmented generation.""" from collections.abc import AsyncIterator, Iterator -from typing import Literal +from typing import cast from litellm import acompletion, completion @@ -88,7 +88,7 @@ def rag( # noqa: PLR0913 *, max_contexts: int = 5, context_neighbors: tuple[int, ...] | None = (-1, 1), - search: SearchMethod | list[str] | list[Chunk] = hybrid_search, + search: SearchMethod | list[str] | list[Chunk] | list[ContextSegment] = hybrid_search, messages: list[dict[str, str]] | None = None, system_prompt: str = RAG_SYSTEM_PROMPT, config: RAGLiteConfig | None = None, @@ -96,13 +96,17 @@ def rag( # noqa: PLR0913 """Retrieval-augmented generation.""" # Get the contexts for RAG as contiguous segments of chunks. config = config or RAGLiteConfig() - segments = context_segments( - prompt, - max_contexts=max_contexts, - context_neighbors=context_neighbors, - search=search, - config=config, - ) + segments: list[ContextSegment] + if isinstance(search, list) and any(isinstance(segment, ContextSegment) for segment in search): + segments = cast(list[ContextSegment], search) + else: + segments = context_segments( + prompt, + max_contexts=max_contexts, + context_neighbors=context_neighbors, + search=search, # type: ignore[arg-type] + config=config, + ) # Stream the LLM response. stream = completion( model=config.llm, @@ -121,7 +125,7 @@ async def async_rag( # noqa: PLR0913 *, max_contexts: int = 5, context_neighbors: tuple[int, ...] | None = (-1, 1), - search: SearchMethod | list[str] | list[Chunk] = hybrid_search, + search: SearchMethod | list[str] | list[Chunk] | list[ContextSegment] = hybrid_search, messages: list[dict[str, str]] | None = None, system_prompt: str = RAG_SYSTEM_PROMPT, config: RAGLiteConfig | None = None, @@ -129,13 +133,17 @@ async def async_rag( # noqa: PLR0913 """Retrieval-augmented generation.""" # Get the contexts for RAG as contiguous segments of chunks. config = config or RAGLiteConfig() - segments = context_segments( - prompt, - max_contexts=max_contexts, - context_neighbors=context_neighbors, - search=search, - config=config, - ) + segments: list[ContextSegment] + if isinstance(search, list) and any(isinstance(segment, ContextSegment) for segment in search): + segments = cast(list[ContextSegment], search) + else: + segments = context_segments( + prompt, + max_contexts=max_contexts, + context_neighbors=context_neighbors, + search=search, # type: ignore[arg-type] + config=config, + ) messages = _compose_messages( prompt=prompt, system_prompt=system_prompt, messages=messages, segments=segments ) @@ -151,11 +159,8 @@ def _compose_messages( system_prompt: str, messages: list[dict[str, str]] | None, segments: list[ContextSegment] | None, - context_placement: Literal[ - "system_prompt", "user_prompt", "separate_system_prompt" - ] = "user_prompt", ) -> list[dict[str, str]]: - """Compose the messages for the LLM, placing the context in the desired position.""" + """Compose the messages for the LLM, placing the context in the user position.""" # Using the format recommended by Anthropic for documents in RAG # (https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips#essential-tips-for-long-context-prompts if not segments: @@ -164,26 +169,11 @@ def _compose_messages( *(messages or []), {"role": "user", "content": prompt}, ] - context_content = ( - "\n\n\n" + "\n\n".join(seg.to_xml() for seg in segments) + "\n" - ) - if context_placement == "system_prompt": - return [ - {"role": "system", "content": system_prompt + "\n\n" + context_content}, - *(messages or []), - {"role": "user", "content": prompt}, - ] - if context_placement == "user_prompt": - return [ - {"role": "system", "content": system_prompt}, - *(messages or []), - {"role": "user", "content": prompt + "\n\n" + context_content}, - ] - # Separate system prompt from context + context_content = "\n" + "\n".join(str(seg) for seg in segments) + "\n" + return [ {"role": "system", "content": system_prompt}, *(messages or []), - {"role": "system", "content": context_content}, - {"role": "user", "content": prompt}, + {"role": "user", "content": prompt + "\n\n" + context_content}, ] From c978e0eca9d327f7637333d7a46d5d15b16365a5 Mon Sep 17 00:00:00 2001 From: Manolo Santos Date: Wed, 27 Nov 2024 12:45:58 +0100 Subject: [PATCH 04/12] feat: decouple generation from context retrieval --- src/raglite/__init__.py | 29 ++++++++++---------- src/raglite/_chainlit.py | 10 ++++--- src/raglite/_database.py | 42 +++++------------------------ src/raglite/_eval.py | 9 ++++--- src/raglite/_rag.py | 58 ++++++++++------------------------------ tests/test_rag.py | 6 +++-- 6 files changed, 50 insertions(+), 104 deletions(-) diff --git a/src/raglite/__init__.py b/src/raglite/__init__.py index c1c9dc9..27f9822 100644 --- a/src/raglite/__init__.py +++ b/src/raglite/__init__.py @@ -5,7 +5,7 @@ from raglite._eval import answer_evals, evaluate, insert_evals from raglite._insert import insert_document from raglite._query_adapter import update_query_adapter -from raglite._rag import async_rag, rag +from raglite._rag import async_generate, generate, get_context_segments from raglite._search import ( hybrid_search, keyword_search, @@ -18,24 +18,25 @@ __all__ = [ # Config "RAGLiteConfig", - # Insert - "insert_document", + "answer_evals", + "async_generate", + # CLI + "cli", + "evaluate", + # RAG + "generate", + "get_context_segments", # Search "hybrid_search", + # Insert + "insert_document", + # Evaluate + "insert_evals", "keyword_search", - "vector_search", + "rerank_chunks", "retrieve_chunks", "retrieve_segments", - "rerank_chunks", - # RAG - "async_rag", - "rag", # Query adapter "update_query_adapter", - # Evaluate - "insert_evals", - "answer_evals", - "evaluate", - # CLI - "cli", + "vector_search", ] diff --git a/src/raglite/_chainlit.py b/src/raglite/_chainlit.py index 9499baf..7330b2d 100644 --- a/src/raglite/_chainlit.py +++ b/src/raglite/_chainlit.py @@ -8,7 +8,8 @@ from raglite import ( RAGLiteConfig, - async_rag, + async_generate, + get_context_segments, hybrid_search, insert_document, rerank_chunks, @@ -107,10 +108,11 @@ async def handle_message(user_message: cl.Message) -> None: ] # Stream the LLM response. assistant_message = cl.Message(content="") - async for token in async_rag( + context_segments = get_context_segments(user_prompt, config=config) + async for token in async_generate( prompt=user_prompt, - search=chunks, - messages=cl.chat_context.to_openai()[-5:], # type: ignore[no-untyped-call] + messages=cl.chat_context.to_openai()[-5:-1], # type: ignore[no-untyped-call] + context_segments=context_segments, config=config, ): await assistant_message.stream_token(token) diff --git a/src/raglite/_database.py b/src/raglite/_database.py index e1a5f6c..673e64f 100644 --- a/src/raglite/_database.py +++ b/src/raglite/_database.py @@ -371,7 +371,7 @@ def as_xml(self) -> str: xml = "\n".join( [ f'', - escape(self.as_str), + escape(self.reconstructed_str), "", ] ) @@ -388,39 +388,9 @@ def chunk_ids(self) -> list[str]: return [chunk.id for chunk in self.chunks] @property - def as_str(self) -> str: - """Return a string representation reconstructing the document with headings. + def reconstructed_str(self) -> str: + """Return a string representation reconstructing the document with headings.""" + heading = self.chunks[0].headings if self.chunks else "" + bodies = "\n".join(chunk.body for chunk in self.chunks) - Treats headings as a stack, showing headers only when they differ from - the current stack path. - - For example: - - "# A ## B" shows both headers - - "# A ## B" shows nothing (already seen) - - "# A ## C" shows only "## C" (new branch) - - "# D ## B" shows both (new path) - """ - if not self.chunks: - return "" - - result: list[str] = [] - stack: list[str] = [] - - for chunk in self.chunks: - headers = [h.strip() for h in chunk.headings.split("\n") if h.strip()] - - # Find first differing header - i = 0 - while i < len(headers) and i < len(stack) and headers[i] == stack[i]: - i += 1 - - # Update stack and show new headers - stack[i:] = headers[i:] - if headers[i:]: - result.extend(headers[i:]) - result.append("") - - if chunk.body.strip(): - result.append(chunk.body.strip()) - - return "\n".join(result).strip() + return f"{heading}\n\n{bodies}".strip() diff --git a/src/raglite/_eval.py b/src/raglite/_eval.py index e8b5ffc..87f1e2c 100644 --- a/src/raglite/_eval.py +++ b/src/raglite/_eval.py @@ -12,7 +12,7 @@ from raglite._config import RAGLiteConfig from raglite._database import Chunk, Document, Eval, create_database_engine from raglite._extract import extract_with_llm -from raglite._rag import rag +from raglite._rag import generate, get_context_segments from raglite._search import hybrid_search, retrieve_segments, vector_search from raglite._typing import SearchMethod @@ -181,7 +181,8 @@ def answer_evals( answers: list[str] = [] contexts: list[list[str]] = [] for eval_ in tqdm(evals, desc="Answering evals", unit="eval", dynamic_ncols=True): - response = rag(eval_.question, search=search, config=config) + segments = get_context_segments(eval_.question, search=search, config=config) + response = generate(eval_.question, context_segments=segments, config=config) answer = "".join(response) answers.append(answer) chunk_ids, _ = search(eval_.question, config=config) @@ -233,13 +234,13 @@ def evaluate( verbose=llm.verbose, ) else: - lc_llm = ChatLiteLLM(model=config.llm) # type: ignore[call-arg] + lc_llm = ChatLiteLLM(model=config.llm) # Load the embedder. if not config.embedder.startswith("llama-cpp-python"): error_message = "Currently, only `llama-cpp-python` embedders are supported." raise NotImplementedError(error_message) embedder = LlamaCppPythonLLM().llm(model=config.embedder, embedding=True) - lc_embedder = LlamaCppEmbeddings( # type: ignore[call-arg] + lc_embedder = LlamaCppEmbeddings( model_path=embedder.model_path, n_batch=embedder.n_batch, n_ctx=embedder.n_ctx(), diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index 90e7262..ca3d579 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -1,7 +1,6 @@ """Retrieval-augmented generation.""" from collections.abc import AsyncIterator, Iterator -from typing import cast from litellm import acompletion, completion @@ -47,7 +46,7 @@ def _max_contexts( return max_contexts -def context_segments( # noqa: PLR0913 +def get_context_segments( # noqa: PLR0913 prompt: str, *, max_contexts: int = 5, @@ -83,35 +82,22 @@ def context_segments( # noqa: PLR0913 return segments -def rag( # noqa: PLR0913 +def generate( prompt: str, *, - max_contexts: int = 5, - context_neighbors: tuple[int, ...] | None = (-1, 1), - search: SearchMethod | list[str] | list[Chunk] | list[ContextSegment] = hybrid_search, - messages: list[dict[str, str]] | None = None, system_prompt: str = RAG_SYSTEM_PROMPT, - config: RAGLiteConfig | None = None, + messages: list[dict[str, str]] | None = None, + context_segments: list[ContextSegment], + config: RAGLiteConfig, ) -> Iterator[str]: - """Retrieval-augmented generation.""" - # Get the contexts for RAG as contiguous segments of chunks. - config = config or RAGLiteConfig() - segments: list[ContextSegment] - if isinstance(search, list) and any(isinstance(segment, ContextSegment) for segment in search): - segments = cast(list[ContextSegment], search) - else: - segments = context_segments( - prompt, - max_contexts=max_contexts, - context_neighbors=context_neighbors, - search=search, # type: ignore[arg-type] - config=config, - ) + messages = _compose_messages( + prompt=prompt, system_prompt=system_prompt, messages=messages, segments=context_segments + ) # Stream the LLM response. stream = completion( model=config.llm, messages=_compose_messages( - prompt=prompt, system_prompt=system_prompt, messages=messages, segments=segments + prompt=prompt, system_prompt=system_prompt, messages=messages, segments=context_segments ), stream=True, ) @@ -120,32 +106,16 @@ def rag( # noqa: PLR0913 yield token -async def async_rag( # noqa: PLR0913 +async def async_generate( prompt: str, *, - max_contexts: int = 5, - context_neighbors: tuple[int, ...] | None = (-1, 1), - search: SearchMethod | list[str] | list[Chunk] | list[ContextSegment] = hybrid_search, - messages: list[dict[str, str]] | None = None, system_prompt: str = RAG_SYSTEM_PROMPT, - config: RAGLiteConfig | None = None, + messages: list[dict[str, str]] | None = None, + context_segments: list[ContextSegment], + config: RAGLiteConfig, ) -> AsyncIterator[str]: - """Retrieval-augmented generation.""" - # Get the contexts for RAG as contiguous segments of chunks. - config = config or RAGLiteConfig() - segments: list[ContextSegment] - if isinstance(search, list) and any(isinstance(segment, ContextSegment) for segment in search): - segments = cast(list[ContextSegment], search) - else: - segments = context_segments( - prompt, - max_contexts=max_contexts, - context_neighbors=context_neighbors, - search=search, # type: ignore[arg-type] - config=config, - ) messages = _compose_messages( - prompt=prompt, system_prompt=system_prompt, messages=messages, segments=segments + prompt=prompt, system_prompt=system_prompt, messages=messages, segments=context_segments ) # Stream the LLM response. async_stream = await acompletion(model=config.llm, messages=messages, stream=True) diff --git a/tests/test_rag.py b/tests/test_rag.py index 150a31b..3cbcf04 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -6,7 +6,8 @@ import pytest from llama_cpp import llama_supports_gpu_offload -from raglite import RAGLiteConfig, hybrid_search, rag, retrieve_chunks +from raglite import RAGLiteConfig, hybrid_search, retrieve_chunks +from raglite._rag import generate, get_context_segments if TYPE_CHECKING: from raglite._database import Chunk @@ -32,7 +33,8 @@ def test_rag(raglite_test_config: RAGLiteConfig) -> None: ] # Answer a question with RAG. for search_input in search_inputs: - stream = rag(prompt, search=search_input, config=raglite_test_config) + segments = get_context_segments(prompt, search=search_input, config=raglite_test_config) + stream = generate(prompt, context_segments=segments, config=raglite_test_config) answer = "" for update in stream: assert isinstance(update, str) From 0c247eabaa1f22f39c13f0ab4b36a15165743436 Mon Sep 17 00:00:00 2001 From: Manolo Santos Date: Sun, 1 Dec 2024 23:05:45 +0100 Subject: [PATCH 05/12] fix: consecutive chunks wrongly scored as neighbours. --- src/raglite/_search.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/raglite/_search.py b/src/raglite/_search.py index 1cb9f03..10ff8c3 100644 --- a/src/raglite/_search.py +++ b/src/raglite/_search.py @@ -198,6 +198,8 @@ def retrieve_segments( if all(isinstance(chunk_id, str) for chunk_id in chunk_ids) else chunk_ids ) + # Assign a reciprocal ranking score to each chunk based on its position in the original list. + chunk_id_to_score = {chunk.id: 1 / (i + 1) for i, chunk in enumerate(chunks)} # Extend the chunks with their neighbouring chunks. if neighbors: engine = create_database_engine(config) @@ -208,8 +210,6 @@ def retrieve_segments( for offset in neighbors ] chunks += list(session.exec(select(Chunk).where(or_(*neighbor_conditions))).all()) - # Assign a reciprocal ranking score to each chunk based on its position in the original list. - chunk_id_to_score = {chunk.id: 1 / (i + 1) for i, chunk in enumerate(chunks)} # Deduplicate and sort the chunks by document_id and index (needed for groupby). unique_chunks = sorted(set(chunks), key=lambda chunk: (chunk.document_id, chunk.index)) # Group the chunks into contiguous segments. From 851311cdcf97bb70afc265f1e7cfd71bf1b29dad Mon Sep 17 00:00:00 2001 From: Manolo Santos Date: Mon, 2 Dec 2024 10:33:47 +0100 Subject: [PATCH 06/12] fix: type validation. --- src/raglite/_eval.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/raglite/_eval.py b/src/raglite/_eval.py index 87f1e2c..54e4803 100644 --- a/src/raglite/_eval.py +++ b/src/raglite/_eval.py @@ -234,13 +234,13 @@ def evaluate( verbose=llm.verbose, ) else: - lc_llm = ChatLiteLLM(model=config.llm) + lc_llm = ChatLiteLLM(model=config.llm) # type: ignore[call-arg] # Load the embedder. if not config.embedder.startswith("llama-cpp-python"): error_message = "Currently, only `llama-cpp-python` embedders are supported." raise NotImplementedError(error_message) embedder = LlamaCppPythonLLM().llm(model=config.embedder, embedding=True) - lc_embedder = LlamaCppEmbeddings( + lc_embedder = LlamaCppEmbeddings( # type: ignore[call-arg] model_path=embedder.model_path, n_batch=embedder.n_batch, n_ctx=embedder.n_ctx(), From bbde96f5c8d5fe15ecb8de14472b4433f68c4aef Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Tue, 3 Dec 2024 12:06:04 +0100 Subject: [PATCH 07/12] feat: simplify RAG API --- README.md | 75 +++++++++++++--- src/raglite/__init__.py | 35 ++++---- src/raglite/_chainlit.py | 38 ++++---- src/raglite/_database.py | 142 ++++++++++++++---------------- src/raglite/_eval.py | 22 +++-- src/raglite/_extract.py | 2 +- src/raglite/_insert.py | 4 +- src/raglite/_rag.py | 183 ++++++++++++++------------------------- src/raglite/_search.py | 139 +++++++++++++++++------------ src/raglite/_typing.py | 5 ++ tests/test_rag.py | 39 ++++----- tests/test_search.py | 13 ++- 12 files changed, 359 insertions(+), 338 deletions(-) diff --git a/README.md b/README.md index 2cf9e4c..572aa04 100644 --- a/README.md +++ b/README.md @@ -157,38 +157,85 @@ insert_document(Path("Special Relativity.pdf"), config=my_config) ### 3. Searching and Retrieval-Augmented Generation (RAG) -Now, you can search for chunks with vector search, keyword search, or a hybrid of the two. You can also rerank the search results with the configured reranker. And you can use any search method of your choice (`hybrid_search` is the default) together with reranking to answer questions with RAG: +#### 3.1 Simple RAG pipeline + +Now you can run a simple but powerful RAG pipeline that consists of retrieving the most relevant chunk spans (each of which is a list of consecutive chunks) with hybrid search and reranking, converting the user prompt to a RAG instruction and appending it to the message history, and finally generating the RAG response: + +```python +from raglite import create_rag_instruction, rag, retrieve_rag_context + +# Retrieve relevant chunk spans with hybrid search and reranking: +user_prompt = "How is intelligence measured?" +chunk_spans = retrieve_rag_context(query=user_prompt, num_chunks=5, config=my_config) + +# Append a RAG instruction based on the user prompt and context to the message history: +messages = [] # Or start with an existing message history. +messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans)) + +# Stream the RAG response: +stream = rag(messages, config=my_config) +for update in stream: + print(update, end="") + +# Access the documents cited in the RAG response: +documents = [chunk_span.document for chunk_span in chunk_spans] +``` + +#### 3.2 Advanced RAG pipeline + +> [!TIP] +> 🥇 Reranking can significantly improve the output quality of a RAG application. To add reranking to your application: first search for a larger set of 20 relevant chunks, then rerank them with a [rerankers](https://github.com/AnswerDotAI/rerankers) reranker, and finally keep the top 5 chunks. + +In addition to the simple RAG pipeline, RAGLite also offers more advanced control over the individual steps of the pipeline. A full pipeline consists of several steps: + +1. Searching for relevant chunks with keyword, vector, or hybrid search +2. Retrieving the chunks from the database +3. Reranking the chunks and truncating the results to the top 5 +4. Extending the chunks with their neighbors and grouping them into chunk spans +5. Converting the user prompt to a RAG instruction and appending it to the message history +6. Streaming an LLM response to the message history +7. Accessing the cited documents from the chunk spans ```python # Search for chunks: from raglite import hybrid_search, keyword_search, vector_search -prompt = "How is intelligence measured?" -chunk_ids_vector, _ = vector_search(prompt, num_results=20, config=my_config) -chunk_ids_keyword, _ = keyword_search(prompt, num_results=20, config=my_config) -chunk_ids_hybrid, _ = hybrid_search(prompt, num_results=20, config=my_config) +user_prompt = "How is intelligence measured?" +chunk_ids_vector, _ = vector_search(user_prompt, num_results=20, config=my_config) +chunk_ids_keyword, _ = keyword_search(user_prompt, num_results=20, config=my_config) +chunk_ids_hybrid, _ = hybrid_search(user_prompt, num_results=20, config=my_config) # Retrieve chunks: from raglite import retrieve_chunks chunks_hybrid = retrieve_chunks(chunk_ids_hybrid, config=my_config) -# Rerank chunks: +# Rerank chunks and keep the top 5 (optional, but recommended): from raglite import rerank_chunks -chunks_reranked = rerank_chunks(prompt, chunks_hybrid, config=my_config) +chunks_reranked = rerank_chunks(user_prompt, chunks_hybrid, config=my_config) +chunks_reranked = chunks_reranked[:5] + +# Extend chunks with their neighbors and group them into chunk spans: +from raglite import retrieve_chunk_spans + +chunk_spans = retrieve_chunk_spans(chunks_reranked, config=my_config) + +# Append a RAG instruction based on the user prompt and context to the message history: +from raglite import create_rag_instruction + +messages = [] # Or start with an existing message history. +messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans)) -# Answer questions with RAG: +# Stream the RAG response: from raglite import rag -prompt = "What does it mean for two events to be simultaneous?" -stream = rag(prompt, config=my_config) +stream = rag(messages, config=my_config) for update in stream: print(update, end="") -# You can also pass a search method or search results directly: -stream = rag(prompt, search=hybrid_search, config=my_config) -stream = rag(prompt, search=chunks_reranked, config=my_config) +# Access the documents cited in the RAG response: +documents = [chunk_span.document for chunk_span in chunk_spans] ``` ### 4. Computing and using an optimal query adapter @@ -200,7 +247,7 @@ RAGLite can compute and apply an [optimal closed-form query adapter](src/raglite from raglite import insert_evals, update_query_adapter insert_evals(num_evals=100, config=my_config) -update_query_adapter(config=my_config) # From here, simply call vector_search to use the query adapter. +update_query_adapter(config=my_config) # From here, every vector search will use the query adapter. ``` ### 5. Evaluation of retrieval and generation diff --git a/src/raglite/__init__.py b/src/raglite/__init__.py index 27f9822..8ef7a26 100644 --- a/src/raglite/__init__.py +++ b/src/raglite/__init__.py @@ -5,38 +5,39 @@ from raglite._eval import answer_evals, evaluate, insert_evals from raglite._insert import insert_document from raglite._query_adapter import update_query_adapter -from raglite._rag import async_generate, generate, get_context_segments +from raglite._rag import async_rag, create_rag_instruction, rag, retrieve_rag_context from raglite._search import ( hybrid_search, keyword_search, rerank_chunks, + retrieve_chunk_spans, retrieve_chunks, - retrieve_segments, vector_search, ) __all__ = [ # Config "RAGLiteConfig", - "answer_evals", - "async_generate", - # CLI - "cli", - "evaluate", - # RAG - "generate", - "get_context_segments", - # Search - "hybrid_search", # Insert "insert_document", - # Evaluate - "insert_evals", + # Search + "hybrid_search", "keyword_search", - "rerank_chunks", + "vector_search", "retrieve_chunks", - "retrieve_segments", + "retrieve_chunk_spans", + "rerank_chunks", + # RAG + "retrieve_rag_context", + "create_rag_instruction", + "async_rag", + "rag", # Query adapter "update_query_adapter", - "vector_search", + # Evaluate + "insert_evals", + "answer_evals", + "evaluate", + # CLI + "cli", ] diff --git a/src/raglite/_chainlit.py b/src/raglite/_chainlit.py index 7330b2d..bebd730 100644 --- a/src/raglite/_chainlit.py +++ b/src/raglite/_chainlit.py @@ -8,11 +8,12 @@ from raglite import ( RAGLiteConfig, - async_generate, - get_context_segments, + async_rag, + create_rag_instruction, hybrid_search, insert_document, rerank_chunks, + retrieve_chunk_spans, retrieve_chunks, ) from raglite._markdown import document_to_markdown @@ -20,6 +21,7 @@ async_insert_document = cl.make_async(insert_document) async_hybrid_search = cl.make_async(hybrid_search) async_retrieve_chunks = cl.make_async(retrieve_chunks) +async_retrieve_chunk_spans = cl.make_async(retrieve_chunk_spans) async_rerank_chunks = cl.make_async(rerank_chunks) @@ -85,9 +87,12 @@ async def handle_message(user_message: cl.Message) -> None: step.input = Path(file.path).name await async_insert_document(Path(file.path), config=config) # Append any inline attachments to the user prompt. - user_prompt = f"{user_message.content}\n\n" + "\n\n".join( - f'\n{attachment.strip()}\n' - for i, attachment in enumerate(inline_attachments) + user_prompt = ( + "\n\n".join( + f'\n{attachment.strip()}\n' + for i, attachment in enumerate(inline_attachments) + ) + + f"\n\n{user_message.content}" ) # Search for relevant contexts for RAG. async with cl.Step(name="search", type="retrieval") as step: @@ -95,25 +100,22 @@ async def handle_message(user_message: cl.Message) -> None: chunk_ids, _ = await async_hybrid_search(query=user_prompt, num_results=10, config=config) chunks = await async_retrieve_chunks(chunk_ids=chunk_ids, config=config) step.output = chunks - step.elements = [ # Show the top 3 chunks inline. - cl.Text(content=str(chunk), display="inline") for chunk in chunks[:3] + step.elements = [ # Show the top chunks inline. + cl.Text(content=str(chunk), display="inline") for chunk in chunks[:5] ] - # Rerank the chunks. + # Rerank the chunks and group them into chunk spans. async with cl.Step(name="rerank", type="rerank") as step: step.input = chunks chunks = await async_rerank_chunks(query=user_prompt, chunk_ids=chunks, config=config) - step.output = chunks - step.elements = [ # Show the top 3 chunks inline. - cl.Text(content=str(chunk), display="inline") for chunk in chunks[:3] + chunk_spans = await async_retrieve_chunk_spans(chunks[:5], config=config) + step.output = chunk_spans + step.elements = [ # Show the top chunk spans inline. + cl.Text(content=str(chunk_span), display="inline") for chunk_span in chunk_spans ] # Stream the LLM response. assistant_message = cl.Message(content="") - context_segments = get_context_segments(user_prompt, config=config) - async for token in async_generate( - prompt=user_prompt, - messages=cl.chat_context.to_openai()[-5:-1], # type: ignore[no-untyped-call] - context_segments=context_segments, - config=config, - ): + messages: list[dict[str, str]] = cl.chat_context.to_openai() # type: ignore[no-untyped-call] + messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans)) + async for token in async_rag(messages, config=config): await assistant_message.stream_token(token) await assistant_message.update() # type: ignore[no-untyped-call] diff --git a/src/raglite/_database.py b/src/raglite/_database.py index 673e64f..8c47f4a 100644 --- a/src/raglite/_database.py +++ b/src/raglite/_database.py @@ -2,7 +2,6 @@ import datetime import json -from collections.abc import Callable from dataclasses import dataclass from functools import lru_cache from hashlib import sha256 @@ -18,7 +17,16 @@ from raglite._config import RAGLiteConfig from raglite._litellm import get_embedding_dim -from raglite._typing import Embedding, FloatMatrix, FloatVector, PickledObject +from raglite._typing import ( + ChunkId, + DocumentId, + Embedding, + EvalId, + FloatMatrix, + FloatVector, + IndexId, + PickledObject, +) def hash_bytes(data: bytes, max_len: int = 16) -> str: @@ -33,7 +41,7 @@ class Document(SQLModel, table=True): model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment] # Table columns. - id: str = Field(..., primary_key=True) + id: DocumentId = Field(..., primary_key=True) filename: str url: str | None = Field(default=None) metadata_: dict[str, Any] = Field(default_factory=dict, sa_column=Column("metadata", JSON)) @@ -64,8 +72,8 @@ class Chunk(SQLModel, table=True): model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment] # Table columns. - id: str = Field(..., primary_key=True) - document_id: str = Field(..., foreign_key="document.id", index=True) + id: ChunkId = Field(..., primary_key=True) + document_id: DocumentId = Field(..., foreign_key="document.id", index=True) index: int = Field(..., index=True) headings: str body: str @@ -77,7 +85,7 @@ class Chunk(SQLModel, table=True): @staticmethod def from_body( - document_id: str, index: int, body: str, headings: str = "", **kwargs: Any + document_id: DocumentId, index: int, body: str, headings: str = "", **kwargs: Any ) -> "Chunk": """Create a chunk from Markdown.""" return Chunk( @@ -129,10 +137,55 @@ def __repr__(self) -> str: indent=4, ) - def __str__(self) -> str: - """Context representation of this chunk.""" + @property + def content(self) -> str: + """Return this chunk's contextual heading and body.""" return f"{self.headings.strip()}\n\n{self.body.strip()}".strip() + def __str__(self) -> str: + """Return this chunk's content.""" + return self.content + + +@dataclass +class ChunkSpan: + """A consecutive sequence of chunks from a single document.""" + + chunks: list[Chunk] + document: Document + + def to_xml(self, index: int | None = None) -> str: + """Convert this chunk span to an XML representation. + + The XML representation follows Anthropic's best practices [1]. + + [1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips + """ + if not self.chunks: + return "" + index_attribute = f' index="{index}"' if index is not None else "" + xml = "\n".join( + [ + f'', + f"{self.document.url if self.document.url else self.document.filename}" + f"{escape(self.chunks[0].headings.strip())}" + f"\n{escape(''.join(chunk.body for chunk in self.chunks).strip())}\n", + "", + ] + ) + return xml + + @property + def content(self) -> str: + """Return this chunk span's contextual heading and chunk bodies.""" + heading = self.chunks[0].headings.strip() if self.chunks else "" + bodies = "".join(chunk.body for chunk in self.chunks) + return f"{heading}\n\n{bodies}".strip() + + def __str__(self) -> str: + """Return this chunk span's content.""" + return self.content + class ChunkEmbedding(SQLModel, table=True): """A (sub-)chunk embedding.""" @@ -144,7 +197,7 @@ class ChunkEmbedding(SQLModel, table=True): # Table columns. id: int = Field(..., primary_key=True) - chunk_id: str = Field(..., foreign_key="chunk.id", index=True) + chunk_id: ChunkId = Field(..., foreign_key="chunk.id", index=True) embedding: FloatVector = Field(..., sa_column=Column(Embedding(dim=-1))) # Add relationship so we can access embedding.chunk. @@ -165,7 +218,7 @@ class IndexMetadata(SQLModel, table=True): model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment] # Table columns. - id: str = Field(..., primary_key=True) + id: IndexId = Field(..., primary_key=True) version: datetime.datetime = Field( default_factory=lambda: datetime.datetime.now(datetime.timezone.utc) ) @@ -198,9 +251,9 @@ class Eval(SQLModel, table=True): model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment] # Table columns. - id: str = Field(..., primary_key=True) - document_id: str = Field(..., foreign_key="document.id", index=True) - chunk_ids: list[str] = Field(default_factory=list, sa_column=Column(JSON)) + id: EvalId = Field(..., primary_key=True) + document_id: DocumentId = Field(..., foreign_key="document.id", index=True) + chunk_ids: list[ChunkId] = Field(default_factory=list, sa_column=Column(JSON)) question: str contexts: list[str] = Field(default_factory=list, sa_column=Column(JSON)) ground_truth: str @@ -331,66 +384,3 @@ def create_database_engine(config: RAGLiteConfig | None = None) -> Engine: ) session.commit() return engine - - -@dataclass -class ContextSegment: - """A class representing a segment of context from a document. - - This class holds information about a specific segment of a document, - including its document ID and associated chunks of text with their IDs and scores. - - Attributes - ---------- - document_id (str): The unique identifier for the document. - chunks (list[Chunk]): List of chunks for this segment. - chunk_scores (list[float]): List of scores for each chunk. - - Raises - ------ - ValueError: If document_id is empty or if chunks is empty. - """ - - document_id: str - chunks: list[Chunk] - chunk_scores: list[float] - - def __str__(self) -> str: - """Return a string representation of the segment.""" - return self.as_xml - - @property - def as_xml(self) -> str: - """Returns the segment as an XML string representation. - - Returns - ------- - str: XML representation of the segment. - """ - chunk_ids = ",".join(self.chunk_ids) - xml = "\n".join( - [ - f'', - escape(self.reconstructed_str), - "", - ] - ) - - return xml - - def score(self, scoring_function: Callable[[list[float]], float] = sum) -> float: - """Return an aggregated score of the segment, given a scoring function.""" - return scoring_function(self.chunk_scores) - - @property - def chunk_ids(self) -> list[str]: - """Return a list of chunk IDs.""" - return [chunk.id for chunk in self.chunks] - - @property - def reconstructed_str(self) -> str: - """Return a string representation reconstructing the document with headings.""" - heading = self.chunks[0].headings if self.chunks else "" - bodies = "\n".join(chunk.body for chunk in self.chunks) - - return f"{heading}\n\n{bodies}".strip() diff --git a/src/raglite/_eval.py b/src/raglite/_eval.py index 54e4803..f26789c 100644 --- a/src/raglite/_eval.py +++ b/src/raglite/_eval.py @@ -12,8 +12,8 @@ from raglite._config import RAGLiteConfig from raglite._database import Chunk, Document, Eval, create_database_engine from raglite._extract import extract_with_llm -from raglite._rag import generate, get_context_segments -from raglite._search import hybrid_search, retrieve_segments, vector_search +from raglite._rag import create_rag_instruction, rag, retrieve_rag_context +from raglite._search import hybrid_search, retrieve_chunk_spans, vector_search from raglite._typing import SearchMethod @@ -74,12 +74,13 @@ def validate_question(cls, value: str) -> str: continue # Expand the seed chunk into a set of related chunks. related_chunk_ids, _ = vector_search( - np.mean(seed_chunk.embedding_matrix, axis=0, keepdims=True), + query=np.mean(seed_chunk.embedding_matrix, axis=0, keepdims=True), num_results=randint(2, max_contexts_per_eval // 2), # noqa: S311 config=config, ) related_chunks = [ - str(segment) for segment in retrieve_segments(related_chunk_ids, config=config) + str(chunk_spans) + for chunk_spans in retrieve_chunk_spans(related_chunk_ids, config=config) ] # Extract a question from the seed chunk's related chunks. try: @@ -92,7 +93,7 @@ def validate_question(cls, value: str) -> str: question = question_response.question # Search for candidate chunks to answer the generated question. candidate_chunk_ids, _ = hybrid_search( - question, num_results=max_contexts_per_eval, config=config + query=question, num_results=max_contexts_per_eval, config=config ) candidate_chunks = [session.get(Chunk, chunk_id) for chunk_id in candidate_chunk_ids] @@ -181,12 +182,15 @@ def answer_evals( answers: list[str] = [] contexts: list[list[str]] = [] for eval_ in tqdm(evals, desc="Answering evals", unit="eval", dynamic_ncols=True): - segments = get_context_segments(eval_.question, search=search, config=config) - response = generate(eval_.question, context_segments=segments, config=config) + chunk_spans = retrieve_rag_context(query=eval_.question, search=search, config=config) + messages = [create_rag_instruction(user_prompt=eval_.question, context=chunk_spans)] + response = rag(messages, config=config) answer = "".join(response) answers.append(answer) - chunk_ids, _ = search(eval_.question, config=config) - contexts.append([str(segment) for segment in retrieve_segments(chunk_ids)]) + chunk_ids, _ = search(query=eval_.question, config=config) + contexts.append( + [str(chunk_span) for chunk_span in retrieve_chunk_spans(chunk_ids, config=config)] + ) # Collect the answered evals. answered_evals: dict[str, list[str] | list[list[str]]] = { "question": [eval_.question for eval_ in evals], diff --git a/src/raglite/_extract.py b/src/raglite/_extract.py index 95e46e3..f3d73ff 100644 --- a/src/raglite/_extract.py +++ b/src/raglite/_extract.py @@ -61,7 +61,7 @@ class MyNameResponse(BaseModel): # Concatenate the user prompt if it is a list of strings. if isinstance(user_prompt, list): user_prompt = "\n\n".join( - f'\n{chunk.strip()}\n' + f'\n{chunk.strip()}\n' for i, chunk in enumerate(user_prompt) ) # Enable JSON schema validation. diff --git a/src/raglite/_insert.py b/src/raglite/_insert.py index 804d0b7..42061a6 100644 --- a/src/raglite/_insert.py +++ b/src/raglite/_insert.py @@ -13,11 +13,11 @@ from raglite._markdown import document_to_markdown from raglite._split_chunks import split_chunks from raglite._split_sentences import split_sentences -from raglite._typing import FloatMatrix +from raglite._typing import DocumentId, FloatMatrix def _create_chunk_records( - document_id: str, + document_id: DocumentId, chunks: list[str], chunk_embeddings: list[FloatMatrix], config: RAGLiteConfig, diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index ca3d579..d9db48d 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -2,148 +2,95 @@ from collections.abc import AsyncIterator, Iterator +import numpy as np from litellm import acompletion, completion from raglite._config import RAGLiteConfig -from raglite._database import Chunk, ContextSegment +from raglite._database import ChunkSpan from raglite._litellm import get_context_size -from raglite._search import hybrid_search, rerank_chunks, retrieve_segments +from raglite._search import hybrid_search, rerank_chunks, retrieve_chunk_spans from raglite._typing import SearchMethod -RAG_SYSTEM_PROMPT = """ +# The default RAG instruction template follows Anthropic's best practices [1]. +# [1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips +RAG_INSTRUCTION_TEMPLATE = """ You are a friendly and knowledgeable assistant that provides complete and insightful answers. -Answer the user's question using only the context below. -When responding, you MUST NOT reference the existence of the context, directly or indirectly. -Instead, you MUST treat the context as if its contents are entirely part of your working memory. +You MUST observe the following rules: +1. Whenever possible, use only the provided context below to answer the question at the end. +2. Cite your sources with inline numerical citations of the form "[n]", where n is the document index. + Use commas to separate citations as "[a], [b], [c]" when citing multiple sources consecutively. +3. Do not print a list of sources at the end. + +{context} + +{user_prompt} """.strip() -def _max_contexts( - prompt: str, - *, - max_contexts: int = 5, - context_neighbors: tuple[int, ...] | None = (-1, 1), - messages: list[dict[str, str]] | None = None, - config: RAGLiteConfig | None = None, -) -> int: - """Determine the maximum number of contexts for RAG.""" - # Get the model's context size. - config = config or RAGLiteConfig() - max_tokens = get_context_size(config) - # Reduce the maximum number of contexts to take into account the LLM's context size. - max_context_tokens = ( - max_tokens - - sum(len(message["content"]) // 3 for message in messages or []) # Previous messages. - - len(RAG_SYSTEM_PROMPT) // 3 # System prompt. - - len(prompt) // 3 # User prompt. - ) - max_tokens_per_context = config.chunk_max_size // 3 - max_tokens_per_context *= 1 + len(context_neighbors or []) - max_contexts = min(max_contexts, max_context_tokens // max_tokens_per_context) - if max_contexts <= 0: - error_message = "Not enough context tokens available for RAG." - raise ValueError(error_message) - return max_contexts - - -def get_context_segments( # noqa: PLR0913 - prompt: str, +def retrieve_rag_context( + query: str, *, - max_contexts: int = 5, - context_neighbors: tuple[int, ...] | None = (-1, 1), - search: SearchMethod | list[str] | list[Chunk] = hybrid_search, - messages: list[dict[str, str]] | None = None, + num_chunks: int = 5, + chunk_neighbors: tuple[int, ...] | None = (-1, 1), + search: SearchMethod = hybrid_search, config: RAGLiteConfig | None = None, -) -> list[ContextSegment]: - """Retrieve contexts for RAG.""" - # Determine the maximum number of contexts. - max_contexts = _max_contexts( - prompt, - max_contexts=max_contexts, - context_neighbors=context_neighbors, - messages=messages, - config=config, - ) - # Retrieve the top chunks. +) -> list[ChunkSpan]: + """Retrieve context for RAG.""" + # If the user has configured a reranker, we retrieve extra contexts to rerank. config = config or RAGLiteConfig() - chunks: list[str] | list[Chunk] - if callable(search): - # If the user has configured a reranker, we retrieve extra contexts to rerank. - extra_contexts = 3 * max_contexts if config.reranker else 0 - # Retrieve relevant contexts. - chunk_ids, _ = search(prompt, num_results=max_contexts + extra_contexts, config=config) - # Rerank the relevant contexts. - chunks = rerank_chunks(query=prompt, chunk_ids=chunk_ids, config=config) - else: - # The user has passed a list of chunk_ids or chunks directly. - chunks = search + extra_chunks = 3 * num_chunks if config.reranker else 0 + # Search for relevant chunks. + chunk_ids, _ = search(query, num_results=num_chunks + extra_chunks, config=config) + # Rerank the chunks from most to least relevant. + chunks = rerank_chunks(query, chunk_ids=chunk_ids, config=config) # Extend the top contexts with their neighbors and group chunks into contiguous segments. - segments = retrieve_segments(chunks[:max_contexts], neighbors=context_neighbors, config=config) - return segments + context = retrieve_chunk_spans(chunks[:num_chunks], neighbors=chunk_neighbors, config=config) + return context -def generate( - prompt: str, +def create_rag_instruction( + user_prompt: str, + context: list[ChunkSpan], *, - system_prompt: str = RAG_SYSTEM_PROMPT, - messages: list[dict[str, str]] | None = None, - context_segments: list[ContextSegment], - config: RAGLiteConfig, -) -> Iterator[str]: - messages = _compose_messages( - prompt=prompt, system_prompt=system_prompt, messages=messages, segments=context_segments - ) - # Stream the LLM response. - stream = completion( - model=config.llm, - messages=_compose_messages( - prompt=prompt, system_prompt=system_prompt, messages=messages, segments=context_segments + rag_instruction_template: str = RAG_INSTRUCTION_TEMPLATE, +) -> dict[str, str]: + """Convert a user prompt to a RAG instruction. + + The RAG instruction's format follows Anthropic's best practices [1]. + + [1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips + """ + message = { + "role": "user", + "content": rag_instruction_template.format( + user_prompt=user_prompt, + context="\n".join( + chunk_span.to_xml(index=i + 1) for i, chunk_span in enumerate(context) + ), ), - stream=True, - ) + } + return message + + +def rag(messages: list[dict[str, str]], *, config: RAGLiteConfig) -> Iterator[str]: + # Truncate the oldest messages so we don't hit the context limit. + max_tokens = get_context_size(config) + cum_tokens = np.cumsum([len(message.get("content", "")) // 3 for message in messages][::-1]) + messages = messages[-np.searchsorted(cum_tokens, max_tokens) :] + # Stream the LLM response. + stream = completion(model=config.llm, messages=messages, stream=True) for output in stream: token: str = output["choices"][0]["delta"].get("content") or "" yield token -async def async_generate( - prompt: str, - *, - system_prompt: str = RAG_SYSTEM_PROMPT, - messages: list[dict[str, str]] | None = None, - context_segments: list[ContextSegment], - config: RAGLiteConfig, -) -> AsyncIterator[str]: - messages = _compose_messages( - prompt=prompt, system_prompt=system_prompt, messages=messages, segments=context_segments - ) - # Stream the LLM response. +async def async_rag(messages: list[dict[str, str]], *, config: RAGLiteConfig) -> AsyncIterator[str]: + # Truncate the oldest messages so we don't hit the context limit. + max_tokens = get_context_size(config) + cum_tokens = np.cumsum([len(message.get("content", "")) // 3 for message in messages][::-1]) + messages = messages[-np.searchsorted(cum_tokens, max_tokens) :] + # Asynchronously stream the LLM response. async_stream = await acompletion(model=config.llm, messages=messages, stream=True) async for output in async_stream: token: str = output["choices"][0]["delta"].get("content") or "" yield token - - -def _compose_messages( - prompt: str, - system_prompt: str, - messages: list[dict[str, str]] | None, - segments: list[ContextSegment] | None, -) -> list[dict[str, str]]: - """Compose the messages for the LLM, placing the context in the user position.""" - # Using the format recommended by Anthropic for documents in RAG - # (https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips#essential-tips-for-long-context-prompts - if not segments: - return [ - {"role": "system", "content": system_prompt}, - *(messages or []), - {"role": "user", "content": prompt}, - ] - - context_content = "\n" + "\n".join(str(seg) for seg in segments) + "\n" - - return [ - {"role": "system", "content": system_prompt}, - *(messages or []), - {"role": "user", "content": prompt + "\n\n" + context_content}, - ] diff --git a/src/raglite/_search.py b/src/raglite/_search.py index 10ff8c3..2a00675 100644 --- a/src/raglite/_search.py +++ b/src/raglite/_search.py @@ -1,33 +1,33 @@ -"""Query documents.""" +"""Search and retrieve chunks.""" import re import string from collections import defaultdict from collections.abc import Sequence from itertools import groupby -from operator import attrgetter, methodcaller from typing import cast import numpy as np from langdetect import detect from sqlalchemy.engine import make_url +from sqlalchemy.orm import joinedload from sqlmodel import Session, and_, col, or_, select, text from raglite._config import RAGLiteConfig from raglite._database import ( Chunk, ChunkEmbedding, - ContextSegment, + ChunkSpan, IndexMetadata, create_database_engine, ) from raglite._embed import embed_sentences -from raglite._typing import FloatMatrix +from raglite._typing import ChunkId, FloatMatrix def vector_search( query: str | FloatMatrix, *, num_results: int = 3, config: RAGLiteConfig | None = None -) -> tuple[list[str], list[float]]: +) -> tuple[list[ChunkId], list[float]]: """Search chunks using ANN vector search.""" # Read the config. config = config or RAGLiteConfig() @@ -94,7 +94,7 @@ def vector_search( def keyword_search( query: str, *, num_results: int = 3, config: RAGLiteConfig | None = None -) -> tuple[list[str], list[float]]: +) -> tuple[list[ChunkId], list[float]]: """Search chunks using BM25 keyword search.""" # Read the config. config = config or RAGLiteConfig() @@ -144,8 +144,8 @@ def keyword_search( def reciprocal_rank_fusion( - rankings: list[list[str]], *, k: int = 60 -) -> tuple[list[str], list[float]]: + rankings: list[list[ChunkId]], *, k: int = 60 +) -> tuple[list[ChunkId], list[float]]: """Reciprocal Rank Fusion.""" # Compute the RRF score. chunk_ids = {chunk_id for ranking in rankings for chunk_id in ranking} @@ -163,7 +163,7 @@ def reciprocal_rank_fusion( def hybrid_search( query: str, *, num_results: int = 3, num_rerank: int = 100, config: RAGLiteConfig | None = None -) -> tuple[list[str], list[float]]: +) -> tuple[list[ChunkId], list[float]]: """Search chunks by combining ANN vector search with BM25 keyword search.""" # Run both searches. vs_chunk_ids, _ = vector_search(query, num_results=num_rerank, config=config) @@ -174,67 +174,34 @@ def hybrid_search( return chunk_ids, hybrid_score -def retrieve_chunks(chunk_ids: list[str], *, config: RAGLiteConfig | None = None) -> list[Chunk]: +def retrieve_chunks( + chunk_ids: list[ChunkId], *, config: RAGLiteConfig | None = None +) -> list[Chunk]: """Retrieve chunks by their ids.""" config = config or RAGLiteConfig() engine = create_database_engine(config) with Session(engine) as session: - chunks = list(session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all()) + chunks = list( + session.exec( + select(Chunk) + .where(col(Chunk.id).in_(chunk_ids)) + # Eagerly load chunk.document. + .options(joinedload(Chunk.document)) # type: ignore[arg-type] + ).all() + ) chunks = sorted(chunks, key=lambda chunk: chunk_ids.index(chunk.id)) return chunks -def retrieve_segments( - chunk_ids: list[str] | list[Chunk], - *, - neighbors: tuple[int, ...] | None = (-1, 1), - config: RAGLiteConfig | None = None, -) -> list[ContextSegment]: - """Group chunks into contiguous segments and retrieve them.""" - # Retrieve the chunks. - config = config or RAGLiteConfig() - chunks: list[Chunk] = ( - retrieve_chunks(chunk_ids, config=config) # type: ignore[arg-type,assignment] - if all(isinstance(chunk_id, str) for chunk_id in chunk_ids) - else chunk_ids - ) - # Assign a reciprocal ranking score to each chunk based on its position in the original list. - chunk_id_to_score = {chunk.id: 1 / (i + 1) for i, chunk in enumerate(chunks)} - # Extend the chunks with their neighbouring chunks. - if neighbors: - engine = create_database_engine(config) - with Session(engine) as session: - neighbor_conditions = [ - and_(Chunk.document_id == chunk.document_id, Chunk.index == chunk.index + offset) - for chunk in chunks - for offset in neighbors - ] - chunks += list(session.exec(select(Chunk).where(or_(*neighbor_conditions))).all()) - # Deduplicate and sort the chunks by document_id and index (needed for groupby). - unique_chunks = sorted(set(chunks), key=lambda chunk: (chunk.document_id, chunk.index)) - # Group the chunks into contiguous segments. - context_segments: list[ContextSegment] = [ - ContextSegment( - document_id=doc_id, - chunks=(doc_chunks := list(group)), - chunk_scores=[chunk_id_to_score.get(chunk.id, 0.0) for chunk in doc_chunks], - ) - for doc_id, group in groupby(unique_chunks, key=attrgetter("document_id")) - ] - # Rank segments according to the aggregate relevance of their chunks. - context_segments.sort(key=methodcaller("score", scoring_function=sum), reverse=True) - return context_segments - - def rerank_chunks( - query: str, chunk_ids: list[str] | list[Chunk], *, config: RAGLiteConfig | None = None + query: str, chunk_ids: list[ChunkId] | list[Chunk], *, config: RAGLiteConfig | None = None ) -> list[Chunk]: """Rerank chunks according to their relevance to a given query.""" # Retrieve the chunks. config = config or RAGLiteConfig() chunks: list[Chunk] = ( retrieve_chunks(chunk_ids, config=config) # type: ignore[arg-type,assignment] - if all(isinstance(chunk_id, str) for chunk_id in chunk_ids) + if all(isinstance(chunk_id, ChunkId) for chunk_id in chunk_ids) else chunk_ids ) # Early exit if no reranker is configured. @@ -259,3 +226,65 @@ def rerank_chunks( results = reranker.rank(query=query, docs=[str(chunk) for chunk in chunks]) chunks = [chunks[result.doc_id] for result in results.results] return chunks + + +def retrieve_chunk_spans( + chunk_ids: list[ChunkId] | list[Chunk], + *, + neighbors: tuple[int, ...] | None = (-1, 1), + config: RAGLiteConfig | None = None, +) -> list[ChunkSpan]: + """Group chunks into contiguous chunk spans and retrieve them. + + Chunk spans are ordered according to the aggregate relevance of their underlying chunks, as + determined by the order in which they are provided to this function. + """ + # Retrieve the chunks. + config = config or RAGLiteConfig() + chunks: list[Chunk] = ( + retrieve_chunks(chunk_ids, config=config) # type: ignore[arg-type,assignment] + if all(isinstance(chunk_id, ChunkId) for chunk_id in chunk_ids) + else chunk_ids + ) + # Assign a reciprocal ranking score to each chunk based on its position in the original list. + chunk_id_to_score = {chunk.id: 1 / (i + 1) for i, chunk in enumerate(chunks)} + # Extend the chunks with their neighbouring chunks. + engine = create_database_engine(config) + with Session(engine) as session: + if neighbors: + neighbor_conditions = [ + and_(Chunk.document_id == chunk.document_id, Chunk.index == chunk.index + offset) + for chunk in chunks + for offset in neighbors + ] + chunks += list( + session.exec( + select(Chunk) + .where(or_(*neighbor_conditions)) + # Eagerly load chunk.document. + .options(joinedload(Chunk.document)) # type: ignore[arg-type] + ).all() + ) + # Deduplicate and sort the chunks by document_id and index (needed for groupby). + unique_chunks = sorted(set(chunks), key=lambda chunk: (chunk.document_id, chunk.index)) + # Group the chunks into contiguous segments. + chunk_spans: list[ChunkSpan] = [] + for _, group in groupby(unique_chunks, key=lambda chunk: chunk.document_id): + chunk_sequence: list[Chunk] = [] + for chunk in group: + if not chunk_sequence or chunk.index == chunk_sequence[-1].index + 1: + chunk_sequence.append(chunk) + else: + chunk_spans.append( + ChunkSpan(chunks=chunk_sequence, document=chunk_sequence[0].document) + ) + chunk_sequence = [chunk] + chunk_spans.append(ChunkSpan(chunks=chunk_sequence, document=chunk_sequence[0].document)) + # Rank segments according to the aggregate relevance of their chunks. + chunk_spans.sort( + key=lambda chunk_span: sum( + chunk_id_to_score.get(chunk.id, 0.0) for chunk in chunk_span.chunks + ), + reverse=True, + ) + return chunk_spans diff --git a/src/raglite/_typing.py b/src/raglite/_typing.py index 07a6904..9846ecc 100644 --- a/src/raglite/_typing.py +++ b/src/raglite/_typing.py @@ -12,6 +12,11 @@ from raglite._config import RAGLiteConfig +ChunkId = str +DocumentId = str +EvalId = str +IndexId = str + FloatMatrix = np.ndarray[tuple[int, int], np.dtype[np.floating[Any]]] FloatVector = np.ndarray[tuple[int], np.dtype[np.floating[Any]]] IntVector = np.ndarray[tuple[int], np.dtype[np.intp]] diff --git a/tests/test_rag.py b/tests/test_rag.py index 3cbcf04..7643bcf 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -1,17 +1,16 @@ """Test RAGLite's RAG functionality.""" import os -from typing import TYPE_CHECKING import pytest from llama_cpp import llama_supports_gpu_offload -from raglite import RAGLiteConfig, hybrid_search, retrieve_chunks -from raglite._rag import generate, get_context_segments - -if TYPE_CHECKING: - from raglite._database import Chunk - from raglite._typing import SearchMethod +from raglite import ( + RAGLiteConfig, + create_rag_instruction, + retrieve_rag_context, +) +from raglite._rag import rag def is_accelerator_available() -> bool: @@ -22,21 +21,13 @@ def is_accelerator_available() -> bool: @pytest.mark.skipif(not is_accelerator_available(), reason="No accelerator available") def test_rag(raglite_test_config: RAGLiteConfig) -> None: """Test Retrieval-Augmented Generation.""" - # Assemble different types of search inputs for RAG. - prompt = "What does it mean for two events to be simultaneous?" - search_inputs: list[SearchMethod | list[str] | list[Chunk]] = [ - hybrid_search, # A search method as input. - hybrid_search(prompt, config=raglite_test_config)[0], # Chunk ids as input. - retrieve_chunks( # Chunks as input. - hybrid_search(prompt, config=raglite_test_config)[0], config=raglite_test_config - ), - ] # Answer a question with RAG. - for search_input in search_inputs: - segments = get_context_segments(prompt, search=search_input, config=raglite_test_config) - stream = generate(prompt, context_segments=segments, config=raglite_test_config) - answer = "" - for update in stream: - assert isinstance(update, str) - answer += update - assert "simultaneous" in answer.lower() + user_prompt = "What does it mean for two events to be simultaneous?" + chunk_spans = retrieve_rag_context(query=user_prompt, config=raglite_test_config) + messages = [create_rag_instruction(user_prompt, context=chunk_spans)] + stream = rag(messages, config=raglite_test_config) + answer = "" + for update in stream: + assert isinstance(update, str) + answer += update + assert "simultaneous" in answer.lower() diff --git a/tests/test_search.py b/tests/test_search.py index 8e74fd3..e677465 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -6,11 +6,11 @@ RAGLiteConfig, hybrid_search, keyword_search, + retrieve_chunk_spans, retrieve_chunks, - retrieve_segments, vector_search, ) -from raglite._database import Chunk, ContextSegment +from raglite._database import Chunk, ChunkSpan, Document from raglite._typing import SearchMethod @@ -43,9 +43,14 @@ def test_search(raglite_test_config: RAGLiteConfig, search_method: SearchMethod) assert all(isinstance(chunk, Chunk) for chunk in chunks) assert all(chunk_id == chunk.id for chunk_id, chunk in zip(chunk_ids, chunks, strict=True)) assert any("Definition of Simultaneity" in str(chunk) for chunk in chunks) + assert all(isinstance(chunk.document, Document) for chunk in chunks) # Extend the chunks with their neighbours and group them into contiguous segments. - segments = retrieve_segments(chunk_ids, neighbors=(-1, 1), config=raglite_test_config) - assert all(isinstance(segment, ContextSegment) for segment in segments) + chunk_spans = retrieve_chunk_spans(chunk_ids, neighbors=(-1, 1), config=raglite_test_config) + assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans) + assert all(isinstance(chunk_span.document, Document) for chunk_span in chunk_spans) + chunk_spans = retrieve_chunk_spans(chunks, neighbors=(-1, 1), config=raglite_test_config) + assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans) + assert all(isinstance(chunk_span.document, Document) for chunk_span in chunk_spans) def test_search_no_results(raglite_test_config: RAGLiteConfig, search_method: SearchMethod) -> None: From 60a72928851fbac56dd30c4cd9b45d1c926d3c94 Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Tue, 3 Dec 2024 12:51:31 +0100 Subject: [PATCH 08/12] docs: add new features to the README --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 572aa04..4b1be88 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,8 @@ RAGLite is a Python toolkit for Retrieval-Augmented Generation (RAG) with Postgr - 🧬 Multi-vector chunk embedding with [late chunking](https://weaviate.io/blog/late-chunking) and [contextual chunk headings](https://d-star.ai/solving-the-out-of-context-chunk-problem-for-rag) - ✂️ Optimal [level 4 semantic chunking](https://medium.com/@anuragmishra_27746/five-levels-of-chunking-strategies-in-rag-notes-from-gregs-video-7b735895694d) by solving a [binary integer programming problem](https://en.wikipedia.org/wiki/Integer_programming) - 🔍 [Hybrid search](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf) with the database's native keyword & vector search ([tsvector](https://www.postgresql.org/docs/current/datatype-textsearch.html)+[pgvector](https://github.com/pgvector/pgvector), [FTS5](https://www.sqlite.org/fts5.html)+[sqlite-vec](https://github.com/asg017/sqlite-vec)[^1]) +- 💰 Improved cost and latency with a [prompt caching-aware message array structure](https://platform.openai.com/docs/guides/prompt-caching) +- 🍰 Improved output quality with [Anthropic's long-context prompt format](https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips) - 🌀 Optimal [closed-form linear query adapter](src/raglite/_query_adapter.py) by solving an [orthogonal Procrustes problem](https://en.wikipedia.org/wiki/Orthogonal_Procrustes_problem) ##### Extensible @@ -190,7 +192,7 @@ In addition to the simple RAG pipeline, RAGLite also offers more advanced contro 1. Searching for relevant chunks with keyword, vector, or hybrid search 2. Retrieving the chunks from the database -3. Reranking the chunks and truncating the results to the top 5 +3. Reranking the chunks and selecting the top 5 results 4. Extending the chunks with their neighbors and grouping them into chunk spans 5. Converting the user prompt to a RAG instruction and appending it to the message history 6. Streaming an LLM response to the message history From 4a559572123b2f283dbed7afba1b068701ef570d Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Tue, 3 Dec 2024 13:17:17 +0100 Subject: [PATCH 09/12] fix: correct Chainlit message history --- src/raglite/_chainlit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/raglite/_chainlit.py b/src/raglite/_chainlit.py index bebd730..785a937 100644 --- a/src/raglite/_chainlit.py +++ b/src/raglite/_chainlit.py @@ -114,7 +114,7 @@ async def handle_message(user_message: cl.Message) -> None: ] # Stream the LLM response. assistant_message = cl.Message(content="") - messages: list[dict[str, str]] = cl.chat_context.to_openai() # type: ignore[no-untyped-call] + messages: list[dict[str, str]] = cl.chat_context.to_openai()[:-1] # type: ignore[no-untyped-call] messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans)) async for token in async_rag(messages, config=config): await assistant_message.stream_token(token) From b450fe37a3c92a47b512a3b55822c01829c2e019 Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Tue, 3 Dec 2024 13:31:43 +0100 Subject: [PATCH 10/12] fix: add workaround for incorrect Chainlit step output order --- src/raglite/_chainlit.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/raglite/_chainlit.py b/src/raglite/_chainlit.py index 785a937..1f3eeeb 100644 --- a/src/raglite/_chainlit.py +++ b/src/raglite/_chainlit.py @@ -103,6 +103,7 @@ async def handle_message(user_message: cl.Message) -> None: step.elements = [ # Show the top chunks inline. cl.Text(content=str(chunk), display="inline") for chunk in chunks[:5] ] + await step.update() # TODO: Workaround for https://github.com/Chainlit/chainlit/issues/602. # Rerank the chunks and group them into chunk spans. async with cl.Step(name="rerank", type="rerank") as step: step.input = chunks @@ -112,6 +113,7 @@ async def handle_message(user_message: cl.Message) -> None: step.elements = [ # Show the top chunk spans inline. cl.Text(content=str(chunk_span), display="inline") for chunk_span in chunk_spans ] + await step.update() # TODO: Workaround for https://github.com/Chainlit/chainlit/issues/602. # Stream the LLM response. assistant_message = cl.Message(content="") messages: list[dict[str, str]] = cl.chat_context.to_openai()[:-1] # type: ignore[no-untyped-call] From 555648ba1c795eea0770861ee80a80b966990099 Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Tue, 3 Dec 2024 13:53:56 +0100 Subject: [PATCH 11/12] fix: replace from|to_chunk_id with from|to_chunk_index --- src/raglite/_database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/raglite/_database.py b/src/raglite/_database.py index 8c47f4a..55761b8 100644 --- a/src/raglite/_database.py +++ b/src/raglite/_database.py @@ -166,7 +166,7 @@ def to_xml(self, index: int | None = None) -> str: index_attribute = f' index="{index}"' if index is not None else "" xml = "\n".join( [ - f'', + f'', f"{self.document.url if self.document.url else self.document.filename}" f"{escape(self.chunks[0].headings.strip())}" f"\n{escape(''.join(chunk.body for chunk in self.chunks).strip())}\n", From 91a41b0fe8508d90aa551e3965890f1aa35f8064 Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Tue, 3 Dec 2024 17:42:30 +0100 Subject: [PATCH 12/12] fix: avoid inline numerical citations --- src/raglite/_database.py | 19 +++++++++++++------ src/raglite/_rag.py | 10 ++++------ src/raglite/_search.py | 6 ++---- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/src/raglite/_database.py b/src/raglite/_database.py index 55761b8..510a3bb 100644 --- a/src/raglite/_database.py +++ b/src/raglite/_database.py @@ -2,7 +2,7 @@ import datetime import json -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import lru_cache from hashlib import sha256 from pathlib import Path @@ -152,7 +152,12 @@ class ChunkSpan: """A consecutive sequence of chunks from a single document.""" chunks: list[Chunk] - document: Document + document: Document = field(init=False) + + def __post_init__(self) -> None: + """Set the document field.""" + if self.chunks: + self.document = self.chunks[0].document def to_xml(self, index: int | None = None) -> str: """Convert this chunk span to an XML representation. @@ -166,10 +171,12 @@ def to_xml(self, index: int | None = None) -> str: index_attribute = f' index="{index}"' if index is not None else "" xml = "\n".join( [ - f'', - f"{self.document.url if self.document.url else self.document.filename}" - f"{escape(self.chunks[0].headings.strip())}" - f"\n{escape(''.join(chunk.body for chunk in self.chunks).strip())}\n", + f'', + f"{self.document.url if self.document.url else self.document.filename}", + f'', + f"\n{escape(self.chunks[0].headings.strip())}\n", + f"\n{escape(''.join(chunk.body for chunk in self.chunks).strip())}\n", + "", "", ] ) diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index d9db48d..8fb1a0c 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -15,11 +15,9 @@ # [1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips RAG_INSTRUCTION_TEMPLATE = """ You are a friendly and knowledgeable assistant that provides complete and insightful answers. -You MUST observe the following rules: -1. Whenever possible, use only the provided context below to answer the question at the end. -2. Cite your sources with inline numerical citations of the form "[n]", where n is the document index. - Use commas to separate citations as "[a], [b], [c]" when citing multiple sources consecutively. -3. Do not print a list of sources at the end. +Whenever possible, use only the provided context to respond to the question at the end. +When responding, you MUST NOT reference the existence of the context, directly or indirectly. +Instead, you MUST treat the context as if its contents are entirely part of your working memory. {context} @@ -63,7 +61,7 @@ def create_rag_instruction( message = { "role": "user", "content": rag_instruction_template.format( - user_prompt=user_prompt, + user_prompt=user_prompt.strip(), context="\n".join( chunk_span.to_xml(index=i + 1) for i, chunk_span in enumerate(context) ), diff --git a/src/raglite/_search.py b/src/raglite/_search.py index 2a00675..b7976cb 100644 --- a/src/raglite/_search.py +++ b/src/raglite/_search.py @@ -275,11 +275,9 @@ def retrieve_chunk_spans( if not chunk_sequence or chunk.index == chunk_sequence[-1].index + 1: chunk_sequence.append(chunk) else: - chunk_spans.append( - ChunkSpan(chunks=chunk_sequence, document=chunk_sequence[0].document) - ) + chunk_spans.append(ChunkSpan(chunks=chunk_sequence)) chunk_sequence = [chunk] - chunk_spans.append(ChunkSpan(chunks=chunk_sequence, document=chunk_sequence[0].document)) + chunk_spans.append(ChunkSpan(chunks=chunk_sequence)) # Rank segments according to the aggregate relevance of their chunks. chunk_spans.sort( key=lambda chunk_span: sum(