From de77e610c1801fde6d52716dd1a1b46bb1c12903 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Tue, 6 Nov 2018 11:40:02 +0100 Subject: [PATCH] Add the ability to set the number of hits to track accurately In Lucene 8 searches can skip non-competitive hits if the total hit count is not requested. It is also possible to track the number of hits up to a certain threshold. This is a trade off to speed up searches while still being able to know a lower bound of the total hit count. This change adds the ability to set this threshold directly in the `track_total_hits` search option. A boolean value (`true`, `false`) indicates whether the total hit count should be tracked in the response. When set as an integer this option allows to compute a lower bound of the total hits while preserving the ability to skip non-competitive hits when enough hits have been collected. In order to ensure that the result is correctly interpreted this commit also adds a new section in the search response that indicates the number of tracked hits and whether the value is a lower bound (`gte`) or the exact count (`eq`): ``` GET /_search { "track_total_hits": 100, "query": { "term": { "title": "fast" } } } ``` ... will return: ``` { "_shards": ... "hits" : { "total" : -1, "tracked_total": { "value": 100, "relation": "gte" }, "max_score" : 0.42, "hits" : [] } } ``` Relates #33028 --- .../query-dsl/feature-query.asciidoc | 2 +- .../search/request/track-total-hits.asciidoc | 127 ++++++++++++++++++ docs/reference/search/uri-request.asciidoc | 5 +- .../test/search/230_track_total_hits.yml | 90 +++++++++++++ .../search/AbstractSearchAsyncAction.java | 7 +- .../action/search/SearchPhaseController.java | 62 +++++---- .../action/search/SearchRequestBuilder.java | 8 ++ .../rest/action/search/RestSearchAction.java | 8 +- .../search/DefaultSearchContext.java | 6 +- .../org/elasticsearch/search/SearchHits.java | 74 ++++++---- .../elasticsearch/search/SearchService.java | 2 +- .../search/TotalHitsWrapper.java | 103 ++++++++++++++ .../metrics/TopHitsAggregator.java | 2 +- .../search/builder/SearchSourceBuilder.java | 57 +++++--- .../search/fetch/FetchPhase.java | 4 +- .../internal/FilteredSearchContext.java | 4 +- .../internal/InternalSearchResponse.java | 7 +- .../search/internal/SearchContext.java | 5 +- .../search/query/QueryPhase.java | 4 +- .../search/query/TopDocsCollectorContext.java | 23 ++-- .../search/SearchPhaseControllerTests.java | 4 +- .../elasticsearch/search/SearchHitsTests.java | 6 +- .../search/query/QueryPhaseTests.java | 2 +- .../search/RandomSearchRequestGenerator.java | 6 +- .../elasticsearch/test/TestSearchContext.java | 6 +- 25 files changed, 515 insertions(+), 109 deletions(-) create mode 100644 docs/reference/search/request/track-total-hits.asciidoc create mode 100644 rest-api-spec/src/main/resources/rest-api-spec/test/search/230_track_total_hits.yml create mode 100644 server/src/main/java/org/elasticsearch/search/TotalHitsWrapper.java diff --git a/docs/reference/query-dsl/feature-query.asciidoc b/docs/reference/query-dsl/feature-query.asciidoc index 387278a432af6..6b60b95a05893 100644 --- a/docs/reference/query-dsl/feature-query.asciidoc +++ b/docs/reference/query-dsl/feature-query.asciidoc @@ -11,7 +11,7 @@ of the query. Compared to using <> or other ways to modify the score, this query has the benefit of being able to efficiently skip non-competitive hits when -<> is set to `false`. Speedups may be +<> is set to `false`. Speedups may be spectacular. Here is an example that indexes various features: diff --git a/docs/reference/search/request/track-total-hits.asciidoc b/docs/reference/search/request/track-total-hits.asciidoc new file mode 100644 index 0000000000000..179160a5d431d --- /dev/null +++ b/docs/reference/search/request/track-total-hits.asciidoc @@ -0,0 +1,127 @@ +[[search-request-track-total-hits]] +=== Track total hits + +The `track_total_hits` parameter allows you to configure the number of hits to +count accurately. +When set to `true` the search response will contain the total number of hits +that match the query: + +[source,js] +-------------------------------------------------- +GET /_search +{ + "track_total_hits": true, + "query" : { + "match_all" : {} + } +} +-------------------------------------------------- +// CONSOLE + +\... returns: + +[source,js] +-------------------------------------------------- +{ + "_shards": ... + "hits" : { + "total" : 2048, <1> + "max_score" : 1.0, + "hits" : [] + } +} +-------------------------------------------------- +// TESTRESPONSE[s/"_shards": \.\.\./"_shards": "$body._shards",/] +// TESTRESPONSE[s/"total": 2048/"total": $body.hits.total/] + +<1> The total number of hits that match the query. + +If you don't need to track the total number of hits you can set this option +to `false`. In such case the total number of hits is unknown and the search +can efficiently skip non-competitive hits if the query is sorted by relevancy: + +[source,js] +-------------------------------------------------- +GET /_search +{ + "track_total_hits": false, + "query": { + "term": { + "title": "fast" + } + } +} +-------------------------------------------------- +// CONSOLE + +\... returns: + +[source,js] +-------------------------------------------------- +{ + "_shards": ... + "hits" : { + "total" : -1, <1> + "max_score" : 0.42, + "hits" : [] + } +} +-------------------------------------------------- +// TESTRESPONSE[s/"_shards": \.\.\./"_shards": "$body._shards",/] +// TESTRESPONSE[s/"max_score": 0\.42/"max_score": $body.hits.max_score/] + +<1> The total number of hits is unknown. + +The total hit count can't be computed accurately without visiting all matches, +which is costly for queries that match lots of documents. Given that it is +often enough to have a lower bounds of the number of hits, such as +"there are more than 1000 hits", it is also possible to set `track_total_hits` +as an integer that represents the number of hits to count accurately. When this +option is set as a number the search response will contain a new section called +`tracked_total` that contains the number of tracked hits (`tracked_total.value`) +and a relation (`tracked_total.relation`) that indicates if the `value` is + accurate (`eq`) or a lower bound of the total hit count (`gte`): + +[source,js] +-------------------------------------------------- +GET /_search +{ + "track_total_hits": 100, + "query": { + "term": { + "title": "fast" + } + } +} +-------------------------------------------------- +// CONSOLE + +\... returns: + +[source,js] +-------------------------------------------------- +{ + "_shards": ... + "hits" : { + "total" : -1, <1> + "tracked_total": { <2> + "value": 100, + "relation": "gte" + }, + "max_score" : 0.42, + "hits" : [] + } +} +-------------------------------------------------- +// TESTRESPONSE[s/"_shards": \.\.\./"_shards": "$body._shards",/] +// TESTRESPONSE[s/"max_score": 0\.42/"max_score": $body.hits.max_score/] +// TESTRESPONSE[s/"value": 100/"value": $body.hits.tracked_total.value/] +// TESTRESPONSE[s/"relation": "gte"/"relation": "$body.hits.tracked_total.relation"/] + +<1> The total number of hits is unknown. +<2> There are at least (`gte`) 100 documents that match the query. + +Search can also skip non-competitive hits if the query is sorted by +relevancy but the optimization kicks in only after collecting at least +$`track_total_hits` documents. This is a good trade off to speed up searches +if you don't need the accurate number of hits after a certain threshold. \ No newline at end of file diff --git a/docs/reference/search/uri-request.asciidoc b/docs/reference/search/uri-request.asciidoc index bfc50e774bff8..c74285ddd1941 100644 --- a/docs/reference/search/uri-request.asciidoc +++ b/docs/reference/search/uri-request.asciidoc @@ -100,8 +100,11 @@ scores and return them as part of each hit. |`track_total_hits` |Set to `false` in order to disable the tracking of the total number of hits that match the query. -(see <> for more details). Defaults to true. +It also accepts an integer which in this case represents the number of hits +to count accurately. +(see the <> documentation +for more details). |`timeout` |A search timeout, bounding the search request to be executed within the specified time value and bail with the hits accumulated up to diff --git a/rest-api-spec/src/main/resources/rest-api-spec/test/search/230_track_total_hits.yml b/rest-api-spec/src/main/resources/rest-api-spec/test/search/230_track_total_hits.yml new file mode 100644 index 0000000000000..3f8bdb7d88e7b --- /dev/null +++ b/rest-api-spec/src/main/resources/rest-api-spec/test/search/230_track_total_hits.yml @@ -0,0 +1,90 @@ +--- +"Track total hits": + + - skip: + version: " - 6.99.99" + reason: track_total_hits was introduced in 7.0.0 + + - do: + search: + index: test_1 + track_total_hits: false + + - match: { hits.total: -1 } + - is_false: "hits.tracked_total" + + - do: + search: + index: test_1 + track_total_hits: true + + - match: { hits.total: 0 } + - is_false: "hits.tracked_total" + + - do: + search: + index: test_1 + track_total_hits: 10 + + - match: { hits.total: -1 } + - match: { hits.tracked_total.value: 0 } + - match: { hits.tracked_total.relation: "eq" } + + - do: + index: + index: test_1 + id: 1 + body: {} + + - do: + index: + index: test_1 + id: 2 + body: {} + + - do: + index: + index: test_1 + id: 3 + body: {} + + - do: + index: + index: test_1 + id: 4 + body: {} + + - do: + indices.refresh: {} + + - do: + search: + index: test_1 + + - match: { hits.total: 4 } + + - do: + search: + index: test_1 + track_total_hits: false + + - match: { hits.total: -1 } + - is_false: "hits.tracked_total" + + - do: + search: + index: test_1 + track_total_hits: 10 + + - match: { hits.total: -1 } + - match: { hits.tracked_total.value: 4 } + - match: { hits.tracked_total.relation: "eq" } + + - do: + search: + index: test_1 + track_total_hits: 3 + + - match: { hits.total: -1 } + - match: { hits.tracked_total.value: 3 } + - match: { hits.tracked_total.relation: "gte" } \ No newline at end of file diff --git a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java index 8aa847f753682..217c3ad8a5046 100644 --- a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java @@ -34,6 +34,7 @@ import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.search.internal.InternalSearchResponse; +import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.ShardSearchTransportRequest; import org.elasticsearch.transport.Transport; @@ -113,8 +114,10 @@ public final void start() { if (getNumShards() == 0) { //no search shards to search on, bail with empty response //(it happens with search across _all with no indices around and consistent with broadcast operations) - listener.onResponse(new SearchResponse(InternalSearchResponse.empty(), null, 0, 0, 0, buildTookInMillis(), - ShardSearchFailure.EMPTY_ARRAY, clusters)); + int trackTotalHitsThreshold = request.source() != null ? + request.source().trackTotalHitsThreshold() : SearchContext.DEFAULT_TRACK_TOTAL_HITS; + listener.onResponse(new SearchResponse(InternalSearchResponse.empty(trackTotalHitsThreshold), null, 0, 0, 0, + buildTookInMillis(), ShardSearchFailure.EMPTY_ARRAY, clusters)); return; } executePhase(this); diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java index 2c24d2852217e..51b6e40024929 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java @@ -49,6 +49,7 @@ import org.elasticsearch.search.dfs.DfsSearchResult; import org.elasticsearch.search.fetch.FetchSearchResult; import org.elasticsearch.search.internal.InternalSearchResponse; +import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.profile.ProfileShardResult; import org.elasticsearch.search.profile.SearchProfileShardResults; import org.elasticsearch.search.query.QuerySearchResult; @@ -310,7 +311,7 @@ public IntArrayList[] fillDocIdsToLoad(int numShards, ScoreDoc[] shardDocs) { public InternalSearchResponse merge(boolean ignoreFrom, ReducedQueryPhase reducedQueryPhase, Collection fetchResults, IntFunction resultsLookup) { if (reducedQueryPhase.isEmptyResult) { - return InternalSearchResponse.empty(); + return InternalSearchResponse.empty(reducedQueryPhase.trackTotalHitsThreshold); } ScoreDoc[] sortedDocs = reducedQueryPhase.scoreDocs; SearchHits hits = getHits(reducedQueryPhase, ignoreFrom, fetchResults, resultsLookup); @@ -400,8 +401,8 @@ private SearchHits getHits(ReducedQueryPhase reducedQueryPhase, boolean ignoreFr hits.add(searchHit); } } - return new SearchHits(hits.toArray(new SearchHit[hits.size()]), reducedQueryPhase.totalHits, - reducedQueryPhase.maxScore); + return new SearchHits(hits.toArray(new SearchHit[hits.size()]), reducedQueryPhase.totalHits, reducedQueryPhase.maxScore, + reducedQueryPhase.trackedHits); } /** @@ -409,14 +410,14 @@ private SearchHits getHits(ReducedQueryPhase reducedQueryPhase, boolean ignoreFr * @param queryResults a list of non-null query shard results */ public ReducedQueryPhase reducedQueryPhase(Collection queryResults, boolean isScrollRequest) { - return reducedQueryPhase(queryResults, isScrollRequest, true); + return reducedQueryPhase(queryResults, isScrollRequest, SearchContext.DEFAULT_TRACK_TOTAL_HITS); } /** * Reduces the given query results and consumes all aggregations and profile results. * @param queryResults a list of non-null query shard results */ - public ReducedQueryPhase reducedQueryPhase(Collection queryResults, boolean isScrollRequest, boolean trackTotalHits) { + public ReducedQueryPhase reducedQueryPhase(Collection queryResults, boolean isScrollRequest, int trackTotalHits) { return reducedQueryPhase(queryResults, null, new ArrayList<>(), new TopDocsStats(trackTotalHits), 0, isScrollRequest); } @@ -440,8 +441,8 @@ private ReducedQueryPhase reducedQueryPhase(Collection aggregationsL } public static final class ReducedQueryPhase { + final int trackTotalHitsThreshold; // the sum of all hits across all reduces shards final long totalHits; + // the number of tracked hits across all reduces shards + final TotalHits trackedHits; // the number of returned hits (doc IDs) across all reduces shards final long fetchHits; // the max score across all reduces hits or {@link Float#NaN} if no hits returned @@ -571,19 +574,32 @@ public static final class ReducedQueryPhase { // sort value formats used to sort / format the result final DocValueFormat[] sortValueFormats; - ReducedQueryPhase(long totalHits, long fetchHits, float maxScore, boolean timedOut, Boolean terminatedEarly, Suggest suggest, + ReducedQueryPhase(TopDocsStats stats, boolean timedOut, Boolean terminatedEarly, Suggest suggest, InternalAggregations aggregations, SearchProfileShardResults shardResults, ScoreDoc[] scoreDocs, SortField[] sortFields, DocValueFormat[] sortValueFormats, int numReducePhases, boolean isSortedByField, int size, int from, boolean isEmptyResult) { if (numReducePhases <= 0) { throw new IllegalArgumentException("at least one reduce phase must have been applied but was: " + numReducePhases); } - this.totalHits = totalHits; - this.fetchHits = fetchHits; - if (Float.isInfinite(maxScore)) { + this.trackTotalHitsThreshold = stats.trackTotalHitsThreshold; + if (stats.trackTotalHitsThreshold == 0) { + this.totalHits = -1; + this.trackedHits = null; + } else if (stats.trackTotalHitsThreshold != Integer.MAX_VALUE) { + final long total = Math.min(stats.trackTotalHitsThreshold, stats.totalHits); + final TotalHits.Relation relation = total == stats.totalHits ? Relation.EQUAL_TO : Relation.GREATER_THAN_OR_EQUAL_TO; + this.trackedHits = stats.trackTotalHitsThreshold > 0 ? new TotalHits(total, relation) : null; + this.totalHits = -1; + } else { + assert stats.totalHitsRelation == Relation.EQUAL_TO; + this.totalHits = stats.totalHits; + this.trackedHits = null; + } + this.fetchHits = stats.fetchHits; + if (Float.isInfinite(stats.maxScore)) { this.maxScore = Float.NaN; } else { - this.maxScore = maxScore; + this.maxScore = stats.maxScore; } this.timedOut = timedOut; this.terminatedEarly = terminatedEarly; @@ -724,7 +740,7 @@ InitialSearchPhase.ArraySearchPhaseResults newSearchPhaseResu boolean isScrollRequest = request.scroll() != null; final boolean hasAggs = source != null && source.aggregations() != null; final boolean hasTopDocs = source == null || source.size() != 0; - final boolean trackTotalHits = source == null || source.trackTotalHits(); + final int trackTotalHits = source == null ? SearchContext.DEFAULT_TRACK_TOTAL_HITS : source.trackTotalHitsThreshold(); if (isScrollRequest == false && (hasAggs || hasTopDocs)) { // no incremental reduce if scroll is used - we only hit a single shard or sometimes more... @@ -742,23 +758,23 @@ public ReducedQueryPhase reduce() { } static final class TopDocsStats { - final boolean trackTotalHits; + final int trackTotalHitsThreshold; long totalHits; TotalHits.Relation totalHitsRelation = TotalHits.Relation.EQUAL_TO; long fetchHits; float maxScore = Float.NEGATIVE_INFINITY; TopDocsStats() { - this(true); + this(SearchContext.DEFAULT_TRACK_TOTAL_HITS); } - TopDocsStats(boolean trackTotalHits) { - this.trackTotalHits = trackTotalHits; - this.totalHits = trackTotalHits ? 0 : -1; + TopDocsStats(int trackTotalHitsThreshold) { + this.trackTotalHitsThreshold = trackTotalHitsThreshold; + this.totalHits = trackTotalHitsThreshold > 0 ? 0 : -1; } void add(TopDocsAndMaxScore topDocs) { - if (trackTotalHits) { + if (trackTotalHitsThreshold > 0) { totalHits += topDocs.topDocs.totalHits.value; if (topDocs.topDocs.totalHits.relation == Relation.GREATER_THAN_OR_EQUAL_TO) { totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchRequestBuilder.java b/server/src/main/java/org/elasticsearch/action/search/SearchRequestBuilder.java index 9389edeb345fc..2e2dc52ed3d05 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchRequestBuilder.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchRequestBuilder.java @@ -376,6 +376,14 @@ public SearchRequestBuilder setTrackTotalHits(boolean trackTotalHits) { return this; } + /** + * Indicates the number of hits to count accurately. + */ + public SearchRequestBuilder setTrackTotalHitsThreshold(int trackTotalHits) { + sourceBuilder().trackTotalHitsThreshold(trackTotalHits); + return this; + } + /** * Adds stored fields to load and return (note, it must be stored) as part of the search request. * To disable the stored fields entirely (source and metadata fields) use {@code storedField("_none_")}. diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchAction.java b/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchAction.java index 3efa9e633de30..c9887712b13ac 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchAction.java @@ -23,6 +23,7 @@ import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.common.Booleans; import org.elasticsearch.common.Strings; import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.settings.Settings; @@ -227,7 +228,12 @@ private static void parseSearchSource(final SearchSourceBuilder searchSourceBuil } if (request.hasParam("track_total_hits")) { - searchSourceBuilder.trackTotalHits(request.paramAsBoolean("track_total_hits", true)); + if (Booleans.isBoolean(request.param("track_total_hits"))) { + searchSourceBuilder.trackTotalHits(request.paramAsBoolean("track_total_hits", true)); + } else { + searchSourceBuilder.trackTotalHitsThreshold(request.paramAsInt("track_total_hits", + SearchContext.DEFAULT_TRACK_TOTAL_HITS)); + } } String sSorts = request.param("sort"); diff --git a/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java b/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java index d681a186892db..5aaf34b4db9f4 100644 --- a/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java @@ -116,7 +116,7 @@ final class DefaultSearchContext extends SearchContext { private SortAndFormats sort; private Float minimumScore; private boolean trackScores = false; // when sorting, track scores as well... - private boolean trackTotalHits = true; + private int trackTotalHits = SearchContext.DEFAULT_TRACK_TOTAL_HITS; private FieldDoc searchAfter; private CollapseContext collapse; private boolean lowLevelCancellation; @@ -558,13 +558,13 @@ public boolean trackScores() { } @Override - public SearchContext trackTotalHits(boolean trackTotalHits) { + public SearchContext trackTotalHits(int trackTotalHits) { this.trackTotalHits = trackTotalHits; return this; } @Override - public boolean trackTotalHits() { + public int trackTotalHits() { return trackTotalHits; } diff --git a/server/src/main/java/org/elasticsearch/search/SearchHits.java b/server/src/main/java/org/elasticsearch/search/SearchHits.java index edbcb021331f5..d716b19aae539 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchHits.java +++ b/server/src/main/java/org/elasticsearch/search/SearchHits.java @@ -19,14 +19,15 @@ package org.elasticsearch.search; +import org.apache.lucene.search.TotalHits; import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Streamable; -import org.elasticsearch.common.xcontent.ToXContent.Params; import org.elasticsearch.common.xcontent.ToXContentFragment; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.internal.SearchContext; import java.io.IOException; import java.util.ArrayList; @@ -38,10 +39,15 @@ import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; public final class SearchHits implements Streamable, ToXContentFragment, Iterable { - public static SearchHits empty() { + return empty(SearchContext.DEFAULT_TRACK_TOTAL_HITS); + } + + public static SearchHits empty(int trackTotalHits) { + final TotalHits trackedHits = trackTotalHits > 0 && trackTotalHits < Integer.MAX_VALUE ? + new TotalHits(0, TotalHits.Relation.EQUAL_TO) : null; // We shouldn't use static final instance, since that could directly be returned by native transport clients - return new SearchHits(EMPTY, 0, 0); + return new SearchHits(EMPTY, 0, 0, trackedHits); } public static final SearchHit[] EMPTY = new SearchHit[0]; @@ -50,14 +56,21 @@ public static SearchHits empty() { public long totalHits; + private TotalHitsWrapper trackedHits; + private float maxScore; SearchHits() { - } public SearchHits(SearchHit[] hits, long totalHits, float maxScore) { + this(hits, totalHits, maxScore, null); + } + + public SearchHits(SearchHit[] hits, long totalHits, float maxScore, TotalHits trackedTotalHits) { + assert trackedTotalHits == null || totalHits == -1; this.hits = hits; + this.trackedHits = trackedTotalHits == null ? null : new TotalHitsWrapper(trackedTotalHits.value, trackedTotalHits.relation); this.totalHits = totalHits; this.maxScore = maxScore; } @@ -69,6 +82,13 @@ public long getTotalHits() { return totalHits; } + /** + * The total number of hits that matches the search request. + */ + public TotalHits getTrackedHits() { + return trackedHits == null ? null : trackedHits.totalHits; + } + /** * The maximum score of this query. @@ -99,6 +119,7 @@ public Iterator iterator() { public static final class Fields { public static final String HITS = "hits"; public static final String TOTAL = "total"; + public static final String TRACKED_TOTAL = "tracked_total"; public static final String MAX_SCORE = "max_score"; } @@ -106,6 +127,11 @@ public static final class Fields { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(Fields.HITS); builder.field(Fields.TOTAL, totalHits); + if (trackedHits != null) { + builder.startObject(Fields.TRACKED_TOTAL); + trackedHits.toXContent(builder, params); + builder.endObject(); + } if (Float.isNaN(maxScore)) { builder.nullField(Fields.MAX_SCORE); } else { @@ -131,6 +157,7 @@ public static SearchHits fromXContent(XContentParser parser) throws IOException List hits = new ArrayList<>(); long totalHits = 0; float maxScore = 0f; + TotalHitsWrapper trackedTotalHits = null; while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { currentFieldName = parser.currentName(); @@ -153,12 +180,15 @@ public static SearchHits fromXContent(XContentParser parser) throws IOException parser.skipChildren(); } } else if (token == XContentParser.Token.START_OBJECT) { - parser.skipChildren(); + if (Fields.TRACKED_TOTAL.equals(currentFieldName)) { + trackedTotalHits = TotalHitsWrapper.fromXContent(parser); + } else { + parser.skipChildren(); + } } } - SearchHits searchHits = new SearchHits(hits.toArray(new SearchHit[hits.size()]), totalHits, - maxScore); - return searchHits; + return new SearchHits(hits.toArray(new SearchHit[hits.size()]), totalHits, maxScore, + trackedTotalHits == null ? null : trackedTotalHits.totalHits); } @@ -170,12 +200,7 @@ public static SearchHits readSearchHits(StreamInput in) throws IOException { @Override public void readFrom(StreamInput in) throws IOException { - final boolean hasTotalHits; - if (in.getVersion().onOrAfter(Version.V_6_0_0_beta1)) { - hasTotalHits = in.readBoolean(); - } else { - hasTotalHits = true; - } + final boolean hasTotalHits = in.readBoolean(); if (hasTotalHits) { totalHits = in.readVLong(); } else { @@ -191,18 +216,17 @@ public void readFrom(StreamInput in) throws IOException { hits[i] = SearchHit.readSearchHit(in); } } + if (in.getVersion().onOrAfter(Version.V_7_0_0_alpha1)) { + trackedHits = in.readOptionalWriteable(TotalHitsWrapper::new); + } else { + trackedHits = null; + } } @Override public void writeTo(StreamOutput out) throws IOException { - final boolean hasTotalHits; - if (out.getVersion().onOrAfter(Version.V_6_0_0_beta1)) { - hasTotalHits = totalHits >= 0; - out.writeBoolean(hasTotalHits); - } else { - assert totalHits >= 0; - hasTotalHits = true; - } + final boolean hasTotalHits = totalHits >= 0; + out.writeBoolean(hasTotalHits); if (hasTotalHits) { out.writeVLong(totalHits); } @@ -213,6 +237,9 @@ public void writeTo(StreamOutput out) throws IOException { hit.writeTo(out); } } + if (out.getVersion().onOrAfter(Version.V_7_0_0_alpha1)) { + out.writeOptionalWriteable(trackedHits); + } } @Override @@ -223,11 +250,12 @@ public boolean equals(Object obj) { SearchHits other = (SearchHits) obj; return Objects.equals(totalHits, other.totalHits) && Objects.equals(maxScore, other.maxScore) + && Objects.equals(trackedHits, other.trackedHits) && Arrays.equals(hits, other.hits); } @Override public int hashCode() { - return Objects.hash(totalHits, maxScore, Arrays.hashCode(hits)); + return Objects.hash(totalHits, maxScore, Arrays.hashCode(hits), trackedHits); } } diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index 51750c3953ad6..c589adb5c0d5c 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -790,7 +790,7 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc if (source.trackTotalHits() == false && context.scrollContext() != null) { throw new SearchContextException(context, "disabling [track_total_hits] is not allowed in a scroll context"); } - context.trackTotalHits(source.trackTotalHits()); + context.trackTotalHits(source.trackTotalHitsThreshold()); if (source.minScore() != null) { context.minimumScore(source.minScore()); } diff --git a/server/src/main/java/org/elasticsearch/search/TotalHitsWrapper.java b/server/src/main/java/org/elasticsearch/search/TotalHitsWrapper.java new file mode 100644 index 0000000000000..54896ca46ed5d --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/TotalHitsWrapper.java @@ -0,0 +1,103 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search; + +import org.apache.lucene.search.TotalHits; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentFragment; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +class TotalHitsWrapper implements Writeable, ToXContentFragment { + private static ConstructingObjectParser PARSER = new ConstructingObjectParser<>("tracked_total", + (a) -> new TotalHitsWrapper((long) a[0], (TotalHits.Relation) a[1])); + + static { + PARSER.declareLong(ConstructingObjectParser.constructorArg(), new ParseField("value")); + PARSER.declareField(ConstructingObjectParser.constructorArg(), + (p, c) -> parseRelation(p.text()), new ParseField("relation"), ObjectParser.ValueType.STRING); + } + + final TotalHits totalHits; + + TotalHitsWrapper(long value, TotalHits.Relation relation) { + this.totalHits = new TotalHits(value, relation); + } + + TotalHitsWrapper(StreamInput in) throws IOException { + this.totalHits = new TotalHits(in.readVLong(), in.readEnum(TotalHits.Relation.class)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVLong(totalHits.value); + out.writeEnum(totalHits.relation); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("value", totalHits.value); + builder.field("relation", getRelation()); + return builder; + } + + public static TotalHitsWrapper fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TotalHitsWrapper that = (TotalHitsWrapper) o; + return totalHits.value == that.totalHits.value && + totalHits.relation == that.totalHits.relation; + } + + @Override + public int hashCode() { + return Objects.hash(totalHits.value, totalHits.relation); + } + + private String getRelation() { + return totalHits.relation == TotalHits.Relation.EQUAL_TO ? "eq" : "gte"; + } + + private static TotalHits.Relation parseRelation(String rel) { + switch (rel) { + case "gte": + return TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; + + case "eq": + return TotalHits.Relation.EQUAL_TO; + + default: + throw new IllegalArgumentException("invalid relation:[" + rel + "] for [tracked_total]"); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/TopHitsAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/TopHitsAggregator.java index c017eb4a5e3bc..57eb5191b6b67 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/TopHitsAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/TopHitsAggregator.java @@ -217,7 +217,7 @@ public InternalTopHits buildEmptyAggregation() { topDocs = Lucene.EMPTY_TOP_DOCS; } return new InternalTopHits(name, subSearchContext.from(), subSearchContext.size(), new TopDocsAndMaxScore(topDocs, Float.NaN), - SearchHits.empty(), pipelineAggregators(), metaData()); + SearchHits.empty(subSearchContext.trackTotalHits()), pipelineAggregators(), metaData()); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java index a199ce3a37776..09b75d27252a3 100644 --- a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java @@ -22,6 +22,7 @@ import org.apache.logging.log4j.LogManager; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.Version; +import org.elasticsearch.common.Booleans; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParsingException; @@ -110,7 +111,6 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R public static final ParseField SEARCH_AFTER = new ParseField("search_after"); public static final ParseField COLLAPSE = new ParseField("collapse"); public static final ParseField SLICE = new ParseField("slice"); - public static final ParseField ALL_FIELDS_FIELDS = new ParseField("all_fields"); public static SearchSourceBuilder fromXContent(XContentParser parser) throws IOException { return fromXContent(parser, true); @@ -152,7 +152,7 @@ public static HighlightBuilder highlight() { private boolean trackScores = false; - private boolean trackTotalHits = true; + private int trackTotalHitsThreshold = SearchContext.DEFAULT_TRACK_TOTAL_HITS; private SearchAfterBuilder searchAfterBuilder; @@ -249,10 +249,10 @@ public SearchSourceBuilder(StreamInput in) throws IOException { searchAfterBuilder = in.readOptionalWriteable(SearchAfterBuilder::new); sliceBuilder = in.readOptionalWriteable(SliceBuilder::new); collapse = in.readOptionalWriteable(CollapseBuilder::new); - if (in.getVersion().onOrAfter(Version.V_6_0_0_beta1)) { - trackTotalHits = in.readBoolean(); + if (in.getVersion().onOrAfter(Version.V_7_0_0_alpha1)) { + trackTotalHitsThreshold = in.readVInt(); } else { - trackTotalHits = true; + trackTotalHitsThreshold = in.readBoolean() ? Integer.MAX_VALUE : -1; } } @@ -312,8 +312,10 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalWriteable(searchAfterBuilder); out.writeOptionalWriteable(sliceBuilder); out.writeOptionalWriteable(collapse); - if (out.getVersion().onOrAfter(Version.V_6_0_0_beta1)) { - out.writeBoolean(trackTotalHits); + if (out.getVersion().onOrAfter(Version.V_7_0_0_alpha1)) { + out.writeVInt(trackTotalHitsThreshold); + } else { + out.writeBoolean(trackTotalHitsThreshold == Integer.MAX_VALUE); } } @@ -536,11 +538,23 @@ public boolean trackScores() { * Indicates if the total hit count for the query should be tracked. */ public boolean trackTotalHits() { - return trackTotalHits; + return trackTotalHitsThreshold == Integer.MAX_VALUE; } public SearchSourceBuilder trackTotalHits(boolean trackTotalHits) { - this.trackTotalHits = trackTotalHits; + this.trackTotalHitsThreshold = trackTotalHits ? Integer.MAX_VALUE : 0; + return this; + } + + /** + * Indicates the number of hits to count accurately. + */ + public int trackTotalHitsThreshold() { + return trackTotalHitsThreshold; + } + + public SearchSourceBuilder trackTotalHitsThreshold(int trackTotalHitsThreshold) { + this.trackTotalHitsThreshold = trackTotalHitsThreshold; return this; } @@ -596,9 +610,9 @@ public SearchSourceBuilder collapse(CollapseBuilder collapse) { */ public SearchSourceBuilder aggregation(AggregationBuilder aggregation) { if (aggregations == null) { - aggregations = AggregatorFactories.builder(); + aggregations = AggregatorFactories.builder(); } - aggregations.addAggregator(aggregation); + aggregations.addAggregator(aggregation); return this; } @@ -607,9 +621,9 @@ public SearchSourceBuilder aggregation(AggregationBuilder aggregation) { */ public SearchSourceBuilder aggregation(PipelineAggregationBuilder aggregation) { if (aggregations == null) { - aggregations = AggregatorFactories.builder(); + aggregations = AggregatorFactories.builder(); } - aggregations.addPipelineAggregator(aggregation); + aggregations.addPipelineAggregator(aggregation); return this; } @@ -979,7 +993,7 @@ private SearchSourceBuilder shallowCopy(QueryBuilder queryBuilder, QueryBuilder rewrittenBuilder.terminateAfter = terminateAfter; rewrittenBuilder.timeout = timeout; rewrittenBuilder.trackScores = trackScores; - rewrittenBuilder.trackTotalHits = trackTotalHits; + rewrittenBuilder.trackTotalHitsThreshold = trackTotalHitsThreshold; rewrittenBuilder.version = version; rewrittenBuilder.collapse = collapse; return rewrittenBuilder; @@ -1025,7 +1039,12 @@ public void parseXContent(XContentParser parser, boolean checkTrailingTokens) th } else if (TRACK_SCORES_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { trackScores = parser.booleanValue(); } else if (TRACK_TOTAL_HITS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - trackTotalHits = parser.booleanValue(); + if (token == XContentParser.Token.VALUE_BOOLEAN || + (token == XContentParser.Token.VALUE_STRING && Booleans.isBoolean(parser.text()))) { + trackTotalHitsThreshold = parser.booleanValue() ? Integer.MAX_VALUE : 0; + } else { + trackTotalHitsThreshold = parser.intValue(); + } } else if (_SOURCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { fetchSourceContext = FetchSourceContext.fromXContent(parser); } else if (STORED_FIELDS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { @@ -1231,8 +1250,10 @@ public XContentBuilder innerToXContent(XContentBuilder builder, Params params) t builder.field(TRACK_SCORES_FIELD.getPreferredName(), true); } - if (trackTotalHits == false) { + if (trackTotalHitsThreshold == 0) { builder.field(TRACK_TOTAL_HITS_FIELD.getPreferredName(), false); + } else if (trackTotalHitsThreshold != Integer.MAX_VALUE) { + builder.field(TRACK_TOTAL_HITS_FIELD.getPreferredName(), trackTotalHitsThreshold); } if (searchAfterBuilder != null) { @@ -1500,7 +1521,7 @@ public int hashCode() { return Objects.hash(aggregations, explain, fetchSourceContext, docValueFields, storedFieldsContext, from, highlightBuilder, indexBoosts, minScore, postQueryBuilder, queryBuilder, rescoreBuilders, scriptFields, size, sorts, searchAfterBuilder, sliceBuilder, stats, suggestBuilder, terminateAfter, timeout, trackScores, version, - profile, extBuilders, collapse, trackTotalHits); + profile, extBuilders, collapse, trackTotalHitsThreshold); } @Override @@ -1538,7 +1559,7 @@ public boolean equals(Object obj) { && Objects.equals(profile, other.profile) && Objects.equals(extBuilders, other.extBuilders) && Objects.equals(collapse, other.collapse) - && Objects.equals(trackTotalHits, other.trackTotalHits); + && Objects.equals(trackTotalHitsThreshold, other.trackTotalHitsThreshold); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java index 1b4cbbbd882bc..239b8e2db18fc 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java @@ -26,7 +26,6 @@ import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.TotalHits; -import org.apache.lucene.search.TotalHits.Relation; import org.apache.lucene.search.Weight; import org.apache.lucene.util.BitSet; import org.elasticsearch.ExceptionsHelper; @@ -174,8 +173,7 @@ public void execute(SearchContext context) { } TotalHits totalHits = context.queryResult().getTotalHits(); - long totalHitsAsLong = totalHits.relation == Relation.EQUAL_TO ? totalHits.value : -1; - context.fetchResult().hits(new SearchHits(hits, totalHitsAsLong, context.queryResult().getMaxScore())); + context.fetchResult().hits(new SearchHits(hits, totalHits.value, context.queryResult().getMaxScore(), null)); } catch (IOException e) { throw ExceptionsHelper.convertToElastic(e); } diff --git a/server/src/main/java/org/elasticsearch/search/internal/FilteredSearchContext.java b/server/src/main/java/org/elasticsearch/search/internal/FilteredSearchContext.java index 4f95fcc0195c0..fc4be822ade4a 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/FilteredSearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/internal/FilteredSearchContext.java @@ -322,12 +322,12 @@ public boolean trackScores() { } @Override - public SearchContext trackTotalHits(boolean trackTotalHits) { + public SearchContext trackTotalHits(int trackTotalHits) { return in.trackTotalHits(trackTotalHits); } @Override - public boolean trackTotalHits() { + public int trackTotalHits() { return in.trackTotalHits(); } diff --git a/server/src/main/java/org/elasticsearch/search/internal/InternalSearchResponse.java b/server/src/main/java/org/elasticsearch/search/internal/InternalSearchResponse.java index 48ab4914e386c..cdc4262f5ca47 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/InternalSearchResponse.java +++ b/server/src/main/java/org/elasticsearch/search/internal/InternalSearchResponse.java @@ -35,9 +35,12 @@ * {@link SearchResponseSections} subclass that can be serialized over the wire. */ public class InternalSearchResponse extends SearchResponseSections implements Writeable, ToXContentFragment { - public static InternalSearchResponse empty() { - return new InternalSearchResponse(SearchHits.empty(), null, null, null, false, null, 1); + return empty(SearchContext.DEFAULT_TRACK_TOTAL_HITS); + } + + public static InternalSearchResponse empty(int trackTotalHits) { + return new InternalSearchResponse(SearchHits.empty(trackTotalHits), null, null, null, false, null, 1); } public InternalSearchResponse(SearchHits hits, InternalAggregations aggregations, Suggest suggest, diff --git a/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java b/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java index 70a52c39ee110..02d9da31b1765 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java @@ -82,6 +82,7 @@ public abstract class SearchContext extends AbstractRefCounted implements Releasable { public static final int DEFAULT_TERMINATE_AFTER = 0; + public static final int DEFAULT_TRACK_TOTAL_HITS = Integer.MAX_VALUE; private Map> clearables = null; private final AtomicBoolean closed = new AtomicBoolean(false); private InnerHitsContext innerHitsContext; @@ -240,12 +241,12 @@ public InnerHitsContext innerHits() { public abstract boolean trackScores(); - public abstract SearchContext trackTotalHits(boolean trackTotalHits); + public abstract SearchContext trackTotalHits(int trackTotalHits); /** * Indicates if the total hit count for the query should be tracked. Defaults to {@code true} */ - public abstract boolean trackTotalHits(); + public abstract int trackTotalHits(); public abstract SearchContext searchAfter(FieldDoc searchAfter); diff --git a/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java b/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java index 3523966b7eda4..25536bc6e6256 100644 --- a/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java +++ b/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java @@ -158,7 +158,7 @@ static boolean execute(SearchContext searchContext, } // ... and stop collecting after ${size} matches searchContext.terminateAfter(searchContext.size()); - searchContext.trackTotalHits(false); + searchContext.trackTotalHits(0); } else if (canEarlyTerminate(reader, searchContext.sort())) { // now this gets interesting: since the search sort is a prefix of the index sort, we can directly // skip to the desired doc @@ -169,7 +169,7 @@ static boolean execute(SearchContext searchContext, .build(); query = bq; } - searchContext.trackTotalHits(false); + searchContext.trackTotalHits(0); } } } diff --git a/server/src/main/java/org/elasticsearch/search/query/TopDocsCollectorContext.java b/server/src/main/java/org/elasticsearch/search/query/TopDocsCollectorContext.java index fcf70a4f98c05..61c36e1f2bcaa 100644 --- a/server/src/main/java/org/elasticsearch/search/query/TopDocsCollectorContext.java +++ b/server/src/main/java/org/elasticsearch/search/query/TopDocsCollectorContext.java @@ -96,9 +96,9 @@ static class EmptyTopDocsCollectorContext extends TopDocsCollectorContext { * @param hasFilterCollector True if the collector chain contains a filter */ private EmptyTopDocsCollectorContext(IndexReader reader, Query query, - boolean trackTotalHits, boolean hasFilterCollector) throws IOException { + int trackTotalHits, boolean hasFilterCollector) throws IOException { super(REASON_SEARCH_COUNT, 0); - if (trackTotalHits) { + if (trackTotalHits > 0) { TotalHitCountCollector hitCountCollector = new TotalHitCountCollector(); // implicit total hit counts are valid only when there is no filter collector in the chain int hitCount = hasFilterCollector ? -1 : shortcutTotalHitCount(reader, query); @@ -207,7 +207,7 @@ private SimpleTopDocsCollectorContext(IndexReader reader, @Nullable ScoreDoc searchAfter, int numHits, boolean trackMaxScore, - boolean trackTotalHits, + int trackTotalHits, boolean hasFilterCollector) throws IOException { super(REASON_SEARCH_TOP_HITS, numHits); this.sortAndFormats = sortAndFormats; @@ -215,19 +215,14 @@ private SimpleTopDocsCollectorContext(IndexReader reader, // implicit total hit counts are valid only when there is no filter collector in the chain final int hitCount = hasFilterCollector ? -1 : shortcutTotalHitCount(reader, query); final TopDocsCollector topDocsCollector; - if (hitCount == -1 && trackTotalHits) { - topDocsCollector = createCollector(sortAndFormats, numHits, searchAfter, Integer.MAX_VALUE); + if (hitCount == -1) { + topDocsCollector = createCollector(sortAndFormats, numHits, searchAfter, trackTotalHits); topDocsSupplier = new CachedSupplier<>(topDocsCollector::topDocs); totalHitsSupplier = () -> topDocsSupplier.get().totalHits; } else { topDocsCollector = createCollector(sortAndFormats, numHits, searchAfter, 1); // don't compute hit counts via the collector topDocsSupplier = new CachedSupplier<>(topDocsCollector::topDocs); - if (hitCount == -1) { - assert trackTotalHits == false; - totalHitsSupplier = () -> new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO); - } else { - totalHitsSupplier = () -> new TotalHits(hitCount, TotalHits.Relation.EQUAL_TO); - } + totalHitsSupplier = () -> new TotalHits(hitCount, TotalHits.Relation.EQUAL_TO); } MaxScoreCollector maxScoreCollector = null; if (sortAndFormats == null) { @@ -273,10 +268,9 @@ private ScrollingTopDocsCollectorContext(IndexReader reader, int numHits, boolean trackMaxScore, int numberOfShards, - boolean trackTotalHits, boolean hasFilterCollector) throws IOException { super(reader, query, sortAndFormats, scrollContext.lastEmittedDoc, numHits, trackMaxScore, - trackTotalHits, hasFilterCollector); + Integer.MAX_VALUE, hasFilterCollector); this.scrollContext = Objects.requireNonNull(scrollContext); this.numberOfShards = numberOfShards; } @@ -356,8 +350,7 @@ static TopDocsCollectorContext createTopDocsCollectorContext(SearchContext searc // no matter what the value of from is int numDocs = Math.min(searchContext.size(), totalNumDocs); return new ScrollingTopDocsCollectorContext(reader, query, searchContext.scrollContext(), - searchContext.sort(), numDocs, searchContext.trackScores(), searchContext.numberOfShards(), - searchContext.trackTotalHits(), hasFilterCollector); + searchContext.sort(), numDocs, searchContext.trackScores(), searchContext.numberOfShards(), hasFilterCollector); } else if (searchContext.collapse() != null) { boolean trackScores = searchContext.sort() == null ? true : searchContext.trackScores(); int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs); diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java index f4cb7d224d2aa..4c0682927c328 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java @@ -148,7 +148,7 @@ public void testMerge() throws IOException { int nShards = randomIntBetween(1, 20); int queryResultSize = randomBoolean() ? 0 : randomIntBetween(1, nShards * 2); AtomicArray queryResults = generateQueryResults(nShards, suggestions, queryResultSize, false); - for (boolean trackTotalHits : new boolean[] {true, false}) { + for (int trackTotalHits : new int[] {0, 200, Integer.MAX_VALUE}) { SearchPhaseController.ReducedQueryPhase reducedQueryPhase = searchPhaseController.reducedQueryPhase(queryResults.asList(), false, trackTotalHits); AtomicArray searchPhaseResultAtomicArray = generateFetchResults(nShards, reducedQueryPhase.scoreDocs, @@ -156,7 +156,7 @@ public void testMerge() throws IOException { InternalSearchResponse mergedResponse = searchPhaseController.merge(false, reducedQueryPhase, searchPhaseResultAtomicArray.asList(), searchPhaseResultAtomicArray::get); - if (trackTotalHits == false) { + if (trackTotalHits != Integer.MAX_VALUE) { assertThat(mergedResponse.hits.totalHits, equalTo(-1L)); } int suggestSize = 0; diff --git a/server/src/test/java/org/elasticsearch/search/SearchHitsTests.java b/server/src/test/java/org/elasticsearch/search/SearchHitsTests.java index a42804692fbf3..a1b64684b8fef 100644 --- a/server/src/test/java/org/elasticsearch/search/SearchHitsTests.java +++ b/server/src/test/java/org/elasticsearch/search/SearchHitsTests.java @@ -19,6 +19,7 @@ package org.elasticsearch.search; +import org.apache.lucene.search.TotalHits; import org.apache.lucene.util.TestUtil; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; @@ -46,9 +47,10 @@ public static SearchHits createTestItem() { for (int i = 0; i < searchHits; i++) { hits[i] = SearchHitTests.createTestItem(false); // creating random innerHits could create loops } - long totalHits = frequently() ? TestUtil.nextLong(random(), 0, Long.MAX_VALUE) : -1; float maxScore = frequently() ? randomFloat() : Float.NaN; - return new SearchHits(hits, totalHits, maxScore); + return randomBoolean() ? new SearchHits(hits, TestUtil.nextLong(random(), 0, Long.MAX_VALUE), maxScore) : + new SearchHits(hits, -1, maxScore, new TotalHits(TestUtil.nextLong(random(), 0, Long.MAX_VALUE), + randomFrom(random(), TotalHits.Relation.values()))); } public void testFromXContent() throws IOException { diff --git a/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java b/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java index 7e9c0153b728f..b98751eac366a 100644 --- a/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java @@ -453,7 +453,7 @@ public void testIndexSortingEarlyTermination() throws Exception { { contextSearcher = getAssertingEarlyTerminationSearcher(reader, 1); - context.trackTotalHits(false); + context.trackTotalHits(0); QueryPhase.execute(context, contextSearcher, checkCancelled -> {}); assertNull(context.queryResult().terminatedEarly()); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); diff --git a/test/framework/src/main/java/org/elasticsearch/search/RandomSearchRequestGenerator.java b/test/framework/src/main/java/org/elasticsearch/search/RandomSearchRequestGenerator.java index d534af5789448..9b687f15e1caa 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/RandomSearchRequestGenerator.java +++ b/test/framework/src/main/java/org/elasticsearch/search/RandomSearchRequestGenerator.java @@ -147,7 +147,11 @@ public static SearchSourceBuilder randomSearchSourceBuilder( builder.terminateAfter(randomIntBetween(1, 100000)); } if (randomBoolean()) { - builder.trackTotalHits(randomBoolean()); + if (randomBoolean()) { + builder.trackTotalHits(randomBoolean()); + } else { + builder.trackTotalHitsThreshold(randomIntBetween(0, Integer.MAX_VALUE)); + } } switch(randomInt(2)) { diff --git a/test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java b/test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java index 9d03383561614..2553537a56bff 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java +++ b/test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java @@ -83,7 +83,7 @@ public class TestSearchContext extends SearchContext { SearchTask task; SortAndFormats sort; boolean trackScores = false; - boolean trackTotalHits = true; + int trackTotalHits = Integer.MAX_VALUE; ContextIndexSearcher searcher; int size; @@ -364,13 +364,13 @@ public boolean trackScores() { } @Override - public SearchContext trackTotalHits(boolean trackTotalHits) { + public SearchContext trackTotalHits(int trackTotalHits) { this.trackTotalHits = trackTotalHits; return this; } @Override - public boolean trackTotalHits() { + public int trackTotalHits() { return trackTotalHits; }