diff --git a/registry/sql-registry/main.py b/registry/sql-registry/main.py index 00ac1d422..117d1dd7a 100644 --- a/registry/sql-registry/main.py +++ b/registry/sql-registry/main.py @@ -1,10 +1,12 @@ import os +import traceback from typing import Optional from uuid import UUID from fastapi import APIRouter, FastAPI, HTTPException +from fastapi.responses import JSONResponse from starlette.middleware.cors import CORSMiddleware from registry import * -from registry.db_registry import DbRegistry +from registry.db_registry import DbRegistry, ConflictError from registry.models import AnchorDef, AnchorFeatureDef, DerivedFeatureDef, EntityType, ProjectDef, SourceDef, to_snake rp = "/" @@ -28,6 +30,49 @@ allow_headers=["*"], ) +def exc_to_content(e: Exception) -> dict: + content={"message": str(e)} + if os.environ.get("REGISTRY_DEBUGGING"): + content["traceback"] = "".join(traceback.TracebackException.from_exception(e).format()) + return content + +@app.exception_handler(ConflictError) +async def conflict_error_handler(_, exc: ConflictError): + return JSONResponse( + status_code=409, + content=exc_to_content(exc), + ) + + +@app.exception_handler(ValueError) +async def value_error_handler(_, exc: ValueError): + return JSONResponse( + status_code=400, + content=exc_to_content(exc), + ) + +@app.exception_handler(TypeError) +async def type_error_handler(_, exc: ValueError): + return JSONResponse( + status_code=400, + content=exc_to_content(exc), + ) + + +@app.exception_handler(KeyError) +async def key_error_handler(_, exc: KeyError): + return JSONResponse( + status_code=404, + content=exc_to_content(exc), + ) + +@app.exception_handler(IndexError) +async def index_error_handler(_, exc: IndexError): + return JSONResponse( + status_code=404, + content=exc_to_content(exc), + ) + @router.get("/projects") def get_projects() -> list[str]: diff --git a/registry/sql-registry/registry/__init__.py b/registry/sql-registry/registry/__init__.py index 5ce157408..afcc69eee 100644 --- a/registry/sql-registry/registry/__init__.py +++ b/registry/sql-registry/registry/__init__.py @@ -3,4 +3,4 @@ from registry.models import * from registry.interface import Registry from registry.database import DbConnection, connect -from registry.db_registry import DbRegistry \ No newline at end of file +from registry.db_registry import DbRegistry, ConflictError \ No newline at end of file diff --git a/registry/sql-registry/registry/db_registry.py b/registry/sql-registry/registry/db_registry.py index 58f4b98db..70ab62f68 100644 --- a/registry/sql-registry/registry/db_registry.py +++ b/registry/sql-registry/registry/db_registry.py @@ -7,6 +7,9 @@ from registry.models import AnchorAttributes, AnchorDef, AnchorFeatureAttributes, AnchorFeatureDef, DerivedFeatureAttributes, DerivedFeatureDef, Edge, EntitiesAndRelations, Entity, EntityRef, EntityType, ProjectAttributes, ProjectDef, RelationshipType, SourceAttributes, SourceDef, _to_type, _to_uuid import json +class ConflictError(Exception): + pass + def quote(id): if isinstance(id, str): @@ -16,7 +19,6 @@ def quote(id): else: return ",".join([quote(i) for i in id]) - class DbRegistry(Registry): def __init__(self): self.conn = connect() @@ -41,6 +43,8 @@ def get_entity_id(self, id_or_name: Union[str, UUID]) -> UUID: # It is a name ret = self.conn.query( f"select entity_id from entities where qualified_name=%s", str(id_or_name)) + if len(ret) == 0: + raise KeyError(f"Entity {id_or_name} not found") return ret[0]["entity_id"] def get_neighbors(self, id_or_name: Union[str, UUID], relationship: RelationshipType) -> list[Edge]: @@ -138,7 +142,7 @@ def create_project(self, definition: ProjectDef) -> UUID: len(r), definition.qualified_name) # The entity with same name already exists but with different type if _to_type(r[0]["entity_type"], EntityType) != EntityType.Project: - raise ValueError("Entity %s already exists" % + raise ConflictError("Entity %s already exists" % definition.qualified_name) # Just return the existing project id return _to_uuid(r[0]["entity_id"]) @@ -166,7 +170,7 @@ def create_project_datasource(self, project_id: UUID, definition: SourceDef) -> len(r), definition.qualified_name) # The entity with same name already exists but with different type if _to_type(r[0]["entity_type"], EntityType) != EntityType.Source: - raise ValueError("Entity %s already exists" % + raise ConflictError("Entity %s already exists" % definition.qualified_name) attr: SourceAttributes = _to_type( json.loads(r[0]["attributes"]), SourceAttributes) @@ -179,7 +183,7 @@ def create_project_datasource(self, project_id: UUID, definition: SourceDef) -> # Creating exactly same entity # Just return the existing id return _to_uuid(r[0]["entity_id"]) - raise ValueError("Entity %s already exists" % + raise ConflictError("Entity %s already exists" % definition.qualified_name) id = uuid4() c.execute(f"insert into entities (entity_id, entity_type, qualified_name, attributes) values (%s, %s, %s, %s)", @@ -207,7 +211,7 @@ def create_project_anchor(self, project_id: UUID, definition: AnchorDef) -> UUID len(r), definition.qualified_name) # The entity with same name already exists but with different type if _to_type(r[0]["entity_type"], EntityType) != EntityType.Anchor: - raise ValueError("Entity %s already exists" % + raise ConflictError("Entity %s already exists" % definition.qualified_name) attr: AnchorAttributes = _to_type( json.loads(r[0]["attributes"]), AnchorAttributes) @@ -215,7 +219,7 @@ def create_project_anchor(self, project_id: UUID, definition: AnchorDef) -> UUID # Creating exactly same entity # Just return the existing id return _to_uuid(r[0]["entity_id"]) - raise ValueError("Entity %s already exists" % + raise ConflictError("Entity %s already exists" % definition.qualified_name) c.execute("select entity_id, qualified_name from entities where entity_id = %s and entity_type = %s", (str( definition.source_id), str(EntityType.Source))) @@ -257,7 +261,7 @@ def create_project_anchor_feature(self, project_id: UUID, anchor_id: UUID, defin len(r), definition.qualified_name) # The entity with same name already exists but with different type if _to_type(r[0]["entity_type"], EntityType) != EntityType.AnchorFeature: - raise ValueError("Entity %s already exists" % + raise ConflictError("Entity %s already exists" % definition.qualified_name) attr: AnchorFeatureAttributes = _to_type( json.loads(r[0]["attributes"]), AnchorFeatureAttributes) @@ -269,7 +273,7 @@ def create_project_anchor_feature(self, project_id: UUID, anchor_id: UUID, defin # Just return the existing id return _to_uuid(r[0]["entity_id"]) # The existing entity has different definition, that's a conflict - raise ValueError("Entity %s already exists" % + raise ConflictError("Entity %s already exists" % definition.qualified_name) source_id = anchor.attributes.source.id id = uuid4() @@ -305,7 +309,7 @@ def create_project_derived_feature(self, project_id: UUID, definition: DerivedFe len(r), definition.qualified_name) # The entity with same name already exists but with different type, that's conflict if _to_type(r[0]["entity_type"], EntityType) != EntityType.DerivedFeature: - raise ValueError("Entity %s already exists" % + raise ConflictError("Entity %s already exists" % definition.qualified_name) attr: DerivedFeatureAttributes = _to_type( json.loads(r[0]["attributes"]), DerivedFeatureAttributes) @@ -317,7 +321,7 @@ def create_project_derived_feature(self, project_id: UUID, definition: DerivedFe # Just return the existing id return _to_uuid(r[0]["entity_id"]) # The existing entity has different definition, that's a conflict - raise ValueError("Entity %s already exists" % + raise ConflictError("Entity %s already exists" % definition.qualified_name) r1 = [] # Fill `input_anchor_features`, from `definition` we have ids only, we still need qualified names @@ -429,7 +433,7 @@ def _get_entity(self, id_or_name: Union[str, UUID]) -> Entity: where entity_id = %s ''', self.get_entity_id(id_or_name)) if not row: - raise ValueError(f"Entity {id_or_name} not found") + raise KeyError(f"Entity {id_or_name} not found") row=row[0] row["attributes"] = json.loads(row["attributes"]) return _to_type(row, Entity)