Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qdrant 2.0 #742

Merged
merged 1 commit into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions superagi/vector_store/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,18 +213,19 @@ def __build_documents(
def create_collection(cls,
client: QdrantClient,
collection_name: str,
vector_params: VectorParams = VectorParams(size=1536, distance=Distance.COSINE)
size: int
):
"""
Create a new collection in Qdrant if it does not exist.

Args:
client : The Qdrant client.
collection_name: The name of the collection to create.
vector_params: The vector parameters for the new collection.
size: The size for the new collection.
"""
if not any(collection.name == collection_name for collection in client.get_collections().collections):
print("here")
client.create_collection(
collection_name=collection_name,
vectors_config=vector_params,
)
vectors_config=VectorParams(size=size, distance=Distance.COSINE),
)
13 changes: 9 additions & 4 deletions superagi/vector_store/vector_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def get_vector_storage(cls, vector_store: VectorStoreType, index_name, embedding
Returns:
The vector storage object.
"""
vector_store = VectorStoreType.get_vector_store_type(vector_store)
if isinstance(vector_store, str):
vector_store = VectorStoreType.get_vector_store_type(vector_store)
if vector_store == VectorStoreType.PINECONE:
try:
api_key = get_config("PINECONE_API_KEY")
Expand All @@ -51,7 +52,7 @@ def get_vector_storage(cls, vector_store: VectorStoreType, index_name, embedding
except UnauthorizedException:
raise ValueError("PineCone API key not found")

if vector_store == "Weaviate":
if vector_store == VectorStoreType.WEAVIATE:
use_embedded = get_config("WEAVIATE_USE_EMBEDDED")
url = get_config("WEAVIATE_URL")
api_key = get_config("WEAVIATE_API_KEY")
Expand All @@ -65,7 +66,11 @@ def get_vector_storage(cls, vector_store: VectorStoreType, index_name, embedding

if vector_store == VectorStoreType.QDRANT:
client = qdrant.create_qdrant_client()
Qdrant.create_collection(client, index_name)
sample_embedding = embedding_model.get_embedding("sample")
if "error" in sample_embedding:
logger.error(f"Error in embedding model {sample_embedding}")

Qdrant.create_collection(client, index_name, len(sample_embedding))
return qdrant.Qdrant(client, embedding_model, index_name)

raise ValueError(f"Vector store {vector_store} not supported")
raise ValueError(f"Vector store {vector_store} not supported")
Empty file.
67 changes: 67 additions & 0 deletions tests/unit_tests/vector_store/test_vector_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import unittest
from unittest.mock import patch, MagicMock
from superagi.types.vector_store_types import VectorStoreType
from superagi.vector_store.pinecone import Pinecone
from superagi.vector_store.weaviate import Weaviate
from superagi.vector_store.qdrant import Qdrant
from superagi.vector_store.vector_factory import VectorFactory
import pinecone
import weaviate


class MockPineconeIndex(pinecone.index.Index):
pass


class MockWeaviate(Weaviate):
pass


class MockQdrant(Qdrant):
pass


class TestVectorFactory(unittest.TestCase):

@patch('superagi.vector_store.vector_factory.get_config')
@patch('superagi.vector_store.vector_factory.pinecone')
@patch('superagi.vector_store.vector_factory.weaviate')
@patch('superagi.vector_store.vector_factory.Qdrant')
def test_get_vector_storage(self, mock_qdrant, mock_weaviate, mock_pinecone, mock_get_config):
mock_get_config.return_value = 'test'
mock_embedding_model = MagicMock()
mock_embedding_model.get_embedding.return_value = [0.1, 0.2, 0.3]

# Mock Pinecone index
mock_pinecone_index = MockPineconeIndex('test_index')
mock_pinecone.Index.return_value = mock_pinecone_index

# Test Pinecone
mock_pinecone.list_indexes.return_value = ['test_index']
vector_store = VectorFactory.get_vector_storage(VectorStoreType.PINECONE, 'test_index', mock_embedding_model)
self.assertIsInstance(vector_store, Pinecone)

# Mock Weaviate client
mock_weaviate_client = MagicMock()
mock_weaviate.create_weaviate_client.return_value = mock_weaviate_client
mock_weaviate.Weaviate = MockWeaviate

# Test Weaviate
vector_store = VectorFactory.get_vector_storage('Weaviate', 'test_index', mock_embedding_model)
self.assertIsInstance(vector_store, Weaviate)

# Test Qdrant
mock_qdrant_client = MagicMock()
mock_qdrant.create_qdrant_client.return_value = mock_qdrant_client
mock_qdrant.Qdrant = MockQdrant

vector_store = VectorFactory.get_vector_storage(VectorStoreType.QDRANT, 'test_index', mock_embedding_model)
self.assertIsInstance(vector_store, Qdrant)

# Test unsupported vector store
with self.assertRaises(ValueError):
VectorFactory.get_vector_storage('Unsupported', 'test_index', mock_embedding_model)


if __name__ == '__main__':
unittest.main()