Skip to content

Commit

Permalink
Add framework integration for composite retrieval (#17536)
Browse files Browse the repository at this point in the history
  • Loading branch information
sourabhdesai authored Jan 21, 2025
1 parent 6904770 commit b020aa0
Show file tree
Hide file tree
Showing 9 changed files with 5,294 additions and 116 deletions.
69 changes: 69 additions & 0 deletions docs/docs/module_guides/indexing/llama_cloud_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ Currently, LlamaCloud supports
- Managed Ingestion API, handling parsing and document management
- Managed Retrieval API, configuring optimal retrieval for your RAG system

For additional documentation on LlamaCloud and this integration in particular, please reference our [official LlamaCloud docs](https://docs.cloud.llamaindex.ai/llamacloud/guides/framework_integration).

## Access

We are opening up a private beta to a limited set of enterprise partners for the managed ingestion and retrieval API. If you’re interested in centralizing your data pipelines and spending more time working on your actual RAG use cases, come [talk to us.](https://www.llamaindex.ai/contact)
Expand Down Expand Up @@ -84,3 +86,70 @@ A full list of retriever settings/kwargs is below:
- `enable_reranking`: Optional[bool] -- Whether to enable reranking or not. Sacrifices some speed for accuracy
- `rerank_top_n`: Optional[int] -- The number of nodes to return after reranking initial retrieval results
- `alpha` Optional[float] -- The weighting between dense and sparse retrieval. 1 = Full dense retrieval, 0 = Full sparse retrieval.


## Composite Retrieval Usage

Once you've setup multiple indexes that are ingesting various forms of data, you may want to create an application that can query over the data across all of your indices.

This is where you can use the `LlamaCloudCompositeRetriever` class. The following snippet shows you how to setup the composite retriever:

```python
import os
from llama_cloud import CompositeRetrievalMode, RetrieverPipeline
from llama_index.indices.managed.llama_cloud import (
LlamaCloudIndex,
LlamaCloudCompositeRetriever,
)

llama_cloud_api_key = os.environ["LLAMA_CLOUD_API_KEY"]
project_name = "Essays"

# Setup your indices
pg_documents = SimpleDirectoryReader("./examples/data/paul_graham").load_data()
pg_index = LlamaCloudIndex.from_documents(
documents=pg_documents,
name="PG Index",
project_name=project_name,
api_key=llama_cloud_api_key,
)

sama_documents = SimpleDirectoryReader(
"./examples/data/sam_altman"
).load_data()
sama_index = LlamaCloudIndex.from_documents(
documents=sama_documents,
name="Sam Index",
project_name=project_name,
api_key=llama_cloud_api_key,
)

retriever = LlamaCloudCompositeRetriever(
name="Essays Retriever",
project_name=project_name,
api_key=llama_cloud_api_key,
# If a Retriever named "Essays Retriever" doesn't already exist, one will be created
create_if_not_exists=True,
# CompositeRetrievalMode.FULL will query each index individually and globally rerank results at the end
mode=CompositeRetrievalMode.FULL,
rerank_top_n=5,
)

# Add the above indices to the composite retriever
# Carefully craft the description as this is used internally to route a query to an attached sub-index when CompositeRetrievalMode.ROUTING is used
retriever.add_index(pg_index, description="A collection of Paul Graham essays")
retriever.add_index(
sama_index, description="A collection of Sam Altman essays"
)

# Start retrieving context for your queries
# async .aretrieve() is also available
nodes = retriever.retrieve("What does YC do?")
```

### Composite Retrieval related parameters
There are a few parameters that are specific to tuning the composite retrieval parameters:
- `mode`: `Optional[CompositeRetrievalMode]` -- Can either be `CompositeRetrievalMode.FULL` or `CompositeRetrievalMode.ROUTING`
- `full`: In this mode, all attached sub-indices will be queried and reranking will be executed across all nodes received from these sub-indices.
- `routing`: In this mode, an agent determines which sub-indices are most relevant to the provided query (based on the sub-index's `name` & `description` you've provided) and only queries those indices that are deemed relevant. Only the nodes from that chosen subset of indices are then reranked before being returned in the retrieval response.
- `rerank_top_n`: `Optional[int]` -- Determines how many nodes to return after re-ranking across the nodes retrieved from all indices
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from llama_index.indices.managed.llama_cloud.base import LlamaCloudIndex
from llama_index.indices.managed.llama_cloud.retriever import LlamaCloudRetriever
from llama_index.indices.managed.llama_cloud.composite_retriever import (
LlamaCloudCompositeRetriever,
)

__all__ = [
"LlamaCloudIndex",
"LlamaCloudRetriever",
"LlamaCloudCompositeRetriever",
]
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import urllib.parse
from httpx import Request

from llama_index.core.async_utils import run_jobs
from llama_cloud import (
AutoTransformConfig,
Pipeline,
Expand All @@ -9,7 +12,9 @@
PipelineType,
Project,
)
from llama_cloud.client import LlamaCloud
from llama_cloud.core import remove_none_from_dict
from llama_cloud.client import LlamaCloud, AsyncLlamaCloud
from llama_cloud.core.api_error import ApiError


def default_embedding_config() -> PipelineCreateEmbeddingConfig:
Expand Down Expand Up @@ -96,3 +101,104 @@ def resolve_project_and_pipeline(
)

return project, pipeline


def _build_get_page_screenshot_request(
client: Union[LlamaCloud, AsyncLlamaCloud],
file_id: str,
page_index: int,
project_id: str,
) -> Request:
return client._client_wrapper.httpx_client.build_request(
"GET",
urllib.parse.urljoin(
f"{client._client_wrapper.get_base_url()}/",
f"api/v1/files/{file_id}/page_screenshots/{page_index}",
),
params=remove_none_from_dict({"project_id": project_id}),
headers=client._client_wrapper.get_headers(),
timeout=60,
)


def get_page_screenshot(
client: LlamaCloud, file_id: str, page_index: int, project_id: str
) -> str:
"""Get the page screenshot."""
# TODO: this currently uses requests, should be replaced with the client
request = _build_get_page_screenshot_request(
client, file_id, page_index, project_id
)
_response = client._client_wrapper.httpx_client.send(request)
if 200 <= _response.status_code < 300:
return _response.content
else:
raise ApiError(status_code=_response.status_code, body=_response.text)


async def aget_page_screenshot(
client: AsyncLlamaCloud, file_id: str, page_index: int, project_id: str
) -> str:
"""Get the page screenshot (async)."""
request = _build_get_page_screenshot_request(
client, file_id, page_index, project_id
)
_response = await client._client_wrapper.httpx_client.send(request)
if 200 <= _response.status_code < 300:
return _response.content
else:
raise ApiError(status_code=_response.status_code, body=_response.text)


from typing import List
import base64
from llama_cloud import PageScreenshotNodeWithScore
from llama_index.core.schema import NodeWithScore, ImageNode
from llama_cloud.client import LlamaCloud, AsyncLlamaCloud


def image_nodes_to_node_with_score(
client: LlamaCloud,
raw_image_nodes: List[PageScreenshotNodeWithScore],
project_id: str,
) -> List[NodeWithScore]:
image_nodes = []
for raw_image_node in raw_image_nodes:
image_bytes = get_page_screenshot(
client=client,
file_id=raw_image_node.node.file_id,
page_index=raw_image_node.node.page_index,
project_id=project_id,
)
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
image_node_with_score = NodeWithScore(
node=ImageNode(image=image_base64), score=raw_image_node.score
)
image_nodes.append(image_node_with_score)
return image_nodes


async def aimage_nodes_to_node_with_score(
client: AsyncLlamaCloud,
raw_image_nodes: List[PageScreenshotNodeWithScore],
project_id: str,
) -> List[NodeWithScore]:
image_nodes = []
tasks = [
aget_page_screenshot(
client=client,
file_id=raw_image_node.node.file_id,
page_index=raw_image_node.node.page_index,
project_id=project_id,
)
for raw_image_node in raw_image_nodes
]

image_bytes_list = await run_jobs(tasks)
for image_bytes, raw_image_node in zip(image_bytes_list, raw_image_nodes):
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
image_node_with_score = NodeWithScore(
node=ImageNode(image=image_base64), score=raw_image_node.score
)
image_nodes.append(image_node_with_score)
return image_nodes
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ def __init__(
self._service_context = None
self._callback_manager = callback_manager or Settings.callback_manager

@property
def id(self) -> str:
"""Return the pipeline (aka index) ID."""
return self.pipeline.id

def wait_for_completion(
self,
verbose: bool = False,
Expand Down
Loading

0 comments on commit b020aa0

Please sign in to comment.