From f5856680fe55ca1ee515d8ece7aa3032413eb38e Mon Sep 17 00:00:00 2001 From: Philippe PRADOS Date: Fri, 19 Jul 2024 19:54:12 +0200 Subject: [PATCH] community[minor]: add mongodb byte store (#23876) The `MongoDBStore` can manage only documents. It's not possible to use MongoDB for an `CacheBackedEmbeddings`. With this new implementation, it's possible to use: ```python CacheBackedEmbeddings.from_bytes_store( underlying_embeddings=embeddings, document_embedding_cache=MongoDBByteStore( connection_string=db_uri, db_name=db_name, collection_name=collection_name, ), ) ``` and use MongoDB to cache the embeddings ! --- .../langchain_community/storage/__init__.py | 6 +- .../langchain_community/storage/mongodb.py | 124 +++++++++++++++++- .../integration_tests/storage/test_mongodb.py | 21 ++- .../tests/unit_tests/storage/test_imports.py | 1 + 4 files changed, 146 insertions(+), 6 deletions(-) diff --git a/libs/community/langchain_community/storage/__init__.py b/libs/community/langchain_community/storage/__init__.py index d75b497bf584f..21a6090bd143d 100644 --- a/libs/community/langchain_community/storage/__init__.py +++ b/libs/community/langchain_community/storage/__init__.py @@ -25,9 +25,7 @@ from langchain_community.storage.cassandra import ( CassandraByteStore, ) - from langchain_community.storage.mongodb import ( - MongoDBStore, - ) + from langchain_community.storage.mongodb import MongoDBByteStore, MongoDBStore from langchain_community.storage.redis import ( RedisStore, ) @@ -44,6 +42,7 @@ "AstraDBStore", "CassandraByteStore", "MongoDBStore", + "MongoDBByteStore", "RedisStore", "SQLStore", "UpstashRedisByteStore", @@ -55,6 +54,7 @@ "AstraDBStore": "langchain_community.storage.astradb", "CassandraByteStore": "langchain_community.storage.cassandra", "MongoDBStore": "langchain_community.storage.mongodb", + "MongoDBByteStore": "langchain_community.storage.mongodb", "RedisStore": "langchain_community.storage.redis", "SQLStore": "langchain_community.storage.sql", "UpstashRedisByteStore": "langchain_community.storage.upstash_redis", diff --git a/libs/community/langchain_community/storage/mongodb.py b/libs/community/langchain_community/storage/mongodb.py index 97447f832175e..264faefcba649 100644 --- a/libs/community/langchain_community/storage/mongodb.py +++ b/libs/community/langchain_community/storage/mongodb.py @@ -4,6 +4,126 @@ from langchain_core.stores import BaseStore +class MongoDBByteStore(BaseStore[str, bytes]): + """BaseStore implementation using MongoDB as the underlying store. + + Examples: + Create a MongoDBByteStore instance and perform operations on it: + + .. code-block:: python + + # Instantiate the MongoDBByteStore with a MongoDB connection + from langchain.storage import MongoDBByteStore + + mongo_conn_str = "mongodb://localhost:27017/" + mongodb_store = MongoDBBytesStore(mongo_conn_str, db_name="test-db", + collection_name="test-collection") + + # Set values for keys + mongodb_store.mset([("key1", "hello"), ("key2", "workd")]) + + # Get values for keys + values = mongodb_store.mget(["key1", "key2"]) + # [bytes1, bytes1] + + # Iterate over keys + for key in mongodb_store.yield_keys(): + print(key) + + # Delete keys + mongodb_store.mdelete(["key1", "key2"]) + """ + + def __init__( + self, + connection_string: str, + db_name: str, + collection_name: str, + *, + client_kwargs: Optional[dict] = None, + ) -> None: + """Initialize the MongoDBStore with a MongoDB connection string. + + Args: + connection_string (str): MongoDB connection string + db_name (str): name to use + collection_name (str): collection name to use + client_kwargs (dict): Keyword arguments to pass to the Mongo client + """ + try: + from pymongo import MongoClient + except ImportError as e: + raise ImportError( + "The MongoDBStore requires the pymongo library to be " + "installed. " + "pip install pymongo" + ) from e + + if not connection_string: + raise ValueError("connection_string must be provided.") + if not db_name: + raise ValueError("db_name must be provided.") + if not collection_name: + raise ValueError("collection_name must be provided.") + + self.client: MongoClient = MongoClient( + connection_string, **(client_kwargs or {}) + ) + self.collection = self.client[db_name][collection_name] + + def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]: + """Get the list of documents associated with the given keys. + + Args: + keys (list[str]): A list of keys representing Document IDs.. + + Returns: + list[Document]: A list of Documents corresponding to the provided + keys, where each Document is either retrieved successfully or + represented as None if not found. + """ + result = self.collection.find({"_id": {"$in": keys}}) + result_dict = {doc["_id"]: doc["value"] for doc in result} + return [result_dict.get(key) for key in keys] + + def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: + """Set the given key-value pairs. + + Args: + key_value_pairs (list[tuple[str, Document]]): A list of id-document + pairs. + """ + from pymongo import UpdateOne + + updates = [{"_id": k, "value": v} for k, v in key_value_pairs] + self.collection.bulk_write( + [UpdateOne({"_id": u["_id"]}, {"$set": u}, upsert=True) for u in updates] + ) + + def mdelete(self, keys: Sequence[str]) -> None: + """Delete the given ids. + + Args: + keys (list[str]): A list of keys representing Document IDs.. + """ + self.collection.delete_many({"_id": {"$in": keys}}) + + def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]: + """Yield keys in the store. + + Args: + prefix (str): prefix of keys to retrieve. + """ + if prefix is None: + for doc in self.collection.find(projection=["_id"]): + yield doc["_id"] + else: + for doc in self.collection.find( + {"_id": {"$regex": f"^{prefix}"}}, projection=["_id"] + ): + yield doc["_id"] + + class MongoDBStore(BaseStore[str, Document]): """BaseStore implementation using MongoDB as the underlying store. @@ -68,7 +188,9 @@ def __init__( if not collection_name: raise ValueError("collection_name must be provided.") - self.client = MongoClient(connection_string, **(client_kwargs or {})) + self.client: MongoClient = MongoClient( + connection_string, **(client_kwargs or {}) + ) self.collection = self.client[db_name][collection_name] def mget(self, keys: Sequence[str]) -> List[Optional[Document]]: diff --git a/libs/community/tests/integration_tests/storage/test_mongodb.py b/libs/community/tests/integration_tests/storage/test_mongodb.py index 44062e994e051..850e4f73c2b47 100644 --- a/libs/community/tests/integration_tests/storage/test_mongodb.py +++ b/libs/community/tests/integration_tests/storage/test_mongodb.py @@ -1,9 +1,10 @@ -from typing import Generator +from typing import Generator, Tuple import pytest from langchain_core.documents import Document +from langchain_standard_tests.integration_tests.base_store import BaseStoreSyncTests -from langchain_community.storage.mongodb import MongoDBStore +from langchain_community.storage.mongodb import MongoDBByteStore, MongoDBStore pytest.importorskip("pymongo") @@ -71,3 +72,19 @@ def test_mdelete(mongo_store: MongoDBStore) -> None: def test_init_errors() -> None: with pytest.raises(ValueError): MongoDBStore("", "", "") + + +class TestMongoDBStore(BaseStoreSyncTests): + @pytest.fixture + def three_values(self) -> Tuple[bytes, bytes, bytes]: # <-- Provide 3 + return b"foo", b"bar", b"buzz" + + @pytest.fixture + def kv_store(self) -> MongoDBByteStore: + import mongomock + + # mongomock creates a mock MongoDB instance for testing purposes + with mongomock.patch(servers=(("localhost", 27017),)): + return MongoDBByteStore( + "mongodb://localhost:27017/", "test_db", "test_collection" + ) diff --git a/libs/community/tests/unit_tests/storage/test_imports.py b/libs/community/tests/unit_tests/storage/test_imports.py index 791f0298cc5e7..2501be0ce17f2 100644 --- a/libs/community/tests/unit_tests/storage/test_imports.py +++ b/libs/community/tests/unit_tests/storage/test_imports.py @@ -4,6 +4,7 @@ "AstraDBStore", "AstraDBByteStore", "CassandraByteStore", + "MongoDBByteStore", "MongoDBStore", "SQLStore", "RedisStore",