Skip to content

Commit

Permalink
Implement graph id provider
Browse files Browse the repository at this point in the history
  • Loading branch information
cutoffthetop committed Jan 16, 2024
1 parent 2a4badc commit d62dedd
Show file tree
Hide file tree
Showing 21 changed files with 315 additions and 311 deletions.
6 changes: 6 additions & 0 deletions mex/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from pkgutil import extend_path

__path__ = extend_path(__path__, __name__)

from mex.backend.identity.provider import GraphIdentityProvider
from mex.backend.types import BackendIdentityProvider
from mex.common.identity.registry import register_provider

register_provider(BackendIdentityProvider.GRAPH, GraphIdentityProvider)
34 changes: 7 additions & 27 deletions mex/backend/extracted/models.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,14 @@
from enum import Enum, EnumMeta, _EnumDict
from typing import TYPE_CHECKING, Generator, Literal, Union, cast
from typing import TYPE_CHECKING, Union

from pydantic import Field, create_model
from pydantic import Field

from mex.common.models import BASE_MODEL_CLASSES, BaseExtractedData, BaseModel
from mex.common.transform import dromedary_to_snake


def _collect_extracted_model_classes(
base_models: list[type[BaseModel]],
) -> Generator[tuple[str, type[BaseExtractedData]], None, None]:
"""Create extracted model classes with type for the given MEx models."""
for model in base_models:
# to satisfy current frontend, rename ExtractedThing -> Thing
name = model.__name__.replace("Base", "Extracted")
extracted_model = create_model(
name,
__base__=(model, BaseExtractedData),
__module__=__name__,
entityType=(Literal[name], Field(name, alias="$type", frozen=True)),
)
yield name, cast(type[BaseExtractedData], extracted_model)


# mx-1533 stopgap: because we do not yet have a backend-powered identity provider,
# we need to re-create the extracted models without automatic
# identifier and stableTargetId assignment
EXTRACTED_MODEL_CLASSES_BY_NAME: dict[str, type[BaseExtractedData]] = dict(
_collect_extracted_model_classes(BASE_MODEL_CLASSES)
from mex.common.models import (
EXTRACTED_MODEL_CLASSES_BY_NAME,
BaseExtractedData,
BaseModel,
)
from mex.common.transform import dromedary_to_snake


class ExtractedTypeMeta(EnumMeta):
Expand Down
2 changes: 1 addition & 1 deletion mex/backend/extracted/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ def transform_graph_results_to_extracted_item_search_response(
" \n".join("()-[{key}]->({value})".format(**r) for r in result["r"]),
exc_info=False,
)
return ExtractedItemSearchResponse(items=items, total=total)
return ExtractedItemSearchResponse.model_construct(items=items, total=total)
2 changes: 1 addition & 1 deletion mex/backend/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pydantic.fields import FieldInfo

from mex.backend.extracted.models import EXTRACTED_MODEL_CLASSES_BY_NAME
from mex.common.models import EXTRACTED_MODEL_CLASSES_BY_NAME
from mex.common.types import Identifier, Text


Expand Down
20 changes: 19 additions & 1 deletion mex/backend/graph/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,16 @@
from mex.common.connector import BaseConnector
from mex.common.exceptions import MExError
from mex.common.logging import logger
from mex.common.models import EXTRACTED_MODEL_CLASSES_BY_NAME
from mex.common.models import (
EXTRACTED_MODEL_CLASSES_BY_NAME,
MEX_PRIMARY_SOURCE_IDENTIFIER_IN_PRIMARY_SOURCE,
MEX_PRIMARY_SOURCE_STABLE_TARGET_ID,
ExtractedPrimarySource,
)
from mex.common.types import Identifier

MEX_PRIMARY_SOURCE_IDENTIFIER = Identifier("00000000000000")


class GraphConnector(BaseConnector):
"""Connector to handle authentication and transactions with the graph database."""
Expand All @@ -48,6 +55,7 @@ def __init__(self) -> None:
self._check_connectivity_and_authentication()
self._seed_constraints()
self._seed_indices()
self._seed_primary_source()

def _check_connectivity_and_authentication(self) -> None:
"""Check the connectivity and authentication to the graph."""
Expand Down Expand Up @@ -87,6 +95,16 @@ def _seed_indices(self) -> GraphResult:
},
)

