diff --git a/cl/api/pagination.py b/cl/api/pagination.py index 9e79185d66..27ee3928d2 100644 --- a/cl/api/pagination.py +++ b/cl/api/pagination.py @@ -44,6 +44,8 @@ class VersionBasedPagination(PageNumberPagination): } ordering = "" cursor_ordering_fields = [] + is_count_request = False + count = 0 def __init__(self): super().__init__() @@ -88,6 +90,14 @@ def paginate_queryset(self, queryset, request, view=None): self.version = request.version self.request = request + self.is_count_request = ( + request.query_params.get("count") == "on" and self.version == "v4" + ) + + if self.is_count_request: + self.count = queryset.count() + return [] + do_cursor_pagination, requested_ordering = ( self.do_v4_cursor_pagination() ) @@ -103,10 +113,18 @@ def paginate_queryset(self, queryset, request, view=None): ) def get_paginated_response(self, data): + if self.is_count_request: + return Response({"count": self.count}) + do_cursor_pagination, _ = self.do_v4_cursor_pagination() if do_cursor_pagination: - # Get paginated response for CursorPagination - return self.cursor_paginator.get_paginated_response(data) + response = self.cursor_paginator.get_paginated_response(data) + # Build and include the count URL: + count_url = self.request.build_absolute_uri() + count_url = replace_query_param(count_url, "count", "on") + response.data["count"] = count_url + response.data.move_to_end("count", last=False) + return response # Get paginated response for PageNumberPagination return super().get_paginated_response(data) diff --git a/cl/api/templates/includes/toc_sidebar.html b/cl/api/templates/includes/toc_sidebar.html index 60a99cd69f..e609bfd201 100644 --- a/cl/api/templates/includes/toc_sidebar.html +++ b/cl/api/templates/includes/toc_sidebar.html @@ -33,6 +33,7 @@
Ordering by fields with duplicate values is non-deterministic. If you wish to order by such a field, you should provide a second field as a tie-breaker to consistently order results. For example, ordering by date_filed
will not return consistent ordering for items that have the same date, but this can be fixed by ordering by date_filed,id
. In that case, if two items have the same date_filed
value, the tie will be broken by the id
field.
To retrieve the total number of items matching your query without fetching all the data, you can use the count=on
parameter. This is useful for verifying filters and understanding the scope of your query results without incurring the overhead of retrieving full datasets.
+
curl "{% get_full_host %}{% url "opinion-list" version="v4" %}?cited_opinion=32239&count=on" + +{"count": 3302}+
When count=on
is specified:
count
key with the total number of matching items.cursor
are ignored.In standard paginated responses, a count
key is included with the URL to obtain the total count for your query:
curl "{% get_full_host %}{% url "opinion-list" version="v4" %}?cited_opinion=32239" + +{ + "count": "https://www.courtlistener.com/api/rest/v4/opinions/?cited_opinion=32239&count=on", + "next": "https://www.courtlistener.com/api/rest/v4/opinions/?cited_opinion=32239&cursor=2", + "previous": null, + "results": [ + // paginated results + ] +}+
You can follow this URL to get the total count of items matching your query.
To save bandwidth and increase serialization performance, fields can be limited by using the fields
parameter with a comma-separated list of fields.
diff --git a/cl/api/tests.py b/cl/api/tests.py
index 94f068b819..e523d46a80 100644
--- a/cl/api/tests.py
+++ b/cl/api/tests.py
@@ -320,6 +320,56 @@ def test_recap_api_required_filter(self, mock_logging_prefix) -> None:
r = self.client.get(path, {"pacer_doc_id__in": "17711118263,asdf"})
self.assertEqual(r.status_code, HTTPStatus.OK)
+ def test_count_on_query_counts(self, mock_logging_prefix) -> None:
+ """
+ Check that a v4 API request with param `count=on` only performs
+ 2 queries to the database: one to check the authenticated user,
+ and another to select the count.
+ """
+ with CaptureQueriesContext(connection) as ctx:
+ path = reverse("docket-list", kwargs={"version": "v4"})
+ params = {"count": "on"}
+ self.client.get(path, params)
+
+ self.assertEqual(
+ len(ctx.captured_queries),
+ 2,
+ msg=f"{len(ctx.captured_queries)} queries executed, 2 expected",
+ )
+
+ executed_queries = [query["sql"] for query in ctx.captured_queries]
+ expected_queries = [
+ 'FROM "auth_user" WHERE "auth_user"."id" =',
+ 'SELECT COUNT(*) AS "__count"',
+ ]
+ for executed_query, expected_fragment in zip(
+ executed_queries, expected_queries
+ ):
+ self.assertIn(
+ expected_fragment,
+ executed_query,
+ msg=f"Expected query fragment not found: {expected_fragment}",
+ )
+
+ def test_standard_request_no_count_query(
+ self, mock_logging_prefix
+ ) -> None:
+ """
+ Check that a v4 API request without param `count=on` doesn't perform
+ a count query.
+ """
+ with CaptureQueriesContext(connection) as ctx:
+ path = reverse("docket-list", kwargs={"version": "v4"})
+ self.client.get(path)
+
+ executed_queries = [query["sql"] for query in ctx.captured_queries]
+ for sql in executed_queries:
+ self.assertNotIn(
+ 'SELECT COUNT(*) AS "__count"',
+ sql,
+ msg="Unexpected COUNT query found in standard request.",
+ )
+
class ApiEventCreationTestCase(TestCase):
"""Check that events are created properly."""
@@ -2775,3 +2825,100 @@ async def test_avoid_logging_not_successful_webhook_events(
self.assertEqual(await webhook_events.acount(), 2)
# Confirm no milestone event should be created.
self.assertEqual(await milestone_events.acount(), 0)
+
+
+class CountParameterTests(TestCase):
+ @classmethod
+ def setUpTestData(cls) -> None:
+ cls.user_1 = UserProfileWithParentsFactory.create(
+ user__username="recap-user",
+ user__password=make_password("password"),
+ )
+ permissions = Permission.objects.filter(
+ codename__in=["has_recap_api_access", "has_recap_upload_access"]
+ )
+ cls.user_1.user.user_permissions.add(*permissions)
+
+ cls.court_canb = CourtFactory(id="canb")
+ cls.court_cand = CourtFactory(id="cand")
+
+ cls.url = reverse("docket-list", kwargs={"version": "v4"})
+
+ for i in range(7):
+ DocketFactory(
+ court=cls.court_canb,
+ source=Docket.RECAP,
+ pacer_case_id=str(100 + i),
+ )
+ for i in range(5):
+ DocketFactory(
+ court=cls.court_cand,
+ source=Docket.HARVARD,
+ pacer_case_id=str(200 + i),
+ )
+
+ def setUp(self):
+ self.client = make_client(self.user_1.user.pk)
+
+ async def test_count_on_returns_only_count(self):
+ """
+ Test that when 'count=on' is specified, the API returns only the count.
+ """
+ params = {"count": "on"}
+ response = await self.client.get(self.url, params)
+
+ self.assertEqual(response.status_code, 200)
+ # The response should only contain the 'count' key
+ self.assertEqual(list(response.data.keys()), ["count"])
+ self.assertIsInstance(response.data["count"], int)
+ # The count should match the total number of dockets
+ expected_count = await Docket.objects.acount()
+ self.assertEqual(response.data["count"], expected_count)
+
+ async def test_standard_response_includes_count_url(self):
+ """
+ Test that the standard response includes a 'count' key with the count URL.
+ """
+ response = await self.client.get(self.url)
+
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("count", response.data)
+ count_url = response.data["count"]
+ self.assertIsInstance(count_url, str)
+ self.assertIn("count=on", count_url)
+
+ async def test_invalid_count_parameter(self):
+ """
+ Test that invalid 'count' parameter values are handled appropriately.
+ """
+ params = {"count": "invalid"}
+ response = await self.client.get(self.url, params)
+
+ self.assertEqual(response.status_code, 200)
+ # The response should be the standard paginated response
+ self.assertIn("results", response.data)
+ self.assertIsInstance(response.data["results"], list)
+
+ async def test_count_with_filters(self):
+ """
+ Test that the count returned matches the filters applied.
+ """
+ params = {"court": "canb", "source": Docket.RECAP, "count": "on"}
+ response = await self.client.get(self.url, params)
+
+ self.assertEqual(response.status_code, 200)
+ expected_count = await Docket.objects.filter(
+ court__id="canb",
+ source=Docket.RECAP,
+ ).acount()
+ self.assertEqual(response.data["count"], expected_count)
+
+ async def test_count_with_no_results(self):
+ """
+ Test that 'count=on' returns zero when no results match the filters.
+ """
+ params = {"court": "cand", "source": Docket.RECAP, "count": "on"}
+ response = await self.client.get(self.url, params)
+
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.data["count"], 0)