diff --git a/mex/backend/graph/connector.py b/mex/backend/graph/connector.py index 88a4067..39f9d5a 100644 --- a/mex/backend/graph/connector.py +++ b/mex/backend/graph/connector.py @@ -8,6 +8,7 @@ FROZEN_FIELDS_BY_CLASS_NAME, LINK_FIELDS_BY_CLASS_NAME, MUTABLE_FIELDS_BY_CLASS_NAME, + REFERENCE_FIELDS_BY_CLASS_NAME, SEARCHABLE_CLASSES, SEARCHABLE_FIELDS, TEXT_FIELDS_BY_CLASS_NAME, @@ -16,7 +17,6 @@ from mex.backend.graph.queries import q from mex.backend.graph.transform import ( expand_references_in_search_result, - transform_model_to_labels_and_parameters, ) from mex.backend.transform import to_primitive from mex.common.connector import BaseConnector @@ -223,24 +223,27 @@ def merge_node(self, model: AnyExtractedModel) -> Result: text_values = to_primitive(model, include=set(text_fields)) link_values = to_primitive(model, include=set(link_fields)) - nested_spec: list[tuple[str, str]] = [] + nested_edge_labels: list[str] = [] + nested_node_labels: list[str] = [] nested_positions: list[int] = [] nested_values: list[dict[str, Any]] = [] - for nested_label, raws in [ + for nested_node_label, raws in [ (Text.__name__, text_values), (Link.__name__, link_values), ]: - for edge_label, raw_values in to_key_and_values(raws): + for nested_edge_label, raw_values in to_key_and_values(raws): for position, raw_value in enumerate(raw_values): - nested_spec.append((edge_label, nested_label)) + nested_edge_labels.append(nested_edge_label) + nested_node_labels.append(nested_node_label) nested_positions.append(position) nested_values.append(raw_value) statement = q.merge_node( extracted_label=extracted_type, merged_label=merged_type, - nested_spec=nested_spec, + nested_edge_labels=nested_edge_labels, + nested_node_labels=nested_node_labels, ) return self.commit( @@ -253,7 +256,7 @@ def merge_node(self, model: AnyExtractedModel) -> Result: nested_positions=nested_positions, ) - def merge_edges(self, model: AnyExtractedModel) -> list[Result]: + def merge_edges(self, model: AnyExtractedModel) -> Result: """Merge edges into the graph for all relations in the given model. All fields containing references will be iterated over. When the targeted node @@ -265,12 +268,31 @@ def merge_edges(self, model: AnyExtractedModel) -> list[Result]: Returns: Graph result instance """ - results = [] - for label, parameters in transform_model_to_labels_and_parameters(model): - result = self.commit(q.merge_edge(edge_label=label), **parameters) - results.append(result) - # TODO prune edges - return results + extracted_type = model.entityType + ref_fields = REFERENCE_FIELDS_BY_CLASS_NAME[model.entityType] + ref_values = to_primitive(model, include=set(ref_fields)) + + ref_labels: list[str] = [] + ref_identifiers: list[str] = [] + ref_positions: list[int] = [] + + for field, identifiers in to_key_and_values(ref_values): + for position, identifier in enumerate(identifiers): + ref_identifiers.append(str(identifier)) + ref_positions.append(position) + ref_labels.append(field) + + statement = q.merge_edges( + extracted_label=extracted_type, + ref_labels=ref_labels, + ) + + return self.commit( + statement, + identifier=model.identifier, + ref_identifiers=ref_identifiers, + ref_positions=ref_positions, + ) def ingest(self, models: list[AnyExtractedModel]) -> list[Identifier]: """Ingest a list of models into the graph as nodes and connect all edges. diff --git a/mex/backend/graph/cypher/merge_edge.cypher b/mex/backend/graph/cypher/merge_edge.cypher deleted file mode 100644 index 78fb7ff..0000000 --- a/mex/backend/graph/cypher/merge_edge.cypher +++ /dev/null @@ -1,4 +0,0 @@ -MATCH (fromNode:<> {identifier: $source_node}) -MATCH (toNode:<> {identifier: $target_node}) -MERGE (fromNode)-[edge:<> {position: $position}]->(toNode) -RETURN edge; diff --git a/mex/backend/graph/cypher/merge_edges.cypher b/mex/backend/graph/cypher/merge_edges.cypher new file mode 100644 index 0000000..e6dfc60 --- /dev/null +++ b/mex/backend/graph/cypher/merge_edges.cypher @@ -0,0 +1,25 @@ +MATCH (source:<> {identifier: $identifier}) +CALL { +<% if ref_labels %> + WITH source +<% set union = joiner("UNION") %> +<% for ref_label in ref_labels %> +<% set index = loop.index0 %> + <> + MATCH (target_<> {identifier: $ref_identifiers[<>]}) + MERGE (source)-[edge:<> {position: $ref_positions[<>]}]->(target_<>) + RETURN edge +<% endfor %> +<% else %> + RETURN null as edge +<% endif %> +} +WITH source, collect(edge) as edges +CALL { + WITH source, edges + MATCH (source)-[gc]->(:<>) + WHERE NOT gc IN edges + DELETE gc + RETURN count(gc) as pruned +} +RETURN count(edges) as merged, pruned, edges diff --git a/mex/backend/graph/cypher/merge_node.cypher b/mex/backend/graph/cypher/merge_node.cypher index 16e946b..0a67c23 100644 --- a/mex/backend/graph/cypher/merge_node.cypher +++ b/mex/backend/graph/cypher/merge_node.cypher @@ -2,9 +2,9 @@ MERGE (merged:<> {identifier: $stable_target_id}) MERGE (extracted:<> {identifier: $identifier})-[stableTargetId:stableTargetId {position: 0}]->(merged) ON CREATE SET extracted = $on_create ON MATCH SET extracted += $on_match -<%- for edge_label, node_label in nested_spec -%> +<%- for edge_label in nested_edge_labels -%> <%- set index = loop.index0 %> -MERGE (extracted)-[edge_<>:<> {position: $nested_positions[<>]}]->(value_<>:<>) +MERGE (extracted)-[edge_<>:<> {position: $nested_positions[<>]}]->(value_<>:<>) ON CREATE SET value_<> = $nested_values[<>] ON MATCH SET value_<> += $nested_values[<>] <%- endfor %> diff --git a/mex/backend/graph/transform.py b/mex/backend/graph/transform.py index 6380551..4268d5b 100644 --- a/mex/backend/graph/transform.py +++ b/mex/backend/graph/transform.py @@ -1,38 +1,4 @@ -from typing import Any, Generator, TypedDict, cast - -from mex.backend.fields import REFERENCE_FIELDS_BY_CLASS_NAME -from mex.backend.transform import to_primitive -from mex.common.models import AnyExtractedModel -from mex.common.transform import to_key_and_values - - -class MergeEdgeParameters(TypedDict): - """Helper class for merging edges into the graph.""" - - source_node: str # the node from which the edge starts - target_node: str # the node to which the edge leads - position: int # the order in a list for labels with multiple edges - - -def transform_model_to_labels_and_parameters( - model: AnyExtractedModel, -) -> Generator[tuple[str, MergeEdgeParameters], None, None]: - """Transform a model into tuples of edge labels and parameters for the merge query. - - All reference fields except `stableTargetId` are converted to label and parameters. - """ - ref_fields = REFERENCE_FIELDS_BY_CLASS_NAME[model.entityType] - raw_model = to_primitive( - model, - include=set(ref_fields) | {"stableTargetId"}, # we add this during node-merge - ) - for field, stable_target_ids in to_key_and_values(raw_model): - for position, stable_target_id in enumerate(stable_target_ids): - yield field, MergeEdgeParameters( - position=position, - source_node=str(model.identifier), - target_node=str(stable_target_id), - ) +from typing import Any, TypedDict, cast class SearchResultReference(TypedDict):