diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java index 01ef94c428a41..28a7193646001 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java @@ -402,8 +402,7 @@ private SearchHits getHits(ReducedQueryPhase reducedQueryPhase, boolean ignoreFr hits.add(searchHit); } } - return new SearchHits(hits.toArray(new SearchHit[hits.size()]), reducedQueryPhase.totalHits, - reducedQueryPhase.maxScore); + return new SearchHits(hits.toArray(new SearchHit[hits.size()]), reducedQueryPhase.totalHits, reducedQueryPhase.maxScore); } /** @@ -443,7 +442,8 @@ private ReducedQueryPhase reducedQueryPhase(Collection aggregationsL public static final class ReducedQueryPhase { // the sum of all hits across all reduces shards - final long totalHits; + final TotalHits totalHits; // the number of returned hits (doc IDs) across all reduces shards final long fetchHits; // the max score across all reduces hits or {@link Float#NaN} if no hits returned @@ -575,7 +576,7 @@ public static final class ReducedQueryPhase { // sort value formats used to sort / format the result final DocValueFormat[] sortValueFormats; - ReducedQueryPhase(long totalHits, long fetchHits, float maxScore, boolean timedOut, Boolean terminatedEarly, Suggest suggest, + ReducedQueryPhase(TotalHits totalHits, long fetchHits, float maxScore, boolean timedOut, Boolean terminatedEarly, Suggest suggest, InternalAggregations aggregations, SearchProfileShardResults shardResults, ScoreDoc[] scoreDocs, SortField[] sortFields, DocValueFormat[] sortValueFormats, int numReducePhases, boolean isSortedByField, int size, int from, boolean isEmptyResult) { @@ -748,8 +749,8 @@ public ReducedQueryPhase reduce() { static final class TopDocsStats { final boolean trackTotalHits; - long totalHits; - TotalHits.Relation totalHitsRelation = TotalHits.Relation.EQUAL_TO; + private long totalHits; + private TotalHits.Relation totalHitsRelation; long fetchHits; float maxScore = Float.NEGATIVE_INFINITY; @@ -759,7 +760,12 @@ static final class TopDocsStats { TopDocsStats(boolean trackTotalHits) { this.trackTotalHits = trackTotalHits; - this.totalHits = trackTotalHits ? 0 : -1; + this.totalHits = 0; + this.totalHitsRelation = trackTotalHits ? Relation.EQUAL_TO : Relation.GREATER_THAN_OR_EQUAL_TO; + } + + TotalHits getTotalHits() { + return trackTotalHits ? new TotalHits(totalHits, totalHitsRelation) : null; } void add(TopDocsAndMaxScore topDocs) { diff --git a/server/src/main/java/org/elasticsearch/index/SearchSlowLog.java b/server/src/main/java/org/elasticsearch/index/SearchSlowLog.java index 32c527f06ff3b..abd67a47049d3 100644 --- a/server/src/main/java/org/elasticsearch/index/SearchSlowLog.java +++ b/server/src/main/java/org/elasticsearch/index/SearchSlowLog.java @@ -163,7 +163,13 @@ public String toString() { .append(" ") .append("took[").append(TimeValue.timeValueNanos(tookInNanos)).append("], ") .append("took_millis[").append(TimeUnit.NANOSECONDS.toMillis(tookInNanos)).append("], ") - .append("total_hits[").append(context.queryResult().getTotalHits()).append("], "); + .append("total_hits["); + if (context.queryResult().getTotalHits() != null) { + sb.append(context.queryResult().getTotalHits()); + } else { + sb.append("-1"); + } + sb.append("], "); if (context.getQueryShardContext().getTypes() == null) { sb.append("types[], "); } else { diff --git a/server/src/main/java/org/elasticsearch/index/reindex/ClientScrollableHitSource.java b/server/src/main/java/org/elasticsearch/index/reindex/ClientScrollableHitSource.java index eb8c0e14f4343..7ba3013497990 100644 --- a/server/src/main/java/org/elasticsearch/index/reindex/ClientScrollableHitSource.java +++ b/server/src/main/java/org/elasticsearch/index/reindex/ClientScrollableHitSource.java @@ -198,7 +198,8 @@ private Response wrap(SearchResponse response) { } hits = unmodifiableList(hits); } - return new Response(response.isTimedOut(), failures, response.getHits().getTotalHits(), + long total = response.getHits().getTotalHits().value; + return new Response(response.isTimedOut(), failures, total, hits, response.getScrollId()); } diff --git a/server/src/main/java/org/elasticsearch/rest/action/cat/RestCountAction.java b/server/src/main/java/org/elasticsearch/rest/action/cat/RestCountAction.java index 840f4de5ea7b7..7dd0758c00869 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/cat/RestCountAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/cat/RestCountAction.java @@ -19,6 +19,7 @@ package org.elasticsearch.rest.action.cat; +import org.apache.lucene.search.TotalHits; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; @@ -60,7 +61,7 @@ protected void documentation(StringBuilder sb) { public RestChannelConsumer doCatRequest(final RestRequest request, final NodeClient client) { String[] indices = Strings.splitStringByCommaToArray(request.param("index")); SearchRequest countRequest = new SearchRequest(indices); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).trackTotalHits(true); countRequest.source(searchSourceBuilder); try { request.withContentOrSourceParamParserOrNull(parser -> { @@ -79,6 +80,7 @@ public RestChannelConsumer doCatRequest(final RestRequest request, final NodeCli return channel -> client.search(countRequest, new RestResponseListener(channel) { @Override public RestResponse buildResponse(SearchResponse countResponse) throws Exception { + assert countResponse.getHits().getTotalHits().relation == TotalHits.Relation.EQUAL_TO; return RestTable.buildResponse(buildTable(request, countResponse), channel); } }); @@ -96,7 +98,7 @@ protected Table getTableWithHeader(final RestRequest request) { private Table buildTable(RestRequest request, SearchResponse response) { Table table = getTableWithHeader(request); table.startRow(); - table.addCell(response.getHits().getTotalHits()); + table.addCell(response.getHits().getTotalHits().value); table.endRow(); return table; diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/RestCountAction.java b/server/src/main/java/org/elasticsearch/rest/action/search/RestCountAction.java index 99ea1c81fa956..cd09273396942 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/RestCountAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/RestCountAction.java @@ -108,7 +108,7 @@ public RestResponse buildResponse(SearchResponse response, XContentBuilder build if (terminateAfter != DEFAULT_TERMINATE_AFTER) { builder.field("terminated_early", response.isTerminatedEarly()); } - builder.field("count", response.getHits().getTotalHits()); + builder.field("count", response.getHits().getTotalHits().value); buildBroadcastShardsHeader(builder, request, response.getTotalShards(), response.getSuccessfulShards(), 0, response.getFailedShards(), response.getShardFailures()); diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/RestMultiSearchAction.java b/server/src/main/java/org/elasticsearch/rest/action/search/RestMultiSearchAction.java index d3a45fa727b26..01b46a07720a0 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/RestMultiSearchAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/RestMultiSearchAction.java @@ -40,7 +40,9 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import java.io.IOException; +import java.util.Arrays; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Set; @@ -48,7 +50,14 @@ import static org.elasticsearch.rest.RestRequest.Method.POST; public class RestMultiSearchAction extends BaseRestHandler { - private static final Set RESPONSE_PARAMS = Collections.singleton(RestSearchAction.TYPED_KEYS_PARAM); + private static final Set RESPONSE_PARAMS; + + static { + final Set responseParams = new HashSet<>( + Arrays.asList(RestSearchAction.TYPED_KEYS_PARAM, RestSearchAction.TOTAL_HIT_AS_INT_PARAM) + ); + RESPONSE_PARAMS = Collections.unmodifiableSet(responseParams); + } private static final DeprecationLogger deprecationLogger = new DeprecationLogger( LogManager.getLogger(RestMultiSearchAction.class)); diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchAction.java b/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchAction.java index 60fd77e46aa3f..c26b5ab7c1ef1 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchAction.java @@ -45,6 +45,7 @@ import java.io.IOException; import java.util.Arrays; import java.util.Collections; +import java.util.HashSet; import java.util.Set; import java.util.function.IntConsumer; @@ -54,8 +55,18 @@ import static org.elasticsearch.search.suggest.SuggestBuilders.termSuggestion; public class RestSearchAction extends BaseRestHandler { + /** + * Indicates whether hits.total should be rendered as an integer or an object + * in the rest search response. + */ + public static final String TOTAL_HIT_AS_INT_PARAM = "rest_total_hits_as_int"; public static final String TYPED_KEYS_PARAM = "typed_keys"; - private static final Set RESPONSE_PARAMS = Collections.singleton(TYPED_KEYS_PARAM); + private static final Set RESPONSE_PARAMS; + + static { + final Set responseParams = new HashSet<>(Arrays.asList(TYPED_KEYS_PARAM, TOTAL_HIT_AS_INT_PARAM)); + RESPONSE_PARAMS = Collections.unmodifiableSet(responseParams); + } private static final DeprecationLogger deprecationLogger = new DeprecationLogger(LogManager.getLogger(RestSearchAction.class)); static final String TYPES_DEPRECATION_MESSAGE = "[types removal]" + diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchScrollAction.java b/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchScrollAction.java index bc3b0ccb56ac3..6d2f0971ad770 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchScrollAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchScrollAction.java @@ -29,12 +29,16 @@ import org.elasticsearch.search.Scroll; import java.io.IOException; +import java.util.Collections; +import java.util.Set; import static org.elasticsearch.common.unit.TimeValue.parseTimeValue; import static org.elasticsearch.rest.RestRequest.Method.GET; import static org.elasticsearch.rest.RestRequest.Method.POST; public class RestSearchScrollAction extends BaseRestHandler { + private static final Set RESPONSE_PARAMS = Collections.singleton(RestSearchAction.TOTAL_HIT_AS_INT_PARAM); + public RestSearchScrollAction(Settings settings, RestController controller) { super(settings); @@ -70,4 +74,9 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC }}); return channel -> client.searchScroll(searchScrollRequest, new RestStatusToXContentListener<>(channel)); } + + @Override + protected Set responseParams() { + return RESPONSE_PARAMS; + } } diff --git a/server/src/main/java/org/elasticsearch/search/SearchHit.java b/server/src/main/java/org/elasticsearch/search/SearchHit.java index 4532385f31381..3d8ea3845464f 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchHit.java +++ b/server/src/main/java/org/elasticsearch/search/SearchHit.java @@ -117,7 +117,7 @@ public final class SearchHit implements Streamable, ToXContentObject, Iterable innerHits; - private SearchHit() { + SearchHit() { } @@ -792,6 +792,8 @@ public void readFrom(StreamInput in) throws IOException { SearchHits value = SearchHits.readSearchHits(in); innerHits.put(key, value); } + } else { + innerHits = null; } } diff --git a/server/src/main/java/org/elasticsearch/search/SearchHits.java b/server/src/main/java/org/elasticsearch/search/SearchHits.java index edbcb021331f5..01f4e9f880e54 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchHits.java +++ b/server/src/main/java/org/elasticsearch/search/SearchHits.java @@ -19,14 +19,18 @@ package org.elasticsearch.search; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.TotalHits.Relation; import org.elasticsearch.Version; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Streamable; -import org.elasticsearch.common.xcontent.ToXContent.Params; +import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.ToXContentFragment; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.rest.action.search.RestSearchAction; import java.io.IOException; import java.util.ArrayList; @@ -41,14 +45,14 @@ public final class SearchHits implements Streamable, ToXContentFragment, Iterabl public static SearchHits empty() { // We shouldn't use static final instance, since that could directly be returned by native transport clients - return new SearchHits(EMPTY, 0, 0); + return new SearchHits(EMPTY, new TotalHits(0, Relation.EQUAL_TO), 0); } public static final SearchHit[] EMPTY = new SearchHit[0]; private SearchHit[] hits; - public long totalHits; + private Total totalHits; private float maxScore; @@ -56,17 +60,18 @@ public static SearchHits empty() { } - public SearchHits(SearchHit[] hits, long totalHits, float maxScore) { + public SearchHits(SearchHit[] hits, @Nullable TotalHits totalHits, float maxScore) { this.hits = hits; - this.totalHits = totalHits; + this.totalHits = totalHits == null ? null : new Total(totalHits); this.maxScore = maxScore; } /** - * The total number of hits that matches the search request. + * The total number of hits for the query or null if the tracking of total hits + * is disabled in the request. */ - public long getTotalHits() { - return totalHits; + public TotalHits getTotalHits() { + return totalHits == null ? null : totalHits.in; } @@ -105,7 +110,15 @@ public static final class Fields { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(Fields.HITS); - builder.field(Fields.TOTAL, totalHits); + boolean totalHitAsInt = params.paramAsBoolean(RestSearchAction.TOTAL_HIT_AS_INT_PARAM, false); + if (totalHitAsInt) { + long total = totalHits == null ? -1 : totalHits.in.value; + builder.field(Fields.TOTAL, total); + } else if (totalHits != null) { + builder.startObject(Fields.TOTAL); + totalHits.toXContent(builder, params); + builder.endObject(); + } if (Float.isNaN(maxScore)) { builder.nullField(Fields.MAX_SCORE); } else { @@ -129,14 +142,16 @@ public static SearchHits fromXContent(XContentParser parser) throws IOException XContentParser.Token token = parser.currentToken(); String currentFieldName = null; List hits = new ArrayList<>(); - long totalHits = 0; + TotalHits totalHits = null; float maxScore = 0f; while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { currentFieldName = parser.currentName(); } else if (token.isValue()) { if (Fields.TOTAL.equals(currentFieldName)) { - totalHits = parser.longValue(); + // For BWC with nodes pre 7.0 + long value = parser.longValue(); + totalHits = value == -1 ? null : new TotalHits(value, Relation.EQUAL_TO); } else if (Fields.MAX_SCORE.equals(currentFieldName)) { maxScore = parser.floatValue(); } @@ -153,15 +168,17 @@ public static SearchHits fromXContent(XContentParser parser) throws IOException parser.skipChildren(); } } else if (token == XContentParser.Token.START_OBJECT) { - parser.skipChildren(); + if (Fields.TOTAL.equals(currentFieldName)) { + totalHits = parseTotalHitsFragment(parser); + } else { + parser.skipChildren(); + } } } - SearchHits searchHits = new SearchHits(hits.toArray(new SearchHit[hits.size()]), totalHits, - maxScore); + SearchHits searchHits = new SearchHits(hits.toArray(new SearchHit[hits.size()]), totalHits, maxScore); return searchHits; } - public static SearchHits readSearchHits(StreamInput in) throws IOException { SearchHits hits = new SearchHits(); hits.readFrom(in); @@ -170,16 +187,11 @@ public static SearchHits readSearchHits(StreamInput in) throws IOException { @Override public void readFrom(StreamInput in) throws IOException { - final boolean hasTotalHits; - if (in.getVersion().onOrAfter(Version.V_6_0_0_beta1)) { - hasTotalHits = in.readBoolean(); - } else { - hasTotalHits = true; - } - if (hasTotalHits) { - totalHits = in.readVLong(); + if (in.readBoolean()) { + totalHits = new Total(in); } else { - totalHits = -1; + // track_total_hits is false + totalHits = null; } maxScore = in.readFloat(); int size = in.readVInt(); @@ -195,16 +207,10 @@ public void readFrom(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { - final boolean hasTotalHits; - if (out.getVersion().onOrAfter(Version.V_6_0_0_beta1)) { - hasTotalHits = totalHits >= 0; - out.writeBoolean(hasTotalHits); - } else { - assert totalHits >= 0; - hasTotalHits = true; - } + final boolean hasTotalHits = totalHits != null; + out.writeBoolean(hasTotalHits); if (hasTotalHits) { - out.writeVLong(totalHits); + totalHits.writeTo(out); } out.writeFloat(maxScore); out.writeVInt(hits.length); @@ -228,6 +234,91 @@ public boolean equals(Object obj) { @Override public int hashCode() { - return Objects.hash(totalHits, maxScore, Arrays.hashCode(hits)); + return Objects.hash(totalHits, totalHits, maxScore, Arrays.hashCode(hits)); + } + + public static TotalHits parseTotalHitsFragment(XContentParser parser) throws IOException { + long value = -1; + Relation relation = null; + XContentParser.Token token; + String currentFieldName = null; + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = parser.currentName(); + } else if (token.isValue()) { + if ("value".equals(currentFieldName)) { + value = parser.longValue(); + } else if ("relation".equals(currentFieldName)) { + relation = parseRelation(parser.text()); + } + } else { + parser.skipChildren(); + } + } + return new TotalHits(value, relation); + } + + private static Relation parseRelation(String relation) { + if ("gte".equals(relation)) { + return Relation.GREATER_THAN_OR_EQUAL_TO; + } else if ("eq".equals(relation)) { + return Relation.EQUAL_TO; + } else { + throw new IllegalArgumentException("invalid total hits relation: " + relation); + } + } + + private static String printRelation(Relation relation) { + return relation == Relation.EQUAL_TO ? "eq" : "gte"; + } + + private static class Total implements Writeable, ToXContentFragment { + final TotalHits in; + + Total(StreamInput in) throws IOException { + final long value = in.readVLong(); + final Relation relation; + if (in.getVersion().onOrAfter(Version.V_7_0_0)) { + relation = in.readEnum(Relation.class); + } else { + relation = Relation.EQUAL_TO; + } + this.in = new TotalHits(value, relation); + } + + Total(TotalHits in) { + this.in = Objects.requireNonNull(in); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Total total = (Total) o; + return in.value == total.in.value && + in.relation == total.in.relation; + } + + @Override + public int hashCode() { + return Objects.hash(in.value, in.relation); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVLong(in.value); + if (out.getVersion().onOrAfter(Version.V_7_0_0)) { + out.writeEnum(in.relation); + } else { + assert in.relation == Relation.EQUAL_TO; + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("value", in.value); + builder.field("relation", printRelation(in.relation)); + return builder; + } } } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/InternalTopHits.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/InternalTopHits.java index 0c85191379fa9..ff6e10baee534 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/InternalTopHits.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/InternalTopHits.java @@ -161,7 +161,7 @@ public InternalAggregation doReduce(List aggregations, Redu assert reducedTopDocs.totalHits.relation == Relation.EQUAL_TO; return new InternalTopHits(name, this.from, this.size, new TopDocsAndMaxScore(reducedTopDocs, maxScore), - new SearchHits(hits, reducedTopDocs.totalHits.value, maxScore), pipelineAggregators(), getMetaData()); + new SearchHits(hits, reducedTopDocs.totalHits, maxScore), pipelineAggregators(), getMetaData()); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java index d5d081e96972c..ab2f864bfce35 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java @@ -28,7 +28,6 @@ import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.TotalHits; -import org.apache.lucene.search.TotalHits.Relation; import org.apache.lucene.search.Weight; import org.apache.lucene.util.BitSet; import org.elasticsearch.ExceptionsHelper; @@ -182,8 +181,7 @@ public void execute(SearchContext context) { } TotalHits totalHits = context.queryResult().getTotalHits(); - long totalHitsAsLong = totalHits.relation == Relation.EQUAL_TO ? totalHits.value : -1; - context.fetchResult().hits(new SearchHits(hits, totalHitsAsLong, context.queryResult().getMaxScore())); + context.fetchResult().hits(new SearchHits(hits, totalHits, context.queryResult().getMaxScore())); } catch (IOException e) { throw ExceptionsHelper.convertToElastic(e); }