From 76a7539581802bf2fe63a743a00cd114a615aaae Mon Sep 17 00:00:00 2001 From: Igor Ilic <30923996+dexters1@users.noreply.github.com> Date: Sat, 11 Jan 2025 05:03:06 +0100 Subject: [PATCH] Cognee integration update (#17482) --- .../llama-index-graph-rag-cognee/README.md | 26 ++++++- .../llama_index/graph_rag/cognee/graph_rag.py | 78 ++++++++++++++++--- .../pyproject.toml | 6 +- .../tests/test_add_data.py | 3 +- .../tests/test_get_graph_url.py | 2 +- .../tests/test_graph_rag_cognee.py | 15 +++- 6 files changed, 105 insertions(+), 25 deletions(-) diff --git a/llama-index-integrations/graph_rag/llama-index-graph-rag-cognee/README.md b/llama-index-integrations/graph_rag/llama-index-graph-rag-cognee/README.md index dfcce56ed606d..89312829970c9 100644 --- a/llama-index-integrations/graph_rag/llama-index-graph-rag-cognee/README.md +++ b/llama-index-integrations/graph_rag/llama-index-graph-rag-cognee/README.md @@ -40,7 +40,7 @@ async def example_graph_rag_cognee(): graph_db_provider="networkx", vector_db_provider="lancedb", relational_db_provider="sqlite", - db_name="cognee_db", + relational_db_name="cognee_db", ) # Add data to cognee @@ -50,12 +50,22 @@ async def example_graph_rag_cognee(): await cogneeRAG.process_data("test") # Answer prompt based on knowledge graph - search_results = await cogneeRAG.search("person") - print("\n\nExtracted sentences are:\n") + search_results = await cogneeRAG.search( + "Tell me who are the people mentioned?" + ) + print("\n\nAnswer based on knowledge graph:\n") + for result in search_results: + print(f"{result}\n") + + # Answer prompt based on RAG + search_results = await cogneeRAG.rag_search( + "Tell me who are the people mentioned?" + ) + print("\n\nAnswer based on RAG:\n") for result in search_results: print(f"{result}\n") - # Search for related nodes + # Search for related nodes in graph search_results = await cogneeRAG.get_related_nodes("person") print("\n\nRelated nodes are:\n") for result in search_results: @@ -65,3 +75,11 @@ async def example_graph_rag_cognee(): if __name__ == "__main__": asyncio.run(example_graph_rag_cognee()) ``` + +## Supported databases + +**Relational databases:** SQLite, PostgreSQL + +**Vector databases:** LanceDB, PGVector, QDrant, Weviate + +**Graph databases:** Neo4j, NetworkX diff --git a/llama-index-integrations/graph_rag/llama-index-graph-rag-cognee/llama_index/graph_rag/cognee/graph_rag.py b/llama-index-integrations/graph_rag/llama-index-graph-rag-cognee/llama_index/graph_rag/cognee/graph_rag.py index 6b4e91250ab62..96b01a2501b06 100644 --- a/llama-index-integrations/graph_rag/llama-index-graph-rag-cognee/llama_index/graph_rag/cognee/graph_rag.py +++ b/llama-index-integrations/graph_rag/llama-index-graph-rag-cognee/llama_index/graph_rag/cognee/graph_rag.py @@ -17,13 +17,25 @@ class CogneeGraphRAG(GraphRAG): This enables the system to retrieve more precise and structured information about an entity, its relationships, and its properties. Attributes: - llm_api_key: str: Api key for desired llm. - llm_provider: str: Provider for desired llm. - llm_model: str: Model for desired llm. - graph_db_provider: str: The graph database provider. - vector_db_provider: str: The vector database provider. - relational_db_provider: str: The relational database provider. - db_name: str: The name of the databases. + llm_api_key: str: API key for desired LLM. + llm_provider: str: Provider for desired LLM (default: "openai"). + llm_model: str: Model for desired LLM (default: "gpt-4o-mini"). + graph_db_provider: str: The graph database provider (default: "networkx"). + Supported providers: "neo4j", "networkx". + graph_database_url: str: URL for the graph database. + graph_database_username: str: Username for accessing the graph database. + graph_database_password: str: Password for accessing the graph database. + vector_db_provider: str: The vector database provider (default: "lancedb"). + Supported providers: "lancedb", "pgvector", "qdrant", "weviate". + vector_db_url: str: URL for the vector database. + vector_db_key: str: API key for accessing the vector database. + relational_db_provider: str: The relational database provider (default: "sqlite"). + Supported providers: "sqlite", "postgres". + db_name: str: The name of the databases (default: "cognee_db"). + db_host: str: Host for the relational database. + db_port: str: Port for the relational database. + db_username: str: Username for the relational database. + db_password: str: Password for the relational database. """ def __init__( @@ -32,9 +44,18 @@ def __init__( llm_provider: str = "openai", llm_model: str = "gpt-4o-mini", graph_db_provider: str = "networkx", + graph_database_url: str = "", + graph_database_username: str = "", + graph_database_password: str = "", vector_db_provider: str = "lancedb", + vector_db_url: str = "", + vector_db_key: str = "", relational_db_provider: str = "sqlite", - db_name: str = "cognee_db", + relational_db_name: str = "cognee_db", + relational_db_host: str = "", + relational_db_port: str = "", + relational_db_username: str = "", + relational_db_password: str = "", ) -> None: cognee.config.set_llm_config( { @@ -44,11 +65,33 @@ def __init__( } ) - cognee.config.set_vector_db_config({"vector_db_provider": vector_db_provider}) + cognee.config.set_vector_db_config( + { + "vector_db_url": vector_db_url, + "vector_db_key": vector_db_key, + "vector_db_provider": vector_db_provider, + } + ) cognee.config.set_relational_db_config( - {"db_provider": relational_db_provider, "db_name": db_name} + { + "db_path": "", + "db_name": relational_db_name, + "db_host": relational_db_host, + "db_port": relational_db_port, + "db_username": relational_db_username, + "db_password": relational_db_password, + "db_provider": relational_db_provider, + } + ) + + cognee.config.set_graph_db_config( + { + "graph_database_provider": graph_db_provider, + "graph_database_url": graph_database_url, + "graph_database_username": graph_database_username, + "graph_database_password": graph_database_password, + } ) - cognee.config.set_graph_database_provider(graph_db_provider) data_directory_path = str( pathlib.Path( @@ -119,6 +162,17 @@ async def get_graph_url(self, graphistry_password, graphistry_username) -> str: print(graph_url) return graph_url + async def rag_search(self, query: str) -> list: + """Answer query based on data chunk most relevant to query. + + Args: + query (str): The query string. + """ + user = await cognee.modules.users.methods.get_default_user() + return await cognee.search( + cognee.api.v1.search.SearchType.COMPLETION, query, user + ) + async def search(self, query: str) -> list: """Search the graph for relevant information based on a query. @@ -127,7 +181,7 @@ async def search(self, query: str) -> list: """ user = await cognee.modules.users.methods.get_default_user() return await cognee.search( - cognee.api.v1.search.SearchType.SUMMARIES, query, user + cognee.api.v1.search.SearchType.GRAPH_COMPLETION, query, user ) async def get_related_nodes(self, node_id: str) -> list: diff --git a/llama-index-integrations/graph_rag/llama-index-graph-rag-cognee/pyproject.toml b/llama-index-integrations/graph_rag/llama-index-graph-rag-cognee/pyproject.toml index 7e03fdaa6a858..eec1ae943d507 100644 --- a/llama-index-integrations/graph_rag/llama-index-graph-rag-cognee/pyproject.toml +++ b/llama-index-integrations/graph_rag/llama-index-graph-rag-cognee/pyproject.toml @@ -27,11 +27,11 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-graph-rag-cognee" readme = "README.md" -version = "0.1.0" +version = "0.1.1" [tool.poetry.dependencies] python = ">=3.10,<3.12" -cognee = "^0.1.20" +cognee = {extras = ["neo4j", "postgres", "qdrant", "weaviate"], version = "^0.1.21"} httpx = "~=0.27.0" llama-index-core = "^0.12.5" pytest-cov = "^6.0.0" @@ -40,7 +40,7 @@ pytest-cov = "^6.0.0" ipython = "8.10.0" jupyter = "^1.0.0" mypy = "0.991" -pre-commit = "3.2.0" +pre-commit = "^4.0.0" pylint = "2.15.10" pytest = "8.2" pytest-asyncio = "^0.25.0" diff --git a/llama-index-integrations/graph_rag/llama-index-graph-rag-cognee/tests/test_add_data.py b/llama-index-integrations/graph_rag/llama-index-graph-rag-cognee/tests/test_add_data.py index c0e41af04a5ae..4ad9df406b73f 100644 --- a/llama-index-integrations/graph_rag/llama-index-graph-rag-cognee/tests/test_add_data.py +++ b/llama-index-integrations/graph_rag/llama-index-graph-rag-cognee/tests/test_add_data.py @@ -14,10 +14,9 @@ async def test_add_data(monkeypatch): graph_db_provider="networkx", vector_db_provider="lancedb", relational_db_provider="sqlite", - db_name="cognee_db", + relational_db_name="cognee_db", ) - # Mock logging to graphistry async def mock_add_return(add, dataset_name): return True diff --git a/llama-index-integrations/graph_rag/llama-index-graph-rag-cognee/tests/test_get_graph_url.py b/llama-index-integrations/graph_rag/llama-index-graph-rag-cognee/tests/test_get_graph_url.py index 22ee4f6872ec9..01a7e277c4fc1 100644 --- a/llama-index-integrations/graph_rag/llama-index-graph-rag-cognee/tests/test_get_graph_url.py +++ b/llama-index-integrations/graph_rag/llama-index-graph-rag-cognee/tests/test_get_graph_url.py @@ -13,7 +13,7 @@ async def test_get_graph_url(monkeypatch): graph_db_provider="networkx", vector_db_provider="lancedb", relational_db_provider="sqlite", - db_name="cognee_db", + relational_db_name="cognee_db", ) # Mock logging to graphistry diff --git a/llama-index-integrations/graph_rag/llama-index-graph-rag-cognee/tests/test_graph_rag_cognee.py b/llama-index-integrations/graph_rag/llama-index-graph-rag-cognee/tests/test_graph_rag_cognee.py index 3573b18b7e591..ae15f89b45b25 100644 --- a/llama-index-integrations/graph_rag/llama-index-graph-rag-cognee/tests/test_graph_rag_cognee.py +++ b/llama-index-integrations/graph_rag/llama-index-graph-rag-cognee/tests/test_graph_rag_cognee.py @@ -30,7 +30,7 @@ async def test_graph_rag_cognee(): graph_db_provider="networkx", vector_db_provider="lancedb", relational_db_provider="sqlite", - db_name="cognee_db", + relational_db_name="cognee_db", ) # Add data to cognee @@ -39,11 +39,20 @@ async def test_graph_rag_cognee(): await cogneeRAG.process_data("test") # Answer prompt based on knowledge graph - search_results = await cogneeRAG.search("person") + search_results = await cogneeRAG.search("Tell me who are the people mentioned?") assert len(search_results) > 0, "No search results found" - print("\n\nExtracted sentences are:\n") + print("\n\nAnswer based on knowledge graph:\n") + for result in search_results: + print(f"{result}\n") + + # Answer prompt based on RAG + search_results = await cogneeRAG.rag_search("Tell me who are the people mentioned?") + + assert len(search_results) > 0, "No search results found" + + print("\n\nAnswer based on RAG:\n") for result in search_results: print(f"{result}\n")