def _seed_primary_source(self) -> Identifier:
"""Ensure the primary source `mex` is seeded and linked to itself."""
mex_primary_source = ExtractedPrimarySource.model_construct(
hadPrimarySource=MEX_PRIMARY_SOURCE_STABLE_TARGET_ID,
identifier=MEX_PRIMARY_SOURCE_IDENTIFIER,
identifierInPrimarySource=MEX_PRIMARY_SOURCE_IDENTIFIER_IN_PRIMARY_SOURCE,
stableTargetId=MEX_PRIMARY_SOURCE_STABLE_TARGET_ID,
)
return self.ingest([mex_primary_source])[0]

def mcommit(
self, *statements_with_parameters: tuple[str, dict[str, Any] | None]
) -> list[GraphResult]:
Expand Down
2 changes: 1 addition & 1 deletion mex/backend/graph/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

CREATE_CONSTRAINTS_STATEMENT = r"""
CREATE CONSTRAINT IF NOT EXISTS
CREATE CONSTRAINT identifier_uniqueness IF NOT EXISTS
FOR (n:{node_label})
REQUIRE n.identifier IS UNIQUE;
"""
Expand Down
7 changes: 2 additions & 5 deletions mex/backend/graph/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@

from pydantic import BaseModel as PydanticBaseModel

from mex.backend.extracted.models import (
EXTRACTED_MODEL_CLASSES_BY_NAME,
AnyExtractedModel,
)
from mex.backend.extracted.models import AnyExtractedModel
from mex.backend.fields import REFERENCE_FIELDS_BY_CLASS_NAME
from mex.backend.graph.hydrate import dehydrate, hydrate
from mex.backend.transform import to_primitive
from mex.common.identity import Identity
from mex.common.models import BaseModel, MExModel
from mex.common.models import EXTRACTED_MODEL_CLASSES_BY_NAME, BaseModel, MExModel


class MergableNode(PydanticBaseModel):
Expand Down
42 changes: 5 additions & 37 deletions mex/backend/identity/main.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
from fastapi import APIRouter
from fastapi.exceptions import HTTPException

from mex.backend.graph.connector import GraphConnector
from mex.backend.graph.transform import transform_identity_result_to_identity
from mex.backend.identity.models import IdentityAssignRequest, IdentityFetchResponse
from mex.backend.identity.provider import GraphIdentityProvider
from mex.common.exceptions import MExError
from mex.common.identity.models import Identity
from mex.common.models import (
MEX_PRIMARY_SOURCE_IDENTIFIER_IN_PRIMARY_SOURCE,
MEX_PRIMARY_SOURCE_STABLE_TARGET_ID,
)
from mex.common.types import Identifier

