From e433f858bbca35c7cf3211a7dabbbea921ec431d Mon Sep 17 00:00:00 2001 From: Ondrej Slama Date: Thu, 24 Aug 2023 15:52:35 +0200 Subject: [PATCH 1/2] feat(api_client): allow passing json argument in `fetch_all` as request payload --- rossum_api/api_client.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/rossum_api/api_client.py b/rossum_api/api_client.py index 99dc2bd..74af213 100644 --- a/rossum_api/api_client.py +++ b/rossum_api/api_client.py @@ -136,6 +136,7 @@ async def fetch_all( content_schema_ids: Sequence[str] = (), method: str = "GET", max_pages: Optional[int] = None, + json: Optional[dict] = None, **filters: Any, ) -> AsyncIterator[Dict[str, Any]]: """Retrieve a list of objects in a specific resource. @@ -157,6 +158,8 @@ async def fetch_all( method so that export() can re-use fetch_all() implementation max_pages maximum number of pages to fetch + json + json payload sent with the request. Used for POST requests. filters mapping from resource field to value used to filter records """ @@ -168,13 +171,15 @@ async def fetch_all( **filters, } results, total_pages = await self._fetch_page( - f"{resource}", method, query_params, sideloads + f"{resource}", method, query_params, sideloads, json=json ) # Fire async tasks to fetch the rest of the pages and start yielding results from page 1 last_page = min(total_pages, max_pages or total_pages) page_requests = [ asyncio.create_task( - self._fetch_page(f"{resource}", method, {**query_params, "page": i}, sideloads) + self._fetch_page( + f"{resource}", method, {**query_params, "page": i}, sideloads, json=json + ) ) for i in range(2, last_page + 1) ] @@ -193,8 +198,9 @@ async def _fetch_page( method: str, query_params: Dict[str, Any], sideload_groups: Sequence[str], + json: Optional[dict] = None, ) -> Tuple[List[Dict[str, Any]], int]: - data = await self.request_json(method, resource, params=query_params) + data = await self.request_json(method, resource, params=query_params, json=json) self._embed_sideloads(data, sideload_groups) return data["results"], data["pagination"]["total_pages"] @@ -305,7 +311,7 @@ async def export( if export_format == "json": # JSON export is paginated just like a regular fetch_all, it abuses **filters kwargs of # fetch_all to pass export-specific query params - async for result in self.fetch_all(url, method=method, max_pages=None, **query_params): + async for result in self.fetch_all(url, method=method, **query_params): # type: ignore yield result else: # In CSV/XML/XLSX case, all annotations are returned, i.e. the response can be large, From 34b3f7249c149d0fa2da6fa1a171b92f0a0a2ffa Mon Sep 17 00:00:00 2001 From: Ondrej Slama Date: Thu, 24 Aug 2023 15:53:24 +0200 Subject: [PATCH 2/2] feat(elis_api_clients): add new `search_for_annotations` method --- rossum_api/elis_api_client.py | 22 +++++++++++++++ rossum_api/elis_api_client_sync.py | 15 ++++++++++ tests/elis_api_client/test_annotations.py | 34 +++++++++++++++++++++++ 3 files changed, 71 insertions(+) diff --git a/rossum_api/elis_api_client.py b/rossum_api/elis_api_client.py index 18a3a0c..d2d8b76 100644 --- a/rossum_api/elis_api_client.py +++ b/rossum_api/elis_api_client.py @@ -232,6 +232,28 @@ async def list_all_annotations( ): yield dacite.from_dict(Annotation, a) + async def search_for_annotations( + self, + query: Optional[dict] = None, + query_string: Optional[dict] = None, + ordering: Sequence[str] = (), + sideloads: Sequence[str] = (), + **kwargs: Any, + ) -> AsyncIterable[Annotation]: + """https://elis.rossum.ai/api/docs/#search-for-annotations.""" + if not query and not query_string: + raise ValueError("Either query or query_string must be provided") + json_payload = {} + if query: + json_payload["query"] = query + if query_string: + json_payload["query_string"] = query_string + + async for a in self._http_client.fetch_all( + "annotations/search", ordering, sideloads, json=json_payload, method="POST", **kwargs + ): + yield dacite.from_dict(Annotation, a) + async def retrieve_annotation( self, annotation_id: int, sideloads: Sequence[str] = () ) -> Annotation: diff --git a/rossum_api/elis_api_client_sync.py b/rossum_api/elis_api_client_sync.py index 48be7c3..1e6950a 100644 --- a/rossum_api/elis_api_client_sync.py +++ b/rossum_api/elis_api_client_sync.py @@ -230,6 +230,21 @@ def list_all_annotations( ) ) + def search_for_annotations( + self, + query: Optional[dict] = None, + query_string: Optional[dict] = None, + ordering: Sequence[str] = (), + sideloads: Sequence[str] = (), + **kwargs: Any, + ) -> Iterable[Annotation]: + """https://elis.rossum.ai/api/docs/internal/#search-for-annotations.""" + return self._iter_over_async( + self.elis_api_client.search_for_annotations( + query, query_string, ordering, sideloads, **kwargs + ) + ) + def retrieve_annotation(self, annotation_id: int, sideloads: Sequence[str] = ()) -> Annotation: """https://elis.rossum.ai/api/docs/#retrieve-an-annotation.""" return self.event_loop.run_until_complete( diff --git a/tests/elis_api_client/test_annotations.py b/tests/elis_api_client/test_annotations.py index c71274c..64c2fbe 100644 --- a/tests/elis_api_client/test_annotations.py +++ b/tests/elis_api_client/test_annotations.py @@ -178,6 +178,23 @@ async def test_list_all_annotations_with_content_sideloads_without_schema_ids( assert not http_client.fetch_all.called + async def test_search_for_annotations(self, elis_client, dummy_annotation, mock_generator): + client, http_client = elis_client + http_client.fetch_all.return_value = mock_generator(dummy_annotation) + + annotations = client.search_for_annotations({"$and": []}, {"string": "expl"}) + + async for a in annotations: + assert a == Annotation(**dummy_annotation) + + http_client.fetch_all.assert_called_with( + "annotations/search", + (), + (), + json={"query": {"$and": []}, "query_string": {"string": "expl"}}, + method="POST", + ) + async def test_retrieve_annotation(self, elis_client, dummy_annotation): client, http_client = elis_client http_client.fetch_one.return_value = dummy_annotation @@ -312,6 +329,23 @@ def test_list_all_annotations_with_content_sideloads_without_schema_ids( assert not http_client.fetch_all.called + def test_search_for_annotations(self, elis_client_sync, dummy_annotation, mock_generator): + client, http_client = elis_client_sync + http_client.fetch_all.return_value = mock_generator(dummy_annotation) + + annotations = client.search_for_annotations({"$and": []}, {"string": "expl"}) + + for a in annotations: + assert a == Annotation(**dummy_annotation) + + http_client.fetch_all.assert_called_with( + "annotations/search", + (), + (), + json={"query": {"$and": []}, "query_string": {"string": "expl"}}, + method="POST", + ) + def test_retrieve_annotation(self, elis_client_sync, dummy_annotation): client, http_client = elis_client_sync http_client.fetch_one.return_value = dummy_annotation