Skip to content

Commit

Permalink
fix: use Iterator instead of Iterable types
Browse files Browse the repository at this point in the history
Iterator is a better type for the method return value,
since it correctly type checks the use of `next()`.
Iterable interface doesn't.
  • Loading branch information
Filip Uhlik committed Aug 9, 2024
1 parent c771ec9 commit 77edb52
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 31 deletions.
28 changes: 14 additions & 14 deletions rossum_api/elis_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

if typing.TYPE_CHECKING:
import pathlib
from typing import Any, AsyncIterable, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Sequence, Tuple, Union

import httpx

Expand Down Expand Up @@ -80,7 +80,7 @@ async def list_all_queues(
self,
ordering: Sequence[str] = (),
**filters: Any,
) -> AsyncIterable[Queue]:
) -> AsyncIterator[Queue]:
"""https://elis.rossum.ai/api/docs/#list-all-queues."""
async for q in self._http_client.fetch_all(Resource.Queue, ordering, **filters):
yield self._deserializer(Resource.Queue, q)
Expand Down Expand Up @@ -212,7 +212,7 @@ async def retrieve_upload(
async def export_annotations_to_json(
self,
queue_id: int,
) -> AsyncIterable[Annotation]:
) -> AsyncIterator[Annotation]:
"""https://elis.rossum.ai/api/docs/#export-annotations.
JSON export is paginated and returns the result in a way similar to other list_all methods.
Expand All @@ -223,7 +223,7 @@ async def export_annotations_to_json(

async def export_annotations_to_file(
self, queue_id: int, export_format: ExportFileFormats
) -> AsyncIterable[bytes]:
) -> AsyncIterator[bytes]:
"""https://elis.rossum.ai/api/docs/#export-annotations.
XLSX/CSV/XML exports can be huge, therefore byte streaming is used to keep memory consumption low.
Expand All @@ -236,7 +236,7 @@ async def list_all_organizations(
self,
ordering: Sequence[str] = (),
**filters: Any,
) -> AsyncIterable[Organization]:
) -> AsyncIterator[Organization]:
"""https://elis.rossum.ai/api/docs/#list-all-organizations."""
async for o in self._http_client.fetch_all(Resource.Organization, ordering, **filters):
yield self._deserializer(Resource.Organization, o)
Expand All @@ -258,7 +258,7 @@ async def list_all_schemas(
self,
ordering: Sequence[str] = (),
**filters: Any,
) -> AsyncIterable[Schema]:
) -> AsyncIterator[Schema]:
"""https://elis.rossum.ai/api/docs/#list-all-schemas."""
async for s in self._http_client.fetch_all(Resource.Schema, ordering, **filters):
yield self._deserializer(Resource.Schema, s)
Expand All @@ -284,7 +284,7 @@ async def list_all_users(
self,
ordering: Sequence[str] = (),
**filters: Any,
) -> AsyncIterable[User]:
) -> AsyncIterator[User]:
"""https://elis.rossum.ai/api/docs/#list-all-users."""
async for u in self._http_client.fetch_all(Resource.User, ordering, **filters):
yield self._deserializer(Resource.User, u)
Expand Down Expand Up @@ -316,7 +316,7 @@ async def list_all_annotations(
sideloads: Sequence[str] = (),
content_schema_ids: Sequence[str] = (),
**filters: Any,
) -> AsyncIterable[Annotation]:
) -> AsyncIterator[Annotation]:
"""https://elis.rossum.ai/api/docs/#list-all-annotations."""
if sideloads and "content" in sideloads and not content_schema_ids:
raise ValueError(
Expand All @@ -334,7 +334,7 @@ async def search_for_annotations(
ordering: Sequence[str] = (),
sideloads: Sequence[str] = (),
**kwargs: Any,
) -> AsyncIterable[Annotation]:
) -> AsyncIterator[Annotation]:
"""https://elis.rossum.ai/api/docs/#search-for-annotations."""
if not query and not query_string:
raise ValueError("Either query or query_string must be provided")
Expand Down Expand Up @@ -516,7 +516,7 @@ async def list_all_workspaces(
self,
ordering: Sequence[str] = (),
**filters: Any,
) -> AsyncIterable[Workspace]:
) -> AsyncIterator[Workspace]:
"""https://elis.rossum.ai/api/docs/#list-all-workspaces."""
async for w in self._http_client.fetch_all(Resource.Workspace, ordering, **filters):
yield self._deserializer(Resource.Workspace, w)
Expand Down Expand Up @@ -556,7 +556,7 @@ async def list_all_connectors(
self,
ordering: Sequence[str] = (),
**filters: Any,
) -> AsyncIterable[Connector]:
) -> AsyncIterator[Connector]:
"""https://elis.rossum.ai/api/docs/#list-all-connectors."""
async for c in self._http_client.fetch_all(Resource.Connector, ordering, **filters):
yield self._deserializer(Resource.Connector, c)
Expand All @@ -578,7 +578,7 @@ async def list_all_hooks(
self,
ordering: Sequence[str] = (),
**filters: Any,
) -> AsyncIterable[Hook]:
) -> AsyncIterator[Hook]:
"""https://elis.rossum.ai/api/docs/#list-all-hooks."""
async for h in self._http_client.fetch_all(Resource.Hook, ordering, **filters):
yield self._deserializer(Resource.Hook, h)
Expand All @@ -600,13 +600,13 @@ async def list_all_user_roles(
self,
ordering: Sequence[str] = (),
**filters: Any,
) -> AsyncIterable[Group]:
) -> AsyncIterator[Group]:
"""https://elis.rossum.ai/api/docs/#list-all-user-roles."""
async for g in self._http_client.fetch_all(Resource.Group, ordering, **filters):
yield self._deserializer(Resource.Group, g)