router = APIRouter()
Expand All @@ -18,35 +13,11 @@
@router.post("/identity", status_code=200, tags=["extractors"])
def assign_identity(request: IdentityAssignRequest) -> Identity:
"""Insert a new identity or update an existing one."""
connector = GraphConnector.get()
graph_result = connector.fetch_identities(
identity_provider = GraphIdentityProvider.get()
return identity_provider.assign(
had_primary_source=request.hadPrimarySource,
identifier_in_primary_source=request.identifierInPrimarySource,
)
if len(graph_result.data) > 1:
raise MExError("found multiple identities indicating graph inconsistency")
if len(graph_result.data) == 1:
return transform_identity_result_to_identity(graph_result.data[0])
if (
request.identifierInPrimarySource
== MEX_PRIMARY_SOURCE_IDENTIFIER_IN_PRIMARY_SOURCE
and request.hadPrimarySource == MEX_PRIMARY_SOURCE_STABLE_TARGET_ID
):
# This is to deal with the edge case where primary source is the parent of
# all primary sources and has no parents for itself,
# this will add itself as its parent.
return Identity(
hadPrimarySource=request.hadPrimarySource,
identifier=MEX_PRIMARY_SOURCE_STABLE_TARGET_ID,
identifierInPrimarySource=request.identifierInPrimarySource,
stableTargetId=MEX_PRIMARY_SOURCE_STABLE_TARGET_ID,
)
return Identity(
hadPrimarySource=request.hadPrimarySource,
identifier=Identifier.generate(),
identifierInPrimarySource=request.identifierInPrimarySource,
stableTargetId=Identifier.generate(),
)


@router.get("/identity", status_code=200, tags=["extractors"])
Expand All @@ -60,16 +31,13 @@ def fetch_identity(
Either provide `stableTargetId` or `hadPrimarySource`
and `identifierInPrimarySource` together to get a unique result.
"""
connector = GraphConnector.get()
identity_provider = GraphIdentityProvider.get()
try:
graph_result = connector.fetch_identities(
identities = identity_provider.fetch(
had_primary_source=hadPrimarySource,
identifier_in_primary_source=identifierInPrimarySource,
stable_target_id=stableTargetId,
)
except MExError as error:
raise HTTPException(400, error.args)
identities = [
transform_identity_result_to_identity(result) for result in graph_result.data
]
return IdentityFetchResponse(items=identities, total=len(identities))
73 changes: 73 additions & 0 deletions mex/backend/identity/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from functools import cache

from mex.backend.graph.connector import GraphConnector
from mex.backend.graph.transform import transform_identity_result_to_identity
from mex.common.exceptions import MExError
from mex.common.identity import BaseProvider, Identity
from mex.common.models import (
MEX_PRIMARY_SOURCE_IDENTIFIER_IN_PRIMARY_SOURCE,
MEX_PRIMARY_SOURCE_STABLE_TARGET_ID,
)
from mex.common.types import Identifier, PrimarySourceID


class GraphIdentityProvider(BaseProvider, GraphConnector):
"""Identity provider that communicates with the neo4j graph database."""

@cache
def assign(
self,
had_primary_source: PrimarySourceID,
identifier_in_primary_source: str,
) -> Identity:
"""Find an Identity in the database or assign a new one."""
graph_result = self.fetch_identities(
had_primary_source=had_primary_source,
identifier_in_primary_source=identifier_in_primary_source,
)
if len(graph_result.data) > 1:
raise MExError("found multiple identities indicating graph inconsistency")
if len(graph_result.data) == 1:
return transform_identity_result_to_identity(graph_result.data[0])
if (
identifier_in_primary_source
== MEX_PRIMARY_SOURCE_IDENTIFIER_IN_PRIMARY_SOURCE
and had_primary_source == MEX_PRIMARY_SOURCE_STABLE_TARGET_ID
):
# This is to deal with the edge case where primary source is the parent of
# all primary sources and has no parents for itself,
# this will add itself as its parent.
return Identity(
hadPrimarySource=had_primary_source,
identifier=MEX_PRIMARY_SOURCE_STABLE_TARGET_ID,
identifierInPrimarySource=identifier_in_primary_source,
stableTargetId=MEX_PRIMARY_SOURCE_STABLE_TARGET_ID,
)
return Identity(
hadPrimarySource=had_primary_source,
identifier=Identifier.generate(),
identifierInPrimarySource=identifier_in_primary_source,
stableTargetId=Identifier.generate(),
)

def fetch(
self,
*,
had_primary_source: Identifier | None = None,
identifier_in_primary_source: str | None = None,
stable_target_id: Identifier | None = None,
) -> list[Identity]:
"""Find Identity instances matching the given filters.
Either provide `stable_target_id` or `had_primary_source`
and `identifier_in_primary_source` together to get a unique result.
"""
graph_result = self.fetch_identities(
had_primary_source=had_primary_source,
identifier_in_primary_source=identifier_in_primary_source,
stable_target_id=stable_target_id,
)
return [
transform_identity_result_to_identity(result)
for result in graph_result.data
]
7 changes: 2 additions & 5 deletions mex/backend/ingest/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@

from pydantic import ConfigDict, create_model

from mex.backend.extracted.models import (
EXTRACTED_MODEL_CLASSES_BY_NAME,
AnyExtractedModel,
)
from mex.common.models import BaseModel
from mex.backend.extracted.models import AnyExtractedModel
from mex.common.models import EXTRACTED_MODEL_CLASSES_BY_NAME, BaseModel
from mex.common.types import Identifier


Expand Down
9 changes: 8 additions & 1 deletion mex/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,14 @@


def create_openapi_schema() -> dict[str, Any]:
"""Create an OpenAPI schema for the backend."""
"""Create an OpenAPI schema for the backend.
Settings:
backend_api_url: MEx backend API url.
Returns:
OpenApi schema as dictionary
"""
if app.openapi_schema:
return app.openapi_schema

Expand Down
7 changes: 3 additions & 4 deletions mex/backend/merged/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

from fastapi import APIRouter, Query

from mex.backend.extracted.models import ExtractedType
from mex.backend.graph.connector import GraphConnector
from mex.backend.merged.models import MergedItemSearchResponse
from mex.backend.merged.models import MergedItemSearchResponse, MergedType
from mex.backend.merged.transform import (
transform_graph_results_to_merged_item_search_response_facade,
)
Expand All @@ -17,7 +16,7 @@
def search_merged_items_facade(
q: str = Query("", max_length=1000),
stableTargetId: Identifier | None = Query(None), # noqa: N803
entityType: Sequence[ExtractedType] = Query([]), # noqa: N803
entityType: Sequence[MergedType] = Query([]), # noqa: N803
skip: int = Query(0, ge=0, le=10e10),
limit: int = Query(10, ge=1, le=100),
) -> MergedItemSearchResponse:
Expand All @@ -28,7 +27,7 @@ def search_merged_items_facade(
query_results = graph.query_nodes(
q,
stableTargetId,
[t.value for t in entityType or ExtractedType],
[t.value.replace("Merged", "Extracted") for t in entityType or MergedType],
skip,
limit,
)
Expand Down
2 changes: 1 addition & 1 deletion mex/backend/merged/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,5 @@ def transform_graph_results_to_merged_item_search_response_facade(
" \n".join("()-[{key}]->({value})".format(**r) for r in result["r"]),
exc_info=False,
)
# TODO merge extracted items with rule set
# TODO: merge extracted items with rule sets
return MergedItemSearchResponse.model_validate({"items": items, "total": total})
9 changes: 7 additions & 2 deletions mex/backend/settings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from pydantic import Field, SecretStr

from mex.backend.types import APIKeyDatabase, APIUserDatabase
from mex.backend.types import APIKeyDatabase, APIUserDatabase, BackendIdentityProvider
from mex.common.settings import BaseSettings
from mex.common.types import Sink
from mex.common.types import IdentityProvider, Sink


class BackendSettings(BaseSettings):
Expand Down Expand Up @@ -70,3 +70,8 @@ class BackendSettings(BaseSettings):
description="Database of users.",
validation_alias="MEX_BACKEND_API_USER_DATABASE",
)
identity_provider: IdentityProvider | BackendIdentityProvider = Field(
BackendIdentityProvider.GRAPH,
description="Provider to assign stableTargetIds to new model instances.",
validation_alias="MEX_IDENTITY_PROVIDER",
) # type: ignore[assignment]
6 changes: 6 additions & 0 deletions mex/backend/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,9 @@ class APIUserDatabase(BaseModel):

read: dict[str, APIUserPassword] = {}
write: dict[str, APIUserPassword] = {}


class BackendIdentityProvider(Enum):
"""Identity providers implemented by mex-backend."""

GRAPH = "graph"
Loading

0 comments on commit d62dedd

Please sign in to comment.