Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored ChromadbRM to use generic EmbeddingFunction #400

Merged
merged 3 commits into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion docs/retrieval_models_client.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ This documentation provides an overview of the DSPy Retrieval Model Clients.
| --- | --- |
| ColBERTv2 | [ColBERTv2 Section](#ColBERTv2) |
| AzureCognitiveSearch | [AzureCognitiveSearch Section](#AzureCognitiveSearch) |
| ChromadbRM | [ChromadbRM Section](#ChromadbRM) |

## ColBERTv2

Expand Down Expand Up @@ -91,4 +92,66 @@ class AzureCognitiveSearch:

Refer to [ColBERTv2](#ColBERTv2) documentation. Keep in mind there is no `simplify` flag for AzureCognitiveSearch.

AzureCognitiveSearch supports sending queries and processing the received results, mapping content and scores to a correct format for the Azure Cognitive Search server.
AzureCognitiveSearch supports sending queries and processing the received results, mapping content and scores to a correct format for the Azure Cognitive Search server.

## ChromadbRM

### Quickstart with OpenAI Embeddings

ChromadbRM have the flexibility from a variety of embedding functions as outlined in the [chromadb embeddings documentation](https://docs.trychroma.com/embeddings). While different options are available, this example demonstrates how to utilize OpenAI embeddings specifically.

```python
from dspy.retrieve import ChromadbRM
import os
import openai
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction

embedding_function = OpenAIEmbeddingFunction(
api_key=os.environ.get('OPENAI_API_KEY'),
model_name="text-embedding-ada-002"
)

retriever_model = ChromadbRM(
'your_collection_name',
'/path/to/your/db',
embedding_function=embedding_function,
k=5
)

results = retriever_model("Explore the significance of quantum computing", k=5)

for result in results:
print("Document:", result.long_text, "\n")
```

### Constructor

Initialize an instance of the `ChromadbRM` class, with the option to use OpenAI's embeddings or any alternative supported by chromadb, as detailed in the official [chromadb embeddings documentation](https://docs.trychroma.com/embeddings).

```python
ChromadbRM(
collection_name: str,
persist_directory: str,
embedding_function: Optional[EmbeddingFunction[Embeddable]] = OpenAIEmbeddingFunction(),
k: int = 7,
)
```

**Parameters:**
- `collection_name` (_str_): The name of the chromadb collection.
- `persist_directory` (_str_): Path to the directory where chromadb data is persisted.
- `embedding_function` (_Optional[EmbeddingFunction[Embeddable]]_, _optional_): The function used for embedding documents and queries. Defaults to `DefaultEmbeddingFunction()` if not specified.
- `k` (_int_, _optional_): The number of top passages to retrieve. Defaults to 7.

### Methods

#### `forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None) -> dspy.Prediction`

Search the chromadb collection for the top `k` passages matching the given query or queries, using embeddings generated via the specified `embedding_function`.

**Parameters:**
- `query_or_queries` (_Union[str, List[str]]_): The query or list of queries to search for.
- `k` (_Optional[int]_, _optional_): The number of results to retrieve. If not specified, defaults to the value set during initialization.

**Returns:**
- `dspy.Prediction`: Contains the retrieved passages, each represented as a `dotdict` with a `long_text` attribute.
45 changes: 12 additions & 33 deletions dspy/retrieve/chromadb_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
import chromadb
from chromadb.config import Settings
from chromadb.utils import embedding_functions
from chromadb.api.types import (
Embeddable,
EmbeddingFunction
)
import chromadb.utils.embedding_functions as ef
except ImportError:
chromadb = None

Expand All @@ -37,10 +42,8 @@ class ChromadbRM(dspy.Retrieve):
Args:
collection_name (str): chromadb collection name
persist_directory (str): chromadb persist directory
openai_embed_model (str, optional): The OpenAI embedding model to use. Defaults to "text-embedding-ada-002".
openai_api_key (str, optional): The API key for OpenAI. Defaults to None.
openai_org (str, optional): The organization for OpenAI. Defaults to None.
k (int, optional): The number of top passages to retrieve. Defaults to 3.
embedding_function (Optional[EmbeddingFunction[Embeddable]]): Optional function to use to embed documents. Defaults to DefaultEmbeddingFunction.
k (int, optional): The number of top passages to retrieve. Defaults to 7.

Returns:
dspy.Prediction: An object containing the retrieved passages.
Expand All @@ -65,29 +68,14 @@ def __init__(
self,
collection_name: str,
persist_directory: str,
openai_embed_model: str = "text-embedding-ada-002",
openai_api_provider: Optional[str] = None,
openai_api_key: Optional[str] = None,
openai_api_type: Optional[str] = None,
openai_api_base: Optional[str] = None,
openai_api_version: Optional[str] = None,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(),
k: int = 7,
):
self._openai_embed_model = openai_embed_model

self._init_chromadb(collection_name, persist_directory)

self.openai_ef = embedding_functions.OpenAIEmbeddingFunction(
api_key=openai_api_key,
api_base=openai_api_base,
api_type=openai_api_type,
api_version=openai_api_version,
model_name=openai_embed_model,
)
self.api_version = openai_api_version
self.api_base = openai_api_base
self.model_name = openai_embed_model
self.openai_api_type = openai_api_type
self.ef = embedding_function

super().__init__(k=k)

Expand Down Expand Up @@ -130,16 +118,7 @@ def _get_embeddings(self, queries: List[str]) -> List[List[float]]:
Returns:
List[List[float]]: List of embeddings corresponding to each query.
"""

model_arg = {"engine": self.model_name,
"deployment_id": self.model_name,
"api_version": self.api_version,
"api_base": self.api_base,
}
embedding = self.openai_ef._client.create(
input=queries, model=self._openai_embed_model, **model_arg, api_provider=self.openai_api_type
)
return [embedding.embedding for embedding in embedding.data]
return self.ef(queries)

def forward(
self, query_or_queries: Union[str, List[str]], k: Optional[int] = None
Expand Down
Loading