# ##### GENERIC METHODS #####
async def request_paginated(self, url: str, *args, **kwargs) -> AsyncIterable[dict]:
async def request_paginated(self, url: str, *args, **kwargs) -> AsyncIterator[dict]:
"""Use to perform requests to seldomly used or experimental endpoints with paginated response that do not have
direct support in the client and return iterable.
"""
Expand Down
34 changes: 17 additions & 17 deletions rossum_api/elis_api_client_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import pathlib
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(
except RuntimeError:
self.event_loop = asyncio.new_event_loop()

def _iter_over_async(self, ait: AsyncIterable[T]) -> Iterable[T]:
def _iter_over_async(self, ait: AsyncIterator[T]) -> Iterator[T]:
ait = ait.__aiter__()
while True:
try:
Expand All @@ -102,7 +102,7 @@ def list_all_queues(
self,
ordering: Sequence[str] = (),
**filters: Dict[str, Any],
) -> Iterable[Queue]:
) -> Iterator[Queue]:
"""https://elis.rossum.ai/api/docs/#list-all-queues."""
return self._iter_over_async(self.elis_api_client.list_all_queues(ordering, **filters))

Expand Down Expand Up @@ -186,7 +186,7 @@ def retrieve_upload(

return self.event_loop.run_until_complete(self.elis_api_client.retrieve_upload(upload_id))

def export_annotations_to_json(self, queue_id: int) -> Iterable[Annotation]:
def export_annotations_to_json(self, queue_id: int) -> Iterator[Annotation]:
"""https://elis.rossum.ai/api/docs/#export-annotations.
JSON export is paginated and returns the result in a way similar to other list_all methods.
Expand All @@ -195,7 +195,7 @@ def export_annotations_to_json(self, queue_id: int) -> Iterable[Annotation]:

def export_annotations_to_file(
self, queue_id: int, export_format: ExportFileFormats
) -> Iterable[bytes]:
) -> Iterator[bytes]:
"""https://elis.rossum.ai/api/docs/#export-annotations.
XLSX/CSV/XML exports can be huge, therefore byte streaming is used to keep memory consumption low.
Expand All @@ -209,7 +209,7 @@ def list_all_organizations(
self,
ordering: Sequence[str] = (),
**filters: Dict[str, Any],
) -> Iterable[Organization]:
) -> Iterator[Organization]:
"""https://elis.rossum.ai/api/docs/#list-all-organizations."""
return self._iter_over_async(
self.elis_api_client.list_all_organizations(ordering, **filters)
Expand All @@ -233,7 +233,7 @@ def list_all_schemas(
self,
ordering: Sequence[str] = (),
**filters: Dict[str, Any],
) -> Iterable[Schema]:
) -> Iterator[Schema]:
"""https://elis.rossum.ai/api/docs/#list-all-schemas."""
return self._iter_over_async(self.elis_api_client.list_all_schemas(ordering, **filters))

