diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ce30b2d..b95f8cf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,12 +3,12 @@ default_language_version: python: python3.11 repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.3.2 + rev: v0.3.5 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/psf/black - rev: 24.2.0 + rev: 24.3.0 hooks: - id: black - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/CHANGELOG.md b/CHANGELOG.md index 66bc2a4..83b49b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,10 +11,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changes +- re-implemented queries as templated cql files +- updated graph connector for new queries +- improved isolation of neo4j dependency +- improved documentation and code-readability + ### Deprecated ### Removed +- trashed hydration module + ### Fixed ### Security diff --git a/mex/backend/extracted/main.py b/mex/backend/extracted/main.py index f81cc16..93c961b 100644 --- a/mex/backend/extracted/main.py +++ b/mex/backend/extracted/main.py @@ -2,11 +2,9 @@ from fastapi import APIRouter, Query -from mex.backend.extracted.models import ExtractedItemSearchResponse, ExtractedType -from mex.backend.extracted.transform import ( - transform_graph_results_to_extracted_item_search_response, -) +from mex.backend.extracted.models import ExtractedItemSearchResponse from mex.backend.graph.connector import GraphConnector +from mex.backend.types import ExtractedType from mex.common.types import Identifier router = APIRouter() @@ -15,8 +13,8 @@ @router.get("/extracted-item", tags=["editor"]) def search_extracted_items( q: Annotated[str, Query(max_length=100)] = "", - stableTargetId: Identifier | None = None, # noqa: N803 - entityType: Annotated[ # noqa: N803 + stableTargetId: Identifier | None = None, + entityType: Annotated[ Sequence[ExtractedType], Query(max_length=len(ExtractedType)) ] = [], skip: Annotated[int, Query(ge=0, le=10e10)] = 0, @@ -24,11 +22,11 @@ def search_extracted_items( ) -> ExtractedItemSearchResponse: """Search for extracted items by query text or by type and id.""" graph = GraphConnector.get() - query_results = graph.query_nodes( + result = graph.fetch_extracted_data( q, stableTargetId, [str(t.value) for t in entityType or ExtractedType], skip, limit, ) - return transform_graph_results_to_extracted_item_search_response(query_results) + return ExtractedItemSearchResponse.model_validate(result.one()) diff --git a/mex/backend/extracted/models.py b/mex/backend/extracted/models.py index 544ecea..04a7d1e 100644 --- a/mex/backend/extracted/models.py +++ b/mex/backend/extracted/models.py @@ -1,30 +1,12 @@ -from enum import Enum -from typing import TYPE_CHECKING, Union +from typing import Annotated from pydantic import Field -from mex.backend.types import DynamicStrEnum -from mex.common.models import ( - EXTRACTED_MODEL_CLASSES_BY_NAME, - BaseExtractedData, - BaseModel, -) - - -class ExtractedType(Enum, metaclass=DynamicStrEnum): - """Enumeration of possible types for extracted items.""" - - __names__ = list(EXTRACTED_MODEL_CLASSES_BY_NAME) - - -if TYPE_CHECKING: # pragma: no cover - AnyExtractedModel = BaseExtractedData -else: - AnyExtractedModel = Union[*EXTRACTED_MODEL_CLASSES_BY_NAME.values()] +from mex.common.models import AnyExtractedModel, BaseModel class ExtractedItemSearchResponse(BaseModel): """Response body for the extracted item search endpoint.""" total: int - items: list[AnyExtractedModel] = Field(discriminator="entityType") + items: Annotated[list[AnyExtractedModel], Field(discriminator="entityType")] diff --git a/mex/backend/extracted/transform.py b/mex/backend/extracted/transform.py deleted file mode 100644 index f23f560..0000000 --- a/mex/backend/extracted/transform.py +++ /dev/null @@ -1,40 +0,0 @@ -import json - -from neo4j.exceptions import Neo4jError - -from mex.backend.extracted.models import ExtractedItemSearchResponse -from mex.backend.graph.models import GraphResult -from mex.backend.graph.transform import transform_search_result_to_model -from mex.common.logging import logger - - -def transform_graph_results_to_extracted_item_search_response( - graph_results: list[GraphResult], -) -> ExtractedItemSearchResponse: - """Transform graph results to extracted item search results. - - Args: - graph_results: Results of a search and a count query - - Returns: - Search response instance - """ - search_result, count_result = graph_results - total = count_result.data[0]["c"] - items = [] - for result in search_result.data: - try: - model = transform_search_result_to_model(result) - items.append(model) - except Neo4jError as error: # noqa: PERF203 - logger.exception( - "%s\n__node__\n %s\n__refs__\n%s\n", - error, - " \n".join( - "{}: {}".format(k, json.dumps(v, separators=(",", ":"))) - for k, v in result["n"].items() - ), - " \n".join("()-[{key}]->({value})".format(**r) for r in result["r"]), - exc_info=False, - ) - return ExtractedItemSearchResponse.model_construct(items=items, total=total) diff --git a/mex/backend/fields.py b/mex/backend/fields.py index 61538d3..83ed47a 100644 --- a/mex/backend/fields.py +++ b/mex/backend/fields.py @@ -1,60 +1,164 @@ -from types import UnionType -from typing import Annotated, Any, Generator, Union, get_args, get_origin +from types import NoneType, UnionType +from typing import ( + Annotated, + Any, + Callable, + Generator, + Mapping, + Union, + get_args, + get_origin, +) +from pydantic import BaseModel from pydantic.fields import FieldInfo +from mex.backend.types import LiteralStringType from mex.common.models import EXTRACTED_MODEL_CLASSES_BY_NAME -from mex.common.types import Identifier, Text +from mex.common.types import MERGED_IDENTIFIER_CLASSES, Link, Text def _get_inner_types(annotation: Any) -> Generator[type, None, None]: - """Yield all inner types from Unions, lists and annotations.""" + """Yield all inner types from unions, lists and type annotations (except NoneType). + + Args: + annotation: A valid python type annotation + + Returns: + A generator for all (non-NoneType) types found in the annotation + """ if get_origin(annotation) == Annotated: yield from _get_inner_types(get_args(annotation)[0]) elif get_origin(annotation) in (Union, UnionType, list): for arg in get_args(annotation): yield from _get_inner_types(arg) - elif annotation is None: - yield type(None) - else: + elif annotation not in (None, NoneType): yield annotation -def is_reference_field(field: FieldInfo) -> bool: - """Return whether the given field contains a stable target id.""" - return any( - isinstance(t, type) and issubclass(t, Identifier) - for t in _get_inner_types(field.annotation) - ) +def _contains_only_types(field: FieldInfo, *types: type) -> bool: + """Return whether a `field` is annotated as one of the given `types`. + Unions, lists and type annotations are checked for their inner types and only the + non-`NoneType` types are considered for the type-check. -def is_text_field(field: FieldInfo) -> bool: - """Return whether the given field is holding text objects.""" - return any( - isinstance(t, type) and issubclass(t, Text) - for t in _get_inner_types(field.annotation) - ) + Args: + field: A pydantic `FieldInfo` object + types: Types to look for in the field's annotation + Returns: + Whether the field contains any of the given types + """ + if inner_types := list(_get_inner_types(field.annotation)): + return all(inner_type in types for inner_type in inner_types) + return False -REFERENCE_FIELDS_BY_CLASS_NAME = { - name: { - field_name - for field_name, field_info in cls.model_fields.items() - if field_name - not in ( - "identifier", - "stableTargetId", + +def _group_fields_by_class_name( + model_classes_by_name: Mapping[str, type[BaseModel]], + predicate: Callable[[FieldInfo], bool], +) -> dict[str, list[str]]: + """Group the field names by model class and filter them by the given predicate. + + Args: + model_classes_by_name: Map from class names to model classes + predicate: Function to filter the fields of the classes by + + Returns: + Dictionary mapping class names to a list of field names filtered by `predicate` + """ + return { + name: sorted( + { + field_name + for field_name, field_info in cls.model_fields.items() + if predicate(field_info) + } ) - and is_reference_field(field_info) + for name, cls in model_classes_by_name.items() } + + +# fields that are immutable and can only be set once +FROZEN_FIELDS_BY_CLASS_NAME = _group_fields_by_class_name( + EXTRACTED_MODEL_CLASSES_BY_NAME, lambda field_info: field_info.frozen is True +) + +# static fields that are set once on class-level to a literal type +LITERAL_FIELDS_BY_CLASS_NAME = _group_fields_by_class_name( + EXTRACTED_MODEL_CLASSES_BY_NAME, + lambda field_info: isinstance(field_info.annotation, LiteralStringType), +) + +# fields typed as merged identifiers containing references to merged items +REFERENCE_FIELDS_BY_CLASS_NAME = _group_fields_by_class_name( + EXTRACTED_MODEL_CLASSES_BY_NAME, + lambda field_info: _contains_only_types(field_info, *MERGED_IDENTIFIER_CLASSES), +) + +# nested fields that contain `Text` objects +TEXT_FIELDS_BY_CLASS_NAME = _group_fields_by_class_name( + EXTRACTED_MODEL_CLASSES_BY_NAME, + lambda field_info: _contains_only_types(field_info, Text), +) + +# nested fields that contain `Link` objects +LINK_FIELDS_BY_CLASS_NAME = _group_fields_by_class_name( + EXTRACTED_MODEL_CLASSES_BY_NAME, + lambda field_info: _contains_only_types(field_info, Link), +) + +# fields annotated as `str` type +STRING_FIELDS_BY_CLASS_NAME = _group_fields_by_class_name( + EXTRACTED_MODEL_CLASSES_BY_NAME, + lambda field_info: _contains_only_types(field_info, str), +) + +# fields that should be indexed as searchable fields +SEARCHABLE_FIELDS = sorted( + { + field_name + for field_names in STRING_FIELDS_BY_CLASS_NAME.values() + for field_name in field_names + } +) + +# classes that have fields that should be searchable +SEARCHABLE_CLASSES = sorted( + {name for name, field_names in STRING_FIELDS_BY_CLASS_NAME.items() if field_names} +) + +# fields with changeable values that are not nested objects or merged item references +MUTABLE_FIELDS_BY_CLASS_NAME = { + name: sorted( + { + field_name + for field_name in cls.model_fields + if field_name + not in ( + *FROZEN_FIELDS_BY_CLASS_NAME[name], + *REFERENCE_FIELDS_BY_CLASS_NAME[name], + *TEXT_FIELDS_BY_CLASS_NAME[name], + *LINK_FIELDS_BY_CLASS_NAME[name], + ) + } + ) for name, cls in EXTRACTED_MODEL_CLASSES_BY_NAME.items() } -TEXT_FIELDS_BY_CLASS_NAME = { - name: { - f"{field_name}_value" - for field_name, field_info in cls.model_fields.items() - if is_text_field(field_info) - } +# fields with values that should be set once but are neither literal nor references +FINAL_FIELDS_BY_CLASS_NAME = { + name: sorted( + { + field_name + for field_name in cls.model_fields + if field_name in FROZEN_FIELDS_BY_CLASS_NAME[name] + and field_name + not in ( + *LITERAL_FIELDS_BY_CLASS_NAME[name], + *REFERENCE_FIELDS_BY_CLASS_NAME[name], + ) + } + ) for name, cls in EXTRACTED_MODEL_CLASSES_BY_NAME.items() } diff --git a/mex/backend/graph/connector.py b/mex/backend/graph/connector.py index 097550f..9619e25 100644 --- a/mex/backend/graph/connector.py +++ b/mex/backend/graph/connector.py @@ -1,37 +1,46 @@ import json +from string import Template from typing import Any -from neo4j import GraphDatabase +from neo4j import Driver, GraphDatabase -from mex.backend.extracted.models import ( - AnyExtractedModel, -) -from mex.backend.fields import TEXT_FIELDS_BY_CLASS_NAME -from mex.backend.graph.models import GraphResult -from mex.backend.graph.queries import ( - CREATE_CONSTRAINTS_STATEMENT, - CREATE_INDEX_STATEMENT, - HAD_PRIMARY_SOURCE_AND_IDENTIFIER_IN_PRIMARY_SOURCE_IDENTITY_QUERY, - MERGE_EDGE_STATEMENT, - MERGE_NODE_STATEMENT, - QUERY_MAP, - STABLE_TARGET_ID_IDENTITY_QUERY, +from mex.backend.fields import ( + FINAL_FIELDS_BY_CLASS_NAME, + LINK_FIELDS_BY_CLASS_NAME, + MUTABLE_FIELDS_BY_CLASS_NAME, + REFERENCE_FIELDS_BY_CLASS_NAME, + SEARCHABLE_CLASSES, + SEARCHABLE_FIELDS, + TEXT_FIELDS_BY_CLASS_NAME, ) +from mex.backend.graph.models import Result +from mex.backend.graph.query import QueryBuilder from mex.backend.graph.transform import ( - transform_model_to_edges, - transform_model_to_node, + expand_references_in_search_result, ) +from mex.backend.settings import BackendSettings +from mex.backend.transform import to_primitive from mex.common.connector import BaseConnector from mex.common.exceptions import MExError from mex.common.logging import logger from mex.common.models import ( EXTRACTED_MODEL_CLASSES_BY_NAME, + MERGED_MODEL_CLASSES_BY_NAME, MEX_PRIMARY_SOURCE_IDENTIFIER, MEX_PRIMARY_SOURCE_IDENTIFIER_IN_PRIMARY_SOURCE, MEX_PRIMARY_SOURCE_STABLE_TARGET_ID, + AnyExtractedModel, ExtractedPrimarySource, ) -from mex.common.types import Identifier +from mex.common.transform import to_key_and_values +from mex.common.types import Identifier, Link, Text + +MEX_EXTRACTED_PRIMARY_SOURCE = ExtractedPrimarySource.model_construct( + hadPrimarySource=MEX_PRIMARY_SOURCE_STABLE_TARGET_ID, + identifier=MEX_PRIMARY_SOURCE_IDENTIFIER, + identifierInPrimarySource=MEX_PRIMARY_SOURCE_IDENTIFIER_IN_PRIMARY_SOURCE, + stableTargetId=MEX_PRIMARY_SOURCE_STABLE_TARGET_ID, +) class GraphConnector(BaseConnector): @@ -39,11 +48,16 @@ class GraphConnector(BaseConnector): def __init__(self) -> None: """Create a new graph database connection.""" - # break import cycle, sigh - from mex.backend.settings import BackendSettings + self.driver = self._init_driver() + self._check_connectivity_and_authentication() + self._seed_constraints() + self._seed_indices() + self._seed_data() + def _init_driver(self) -> Driver: + """Initialize and return a database driver.""" settings = BackendSettings.get() - self.driver = GraphDatabase.driver( + return GraphDatabase.driver( settings.graph_url, auth=( settings.graph_user.get_secret_value(), @@ -51,133 +65,115 @@ def __init__(self) -> None: ), database=settings.graph_db, ) - self._check_connectivity_and_authentication() - self._seed_constraints() - self._seed_indices() - self._seed_primary_source() - def _check_connectivity_and_authentication(self) -> None: + def _check_connectivity_and_authentication(self) -> Result: """Check the connectivity and authentication to the graph.""" - self.commit("RETURN 1;") - - def close(self) -> None: - """Close the connector's underlying requests session.""" - self.driver.close() + query_builder = QueryBuilder.get() + result = self.commit(query_builder.fetch_database_status()) + if (status := result["currentStatus"]) != "online": + raise MExError(f"Database is {status}.") + return result - def _seed_constraints(self) -> list[GraphResult]: + def _seed_constraints(self) -> list[Result]: """Ensure uniqueness constraints are enabled for all entity types.""" - constraint_statements = [ - (CREATE_CONSTRAINTS_STATEMENT.format(node_label=entity_type), None) - for entity_type in EXTRACTED_MODEL_CLASSES_BY_NAME - ] - return self.mcommit(*constraint_statements) - - def _seed_indices(self) -> GraphResult: - """Ensure there are full text search indices for all text fields.""" - node_labels = "|".join(TEXT_FIELDS_BY_CLASS_NAME.keys()) - node_fields = ", ".join( - sorted( - { - f"n.{f}" - for fields in TEXT_FIELDS_BY_CLASS_NAME.values() - for f in fields - } + query_builder = QueryBuilder.get() + return [ + self.commit( + query_builder.create_identifier_uniqueness_constraint( + node_label=class_name + ) ) - ) + for class_name in sorted( + set(EXTRACTED_MODEL_CLASSES_BY_NAME) | set(MERGED_MODEL_CLASSES_BY_NAME) + ) + ] + + def _seed_indices(self) -> Result: + """Ensure there is a full text search index for all searchable fields.""" + query_builder = QueryBuilder.get() + result = self.commit(query_builder.fetch_full_text_search_index()) + if (index := result.one_or_none()) and ( + set(index["node_labels"]) != set(SEARCHABLE_CLASSES) + or set(index["search_fields"]) != set(SEARCHABLE_FIELDS) + ): + # only drop the index if the classes or fields have changed + self.commit(query_builder.drop_full_text_search_index()) return self.commit( - CREATE_INDEX_STATEMENT.format( - node_labels=node_labels, node_fields=node_fields + query_builder.create_full_text_search_index( + node_labels=SEARCHABLE_CLASSES, + search_fields=SEARCHABLE_FIELDS, ), - config={ + index_config={ "fulltext.eventually_consistent": True, "fulltext.analyzer": "german", }, ) - def _seed_primary_source(self) -> Identifier: + def _seed_data(self) -> list[Identifier]: """Ensure the primary source `mex` is seeded and linked to itself.""" - mex_primary_source = ExtractedPrimarySource.model_construct( - hadPrimarySource=MEX_PRIMARY_SOURCE_STABLE_TARGET_ID, - identifier=MEX_PRIMARY_SOURCE_IDENTIFIER, - identifierInPrimarySource=MEX_PRIMARY_SOURCE_IDENTIFIER_IN_PRIMARY_SOURCE, - stableTargetId=MEX_PRIMARY_SOURCE_STABLE_TARGET_ID, - ) - return self.ingest([mex_primary_source])[0] - - def mcommit( - self, *statements_with_parameters: tuple[str, dict[str, Any] | None] - ) -> list[GraphResult]: - """Send and commit a batch of graph transactions.""" - with self.driver.session(database="neo4j") as session: - results = [] - logger.info( - "\033[95m\n%s\n\033[0m", - json.dumps( - { - "statements": [ - { - "statement": statement, - **({"parameters": parameters} if parameters else {}), - } - for statement, parameters in statements_with_parameters - ] - }, - indent=2, - ), - ) - for statement, parameters in statements_with_parameters: - result = session.run(statement, parameters) - results.append(GraphResult(data=result.data())) - return results + return self.ingest([MEX_EXTRACTED_PRIMARY_SOURCE]) + + def close(self) -> None: + """Close the connector's underlying requests session.""" + self.driver.close() - def commit(self, statement: str, **parameters: Any) -> GraphResult: + def commit(self, query: str, **parameters: Any) -> Result: """Send and commit a single graph transaction.""" - results = self.mcommit((statement, parameters)) - return results[0] + message = Template(query).safe_substitute( + { + k: json.dumps(v, ensure_ascii=False) + for k, v in (parameters or {}).items() + } + ) + try: + with self.driver.session() as session: + result = Result(session.run(query, parameters)) + except Exception as error: + logger.error("\n%s\n%s", message, error) + raise + if counters := result.get_update_counters(): + logger.info("\n%s\n%s", message, json.dumps(counters, indent=4)) + else: + logger.info("\n%s", message) + return result - def query_nodes( + def fetch_extracted_data( self, - query: str | None, + query_string: str | None, stable_target_id: str | None, entity_type: list[str] | None, skip: int, limit: int, - ) -> list[GraphResult]: + ) -> Result: """Query the graph for nodes. Args: - query: Fulltext search query term + query_string: Full text search query term stable_target_id: Optional stable target ID filter entity_type: Optional entity type filter skip: How many nodes to skip for pagination limit: How many nodes to return at most Returns: - Graph result instances + Graph result instance """ - search_statement, count_statement = QUERY_MAP[ - (bool(query), bool(stable_target_id), bool(entity_type)) - ] - return self.mcommit( - ( - search_statement, - dict( - query=query, - labels=entity_type, - stable_target_id=stable_target_id, - skip=skip, - limit=limit, - ), - ), - ( - count_statement, - dict( - query=query, - labels=entity_type, - stable_target_id=stable_target_id, - ), - ), + query_builder = QueryBuilder.get() + query = query_builder.fetch_extracted_data( + filter_by_query_string=bool(query_string), + filter_by_stable_target_id=bool(stable_target_id), + filter_by_labels=bool(entity_type), ) + result = self.commit( + query, + query_string=query_string, + stable_target_id=stable_target_id, + labels=entity_type, + skip=skip, + limit=limit, + ) + for item in result["items"]: + expand_references_in_search_result(item) + return result def fetch_identities( self, @@ -185,11 +181,11 @@ def fetch_identities( identifier_in_primary_source: str | None = None, stable_target_id: Identifier | None = None, limit: int = 1000, - ) -> GraphResult: + ) -> Result: """Search the graph for nodes matching the given ID combination. - Identity queries can be filtered either with just a `stable_target_id` - or with both `had_primary_source` and identifier_in_primary_source`. + Identity queries can be filtered by `stable_target_id`, + `had_primary_source` or `identifier_in_primary_source`. Args: had_primary_source: The stableTargetId of a connected PrimarySource @@ -197,25 +193,15 @@ def fetch_identities( stable_target_id: The stableTargetId of an item limit: How many results to return, defaults to 1000 - Raises: - MExError: When a wrong combination of IDs is given - Returns: A graph result set containing identities """ - if ( - not had_primary_source - and not identifier_in_primary_source - and stable_target_id - ): - query = STABLE_TARGET_ID_IDENTITY_QUERY - elif ( - had_primary_source and identifier_in_primary_source and not stable_target_id - ): - query = HAD_PRIMARY_SOURCE_AND_IDENTIFIER_IN_PRIMARY_SOURCE_IDENTITY_QUERY - else: - raise MExError("invalid identity query parameters") - + query_builder = QueryBuilder.get() + query = query_builder.fetch_identities( + filter_by_had_primary_source=bool(had_primary_source), + filter_by_identifier_in_primary_source=bool(identifier_in_primary_source), + filter_by_stable_target_id=bool(stable_target_id), + ) return self.commit( query, had_primary_source=had_primary_source, @@ -224,10 +210,17 @@ def fetch_identities( limit=limit, ) - def merge_node(self, model: AnyExtractedModel) -> GraphResult: - """Convert a model into a node and merge it into the graph. + def _merge_node(self, model: AnyExtractedModel) -> Result: + """Upsert an extracted model including merged item and nested objects. - Existing nodes will be updated, a new node will be created otherwise. + The given model is created or updated with all its inline properties. + All nested properties (like Text or Link) are created as their own nodes + and linked via edges. For multi-valued fields, the position of each nested + object is stored as a property on the outbound edge. + Any nested objects that are found in the graph, but are not present on the + model any more are purged. + In addition, a merged item is created (if it does not exist yet) and the + extracted item is linked to it via an edge with the label `stableTargetId`. Args: model: Model to merge into the graph as a node @@ -235,57 +228,115 @@ def merge_node(self, model: AnyExtractedModel) -> GraphResult: Returns: Graph result instance """ - entity_type = model.__class__.__name__ - logger.info( - "MERGE (:%s {identifier:%s}) ", - entity_type, - model.identifier, + query_builder = QueryBuilder.get() + extracted_type = model.entityType + merged_type = extracted_type.replace("Extracted", "Merged") + + text_fields = set(TEXT_FIELDS_BY_CLASS_NAME[extracted_type]) + link_fields = set(LINK_FIELDS_BY_CLASS_NAME[extracted_type]) + mutable_fields = set(MUTABLE_FIELDS_BY_CLASS_NAME[extracted_type]) + final_fields = set(FINAL_FIELDS_BY_CLASS_NAME[extracted_type]) + + mutable_values = to_primitive(model, include=mutable_fields) + final_values = to_primitive(model, include=final_fields) + all_values = {**mutable_values, **final_values} + + text_values = to_primitive(model, include=text_fields) + link_values = to_primitive(model, include=link_fields) + + nested_edge_labels: list[str] = [] + nested_node_labels: list[str] = [] + nested_positions: list[int] = [] + nested_values: list[dict[str, Any]] = [] + + for nested_node_label, raws in [ + (Text.__name__, text_values), + (Link.__name__, link_values), + ]: + for nested_edge_label, raw_values in to_key_and_values(raws): + for position, raw_value in enumerate(raw_values): + nested_edge_labels.append(nested_edge_label) + nested_node_labels.append(nested_node_label) + nested_positions.append(position) + nested_values.append(raw_value) + + query = query_builder.merge_node( + extracted_label=extracted_type, + merged_label=merged_type, + nested_edge_labels=nested_edge_labels, + nested_node_labels=nested_node_labels, ) - node = transform_model_to_node(model) + return self.commit( - MERGE_NODE_STATEMENT.format(node_label=entity_type), + query, identifier=model.identifier, - **node.model_dump(), + stable_target_id=model.stableTargetId, + on_match=mutable_values, + on_create=all_values, + nested_values=nested_values, + nested_positions=nested_positions, ) - def merge_edges(self, model: AnyExtractedModel) -> list[GraphResult]: - """Merge edges into the graph for all relations in the given model. + def _merge_edges(self, model: AnyExtractedModel) -> Result: + """Merge edges into the graph for all relations originating from one model. - All fields containing references will be iterated over. When the targeted node + All fields containing references will be iterated over. When the referenced node is found and no such relation exists yet, it will be created. + A position attribute is added to all edges, that stores the index the reference + had in list of references on the originating model. This way, we can preserve + the order for example of `contact` persons referenced on an activity. Args: model: Model to ensure all edges are created in the graph Returns: - Graph result instances + Graph result instance """ - edges = transform_model_to_edges(model) - edge_statements_with_parameters = [ - (MERGE_EDGE_STATEMENT.format(edge_label=edge.label), edge.parameters) - for edge in edges - ] - results = self.mcommit(*edge_statements_with_parameters) - for result, edge in zip(results, edges): - if result.data: - logger.info(f"MERGED {edge.log_message}") - else: - logger.error(f"FAILED {edge.log_message}") - return results + query_builder = QueryBuilder.get() + extracted_type = model.entityType + ref_fields = REFERENCE_FIELDS_BY_CLASS_NAME[model.entityType] + ref_values = to_primitive(model, include=set(ref_fields)) + + ref_labels: list[str] = [] + ref_identifiers: list[str] = [] + ref_positions: list[int] = [] + + for field, identifiers in to_key_and_values(ref_values): + for position, identifier in enumerate(identifiers): + ref_identifiers.append(identifier) + ref_positions.append(position) + ref_labels.append(field) + + query = query_builder.merge_edges( + extracted_label=extracted_type, + ref_labels=ref_labels, + ) + + return self.commit( + query, + identifier=model.identifier, + ref_identifiers=ref_identifiers, + ref_positions=ref_positions, + ) def ingest(self, models: list[AnyExtractedModel]) -> list[Identifier]: """Ingest a list of models into the graph as nodes and connect all edges. + This is a two-step process: first all extracted and merged items are created + along with their nested objects (like Text and Link); then all edges that + represent references (like hadPrimarySource, parentUnit, etc.) are added to + the graph in a second step. + Args: - models: List of extracted items + models: List of extracted models Returns: - List of identifiers from the ingested models + List of identifiers of the ingested models """ for model in models: - self.merge_node(model) + self._merge_node(model) for model in models: - self.merge_edges(model) + self._merge_edges(model) return [m.identifier for m in models] diff --git a/mex/backend/graph/cypher/README.md b/mex/backend/graph/cypher/README.md new file mode 100644 index 0000000..8de361d --- /dev/null +++ b/mex/backend/graph/cypher/README.md @@ -0,0 +1,21 @@ +The queries in this directory are written in the +[cypher query language](https://neo4j.com/docs/getting-started/cypher-intro/), +but contain dynamic templated elements utilizing the +[jinja templating engine](https://jinja.palletsprojects.com/en/latest/). + +The templated elements never contain user input or concrete query parameters! +Those are transmitted to the driver separately from the query, to protect against +injection and improve performance. +See: https://neo4j.com/docs/python-manual/current/query-simple/#query-parameters +Instead, the templated elements are only used to dynamically adjust to changes in +the data structure or to render multiple similar queries from the same template: +For example, a new model class or changing model fields are automatically handled +and don't require rewriting any cypher query. + +Some of these use-cases could be covered by neo4j's [APOC](https://neo4j.com/labs/apoc/) +add-on (e.g. `expand_references_in_search_result`). However, APOC is not included in the +official neo4j docker image. So, to keep deployment simple, the use of APOC was avoided. + +Contrary to the jinja default tags that are centered around curly braces, we use +less/greater signs that do not collide with cypher syntax that often. +See: `mex.backend.graph.query.QueryBuilder` diff --git a/mex/backend/graph/cypher/create_full_text_search_index.cql b/mex/backend/graph/cypher/create_full_text_search_index.cql new file mode 100644 index 0000000..50ccbb6 --- /dev/null +++ b/mex/backend/graph/cypher/create_full_text_search_index.cql @@ -0,0 +1,10 @@ +<# Create a full text search index for faster searches on the given fields. + +Args: + node_labels: List of labels of nodes that have searchable fields + search_fields: List of names of searchable fields +-#> +CREATE FULLTEXT INDEX search_index IF NOT EXISTS +FOR (n:<>) +ON EACH [<>] +OPTIONS {indexConfig: $index_config}; diff --git a/mex/backend/graph/cypher/create_identifier_uniqueness_constraint.cql b/mex/backend/graph/cypher/create_identifier_uniqueness_constraint.cql new file mode 100644 index 0000000..5182ca9 --- /dev/null +++ b/mex/backend/graph/cypher/create_identifier_uniqueness_constraint.cql @@ -0,0 +1,8 @@ +<# Create a uniqueness constraint on identifiers for one node label at a time. + +Args: + node_label: Label of node for which to create a constraint, e.g. ExtractedActivity +-#> +CREATE CONSTRAINT <>_identifier_uniqueness IF NOT EXISTS +FOR (n:<>) +REQUIRE n.identifier IS UNIQUE; diff --git a/mex/backend/graph/cypher/drop_full_text_search_index.cql b/mex/backend/graph/cypher/drop_full_text_search_index.cql new file mode 100644 index 0000000..7a8645b --- /dev/null +++ b/mex/backend/graph/cypher/drop_full_text_search_index.cql @@ -0,0 +1,2 @@ +<# Drop the full text search index. -#> +DROP INDEX search_index IF EXISTS; diff --git a/mex/backend/graph/cypher/fetch_database_status.cql b/mex/backend/graph/cypher/fetch_database_status.cql new file mode 100644 index 0000000..ac2f1c4 --- /dev/null +++ b/mex/backend/graph/cypher/fetch_database_status.cql @@ -0,0 +1,7 @@ +<# Get the current status of the default database. + +Returns: + currentStatus: The current database status as a string +-#> +SHOW DEFAULT DATABASE +YIELD currentStatus; diff --git a/mex/backend/graph/cypher/fetch_extracted_data.cql b/mex/backend/graph/cypher/fetch_extracted_data.cql new file mode 100644 index 0000000..cccb385 --- /dev/null +++ b/mex/backend/graph/cypher/fetch_extracted_data.cql @@ -0,0 +1,63 @@ +<# Fetches extracted items, including their nested objects and referenced identifiers. + +Globals: + extracted_labels: List of all extracted class labels + merged_labels: List of all merged class labels + nested_labels: List of labels for all nestable objects + +Args: + filter_by_query_string: Whether the final query should accept a full text query string + filter_by_stable_target_id: Whether the final query should filter by stableTargetId + filter_by_labels: Whether the final query should filter by entity type labels + +Returns: + total: Count of all items found by this query + items: List of extracted data items, each item has an extra attribute `_refs` that + contains the values of nested objects as well as the identifiers of + referenced items +-#> +CALL { +<%- block match_clause -%> +<%- if filter_by_query_string %> + CALL db.index.fulltext.queryNodes("search_index", $query_string) + YIELD node AS hit, score +<%- endif %> +<%- if filter_by_stable_target_id %> + MATCH (n:<>)-[:stableTargetId]->(merged:<>) +<%- else %> + MATCH (n:<>) +<%- endif %> +<%- if filter_by_query_string or filter_by_stable_target_id or filter_by_labels -%> +<%- set and_ = joiner("AND ") %> + WHERE + <%- if filter_by_query_string %> + <>elementId(hit) = elementId(n) + <%- endif %> + <%- if filter_by_stable_target_id %> + <>merged.identifier = $stable_target_id + <%- endif %> + <%- if filter_by_labels %> + <>ANY(label IN labels(n) WHERE label IN $labels) + <%- endif %> +<%- endif %> +<%- endblock %> + RETURN COUNT(n) AS total +} +CALL { + <<-self.match_clause()>> + CALL { + WITH n + MATCH (n)-[r]->(merged:<>) + RETURN type(r) as label, r.position as position, merged.identifier as value + UNION + WITH n + MATCH (n)-[r]->(nested:<>) + RETURN type(r) as label, r.position as position, properties(nested) as value + } + WITH n, collect({label: label, position: position, value: value}) as refs + RETURN n{.*, entityType: head(labels(n)), _refs: refs} + ORDER BY n.identifier ASC + SKIP $skip + LIMIT $limit +} +RETURN collect(n) AS items, total; diff --git a/mex/backend/graph/cypher/fetch_full_text_search_index.cql b/mex/backend/graph/cypher/fetch_full_text_search_index.cql new file mode 100644 index 0000000..12d0d65 --- /dev/null +++ b/mex/backend/graph/cypher/fetch_full_text_search_index.cql @@ -0,0 +1,10 @@ +<# Fetch the full text search index by its static name. + +Returns: + node_labels: List of labels of nodes with searchable fields + search_fields: List of names of searchable fields +-#> +SHOW INDEXES +YIELD name, labelsOrTypes, properties +WHERE name = "search_index" +RETURN labelsOrTypes as node_labels, properties as search_fields; diff --git a/mex/backend/graph/cypher/fetch_identities.cql b/mex/backend/graph/cypher/fetch_identities.cql new file mode 100644 index 0000000..e35f060 --- /dev/null +++ b/mex/backend/graph/cypher/fetch_identities.cql @@ -0,0 +1,38 @@ +<# Fetch only identity-related fields for the given set of filters. + +Globals: + extracted_labels: List of all extracted class labels + merged_labels: List of all merged class labels + +Args: + filter_by_had_primary_source: Whether the final query should filter by identifiers + of MergedPrimarySources referenced by hadPrimarySource + filter_by_identifier_in_primary_source: Whether the final query should filter by + the value of identifierInPrimarySource + filter_by_stable_target_id: Whether the final query should filter by stableTargetId + +Returns: + List of identity objects. +-#> +MATCH (n:<>)-[:stableTargetId]->(merged:<>) +MATCH (n)-[:hadPrimarySource]->(primary_source:MergedPrimarySource) +<%- if filter_by_had_primary_source or filter_by_identifier_in_primary_source or filter_by_stable_target_id %> +WHERE + <%- set and_ = joiner("AND ") -%> + <%- if filter_by_had_primary_source %> + <>primary_source.identifier = $had_primary_source + <%- endif %> + <%- if filter_by_identifier_in_primary_source %> + <>n.identifierInPrimarySource = $identifier_in_primary_source + <%- endif -%> + <%- if filter_by_stable_target_id %> + <>merged.identifier = $stable_target_id + <%- endif -%> +<%- endif %> +RETURN + merged.identifier as stableTargetId, + primary_source.identifier as hadPrimarySource, + n.identifierInPrimarySource as identifierInPrimarySource, + n.identifier as identifier +ORDER BY n.identifier ASC +LIMIT $limit; diff --git a/mex/backend/graph/cypher/merge_edges.cql b/mex/backend/graph/cypher/merge_edges.cql new file mode 100644 index 0000000..ea716d8 --- /dev/null +++ b/mex/backend/graph/cypher/merge_edges.cql @@ -0,0 +1,42 @@ +<# Merge all edges from a single extracted item to all connected merged items. + +This statement also prunes all edges originating from a given node that are not +part of the references to be merged during this operation. + +Globals: + merged_labels: List of all merged class labels + +Args: + extracted_label: Label of the extracted item that is the source of all edges being + merged in this statement + ref_labels: Ordered list of merged item labels of the edges being merged + +Returns: + merged: Number of merged edges + pruned: Number of pruned edges + edges: List of the merged edge objects +-#> +MATCH (source:<> {identifier: $identifier}) +CALL { +<%- if ref_labels %> +<%- set union = joiner("UNION\n ") %> +<%- for ref_label in ref_labels %> +<%- set index = loop.index0 %> + <>WITH source + MATCH (target_<> {identifier: $ref_identifiers[<>]}) + MERGE (source)-[edge:<> {position: $ref_positions[<>]}]->(target_<>) + RETURN edge +<%- endfor %> +<%- else %> + RETURN null as edge +<%- endif %> +} +WITH source, collect(edge) as edges +CALL { + WITH source, edges + MATCH (source)-[outdated_edge]->(:<>) + WHERE NOT outdated_edge IN edges + DELETE outdated_edge + RETURN count(outdated_edge) as pruned +} +RETURN count(edges) as merged, pruned, edges; diff --git a/mex/backend/graph/cypher/merge_node.cql b/mex/backend/graph/cypher/merge_node.cql new file mode 100644 index 0000000..1c343ea --- /dev/null +++ b/mex/backend/graph/cypher/merge_node.cql @@ -0,0 +1,47 @@ +<# Upsert an extracted item with its accociated merged item and all nested objects. + +One extracted item is created (or updated) along with its accociated merged item. +In addition, all nested objects of the extracted item are created as their own +nodes and linked to the extracted via edges that have the field names as labels. +We also prune all connected nodes that have nested labels but are not part of +the extracted item any more. + +Globals: + nested_labels: List of all labels for nested objects + +Args: + merged_label: Label for the merged item accociated with the extracted item, + e.g. MergedPerson + extracted_label: Label of the current extracted item, e.g. ExtractedPerson + nested_edge_labels: Ordered list of field names that contain nested objects + e.g. homepage, alternativeName, methodDescription + nested_node_labels: Ordered list of class names for the `nested_edge_labels`, + e.g. Link, Text, Text + +Returns: + extracted: The extracted item with all inline properties + edges: List of edges connecting the extracted item with its nested objects + values: List of nested objects that are merged by this statement + pruned: Number of nested objects that were removed by this statement +-#> +MERGE (merged:<> {identifier: $stable_target_id}) +MERGE (extracted:<> {identifier: $identifier})-[stableTargetId:stableTargetId {position: 0}]->(merged) +ON CREATE SET extracted = $on_create +ON MATCH SET extracted += $on_match +<%- for edge_label in nested_edge_labels -%> +<%- set index = loop.index0 %> +MERGE (extracted)-[edge_<>:<> {position: $nested_positions[<>]}]->(value_<>:<>) +ON CREATE SET value_<> = $nested_values[<>] +ON MATCH SET value_<> += $nested_values[<>] +<%- endfor %> +WITH extracted, + [<>] as edges, + [<>] as values +CALL { + WITH extracted, values + MATCH (extracted)-[]->(outdated_node:<>) + WHERE NOT outdated_node IN values + DETACH DELETE outdated_node + RETURN count(outdated_node) as pruned +} +RETURN extracted, edges, values, pruned; diff --git a/mex/backend/graph/exceptions.py b/mex/backend/graph/exceptions.py new file mode 100644 index 0000000..bba541d --- /dev/null +++ b/mex/backend/graph/exceptions.py @@ -0,0 +1,9 @@ +from mex.common.exceptions import MExError + + +class NoResultFoundError(MExError): + """A database result was required but none was found.""" + + +class MultipleResultsFoundError(MExError): + """A single database result was required but more than one were found.""" diff --git a/mex/backend/graph/hydrate.py b/mex/backend/graph/hydrate.py deleted file mode 100644 index e72efa4..0000000 --- a/mex/backend/graph/hydrate.py +++ /dev/null @@ -1,210 +0,0 @@ -from types import NoneType -from typing import Any, TypeGuard, TypeVar, cast, get_args - -from pydantic.fields import FieldInfo - -from mex.common.models import BaseModel - -KEY_SEPARATOR = "_" -NULL_VALUE = "" -DehydratedValue = str | list[str] -HydratedValue = str | None | list[str | None] -NestedValues = str | None | list["NestedValues"] | dict[str, "NestedValues"] -NestedDict = dict[str, NestedValues] -FlatDict = dict[str, str | list[str]] -KeyPath = tuple[str | int, ...] -AnyT = TypeVar("AnyT", bound=Any) -ModelT = TypeVar("ModelT", bound=BaseModel) - - -def are_instances(value: Any, *types: type[AnyT]) -> TypeGuard[list[AnyT]]: - """Return whether value is a list whose elements are all in the provided types.""" - return isinstance(value, list) and all(isinstance(v, types) for v in value) - - -def dehydrate_value(value: HydratedValue) -> DehydratedValue: - """Convert any pydantic value into something we can store in the graph.""" - if value is None: - return NULL_VALUE - if isinstance(value, str): - return value - if are_instances(value, str, NoneType): - return cast(list[str], [dehydrate_value(v) for v in value]) - raise TypeError("can only dehydrate strings or lists of strings") - - -def hydrate_value(value: DehydratedValue) -> HydratedValue: - """Convert a value stored in the graph into something we can parse with pydantic.""" - if value == NULL_VALUE: - return None - if isinstance(value, str): - return value - if are_instances(value, str): - return cast(list[str | None], [hydrate_value(v) for v in value]) - raise TypeError("can only hydrate strings or lists of strings") - - -def dehydrate(nested: NestedDict) -> FlatDict: # noqa: C901 - """Convert a nested pydantic dict into a flattened and dehydrated graph node. - - Args: - nested: Dictionary with nested dictionary and at most one list per key-path - - Returns: - Dictionary with flat key-value-pairs - """ - - def _dehydrate( # noqa: C901 - in_: list[NestedValues] | dict[str, NestedValues], - out: FlatDict, - parent: KeyPath, - ) -> bool: - """Dehydrate the `in_` structure into an `out` dictionary. - - Args: - in_: List or dictionary of allowed values - out: Dictionary to write flat key-value-pairs into - parent: The parent key path for recursions - - Returns: - Whether this path needs to be recursed further - """ - key_value_iterable = enumerate(in_) if isinstance(in_, list) else in_.items() - has_item = False - for key, value in key_value_iterable: - has_item = True - key_path = parent + cast(KeyPath, (key,)) - if isinstance(value, (dict, list)): - has_child = _dehydrate(value, out, key_path) - if has_child or not isinstance(value, (dict, list)): - continue - position = None - flat_key = "" - for key_in_path in key_path: - if isinstance(key_in_path, int): - if position is not None: - raise TypeError("can only handle one list per path") - position = key_in_path - elif flat_key: - flat_key = KEY_SEPARATOR.join((flat_key, key_in_path)) - else: - flat_key = key_in_path - if position is not None: - out.setdefault(flat_key, []) - if not isinstance(out[flat_key], list): # pragma: no cover - raise RuntimeError("this key path should have been a list") - length_needed = 1 + position - len(out[flat_key]) - out[flat_key].extend([NULL_VALUE] * length_needed) # type: ignore - out[flat_key][position] = dehydrate_value(value) # type: ignore - elif flat_key in out: # pragma: no cover - raise RuntimeError("already dehydrated this key path") - else: - out[flat_key] = dehydrate_value(value) # type: ignore - return has_item - - flat: FlatDict = {} - _dehydrate(nested, flat, ()) - return flat - - -def hydrate(flat: FlatDict, model: type[BaseModel]) -> NestedDict: - """Convert a flattened and dehydrated graph node into a nested dict for pydantic. - - Args: - flat: Dictionary with flat key-value-pairs - model: MEx model class to infer structure from - - Returns: - A nested dictionary conforming to the model structure - """ - nested: NestedDict = {} - for flat_key, value in flat.items(): - (*branch_keys, leaf_key) = flat_key.split(KEY_SEPARATOR) - value_count = len(value) - value_is_list = isinstance(value, list) - empty_leaf_value = _initialize_branch_with_missing_expected_types( - branch_keys, model, nested, value_count, value_is_list - ) - - _set_leaf_values(empty_leaf_value, leaf_key, value) - - return nested - - -def _initialize_branch_with_missing_expected_types( - branch_keys: list[str], - model: type[BaseModel], - nested: NestedDict, - value_count: int, - value_is_list: bool, -) -> NestedDict | list[NestedDict]: - model_at_depth = model - nested_value_of_current_branch_key: NestedDict | list[NestedDict] = nested - for key_id, branch_key in enumerate(branch_keys): - nested_value_of_parent_branch_key = nested_value_of_current_branch_key - nested_value_of_current_branch_key = _set_branch_node_default( - nested_value_of_parent_branch_key, - branch_key, - model_at_depth, - value_count, - value_is_list, - ) - if len(branch_keys) - key_id > 1: # if for loop has iterations left - try: - model_at_depth = _get_base_model_from_field( - model_at_depth.model_fields[branch_key] - ) - except KeyError as error: - raise TypeError("flat dict does not align with target model") from error - return nested_value_of_current_branch_key - - -def _get_base_model_from_field(field: FieldInfo) -> type[BaseModel]: - if args := get_args(field.annotation): - args_wo_none = [arg for arg in args if arg is not type(None)] - if (args_count := len(args_wo_none)) != 1: - raise TypeError(f"Expected one non-None type, got {args_count}.") - base_model = args_wo_none[0] - else: - base_model = field.annotation - if not isinstance(base_model, type) or not issubclass(base_model, BaseModel): - raise TypeError("cannot hydrate paths with non base models") - return base_model - - -def _set_leaf_values( - empty_leaf_value: NestedDict | list[NestedDict], - leaf_key: str, - value: str | list[str], -) -> None: - if isinstance(empty_leaf_value, list): - for t, v in zip(empty_leaf_value, value): - t[leaf_key] = hydrate_value(v) # type: ignore - else: - empty_leaf_value[leaf_key] = hydrate_value(value) # type: ignore - - -def _set_branch_node_default( - target: NestedDict | list[NestedDict], - key: str, - model_at_depth: type[BaseModel], - value_count: int, - value_is_list: bool, -) -> NestedDict | list[NestedDict]: - if not issubclass(model_at_depth, BaseModel): - raise TypeError("cannot hydrate paths with non base models") - if key in model_at_depth._get_list_field_names(): - if not value_is_list: - raise TypeError("cannot hydrate non-list to list") - if isinstance(target, list): - raise TypeError("cannot handle multiple list branches") - target = cast( - list[NestedDict], target.setdefault(key, [{} for _ in range(value_count)]) - ) - elif isinstance(target, list): - if len(target) != value_count: # pragma: no cover - raise RuntimeError("branch count must match our values") - target = cast(list[NestedDict], [t.setdefault(key, {}) for t in target]) - else: - target = target.setdefault(key, {}) # type: ignore - return target diff --git a/mex/backend/graph/models.py b/mex/backend/graph/models.py index 1ad347e..864188e 100644 --- a/mex/backend/graph/models.py +++ b/mex/backend/graph/models.py @@ -1,9 +1,67 @@ -from typing import Any +from functools import cache +from typing import Any, Iterator -from pydantic import BaseModel as PydanticBaseModel +from neo4j import Result as Neo4jResult +from mex.backend.graph.exceptions import MultipleResultsFoundError, NoResultFoundError -class GraphResult(PydanticBaseModel): - """Model for graph query results.""" - data: list[dict[str, Any]] = [] +class Result: + """Represent a set of graph results. + + This class wraps `neo4j.Result` in an interface akin to `sqlalchemy.engine.Result`. + We do this, to reduce vendor tie-in with neo4j and limit the dependency-scope of + the neo4j driver library to the `mex.backend.graph` submodule. + """ + + def __init__(self, result: Neo4jResult) -> None: + """Wrap a neo4j result object in a mex-backend result.""" + self._records, self._summary, _ = result.to_eager_result() + self._get_cached_data = cache(lambda i: self._records[i].data()) + + def __getitem__(self, key: str) -> Any: + """Proxy a getitem instruction to the first record if exactly one exists.""" + return self.one()[key] + + def __iter__(self) -> Iterator[dict[str, Any]]: + """Return an iterator over all records.""" + yield from (self._get_cached_data(index) for index in range(len(self._records))) + + def __repr__(self) -> str: + """Return a human-readable representation of this result object.""" + representation = f"Result({self.all()!r})" + if len(representation) > 90: + representation = f"{representation[:40]}... ...{representation[-40:]}" + return representation + + def all(self) -> list[dict[str, Any]]: + """Return all records as a list.""" + return list(self) + + def one(self) -> dict[str, Any]: + """Return exactly one record or raise an exception.""" + match len(self._records): + case 1: + return self._get_cached_data(0) + case 0: + raise NoResultFoundError from None + case _: + raise MultipleResultsFoundError from None + + def one_or_none(self) -> dict[str, Any] | None: + """Return at most one result or raise an exception. + + Returns None if the result has no records. + Raises MultipleResultsFound if multiple records are returned. + """ + match len(self._records): + case 1: + return self._get_cached_data(0) + case 0: + return None + case _: + raise MultipleResultsFoundError + + def get_update_counters(self) -> dict[str, int]: + """Return a summary of counters for operations the query triggered.""" + return {k: v for k, v in vars(self._summary.counters).items() if v} diff --git a/mex/backend/graph/queries.py b/mex/backend/graph/queries.py deleted file mode 100644 index ccd1cf1..0000000 --- a/mex/backend/graph/queries.py +++ /dev/null @@ -1,276 +0,0 @@ -NOOP_STATEMENT = r""" -RETURN 1; -""" - -CREATE_CONSTRAINTS_STATEMENT = r""" -CREATE CONSTRAINT identifier_uniqueness IF NOT EXISTS -FOR (n:{node_label}) -REQUIRE n.identifier IS UNIQUE; -""" - -CREATE_INDEX_STATEMENT = r""" -CREATE FULLTEXT INDEX text_fields IF NOT EXISTS -FOR (n:{node_labels}) -ON EACH [{node_fields}] -OPTIONS {{indexConfig: $config}}; -""" - -MERGE_NODE_STATEMENT = r""" -MERGE (n:{node_label} {{identifier:$identifier}}) -ON CREATE SET n = $on_create -ON MATCH SET n += $on_match -RETURN n; -""" - -MERGE_EDGE_STATEMENT = r""" -MATCH (s {{identifier:$fromID}}) -MATCH (t {{stableTargetId:$toSTI}}) -MERGE (s)-[e:{edge_label}]->(t) -RETURN e; -""" - -STABLE_TARGET_ID_IDENTITY_QUERY = r""" -MATCH (n)-[:hadPrimarySource]->(p:ExtractedPrimarySource) -WHERE n.stableTargetId = $stable_target_id -RETURN { - stableTargetId: n.stableTargetId, - hadPrimarySource: p.stableTargetId, - identifierInPrimarySource: n.identifierInPrimarySource, - identifier: n.identifier -} as i -ORDER BY n.identifier ASC -LIMIT $limit; -""" - -HAD_PRIMARY_SOURCE_AND_IDENTIFIER_IN_PRIMARY_SOURCE_IDENTITY_QUERY = r""" -MATCH (n)-[:hadPrimarySource]->(p:ExtractedPrimarySource) -WHERE n.identifierInPrimarySource = $identifier_in_primary_source - AND p.stableTargetId = $had_primary_source -RETURN { - stableTargetId: n.stableTargetId, - hadPrimarySource: p.stableTargetId, - identifierInPrimarySource: n.identifierInPrimarySource, - identifier: n.identifier -} as i -ORDER BY n.identifier ASC -LIMIT $limit; -""" - -FULL_TEXT_ID_AND_LABEL_FILTER_SEARCH_QUERY = r""" -CALL db.index.fulltext.queryNodes('text_fields', $query) -YIELD node AS hit, score -MATCH (n) -WHERE elementId(hit) = elementId(n) - AND n.stableTargetId = $stable_target_id - AND ANY(label IN labels(n) WHERE label IN $labels) -CALL { - WITH n - MATCH (n)-[r]->() - RETURN collect({key: type(r), value: endNode(r).stableTargetId}) as r -} -RETURN n, head(labels(n)) AS l, r -ORDER BY score DESC -SKIP $skip -LIMIT $limit; -""" - -FULL_TEXT_ID_AND_LABEL_FILTER_COUNT_QUERY = r""" -CALL db.index.fulltext.queryNodes('text_fields', $query) -YIELD node AS hit, score -MATCH (n) -WHERE elementId(hit) = elementId(n) - AND n.stableTargetId = $stable_target_id - AND ANY(label IN labels(n) WHERE label IN $labels) -RETURN COUNT(n) AS c; -""" - -FULL_TEXT_ID_FILTER_SEARCH_QUERY = r""" -CALL db.index.fulltext.queryNodes('text_fields', $query) -YIELD node AS hit, score -MATCH (n) -WHERE elementId(hit) = elementId(n) - AND n.stableTargetId = $stable_target_id -CALL { - WITH n - MATCH (n)-[r]->() - RETURN collect({key: type(r), value: endNode(r).stableTargetId}) as r -} -RETURN n, head(labels(n)) AS l, r -ORDER BY score DESC -SKIP $skip -LIMIT $limit; -""" - -FULL_TEXT_ID_FILTER_COUNT_QUERY = r""" -CALL db.index.fulltext.queryNodes('text_fields', $query) -YIELD node AS hit, score -MATCH (n) -WHERE elementId(hit) = elementId(n) - AND n.stableTargetId = $stable_target_id -RETURN COUNT(n) AS c; -""" - -FULL_TEXT_LABEL_FILTER_SEARCH_QUERY = r""" -CALL db.index.fulltext.queryNodes('text_fields', $query) -YIELD node AS hit, score -MATCH (n) -WHERE elementId(hit) = elementId(n) - AND ANY(label IN labels(n) WHERE label IN $labels) -CALL { - WITH n - MATCH (n)-[r]->() - RETURN collect({key: type(r), value: endNode(r).stableTargetId}) as r -} -RETURN n, head(labels(n)) AS l, r -ORDER BY score DESC -SKIP $skip -LIMIT $limit; -""" - -FULL_TEXT_LABEL_FILTER_COUNT_QUERY = r""" -CALL db.index.fulltext.queryNodes('text_fields', $query) -YIELD node AS hit, score -MATCH (n) -WHERE elementId(hit) = elementId(n) - AND ANY(label IN labels(n) WHERE label IN $labels) -RETURN COUNT(n) AS c; -""" - -FULL_TEXT_SEARCH_QUERY = r""" -CALL db.index.fulltext.queryNodes('text_fields', $query) -YIELD node AS hit, score -MATCH (n) -WHERE elementId(hit) = elementId(n) -CALL { - WITH n - MATCH (n)-[r]->() - RETURN collect({key: type(r), value: endNode(r).stableTargetId}) as r -} -RETURN n, head(labels(n)) AS l, r -ORDER BY score DESC -SKIP $skip -LIMIT $limit; -""" - -FULL_TEXT_COUNT_QUERY = r""" -CALL db.index.fulltext.queryNodes('text_fields', $query) -YIELD node AS hit, score -MATCH (n) -WHERE elementId(hit) = elementId(n) -RETURN COUNT(n) AS c; -""" - -ID_AND_LABEL_FILTER_SEARCH_QUERY = r""" -MATCH (n) -WHERE n.stableTargetId = $stable_target_id - AND ANY(label IN labels(n) WHERE label IN $labels) - CALL { - WITH n - MATCH (n)-[r]->() - RETURN collect({key: type(r), value: endNode(r).stableTargetId}) as r -} -RETURN n, head(labels(n)) AS l, r -ORDER BY n.identifier ASC -SKIP $skip -LIMIT $limit; -""" - -ID_AND_LABEL_FILTER_COUNT_QUERY = r""" -MATCH (n) -WHERE n.stableTargetId = $stable_target_id - AND ANY(label IN labels(n) WHERE label IN $labels) -RETURN COUNT(n) AS c; -""" - -ID_FILTER_SEARCH_QUERY = r""" -MATCH (n) -WHERE n.stableTargetId = $stable_target_id -CALL { - WITH n - MATCH (n)-[r]->() - RETURN collect({key: type(r), value: endNode(r).stableTargetId}) as r -} -RETURN n, head(labels(n)) AS l, r -ORDER BY n.identifier ASC -SKIP $skip -LIMIT $limit; -""" - -ID_FILTER_COUNT_QUERY = r""" -MATCH (n) -WHERE n.stableTargetId = $stable_target_id -RETURN COUNT(n) AS c; -""" - -LABEL_FILTER_SEARCH_QUERY = r""" -MATCH (n) -WHERE ANY(label IN labels(n) WHERE label IN $labels) -CALL { - WITH n - MATCH (n)-[r]->() - RETURN collect({key: type(r), value: endNode(r).stableTargetId}) as r -} -RETURN n, head(labels(n)) AS l, r -ORDER BY n.identifier ASC -SKIP $skip -LIMIT $limit; -""" - -LABEL_FILTER_COUNT_QUERY = r""" -MATCH (n) -WHERE ANY(label IN labels(n) WHERE label IN $labels) -RETURN COUNT(n) AS c; -""" - -GENERAL_SEARCH_QUERY = r""" -MATCH (n) -CALL { - WITH n MATCH (n)-[r]->() - RETURN collect({key: type(r), value: endNode(r).stableTargetId}) as r -} -RETURN n, head(labels(n)) AS l, r -ORDER BY n.identifier ASC -SKIP $skip -LIMIT $limit; -""" - -GENERAL_COUNT_QUERY = r""" -MATCH (n) -RETURN COUNT(n) AS c; -""" - -QUERY_MAP = { - # (full_text, id_filter, label_filter) => search_query, count_query - (True, True, True): ( - FULL_TEXT_ID_AND_LABEL_FILTER_SEARCH_QUERY, - FULL_TEXT_ID_AND_LABEL_FILTER_COUNT_QUERY, - ), - (True, True, False): ( - FULL_TEXT_ID_FILTER_SEARCH_QUERY, - FULL_TEXT_ID_FILTER_COUNT_QUERY, - ), - (True, False, True): ( - FULL_TEXT_LABEL_FILTER_SEARCH_QUERY, - FULL_TEXT_LABEL_FILTER_COUNT_QUERY, - ), - (True, False, False): ( - FULL_TEXT_SEARCH_QUERY, - FULL_TEXT_COUNT_QUERY, - ), - (False, True, True): ( - ID_AND_LABEL_FILTER_SEARCH_QUERY, - ID_AND_LABEL_FILTER_COUNT_QUERY, - ), - (False, True, False): ( - ID_FILTER_SEARCH_QUERY, - ID_FILTER_COUNT_QUERY, - ), - (False, False, True): ( - LABEL_FILTER_SEARCH_QUERY, - LABEL_FILTER_COUNT_QUERY, - ), - (False, False, False): ( - GENERAL_SEARCH_QUERY, - GENERAL_COUNT_QUERY, - ), -} diff --git a/mex/backend/graph/query.py b/mex/backend/graph/query.py new file mode 100644 index 0000000..da0bfeb --- /dev/null +++ b/mex/backend/graph/query.py @@ -0,0 +1,59 @@ +from typing import Callable + +from jinja2 import Environment, PackageLoader, StrictUndefined, select_autoescape + +from mex.backend.settings import BackendSettings +from mex.common.connector import BaseConnector +from mex.common.models import ( + EXTRACTED_MODEL_CLASSES_BY_NAME, + MERGED_MODEL_CLASSES_BY_NAME, +) +from mex.common.transform import ( + dromedary_to_kebab, + dromedary_to_snake, + ensure_prefix, + kebab_to_camel, + snake_to_dromedary, +) +from mex.common.types import NESTED_MODEL_CLASSES_BY_NAME + + +class QueryBuilder(BaseConnector): + """Wrapper around jinja template loading and rendering.""" + + def __init__(self) -> None: + """Create a new jinja environment with template loader, filters and globals.""" + settings = BackendSettings.get() + self._env = Environment( + loader=PackageLoader(__package__, package_path="cypher"), + autoescape=select_autoescape(), + auto_reload=settings.debug, + undefined=StrictUndefined, + block_start_string="<%", + block_end_string="%>", + variable_start_string="<<", + variable_end_string=">>", + comment_start_string="<#", + comment_end_string="#>", + ) + self._env.filters.update( + snake_to_dromedary=snake_to_dromedary, + dromedary_to_snake=dromedary_to_snake, + dromedary_to_kebab=dromedary_to_kebab, + kebab_to_camel=kebab_to_camel, + ensure_prefix=ensure_prefix, + ) + self._env.globals.update( + extracted_labels=list(EXTRACTED_MODEL_CLASSES_BY_NAME), + merged_labels=list(MERGED_MODEL_CLASSES_BY_NAME), + nested_labels=list(NESTED_MODEL_CLASSES_BY_NAME), + ) + + def __getattr__(self, name: str) -> Callable[..., str]: + """Load the template with the given `name` and return its `render` method.""" + template = self._env.get_template(f"{name}.cql") + return template.render + + def close(self) -> None: + """Clean up the connector.""" + pass # no clean-up needed diff --git a/mex/backend/graph/transform.py b/mex/backend/graph/transform.py index c60d344..ca6ccbb 100644 --- a/mex/backend/graph/transform.py +++ b/mex/backend/graph/transform.py @@ -1,110 +1,24 @@ -from collections import defaultdict -from typing import Any, TypedDict +from typing import Any, TypedDict, cast -from pydantic import BaseModel as PydanticBaseModel -from mex.backend.extracted.models import AnyExtractedModel -from mex.backend.fields import REFERENCE_FIELDS_BY_CLASS_NAME -from mex.backend.graph.hydrate import dehydrate, hydrate -from mex.backend.transform import to_primitive -from mex.common.identity import Identity -from mex.common.models import EXTRACTED_MODEL_CLASSES_BY_NAME, BaseModel, MExModel +class _SearchResultReference(TypedDict): + """Helper class to show the structure of search result references.""" + label: str # corresponds to the field name of the pydantic model + position: int # if the field in pydantic is a list this helps keep its order + value: str | dict[str, str | None] # this can be a raw Identifier, Text or Link -class MergableNode(PydanticBaseModel): - """Helper class for merging nodes into the graph.""" - on_create: dict[str, str | list[str]] - on_match: dict[str, str | list[str]] +def expand_references_in_search_result(item: dict[str, Any]) -> None: + """Expand the `_refs` collection in a search result item. - -def transform_model_to_node(model: BaseModel) -> MergableNode: - """Transform a pydantic model into a node that can be merged into the graph.""" - raw_model = to_primitive( - model, - exclude=REFERENCE_FIELDS_BY_CLASS_NAME[model.__class__.__name__] - | {"entityType"}, - ) - on_create = dehydrate(raw_model) - - on_match = on_create.copy() - on_match.pop("identifier") - on_match.pop("stableTargetId") - on_match.pop("identifierInPrimarySource") - - return MergableNode(on_create=on_create, on_match=on_match) - - -class MergableEdge(PydanticBaseModel): - """Helper class for merging edges into the graph.""" - - label: str - parameters: dict[str, Any] - log_message: str - - -def transform_model_to_edges(model: MExModel) -> list[MergableEdge]: - """Transform a model to a list of edges.""" - raw_model = to_primitive( - model, - exclude={"entityType"}, - include=REFERENCE_FIELDS_BY_CLASS_NAME[model.__class__.__name__], - ) - # TODO: add support for link fields in nested dicts, eg. for rules - edges = [] - for field, stable_target_ids in raw_model.items(): - if not isinstance(stable_target_ids, list): - stable_target_ids = [stable_target_ids] - from_id = str(model.identifier) - for stable_target_id in stable_target_ids: - stable_target_id = str(stable_target_id) - parameters = {"fromID": from_id, "toSTI": stable_target_id} - edges.append( - MergableEdge( - label=field, - parameters=parameters, - log_message=f"({from_id})-[:{field}]→({stable_target_id})", - ) - ) - return edges - - -class SearchResultReference(TypedDict): - """Type definition for references returned by search query.""" - - key: str # label of the edge, e.g. parentUnit or hadPrimarySource - value: list[str] | str # stableTargetId of the referenced Node - - -def transform_search_result_to_model( - search_result: dict[str, Any] -) -> AnyExtractedModel: - """Transform a graph search result to an extracted item.""" - model_class_name: str = search_result["l"] - flattened_dict: dict[str, Any] = search_result["n"] - references: list[SearchResultReference] = search_result["r"] - model_class = EXTRACTED_MODEL_CLASSES_BY_NAME[model_class_name] - raw_model = hydrate(flattened_dict, model_class) - - # duplicate references can occur because we link - # rule-sets and extracted-items, not merged-items - deduplicated_references: dict[str, set[str]] = defaultdict(set) - for reference in references: - reference_ids = ( - reference["value"] - if isinstance(reference["value"], list) - else [reference["value"]] - ) - deduplicated_references[reference["key"]].update(reference_ids) - sorted_deduplicated_reference_key_values = { - reference_type: sorted(reference_ids) - for reference_type, reference_ids in deduplicated_references.items() - } - raw_model.update(sorted_deduplicated_reference_key_values) # type: ignore[arg-type] - - return model_class.model_validate(raw_model) - - -def transform_identity_result_to_identity(identity_result: dict[str, Any]) -> Identity: - """Transform the result from an identity query into an Identity instance.""" - return Identity.model_validate(identity_result["i"]) + Each item in a search result has a collection of `_refs` in the form of + `_SearchResultReference`. Before parsing them into pydantic, we need to inline + the references back into the `item` dictionary. + """ + # XXX if we can use `apoc`, we might do this step directly in the cypher query + for ref in cast(list[_SearchResultReference], item.pop("_refs")): + target_list = item.setdefault(ref["label"], [None]) + length_needed = 1 + ref["position"] - len(target_list) + target_list.extend([None] * length_needed) + target_list[ref["position"]] = ref["value"] diff --git a/mex/backend/identity/main.py b/mex/backend/identity/main.py index 4eac152..1392d43 100644 --- a/mex/backend/identity/main.py +++ b/mex/backend/identity/main.py @@ -1,9 +1,7 @@ from fastapi import APIRouter -from fastapi.exceptions import HTTPException from mex.backend.identity.models import IdentityAssignRequest, IdentityFetchResponse from mex.backend.identity.provider import GraphIdentityProvider -from mex.common.exceptions import MExError from mex.common.identity.models import Identity from mex.common.types import Identifier @@ -22,9 +20,9 @@ def assign_identity(request: IdentityAssignRequest) -> Identity: @router.get("/identity", status_code=200, tags=["extractors"]) def fetch_identity( - hadPrimarySource: Identifier | None = None, # noqa: N803 - identifierInPrimarySource: str | None = None, # noqa: N803 - stableTargetId: Identifier | None = None, # noqa: N803 + hadPrimarySource: Identifier | None = None, + identifierInPrimarySource: str | None = None, + stableTargetId: Identifier | None = None, ) -> IdentityFetchResponse: """Find an Identity instance from the database if it can be found. @@ -32,12 +30,9 @@ def fetch_identity( and `identifierInPrimarySource` together to get a unique result. """ identity_provider = GraphIdentityProvider.get() - try: - identities = identity_provider.fetch( - had_primary_source=hadPrimarySource, - identifier_in_primary_source=identifierInPrimarySource, - stable_target_id=stableTargetId, - ) - except MExError as error: - raise HTTPException(400, error.args) from None + identities = identity_provider.fetch( + had_primary_source=hadPrimarySource, + identifier_in_primary_source=identifierInPrimarySource, + stable_target_id=stableTargetId, + ) return IdentityFetchResponse(items=identities, total=len(identities)) diff --git a/mex/backend/identity/models.py b/mex/backend/identity/models.py index 14a905d..f4a1de3 100644 --- a/mex/backend/identity/models.py +++ b/mex/backend/identity/models.py @@ -1,12 +1,12 @@ from mex.common.identity.models import Identity from mex.common.models import BaseModel -from mex.common.types import PrimarySourceID +from mex.common.types import MergedPrimarySourceIdentifier class IdentityAssignRequest(BaseModel): """Request body for identity upsert requests.""" - hadPrimarySource: PrimarySourceID + hadPrimarySource: MergedPrimarySourceIdentifier identifierInPrimarySource: str diff --git a/mex/backend/identity/provider.py b/mex/backend/identity/provider.py index d30890f..e686575 100644 --- a/mex/backend/identity/provider.py +++ b/mex/backend/identity/provider.py @@ -1,30 +1,39 @@ -from functools import cache +from functools import lru_cache from mex.backend.graph.connector import GraphConnector -from mex.backend.graph.transform import transform_identity_result_to_identity -from mex.common.exceptions import MExError from mex.common.identity import BaseProvider, Identity -from mex.common.types import Identifier, PrimarySourceID +from mex.common.types import Identifier, MergedPrimarySourceIdentifier class GraphIdentityProvider(BaseProvider, GraphConnector): - """Identity provider that communicates with the neo4j graph database.""" + """Identity provider that communicates with the graph database.""" + + def __init__(self) -> None: + """Create a new graph identity provider.""" + super().__init__() + # mitigating https://docs.astral.sh/ruff/rules/cached-instance-method + self._cached_assign = lru_cache(5000)(self._do_assign) - @cache # noqa: B019 def assign( self, - had_primary_source: PrimarySourceID, + had_primary_source: MergedPrimarySourceIdentifier, + identifier_in_primary_source: str, + ) -> Identity: + """Return a cached Identity from the database or newly assigned one.""" + return self._cached_assign(had_primary_source, identifier_in_primary_source) + + def _do_assign( + self, + had_primary_source: MergedPrimarySourceIdentifier, identifier_in_primary_source: str, ) -> Identity: """Find an Identity in the database or assign a new one.""" - graph_result = self.fetch_identities( + result = self.fetch_identities( had_primary_source=had_primary_source, identifier_in_primary_source=identifier_in_primary_source, ) - if len(graph_result.data) > 1: - raise MExError("found multiple identities indicating graph inconsistency") - if len(graph_result.data) == 1: - return transform_identity_result_to_identity(graph_result.data[0]) + if record := result.one_or_none(): + return Identity.model_validate(record) return Identity( hadPrimarySource=had_primary_source, identifier=Identifier.generate(), @@ -44,12 +53,9 @@ def fetch( Either provide `stable_target_id` or `had_primary_source` and `identifier_in_primary_source` together to get a unique result. """ - graph_result = self.fetch_identities( + result = self.fetch_identities( had_primary_source=had_primary_source, identifier_in_primary_source=identifier_in_primary_source, stable_target_id=stable_target_id, ) - return [ - transform_identity_result_to_identity(result) - for result in graph_result.data - ] + return [Identity.model_validate(result) for result in result] diff --git a/mex/backend/ingest/models.py b/mex/backend/ingest/models.py index 9d12740..cfef4ad 100644 --- a/mex/backend/ingest/models.py +++ b/mex/backend/ingest/models.py @@ -2,8 +2,11 @@ from pydantic import ConfigDict, create_model -from mex.backend.extracted.models import AnyExtractedModel -from mex.common.models import EXTRACTED_MODEL_CLASSES_BY_NAME, BaseModel +from mex.common.models import ( + EXTRACTED_MODEL_CLASSES_BY_NAME, + AnyExtractedModel, + BaseModel, +) from mex.common.types import Identifier diff --git a/mex/backend/main.py b/mex/backend/main.py index 06bfddc..437f401 100644 --- a/mex/backend/main.py +++ b/mex/backend/main.py @@ -7,10 +7,7 @@ from fastapi.openapi.utils import get_openapi from fastapi.responses import JSONResponse from pydantic import BaseModel, ValidationError -from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request -from starlette.responses import Response -from starlette.types import ASGIApp from mex.backend.extracted.main import router as extracted_router from mex.backend.identity.main import router as identity_router @@ -24,7 +21,6 @@ from mex.common.connector import ConnectorContext from mex.common.exceptions import MExError from mex.common.logging import logger -from mex.common.settings import SettingsContext from mex.common.types import Identifier from mex.common.types.identifier import MEX_ID_PATTERN @@ -56,7 +52,7 @@ def create_openapi_schema() -> dict[str, Any]: openapi_schema["components"]["schemas"][name] = { "title": name, "type": "string", - "description": f"Identifier for {name.replace('ID', '')} items.", + "description": subclass.__doc__, "pattern": MEX_ID_PATTERN, } @@ -74,6 +70,7 @@ def close_connectors() -> None: logger.exception("Error closing %s", connector_type) else: logger.info("Closed %s", connector_type) + context.clear() @asynccontextmanager @@ -129,22 +126,6 @@ def handle_uncaught_exception(request: Request, exc: Exception) -> JSONResponse: return JSONResponse(to_primitive(body), 500) -class SettingsMiddleware(BaseHTTPMiddleware): - """Middleware to inject the settings into requests contexts.""" - - def __init__(self, app: ASGIApp) -> None: - """Store the settings on the middleware.""" - super().__init__(app) - self.settings = BackendSettings.get() - - async def dispatch( - self, request: Request, call_next: RequestResponseEndpoint - ) -> Response: - """Dispatch a new request with settings injected into its context.""" - SettingsContext.set(self.settings) - return await call_next(request) - - app.include_router(router) app.add_exception_handler(ValidationError, handle_uncaught_exception) app.add_exception_handler(MExError, handle_uncaught_exception) @@ -156,7 +137,6 @@ async def dispatch( allow_methods=["*"], allow_origins=["*"], ) -app.add_middleware(SettingsMiddleware) @entrypoint(BackendSettings) diff --git a/mex/backend/merged/main.py b/mex/backend/merged/main.py index 3a2bb61..5ed27a0 100644 --- a/mex/backend/merged/main.py +++ b/mex/backend/merged/main.py @@ -3,14 +3,9 @@ from fastapi import APIRouter, Query from mex.backend.graph.connector import GraphConnector -from mex.backend.merged.models import ( - MergedItemSearchResponse, - MergedType, - UnprefixedType, -) -from mex.backend.merged.transform import ( - transform_graph_results_to_merged_item_search_response_facade, -) +from mex.backend.merged.models import MergedItemSearchResponse +from mex.backend.types import MergedType, UnprefixedType +from mex.common.transform import ensure_prefix from mex.common.types import Identifier router = APIRouter() @@ -19,8 +14,9 @@ @router.get("/merged-item", tags=["editor", "public"]) def search_merged_items_facade( q: Annotated[str, Query(max_length=100)] = "", - stableTargetId: Annotated[Identifier | None, Query()] = None, # noqa: N803 - entityType: Annotated[ # noqa: N803 + stableTargetId: Annotated[Identifier | None, Query(deprecated=True)] = None, + identifier: Identifier | None = None, + entityType: Annotated[ Sequence[MergedType | UnprefixedType], Query(max_length=len(MergedType)) ] = [], skip: Annotated[int, Query(ge=0, le=10e10)] = 0, @@ -30,16 +26,23 @@ def search_merged_items_facade( # XXX We just search for extracted items and pretend they are already merged # as a stopgap for MX-1382. graph = GraphConnector.get() - query_results = graph.query_nodes( + result = graph.fetch_extracted_data( q, - stableTargetId, + identifier or stableTargetId, [ # Allow 'MergedPerson' as well as 'Person' as an entityType for this # endpoint to keep compatibility with previous API clients. - f"Extracted{t.value.removeprefix('Merged')}" + ensure_prefix(t.value.removeprefix("Merged"), "Extracted") for t in entityType or MergedType ], skip, limit, ) - return transform_graph_results_to_merged_item_search_response_facade(query_results) + + for item in result["items"]: + del item["hadPrimarySource"] + del item["identifierInPrimarySource"] + item["identifier"] = item.pop("stableTargetId") + item["entityType"] = item["entityType"].replace("Extracted", "Merged") + + return MergedItemSearchResponse.model_validate(result.one()) diff --git a/mex/backend/merged/models.py b/mex/backend/merged/models.py index 3ca4ebd..2b34fb4 100644 --- a/mex/backend/merged/models.py +++ b/mex/backend/merged/models.py @@ -1,33 +1,12 @@ -from enum import Enum -from typing import TYPE_CHECKING, Union +from typing import Annotated from pydantic import Field -from mex.backend.types import DynamicStrEnum -from mex.common.models import MERGED_MODEL_CLASSES_BY_NAME, MergedItem -from mex.common.models.base import BaseModel - - -class UnprefixedType(Enum, metaclass=DynamicStrEnum): - """Enumeration of possible unprefixed types for merged items.""" - - __names__ = list(m.removeprefix("Merged") for m in MERGED_MODEL_CLASSES_BY_NAME) - - -class MergedType(Enum, metaclass=DynamicStrEnum): - """Enumeration of possible types for merged items.""" - - __names__ = list(MERGED_MODEL_CLASSES_BY_NAME) - - -if TYPE_CHECKING: # pragma: no cover - AnyMergedModel = MergedItem -else: - AnyMergedModel = Union[*MERGED_MODEL_CLASSES_BY_NAME.values()] +from mex.common.models import AnyMergedModel, BaseModel class MergedItemSearchResponse(BaseModel): """Response body for the merged item search endpoint.""" total: int - items: list[AnyMergedModel] = Field(discriminator="entityType") + items: Annotated[list[AnyMergedModel], Field(discriminator="entityType")] diff --git a/mex/backend/merged/transform.py b/mex/backend/merged/transform.py deleted file mode 100644 index f1a57e9..0000000 --- a/mex/backend/merged/transform.py +++ /dev/null @@ -1,52 +0,0 @@ -import json - -from neo4j.exceptions import Neo4jError - -from mex.backend.graph.models import GraphResult -from mex.backend.graph.transform import transform_search_result_to_model -from mex.backend.merged.models import MergedItemSearchResponse -from mex.common.logging import logger -from mex.common.models import MERGED_MODEL_CLASSES_BY_NAME - - -def transform_graph_results_to_merged_item_search_response_facade( - graph_results: list[GraphResult], -) -> MergedItemSearchResponse: - """Transform graph results to extracted item search results. - - We just pretend they are merged items as a stopgap for MX-1382. - - Args: - graph_results: Results of a search and a count query - - Returns: - Search response instance - """ - search_result, count_result = graph_results - total = count_result.data[0]["c"] - items = [] - for result in search_result.data: - try: - model = transform_search_result_to_model(result) - model_dict = model.model_dump( - exclude={"hadPrimarySource", "identifierInPrimarySource"} - ) - # create a MergedModel class with the dictionary - model_dict["entityType"] = model_dict["entityType"].replace( - "Extracted", "Merged" - ) - model_class = MERGED_MODEL_CLASSES_BY_NAME[model_dict["entityType"]] - items.append(model_class.model_validate(model_dict)) - except Neo4jError as error: # noqa: PERF203 - logger.exception( - "%s\n__node__\n %s\n__refs__\n%s\n", - error, - " \n".join( - "{}: {}".format(k, json.dumps(v, separators=(",", ":"))) - for k, v in result["n"].items() - ), - " \n".join("()-[{key}]->({value})".format(**r) for r in result["r"]), - exc_info=False, - ) - # TODO: merge extracted items with rule sets - return MergedItemSearchResponse.model_validate({"items": items, "total": total}) diff --git a/mex/backend/settings.py b/mex/backend/settings.py index 472ddb2..d36d507 100644 --- a/mex/backend/settings.py +++ b/mex/backend/settings.py @@ -42,22 +42,22 @@ class BackendSettings(BaseSettings): ) graph_url: str = Field( "neo4j://localhost:7687", - description="URL of the neo4j HTTP API endpoint including the graph name.", + description="URL for connecting to the graph database.", validation_alias="MEX_GRAPH_URL", ) graph_db: str = Field( "neo4j", - description="Name of the neo4j graph database.", + description="Name of the default graph database.", validation_alias="MEX_GRAPH_NAME", ) graph_user: SecretStr = Field( SecretStr("neo4j"), - description="Username for authenticating with the neo4j graph.", + description="Username for authenticating with the graph database.", validation_alias="MEX_GRAPH_USER", ) graph_password: SecretStr = Field( SecretStr("password"), - description="Password for authenticating with the neo4j graph.", + description="Password for authenticating with the graph database.", validation_alias="MEX_GRAPH_PASSWORD", ) backend_api_key_database: APIKeyDatabase = Field( diff --git a/mex/backend/transform.py b/mex/backend/transform.py index 37e67f8..9fcb986 100644 --- a/mex/backend/transform.py +++ b/mex/backend/transform.py @@ -1,14 +1,14 @@ from enum import Enum -from typing import Any, Callable +from typing import Any, Callable, Final from fastapi.encoders import jsonable_encoder -from mex.common.types import Identifier, Timestamp +from mex.common.types import Identifier, TemporalEntity -JSON_ENCODERS = { - Enum: lambda obj: obj.value, +JSON_ENCODERS: Final[dict[type, Callable[[Any], str]]] = { + Enum: lambda obj: str(obj.value), Identifier: lambda obj: str(obj), - Timestamp: lambda obj: str(obj), + TemporalEntity: lambda obj: str(obj), } diff --git a/mex/backend/types.py b/mex/backend/types.py index bdb6003..abea95b 100644 --- a/mex/backend/types.py +++ b/mex/backend/types.py @@ -1,10 +1,18 @@ from enum import Enum, EnumMeta, _EnumDict +from typing import Literal from pydantic import SecretStr -from mex.common.models import BaseModel +from mex.common.models import ( + BASE_MODEL_CLASSES_BY_NAME, + EXTRACTED_MODEL_CLASSES_BY_NAME, + MERGED_MODEL_CLASSES_BY_NAME, + BaseModel, +) from mex.common.transform import dromedary_to_snake +LiteralStringType = type(Literal["str"]) + class AccessLevel(Enum): """Enum of access level.""" @@ -55,3 +63,21 @@ def __new__( for name in dct.pop("__names__"): dct[dromedary_to_snake(name).upper()] = name return super().__new__(cls, name, bases, dct) + + +class UnprefixedType(Enum, metaclass=DynamicStrEnum): + """Enumeration of possible types without any prefix.""" + + __names__ = list(m.removeprefix("Base") for m in BASE_MODEL_CLASSES_BY_NAME) + + +class ExtractedType(Enum, metaclass=DynamicStrEnum): + """Enumeration of possible types for extracted items.""" + + __names__ = list(EXTRACTED_MODEL_CLASSES_BY_NAME) + + +class MergedType(Enum, metaclass=DynamicStrEnum): + """Enumeration of possible types for merged items.""" + + __names__ = list(MERGED_MODEL_CLASSES_BY_NAME) diff --git a/poetry.lock b/poetry.lock index 36917ba..2df3aad 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "alabaster" @@ -366,18 +366,18 @@ tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipyth [[package]] name = "fastapi" -version = "0.110.0" +version = "0.110.1" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" optional = false python-versions = ">=3.8" files = [ - {file = "fastapi-0.110.0-py3-none-any.whl", hash = "sha256:87a1f6fb632a218222c5984be540055346a8f5d8a68e8f6fb647b1dc9934de4b"}, - {file = "fastapi-0.110.0.tar.gz", hash = "sha256:266775f0dcc95af9d3ef39bad55cff525329a931d5fd51930aadd4f428bf7ff3"}, + {file = "fastapi-0.110.1-py3-none-any.whl", hash = "sha256:5df913203c482f820d31f48e635e022f8cbfe7350e4830ef05a3163925b1addc"}, + {file = "fastapi-0.110.1.tar.gz", hash = "sha256:6feac43ec359dfe4f45b2c18ec8c94edb8dc2dfc461d417d9e626590c071baad"}, ] [package.dependencies] pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0" -starlette = ">=0.36.3,<0.37.0" +starlette = ">=0.37.2,<0.38.0" typing-extensions = ">=4.8.0" [package.extras] @@ -396,13 +396,13 @@ files = [ [[package]] name = "httpcore" -version = "1.0.4" +version = "1.0.5" description = "A minimal low-level HTTP client." optional = false python-versions = ">=3.8" files = [ - {file = "httpcore-1.0.4-py3-none-any.whl", hash = "sha256:ac418c1db41bade2ad53ae2f3834a3a0f5ae76b56cf5aa497d2d033384fc7d73"}, - {file = "httpcore-1.0.4.tar.gz", hash = "sha256:cb2839ccfcba0d2d3c1131d3c3e26dfc327326fbe7a5dc0dbfe9f6c9151bb022"}, + {file = "httpcore-1.0.5-py3-none-any.whl", hash = "sha256:421f18bac248b25d310f3cacd198d55b8e6125c107797b609ff9b7a6ba7991b5"}, + {file = "httpcore-1.0.5.tar.gz", hash = "sha256:34a38e2f9291467ee3b44e89dd52615370e152954ba21721378a87b2960f7a61"}, ] [package.dependencies] @@ -413,7 +413,7 @@ h11 = ">=0.13,<0.15" asyncio = ["anyio (>=4.0,<5.0)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] -trio = ["trio (>=0.22.0,<0.25.0)"] +trio = ["trio (>=0.22.0,<0.26.0)"] [[package]] name = "httptools" @@ -537,13 +537,13 @@ ipython = {version = ">=7.31.1", markers = "python_version >= \"3.11\""} [[package]] name = "ipython" -version = "8.22.2" +version = "8.23.0" description = "IPython: Productive Interactive Computing" optional = false python-versions = ">=3.10" files = [ - {file = "ipython-8.22.2-py3-none-any.whl", hash = "sha256:3c86f284c8f3d8f2b6c662f885c4889a91df7cd52056fd02b7d8d6195d7f56e9"}, - {file = "ipython-8.22.2.tar.gz", hash = "sha256:2dcaad9049f9056f1fef63514f176c7d41f930daa78d05b82a176202818f2c14"}, + {file = "ipython-8.23.0-py3-none-any.whl", hash = "sha256:07232af52a5ba146dc3372c7bf52a0f890a23edf38d77caef8d53f9cdc2584c1"}, + {file = "ipython-8.23.0.tar.gz", hash = "sha256:7468edaf4f6de3e1b912e57f66c241e6fd3c7099f2ec2136e239e142e800274d"}, ] [package.dependencies] @@ -556,12 +556,14 @@ prompt-toolkit = ">=3.0.41,<3.1.0" pygments = ">=2.4.0" stack-data = "*" traitlets = ">=5.13.0" +typing-extensions = {version = "*", markers = "python_version < \"3.12\""} [package.extras] -all = ["ipython[black,doc,kernel,nbconvert,nbformat,notebook,parallel,qtconsole,terminal]", "ipython[test,test-extra]"] +all = ["ipython[black,doc,kernel,matplotlib,nbconvert,nbformat,notebook,parallel,qtconsole]", "ipython[test,test-extra]"] black = ["black"] doc = ["docrepr", "exceptiongroup", "ipykernel", "ipython[test]", "matplotlib", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "sphinxcontrib-jquery", "stack-data", "typing-extensions"] kernel = ["ipykernel"] +matplotlib = ["matplotlib"] nbconvert = ["nbconvert"] nbformat = ["nbformat"] notebook = ["ipywidgets", "notebook"] @@ -719,29 +721,49 @@ traitlets = "*" [[package]] name = "mex-common" -version = "0.19.3" +version = "0.22.0" description = "Common library for MEx python projects." optional = false -python-versions = "^3.11" +python-versions = ">=3.11" files = [] develop = false [package.dependencies] -backoff = "^2.2.1" -click = "^8.1.7" -langdetect = "^1.0.9" -ldap3 = "^2.9.1" -numpy = "^1.26.2" -pandas = "^2.1.4" -pydantic = "^2.6.0" -pydantic-settings = "^2.1.0" -requests = "^2.31.0" +backoff = ">=2.2.1" +click = ">=8.1.7" +langdetect = ">=1.0.9" +ldap3 = ">=2.9.1" +mex-model = {git = "https://github.com/robert-koch-institut/mex-model.git", rev = "2.3.0"} +numpy = ">=1.26.4" +pandas = ">=2.2.1" +pyarrow = ">=15.0.1" +pydantic = ">=2.6.4" +pydantic-settings = ">=2.2.1" +requests = ">=2.31.0" + +[package.extras] +dev = ["black (>=24.3.0)", "ipdb (>=0.13.13)", "mypy (>=1.9.0)", "pandas-stubs (>=2.2.1)", "pytest (>=8.1.1)", "pytest-cov (>=4.1.0)", "pytest-random-order (>=1.1.1)", "pytest-xdist (>=3.5.0)", "ruff (>=0.3.3)", "sphinx (>=7.2.6)", "types-ldap3 (>=2.9.13)", "types-pytz (>=2024.1.0)", "types-requests (>=2.31.0)"] [package.source] type = "git" url = "https://github.com/robert-koch-institut/mex-common.git" -reference = "0.19.3" -resolved_reference = "7a8ede7f28398393ca0c078cb6245b0f4a984545" +reference = "0.22.0" +resolved_reference = "f23a25131fcf35e228b86463b30f99d44c290e0d" + +[[package]] +name = "mex-model" +version = "2.3.0" +description = "Conceptual and machine-readable versions of the MEx metadata model." +optional = false +python-versions = "^3.11" +files = [] +develop = false + +[package.source] +type = "git" +url = "https://github.com/robert-koch-institut/mex-model.git" +reference = "2.3.0" +resolved_reference = "d4aba623fc54f9b0f1b412408a21e62b6913f875" [[package]] name = "mypy" @@ -802,12 +824,12 @@ files = [ [[package]] name = "neo4j" -version = "5.18.0" +version = "5.19.0" description = "Neo4j Bolt driver for Python" optional = false python-versions = ">=3.7" files = [ - {file = "neo4j-5.18.0.tar.gz", hash = "sha256:4014406ae5b8b485a8ba46c9f00b6f5b4aaf88e7c3a50603445030c2aab701c9"}, + {file = "neo4j-5.19.0.tar.gz", hash = "sha256:23704f604214174f3b7d15a38653a1462809986019dfdaf773ff7ca4e1b9e2de"}, ] [package.dependencies] @@ -1055,15 +1077,63 @@ files = [ [package.extras] tests = ["pytest"] +[[package]] +name = "pyarrow" +version = "15.0.2" +description = "Python library for Apache Arrow" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyarrow-15.0.2-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:88b340f0a1d05b5ccc3d2d986279045655b1fe8e41aba6ca44ea28da0d1455d8"}, + {file = "pyarrow-15.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:eaa8f96cecf32da508e6c7f69bb8401f03745c050c1dd42ec2596f2e98deecac"}, + {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23c6753ed4f6adb8461e7c383e418391b8d8453c5d67e17f416c3a5d5709afbd"}, + {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f639c059035011db8c0497e541a8a45d98a58dbe34dc8fadd0ef128f2cee46e5"}, + {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:290e36a59a0993e9a5224ed2fb3e53375770f07379a0ea03ee2fce2e6d30b423"}, + {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:06c2bb2a98bc792f040bef31ad3e9be6a63d0cb39189227c08a7d955db96816e"}, + {file = "pyarrow-15.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:f7a197f3670606a960ddc12adbe8075cea5f707ad7bf0dffa09637fdbb89f76c"}, + {file = "pyarrow-15.0.2-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:5f8bc839ea36b1f99984c78e06e7a06054693dc2af8920f6fb416b5bca9944e4"}, + {file = "pyarrow-15.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f5e81dfb4e519baa6b4c80410421528c214427e77ca0ea9461eb4097c328fa33"}, + {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a4f240852b302a7af4646c8bfe9950c4691a419847001178662a98915fd7ee7"}, + {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e7d9cfb5a1e648e172428c7a42b744610956f3b70f524aa3a6c02a448ba853e"}, + {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2d4f905209de70c0eb5b2de6763104d5a9a37430f137678edfb9a675bac9cd98"}, + {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:90adb99e8ce5f36fbecbbc422e7dcbcbed07d985eed6062e459e23f9e71fd197"}, + {file = "pyarrow-15.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:b116e7fd7889294cbd24eb90cd9bdd3850be3738d61297855a71ac3b8124ee38"}, + {file = "pyarrow-15.0.2-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:25335e6f1f07fdaa026a61c758ee7d19ce824a866b27bba744348fa73bb5a440"}, + {file = "pyarrow-15.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:90f19e976d9c3d8e73c80be84ddbe2f830b6304e4c576349d9360e335cd627fc"}, + {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a22366249bf5fd40ddacc4f03cd3160f2d7c247692945afb1899bab8a140ddfb"}, + {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2a335198f886b07e4b5ea16d08ee06557e07db54a8400cc0d03c7f6a22f785f"}, + {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:3e6d459c0c22f0b9c810a3917a1de3ee704b021a5fb8b3bacf968eece6df098f"}, + {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:033b7cad32198754d93465dcfb71d0ba7cb7cd5c9afd7052cab7214676eec38b"}, + {file = "pyarrow-15.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:29850d050379d6e8b5a693098f4de7fd6a2bea4365bfd073d7c57c57b95041ee"}, + {file = "pyarrow-15.0.2-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:7167107d7fb6dcadb375b4b691b7e316f4368f39f6f45405a05535d7ad5e5058"}, + {file = "pyarrow-15.0.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e85241b44cc3d365ef950432a1b3bd44ac54626f37b2e3a0cc89c20e45dfd8bf"}, + {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:248723e4ed3255fcd73edcecc209744d58a9ca852e4cf3d2577811b6d4b59818"}, + {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ff3bdfe6f1b81ca5b73b70a8d482d37a766433823e0c21e22d1d7dde76ca33f"}, + {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:f3d77463dee7e9f284ef42d341689b459a63ff2e75cee2b9302058d0d98fe142"}, + {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:8c1faf2482fb89766e79745670cbca04e7018497d85be9242d5350cba21357e1"}, + {file = "pyarrow-15.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:28f3016958a8e45a1069303a4a4f6a7d4910643fc08adb1e2e4a7ff056272ad3"}, + {file = "pyarrow-15.0.2-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:89722cb64286ab3d4daf168386f6968c126057b8c7ec3ef96302e81d8cdb8ae4"}, + {file = "pyarrow-15.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cd0ba387705044b3ac77b1b317165c0498299b08261d8122c96051024f953cd5"}, + {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad2459bf1f22b6a5cdcc27ebfd99307d5526b62d217b984b9f5c974651398832"}, + {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58922e4bfece8b02abf7159f1f53a8f4d9f8e08f2d988109126c17c3bb261f22"}, + {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:adccc81d3dc0478ea0b498807b39a8d41628fa9210729b2f718b78cb997c7c91"}, + {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:8bd2baa5fe531571847983f36a30ddbf65261ef23e496862ece83bdceb70420d"}, + {file = "pyarrow-15.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:6669799a1d4ca9da9c7e06ef48368320f5856f36f9a4dd31a11839dda3f6cc8c"}, + {file = "pyarrow-15.0.2.tar.gz", hash = "sha256:9c9bc803cb3b7bfacc1e96ffbfd923601065d9d3f911179d81e72d99fd74a3d9"}, +] + +[package.dependencies] +numpy = ">=1.16.6,<2" + [[package]] name = "pyasn1" -version = "0.5.1" +version = "0.6.0" description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +python-versions = ">=3.8" files = [ - {file = "pyasn1-0.5.1-py2.py3-none-any.whl", hash = "sha256:4439847c58d40b1d0a573d07e3856e95333f1976294494c325775aeca506eb58"}, - {file = "pyasn1-0.5.1.tar.gz", hash = "sha256:6d391a96e59b23130a5cfa74d6fd7f388dbbe26cc8f1edf39fdddf08d9d6676c"}, + {file = "pyasn1-0.6.0-py2.py3-none-any.whl", hash = "sha256:cca4bb0f2df5504f02f6f8a775b6e416ff9b0b3b16f7ee80b5a3153d9b804473"}, + {file = "pyasn1-0.6.0.tar.gz", hash = "sha256:3a35ab2c4b5ef98e17dfdec8ab074046fbda76e281c5a706ccd82328cfc8f64c"}, ] [[package]] @@ -1326,7 +1396,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -1384,28 +1453,28 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "ruff" -version = "0.3.3" +version = "0.3.5" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.3.3-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:973a0e388b7bc2e9148c7f9be8b8c6ae7471b9be37e1cc732f8f44a6f6d7720d"}, - {file = "ruff-0.3.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:cfa60d23269d6e2031129b053fdb4e5a7b0637fc6c9c0586737b962b2f834493"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1eca7ff7a47043cf6ce5c7f45f603b09121a7cc047447744b029d1b719278eb5"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e7d3f6762217c1da954de24b4a1a70515630d29f71e268ec5000afe81377642d"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b24c19e8598916d9c6f5a5437671f55ee93c212a2c4c569605dc3842b6820386"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:5a6cbf216b69c7090f0fe4669501a27326c34e119068c1494f35aaf4cc683778"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:352e95ead6964974b234e16ba8a66dad102ec7bf8ac064a23f95371d8b198aab"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8d6ab88c81c4040a817aa432484e838aaddf8bfd7ca70e4e615482757acb64f8"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:79bca3a03a759cc773fca69e0bdeac8abd1c13c31b798d5bb3c9da4a03144a9f"}, - {file = "ruff-0.3.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:2700a804d5336bcffe063fd789ca2c7b02b552d2e323a336700abb8ae9e6a3f8"}, - {file = "ruff-0.3.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:fd66469f1a18fdb9d32e22b79f486223052ddf057dc56dea0caaf1a47bdfaf4e"}, - {file = "ruff-0.3.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:45817af234605525cdf6317005923bf532514e1ea3d9270acf61ca2440691376"}, - {file = "ruff-0.3.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:0da458989ce0159555ef224d5b7c24d3d2e4bf4c300b85467b08c3261c6bc6a8"}, - {file = "ruff-0.3.3-py3-none-win32.whl", hash = "sha256:f2831ec6a580a97f1ea82ea1eda0401c3cdf512cf2045fa3c85e8ef109e87de0"}, - {file = "ruff-0.3.3-py3-none-win_amd64.whl", hash = "sha256:be90bcae57c24d9f9d023b12d627e958eb55f595428bafcb7fec0791ad25ddfc"}, - {file = "ruff-0.3.3-py3-none-win_arm64.whl", hash = "sha256:0171aab5fecdc54383993389710a3d1227f2da124d76a2784a7098e818f92d61"}, - {file = "ruff-0.3.3.tar.gz", hash = "sha256:38671be06f57a2f8aba957d9f701ea889aa5736be806f18c0cd03d6ff0cbca8d"}, + {file = "ruff-0.3.5-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:aef5bd3b89e657007e1be6b16553c8813b221ff6d92c7526b7e0227450981eac"}, + {file = "ruff-0.3.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:89b1e92b3bd9fca249153a97d23f29bed3992cff414b222fcd361d763fc53f12"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e55771559c89272c3ebab23326dc23e7f813e492052391fe7950c1a5a139d89"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:dabc62195bf54b8a7876add6e789caae0268f34582333cda340497c886111c39"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a05f3793ba25f194f395578579c546ca5d83e0195f992edc32e5907d142bfa3"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:dfd3504e881082959b4160ab02f7a205f0fadc0a9619cc481982b6837b2fd4c0"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:87258e0d4b04046cf1d6cc1c56fadbf7a880cc3de1f7294938e923234cf9e498"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:712e71283fc7d9f95047ed5f793bc019b0b0a29849b14664a60fd66c23b96da1"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a532a90b4a18d3f722c124c513ffb5e5eaff0cc4f6d3aa4bda38e691b8600c9f"}, + {file = "ruff-0.3.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:122de171a147c76ada00f76df533b54676f6e321e61bd8656ae54be326c10296"}, + {file = "ruff-0.3.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:d80a6b18a6c3b6ed25b71b05eba183f37d9bc8b16ace9e3d700997f00b74660b"}, + {file = "ruff-0.3.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a7b6e63194c68bca8e71f81de30cfa6f58ff70393cf45aab4c20f158227d5936"}, + {file = "ruff-0.3.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:a759d33a20c72f2dfa54dae6e85e1225b8e302e8ac655773aff22e542a300985"}, + {file = "ruff-0.3.5-py3-none-win32.whl", hash = "sha256:9d8605aa990045517c911726d21293ef4baa64f87265896e491a05461cae078d"}, + {file = "ruff-0.3.5-py3-none-win_amd64.whl", hash = "sha256:dc56bb16a63c1303bd47563c60482a1512721053d93231cf7e9e1c6954395a0e"}, + {file = "ruff-0.3.5-py3-none-win_arm64.whl", hash = "sha256:faeeae9905446b975dcf6d4499dc93439b131f1443ee264055c5716dd947af55"}, + {file = "ruff-0.3.5.tar.gz", hash = "sha256:a067daaeb1dc2baf9b82a32dae67d154d95212080c80435eb052d95da647763d"}, ] [[package]] @@ -1590,13 +1659,13 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] [[package]] name = "starlette" -version = "0.36.3" +version = "0.37.2" description = "The little ASGI library that shines." optional = false python-versions = ">=3.8" files = [ - {file = "starlette-0.36.3-py3-none-any.whl", hash = "sha256:13d429aa93a61dc40bf503e8c801db1f1bca3dc706b10ef2434a36123568f044"}, - {file = "starlette-0.36.3.tar.gz", hash = "sha256:90a671733cfb35771d8cc605e0b679d23b992f8dcfad48cc60b38cb29aeb7080"}, + {file = "starlette-0.37.2-py3-none-any.whl", hash = "sha256:6fe59f29268538e5d0d182f2791a479a0c64638e6935d1c6989e63fb2699c6ee"}, + {file = "starlette-0.37.2.tar.gz", hash = "sha256:9af890290133b79fc3db55474ade20f6220a364a0402e0b556e7cd5e1e093823"}, ] [package.dependencies] @@ -1672,13 +1741,13 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "uvicorn" -version = "0.27.1" +version = "0.29.0" description = "The lightning-fast ASGI server." optional = false python-versions = ">=3.8" files = [ - {file = "uvicorn-0.27.1-py3-none-any.whl", hash = "sha256:5c89da2f3895767472a35556e539fd59f7edbe9b1e9c0e1c99eebeadc61838e4"}, - {file = "uvicorn-0.27.1.tar.gz", hash = "sha256:3d9a267296243532db80c83a959a3400502165ade2c1338dea4e67915fd4745a"}, + {file = "uvicorn-0.29.0-py3-none-any.whl", hash = "sha256:2c2aac7ff4f4365c206fd773a39bf4ebd1047c238f8b8268ad996829323473de"}, + {file = "uvicorn-0.29.0.tar.gz", hash = "sha256:6a69214c0b6a087462412670b3ef21224fa48cae0e452b5883e8e8bdfdd11dd0"}, ] [package.dependencies] @@ -1921,4 +1990,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "c5e28b651ff790f7f1646b70e9d6e98aab7ea03b383ea53ef076f02c15f313c6" +content-hash = "277e977f506ddefd3c7173ef245ee05f8ec06b88d29e00ab3d56b8763792e7bc" diff --git a/pyproject.toml b/pyproject.toml index c008cea..0ffbfc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,12 +10,12 @@ packages = [{ include = "mex" }] [tool.poetry.dependencies] python = "^3.11" -fastapi = "^0.110.0" +fastapi = "^0.110.1" httpx = "^0.27.0" -mex-common = { git = "https://github.com/robert-koch-institut/mex-common.git", rev = "0.19.3" } -neo4j = "^5.17.0" +mex-common = { git = "https://github.com/robert-koch-institut/mex-common.git", rev = "0.22.0" } +neo4j = "^5.18.0" pydantic = "2.6.4" -uvicorn = { version = "^0.27.0", extras = ["standard"] } +uvicorn = { version = "^0.29.0", extras = ["standard"] } [tool.poetry.group.dev.dependencies] black = "^24.3.0" @@ -24,7 +24,7 @@ mypy = "^1.9.0" pytest = "^8.1.1" pytest-cov = "^4.1.0" pytest-random-order = "^1.1.1" -ruff = "^0.3.3" +ruff = "^0.3.5" types-pytz = "^2024.1.0" sphinx = "^7.2.6" @@ -108,6 +108,9 @@ select = [ "N807", # Allow mocking `__init__` for tests "S101", # Allow use of `assert` in tests ] +"**/main.py" = [ + "N803", # Allow dromedaryCase query parameters +] [tool.ruff.lint.isort] known-first-party = ["mex", "tests"] diff --git a/tests/conftest.py b/tests/conftest.py index 0491fe5..465b28b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,11 +2,12 @@ from base64 import b64encode from functools import partial from itertools import count +from typing import Any from unittest.mock import MagicMock, Mock import pytest from fastapi.testclient import TestClient -from neo4j import GraphDatabase +from neo4j import GraphDatabase, SummaryCounters from pytest import MonkeyPatch from mex.backend.graph.connector import GraphConnector @@ -15,15 +16,23 @@ from mex.backend.types import APIKeyDatabase, APIUserDatabase from mex.common.models import ( MEX_PRIMARY_SOURCE_STABLE_TARGET_ID, - BaseExtractedData, + AnyExtractedModel, ExtractedActivity, ExtractedContactPoint, ExtractedOrganizationalUnit, - ExtractedPerson, ExtractedPrimarySource, ) +from mex.common.settings import BaseSettings from mex.common.transform import MExEncoder -from mex.common.types import Identifier, Link, Text, TextLanguage +from mex.common.types import ( + Email, + Identifier, + IdentityProvider, + Link, + Text, + TextLanguage, + Theme, +) from mex.common.types.identifier import IdentifierT pytest_plugins = ("mex.common.testing.plugin",) @@ -44,13 +53,14 @@ def settings() -> BackendSettings: @pytest.fixture(autouse=True) def skip_integration_test_in_ci(is_integration_test: bool) -> None: - """Overwrite fixture from plugin to not skip int tests in ci.""" + """Overwrite fixture from plugin to not skip integration tests in ci.""" @pytest.fixture def client() -> TestClient: """Return a fastAPI test client initialized with our app.""" - return TestClient(app) + with TestClient(app) as test_client: + return test_client @pytest.fixture @@ -93,16 +103,37 @@ def patch_test_client_json_encoder(monkeypatch: MonkeyPatch) -> None: ) +class MockedGraph: + def __init__(self, records: list[Any], session_run: MagicMock) -> None: + self.records = records + self.run = session_run + + @property + def return_value(self) -> list[Any]: # pragma: no cover + return self.records + + @return_value.setter + def return_value(self, value: list[Any]) -> None: + self.records[:] = [Mock(data=MagicMock(return_value=v)) for v in value] + + @property + def call_args_list(self) -> list[Any]: + return self.run.call_args_list + + @pytest.fixture -def mocked_graph(monkeypatch: MonkeyPatch) -> MagicMock: +def mocked_graph(monkeypatch: MonkeyPatch) -> MockedGraph: """Mock the graph connector and return the mocked `run` for easy manipulation.""" - data = MagicMock(return_value=[]) - run = MagicMock(return_value=Mock(data=data)) - data.run = run # make the call_args available in tests + records: list[Any] = [] + summary = Mock(counters=SummaryCounters({})) + result = Mock(to_eager_result=MagicMock(return_value=(records, summary, None))) + run = MagicMock(return_value=result) session = MagicMock(__enter__=MagicMock(return_value=Mock(run=run))) driver = Mock(session=MagicMock(return_value=session)) - monkeypatch.setattr(GraphDatabase, "driver", lambda _, **__: driver) - return data + monkeypatch.setattr( + GraphConnector, "__init__", lambda self: setattr(self, "driver", driver) + ) + return MockedGraph(records, run) @pytest.fixture(autouse=True) @@ -117,12 +148,20 @@ def generate(cls: type[IdentifierT], seed: int | None = None) -> IdentifierT: monkeypatch.setattr(Identifier, "generate", classmethod(generate)) +@pytest.fixture(autouse=True) +def set_identity_provider(is_integration_test: bool) -> None: + """Ensure the identifier provider is set to `MEMORY` in unit tests.""" + if not is_integration_test: + settings = BaseSettings.get() + settings.identity_provider = IdentityProvider.MEMORY + + @pytest.fixture(autouse=True) def isolate_graph_database( is_integration_test: bool, settings: BackendSettings ) -> None: - """Automatically flush the neo4j database for integration testing.""" - if is_integration_test: # pragma: no cover + """Automatically flush the graph database for integration testing.""" + if is_integration_test: with GraphDatabase.driver( settings.graph_url, auth=( @@ -132,35 +171,15 @@ def isolate_graph_database( database=settings.graph_db, ) as driver: driver.execute_query("MATCH (n) DETACH DELETE n;") - driver.execute_query("DROP INDEX text_fields IF EXISTS;") - driver.execute_query("DROP CONSTRAINT identifier_uniqueness IF EXISTS;") + for row in driver.execute_query("SHOW ALL CONSTRAINTS;").records: + driver.execute_query(f"DROP CONSTRAINT {row['name']};") + for row in driver.execute_query("SHOW ALL INDEXES;").records: + driver.execute_query(f"DROP INDEX {row['name']};") @pytest.fixture -def extracted_person() -> BaseExtractedData: - """Return an extracted person with static dummy values.""" - return ExtractedPerson.model_construct( - identifier=Identifier.generate(seed=6), - stableTargetId=Identifier.generate(seed=66), - affiliation=[ - Identifier.generate(seed=255), - Identifier.generate(seed=3810), - ], - email=["fictitiousf@rki.de", "info@rki.de"], - familyName=["Fictitious"], - givenName=["Frieda"], - identifierInPrimarySource="frieda", - fullName=["Fictitious, Frieda, Dr."], - hadPrimarySource=Identifier.generate(seed=64), - memberOf=[ - Identifier.generate(seed=35), - ], - ) - - -@pytest.fixture -def load_dummy_data() -> None: - """Ingest dummy data into Graph Database.""" +def dummy_data() -> list[AnyExtractedModel]: + """Create a set of interlinked dummy data.""" primary_source_1 = ExtractedPrimarySource( hadPrimarySource=MEX_PRIMARY_SOURCE_STABLE_TARGET_ID, identifierInPrimarySource="ps-1", @@ -168,27 +187,27 @@ def load_dummy_data() -> None: primary_source_2 = ExtractedPrimarySource( hadPrimarySource=MEX_PRIMARY_SOURCE_STABLE_TARGET_ID, identifierInPrimarySource="ps-2", - title=[Text(value="A cool and searchable title", language=None)], + version="Cool Version v2.13", ) contact_point_1 = ExtractedContactPoint( - email="info@rki.de", + email=[Email("info@contact-point.one")], hadPrimarySource=primary_source_1.stableTargetId, identifierInPrimarySource="cp-1", ) contact_point_2 = ExtractedContactPoint( - email="mex@rki.de", + email=[Email("help@contact-point.two")], hadPrimarySource=primary_source_1.stableTargetId, identifierInPrimarySource="cp-2", ) organizational_unit_1 = ExtractedOrganizationalUnit( hadPrimarySource=primary_source_2.stableTargetId, identifierInPrimarySource="ou-1", - name="Unit 1", + name=[Text(value="Unit 1", language=TextLanguage.EN)], ) activity_1 = ExtractedActivity( abstract=[ Text(value="An active activity.", language=TextLanguage.EN), - Text(value="Mumble bumble boo.", language=None), + Text(value="Une activité active.", language=None), ], contact=[ contact_point_1.stableTargetId, @@ -198,17 +217,22 @@ def load_dummy_data() -> None: hadPrimarySource=primary_source_1.stableTargetId, identifierInPrimarySource="a-1", responsibleUnit=[organizational_unit_1.stableTargetId], - theme=["https://mex.rki.de/item/theme-3"], - title=["Activity 1"], + theme=[Theme["DIGITAL_PUBLIC_HEALTH"]], + title=[Text(value="Aktivität 1", language=TextLanguage.DE)], website=[Link(title="Activity Homepage", url="https://activity-1")], ) - GraphConnector.get().ingest( - [ - primary_source_1, - primary_source_2, - contact_point_1, - contact_point_2, - organizational_unit_1, - activity_1, - ] - ) + return [ + primary_source_1, + primary_source_2, + contact_point_1, + contact_point_2, + organizational_unit_1, + activity_1, + ] + + +@pytest.fixture +def load_dummy_data(dummy_data: list[AnyExtractedModel]) -> list[AnyExtractedModel]: + """Ingest dummy data into the graph.""" + GraphConnector.get().ingest(dummy_data) + return dummy_data diff --git a/tests/extracted/test_main.py b/tests/extracted/test_main.py index 208f338..701e0b9 100644 --- a/tests/extracted/test_main.py +++ b/tests/extracted/test_main.py @@ -1,48 +1,65 @@ from typing import Any -from unittest.mock import MagicMock import pytest from fastapi.testclient import TestClient +from mex.backend.transform import to_primitive +from mex.common.models import ExtractedOrganizationalUnit +from tests.conftest import MockedGraph + def test_search_extracted_items_mocked( - client_with_api_key_read_permission: TestClient, mocked_graph: MagicMock + client_with_api_key_read_permission: TestClient, mocked_graph: MockedGraph ) -> None: + unit = ExtractedOrganizationalUnit.model_validate( + { + "hadPrimarySource": "2222222222222222", + "identifierInPrimarySource": "unit-1", + "email": ["test@foo.bar"], + "name": [ + {"value": "Eine unit von einer Org.", "language": "de"}, + {"value": "A unit of an org.", "language": "en"}, + ], + } + ) mocked_graph.return_value = [ { - "c": 1, - "l": "ExtractedContactPoint", - "r": [{"key": "hadPrimarySource", "value": ["2222222222222222"]}], - "n": { - "stableTargetId": "0000000000000000", - "identifier": "1111111111111111", - "identifierInPrimarySource": "test", - "email": "test@foo.bar", - }, - "i": { - "stableTargetId": "0000000000000000", - "hadPrimarySource": "2222222222222222", - "identifierInPrimarySource": "test", - "identifier": "1111111111111111", - }, - }, + "items": [ + { + "identifier": unit.identifier, + "identifierInPrimarySource": unit.identifierInPrimarySource, + "stableTargetId": unit.stableTargetId, + "email": ["test@foo.bar"], + "entityType": "ExtractedOrganizationalUnit", + "_refs": [ + { + "label": "hadPrimarySource", + "position": 0, + "value": "2222222222222222", + }, + { + "label": "name", + "position": 0, + "value": { + "value": "Eine unit von einer Org.", + "language": "de", + }, + }, + { + "label": "name", + "position": 1, + "value": {"value": "A unit of an org.", "language": "en"}, + }, + ], + } + ], + "total": 14, + } ] response = client_with_api_key_read_permission.get("/v0/extracted-item") assert response.status_code == 200, response.text - assert response.json() == { - "items": [ - { - "$type": "ExtractedContactPoint", - "email": ["test@foo.bar"], - "hadPrimarySource": "2222222222222222", - "identifier": "1111111111111111", - "identifierInPrimarySource": "test", - "stableTargetId": "0000000000000000", - } - ], - "total": 1, - } + assert response.json() == {"items": [to_primitive(unit)], "total": 14} @pytest.mark.parametrize( @@ -77,7 +94,7 @@ def test_search_extracted_items_mocked( "items": [ { "$type": "ExtractedContactPoint", - "email": ["info@rki.de"], + "email": ["info@contact-point.one"], "hadPrimarySource": "bFQoRhcVH5DHUr", "identifier": "bFQoRhcVH5DHUu", "identifierInPrimarySource": "cp-1", @@ -93,7 +110,7 @@ def test_search_extracted_items_mocked( "items": [ { "$type": "ExtractedContactPoint", - "email": ["info@rki.de"], + "email": ["info@contact-point.one"], "hadPrimarySource": "bFQoRhcVH5DHUr", "identifier": "bFQoRhcVH5DHUu", "identifierInPrimarySource": "cp-1", @@ -101,7 +118,7 @@ def test_search_extracted_items_mocked( }, { "$type": "ExtractedContactPoint", - "email": ["mex@rki.de"], + "email": ["help@contact-point.two"], "hadPrimarySource": "bFQoRhcVH5DHUr", "identifier": "bFQoRhcVH5DHUw", "identifierInPrimarySource": "cp-2", @@ -126,11 +143,9 @@ def test_search_extracted_items_mocked( "identifierInPrimarySource": "ps-2", "locatedAt": [], "stableTargetId": "bFQoRhcVH5DHUt", - "title": [ - {"language": None, "value": "A cool and searchable title"} - ], + "title": [], "unitInCharge": [], - "version": None, + "version": "Cool Version v2.13", } ], "total": 1, diff --git a/tests/graph/test_connector.py b/tests/graph/test_connector.py index 0c7e893..1e58ba0 100644 --- a/tests/graph/test_connector.py +++ b/tests/graph/test_connector.py @@ -1,56 +1,95 @@ -from unittest.mock import MagicMock, call +from typing import Callable import pytest +from black import Mode, format_str +from pytest import MonkeyPatch -from mex.backend.graph.connector import GraphConnector -from mex.backend.graph.queries import ( - HAD_PRIMARY_SOURCE_AND_IDENTIFIER_IN_PRIMARY_SOURCE_IDENTITY_QUERY, - STABLE_TARGET_ID_IDENTITY_QUERY, +from mex.backend.graph import connector as connector_module +from mex.backend.graph.connector import MEX_EXTRACTED_PRIMARY_SOURCE, GraphConnector +from mex.backend.graph.query import QueryBuilder +from mex.common.models import AnyExtractedModel +from mex.common.types import ( + ExtractedPrimarySourceIdentifier, + Identifier, + MergedPrimarySourceIdentifier, ) -from mex.common.exceptions import MExError -from mex.common.models import ExtractedPerson -from mex.common.types import Identifier +from tests.conftest import MockedGraph -@pytest.mark.usefixtures("mocked_graph") -def test_mocked_graph_init() -> None: - graph = GraphConnector.get() - result = graph.commit("MATCH (this);") - assert result.model_dump() == {"data": []} +@pytest.fixture +def mocked_query_builder(monkeypatch: MonkeyPatch) -> None: + def __getattr__(_: QueryBuilder, query: str) -> Callable[..., str]: + return lambda **parameters: format_str( + f"{query}({','.join(f'{k}={v!r}' for k, v in parameters.items())})", + mode=Mode(line_length=78), + ).strip() + + monkeypatch.setattr(QueryBuilder, "__getattr__", __getattr__) -def test_mocked_graph_seed_constraints(mocked_graph: MagicMock) -> None: +@pytest.mark.usefixtures("mocked_query_builder") +def test_mocked_graph_seed_constraints(mocked_graph: MockedGraph) -> None: graph = GraphConnector.get() graph._seed_constraints() - assert mocked_graph.run.call_args_list[-1] == call( - """ -CREATE CONSTRAINT identifier_uniqueness IF NOT EXISTS -FOR (n:ExtractedVariableGroup) -REQUIRE n.identifier IS UNIQUE; -""", - None, + assert mocked_graph.call_args_list[-1].args == ( + 'create_identifier_uniqueness_constraint(node_label="MergedVariableGroup")', + {}, ) -def test_mocked_graph_seed_indices(mocked_graph: MagicMock) -> None: +@pytest.mark.usefixtures("mocked_query_builder") +def test_mocked_graph_seed_indices( + mocked_graph: MockedGraph, monkeypatch: MonkeyPatch +) -> None: + monkeypatch.setattr( + connector_module, "SEARCHABLE_CLASSES", ["ExtractedThis", "ExtractedThat"] + ) + monkeypatch.setattr( + connector_module, + "SEARCHABLE_FIELDS", + ["title", "name", "keyword", "description"], + ) graph = GraphConnector.get() graph._seed_indices() - assert mocked_graph.run.call_args_list[-1] == call( - """ -CREATE FULLTEXT INDEX text_fields IF NOT EXISTS -FOR (n:ExtractedAccessPlatform|ExtractedActivity|ExtractedContactPoint|ExtractedDistribution|ExtractedOrganization|\ -ExtractedOrganizationalUnit|ExtractedPerson|ExtractedPrimarySource|ExtractedResource|ExtractedVariable|ExtractedVariableGroup) -ON EACH [n.abstract_value, n.alternativeName_value, n.alternativeTitle_value, \ -n.description_value, n.instrumentToolOrApparatus_value, n.keyword_value, \ -n.label_value, n.methodDescription_value, n.method_value, n.name_value, \ -n.officialName_value, n.qualityInformation_value, n.resourceTypeSpecific_value, \ -n.rights_value, n.shortName_value, n.spatial_value, n.title_value] -OPTIONS {indexConfig: $config}; -""", + assert mocked_graph.call_args_list[-1].args == ( + """\ +create_full_text_search_index( + node_labels=["ExtractedThis", "ExtractedThat"], + search_fields=["title", "name", "keyword", "description"], +)""", + { + "index_config": { + "fulltext.eventually_consistent": True, + "fulltext.analyzer": "german", + } + }, + ) + + mocked_graph.return_value = [ + { + "node_labels": ["ExtractedThis", "ExtractedThat"], + "search_fields": ["title", "name", "keyword", "description"], + } + ] + monkeypatch.setattr( + connector_module, + "SEARCHABLE_CLASSES", + ["ExtractedThis", "ExtractedThat", "ExtractedOther"], + ) + + graph._seed_indices() + + assert mocked_graph.call_args_list[-2].args == ("drop_full_text_search_index()", {}) + assert mocked_graph.call_args_list[-1].args == ( + """\ +create_full_text_search_index( + node_labels=["ExtractedThis", "ExtractedThat", "ExtractedOther"], + search_fields=["title", "name", "keyword", "description"], +)""", { - "config": { + "index_config": { "fulltext.eventually_consistent": True, "fulltext.analyzer": "german", } @@ -58,12 +97,112 @@ def test_mocked_graph_seed_indices(mocked_graph: MagicMock) -> None: ) -def test_mocked_graph_fetch_identities(mocked_graph: MagicMock) -> None: +@pytest.mark.usefixtures("mocked_query_builder") +def test_mocked_graph_seed_data(mocked_graph: MockedGraph) -> None: + graph = GraphConnector.get() + graph._seed_data() + + assert mocked_graph.call_args_list[-2].args == ( + """\ +merge_node( + extracted_label="ExtractedPrimarySource", + merged_label="MergedPrimarySource", + nested_edge_labels=[], + nested_node_labels=[], +)""", + { + "identifier": ExtractedPrimarySourceIdentifier("00000000000000"), + "stable_target_id": MergedPrimarySourceIdentifier("00000000000000"), + "on_match": {"version": None}, + "on_create": { + "version": None, + "identifier": "00000000000000", + "identifierInPrimarySource": "mex", + }, + "nested_positions": [], + "nested_values": [], + }, + ) + assert mocked_graph.call_args_list[-1].args == ( + """\ +merge_edges( + extracted_label="ExtractedPrimarySource", + ref_labels=["hadPrimarySource", "stableTargetId"], +)""", + { + "identifier": "00000000000000", + "ref_identifiers": ["00000000000000", "00000000000000"], + "ref_positions": [0, 0], + }, + ) + + +@pytest.mark.usefixtures("mocked_query_builder") +def test_mocked_graph_fetch_extracted_data(mocked_graph: MockedGraph) -> None: + mocked_graph.return_value = [ + { + "items": [ + { + "entityType": "ExtractedThis", + "inlineProperty": ["foo", "bar"], + "_refs": [ + {"value": "second", "position": 1, "label": "nestedProperty"}, + {"value": "first", "position": 0, "label": "nestedProperty"}, + ], + } + ], + "total": 1, + } + ] graph = GraphConnector.get() + result = graph.fetch_extracted_data( + query_string="my-query", + stable_target_id=Identifier.generate(99), + entity_type=[], + skip=10, + limit=100, + ) + assert mocked_graph.call_args_list[-1].args == ( + """\ +fetch_extracted_data( + filter_by_query_string=True, + filter_by_stable_target_id=True, + filter_by_labels=False, +)""", + { + "labels": [], + "limit": 100, + "query_string": "my-query", + "skip": 10, + "stable_target_id": "bFQoRhcVH5DHV1", + }, + ) + + assert result.one() == { + "items": [ + { + "entityType": "ExtractedThis", + "inlineProperty": ["foo", "bar"], + "nestedProperty": ["first", "second"], + } + ], + "total": 1, + } + + +@pytest.mark.usefixtures("mocked_query_builder") +def test_mocked_graph_fetch_identities(mocked_graph: MockedGraph) -> None: + graph = GraphConnector.get() graph.fetch_identities(stable_target_id=Identifier.generate(99)) - assert mocked_graph.run.call_args.args == ( - STABLE_TARGET_ID_IDENTITY_QUERY, + + assert mocked_graph.call_args_list[-1].args == ( + """\ +fetch_identities( + filter_by_had_primary_source=False, + filter_by_identifier_in_primary_source=False, + filter_by_stable_target_id=True, +)""", { "had_primary_source": None, "identifier_in_primary_source": None, @@ -75,8 +214,14 @@ def test_mocked_graph_fetch_identities(mocked_graph: MagicMock) -> None: graph.fetch_identities( had_primary_source=Identifier.generate(101), identifier_in_primary_source="one" ) - assert mocked_graph.run.call_args.args == ( - HAD_PRIMARY_SOURCE_AND_IDENTIFIER_IN_PRIMARY_SOURCE_IDENTITY_QUERY, + + assert mocked_graph.call_args_list[-1].args == ( + """\ +fetch_identities( + filter_by_had_primary_source=True, + filter_by_identifier_in_primary_source=True, + filter_by_stable_target_id=False, +)""", { "had_primary_source": Identifier.generate(101), "identifier_in_primary_source": "one", @@ -85,157 +230,106 @@ def test_mocked_graph_fetch_identities(mocked_graph: MagicMock) -> None: }, ) - with pytest.raises(MExError, match="invalid identity query parameters"): - graph.fetch_identities(identifier_in_primary_source="two") + graph.fetch_identities(identifier_in_primary_source="two") - -def test_mocked_graph_merges_node( - mocked_graph: MagicMock, extracted_person: ExtractedPerson -) -> None: - graph = GraphConnector.get() - graph.merge_node(extracted_person) - - assert ( - mocked_graph.run.call_args.args[0] - == """ -MERGE (n:ExtractedPerson {identifier:$identifier}) -ON CREATE SET n = $on_create -ON MATCH SET n += $on_match -RETURN n; -""" - ) - assert mocked_graph.run.call_args.args[1] == { - "identifier": str(extracted_person.identifier), - "on_create": { - "email": ["fictitiousf@rki.de", "info@rki.de"], - "familyName": ["Fictitious"], - "fullName": ["Fictitious, Frieda, Dr."], - "givenName": ["Frieda"], - "identifier": "bFQoRhcVH5DHUw", - "identifierInPrimarySource": "frieda", - "isniId": [], - "orcidId": [], - "stableTargetId": "bFQoRhcVH5DHVu", - }, - "on_match": { - "email": ["fictitiousf@rki.de", "info@rki.de"], - "familyName": ["Fictitious"], - "fullName": ["Fictitious, Frieda, Dr."], - "givenName": ["Frieda"], - "isniId": [], - "orcidId": [], + assert mocked_graph.call_args_list[-1].args == ( + """\ +fetch_identities( + filter_by_had_primary_source=False, + filter_by_identifier_in_primary_source=True, + filter_by_stable_target_id=False, +)""", + { + "had_primary_source": None, + "identifier_in_primary_source": "two", + "stable_target_id": None, + "limit": 1000, }, - } - - -def test_mocked_graph_merges_edges( - mocked_graph: MagicMock, extracted_person: ExtractedPerson -) -> None: - graph = GraphConnector.get() - graph.merge_edges(extracted_person) - mocked_graph.assert_has_calls( - [ - call.run( - """ -MATCH (s {identifier:$fromID}) -MATCH (t {stableTargetId:$toSTI}) -MERGE (s)-[e:hadPrimarySource]->(t) -RETURN e; -""", - { - "fromID": str(extracted_person.identifier), - "toSTI": str(extracted_person.hadPrimarySource), - }, - ) - ] ) -def test_mocked_graph_ingests_models( - mocked_graph: MagicMock, extracted_person: ExtractedPerson +@pytest.mark.usefixtures("mocked_query_builder") +def test_mocked_graph_merges_node( + mocked_graph: MockedGraph, dummy_data: list[AnyExtractedModel] ) -> None: + extracted_organizational_unit = dummy_data[4] graph = GraphConnector.get() - identifiers = graph.ingest([extracted_person]) - - assert identifiers == [extracted_person.identifier] - - # expect node is created - assert mocked_graph.run.call_args_list[-5:][0][0] == ( - """ -MERGE (n:ExtractedPerson {identifier:$identifier}) -ON CREATE SET n = $on_create -ON MATCH SET n += $on_match -RETURN n; -""", + graph._merge_node(extracted_organizational_unit) + + assert mocked_graph.call_args_list[-1].args == ( + """\ +merge_node( + extracted_label="ExtractedOrganizationalUnit", + merged_label="MergedOrganizationalUnit", + nested_edge_labels=["name"], + nested_node_labels=["Text"], +)""", { - "identifier": "bFQoRhcVH5DHUw", + "identifier": extracted_organizational_unit.identifier, + "nested_positions": [0], + "nested_values": [{"language": "en", "value": "Unit 1"}], "on_create": { - "stableTargetId": "bFQoRhcVH5DHVu", - "email": ["fictitiousf@rki.de", "info@rki.de"], - "familyName": ["Fictitious"], - "fullName": ["Fictitious, Frieda, Dr."], - "givenName": ["Frieda"], - "isniId": [], - "orcidId": [], - "identifier": "bFQoRhcVH5DHUw", - "identifierInPrimarySource": "frieda", - }, - "on_match": { - "email": ["fictitiousf@rki.de", "info@rki.de"], - "familyName": ["Fictitious"], - "fullName": ["Fictitious, Frieda, Dr."], - "givenName": ["Frieda"], - "isniId": [], - "orcidId": [], + "email": [], + "identifier": extracted_organizational_unit.identifier, + "identifierInPrimarySource": "ou-1", }, + "on_match": {"email": []}, + "stable_target_id": extracted_organizational_unit.stableTargetId, }, ) - # expect edges are created - assert mocked_graph.run.call_args_list[-5:][1][0] == ( - """ -MATCH (s {identifier:$fromID}) -MATCH (t {stableTargetId:$toSTI}) -MERGE (s)-[e:hadPrimarySource]->(t) -RETURN e; -""", - { - "fromID": str(extracted_person.identifier), - "toSTI": str(extracted_person.hadPrimarySource), - }, - ) - assert mocked_graph.run.call_args_list[-5:][2][0] == ( - """ -MATCH (s {identifier:$fromID}) -MATCH (t {stableTargetId:$toSTI}) -MERGE (s)-[e:affiliation]->(t) -RETURN e; -""", - { - "fromID": str(extracted_person.identifier), - "toSTI": str(extracted_person.affiliation[0]), - }, - ) - assert mocked_graph.run.call_args_list[-5:][3][0] == ( - """ -MATCH (s {identifier:$fromID}) -MATCH (t {stableTargetId:$toSTI}) -MERGE (s)-[e:affiliation]->(t) -RETURN e; -""", + + +@pytest.mark.usefixtures("mocked_query_builder") +def test_mocked_graph_merges_edges( + mocked_graph: MockedGraph, dummy_data: list[AnyExtractedModel] +) -> None: + extracted_activity = dummy_data[4] + graph = GraphConnector.get() + graph._merge_edges(extracted_activity) + + assert mocked_graph.call_args_list[-1].args == ( + """\ +merge_edges( + extracted_label="ExtractedOrganizationalUnit", + ref_labels=["hadPrimarySource", "stableTargetId"], +)""", { - "fromID": str(extracted_person.identifier), - "toSTI": str(extracted_person.affiliation[1]), + "identifier": extracted_activity.identifier, + "ref_identifiers": [ + extracted_activity.hadPrimarySource, + extracted_activity.stableTargetId, + ], + "ref_positions": [0, 0], }, ) - assert mocked_graph.run.call_args_list[-5:][4][0] == ( - """ -MATCH (s {identifier:$fromID}) -MATCH (t {stableTargetId:$toSTI}) -MERGE (s)-[e:memberOf]->(t) -RETURN e; -""", + + +@pytest.mark.usefixtures("mocked_graph") +def test_mocked_graph_ingests_models(dummy_data: list[AnyExtractedModel]) -> None: + graph = GraphConnector.get() + identifiers = graph.ingest(dummy_data) + + assert identifiers == [d.identifier for d in dummy_data] + + +@pytest.mark.usefixtures("load_dummy_data") +@pytest.mark.integration +def test_fetch_extracted_data() -> None: + connector = GraphConnector.get() + + result = connector.fetch_extracted_data(None, None, None, 0, 1) + + assert result.all() == [ { - "fromID": str(extracted_person.identifier), - "toSTI": str(extracted_person.memberOf[0]), - }, - ) + "items": [ + { + "entityType": MEX_EXTRACTED_PRIMARY_SOURCE.entityType, + "hadPrimarySource": [MEX_EXTRACTED_PRIMARY_SOURCE.hadPrimarySource], + "identifier": MEX_EXTRACTED_PRIMARY_SOURCE.identifier, + "identifierInPrimarySource": MEX_EXTRACTED_PRIMARY_SOURCE.identifierInPrimarySource, + "stableTargetId": [MEX_EXTRACTED_PRIMARY_SOURCE.stableTargetId], + } + ], + "total": 7, + } + ] diff --git a/tests/graph/test_hydrate.py b/tests/graph/test_hydrate.py deleted file mode 100644 index 34c03d9..0000000 --- a/tests/graph/test_hydrate.py +++ /dev/null @@ -1,250 +0,0 @@ -import pytest -from pydantic import BaseModel as PydanticBaseModel - -from mex.backend.graph.hydrate import ( - NULL_VALUE, - FlatDict, - NestedDict, - _get_base_model_from_field, - are_instances, - dehydrate, - dehydrate_value, - hydrate, - hydrate_value, -) -from mex.common.models import BaseModel - - -def test_are_instances() -> None: - assert are_instances(None, str) is False - assert are_instances([], str) is True - assert are_instances([3, 4], str) is False - assert are_instances(["foo", "bar"], str) is True - - -def test_dehydrate_value() -> None: - assert dehydrate_value(None) == NULL_VALUE - assert dehydrate_value("foo") == "foo" - assert dehydrate_value(["foo", None]) == ["foo", NULL_VALUE] - with pytest.raises(TypeError): - assert dehydrate_value([1.3, object()]) # type: ignore - - -def test_hydrate_value() -> None: - assert hydrate_value(NULL_VALUE) is None - assert hydrate_value("foo") == "foo" - assert hydrate_value(["foo", NULL_VALUE]) == ["foo", None] - with pytest.raises(TypeError): - assert hydrate_value([1.3, object()]) # type: ignore - - -class Leaf(BaseModel): - color: str | None - veins: list[str | None] - - -class Branch(BaseModel): - leaf: Leaf - leaves: list[Leaf] - - -class Tree(BaseModel): - branch: Branch - branches: list[Branch] - - -class Caterpillar(BaseModel): - home: Leaf | None - - -class Nature(BaseModel): - flora_or_fauna: Tree | Caterpillar # not supported - - -@pytest.mark.parametrize( - ("model", "attribute", "expected"), - [ - (Tree, "branch", Branch), - (Branch, "leaves", Leaf), - (Leaf, "color", "cannot hydrate paths with non base models"), - (Leaf, "veins", "cannot hydrate paths with non base models"), - (Caterpillar, "home", Leaf), - (Nature, "flora_or_fauna", "Expected one non-None type, got 2"), - ], -) -def test_get_base_model_from_field( - model: type[BaseModel], attribute: str, expected: type[BaseModel] | str -) -> None: - field = model.model_fields[attribute] - if isinstance(expected, str): - with pytest.raises(TypeError, match=expected): - _get_base_model_from_field(field) - else: - assert _get_base_model_from_field(field) is expected - - -@pytest.mark.parametrize( - ("nested", "flat", "model"), - [ - ({}, {}, Leaf), - ( - {"color": None, "veins": ["primary", None]}, - { - "color": "", - "veins": ["primary", ""], - }, - Leaf, - ), - ( - {"color": "green", "veins": ["primary", "secondary"]}, - { - "color": "green", - "veins": ["primary", "secondary"], - }, - Leaf, - ), - ( - { - "leaf": {"color": None, "veins": [None, "secondary"]}, - "leaves": [{"color": "red"}, {"color": None}], - }, - { - "leaf_color": "", - "leaf_veins": ["", "secondary"], - "leaves_color": ["red", ""], - }, - Branch, - ), - ( - { - "leaf": {"color": "red", "veins": ["primary", "secondary"]}, - "leaves": [{"color": "red"}, {"color": "brown"}, {"color": "green"}], - }, - { - "leaf_color": "red", - "leaf_veins": ["primary", "secondary"], - "leaves_color": ["red", "brown", "green"], - }, - Branch, - ), - ( - { - "branch": { - "leaf": {"veins": [None, "secondary"]}, - "leaves": [{"color": "red"}, {"color": None}], - }, - "branches": [ - {"leaf": {"color": None}}, - {"leaf": {"color": "gold"}}, - ], - }, - { - "branch_leaf_veins": ["", "secondary"], - "branch_leaves_color": ["red", ""], - "branches_leaf_color": ["", "gold"], - }, - Tree, - ), - ( - { - "branch": { - "leaf": {"veins": ["primary", "secondary"]}, - "leaves": [{"color": "red"}, {"color": "yellow"}], - }, - "branches": [ - {"leaf": {"color": "red"}}, - {"leaf": {"color": "gold"}}, - {"leaf": {"color": "yellow"}}, - ], - }, - { - "branch_leaf_veins": ["primary", "secondary"], - "branch_leaves_color": ["red", "yellow"], - "branches_leaf_color": ["red", "gold", "yellow"], - }, - Tree, - ), - ], - ids=[ - "leaf-empty", - "leaf-sparse", - "leaf-full", - "branch-sparse", - "branch-full", - "tree-sparse", - "tree-full", - ], -) -def test_de_hydration_roundtrip( - nested: NestedDict, flat: FlatDict, model: type[BaseModel] -) -> None: - assert dehydrate(nested) == flat - assert hydrate(flat, model) == nested - - -class NonMExModel(PydanticBaseModel): - tree: Tree - - -@pytest.mark.parametrize( - ("flat", "model", "error"), - [ - ( - {"branch_leaf_color": "blue"}, - NonMExModel, - "cannot hydrate paths with non base models", - ), - ( - {"veins": [42, object(), 1.3]}, - Leaf, - "can only hydrate strings or lists of strings", - ), - ( - { - "branches_leaves_color": ["red", "green"], - }, - Branch, - "flat dict does not align with target model", - ), - ( - { - "branches_leaves_color": ["red", "green"], - "branches_leaf_veins": ["primary", "secondary"], - }, - Tree, - "cannot handle multiple list branches", - ), - ( - {"leaves_color": "yellow"}, - Branch, - "cannot hydrate non-list to list", - ), - ], - ids=[ - "non base model", - "disallowed objects", - "model misalignment", - "multiple list branches", - "non list incompatibility", - ], -) -def test_hydration_errors(flat: FlatDict, model: type[BaseModel], error: str) -> None: - with pytest.raises(Exception, match=error): - hydrate(flat, model) - - -@pytest.mark.parametrize( - ("nested", "error"), - [ - ( - {"leaves": [{"veins": ["primary"]}, {"veins": ["secondary"]}]}, - "can only handle one list per path", - ) - ], - ids=[ - "multiple lists per path", - ], -) -def test_dehydration_errors(nested: NestedDict, error: str) -> None: - with pytest.raises(Exception, match=error): - assert dehydrate(nested) diff --git a/tests/graph/test_models.py b/tests/graph/test_models.py new file mode 100644 index 0000000..7bcf957 --- /dev/null +++ b/tests/graph/test_models.py @@ -0,0 +1,123 @@ +from unittest.mock import MagicMock, Mock + +import pytest +from neo4j import ( + Record as Neo4jRecord, +) +from neo4j import ( + Result as Neo4jResult, +) +from neo4j import ( + ResultSummary as Neo4jResultSummary, +) + +from mex.backend.graph.exceptions import MultipleResultsFoundError, NoResultFoundError +from mex.backend.graph.models import Result +from mex.common.testing import Joker + + +@pytest.fixture +def summary() -> Mock: + + class SummaryCounters: + def __init__(self) -> None: + self.nodes_created = 73 + self.labels_added = 0 + self.constraints_removed = 0 + + return Mock(spec=Neo4jResultSummary, counters=SummaryCounters()) + + +@pytest.fixture +def multiple_results(summary: Mock) -> Mock: + records = [ + Mock(spec=Neo4jRecord, data=MagicMock(return_value={"num": 40})), + Mock(spec=Neo4jRecord, data=MagicMock(return_value={"num": 41})), + Mock(spec=Neo4jRecord, data=MagicMock(return_value={"num": 42})), + ] + return Mock( + spec=Neo4jResult, to_eager_result=MagicMock(return_value=(records, summary, [])) + ) + + +@pytest.fixture +def no_result(summary: Mock) -> Mock: + return Mock( + spec=Neo4jResult, to_eager_result=MagicMock(return_value=([], summary, [])) + ) + + +@pytest.fixture +def single_result(summary: Mock) -> Mock: + records = [ + Mock( + spec=Neo4jRecord, + data=MagicMock( + return_value={ + "text": "Lorem adipisicing elit consequat sint consectetur " + "proident cupidatat culpa voluptate. Aute commodo ea sunt mollit. " + "Lorem sint amet reprehenderit aliqua." + } + ), + ), + ] + return Mock( + spec=Neo4jResult, to_eager_result=MagicMock(return_value=(records, summary, [])) + ) + + +def test_result_getitem(multiple_results: Mock, single_result: Mock) -> None: + # cannot access item when there are multiple results, + # because that might lead to unexpected results + with pytest.raises(MultipleResultsFoundError): + _ = Result(multiple_results)["num"] + + assert Result(single_result)["text"].startswith("Lorem") + + +def test_result_iter(multiple_results: Mock) -> None: + assert list(Result(multiple_results)) == [{"num": 40}, {"num": 41}, {"num": 42}] + + +def test_result_repr(multiple_results: Mock, single_result: Mock) -> None: + assert ( + repr(Result(multiple_results)) + == "Result([{'num': 40}, {'num': 41}, {'num': 42}])" + ) + + # when representation is too long, it should be abbreviated in the middle + assert repr(Result(single_result)) == ( + "Result([{'text': 'Lorem adipisicing elit... " + "...orem sint amet reprehenderit aliqua.'}])" + ) + + +def test_result_all(multiple_results: Mock) -> None: + assert Result(multiple_results).all() == [{"num": 40}, {"num": 41}, {"num": 42}] + + +def test_result_one( + multiple_results: Mock, no_result: Mock, single_result: Mock +) -> None: + assert "text" in Result(single_result).one() + + with pytest.raises(NoResultFoundError): + Result(no_result).one() + + with pytest.raises(MultipleResultsFoundError): + Result(multiple_results).one() + + +def test_result_one_or_none( + multiple_results: Mock, no_result: Mock, single_result: Mock +) -> None: + assert Result(single_result).one_or_none() == {"text": Joker()} + + assert Result(no_result).one_or_none() is None + + with pytest.raises(MultipleResultsFoundError): + Result(multiple_results).one_or_none() + + +def test_get_update_counters(multiple_results: Mock) -> None: + assert Result(multiple_results).get_update_counters() == {"nodes_created": 73} diff --git a/tests/graph/test_query.py b/tests/graph/test_query.py new file mode 100644 index 0000000..b581b93 --- /dev/null +++ b/tests/graph/test_query.py @@ -0,0 +1,390 @@ +import pytest + +from mex.backend.graph.query import QueryBuilder + + +@pytest.fixture +def query_builder() -> QueryBuilder: + builder = QueryBuilder.get() + builder._env.globals.update( + extracted_labels=["ExtractedThis", "ExtractedThat", "ExtractedOther"], + merged_labels=["MergedThis", "MergedThat", "MergedOther"], + nested_labels=["Link", "Text", "Location"], + ) + return builder + + +def test_create_full_text_search_index(query_builder: QueryBuilder) -> None: + query = query_builder.create_full_text_search_index( + node_labels=["Apple", "Orange"], + search_fields=["texture", "sugarContent", "color"], + ) + assert ( + query + == """\ +CREATE FULLTEXT INDEX search_index IF NOT EXISTS +FOR (n:Apple|Orange) +ON EACH [n.texture, n.sugarContent, n.color] +OPTIONS {indexConfig: $index_config};""" + ) + + +def test_create_identifier_uniqueness_constraint(query_builder: QueryBuilder) -> None: + query = query_builder.create_identifier_uniqueness_constraint( + node_label="BlueBerryPie" + ) + assert ( + query + == """\ +CREATE CONSTRAINT blue_berry_pie_identifier_uniqueness IF NOT EXISTS +FOR (n:BlueBerryPie) +REQUIRE n.identifier IS UNIQUE;""" + ) + + +def test_fetch_database_status(query_builder: QueryBuilder) -> None: + query = query_builder.fetch_database_status() + assert ( + query + == """\ +SHOW DEFAULT DATABASE +YIELD currentStatus;""" + ) + + +@pytest.mark.parametrize( + ( + "filter_by_query_string", + "filter_by_stable_target_id", + "filter_by_labels", + "expected", + ), + [ + ( + True, + True, + True, + """\ +CALL { + CALL db.index.fulltext.queryNodes("search_index", $query_string) + YIELD node AS hit, score + MATCH (n:ExtractedThis|ExtractedThat|ExtractedOther)-[:stableTargetId]->(merged:MergedThis|MergedThat|MergedOther) + WHERE + elementId(hit) = elementId(n) + AND merged.identifier = $stable_target_id + AND ANY(label IN labels(n) WHERE label IN $labels) + RETURN COUNT(n) AS total +} +CALL { + CALL db.index.fulltext.queryNodes("search_index", $query_string) + YIELD node AS hit, score + MATCH (n:ExtractedThis|ExtractedThat|ExtractedOther)-[:stableTargetId]->(merged:MergedThis|MergedThat|MergedOther) + WHERE + elementId(hit) = elementId(n) + AND merged.identifier = $stable_target_id + AND ANY(label IN labels(n) WHERE label IN $labels) + CALL { + WITH n + MATCH (n)-[r]->(merged:MergedThis|MergedThat|MergedOther) + RETURN type(r) as label, r.position as position, merged.identifier as value + UNION + WITH n + MATCH (n)-[r]->(nested:Link|Text|Location) + RETURN type(r) as label, r.position as position, properties(nested) as value + } + WITH n, collect({label: label, position: position, value: value}) as refs + RETURN n{.*, entityType: head(labels(n)), _refs: refs} + ORDER BY n.identifier ASC + SKIP $skip + LIMIT $limit +} +RETURN collect(n) AS items, total;""", + ), + ( + False, + False, + False, + """\ +CALL { + MATCH (n:ExtractedThis|ExtractedThat|ExtractedOther) + RETURN COUNT(n) AS total +} +CALL { + MATCH (n:ExtractedThis|ExtractedThat|ExtractedOther) + CALL { + WITH n + MATCH (n)-[r]->(merged:MergedThis|MergedThat|MergedOther) + RETURN type(r) as label, r.position as position, merged.identifier as value + UNION + WITH n + MATCH (n)-[r]->(nested:Link|Text|Location) + RETURN type(r) as label, r.position as position, properties(nested) as value + } + WITH n, collect({label: label, position: position, value: value}) as refs + RETURN n{.*, entityType: head(labels(n)), _refs: refs} + ORDER BY n.identifier ASC + SKIP $skip + LIMIT $limit +} +RETURN collect(n) AS items, total;""", + ), + ( + False, + False, + True, + """\ +CALL { + MATCH (n:ExtractedThis|ExtractedThat|ExtractedOther) + WHERE + ANY(label IN labels(n) WHERE label IN $labels) + RETURN COUNT(n) AS total +} +CALL { + MATCH (n:ExtractedThis|ExtractedThat|ExtractedOther) + WHERE + ANY(label IN labels(n) WHERE label IN $labels) + CALL { + WITH n + MATCH (n)-[r]->(merged:MergedThis|MergedThat|MergedOther) + RETURN type(r) as label, r.position as position, merged.identifier as value + UNION + WITH n + MATCH (n)-[r]->(nested:Link|Text|Location) + RETURN type(r) as label, r.position as position, properties(nested) as value + } + WITH n, collect({label: label, position: position, value: value}) as refs + RETURN n{.*, entityType: head(labels(n)), _refs: refs} + ORDER BY n.identifier ASC + SKIP $skip + LIMIT $limit +} +RETURN collect(n) AS items, total;""", + ), + ], + ids=["all-filters", "no-filters", "label-filter"], +) +def test_fetch_extracted_data( + query_builder: QueryBuilder, + filter_by_query_string: bool, + filter_by_stable_target_id: bool, + filter_by_labels: bool, + expected: str, +) -> None: + query = query_builder.fetch_extracted_data( + filter_by_query_string=filter_by_query_string, + filter_by_stable_target_id=filter_by_stable_target_id, + filter_by_labels=filter_by_labels, + ) + assert query == expected + + +@pytest.mark.parametrize( + ( + "filter_by_had_primary_source", + "filter_by_identifier_in_primary_source", + "filter_by_stable_target_id", + "expected", + ), + [ + ( + True, + True, + True, + """\ +MATCH (n:ExtractedThis|ExtractedThat|ExtractedOther)-[:stableTargetId]->(merged:MergedThis|MergedThat|MergedOther) +MATCH (n)-[:hadPrimarySource]->(primary_source:MergedPrimarySource) +WHERE + primary_source.identifier = $had_primary_source + AND n.identifierInPrimarySource = $identifier_in_primary_source + AND merged.identifier = $stable_target_id +RETURN + merged.identifier as stableTargetId, + primary_source.identifier as hadPrimarySource, + n.identifierInPrimarySource as identifierInPrimarySource, + n.identifier as identifier +ORDER BY n.identifier ASC +LIMIT $limit;""", + ), + ( + False, + False, + False, + """\ +MATCH (n:ExtractedThis|ExtractedThat|ExtractedOther)-[:stableTargetId]->(merged:MergedThis|MergedThat|MergedOther) +MATCH (n)-[:hadPrimarySource]->(primary_source:MergedPrimarySource) +RETURN + merged.identifier as stableTargetId, + primary_source.identifier as hadPrimarySource, + n.identifierInPrimarySource as identifierInPrimarySource, + n.identifier as identifier +ORDER BY n.identifier ASC +LIMIT $limit;""", + ), + ( + False, + False, + True, + """\ +MATCH (n:ExtractedThis|ExtractedThat|ExtractedOther)-[:stableTargetId]->(merged:MergedThis|MergedThat|MergedOther) +MATCH (n)-[:hadPrimarySource]->(primary_source:MergedPrimarySource) +WHERE + merged.identifier = $stable_target_id +RETURN + merged.identifier as stableTargetId, + primary_source.identifier as hadPrimarySource, + n.identifierInPrimarySource as identifierInPrimarySource, + n.identifier as identifier +ORDER BY n.identifier ASC +LIMIT $limit;""", + ), + ], + ids=["all-filters", "no-filters", "id-filter"], +) +def test_fetch_identities( + query_builder: QueryBuilder, + filter_by_had_primary_source: bool, + filter_by_identifier_in_primary_source: bool, + filter_by_stable_target_id: bool, + expected: str, +) -> None: + query = query_builder.fetch_identities( + filter_by_had_primary_source=filter_by_had_primary_source, + filter_by_identifier_in_primary_source=filter_by_identifier_in_primary_source, + filter_by_stable_target_id=filter_by_stable_target_id, + ) + assert query == expected + + +@pytest.mark.parametrize( + ( + "ref_labels", + "expected", + ), + [ + ( + ["personInCharge", "meetingScheduledBy", "agendaSignedOff"], + """\ +MATCH (source:ExtractedThat {identifier: $identifier}) +CALL { + WITH source + MATCH (target_0 {identifier: $ref_identifiers[0]}) + MERGE (source)-[edge:personInCharge {position: $ref_positions[0]}]->(target_0) + RETURN edge + UNION + WITH source + MATCH (target_1 {identifier: $ref_identifiers[1]}) + MERGE (source)-[edge:meetingScheduledBy {position: $ref_positions[1]}]->(target_1) + RETURN edge + UNION + WITH source + MATCH (target_2 {identifier: $ref_identifiers[2]}) + MERGE (source)-[edge:agendaSignedOff {position: $ref_positions[2]}]->(target_2) + RETURN edge +} +WITH source, collect(edge) as edges +CALL { + WITH source, edges + MATCH (source)-[outdated_edge]->(:MergedThis|MergedThat|MergedOther) + WHERE NOT outdated_edge IN edges + DELETE outdated_edge + RETURN count(outdated_edge) as pruned +} +RETURN count(edges) as merged, pruned, edges;""", + ), + ( + [], + """\ +MATCH (source:ExtractedThat {identifier: $identifier}) +CALL { + RETURN null as edge +} +WITH source, collect(edge) as edges +CALL { + WITH source, edges + MATCH (source)-[outdated_edge]->(:MergedThis|MergedThat|MergedOther) + WHERE NOT outdated_edge IN edges + DELETE outdated_edge + RETURN count(outdated_edge) as pruned +} +RETURN count(edges) as merged, pruned, edges;""", + ), + ], + ids=["has-ref-labels", "no-ref-labels"], +) +def test_merge_edges( + query_builder: QueryBuilder, ref_labels: list[str], expected: str +) -> None: + query = query_builder.merge_edges( + extracted_label="ExtractedThat", ref_labels=ref_labels + ) + assert query == expected + + +@pytest.mark.parametrize( + ("nested_edge_labels", "nested_node_labels", "expected"), + [ + ( + ["description", "homepage", "geoLocation"], + ["Text", "Link", "Location"], + """\ +MERGE (merged:MergedThat {identifier: $stable_target_id}) +MERGE (extracted:ExtractedThat {identifier: $identifier})-[stableTargetId:stableTargetId {position: 0}]->(merged) +ON CREATE SET extracted = $on_create +ON MATCH SET extracted += $on_match +MERGE (extracted)-[edge_0:description {position: $nested_positions[0]}]->(value_0:Text) +ON CREATE SET value_0 = $nested_values[0] +ON MATCH SET value_0 += $nested_values[0] +MERGE (extracted)-[edge_1:homepage {position: $nested_positions[1]}]->(value_1:Link) +ON CREATE SET value_1 = $nested_values[1] +ON MATCH SET value_1 += $nested_values[1] +MERGE (extracted)-[edge_2:geoLocation {position: $nested_positions[2]}]->(value_2:Location) +ON CREATE SET value_2 = $nested_values[2] +ON MATCH SET value_2 += $nested_values[2] +WITH extracted, + [edge_0, edge_1, edge_2] as edges, + [value_0, value_1, value_2] as values +CALL { + WITH extracted, values + MATCH (extracted)-[]->(outdated_node:Link|Text|Location) + WHERE NOT outdated_node IN values + DETACH DELETE outdated_node + RETURN count(outdated_node) as pruned +} +RETURN extracted, edges, values, pruned;""", + ), + ( + [], + [], + """\ +MERGE (merged:MergedThat {identifier: $stable_target_id}) +MERGE (extracted:ExtractedThat {identifier: $identifier})-[stableTargetId:stableTargetId {position: 0}]->(merged) +ON CREATE SET extracted = $on_create +ON MATCH SET extracted += $on_match +WITH extracted, + [] as edges, + [] as values +CALL { + WITH extracted, values + MATCH (extracted)-[]->(outdated_node:Link|Text|Location) + WHERE NOT outdated_node IN values + DETACH DELETE outdated_node + RETURN count(outdated_node) as pruned +} +RETURN extracted, edges, values, pruned;""", + ), + ], + ids=["has-nested-labels", "no-nested-labels"], +) +def test_merge_node( + query_builder: QueryBuilder, + nested_edge_labels: list[str], + nested_node_labels: list[str], + expected: str, +) -> None: + query = query_builder.merge_node( + extracted_label="ExtractedThat", + merged_label="MergedThat", + nested_edge_labels=nested_edge_labels, + nested_node_labels=nested_node_labels, + ) + assert query == expected diff --git a/tests/graph/test_transform.py b/tests/graph/test_transform.py index 46fa33c..315ae19 100644 --- a/tests/graph/test_transform.py +++ b/tests/graph/test_transform.py @@ -1,21 +1,65 @@ -from mex.backend.graph.transform import transform_identity_result_to_identity -from mex.common.identity import Identity -from mex.common.types import Identifier, PrimarySourceID +from mex.backend.graph.transform import expand_references_in_search_result -def test_transform_identity_result_to_identity() -> None: - assert transform_identity_result_to_identity( - { - "i": { - "identifier": "90200009120910", - "hadPrimarySource": "7827287287287287", - "identifierInPrimarySource": "one", - "stableTargetId": "6536536536536536536536", - } - } - ) == Identity( - identifier=Identifier("90200009120910"), - hadPrimarySource=PrimarySourceID("7827287287287287"), - identifierInPrimarySource="one", - stableTargetId=Identifier("6536536536536536536536"), - ) +def test_expand_references_in_search_result() -> None: + node_dict = { + "_refs": [ + {"label": "responsibleUnit", "position": 0, "value": "bFQoRhcVH5DHUz"}, + {"label": "contact", "position": 2, "value": "bFQoRhcVH5DHUz"}, + {"label": "contact", "position": 0, "value": "bFQoRhcVH5DHUv"}, + {"label": "contact", "position": 1, "value": "bFQoRhcVH5DHUx"}, + {"label": "hadPrimarySource", "position": 0, "value": "bFQoRhcVH5DHUr"}, + {"label": "stableTargetId", "position": 0, "value": "bFQoRhcVH5DHUB"}, + { + "label": "website", + "position": 0, + "value": {"title": "Activity Homepage", "url": "https://activity-1"}, + }, + { + "label": "abstract", + "position": 1, + "value": {"value": "Une activité active."}, + }, + { + "label": "title", + "position": 0, + "value": {"language": "de", "value": "Aktivität 1"}, + }, + { + "label": "abstract", + "position": 0, + "value": {"language": "en", "value": "An active activity."}, + }, + ], + "activityType": [], + "end": [], + "entityType": "ExtractedActivity", + "fundingProgram": [], + "identifier": "bFQoRhcVH5DHUA", + "identifierInPrimarySource": "a-1", + "start": [], + "theme": ["https://mex.rki.de/item/theme-3"], + } + + expand_references_in_search_result(node_dict) + + assert node_dict == { + "activityType": [], + "end": [], + "entityType": "ExtractedActivity", + "fundingProgram": [], + "identifier": "bFQoRhcVH5DHUA", + "identifierInPrimarySource": "a-1", + "start": [], + "theme": ["https://mex.rki.de/item/theme-3"], + "responsibleUnit": ["bFQoRhcVH5DHUz"], + "contact": ["bFQoRhcVH5DHUv", "bFQoRhcVH5DHUx", "bFQoRhcVH5DHUz"], + "hadPrimarySource": ["bFQoRhcVH5DHUr"], + "stableTargetId": ["bFQoRhcVH5DHUB"], + "website": [{"title": "Activity Homepage", "url": "https://activity-1"}], + "abstract": [ + {"language": "en", "value": "An active activity."}, + {"value": "Une activité active."}, + ], + "title": [{"language": "de", "value": "Aktivität 1"}], + } diff --git a/tests/identity/test_main.py b/tests/identity/test_main.py index 1a6d0d3..078ae76 100644 --- a/tests/identity/test_main.py +++ b/tests/identity/test_main.py @@ -1,5 +1,4 @@ from typing import Any -from unittest.mock import MagicMock import pytest from fastapi.testclient import TestClient @@ -9,6 +8,7 @@ MEX_PRIMARY_SOURCE_IDENTIFIER_IN_PRIMARY_SOURCE, MEX_PRIMARY_SOURCE_STABLE_TARGET_ID, ) +from tests.conftest import MockedGraph @pytest.mark.parametrize( @@ -30,12 +30,10 @@ ( [ { - "i": { - "hadPrimarySource": "psSti00000000001", - "identifier": "cpId000000000002", - "identifierInPrimarySource": "cp-2", - "stableTargetId": "cpSti00000000002", - } + "hadPrimarySource": "psSti00000000001", + "identifier": "cpId000000000002", + "identifierInPrimarySource": "cp-2", + "stableTargetId": "cpSti00000000002", } ], { @@ -54,7 +52,7 @@ ) def test_assign_identity_mocked( client_with_api_key_write_permission: TestClient, - mocked_graph: MagicMock, + mocked_graph: MockedGraph, mocked_return: list[dict[str, str]], post_body: dict[str, str], expected: dict[str, Any], @@ -67,24 +65,20 @@ def test_assign_identity_mocked( def test_assign_identity_inconsistency_mocked( client_with_api_key_write_permission: TestClient, - mocked_graph: MagicMock, + mocked_graph: MockedGraph, ) -> None: mocked_graph.return_value = [ { - "i": { - "hadPrimarySource": "psSti00000000001", - "identifier": "cpId000000000002", - "identifierInPrimarySource": "cp-2", - "stableTargetId": "cpSti00000000002", - } + "hadPrimarySource": "psSti00000000001", + "identifier": "cpId000000000002", + "identifierInPrimarySource": "cp-2", + "stableTargetId": "cpSti00000000002", }, { - "i": { - "hadPrimarySource": "psSti00000000001", - "identifier": "cpId000000000098", - "identifierInPrimarySource": "cp-2", - "stableTargetId": "cpSti00000000099", - } + "hadPrimarySource": "psSti00000000001", + "identifier": "cpId000000000098", + "identifierInPrimarySource": "cp-2", + "stableTargetId": "cpSti00000000099", }, ] response = client_with_api_key_write_permission.post( @@ -94,19 +88,8 @@ def test_assign_identity_inconsistency_mocked( "identifierInPrimarySource": "cp-2", }, ) - assert response.status_code == 500 - assert "graph inconsistency" in response.text - - -@pytest.mark.usefixtures("mocked_graph") -def test_fetch_identity_invalid_query_params_mocked( - client_with_api_key_write_permission: TestClient, -) -> None: - response = client_with_api_key_write_permission.get( - "/v0/identity", - ) - assert response.status_code == 400 - assert "invalid identity query parameters" in response.text + assert response.status_code == 500, response.text + assert "MultipleResultsFoundError" in response.text @pytest.mark.parametrize( @@ -130,8 +113,8 @@ def test_fetch_identity_invalid_query_params_mocked( "identifierInPrimarySource": "cp-2", }, { - "identifier": "bFQoRhcVH5DHUw", "hadPrimarySource": "bFQoRhcVH5DHUr", + "identifier": "bFQoRhcVH5DHUw", "identifierInPrimarySource": "cp-2", "stableTargetId": "bFQoRhcVH5DHUx", }, @@ -170,12 +153,10 @@ def test_assign_identity( ( [ { - "i": { - "hadPrimarySource": "28282828282828", - "identifier": "7878787878787878777", - "identifierInPrimarySource": "one", - "stableTargetId": "949494949494949494", - } + "hadPrimarySource": "28282828282828", + "identifier": "7878787878787878777", + "identifierInPrimarySource": "one", + "stableTargetId": "949494949494949494", } ], "?hadPrimarySource=28282828282828&identifierInPrimarySource=one", @@ -194,20 +175,16 @@ def test_assign_identity( ( [ { - "i": { - "hadPrimarySource": "28282828282828", - "identifier": "62626262626266262", - "identifierInPrimarySource": "two", - "stableTargetId": "949494949494949494", - } + "hadPrimarySource": "28282828282828", + "identifier": "62626262626266262", + "identifierInPrimarySource": "two", + "stableTargetId": "949494949494949494", }, { - "i": { - "hadPrimarySource": "39393939393939", - "identifier": "7878787878787878777", - "identifierInPrimarySource": "duo", - "stableTargetId": "949494949494949494", - } + "hadPrimarySource": "39393939393939", + "identifier": "7878787878787878777", + "identifierInPrimarySource": "duo", + "stableTargetId": "949494949494949494", }, ], "?stableTargetId=949494949494949494", @@ -234,7 +211,7 @@ def test_assign_identity( ) def test_fetch_identities_mocked( client_with_api_key_write_permission: TestClient, - mocked_graph: MagicMock, + mocked_graph: MockedGraph, mocked_return: list[dict[str, str]], query_string: str, expected: dict[str, Any], @@ -258,7 +235,7 @@ def test_fetch_identities_mocked( "identifier": "bFQoRhcVH5DHUq", "identifierInPrimarySource": "ps-1", "stableTargetId": "bFQoRhcVH5DHUr", - }, + } ], "total": 1, }, diff --git a/tests/identity/test_provider.py b/tests/identity/test_provider.py index 9aba7bb..5cba1d2 100644 --- a/tests/identity/test_provider.py +++ b/tests/identity/test_provider.py @@ -1,16 +1,20 @@ from typing import Any -from unittest.mock import MagicMock import pytest +from mex.backend.graph.exceptions import MultipleResultsFoundError from mex.backend.identity.provider import GraphIdentityProvider -from mex.common.exceptions import MExError from mex.common.models import ( MEX_PRIMARY_SOURCE_IDENTIFIER, MEX_PRIMARY_SOURCE_IDENTIFIER_IN_PRIMARY_SOURCE, MEX_PRIMARY_SOURCE_STABLE_TARGET_ID, ) -from mex.common.types import Identifier, PrimarySourceID +from mex.common.types import ( + Identifier, + MergedOrganizationalUnitIdentifier, + MergedPrimarySourceIdentifier, +) +from tests.conftest import MockedGraph @pytest.mark.parametrize( @@ -18,7 +22,7 @@ [ ( [], - PrimarySourceID("psSti00000000001"), + MergedPrimarySourceIdentifier("psSti00000000001"), "new-item", { "hadPrimarySource": "psSti00000000001", @@ -30,15 +34,13 @@ ( [ { - "i": { - "hadPrimarySource": "psSti00000000001", - "identifier": "cpId000000000002", - "identifierInPrimarySource": "existing-item", - "stableTargetId": "cpSti00000000002", - } + "hadPrimarySource": "psSti00000000001", + "identifier": "cpId000000000002", + "identifierInPrimarySource": "existing-item", + "stableTargetId": "cpSti00000000002", } ], - PrimarySourceID("psSti00000000001"), + MergedPrimarySourceIdentifier("psSti00000000001"), "existing-item", { "hadPrimarySource": "psSti00000000001", @@ -51,9 +53,9 @@ ids=["new item", "existing item"], ) def test_assign_identity_mocked( - mocked_graph: MagicMock, + mocked_graph: MockedGraph, mocked_return: list[dict[str, str]], - had_primary_source: PrimarySourceID, + had_primary_source: MergedPrimarySourceIdentifier, identifier_in_primary_source: str, expected: dict[str, Any], ) -> None: @@ -67,30 +69,26 @@ def test_assign_identity_mocked( def test_assign_identity_inconsistency_mocked( - mocked_graph: MagicMock, + mocked_graph: MockedGraph, ) -> None: mocked_graph.return_value = [ { - "i": { - "hadPrimarySource": "psSti00000000001", - "identifier": "cpId000000000002", - "identifierInPrimarySource": "existing-item", - "stableTargetId": "cpSti00000000002", - } + "hadPrimarySource": "psSti00000000001", + "identifier": "cpId000000000002", + "identifierInPrimarySource": "existing-item", + "stableTargetId": "cpSti00000000002", }, { - "i": { - "hadPrimarySource": "psSti00000000001", - "identifier": "cpId000000000098", - "identifierInPrimarySource": "existing-item", - "stableTargetId": "cpSti00000000099", - } + "hadPrimarySource": "psSti00000000001", + "identifier": "cpId000000000098", + "identifierInPrimarySource": "existing-item", + "stableTargetId": "cpSti00000000099", }, ] provider = GraphIdentityProvider.get() - with pytest.raises(MExError, match="graph inconsistency"): + with pytest.raises(MultipleResultsFoundError): provider.assign( - had_primary_source=PrimarySourceID("psSti00000000001"), + had_primary_source=MergedPrimarySourceIdentifier("psSti00000000001"), identifier_in_primary_source="existing-item", ) @@ -134,7 +132,7 @@ def test_assign_identity_inconsistency_mocked( @pytest.mark.usefixtures("load_dummy_data") @pytest.mark.integration def test_assign_identity( - had_primary_source: PrimarySourceID, + had_primary_source: MergedPrimarySourceIdentifier, identifier_in_primary_source: str, expected: dict[str, Any], ) -> None: @@ -159,15 +157,13 @@ def test_assign_identity( ( [ { - "i": { - "hadPrimarySource": "28282828282828", - "identifier": "7878787878787878777", - "identifierInPrimarySource": "one", - "stableTargetId": "949494949494949494", - } + "hadPrimarySource": "28282828282828", + "identifier": "7878787878787878777", + "identifierInPrimarySource": "one", + "stableTargetId": "949494949494949494", } ], - PrimarySourceID("28282828282828"), + MergedPrimarySourceIdentifier("28282828282828"), "one", None, [ @@ -182,20 +178,16 @@ def test_assign_identity( ( [ { - "i": { - "hadPrimarySource": "28282828282828", - "identifier": "62626262626266262", - "identifierInPrimarySource": "two", - "stableTargetId": "949494949494949494", - } + "hadPrimarySource": "28282828282828", + "identifier": "62626262626266262", + "identifierInPrimarySource": "two", + "stableTargetId": "949494949494949494", }, { - "i": { - "hadPrimarySource": "39393939393939", - "identifier": "7878787878787878777", - "identifierInPrimarySource": "duo", - "stableTargetId": "949494949494949494", - } + "hadPrimarySource": "39393939393939", + "identifier": "7878787878787878777", + "identifierInPrimarySource": "duo", + "stableTargetId": "949494949494949494", }, ], None, @@ -220,9 +212,9 @@ def test_assign_identity( ids=["nothing found", "one item", "two items"], ) def test_fetch_identities_mocked( - mocked_graph: MagicMock, + mocked_graph: MockedGraph, mocked_return: list[dict[str, str]], - had_primary_source: PrimarySourceID | None, + had_primary_source: MergedPrimarySourceIdentifier | None, identifier_in_primary_source: str | None, stable_target_id: Identifier | None, expected: list[dict[str, Any]], @@ -247,7 +239,7 @@ def test_fetch_identities_mocked( [ (None, None, Identifier("thisDoesNotExist"), []), ( - PrimarySourceID("00000000000000"), + MergedPrimarySourceIdentifier("00000000000000"), "ps-1", None, [ @@ -256,13 +248,13 @@ def test_fetch_identities_mocked( "identifier": "bFQoRhcVH5DHUq", "identifierInPrimarySource": "ps-1", "stableTargetId": "bFQoRhcVH5DHUr", - }, + } ], ), ( None, None, - Identifier("bFQoRhcVH5DHUz"), + MergedOrganizationalUnitIdentifier("bFQoRhcVH5DHUz"), [ { "identifier": "bFQoRhcVH5DHUy", @@ -282,7 +274,7 @@ def test_fetch_identities_mocked( @pytest.mark.usefixtures("load_dummy_data") @pytest.mark.integration def test_fetch_identities( - had_primary_source: PrimarySourceID | None, + had_primary_source: MergedPrimarySourceIdentifier | None, identifier_in_primary_source: str | None, stable_target_id: Identifier | None, expected: list[dict[str, Any]], diff --git a/tests/ingest/test_main.py b/tests/ingest/test_main.py index 3c69900..434392d 100644 --- a/tests/ingest/test_main.py +++ b/tests/ingest/test_main.py @@ -1,29 +1,23 @@ -from typing import Any +from collections import defaultdict +from typing import Any, cast import pytest from fastapi.testclient import TestClient -from mex.common.models import ExtractedContactPoint, ExtractedPrimarySource +from mex.backend.graph.connector import GraphConnector +from mex.common.models import AnyExtractedModel from mex.common.testing import Joker -from mex.common.types import Identifier +from tests.conftest import MockedGraph + +Payload = dict[str, list[dict[str, Any]]] @pytest.fixture -def post_payload() -> dict[str, Any]: - primary_source = ExtractedPrimarySource( - title="database", - identifierInPrimarySource="ps-1", - hadPrimarySource=Identifier.generate(), - ) - contact_point = ExtractedContactPoint( - email="info@rki.de", - identifierInPrimarySource="cp-1", - hadPrimarySource=primary_source.stableTargetId, - ) - return { - "ExtractedPrimarySource": [primary_source.model_dump()], - "ExtractedContactPoint": [contact_point.model_dump()], - } +def post_payload(dummy_data: list[AnyExtractedModel]) -> Payload: + payload = defaultdict(list) + for model in dummy_data: + payload[model.entityType].append(model.model_dump()) + return cast(Payload, dict(payload)) @pytest.mark.integration @@ -36,39 +30,26 @@ def test_bulk_insert_empty(client_with_api_key_write_permission: TestClient) -> @pytest.mark.integration def test_bulk_insert( - client_with_api_key_write_permission: TestClient, post_payload: dict[str, Any] + client_with_api_key_write_permission: TestClient, + post_payload: Payload, + dummy_data: list[AnyExtractedModel], ) -> None: - # post a single contact point to ingest endpoint - identifier = post_payload["ExtractedContactPoint"][0]["identifier"] - stable_target_id = post_payload["ExtractedContactPoint"][0]["stableTargetId"] - had_primary_source = post_payload["ExtractedContactPoint"][0]["hadPrimarySource"] - primary_source_id = post_payload["ExtractedPrimarySource"][0]["identifier"] + # get expected identifiers from the dummy data + expected_identifiers = sorted(d.identifier for d in dummy_data) + + # post the dummy data to the ingest endpoint response = client_with_api_key_write_permission.post( "/v0/ingest", json=post_payload ) # assert the response is the identifier of the contact point assert response.status_code == 201, response.text - assert response.json() == {"identifiers": [str(identifier), str(primary_source_id)]} + assert sorted(response.json()["identifiers"]) == expected_identifiers - # verify the node has actually been stored in the backend - response = client_with_api_key_write_permission.get( - f"/v0/extracted-item?stableTargetId={stable_target_id}", - ) - assert response.status_code == 200, response.text - assert response.json() == { - "total": 1, - "items": [ - { - "$type": "ExtractedContactPoint", - "email": ["info@rki.de"], - "identifier": str(identifier), - "identifierInPrimarySource": "cp-1", - "stableTargetId": str(stable_target_id), - "hadPrimarySource": str(had_primary_source), - } - ], - } + # verify the nodes have actually been stored in the database + graph = GraphConnector.get() + result = graph.fetch_extracted_data(None, None, None, 1, len(dummy_data)) + assert [i["identifier"] for i in result["items"]] == expected_identifiers def test_bulk_insert_malformed( @@ -92,17 +73,17 @@ def test_bulk_insert_malformed( } -@pytest.mark.usefixtures("mocked_graph") -def test_bulk_insert_mock( - client_with_api_key_write_permission: TestClient, post_payload: dict[str, Any] +def test_bulk_insert_mocked( + client_with_api_key_write_permission: TestClient, + post_payload: Payload, + dummy_data: list[AnyExtractedModel], + mocked_graph: MockedGraph, ) -> None: + mocked_graph.return_value = [] response = client_with_api_key_write_permission.post( "/v0/ingest", json=post_payload ) assert response.status_code == 201, response.text - assert response.json() == { - "identifiers": [ - post_payload["ExtractedContactPoint"][0]["identifier"], - post_payload["ExtractedPrimarySource"][0]["identifier"], - ] - } + assert sorted(response.json()["identifiers"]) == sorted( + d.identifier for d in dummy_data + ) diff --git a/tests/merged/test_main.py b/tests/merged/test_main.py index 7a86a67..f481a45 100644 --- a/tests/merged/test_main.py +++ b/tests/merged/test_main.py @@ -1,30 +1,58 @@ from typing import Any -from unittest.mock import MagicMock import pytest from fastapi.testclient import TestClient +from mex.common.models import ExtractedOrganizationalUnit +from tests.conftest import MockedGraph + def test_search_merged_items_mocked( - client_with_api_key_read_permission: TestClient, mocked_graph: MagicMock + client_with_api_key_read_permission: TestClient, mocked_graph: MockedGraph ) -> None: + unit = ExtractedOrganizationalUnit.model_validate( + { + "hadPrimarySource": "2222222222222222", + "identifierInPrimarySource": "unit-1", + "email": ["test@foo.bar"], + "name": [ + {"value": "Eine unit von einer Org.", "language": "de"}, + {"value": "A unit of an org.", "language": "en"}, + ], + } + ) mocked_graph.return_value = [ { - "c": 0, - "l": "ExtractedContactPoint", # stopgap mx-1382 (search for MergedContactPoint instead) - "r": [{"key": "hadPrimarySource", "value": ["2222222222222222"]}], - "n": { - "stableTargetId": "0000000000000000", - "identifier": "1111111111111111", - "identifierInPrimarySource": "test", - "email": "test@foo.bar", - }, - "i": { - "stableTargetId": "0000000000000000", - "identifier": "1111111111111111", - "identifierInPrimarySource": "test", - "hadPrimarySource": "2222222222222222", - }, + "items": [ + { + "identifier": unit.stableTargetId, + "identifierInPrimarySource": unit.identifierInPrimarySource, + "stableTargetId": unit.stableTargetId, + "email": ["test@foo.bar"], + "entityType": "ExtractedOrganizationalUnit", + "_refs": [ + { + "label": "hadPrimarySource", + "position": 0, + "value": "2222222222222222", + }, + { + "label": "name", + "position": 0, + "value": { + "value": "Eine unit von einer Org.", + "language": "de", + }, + }, + { + "label": "name", + "position": 1, + "value": {"value": "A unit of an org.", "language": "en"}, + }, + ], + } + ], + "total": 14, } ] @@ -33,13 +61,21 @@ def test_search_merged_items_mocked( assert response.json() == { "items": [ { - "$type": "MergedContactPoint", + "$type": "MergedOrganizationalUnit", + "alternativeName": [], "email": ["test@foo.bar"], - "identifier": "1111111111111111", - "stableTargetId": "0000000000000000", + "identifier": unit.stableTargetId, + "name": [ + {"language": "de", "value": "Eine unit von einer Org."}, + {"language": "en", "value": "A unit of an org."}, + ], + "parentUnit": None, + "shortName": [], + "unitOf": [], + "website": [], } ], - "total": 0, + "total": 14, } @@ -58,7 +94,6 @@ def test_search_merged_items_mocked( "documentation": [], "identifier": "00000000000000", "locatedAt": [], - "stableTargetId": "00000000000000", "title": [], "unitInCharge": [], "version": None, @@ -73,9 +108,8 @@ def test_search_merged_items_mocked( "items": [ { "$type": "MergedContactPoint", - "email": ["info@rki.de"], - "identifier": "bFQoRhcVH5DHUu", - "stableTargetId": "bFQoRhcVH5DHUv", + "email": ["info@contact-point.one"], + "identifier": "bFQoRhcVH5DHUv", } ], "total": 7, @@ -87,15 +121,13 @@ def test_search_merged_items_mocked( "items": [ { "$type": "MergedContactPoint", - "email": ["info@rki.de"], - "identifier": "bFQoRhcVH5DHUu", - "stableTargetId": "bFQoRhcVH5DHUv", + "email": ["info@contact-point.one"], + "identifier": "bFQoRhcVH5DHUv", }, { "$type": "MergedContactPoint", - "email": ["mex@rki.de"], - "identifier": "bFQoRhcVH5DHUw", - "stableTargetId": "bFQoRhcVH5DHUx", + "email": ["help@contact-point.two"], + "identifier": "bFQoRhcVH5DHUx", }, ], "total": 2, @@ -111,32 +143,28 @@ def test_search_merged_items_mocked( "contact": [], "description": [], "documentation": [], - "identifier": "bFQoRhcVH5DHUs", + "identifier": "bFQoRhcVH5DHUt", "locatedAt": [], - "stableTargetId": "bFQoRhcVH5DHUt", - "title": [ - {"language": None, "value": "A cool and searchable title"} - ], + "title": [], "unitInCharge": [], - "version": None, + "version": "Cool Version v2.13", } ], "total": 1, }, ), ( - "?stableTargetId=bFQoRhcVH5DHUz", + "?identifier=bFQoRhcVH5DHUz", { "items": [ { "$type": "MergedOrganizationalUnit", "alternativeName": [], "email": [], - "identifier": "bFQoRhcVH5DHUy", + "identifier": "bFQoRhcVH5DHUz", "name": [{"language": "en", "value": "Unit 1"}], "parentUnit": None, "shortName": [], - "stableTargetId": "bFQoRhcVH5DHUz", "unitOf": [], "website": [], } @@ -150,7 +178,7 @@ def test_search_merged_items_mocked( "skip 1", "entity type contact points", "full text search", - "stable target id filter", + "identifier filter", ], ) @pytest.mark.usefixtures("load_dummy_data") diff --git a/tests/test_fields.py b/tests/test_fields.py index ad8e70e..fe26376 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1,64 +1,69 @@ -from typing import Any +from typing import Annotated, Any import pytest from pydantic import BaseModel -from mex.backend.fields import _get_inner_types, is_reference_field, is_text_field -from mex.common.types import Identifier, OrganizationalUnitID, PersonID, Text +from mex.backend.fields import ( + _contains_only_types, + _get_inner_types, +) +from mex.common.types import ( + MERGED_IDENTIFIER_CLASSES, + Identifier, + MergedPersonIdentifier, +) @pytest.mark.parametrize( ("annotation", "expected_types"), ( (str, [str]), - (str | None, [str, type(None)]), + (str | None, [str]), + (str | int, [str, int]), (list[str | int | list[str]], [str, int, str]), - (None, [type(None)]), + (Annotated[list[str | int], "some-annotation"], [str, int]), + (None, []), ), + ids=[ + "simple annotation", + "optional type", + "type union", + "complex nested types", + "annotated list", + "static None", + ], ) -def test__get_inner_types(annotation: Any, expected_types: list[type]) -> None: +def test_get_inner_types(annotation: Any, expected_types: list[type]) -> None: assert list(_get_inner_types(annotation)) == expected_types @pytest.mark.parametrize( - ("annotation", "is_reference"), - ( - (str, False), - (str | None, False), - (list[str | int | list[str]], False), - (None, False), - (Identifier, True), - (PersonID, True), - (list[PersonID], True), - (str | PersonID, True), - (list[None | OrganizationalUnitID], True), - (list[None | list[OrganizationalUnitID]], True), - ), -) -def test_is_reference_field(annotation: Any, is_reference: bool) -> None: - class DummyModel(BaseModel): - attribute: annotation - - assert is_reference_field(DummyModel.model_fields["attribute"]) == is_reference - - -@pytest.mark.parametrize( - ("annotation", "is_text"), + ("annotation", "types", "expected"), ( - (str, False), - (str | None, False), - (list[str | int | list[str]], False), - (None, False), - (Identifier, False), - (Text, True), - (list[Text], True), - (str | Text, True), - (list[None | Text], True), - (list[None | list[Text]], True), + (None, [str], False), + (str, [str], True), + (str, [Identifier], False), + (Identifier, [str], False), + (list[str | int | list[str]], [str, float], False), + (list[str | int | list[str]], [int, str], True), + (MergedPersonIdentifier | None, MERGED_IDENTIFIER_CLASSES, True), ), + ids=[ + "static None", + "simple str", + "str vs identifier", + "identifier vs str", + "complex miss", + "complex hit", + "optional identifier", + ], ) -def test_is_text_field(annotation: Any, is_text: bool) -> None: +def test_contains_only_types( + annotation: Any, types: list[type], expected: bool +) -> None: class DummyModel(BaseModel): attribute: annotation - assert is_text_field(DummyModel.model_fields["attribute"]) == is_text + assert ( + _contains_only_types(DummyModel.model_fields["attribute"], *types) == expected + ) diff --git a/tests/test_main.py b/tests/test_main.py index 6dccf11..b7f7462 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,8 +1,7 @@ -import asyncio import json import logging from typing import Any -from unittest.mock import AsyncMock, MagicMock, Mock +from unittest.mock import MagicMock, Mock import pydantic_core import pytest @@ -12,7 +11,6 @@ from pytest import LogCaptureFixture from mex.backend.main import ( - SettingsMiddleware, app, close_connectors, handle_uncaught_exception, @@ -20,7 +18,6 @@ from mex.backend.settings import BackendSettings from mex.common.connector import ConnectorContext from mex.common.exceptions import MExError -from mex.common.settings import SettingsContext def test_openapi_schema(client: TestClient) -> None: @@ -30,10 +27,10 @@ def test_openapi_schema(client: TestClient) -> None: schema = response.json() assert schema["info"]["title"] == "mex-backend" assert schema["servers"] == [{"url": "http://localhost:8080/"}] - assert schema["components"]["schemas"]["PersonID"] == { - "title": "PersonID", + assert schema["components"]["schemas"]["MergedPersonIdentifier"] == { + "title": "MergedPersonIdentifier", "type": "string", - "description": "Identifier for Person items.", + "description": "Identifier for merged persons.", "pattern": "^[a-zA-Z0-9]{14,22}$", } @@ -95,28 +92,10 @@ def test_handle_uncaught_exception( exception: Exception, expected: dict[str, Any] ) -> None: response = handle_uncaught_exception(Mock(), exception) - assert response.status_code == 500 + assert response.status_code == 500, response.body assert json.loads(response.body) == expected -def test_settings_middleware() -> None: - SettingsContext.set(None) - assert SettingsContext.get() is None - - # check settings are loaded on middleware init - middleware = SettingsMiddleware(Mock()) - assert middleware.settings.debug is False - - # check settings are injected on middleware dispatch - async def dispatch_and_assert() -> None: - await middleware.dispatch(Mock(), AsyncMock()) - settings = SettingsContext.get() - assert settings - assert settings.debug is False - - asyncio.run(dispatch_and_assert()) - - def test_close_all_connectors(caplog: LogCaptureFixture) -> None: context = { "ConnectorA": Mock(close=MagicMock()), @@ -146,7 +125,7 @@ def test_all_endpoints_require_authorization(client: TestClient) -> None: @pytest.mark.integration -def test_database_is_empty(settings: BackendSettings): +def test_database_is_empty(settings: BackendSettings) -> None: with GraphDatabase.driver( settings.graph_url, auth=( diff --git a/tests/test_roundtrip.py b/tests/test_roundtrip.py new file mode 100644 index 0000000..906bd82 --- /dev/null +++ b/tests/test_roundtrip.py @@ -0,0 +1,25 @@ +from typing import Annotated + +import pytest +from pydantic import Field, TypeAdapter + +from mex.backend.graph.connector import MEX_EXTRACTED_PRIMARY_SOURCE, GraphConnector +from mex.common.models import AnyExtractedModel + + +@pytest.mark.integration +def test_graph_ingest_and_query_roundtrip( + load_dummy_data: list[AnyExtractedModel], +) -> None: + seeded_models = [*load_dummy_data, MEX_EXTRACTED_PRIMARY_SOURCE] + + connector = GraphConnector.get() + result = connector.fetch_extracted_data(None, None, None, 0, len(seeded_models)) + + extracted_model_adapter = TypeAdapter( + list[Annotated[AnyExtractedModel, Field(discriminator="entityType")]] + ) + + assert extracted_model_adapter.validate_python(result["items"]) == sorted( + seeded_models, key=lambda x: x.identifier + ) diff --git a/tests/test_transform.py b/tests/test_transform.py new file mode 100644 index 0000000..57aa360 --- /dev/null +++ b/tests/test_transform.py @@ -0,0 +1,17 @@ +from mex.backend.transform import to_primitive +from mex.common.types import APIType, MergedActivityIdentifier, YearMonth + + +def test_to_primitive_uses_custom_encoders() -> None: + primitive = to_primitive( + dict( + api=APIType["RPC"], + activity=MergedActivityIdentifier.generate(seed=99), + month=YearMonth(2005, 11), + ) + ) + assert primitive == { + "api": "https://mex.rki.de/item/api-type-5", + "activity": "bFQoRhcVH5DHV1", + "month": "2005-11", + }