Skip to content

Commit

Permalink
feat: Support deserializing of API payloads to user-defined model cla…
Browse files Browse the repository at this point in the history
…sses

Users can provide a custom deserializer that allows customising which
models should be returned by the client.
  • Loading branch information
asgeirrr committed Nov 2, 2023
1 parent acbf731 commit a1e1385
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 48 deletions.
102 changes: 55 additions & 47 deletions rossum_api/elis_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,28 @@
from enum import Enum

import aiofiles
import dacite

from rossum_api.api_client import APIClient, Resource
from rossum_api.models.annotation import Annotation
from rossum_api.models.connector import Connector
from rossum_api.models.hook import Hook
from rossum_api.models.inbox import Inbox
from rossum_api.models.organization import Organization
from rossum_api.models.queue import Queue
from rossum_api.models.schema import Schema
from rossum_api.models.user import User
from rossum_api.models.user_role import UserRole
from rossum_api.models.workspace import Workspace
from rossum_api.models import deserialize_default

if typing.TYPE_CHECKING:
import pathlib
from typing import Any, AsyncIterable, Callable, Dict, List, Optional, Sequence, Tuple, Union

import httpx

from rossum_api.models import Deserializer
from rossum_api.models.annotation import Annotation
from rossum_api.models.connector import Connector
from rossum_api.models.hook import Hook
from rossum_api.models.inbox import Inbox
from rossum_api.models.organization import Organization
from rossum_api.models.queue import Queue
from rossum_api.models.schema import Schema
from rossum_api.models.user import User
from rossum_api.models.user_role import UserRole
from rossum_api.models.workspace import Workspace


class ExportFileFormats(Enum):
CSV = "csv"
Expand All @@ -44,8 +46,16 @@ def __init__(
token: Optional[str] = None,
base_url: Optional[str] = None,
http_client: Optional[APIClient] = None,
deserializer: Optional[Deserializer] = None,
):
"""
Parameters
----------
deserializer
pass a custom deserialization callable if different model classes should be returned
"""
self._http_client = http_client or APIClient(username, password, token, base_url)
self._deserializer = deserializer or deserialize_default