Expand All @@ -259,7 +259,7 @@ def list_all_users(
self,
ordering: Sequence[str] = (),
**filters: Dict[str, Any],
) -> Iterable[User]:
) -> Iterator[User]:
"""https://elis.rossum.ai/api/docs/#list-all-users."""
return self._iter_over_async(self.elis_api_client.list_all_users(ordering, **filters))

Expand All @@ -286,7 +286,7 @@ def list_all_annotations(
sideloads: Sequence[str] = (),
content_schema_ids: Sequence[str] = (),
**filters: Dict[str, Any],
) -> Iterable[Annotation]:
) -> Iterator[Annotation]:
"""https://elis.rossum.ai/api/docs/#list-all-annotations."""
return self._iter_over_async(
self.elis_api_client.list_all_annotations(
Expand All @@ -301,7 +301,7 @@ def search_for_annotations(
ordering: Sequence[str] = (),
sideloads: Sequence[str] = (),
**kwargs: Any,
) -> Iterable[Annotation]:
) -> Iterator[Annotation]:
"""https://elis.rossum.ai/api/docs/internal/#search-for-annotations."""
return self._iter_over_async(
self.elis_api_client.search_for_annotations(
Expand Down Expand Up @@ -433,7 +433,7 @@ def list_all_workspaces(
self,
ordering: Sequence[str] = (),
**filters: Dict[str, Any],
) -> Iterable[Workspace]:
) -> Iterator[Workspace]:
"""https://elis.rossum.ai/api/docs/#list-all-workspaces."""
return self._iter_over_async(self.elis_api_client.list_all_workspaces(ordering, **filters))

Expand Down Expand Up @@ -466,7 +466,7 @@ def list_all_connectors(
self,
ordering: Sequence[str] = (),
**filters: Dict[str, Any],
) -> Iterable[Connector]:
) -> Iterator[Connector]:
"""https://elis.rossum.ai/api/docs/#list-all-connectors."""
return self._iter_over_async(self.elis_api_client.list_all_connectors(ordering, **filters))

Expand All @@ -485,7 +485,7 @@ def list_all_hooks(
self,
ordering: Sequence[str] = (),
**filters: Dict[str, Any],
) -> Iterable[Hook]:
) -> Iterator[Hook]:
"""https://elis.rossum.ai/api/docs/#list-all-hooks."""
return self._iter_over_async(self.elis_api_client.list_all_hooks(ordering, **filters))

Expand All @@ -502,13 +502,13 @@ def list_all_user_roles(
self,
ordering: Sequence[str] = (),
**filters: Dict[str, Any],
) -> Iterable[Group]:
) -> Iterator[Group]:
"""https://elis.rossum.ai/api/docs/#list-all-user-roles."""
return self._iter_over_async(self.elis_api_client.list_all_user_roles(ordering, **filters))

def request_paginated(self, url: str, *args, **kwargs) -> Iterable[dict]:
def request_paginated(self, url: str, *args, **kwargs) -> Iterator[dict]:
"""Use to perform requests to seldomly used or experimental endpoints with paginated response that do not have
direct support in the client and return iterable.
direct support in the client and return Iterator.
"""
return self._iter_over_async(self.elis_api_client.request_paginated(url, *args, **kwargs))

Expand Down

0 comments on commit 77edb52

Please sign in to comment.