diff --git a/elasticsearch_dsl/search.py b/elasticsearch_dsl/search.py index 8a5efd712..4c7a6036a 100644 --- a/elasticsearch_dsl/search.py +++ b/elasticsearch_dsl/search.py @@ -666,8 +666,8 @@ def count(self): Return the number of hits matching the query and filters. Note that only the actual number is returned. """ - if hasattr(self, '_response'): - return self._response.hits.total + if hasattr(self, '_response') and self._response.hits.total.relation == 'eq': + return self._response.hits.total.value es = connections.get_connection(self._using) diff --git a/setup.py b/setup.py index 907c5fe16..9de16c359 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ "mock", "pytest>=3.0.0", "pytest-cov", + "pytest-mock", "pytz", "coverage<5.0.0" ] diff --git a/test_elasticsearch_dsl/test_integration/test_count.py b/test_elasticsearch_dsl/test_integration/test_count.py index ce8efa0fa..2326b174a 100644 --- a/test_elasticsearch_dsl/test_integration/test_count.py +++ b/test_elasticsearch_dsl/test_integration/test_count.py @@ -1,9 +1,24 @@ from elasticsearch_dsl.search import Search, Q + def test_count_all(data_client): s = Search(using=data_client).index('git') assert 53 == s.count() + +def test_count_prefetch(data_client, mocker): + mocker.spy(data_client, 'count') + + search = Search(using=data_client).index('git') + search.execute() + assert search.count() == 53 + assert data_client.count.call_count == 0 + + search._response.hits.total.relation = 'gte' + assert search.count() == 53 + assert data_client.count.call_count == 1 + + def test_count_filter(data_client): s = Search(using=data_client).index('git').filter(~Q('exists', field='parent_shas')) # initial commit + repo document diff --git a/test_elasticsearch_dsl/test_search.py b/test_elasticsearch_dsl/test_search.py index 40a28dd98..592a932d4 100644 --- a/test_elasticsearch_dsl/test_search.py +++ b/test_elasticsearch_dsl/test_search.py @@ -35,12 +35,6 @@ def test_iter_iterates_over_hits(): assert [1, 2, 3] == list(s) -def test_count_uses_cache(): - s = search.Search() - s._response = utils.AttrDict({'hits': {'total': 42}}) - - assert 42 == s.count() - def test_cache_isnt_cloned(): s = search.Search() s._response = object() @@ -544,4 +538,4 @@ def test_update_from_dict(): 'id', 'name' ] - } == s.to_dict() \ No newline at end of file + } == s.to_dict()