# ##### QUEUE #####
async def retrieve_queue(
Expand All @@ -55,7 +65,7 @@ async def retrieve_queue(
"""https://elis.rossum.ai/api/docs/#retrieve-a-queue-2."""
queue = await self._http_client.fetch_one(Resource.Queue, queue_id)

return dacite.from_dict(Queue, queue)
return self._deserializer(Resource.Queue, queue)

async def list_all_queues(
self,
Expand All @@ -64,13 +74,13 @@ async def list_all_queues(
) -> AsyncIterable[Queue]:
"""https://elis.rossum.ai/api/docs/#list-all-queues."""
async for q in self._http_client.fetch_all(Resource.Queue, ordering, **filters):
yield dacite.from_dict(Queue, q)
yield self._deserializer(Resource.Queue, q)

async def create_new_queue(self, data: Dict[str, Any]) -> Queue:
"""https://elis.rossum.ai/api/docs/#create-new-queue."""
queue = await self._http_client.create(Resource.Queue, data)

return dacite.from_dict(Queue, queue)
return self._deserializer(Resource.Queue, queue)

async def delete_queue(self, queue_id: int) -> None:
"""https://elis.rossum.ai/api/docs/#delete-a-queue."""
Expand Down Expand Up @@ -124,7 +134,7 @@ async def export_annotations_to_json(
"""
async for chunk in self._http_client.export(Resource.Queue, queue_id, "json"):
# JSON export can be translated directly to Annotation object
yield dacite.from_dict(Annotation, typing.cast(typing.Dict, chunk))
yield self._deserializer(Resource.Annotation, typing.cast(typing.Dict, chunk))

async def export_annotations_to_file(
self, queue_id: int, export_format: ExportFileFormats
Expand All @@ -144,15 +154,13 @@ async def list_all_organizations(
):
"""https://elis.rossum.ai/api/docs/#list-all-organizations."""
async for o in self._http_client.fetch_all(Resource.Organization, ordering, **filters):
yield dacite.from_dict(Organization, o)
yield self._deserializer(Resource.Organization, o)

async def retrieve_organization(self, org_id: int) -> Organization:
"""https://elis.rossum.ai/api/docs/#retrieve-an-organization."""
organization: Dict[Any, Any] = await self._http_client.fetch_one(
Resource.Organization, org_id
)
organization = await self._http_client.fetch_one(Resource.Organization, org_id)

return dacite.from_dict(Organization, organization)
return self._deserializer(Resource.Organization, organization)

async def retrieve_own_organization(self) -> Organization:
"""Retrieve organization of currently logged in user."""
Expand All @@ -168,19 +176,19 @@ async def list_all_schemas(
) -> AsyncIterable[Schema]:
"""https://elis.rossum.ai/api/docs/#list-all-schemas."""
async for s in self._http_client.fetch_all(Resource.Schema, ordering, **filters):
yield dacite.from_dict(Schema, s)
yield self._deserializer(Resource.Schema, s)

async def retrieve_schema(self, schema_id: int) -> Schema:
"""https://elis.rossum.ai/api/docs/#retrieve-a-schema."""
schema: Dict[Any, Any] = await self._http_client.fetch_one(Resource.Schema, schema_id)

return dacite.from_dict(Schema, schema)
return self._deserializer(Resource.Schema, schema)

async def create_new_schema(self, data: Dict[str, Any]) -> Schema:
"""https://elis.rossum.ai/api/docs/#create-a-new-schema."""
queue = await self._http_client.create(Resource.Schema, data)
schema = await self._http_client.create(Resource.Schema, data)

return dacite.from_dict(Schema, queue)
return self._deserializer(Resource.Schema, schema)

async def delete_schema(self, schema_id) -> None:
"""https://elis.rossum.ai/api/docs/#delete-a-schema."""
Expand All @@ -194,19 +202,19 @@ async def list_all_users(
) -> AsyncIterable[User]:
"""https://elis.rossum.ai/api/docs/#list-all-users."""
async for u in self._http_client.fetch_all(Resource.User, ordering, **filters):
yield dacite.from_dict(User, u)
yield self._deserializer(Resource.User, u)

async def retrieve_user(self, user_id: int) -> User:
"""https://elis.rossum.ai/api/docs/#retrieve-a-user-2."""
user = await self._http_client.fetch_one(Resource.User, user_id)

return dacite.from_dict(User, user)
return self._deserializer(Resource.User, user)

async def create_new_user(self, data: Dict[str, Any]) -> User:
"""https://elis.rossum.ai/api/docs/#create-new-user."""
user = await self._http_client.create(Resource.User, data)

return dacite.from_dict(User, user)
return self._deserializer(Resource.User, user)

# TODO: specific method in APICLient
def change_user_password(self, new_password: str) -> dict:
Expand All @@ -232,7 +240,7 @@ async def list_all_annotations(
async for a in self._http_client.fetch_all(
Resource.Annotation, ordering, sideloads, content_schema_ids, **filters
):
yield dacite.from_dict(Annotation, a)
yield self._deserializer(Resource.Annotation, a)

async def search_for_annotations(
self,
Expand All @@ -259,7 +267,7 @@ async def search_for_annotations(
method="POST",
**kwargs,
):
yield dacite.from_dict(Annotation, a)
yield self._deserializer(Resource.Annotation, a)

async def retrieve_annotation(
self, annotation_id: int, sideloads: Sequence[str] = ()
Expand All @@ -268,7 +276,7 @@ async def retrieve_annotation(
annotation_json = await self._http_client.fetch_one(Resource.Annotation, annotation_id)
if sideloads:
await self._sideload(annotation_json, sideloads)
return dacite.from_dict(Annotation, annotation_json)
return self._deserializer(Resource.Annotation, annotation_json)

async def poll_annotation(
self,
Expand All @@ -283,28 +291,28 @@ async def poll_annotation(
"""
annotation_json = await self._http_client.fetch_one(Resource.Annotation, annotation_id)
# Parse early, we want predicate to work with Annotation instances for convenience
annotation = dacite.from_dict(Annotation, annotation_json)
annotation = self._deserializer(Resource.Annotation, annotation_json)

while not predicate(annotation):
await asyncio.sleep(sleep_s)
annotation_json = await self._http_client.fetch_one(Resource.Annotation, annotation_id)
annotation = dacite.from_dict(Annotation, annotation_json)
annotation = self._deserializer(Resource.Annotation, annotation_json)

if sideloads:
await self._sideload(annotation_json, sideloads)
return dacite.from_dict(Annotation, annotation_json)
return self._deserializer(Resource.Annotation, annotation_json)

async def update_annotation(self, annotation_id: int, data: Dict[str, Any]) -> Annotation:
"""https://elis.rossum.ai/api/docs/#update-an-annotation."""
annotation = await self._http_client.replace(Resource.Annotation, annotation_id, data)

return dacite.from_dict(Annotation, annotation)
return self._deserializer(Resource.Annotation, annotation)

async def update_part_annotation(self, annotation_id: int, data: Dict[str, Any]) -> Annotation:
"""https://elis.rossum.ai/api/docs/#update-part-of-an-annotation."""
annotation = await self._http_client.update(Resource.Annotation, annotation_id, data)

return dacite.from_dict(Annotation, annotation)
return self._deserializer(Resource.Annotation, annotation)

# ##### WORKSPACES #####
async def list_all_workspaces(
Expand All @@ -314,19 +322,19 @@ async def list_all_workspaces(
) -> AsyncIterable[Workspace]:
"""https://elis.rossum.ai/api/docs/#list-all-workspaces."""
async for w in self._http_client.fetch_all(Resource.Workspace, ordering, **filters):
yield dacite.from_dict(Workspace, w)
yield self._deserializer(Resource.Workspace, w)

async def retrieve_workspace(self, workspace_id) -> Workspace:
"""https://elis.rossum.ai/api/docs/#retrieve-a-workspace."""
workspace = await self._http_client.fetch_one(Resource.Workspace, workspace_id)

return dacite.from_dict(Workspace, workspace)
return self._deserializer(Resource.Workspace, workspace)

async def create_new_workspace(self, data: Dict[str, Any]) -> Workspace:
"""https://elis.rossum.ai/api/docs/#create-a-new-workspace."""
workspace = await self._http_client.create(Resource.Workspace, data)

return dacite.from_dict(Workspace, workspace)
return self._deserializer(Resource.Workspace, workspace)

async def delete_workspace(self, workspace_id) -> None:
"""https://elis.rossum.ai/api/docs/#delete-a-workspace."""
Expand All @@ -337,7 +345,7 @@ async def create_new_inbox(self, data: Dict[str, Any]) -> Inbox:
"""https://elis.rossum.ai/api/docs/#create-a-new-inbox."""
inbox = await self._http_client.create(Resource.Inbox, data)

return dacite.from_dict(Inbox, inbox)
return self._deserializer(Resource.Inbox, inbox)

# ##### CONNECTORS #####
async def list_all_connectors(
Expand All @@ -347,19 +355,19 @@ async def list_all_connectors(
) -> AsyncIterable[Connector]:
"""https://elis.rossum.ai/api/docs/#list-all-connectors."""
async for c in self._http_client.fetch_all(Resource.Connector, ordering, **filters):
yield dacite.from_dict(Connector, c)
yield self._deserializer(Resource.Connector, c)

async def retrieve_connector(self, connector_id) -> Connector:
"""https://elis.rossum.ai/api/docs/#retrieve-a-connector."""
connector = await self._http_client.fetch_one(Resource.Connector, connector_id)

return dacite.from_dict(Connector, connector)
return self._deserializer(Resource.Connector, connector)

async def create_new_connector(self, data: Dict[str, Any]) -> Connector:
"""https://elis.rossum.ai/api/docs/#create-a-new-connector."""
connector = await self._http_client.create(Resource.Connector, data)

return dacite.from_dict(Connector, connector)
return self._deserializer(Resource.Connector, connector)

# ##### HOOKS #####
async def list_all_hooks(
Expand All @@ -369,19 +377,19 @@ async def list_all_hooks(
) -> AsyncIterable[Hook]:
"""https://elis.rossum.ai/api/docs/#list-all-hooks."""
async for h in self._http_client.fetch_all(Resource.Hook, ordering, **filters):
yield dacite.from_dict(Hook, h)
yield self._deserializer(Resource.Hook, h)

async def retrieve_hook(self, hook_id) -> Hook:
"""https://elis.rossum.ai/api/docs/#retrieve-a-hook."""
hook = await self._http_client.fetch_one(Resource.Hook, hook_id)

return dacite.from_dict(Hook, hook)
return self._deserializer(Resource.Hook, hook)

async def create_new_hook(self, data: Dict[str, Any]) -> Hook:
"""https://elis.rossum.ai/api/docs/#create-a-new-hook."""
hook = await self._http_client.create(Resource.Hook, data)

return dacite.from_dict(Hook, hook)
return self._deserializer(Resource.Hook, hook)

# ##### USER ROLES #####
async def list_all_user_roles(
Expand All @@ -390,8 +398,8 @@ async def list_all_user_roles(
**filters: Any,
) -> AsyncIterable[UserRole]:
"""https://elis.rossum.ai/api/docs/#list-all-user-roles."""
async for u in self._http_client.fetch_all(Resource.Group, ordering, **filters):
yield dacite.from_dict(UserRole, u)
async for g in self._http_client.fetch_all(Resource.Group, ordering, **filters):
yield self._deserializer(Resource.Group, g)

# ##### GENERIC METHODS #####
async def request_paginated(self, url: str, *args, **kwargs) -> AsyncIterable[dict]:
Expand Down
6 changes: 5 additions & 1 deletion rossum_api/elis_api_client_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from rossum_api import ExportFileFormats
from rossum_api.api_client import APIClient
from rossum_api.models import Deserializer
from rossum_api.models.annotation import Annotation
from rossum_api.models.connector import Connector
from rossum_api.models.hook import Hook
Expand Down Expand Up @@ -55,8 +56,11 @@ def __init__(
token: Optional[str] = None,
base_url: Optional[str] = None,
http_client: Optional[APIClient] = None,
deserializer: Optional[Deserializer] = None,
):
self.elis_api_client = ElisAPIClient(username, password, token, base_url, http_client)
self.elis_api_client = ElisAPIClient(
username, password, token, base_url, http_client, deserializer
)

try:
self.event_loop = asyncio.get_running_loop()
Expand Down
43 changes: 43 additions & 0 deletions rossum_api/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import dacite

from rossum_api.api_client import Resource
from rossum_api.models.annotation import Annotation
from rossum_api.models.connector import Connector
from rossum_api.models.hook import Hook
from rossum_api.models.inbox import Inbox
from rossum_api.models.organization import Organization
from rossum_api.models.queue import Queue
from rossum_api.models.schema import Schema
from rossum_api.models.user import User
from rossum_api.models.user_role import UserRole
from rossum_api.models.workspace import Workspace

if TYPE_CHECKING:
from typing import Any, Callable, Dict

JsonDict = Dict[str, Any]
Deserializer = Callable[[Resource, JsonDict], Any]


RESOURCE_TO_MODEL = {
Resource.Annotation: Annotation,
Resource.Connector: Connector,
Resource.Group: UserRole,
Resource.Hook: Hook,
Resource.Inbox: Inbox,
Resource.Organization: Organization,
Resource.Queue: Queue,
Resource.Schema: Schema,
Resource.User: User,
Resource.Workspace: Workspace,
}


def deserialize_default(resource: Resource, payload: JsonDict) -> Any:
"""Deserialize payload into dataclasses using dacite."""
model_class = RESOURCE_TO_MODEL[resource]
return dacite.from_dict(model_class, payload)

0 comments on commit a1e1385

Please sign in to comment.