Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: fix issue #29429 in age_graph.py #29506

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 52 additions & 45 deletions libs/community/langchain_community/graphs/age_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,71 +473,78 @@ def _get_col_name(field: str, idx: int) -> str:
@staticmethod
def _wrap_query(query: str, graph_name: str) -> str:
"""
Convert a cypher query to an Apache Age compatible
sql query by wrapping the cypher query in ag_catalog.cypher,
casting results to agtype and building a select statement
Convert a Cyper query to an Apache Age compatible Sql Query.
Handles combined queries with UNION/EXCEPT operators

Args:
query (str): a valid cypher query
graph_name (str): the name of the graph to query
query (str) : A valid cypher query, can include UNION/EXCEPT operators
graph_name (str) : The name of the graph to query

Returns:
str: an equivalent pgsql query
Returns :
str : An equivalent pgSql query wrapped with ag_catalog.cypher

Raises:
ValueError : If query is empty, contain RETURN *, or has invalid field names
"""

if not query.strip():
raise ValueError("Empty query provided")

# pgsql template
template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$
{query}
$$) AS ({fields});"""

# if there are any returned fields they must be added to the pgsql query
return_match = re.search(r'\breturn\b(?![^"]*")', query, re.IGNORECASE)
if return_match:
# Extract the part of the query after the RETURN keyword
return_clause = query[return_match.end() :]

# parse return statement to identify returned fields
fields = (
return_clause.lower()
.split("distinct")[-1]
.split("order by")[0]
.split("skip")[0]
.split("limit")[0]
.split(",")
)

# raise exception if RETURN * is found as we can't resolve the fields
if "*" in [x.strip() for x in fields]:
raise ValueError(
"AGE graph does not support 'RETURN *'"
+ " statements in Cypher queries"
# split the query into parts based on UNION and EXCEPT
parts = re.split(r"\b(UNION\b|\bEXCEPT)\b", query, flags=re.IGNORECASE)

all_fields = []

for part in parts:
if part.strip().upper() in ("UNION", "EXCEPT"):
continue

# if there are any returned fields they must be added to the pgsql query
return_match = re.search(r'\breturn\b(?![^"]*")', part, re.IGNORECASE)
if return_match:
# Extract the part of the query after the RETURN keyword
return_clause = part[return_match.end() :]

# parse return statement to identify returned fields
fields = (
return_clause.lower()
.split("distinct")[-1]
.split("order by")[0]
.split("skip")[0]
.split("limit")[0]
.split(",")
)

# get pgsql formatted field names
fields = [
AGEGraph._get_col_name(field, idx) for idx, field in enumerate(fields)
]

# build resulting pgsql relation
fields_str = ", ".join(
[
field.split(".")[-1] + " agtype"
for field in fields
if field.split(".")[-1]
]
)
# raise exception if RETURN * is found as we can't resolve the fields
clean_fileds = [f.strip() for f in fields if f.strip()]
if "*" in clean_fileds:
raise ValueError(
"Apache Age does not support RETURN * in Cypher queries"
)

# if no return statement we still need to return a single field of type agtype
else:
# Format fields and maintain order of appearance
for idx, field in enumerate(clean_fileds):
field_name = AGEGraph._get_col_name(field, idx)
if field_name not in all_fields:
all_fields.append(field_name)

# if no return statements found in any part
if not all_fields:
fields_str = "a agtype"

select_str = "*"
else:
fields_str = ", ".join(f"{field} agtype" for field in all_fields)

return template.format(
graph_name=graph_name,
query=query,
fields=fields_str,
projection=select_str,
projection="*",
)

@staticmethod
Expand Down
188 changes: 169 additions & 19 deletions libs/community/tests/unit_tests/graphs/test_age_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_get_col_name(self) -> None:
self.assertEqual(AGEGraph._get_col_name(*value), expected[idx])

def test_wrap_query(self) -> None:
"""Test basic query wrapping functionality."""
inputs = [
# Positive case: Simple return clause
"""
Expand All @@ -76,46 +77,195 @@ def test_wrap_query(self) -> None:

expected = [
# Expected output for the first positive case
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (keanu:Person {name:'Keanu Reeves'})
RETURN keanu.name AS name, keanu.born AS born
$$) AS (name agtype, born agtype);
""",
# Second test case (no RETURN clause)
"""
SELECT * FROM ag_catalog.cypher('test', $$
MERGE (n:a {id: 1})
$$) AS (a agtype);
""",
# Expected output for the negative cases (no RETURN clause)
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (n {description: "This will return a value"})
MERGE (n)-[:RELATED]->(m)
$$) AS (a agtype);
""",
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (n {returnValue: "some value"})
MERGE (n)-[:RELATED]->(m)
$$) AS (a agtype);
""",
]

for idx, value in enumerate(inputs):
result = AGEGraph._wrap_query(value, "test")
expected_result = expected[idx]
self.assertEqual(
re.sub(r"\s", "", result),
re.sub(r"\s", "", expected_result),
(
f"Failed on test case {idx + 1}\n"
f"Input:\n{value}\n"
f"Expected:\n{expected_result}\n"
f"Got:\n{result}"
),
)

def test_wrap_query_union_except(self) -> None:
"""Test query wrapping with UNION and EXCEPT operators."""
inputs = [
# UNION case
"""
MATCH (n:Person)
RETURN n.name AS name, n.age AS age
UNION
MATCH (n:Employee)
RETURN n.name AS name, n.salary AS salary
""",
"""
MATCH (a:Employee {name: "Alice"})
RETURN a.name AS name
UNION
MATCH (b:Manager {name: "Bob"})
RETURN b.name AS name
""",
# Complex UNION case
"""
MATCH (n)-[r]->(m)
RETURN n.name AS source, type(r) AS relationship, m.name AS target
UNION
MATCH (m)-[r]->(n)
RETURN m.name AS source, type(r) AS relationship, n.name AS target
""",
"""
MATCH (a:Person)-[:FRIEND]->(b:Person)
WHERE a.age > 30
RETURN a.name AS name
UNION
MATCH (c:Person)-[:FRIEND]->(d:Person)
WHERE c.age < 25
RETURN c.name AS name
""",
# EXCEPT case
"""
MATCH (n:Person)
RETURN n.name AS name
EXCEPT
MATCH (n:Employee)
RETURN n.name AS name
""",
"""
MATCH (a:Person)
RETURN a.name AS name, a.age AS age
EXCEPT
MATCH (b:Person {name: "Alice", age: 30})
RETURN b.name AS name, b.age AS age
""",
]

expected = [
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (keanu:Person {name:'Keanu Reeves'})
RETURN keanu.name AS name, keanu.born AS born
$$) AS (name agtype, born agtype);
MATCH (n:Person)
RETURN n.name AS name, n.age AS age
UNION
MATCH (n:Employee)
RETURN n.name AS name, n.salary AS salary
$$) AS (name agtype, age agtype, salary agtype);
""",
"""
SELECT * FROM ag_catalog.cypher('test', $$
MERGE (n:a {id: 1})
$$) AS (a agtype);
MATCH (a:Employee {name: "Alice"})
RETURN a.name AS name
UNION
MATCH (b:Manager {name: "Bob"})
RETURN b.name AS name
$$) AS (name agtype);
""",
# Expected output for the negative cases (no return clause)
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (n {description: "This will return a value"})
MERGE (n)-[:RELATED]->(m)
$$) AS (a agtype);
MATCH (n)-[r]->(m)
RETURN n.name AS source, type(r) AS relationship, m.name AS target
UNION
MATCH (m)-[r]->(n)
RETURN m.name AS source, type(r) AS relationship, n.name AS target
$$) AS (source agtype, relationship agtype, target agtype);
""",
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (n {returnValue: "some value"})
MERGE (n)-[:RELATED]->(m)
$$) AS (a agtype);
MATCH (a:Person)-[:FRIEND]->(b:Person)
WHERE a.age > 30
RETURN a.name AS name
UNION
MATCH (c:Person)-[:FRIEND]->(d:Person)
WHERE c.age < 25
RETURN c.name AS name
$$) AS (name agtype);
""",
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (n:Person)
RETURN n.name AS name
EXCEPT
MATCH (n:Employee)
RETURN n.name AS name
$$) AS (name agtype);
""",
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (a:Person)
RETURN a.name AS name, a.age AS age
EXCEPT
MATCH (b:Person {name: "Alice", age: 30})
RETURN b.name AS name, b.age AS age
$$) AS (name agtype, age agtype);
""",
]

for idx, value in enumerate(inputs):
result = AGEGraph._wrap_query(value, "test")
expected_result = expected[idx]
self.assertEqual(
re.sub(r"\s", "", AGEGraph._wrap_query(value, "test")),
re.sub(r"\s", "", expected[idx]),
re.sub(r"\s", "", result),
re.sub(r"\s", "", expected_result),
(
f"Failed on test case {idx + 1}\n"
f"Input:\n{value}\n"
f"Expected:\n{expected_result}\n"
f"Got:\n{result}"
),
)

with self.assertRaises(ValueError):
AGEGraph._wrap_query(
"""
def test_wrap_query_errors(self) -> None:
"""Test error cases for query wrapping."""
error_cases = [
# Empty query
"",
# Return * case
"""
MATCH ()
RETURN *
""",
"test",
)
# Return * in UNION
"""
MATCH (n:Person)
RETURN n.name
UNION
MATCH ()
RETURN *
""",
]

for query in error_cases:
with self.assertRaises(ValueError):
AGEGraph._wrap_query(query, "test")

def test_format_properties(self) -> None:
inputs: List[Dict[str, Any]] = [{}, {"a": "b"}, {"a": "b", "c": 1, "d": True}]
Expand Down
Loading