Skip to content

Commit

Permalink
Merge pull request #49 from OndraSlama/ondra-add-search-for-query-method
Browse files Browse the repository at this point in the history
Add search_for_annotations method
  • Loading branch information
lbenka authored Aug 25, 2023
2 parents cba35f7 + 34b3f72 commit b994053
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 4 deletions.
14 changes: 10 additions & 4 deletions rossum_api/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ async def fetch_all(
content_schema_ids: Sequence[str] = (),
method: str = "GET",
max_pages: Optional[int] = None,
json: Optional[dict] = None,
**filters: Any,
) -> AsyncIterator[Dict[str, Any]]:
"""Retrieve a list of objects in a specific resource.
Expand All @@ -157,6 +158,8 @@ async def fetch_all(
method so that export() can re-use fetch_all() implementation
max_pages
maximum number of pages to fetch
json
json payload sent with the request. Used for POST requests.
filters
mapping from resource field to value used to filter records
"""
Expand All @@ -168,13 +171,15 @@ async def fetch_all(
**filters,
}
results, total_pages = await self._fetch_page(
f"{resource}", method, query_params, sideloads
f"{resource}", method, query_params, sideloads, json=json
)
# Fire async tasks to fetch the rest of the pages and start yielding results from page 1
last_page = min(total_pages, max_pages or total_pages)
page_requests = [
asyncio.create_task(
self._fetch_page(f"{resource}", method, {**query_params, "page": i}, sideloads)
self._fetch_page(
f"{resource}", method, {**query_params, "page": i}, sideloads, json=json
)
)
for i in range(2, last_page + 1)
]
Expand All @@ -193,8 +198,9 @@ async def _fetch_page(
method: str,
query_params: Dict[str, Any],
sideload_groups: Sequence[str],
json: Optional[dict] = None,
) -> Tuple[List[Dict[str, Any]], int]:
data = await self.request_json(method, resource, params=query_params)
data = await self.request_json(method, resource, params=query_params, json=json)
self._embed_sideloads(data, sideload_groups)
return data["results"], data["pagination"]["total_pages"]

Expand Down Expand Up @@ -305,7 +311,7 @@ async def export(
if export_format == "json":
# JSON export is paginated just like a regular fetch_all, it abuses **filters kwargs of
# fetch_all to pass export-specific query params
async for result in self.fetch_all(url, method=method, max_pages=None, **query_params):
async for result in self.fetch_all(url, method=method, **query_params): # type: ignore
yield result
else:
# In CSV/XML/XLSX case, all annotations are returned, i.e. the response can be large,
Expand Down
22 changes: 22 additions & 0 deletions rossum_api/elis_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,28 @@ async def list_all_annotations(
):
yield dacite.from_dict(Annotation, a)

async def search_for_annotations(
self,
query: Optional[dict] = None,
query_string: Optional[dict] = None,
ordering: Sequence[str] = (),
sideloads: Sequence[str] = (),
**kwargs: Any,
) -> AsyncIterable[Annotation]:
"""https://elis.rossum.ai/api/docs/#search-for-annotations."""
if not query and not query_string:
raise ValueError("Either query or query_string must be provided")
json_payload = {}
if query:
json_payload["query"] = query
if query_string:
json_payload["query_string"] = query_string

async for a in self._http_client.fetch_all(
"annotations/search", ordering, sideloads, json=json_payload, method="POST", **kwargs
):
yield dacite.from_dict(Annotation, a)

async def retrieve_annotation(
self, annotation_id: int, sideloads: Sequence[str] = ()
) -> Annotation:
Expand Down
15 changes: 15 additions & 0 deletions rossum_api/elis_api_client_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,21 @@ def list_all_annotations(
)
)

def search_for_annotations(
self,
query: Optional[dict] = None,
query_string: Optional[dict] = None,
ordering: Sequence[str] = (),
sideloads: Sequence[str] = (),
**kwargs: Any,
) -> Iterable[Annotation]:
"""https://elis.rossum.ai/api/docs/internal/#search-for-annotations."""
return self._iter_over_async(
self.elis_api_client.search_for_annotations(
query, query_string, ordering, sideloads, **kwargs
)
)

def retrieve_annotation(self, annotation_id: int, sideloads: Sequence[str] = ()) -> Annotation:
"""https://elis.rossum.ai/api/docs/#retrieve-an-annotation."""
return self.event_loop.run_until_complete(
Expand Down
34 changes: 34 additions & 0 deletions tests/elis_api_client/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,23 @@ async def test_list_all_annotations_with_content_sideloads_without_schema_ids(

assert not http_client.fetch_all.called

async def test_search_for_annotations(self, elis_client, dummy_annotation, mock_generator):
client, http_client = elis_client
http_client.fetch_all.return_value = mock_generator(dummy_annotation)

annotations = client.search_for_annotations({"$and": []}, {"string": "expl"})

async for a in annotations:
assert a == Annotation(**dummy_annotation)

http_client.fetch_all.assert_called_with(
"annotations/search",
(),
(),
json={"query": {"$and": []}, "query_string": {"string": "expl"}},
method="POST",
)

async def test_retrieve_annotation(self, elis_client, dummy_annotation):
client, http_client = elis_client
http_client.fetch_one.return_value = dummy_annotation
Expand Down Expand Up @@ -312,6 +329,23 @@ def test_list_all_annotations_with_content_sideloads_without_schema_ids(

assert not http_client.fetch_all.called

def test_search_for_annotations(self, elis_client_sync, dummy_annotation, mock_generator):
client, http_client = elis_client_sync
http_client.fetch_all.return_value = mock_generator(dummy_annotation)

annotations = client.search_for_annotations({"$and": []}, {"string": "expl"})

for a in annotations:
assert a == Annotation(**dummy_annotation)

http_client.fetch_all.assert_called_with(
"annotations/search",
(),
(),
json={"query": {"$and": []}, "query_string": {"string": "expl"}},
method="POST",
)

def test_retrieve_annotation(self, elis_client_sync, dummy_annotation):
client, http_client = elis_client_sync
http_client.fetch_one.return_value = dummy_annotation
Expand Down

0 comments on commit b994053

Please sign in to comment.