diff --git a/elasticsearch_dsl/response/__init__.py b/elasticsearch_dsl/response/__init__.py index 7af054b5b..482400a4d 100644 --- a/elasticsearch_dsl/response/__init__.py +++ b/elasticsearch_dsl/response/__init__.py @@ -90,6 +90,38 @@ def aggs(self): super(AttrDict, self).__setattr__("_aggs", aggs) return self._aggs + def search_after(self): + """ + Return a ``Search`` instance that retrieves the next page of results. + + This method provides an easy way to paginate a long list of results using + the ``search_after`` option. For example:: + + page_size = 20 + s = Search()[:page_size].sort("date") + + while True: + # get a page of results + r = await s.execute() + + # do something with this page of results + + # exit the loop if we reached the end + if len(r.hits) < page_size: + break + + # get a search object with the next page of results + s = r.search_after() + + Note that the ``search_after`` option requires the search to have an + explicit ``sort`` order. + """ + if len(self.hits) == 0: + raise ValueError("Cannot use search_after when there are no search results") + if not hasattr(self.hits[-1].meta, "sort"): + raise ValueError("Cannot use search_after when results are not sorted") + return self._search.extra(search_after=self.hits[-1].meta.sort) + class AggResponse(AttrDict): def __init__(self, aggs, search, data): diff --git a/elasticsearch_dsl/search_base.py b/elasticsearch_dsl/search_base.py index d54b6b925..5680778cb 100644 --- a/elasticsearch_dsl/search_base.py +++ b/elasticsearch_dsl/search_base.py @@ -760,6 +760,36 @@ def suggest(self, name, text, **kwargs): s._suggest[name].update(kwargs) return s + def search_after(self): + """ + Return a ``Search`` instance that retrieves the next page of results. + + This method provides an easy way to paginate a long list of results using + the ``search_after`` option. For example:: + + page_size = 20 + s = Search()[:page_size].sort("date") + + while True: + # get a page of results + r = await s.execute() + + # do something with this page of results + + # exit the loop if we reached the end + if len(r.hits) < page_size: + break + + # get a search object with the next page of results + s = s.search_after() + + Note that the ``search_after`` option requires the search to have an + explicit ``sort`` order. + """ + if not hasattr(self, "_response"): + raise ValueError("A search must be executed before using search_after") + return self._response.search_after() + def to_dict(self, count=False, **kwargs): """ Serialize the search into the dictionary that will be sent over as the diff --git a/tests/test_integration/_async/test_search.py b/tests/test_integration/_async/test_search.py index 7fc56c870..6d6a5ab98 100644 --- a/tests/test_integration/_async/test_search.py +++ b/tests/test_integration/_async/test_search.py @@ -125,6 +125,60 @@ async def test_scan_iterates_through_all_docs(async_data_client): assert {d["_id"] for d in FLAT_DATA} == {c.meta.id for c in commits} +@pytest.mark.asyncio +async def test_search_after(async_data_client): + page_size = 7 + s = AsyncSearch(index="flat-git")[:page_size].sort("authored_date") + commits = [] + while True: + r = await s.execute() + commits += r.hits + if len(r.hits) < page_size: + break + s = r.search_after() + + assert 52 == len(commits) + assert {d["_id"] for d in FLAT_DATA} == {c.meta.id for c in commits} + + +@pytest.mark.asyncio +async def test_search_after_no_search(async_data_client): + s = AsyncSearch(index="flat-git") + with raises( + ValueError, match="A search must be executed before using search_after" + ): + await s.search_after() + await s.count() + with raises( + ValueError, match="A search must be executed before using search_after" + ): + await s.search_after() + + +@pytest.mark.asyncio +async def test_search_after_no_sort(async_data_client): + s = AsyncSearch(index="flat-git") + r = await s.execute() + with raises( + ValueError, match="Cannot use search_after when results are not sorted" + ): + await r.search_after() + + +@pytest.mark.asyncio +async def test_search_after_no_results(async_data_client): + s = AsyncSearch(index="flat-git")[:100].sort("authored_date") + r = await s.execute() + assert 52 == len(r.hits) + s = r.search_after() + r = await s.execute() + assert 0 == len(r.hits) + with raises( + ValueError, match="Cannot use search_after when there are no search results" + ): + await r.search_after() + + @pytest.mark.asyncio async def test_response_is_cached(async_data_client): s = Repository.search() diff --git a/tests/test_integration/_sync/test_search.py b/tests/test_integration/_sync/test_search.py index b31ef8b3d..09c318369 100644 --- a/tests/test_integration/_sync/test_search.py +++ b/tests/test_integration/_sync/test_search.py @@ -117,6 +117,60 @@ def test_scan_iterates_through_all_docs(data_client): assert {d["_id"] for d in FLAT_DATA} == {c.meta.id for c in commits} +@pytest.mark.sync +def test_search_after(data_client): + page_size = 7 + s = Search(index="flat-git")[:page_size].sort("authored_date") + commits = [] + while True: + r = s.execute() + commits += r.hits + if len(r.hits) < page_size: + break + s = r.search_after() + + assert 52 == len(commits) + assert {d["_id"] for d in FLAT_DATA} == {c.meta.id for c in commits} + + +@pytest.mark.sync +def test_search_after_no_search(data_client): + s = Search(index="flat-git") + with raises( + ValueError, match="A search must be executed before using search_after" + ): + s.search_after() + s.count() + with raises( + ValueError, match="A search must be executed before using search_after" + ): + s.search_after() + + +@pytest.mark.sync +def test_search_after_no_sort(data_client): + s = Search(index="flat-git") + r = s.execute() + with raises( + ValueError, match="Cannot use search_after when results are not sorted" + ): + r.search_after() + + +@pytest.mark.sync +def test_search_after_no_results(data_client): + s = Search(index="flat-git")[:100].sort("authored_date") + r = s.execute() + assert 52 == len(r.hits) + s = r.search_after() + r = s.execute() + assert 0 == len(r.hits) + with raises( + ValueError, match="Cannot use search_after when there are no search results" + ): + r.search_after() + + @pytest.mark.sync def test_response_is_cached(data_client): s = Repository.search()