Skip to content

Commit

Permalink
community[patch]: Support passing graph object to Neo4j integrations (#…
Browse files Browse the repository at this point in the history
…20876)

For driver connection reusage, we introduce passing the graph object to
neo4j integrations
  • Loading branch information
tomasonjo authored Apr 25, 2024
1 parent 748a6ae commit 520972f
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 52 deletions.
57 changes: 36 additions & 21 deletions libs/community/langchain_community/chat_message_histories/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, messages_from_dict
from langchain_core.utils import get_from_env
from langchain_core.utils import get_from_dict_or_env

from langchain_community.graphs import Neo4jGraph


class Neo4jChatMessageHistory(BaseChatMessageHistory):
Expand All @@ -17,6 +19,8 @@ def __init__(
database: str = "neo4j",
node_label: str = "Session",
window: int = 3,
*,
graph: Optional[Neo4jGraph] = None,
):
try:
import neo4j
Expand All @@ -30,30 +34,41 @@ def __init__(
if not session_id:
raise ValueError("Please ensure that the session_id parameter is provided")

url = get_from_env("url", "NEO4J_URI", url)
username = get_from_env("username", "NEO4J_USERNAME", username)
password = get_from_env("password", "NEO4J_PASSWORD", password)
database = get_from_env("database", "NEO4J_DATABASE", database)
# Graph object takes precedent over env or input params
if graph:
self._driver = graph._driver
self._database = graph._database
else:
# Handle if the credentials are environment variables
url = get_from_dict_or_env({"url": url}, "url", "NEO4J_URI")
username = get_from_dict_or_env(
{"username": username}, "username", "NEO4J_USERNAME"
)
password = get_from_dict_or_env(
{"password": password}, "password", "NEO4J_PASSWORD"
)
database = get_from_dict_or_env(
{"database": database}, "database", "NEO4J_DATABASE", "neo4j"
)

self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
self._database = database
self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
self._database = database
# Verify connection
try:
self._driver.verify_connectivity()
except neo4j.exceptions.ServiceUnavailable:
raise ValueError(
"Could not connect to Neo4j database. "
"Please ensure that the url is correct"
)
except neo4j.exceptions.AuthError:
raise ValueError(
"Could not connect to Neo4j database. "
"Please ensure that the username and password are correct"
)
self._session_id = session_id
self._node_label = node_label
self._window = window

# Verify connection
try:
self._driver.verify_connectivity()
except neo4j.exceptions.ServiceUnavailable:
raise ValueError(
"Could not connect to Neo4j database. "
"Please ensure that the url is correct"
)
except neo4j.exceptions.AuthError:
raise ValueError(
"Could not connect to Neo4j database. "
"Please ensure that the username and password are correct"
)
# Create session node
self._driver.execute_query(
f"MERGE (s:`{self._node_label}` {{id:$session_id}})",
Expand Down
68 changes: 37 additions & 31 deletions libs/community/langchain_community/vectorstores/neo4j_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from langchain_core.utils import get_from_dict_or_env
from langchain_core.vectorstores import VectorStore

from langchain_community.graphs import Neo4jGraph
from langchain_community.vectorstores.utils import DistanceStrategy

DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE
Expand Down Expand Up @@ -483,6 +484,7 @@ def __init__(
retrieval_query: str = "",
relevance_score_fn: Optional[Callable[[float], float]] = None,
index_type: IndexType = DEFAULT_INDEX_TYPE,
graph: Optional[Neo4jGraph] = None,
) -> None:
try:
import neo4j
Expand All @@ -501,40 +503,44 @@ def __init__(
"distance_strategy must be either 'EUCLIDEAN_DISTANCE' or 'COSINE'"
)

# Handle if the credentials are environment variables

# Support URL for backwards compatibility
if not url:
url = os.environ.get("NEO4J_URL")

url = get_from_dict_or_env({"url": url}, "url", "NEO4J_URI")
username = get_from_dict_or_env(
{"username": username}, "username", "NEO4J_USERNAME"
)
password = get_from_dict_or_env(
{"password": password}, "password", "NEO4J_PASSWORD"
)
database = get_from_dict_or_env(
{"database": database}, "database", "NEO4J_DATABASE", "neo4j"
)

self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
self._database = database
self.schema = ""
# Verify connection
try:
self._driver.verify_connectivity()
except neo4j.exceptions.ServiceUnavailable:
raise ValueError(
"Could not connect to Neo4j database. "
"Please ensure that the url is correct"
# Graph object takes precedent over env or input params
if graph:
self._driver = graph._driver
self._database = graph._database
else:
# Handle if the credentials are environment variables
# Support URL for backwards compatibility
if not url:
url = os.environ.get("NEO4J_URL")

url = get_from_dict_or_env({"url": url}, "url", "NEO4J_URI")
username = get_from_dict_or_env(
{"username": username}, "username", "NEO4J_USERNAME"
)
except neo4j.exceptions.AuthError:
raise ValueError(
"Could not connect to Neo4j database. "
"Please ensure that the username and password are correct"
password = get_from_dict_or_env(
{"password": password}, "password", "NEO4J_PASSWORD"
)
database = get_from_dict_or_env(
{"database": database}, "database", "NEO4J_DATABASE", "neo4j"
)

self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
self._database = database
# Verify connection
try:
self._driver.verify_connectivity()
except neo4j.exceptions.ServiceUnavailable:
raise ValueError(
"Could not connect to Neo4j database. "
"Please ensure that the url is correct"
)
except neo4j.exceptions.AuthError:
raise ValueError(
"Could not connect to Neo4j database. "
"Please ensure that the username and password are correct"
)

self.schema = ""
# Verify if the version support vector index
self._is_enterprise = False
self.verify_version()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os

from langchain_core.messages import AIMessage, HumanMessage

from langchain_community.chat_message_histories import Neo4jChatMessageHistory
from langchain_community.graphs import Neo4jGraph


def test_add_messages() -> None:
"""Basic testing: adding messages to the Neo4jChatMessageHistory."""
assert os.environ.get("NEO4J_URI") is not None
assert os.environ.get("NEO4J_USERNAME") is not None
assert os.environ.get("NEO4J_PASSWORD") is not None
message_store = Neo4jChatMessageHistory("23334")
message_store.clear()
assert len(message_store.messages) == 0
message_store.add_user_message("Hello! Language Chain!")
message_store.add_ai_message("Hi Guys!")

# create another message store to check if the messages are stored correctly
message_store_another = Neo4jChatMessageHistory("46666")
message_store_another.clear()
assert len(message_store_another.messages) == 0
message_store_another.add_user_message("Hello! Bot!")
message_store_another.add_ai_message("Hi there!")
message_store_another.add_user_message("How's this pr going?")

# Now check if the messages are stored in the database correctly
assert len(message_store.messages) == 2
assert isinstance(message_store.messages[0], HumanMessage)
assert isinstance(message_store.messages[1], AIMessage)
assert message_store.messages[0].content == "Hello! Language Chain!"
assert message_store.messages[1].content == "Hi Guys!"

assert len(message_store_another.messages) == 3
assert isinstance(message_store_another.messages[0], HumanMessage)
assert isinstance(message_store_another.messages[1], AIMessage)
assert isinstance(message_store_another.messages[2], HumanMessage)
assert message_store_another.messages[0].content == "Hello! Bot!"
assert message_store_another.messages[1].content == "Hi there!"
assert message_store_another.messages[2].content == "How's this pr going?"

# Now clear the first history
message_store.clear()
assert len(message_store.messages) == 0
assert len(message_store_another.messages) == 3
message_store_another.clear()
assert len(message_store.messages) == 0
assert len(message_store_another.messages) == 0


def test_add_messages_graph_object() -> None:
"""Basic testing: Passing driver through graph object."""
assert os.environ.get("NEO4J_URI") is not None
assert os.environ.get("NEO4J_USERNAME") is not None
assert os.environ.get("NEO4J_PASSWORD") is not None
graph = Neo4jGraph()
# rewrite env for testing
os.environ["NEO4J_USERNAME"] = "foo"
message_store = Neo4jChatMessageHistory("23334", graph=graph)
message_store.clear()
assert len(message_store.messages) == 0
message_store.add_user_message("Hello! Language Chain!")
message_store.add_ai_message("Hi Guys!")
# Now check if the messages are stored in the database correctly
assert len(message_store.messages) == 2
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from langchain_core.documents import Document

from langchain_community.graphs import Neo4jGraph
from langchain_community.vectorstores.neo4j_vector import (
Neo4jVector,
SearchType,
Expand Down Expand Up @@ -902,3 +903,20 @@ def test_neo4jvector_relationship_index_retrieval() -> None:
assert output == [Document(page_content="foo-text", metadata={"foo": "bar"})]

drop_vector_indexes(docsearch)


def test_neo4jvector_passing_graph_object() -> None:
"""Test end to end construction and search with passing graph object."""
graph = Neo4jGraph()
# Rewrite env vars to make sure it fails if env is used
os.environ["NEO4J_URI"] = "foo"
docsearch = Neo4jVector.from_texts(
texts=texts,
embedding=FakeEmbeddingsWithOsDimension(),
graph=graph,
pre_delete_collection=True,
)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]

drop_vector_indexes(docsearch)

0 comments on commit 520972f

Please sign in to comment.