Skip to content

Commit

Permalink
Add edge pruning
Browse files Browse the repository at this point in the history
  • Loading branch information
cutoffthetop committed Feb 14, 2024
1 parent 61d0902 commit 798524c
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 54 deletions.
48 changes: 35 additions & 13 deletions mex/backend/graph/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
FROZEN_FIELDS_BY_CLASS_NAME,
LINK_FIELDS_BY_CLASS_NAME,
MUTABLE_FIELDS_BY_CLASS_NAME,
REFERENCE_FIELDS_BY_CLASS_NAME,
SEARCHABLE_CLASSES,
SEARCHABLE_FIELDS,
TEXT_FIELDS_BY_CLASS_NAME,
Expand All @@ -16,7 +17,6 @@
from mex.backend.graph.queries import q
from mex.backend.graph.transform import (
expand_references_in_search_result,
transform_model_to_labels_and_parameters,
)
from mex.backend.transform import to_primitive
from mex.common.connector import BaseConnector
Expand Down Expand Up @@ -223,24 +223,27 @@ def merge_node(self, model: AnyExtractedModel) -> Result:
text_values = to_primitive(model, include=set(text_fields))
link_values = to_primitive(model, include=set(link_fields))

nested_spec: list[tuple[str, str]] = []
nested_edge_labels: list[str] = []
nested_node_labels: list[str] = []
nested_positions: list[int] = []
nested_values: list[dict[str, Any]] = []

for nested_label, raws in [
for nested_node_label, raws in [
(Text.__name__, text_values),
(Link.__name__, link_values),
]:
for edge_label, raw_values in to_key_and_values(raws):
for nested_edge_label, raw_values in to_key_and_values(raws):
for position, raw_value in enumerate(raw_values):
nested_spec.append((edge_label, nested_label))
nested_edge_labels.append(nested_edge_label)
nested_node_labels.append(nested_node_label)
nested_positions.append(position)
nested_values.append(raw_value)

statement = q.merge_node(
extracted_label=extracted_type,
merged_label=merged_type,
nested_spec=nested_spec,
nested_edge_labels=nested_edge_labels,
nested_node_labels=nested_node_labels,
)

return self.commit(
Expand All @@ -253,7 +256,7 @@ def merge_node(self, model: AnyExtractedModel) -> Result:
nested_positions=nested_positions,
)

def merge_edges(self, model: AnyExtractedModel) -> list[Result]:
def merge_edges(self, model: AnyExtractedModel) -> Result:
"""Merge edges into the graph for all relations in the given model.
All fields containing references will be iterated over. When the targeted node
Expand All @@ -265,12 +268,31 @@ def merge_edges(self, model: AnyExtractedModel) -> list[Result]:
Returns:
Graph result instance
"""
results = []
for label, parameters in transform_model_to_labels_and_parameters(model):
result = self.commit(q.merge_edge(edge_label=label), **parameters)
results.append(result)
# TODO prune edges
return results
extracted_type = model.entityType
ref_fields = REFERENCE_FIELDS_BY_CLASS_NAME[model.entityType]
ref_values = to_primitive(model, include=set(ref_fields))

ref_labels: list[str] = []
ref_identifiers: list[str] = []
ref_positions: list[int] = []

for field, identifiers in to_key_and_values(ref_values):
for position, identifier in enumerate(identifiers):
ref_identifiers.append(str(identifier))
ref_positions.append(position)
ref_labels.append(field)

statement = q.merge_edges(
extracted_label=extracted_type,
ref_labels=ref_labels,
)

return self.commit(
statement,
identifier=model.identifier,
ref_identifiers=ref_identifiers,
ref_positions=ref_positions,
)

def ingest(self, models: list[AnyExtractedModel]) -> list[Identifier]:
"""Ingest a list of models into the graph as nodes and connect all edges.
Expand Down
4 changes: 0 additions & 4 deletions mex/backend/graph/cypher/merge_edge.cypher

This file was deleted.

25 changes: 25 additions & 0 deletions mex/backend/graph/cypher/merge_edges.cypher
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
MATCH (source:<<extracted_label>> {identifier: $identifier})
CALL {
<% if ref_labels %>
WITH source
<% set union = joiner("UNION") %>
<% for ref_label in ref_labels %>
<% set index = loop.index0 %>
<<union()>>
MATCH (target_<<index>> {identifier: $ref_identifiers[<<index>>]})
MERGE (source)-[edge:<<ref_label>> {position: $ref_positions[<<index>>]}]->(target_<<index>>)
RETURN edge
<% endfor %>
<% else %>
RETURN null as edge
<% endif %>
}
WITH source, collect(edge) as edges
CALL {
WITH source, edges
MATCH (source)-[gc]->(:<<merged_labels|join("|")>>)
WHERE NOT gc IN edges
DELETE gc
RETURN count(gc) as pruned
}
RETURN count(edges) as merged, pruned, edges
4 changes: 2 additions & 2 deletions mex/backend/graph/cypher/merge_node.cypher
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ MERGE (merged:<<merged_label>> {identifier: $stable_target_id})
MERGE (extracted:<<extracted_label>> {identifier: $identifier})-[stableTargetId:stableTargetId {position: 0}]->(merged)
ON CREATE SET extracted = $on_create
ON MATCH SET extracted += $on_match
<%- for edge_label, node_label in nested_spec -%>
<%- for edge_label in nested_edge_labels -%>
<%- set index = loop.index0 %>
MERGE (extracted)-[edge_<<index>>:<<edge_label>> {position: $nested_positions[<<index>>]}]->(value_<<index>>:<<node_label>>)
MERGE (extracted)-[edge_<<index>>:<<edge_label>> {position: $nested_positions[<<index>>]}]->(value_<<index>>:<<nested_node_labels[index]>>)
ON CREATE SET value_<<index>> = $nested_values[<<index>>]
ON MATCH SET value_<<index>> += $nested_values[<<index>>]
<%- endfor %>
Expand Down
36 changes: 1 addition & 35 deletions mex/backend/graph/transform.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,4 @@
from typing import Any, Generator, TypedDict, cast

from mex.backend.fields import REFERENCE_FIELDS_BY_CLASS_NAME
from mex.backend.transform import to_primitive
from mex.common.models import AnyExtractedModel
from mex.common.transform import to_key_and_values


class MergeEdgeParameters(TypedDict):
"""Helper class for merging edges into the graph."""

source_node: str # the node from which the edge starts
target_node: str # the node to which the edge leads
position: int # the order in a list for labels with multiple edges


def transform_model_to_labels_and_parameters(
model: AnyExtractedModel,
) -> Generator[tuple[str, MergeEdgeParameters], None, None]:
"""Transform a model into tuples of edge labels and parameters for the merge query.
All reference fields except `stableTargetId` are converted to label and parameters.
"""
ref_fields = REFERENCE_FIELDS_BY_CLASS_NAME[model.entityType]
raw_model = to_primitive(
model,
include=set(ref_fields) | {"stableTargetId"}, # we add this during node-merge
)
for field, stable_target_ids in to_key_and_values(raw_model):
for position, stable_target_id in enumerate(stable_target_ids):
yield field, MergeEdgeParameters(
position=position,
source_node=str(model.identifier),
target_node=str(stable_target_id),
)
from typing import Any, TypedDict, cast


class SearchResultReference(TypedDict):
Expand Down

0 comments on commit 798524c

Please sign in to comment.