Skip to content

Commit

Permalink
Added Response.search_after() method (#1829)
Browse files Browse the repository at this point in the history
* Added Response.search_after() method

* add match clause to pytest.raises

(cherry picked from commit 891ba7c)
  • Loading branch information
miguelgrinberg authored and github-actions[bot] committed May 21, 2024
1 parent d0c2c9b commit dcb3d4b
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 0 deletions.
32 changes: 32 additions & 0 deletions elasticsearch_dsl/response/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,38 @@ def aggs(self):
super(AttrDict, self).__setattr__("_aggs", aggs)
return self._aggs

def search_after(self):
"""
Return a ``Search`` instance that retrieves the next page of results.
This method provides an easy way to paginate a long list of results using
the ``search_after`` option. For example::
page_size = 20
s = Search()[:page_size].sort("date")
while True:
# get a page of results
r = await s.execute()
# do something with this page of results
# exit the loop if we reached the end
if len(r.hits) < page_size:
break
# get a search object with the next page of results
s = r.search_after()
Note that the ``search_after`` option requires the search to have an
explicit ``sort`` order.
"""
if len(self.hits) == 0:
raise ValueError("Cannot use search_after when there are no search results")
if not hasattr(self.hits[-1].meta, "sort"):
raise ValueError("Cannot use search_after when results are not sorted")
return self._search.extra(search_after=self.hits[-1].meta.sort)


class AggResponse(AttrDict):
def __init__(self, aggs, search, data):
Expand Down
30 changes: 30 additions & 0 deletions elasticsearch_dsl/search_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,36 @@ def suggest(self, name, text, **kwargs):
s._suggest[name].update(kwargs)
return s

def search_after(self):
"""
Return a ``Search`` instance that retrieves the next page of results.
This method provides an easy way to paginate a long list of results using
the ``search_after`` option. For example::
page_size = 20
s = Search()[:page_size].sort("date")
while True:
# get a page of results
r = await s.execute()
# do something with this page of results
# exit the loop if we reached the end
if len(r.hits) < page_size:
break
# get a search object with the next page of results
s = s.search_after()
Note that the ``search_after`` option requires the search to have an
explicit ``sort`` order.
"""
if not hasattr(self, "_response"):
raise ValueError("A search must be executed before using search_after")
return self._response.search_after()

def to_dict(self, count=False, **kwargs):
"""
Serialize the search into the dictionary that will be sent over as the
Expand Down
54 changes: 54 additions & 0 deletions tests/test_integration/_async/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,60 @@ async def test_scan_iterates_through_all_docs(async_data_client):
assert {d["_id"] for d in FLAT_DATA} == {c.meta.id for c in commits}


@pytest.mark.asyncio
async def test_search_after(async_data_client):
page_size = 7
s = AsyncSearch(index="flat-git")[:page_size].sort("authored_date")
commits = []
while True:
r = await s.execute()
commits += r.hits
if len(r.hits) < page_size:
break
s = r.search_after()

assert 52 == len(commits)
assert {d["_id"] for d in FLAT_DATA} == {c.meta.id for c in commits}


@pytest.mark.asyncio
async def test_search_after_no_search(async_data_client):
s = AsyncSearch(index="flat-git")
with raises(
ValueError, match="A search must be executed before using search_after"
):
await s.search_after()
await s.count()
with raises(
ValueError, match="A search must be executed before using search_after"
):
await s.search_after()


@pytest.mark.asyncio
async def test_search_after_no_sort(async_data_client):
s = AsyncSearch(index="flat-git")
r = await s.execute()
with raises(
ValueError, match="Cannot use search_after when results are not sorted"
):
await r.search_after()


@pytest.mark.asyncio
async def test_search_after_no_results(async_data_client):
s = AsyncSearch(index="flat-git")[:100].sort("authored_date")
r = await s.execute()
assert 52 == len(r.hits)
s = r.search_after()
r = await s.execute()
assert 0 == len(r.hits)
with raises(
ValueError, match="Cannot use search_after when there are no search results"
):
await r.search_after()


@pytest.mark.asyncio
async def test_response_is_cached(async_data_client):
s = Repository.search()
Expand Down
54 changes: 54 additions & 0 deletions tests/test_integration/_sync/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,60 @@ def test_scan_iterates_through_all_docs(data_client):
assert {d["_id"] for d in FLAT_DATA} == {c.meta.id for c in commits}


@pytest.mark.sync
def test_search_after(data_client):
page_size = 7
s = Search(index="flat-git")[:page_size].sort("authored_date")
commits = []
while True:
r = s.execute()
commits += r.hits
if len(r.hits) < page_size:
break
s = r.search_after()

assert 52 == len(commits)
assert {d["_id"] for d in FLAT_DATA} == {c.meta.id for c in commits}


@pytest.mark.sync
def test_search_after_no_search(data_client):
s = Search(index="flat-git")
with raises(
ValueError, match="A search must be executed before using search_after"
):
s.search_after()
s.count()
with raises(
ValueError, match="A search must be executed before using search_after"
):
s.search_after()


@pytest.mark.sync
def test_search_after_no_sort(data_client):
s = Search(index="flat-git")
r = s.execute()
with raises(
ValueError, match="Cannot use search_after when results are not sorted"
):
r.search_after()


@pytest.mark.sync
def test_search_after_no_results(data_client):
s = Search(index="flat-git")[:100].sort("authored_date")
r = s.execute()
assert 52 == len(r.hits)
s = r.search_after()
r = s.execute()
assert 0 == len(r.hits)
with raises(
ValueError, match="Cannot use search_after when there are no search results"
):
r.search_after()


@pytest.mark.sync
def test_response_is_cached(data_client):
s = Repository.search()
Expand Down

0 comments on commit dcb3d4b

Please sign in to comment.