diff --git a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java index 920353abcf808..1bbca35cb9a54 100644 --- a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java @@ -109,8 +109,9 @@ private void innerRun() throws IOException { // query AND fetch optimization finishPhase.run(); } else { - final IntArrayList[] docIdsToLoad = searchPhaseController.fillDocIdsToLoad(numShards, reducedQueryPhase.scoreDocs); - if (reducedQueryPhase.scoreDocs.length == 0) { // no docs to fetch -- sidestep everything and return + ScoreDoc[] scoreDocs = reducedQueryPhase.sortedTopDocs.scoreDocs; + final IntArrayList[] docIdsToLoad = searchPhaseController.fillDocIdsToLoad(numShards, scoreDocs); + if (scoreDocs.length == 0) { // no docs to fetch -- sidestep everything and return phaseResults.stream() .map(SearchPhaseResult::queryResult) .forEach(this::releaseIrrelevantSearchContext); // we have to release contexts here to free up resources diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java index b2b3ee6dd2bd7..82f7760c1abdc 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java @@ -211,18 +211,23 @@ static SortedTopDocs sortDocs(boolean ignoreFrom, Collection fetchResults, IntFunction resultsLookup) { - final boolean sorted = reducedQueryPhase.isSortedByField; - ScoreDoc[] sortedDocs = reducedQueryPhase.scoreDocs; + SortedTopDocs sortedTopDocs = reducedQueryPhase.sortedTopDocs; int sortScoreIndex = -1; - if (sorted) { - for (int i = 0; i < reducedQueryPhase.sortField.length; i++) { - if (reducedQueryPhase.sortField[i].getType() == SortField.Type.SCORE) { + if (sortedTopDocs.isSortedByField) { + SortField[] sortFields = sortedTopDocs.sortFields; + for (int i = 0; i < sortFields.length; i++) { + if (sortFields[i].getType() == SortField.Type.SCORE) { sortScoreIndex = i; } } @@ -362,12 +367,12 @@ private SearchHits getHits(ReducedQueryPhase reducedQueryPhase, boolean ignoreFr int from = ignoreFrom ? 0 : reducedQueryPhase.from; int numSearchHits = (int) Math.min(reducedQueryPhase.fetchHits - from, reducedQueryPhase.size); // with collapsing we can have more fetch hits than sorted docs - numSearchHits = Math.min(sortedDocs.length, numSearchHits); + numSearchHits = Math.min(sortedTopDocs.scoreDocs.length, numSearchHits); // merge hits List hits = new ArrayList<>(); if (!fetchResults.isEmpty()) { for (int i = 0; i < numSearchHits; i++) { - ScoreDoc shardDoc = sortedDocs[i]; + ScoreDoc shardDoc = sortedTopDocs.scoreDocs[i]; SearchPhaseResult fetchResultProvider = resultsLookup.apply(shardDoc.shardIndex); if (fetchResultProvider == null) { // this can happen if we are hitting a shard failure during the fetch phase @@ -381,21 +386,21 @@ private SearchHits getHits(ReducedQueryPhase reducedQueryPhase, boolean ignoreFr assert index < fetchResult.hits().getHits().length : "not enough hits fetched. index [" + index + "] length: " + fetchResult.hits().getHits().length; SearchHit searchHit = fetchResult.hits().getHits()[index]; - if (sorted == false) { - searchHit.score(shardDoc.score); - } searchHit.shard(fetchResult.getSearchShardTarget()); - if (sorted) { + if (sortedTopDocs.isSortedByField) { FieldDoc fieldDoc = (FieldDoc) shardDoc; searchHit.sortValues(fieldDoc.fields, reducedQueryPhase.sortValueFormats); if (sortScoreIndex != -1) { searchHit.score(((Number) fieldDoc.fields[sortScoreIndex]).floatValue()); } + } else { + searchHit.score(shardDoc.score); } hits.add(searchHit); } } - return new SearchHits(hits.toArray(new SearchHit[0]), reducedQueryPhase.totalHits, reducedQueryPhase.maxScore); + return new SearchHits(hits.toArray(new SearchHit[0]), reducedQueryPhase.totalHits, + reducedQueryPhase.maxScore, sortedTopDocs.sortFields, sortedTopDocs.collapseField, sortedTopDocs.collapseValues); } /** @@ -436,8 +441,7 @@ private ReducedQueryPhase reducedQueryPhase(Collectionnull if the results are not sorted - final SortField[] sortField; - // true iff the result score docs is sorted by a field (not score), this implies that sortField is set. - final boolean isSortedByField; + //encloses info about the merged top docs, the sort fields used to sort the score docs etc. + final SortedTopDocs sortedTopDocs; // the size of the top hits to return final int size; // true iff the query phase had no results. Otherwise false @@ -567,9 +567,8 @@ public static final class ReducedQueryPhase { final DocValueFormat[] sortValueFormats; ReducedQueryPhase(TotalHits totalHits, long fetchHits, float maxScore, boolean timedOut, Boolean terminatedEarly, Suggest suggest, - InternalAggregations aggregations, SearchProfileShardResults shardResults, ScoreDoc[] scoreDocs, - SortField[] sortFields, DocValueFormat[] sortValueFormats, int numReducePhases, boolean isSortedByField, int size, - int from, boolean isEmptyResult) { + InternalAggregations aggregations, SearchProfileShardResults shardResults, SortedTopDocs sortedTopDocs, + DocValueFormat[] sortValueFormats, int numReducePhases, int size, int from, boolean isEmptyResult) { if (numReducePhases <= 0) { throw new IllegalArgumentException("at least one reduce phase must have been applied but was: " + numReducePhases); } @@ -586,9 +585,7 @@ public static final class ReducedQueryPhase { this.aggregations = aggregations; this.shardResults = shardResults; this.numReducePhases = numReducePhases; - this.scoreDocs = scoreDocs; - this.sortField = sortFields; - this.isSortedByField = isSortedByField; + this.sortedTopDocs = sortedTopDocs; this.size = size; this.from = from; this.isEmptyResult = isEmptyResult; @@ -728,7 +725,7 @@ InitialSearchPhase.ArraySearchPhaseResults newSearchPhaseResu } return new InitialSearchPhase.ArraySearchPhaseResults(numShards) { @Override - public ReducedQueryPhase reduce() { + ReducedQueryPhase reduce() { return reducedQueryPhase(results.asList(), isScrollRequest, trackTotalHits); } }; @@ -770,15 +767,23 @@ void add(TopDocsAndMaxScore topDocs) { } static final class SortedTopDocs { - static final SortedTopDocs EMPTY = new SortedTopDocs(EMPTY_DOCS, false, null); + static final SortedTopDocs EMPTY = new SortedTopDocs(EMPTY_DOCS, false, null, null, null); + // the searches merged top docs final ScoreDoc[] scoreDocs; + // true iff the result score docs is sorted by a field (not score), this implies that sortField is set. final boolean isSortedByField; + // the top docs sort fields used to sort the score docs, null if the results are not sorted final SortField[] sortFields; + final String collapseField; + final Object[] collapseValues; - SortedTopDocs(ScoreDoc[] scoreDocs, boolean isSortedByField, SortField[] sortFields) { + SortedTopDocs(ScoreDoc[] scoreDocs, boolean isSortedByField, SortField[] sortFields, + String collapseField, Object[] collapseValues) { this.scoreDocs = scoreDocs; this.isSortedByField = isSortedByField; this.sortFields = sortFields; + this.collapseField = collapseField; + this.collapseValues = collapseValues; } } } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchScrollQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchScrollQueryThenFetchAsyncAction.java index 794e3c84f1363..df18296de2a4a 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchScrollQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchScrollQueryThenFetchAsyncAction.java @@ -35,7 +35,6 @@ import org.elasticsearch.search.query.ScrollQuerySearchResult; import org.elasticsearch.transport.Transport; -import java.io.IOException; import java.util.function.BiFunction; final class SearchScrollQueryThenFetchAsyncAction extends SearchScrollAsyncAction { @@ -68,16 +67,16 @@ protected void executeInitialPhase(Transport.Connection connection, InternalScro protected SearchPhase moveToNextPhase(BiFunction clusterNodeLookup) { return new SearchPhase("fetch") { @Override - public void run() throws IOException { + public void run() { final SearchPhaseController.ReducedQueryPhase reducedQueryPhase = searchPhaseController.reducedScrollQueryPhase( queryResults.asList()); - if (reducedQueryPhase.scoreDocs.length == 0) { + ScoreDoc[] scoreDocs = reducedQueryPhase.sortedTopDocs.scoreDocs; + if (scoreDocs.length == 0) { sendResponse(reducedQueryPhase, fetchResults); return; } - final IntArrayList[] docIdsToLoad = searchPhaseController.fillDocIdsToLoad(queryResults.length(), - reducedQueryPhase.scoreDocs); + final IntArrayList[] docIdsToLoad = searchPhaseController.fillDocIdsToLoad(queryResults.length(), scoreDocs); final ScoreDoc[] lastEmittedDocPerShard = searchPhaseController.getLastEmittedDocPerShard(reducedQueryPhase, queryResults.length()); final CountDown counter = new CountDown(docIdsToLoad.length); diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java b/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java index b00706b78aedb..4164e29a58279 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java @@ -785,22 +785,36 @@ public void writeArray(final Writer writer, final T[] array) throws IOExc } } - public void writeArray(T[] array) throws IOException { - writeVInt(array.length); - for (T value: array) { - value.writeTo(this); - } - } - - public void writeOptionalArray(@Nullable T[] array) throws IOException { + /** + * Same as {@link #writeArray(Writer, Object[])} but the provided array may be null. An additional boolean value is + * serialized to indicate whether the array was null or not. + */ + public void writeOptionalArray(final Writer writer, final @Nullable T[] array) throws IOException { if (array == null) { writeBoolean(false); } else { writeBoolean(true); - writeArray(array); + writeArray(writer, array); } } + /** + * Writes the specified array of {@link Writeable}s. This method can be seen as + * writer version of {@link StreamInput#readArray(Writeable.Reader, IntFunction)}. The length of array encoded as a variable-length + * integer is first written to the stream, and then the elements of the array are written to the stream. + */ + public void writeArray(T[] array) throws IOException { + writeArray((out, value) -> value.writeTo(out), array); + } + + /** + * Same as {@link #writeArray(Writeable[])} but the provided array may be null. An additional boolean value is + * serialized to indicate whether the array was null or not. + */ + public void writeOptionalArray(@Nullable T[] array) throws IOException { + writeOptionalArray((out, value) -> value.writeTo(out), array); + } + /** * Serializes a potential null value. */ diff --git a/server/src/main/java/org/elasticsearch/common/lucene/Lucene.java b/server/src/main/java/org/elasticsearch/common/lucene/Lucene.java index fd9d63ea225e1..4d4a2d838dbd3 100644 --- a/server/src/main/java/org/elasticsearch/common/lucene/Lucene.java +++ b/server/src/main/java/org/elasticsearch/common/lucene/Lucene.java @@ -128,6 +128,9 @@ public class Lucene { public static final TopDocs EMPTY_TOP_DOCS = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), EMPTY_SCORE_DOCS); + private Lucene() { + } + public static Version parseVersion(@Nullable String version, Version defaultVersion, Logger logger) { if (version == null) { return defaultVersion; @@ -201,7 +204,7 @@ public static SegmentInfos pruneUnreferencedFiles(String segmentsFileName, Direc try (Lock writeLock = directory.obtainLock(IndexWriter.WRITE_LOCK_NAME)) { int foundSegmentFiles = 0; for (final String file : directory.listAll()) { - /** + /* * we could also use a deletion policy here but in the case of snapshot and restore * sometimes we restore an index and override files that were referenced by a "future" * commit. If such a commit is opened by the IW it would likely throw a corrupted index exception @@ -227,7 +230,7 @@ public static SegmentInfos pruneUnreferencedFiles(String segmentsFileName, Direc .setCommitOnClose(false) .setMergePolicy(NoMergePolicy.INSTANCE) .setOpenMode(IndexWriterConfig.OpenMode.APPEND))) { - // do nothing and close this will kick of IndexFileDeleter which will remove all pending files + // do nothing and close this will kick off IndexFileDeleter which will remove all pending files } return si; } @@ -321,12 +324,7 @@ public static TopDocsAndMaxScore readTopDocs(StreamInput in) throws IOException } else if (type == 1) { TotalHits totalHits = readTotalHits(in); float maxScore = in.readFloat(); - - SortField[] fields = new SortField[in.readVInt()]; - for (int i = 0; i < fields.length; i++) { - fields[i] = readSortField(in); - } - + SortField[] fields = in.readArray(Lucene::readSortField, SortField[]::new); FieldDoc[] fieldDocs = new FieldDoc[in.readVInt()]; for (int i = 0; i < fieldDocs.length; i++) { fieldDocs[i] = readFieldDoc(in); @@ -337,10 +335,7 @@ public static TopDocsAndMaxScore readTopDocs(StreamInput in) throws IOException float maxScore = in.readFloat(); String field = in.readString(); - SortField[] fields = new SortField[in.readVInt()]; - for (int i = 0; i < fields.length; i++) { - fields[i] = readSortField(in); - } + SortField[] fields = in.readArray(Lucene::readSortField, SortField[]::new); int size = in.readVInt(); Object[] collapseValues = new Object[size]; FieldDoc[] fieldDocs = new FieldDoc[size]; @@ -385,7 +380,7 @@ public static FieldDoc readFieldDoc(StreamInput in) throws IOException { return new FieldDoc(in.readVInt(), in.readFloat(), cFields); } - private static Comparable readSortValue(StreamInput in) throws IOException { + public static Comparable readSortValue(StreamInput in) throws IOException { byte type = in.readByte(); if (type == 0) { return null; @@ -436,11 +431,7 @@ public static void writeTopDocs(StreamOutput out, TopDocsAndMaxScore topDocs) th out.writeFloat(topDocs.maxScore); out.writeString(collapseDocs.field); - - out.writeVInt(collapseDocs.fields.length); - for (SortField sortField : collapseDocs.fields) { - writeSortField(out, sortField); - } + out.writeArray(Lucene::writeSortField, collapseDocs.fields); out.writeVInt(topDocs.topDocs.scoreDocs.length); for (int i = 0; i < topDocs.topDocs.scoreDocs.length; i++) { @@ -455,10 +446,7 @@ public static void writeTopDocs(StreamOutput out, TopDocsAndMaxScore topDocs) th writeTotalHits(out, topDocs.topDocs.totalHits); out.writeFloat(topDocs.maxScore); - out.writeVInt(topFieldDocs.fields.length); - for (SortField sortField : topFieldDocs.fields) { - writeSortField(out, sortField); - } + out.writeArray(Lucene::writeSortField, topFieldDocs.fields); out.writeVInt(topDocs.topDocs.scoreDocs.length); for (ScoreDoc doc : topFieldDocs.scoreDocs) { @@ -501,8 +489,7 @@ private static Object readMissingValue(StreamInput in) throws IOException { } } - - private static void writeSortValue(StreamOutput out, Object field) throws IOException { + public static void writeSortValue(StreamOutput out, Object field) throws IOException { if (field == null) { out.writeByte((byte) 0); } else { @@ -687,11 +674,7 @@ public static void writeExplanation(StreamOutput out, Explanation explanation) t } } - private Lucene() { - - } - - public static final boolean indexExists(final Directory directory) throws IOException { + public static boolean indexExists(final Directory directory) throws IOException { return DirectoryReader.indexExists(directory); } @@ -701,7 +684,7 @@ public static final boolean indexExists(final Directory directory) throws IOExce * * Will retry the directory every second for at least {@code timeLimitMillis} */ - public static final boolean waitForIndex(final Directory directory, final long timeLimitMillis) + public static boolean waitForIndex(final Directory directory, final long timeLimitMillis) throws IOException { final long DELAY = 1000; long waited = 0; @@ -1070,7 +1053,7 @@ protected void doClose() { } public LeafMetaData getMetaData() { - return new LeafMetaData(Version.LATEST.major, Version.LATEST, (Sort)null); + return new LeafMetaData(Version.LATEST.major, Version.LATEST, null); } public CacheHelper getCoreCacheHelper() { diff --git a/server/src/main/java/org/elasticsearch/search/SearchHits.java b/server/src/main/java/org/elasticsearch/search/SearchHits.java index 0f1cb7f11e1af..2a432fc97fe09 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchHits.java +++ b/server/src/main/java/org/elasticsearch/search/SearchHits.java @@ -19,8 +19,10 @@ package org.elasticsearch.search; +import org.apache.lucene.search.SortField; import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits.Relation; +import org.elasticsearch.Version; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -56,14 +58,29 @@ public static SearchHits empty() { private float maxScore; + @Nullable + private SortField[] sortFields; + @Nullable + private String collapseField; + @Nullable + private Object[] collapseValues; + SearchHits() { } public SearchHits(SearchHit[] hits, @Nullable TotalHits totalHits, float maxScore) { + this(hits, totalHits, maxScore, null, null, null); + } + + public SearchHits(SearchHit[] hits, @Nullable TotalHits totalHits, float maxScore, @Nullable SortField[] sortFields, + @Nullable String collapseField, @Nullable Object[] collapseValues) { this.hits = hits; this.totalHits = totalHits == null ? null : new Total(totalHits); this.maxScore = maxScore; + this.sortFields = sortFields; + this.collapseField = collapseField; + this.collapseValues = collapseValues; } /** @@ -74,7 +91,6 @@ public TotalHits getTotalHits() { return totalHits == null ? null : totalHits.in; } - /** * The maximum score of this query. */ @@ -96,6 +112,31 @@ public SearchHit getAt(int position) { return hits[position]; } + /** + * In case documents were sorted by field(s), returns information about such field(s), null otherwise + * @see SortField + */ + @Nullable + public SortField[] getSortFields() { + return sortFields; + } + + /** + * In case field collapsing was performed, returns the field used for field collapsing, null otherwise + */ + @Nullable + public String getCollapseField() { + return collapseField; + } + + /** + * In case field collapsing was performed, returns the values of the field that field collapsing was performed on, null otherwise + */ + @Nullable + public Object[] getCollapseValues() { + return collapseValues; + } + @Override public Iterator iterator() { return Arrays.stream(getHits()).iterator(); @@ -175,8 +216,7 @@ public static SearchHits fromXContent(XContentParser parser) throws IOException } } } - SearchHits searchHits = new SearchHits(hits.toArray(new SearchHit[hits.size()]), totalHits, maxScore); - return searchHits; + return new SearchHits(hits.toArray(new SearchHit[0]), totalHits, maxScore); } public static SearchHits readSearchHits(StreamInput in) throws IOException { @@ -203,6 +243,12 @@ public void readFrom(StreamInput in) throws IOException { hits[i] = SearchHit.readSearchHit(in); } } + //TODO update version once backported + if (in.getVersion().onOrAfter(Version.V_7_0_0)) { + sortFields = in.readOptionalArray(Lucene::readSortField, SortField[]::new); + collapseField = in.readOptionalString(); + collapseValues = in.readOptionalArray(Lucene::readSortValue, Object[]::new); + } } @Override @@ -219,6 +265,12 @@ public void writeTo(StreamOutput out) throws IOException { hit.writeTo(out); } } + //TODO update version once backported + if (out.getVersion().onOrAfter(Version.V_7_0_0)) { + out.writeOptionalArray(Lucene::writeSortField, sortFields); + out.writeOptionalString(collapseField); + out.writeOptionalArray(Lucene::writeSortValue, collapseValues); + } } @Override @@ -229,12 +281,16 @@ public boolean equals(Object obj) { SearchHits other = (SearchHits) obj; return Objects.equals(totalHits, other.totalHits) && Objects.equals(maxScore, other.maxScore) - && Arrays.equals(hits, other.hits); + && Arrays.equals(hits, other.hits) + && Arrays.equals(sortFields, other.sortFields) + && Objects.equals(collapseField, other.collapseField) + && Arrays.equals(collapseValues, other.collapseValues); } @Override public int hashCode() { - return Objects.hash(totalHits, maxScore, Arrays.hashCode(hits)); + return Objects.hash(totalHits, maxScore, Arrays.hashCode(hits), + Arrays.hashCode(sortFields), collapseField, Arrays.hashCode(collapseValues)); } public static TotalHits parseTotalHitsFragment(XContentParser parser) throws IOException { diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java index c1a170c69ede6..4a8afe22b18aa 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java @@ -20,10 +20,16 @@ package org.elasticsearch.action.search; import com.carrotsearch.randomizedtesting.RandomizedContext; +import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.SortField; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopFieldDocs; import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits.Relation; +import org.apache.lucene.search.grouping.CollapseTopFieldDocs; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.text.Text; import org.elasticsearch.common.util.BigArrays; @@ -47,7 +53,6 @@ import org.elasticsearch.test.ESTestCase; import org.junit.Before; -import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -138,7 +143,7 @@ private AtomicArray generateSeededQueryResults(long seed, int () -> generateQueryResults(nShards, suggestions, searchHitsSize, useConstantScore)); } - public void testMerge() throws IOException { + public void testMerge() { List suggestions = new ArrayList<>(); int maxSuggestSize = 0; for (int i = 0; i < randomIntBetween(1, 5); i++) { @@ -152,8 +157,8 @@ public void testMerge() throws IOException { for (boolean trackTotalHits : new boolean[] {true, false}) { SearchPhaseController.ReducedQueryPhase reducedQueryPhase = searchPhaseController.reducedQueryPhase(queryResults.asList(), false, trackTotalHits); - AtomicArray searchPhaseResultAtomicArray = generateFetchResults(nShards, reducedQueryPhase.scoreDocs, - reducedQueryPhase.suggest); + AtomicArray searchPhaseResultAtomicArray = generateFetchResults(nShards, + reducedQueryPhase.sortedTopDocs.scoreDocs, reducedQueryPhase.suggest); InternalSearchResponse mergedResponse = searchPhaseController.merge(false, reducedQueryPhase, searchPhaseResultAtomicArray.asList(), searchPhaseResultAtomicArray::get); @@ -166,7 +171,7 @@ public void testMerge() throws IOException { suggestSize += stream.collect(Collectors.summingInt(e -> e.getOptions().size())); } assertThat(suggestSize, lessThanOrEqualTo(maxSuggestSize)); - assertThat(mergedResponse.hits().getHits().length, equalTo(reducedQueryPhase.scoreDocs.length - suggestSize)); + assertThat(mergedResponse.hits().getHits().length, equalTo(reducedQueryPhase.sortedTopDocs.scoreDocs.length - suggestSize)); Suggest suggestResult = mergedResponse.suggest(); for (Suggest.Suggestion suggestion : reducedQueryPhase.suggest) { assertThat(suggestion, instanceOf(CompletionSuggestion.class)); @@ -183,24 +188,24 @@ public void testMerge() throws IOException { } } - private AtomicArray generateQueryResults(int nShards, + private static AtomicArray generateQueryResults(int nShards, List suggestions, int searchHitsSize, boolean useConstantScore) { AtomicArray queryResults = new AtomicArray<>(nShards); for (int shardIndex = 0; shardIndex < nShards; shardIndex++) { QuerySearchResult querySearchResult = new QuerySearchResult(shardIndex, new SearchShardTarget("", new Index("", ""), shardIndex, null)); - TopDocs topDocs = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); + final TopDocs topDocs; float maxScore = 0; - if (searchHitsSize > 0) { + if (searchHitsSize == 0) { + topDocs = Lucene.EMPTY_TOP_DOCS; + } else { int nDocs = randomIntBetween(0, searchHitsSize); ScoreDoc[] scoreDocs = new ScoreDoc[nDocs]; for (int i = 0; i < nDocs; i++) { float score = useConstantScore ? 1.0F : Math.abs(randomFloat()); scoreDocs[i] = new ScoreDoc(i, score); - if (score > maxScore) { - maxScore = score; - } + maxScore = Math.max(score, maxScore); } topDocs = new TopDocs(new TotalHits(scoreDocs.length, TotalHits.Relation.EQUAL_TO), scoreDocs); } @@ -283,7 +288,7 @@ private AtomicArray generateFetchResults(int nShards, ScoreDo } } } - SearchHit[] hits = searchHits.toArray(new SearchHit[searchHits.size()]); + SearchHit[] hits = searchHits.toArray(new SearchHit[0]); fetchSearchResult.hits(new SearchHits(hits, new TotalHits(hits.length, Relation.EQUAL_TO), maxScore)); fetchResults.set(shardIndex, fetchSearchResult); } @@ -336,6 +341,10 @@ public void testConsumer() { assertEquals(numTotalReducePhases, reduce.numReducePhases); InternalMax max = (InternalMax) reduce.aggregations.asList().get(0); assertEquals(3.0D, max.getValue(), 0.0D); + assertFalse(reduce.sortedTopDocs.isSortedByField); + assertNull(reduce.sortedTopDocs.sortFields); + assertNull(reduce.sortedTopDocs.collapseField); + assertNull(reduce.sortedTopDocs.collapseValues); } public void testConsumerConcurrently() throws InterruptedException { @@ -374,13 +383,17 @@ public void testConsumerConcurrently() throws InterruptedException { SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); InternalMax internalMax = (InternalMax) reduce.aggregations.asList().get(0); assertEquals(max.get(), internalMax.getValue(), 0.0D); - assertEquals(1, reduce.scoreDocs.length); + assertEquals(1, reduce.sortedTopDocs.scoreDocs.length); assertEquals(max.get(), reduce.maxScore, 0.0f); assertEquals(expectedNumResults, reduce.totalHits.value); - assertEquals(max.get(), reduce.scoreDocs[0].score, 0.0f); + assertEquals(max.get(), reduce.sortedTopDocs.scoreDocs[0].score, 0.0f); + assertFalse(reduce.sortedTopDocs.isSortedByField); + assertNull(reduce.sortedTopDocs.sortFields); + assertNull(reduce.sortedTopDocs.collapseField); + assertNull(reduce.sortedTopDocs.collapseValues); } - public void testConsumerOnlyAggs() throws InterruptedException { + public void testConsumerOnlyAggs() { int expectedNumResults = randomIntBetween(1, 100); int bufferSize = randomIntBetween(2, 200); SearchRequest request = new SearchRequest(); @@ -390,29 +403,31 @@ public void testConsumerOnlyAggs() throws InterruptedException { searchPhaseController.newSearchPhaseResults(request, expectedNumResults); AtomicInteger max = new AtomicInteger(); for (int i = 0; i < expectedNumResults; i++) { - int id = i; int number = randomIntBetween(1, 1000); max.updateAndGet(prev -> Math.max(prev, number)); - QuerySearchResult result = new QuerySearchResult(id, new SearchShardTarget("node", new Index("a", "b"), id, null)); + QuerySearchResult result = new QuerySearchResult(i, new SearchShardTarget("node", new Index("a", "b"), i, null)); result.topDocs(new TopDocsAndMaxScore(new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), number), new DocValueFormat[0]); InternalAggregations aggs = new InternalAggregations(Arrays.asList(new InternalMax("test", (double) number, DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap()))); result.aggregations(aggs); - result.setShardIndex(id); + result.setShardIndex(i); result.size(1); consumer.consumeResult(result); } SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); InternalMax internalMax = (InternalMax) reduce.aggregations.asList().get(0); assertEquals(max.get(), internalMax.getValue(), 0.0D); - assertEquals(0, reduce.scoreDocs.length); + assertEquals(0, reduce.sortedTopDocs.scoreDocs.length); assertEquals(max.get(), reduce.maxScore, 0.0f); assertEquals(expectedNumResults, reduce.totalHits.value); + assertFalse(reduce.sortedTopDocs.isSortedByField); + assertNull(reduce.sortedTopDocs.sortFields); + assertNull(reduce.sortedTopDocs.collapseField); + assertNull(reduce.sortedTopDocs.collapseValues); } - - public void testConsumerOnlyHits() throws InterruptedException { + public void testConsumerOnlyHits() { int expectedNumResults = randomIntBetween(1, 100); int bufferSize = randomIntBetween(2, 200); SearchRequest request = new SearchRequest(); @@ -424,24 +439,26 @@ public void testConsumerOnlyHits() throws InterruptedException { searchPhaseController.newSearchPhaseResults(request, expectedNumResults); AtomicInteger max = new AtomicInteger(); for (int i = 0; i < expectedNumResults; i++) { - int id = i; int number = randomIntBetween(1, 1000); max.updateAndGet(prev -> Math.max(prev, number)); - QuerySearchResult result = new QuerySearchResult(id, new SearchShardTarget("node", new Index("a", "b"), id, null)); + QuerySearchResult result = new QuerySearchResult(i, new SearchShardTarget("node", new Index("a", "b"), i, null)); result.topDocs(new TopDocsAndMaxScore(new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] {new ScoreDoc(0, number)}), number), new DocValueFormat[0]); - result.setShardIndex(id); + result.setShardIndex(i); result.size(1); consumer.consumeResult(result); } SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); - assertEquals(1, reduce.scoreDocs.length); + assertEquals(1, reduce.sortedTopDocs.scoreDocs.length); assertEquals(max.get(), reduce.maxScore, 0.0f); assertEquals(expectedNumResults, reduce.totalHits.value); - assertEquals(max.get(), reduce.scoreDocs[0].score, 0.0f); + assertEquals(max.get(), reduce.sortedTopDocs.scoreDocs[0].score, 0.0f); + assertFalse(reduce.sortedTopDocs.isSortedByField); + assertNull(reduce.sortedTopDocs.sortFields); + assertNull(reduce.sortedTopDocs.collapseField); + assertNull(reduce.sortedTopDocs.collapseValues); } - public void testNewSearchPhaseResults() { for (int i = 0; i < 10; i++) { int expectedNumResults = randomIntBetween(1, 10); @@ -497,15 +514,87 @@ public void testReduceTopNWithFromOffset() { consumer.consumeResult(result); } // 4*3 results = 12 we get result 5 to 10 here with from=5 and size=5 - SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); - assertEquals(5, reduce.scoreDocs.length); + ScoreDoc[] scoreDocs = reduce.sortedTopDocs.scoreDocs; + assertEquals(5, scoreDocs.length); assertEquals(100.f, reduce.maxScore, 0.0f); assertEquals(12, reduce.totalHits.value); - assertEquals(95.0f, reduce.scoreDocs[0].score, 0.0f); - assertEquals(94.0f, reduce.scoreDocs[1].score, 0.0f); - assertEquals(93.0f, reduce.scoreDocs[2].score, 0.0f); - assertEquals(92.0f, reduce.scoreDocs[3].score, 0.0f); - assertEquals(91.0f, reduce.scoreDocs[4].score, 0.0f); + assertEquals(95.0f, scoreDocs[0].score, 0.0f); + assertEquals(94.0f, scoreDocs[1].score, 0.0f); + assertEquals(93.0f, scoreDocs[2].score, 0.0f); + assertEquals(92.0f, scoreDocs[3].score, 0.0f); + assertEquals(91.0f, scoreDocs[4].score, 0.0f); + } + + public void testConsumerSortByField() { + int expectedNumResults = randomIntBetween(1, 100); + int bufferSize = randomIntBetween(2, 200); + SearchRequest request = new SearchRequest(); + int size = randomIntBetween(1, 10); + request.setBatchedReduceSize(bufferSize); + InitialSearchPhase.ArraySearchPhaseResults consumer = + searchPhaseController.newSearchPhaseResults(request, expectedNumResults); + AtomicInteger max = new AtomicInteger(); + SortField[] sortFields = {new SortField("field", SortField.Type.INT, true)}; + DocValueFormat[] docValueFormats = {DocValueFormat.RAW}; + for (int i = 0; i < expectedNumResults; i++) { + int number = randomIntBetween(1, 1000); + max.updateAndGet(prev -> Math.max(prev, number)); + FieldDoc[] fieldDocs = {new FieldDoc(0, Float.NaN, new Object[]{number})}; + TopDocs topDocs = new TopFieldDocs(new TotalHits(1, Relation.EQUAL_TO), fieldDocs, sortFields); + QuerySearchResult result = new QuerySearchResult(i, new SearchShardTarget("node", new Index("a", "b"), i, null)); + result.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), docValueFormats); + result.setShardIndex(i); + result.size(size); + consumer.consumeResult(result); + } + SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); + assertEquals(Math.min(expectedNumResults, size), reduce.sortedTopDocs.scoreDocs.length); + assertEquals(expectedNumResults, reduce.totalHits.value); + assertEquals(max.get(), ((FieldDoc)reduce.sortedTopDocs.scoreDocs[0]).fields[0]); + assertTrue(reduce.sortedTopDocs.isSortedByField); + assertEquals(1, reduce.sortedTopDocs.sortFields.length); + assertEquals("field", reduce.sortedTopDocs.sortFields[0].getField()); + assertEquals(SortField.Type.INT, reduce.sortedTopDocs.sortFields[0].getType()); + assertNull(reduce.sortedTopDocs.collapseField); + assertNull(reduce.sortedTopDocs.collapseValues); + } + + public void testConsumerFieldCollapsing() { + int expectedNumResults = randomIntBetween(30, 100); + int bufferSize = randomIntBetween(2, 200); + SearchRequest request = new SearchRequest(); + int size = randomIntBetween(5, 10); + request.setBatchedReduceSize(bufferSize); + InitialSearchPhase.ArraySearchPhaseResults consumer = + searchPhaseController.newSearchPhaseResults(request, expectedNumResults); + SortField[] sortFields = {new SortField("field", SortField.Type.STRING)}; + BytesRef a = new BytesRef("a"); + BytesRef b = new BytesRef("b"); + BytesRef c = new BytesRef("c"); + Object[] collapseValues = new Object[]{a, b, c}; + DocValueFormat[] docValueFormats = {DocValueFormat.RAW}; + for (int i = 0; i < expectedNumResults; i++) { + Object[] values = {randomFrom(collapseValues)}; + FieldDoc[] fieldDocs = {new FieldDoc(0, Float.NaN, values)}; + TopDocs topDocs = new CollapseTopFieldDocs("field", new TotalHits(1, Relation.EQUAL_TO), fieldDocs, sortFields, values); + QuerySearchResult result = new QuerySearchResult(i, new SearchShardTarget("node", new Index("a", "b"), i, null)); + result.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), docValueFormats); + result.setShardIndex(i); + result.size(size); + consumer.consumeResult(result); + } + SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); + assertEquals(3, reduce.sortedTopDocs.scoreDocs.length); + assertEquals(expectedNumResults, reduce.totalHits.value); + assertEquals(a, ((FieldDoc)reduce.sortedTopDocs.scoreDocs[0]).fields[0]); + assertEquals(b, ((FieldDoc)reduce.sortedTopDocs.scoreDocs[1]).fields[0]); + assertEquals(c, ((FieldDoc)reduce.sortedTopDocs.scoreDocs[2]).fields[0]); + assertTrue(reduce.sortedTopDocs.isSortedByField); + assertEquals(1, reduce.sortedTopDocs.sortFields.length); + assertEquals("field", reduce.sortedTopDocs.sortFields[0].getField()); + assertEquals(SortField.Type.STRING, reduce.sortedTopDocs.sortFields[0].getType()); + assertEquals("field", reduce.sortedTopDocs.collapseField); + assertArrayEquals(collapseValues, reduce.sortedTopDocs.collapseValues); } } diff --git a/server/src/test/java/org/elasticsearch/common/io/stream/StreamTests.java b/server/src/test/java/org/elasticsearch/common/io/stream/StreamTests.java index 6431a3469b6b0..05cc442c48e9e 100644 --- a/server/src/test/java/org/elasticsearch/common/io/stream/StreamTests.java +++ b/server/src/test/java/org/elasticsearch/common/io/stream/StreamTests.java @@ -213,7 +213,6 @@ public void testReadArraySize() throws IOException { } public void testWritableArrays() throws IOException { - final String[] strings = generateRandomStringArray(10, 10, false, true); WriteableString[] sourceArray = Arrays.stream(strings).map(WriteableString::new).toArray(WriteableString[]::new); WriteableString[] targetArray; @@ -233,6 +232,28 @@ public void testWritableArrays() throws IOException { assertThat(targetArray, equalTo(sourceArray)); } + public void testArrays() throws IOException { + final String[] strings; + final String[] deserialized; + Writeable.Writer writer = StreamOutput::writeString; + Writeable.Reader reader = StreamInput::readString; + BytesStreamOutput out = new BytesStreamOutput(); + if (randomBoolean()) { + if (randomBoolean()) { + strings = null; + } else { + strings = generateRandomStringArray(10, 10, false, true); + } + out.writeOptionalArray(writer, strings); + deserialized = out.bytes().streamInput().readOptionalArray(reader, String[]::new); + } else { + strings = generateRandomStringArray(10, 10, false, true); + out.writeArray(writer, strings); + deserialized = out.bytes().streamInput().readArray(reader, String[]::new); + } + assertThat(deserialized, equalTo(strings)); + } + public void testSetOfLongs() throws IOException { final int size = randomIntBetween(0, 6); final Set sourceSet = new HashSet<>(size); diff --git a/server/src/test/java/org/elasticsearch/common/lucene/LuceneTests.java b/server/src/test/java/org/elasticsearch/common/lucene/LuceneTests.java index b677247f266cd..ea894a2edd09a 100644 --- a/server/src/test/java/org/elasticsearch/common/lucene/LuceneTests.java +++ b/server/src/test/java/org/elasticsearch/common/lucene/LuceneTests.java @@ -23,6 +23,7 @@ import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.Field.Store; +import org.apache.lucene.document.LatLonDocValuesField; import org.apache.lucene.document.StringField; import org.apache.lucene.document.TextField; import org.apache.lucene.index.DirectoryReader; @@ -37,8 +38,12 @@ import org.apache.lucene.index.Term; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; -import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.SortedNumericSortField; +import org.apache.lucene.search.SortedSetSelector; +import org.apache.lucene.search.SortedSetSortField; import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.Weight; @@ -46,8 +51,18 @@ import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.store.MockDirectoryWrapper; import org.apache.lucene.util.Bits; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.core.internal.io.IOUtils; +import org.elasticsearch.index.fielddata.IndexFieldData; +import org.elasticsearch.index.fielddata.fieldcomparator.BytesRefFieldComparatorSource; +import org.elasticsearch.index.fielddata.fieldcomparator.DoubleValuesComparatorSource; +import org.elasticsearch.index.fielddata.fieldcomparator.FloatValuesComparatorSource; +import org.elasticsearch.index.fielddata.fieldcomparator.LongValuesComparatorSource; +import org.elasticsearch.search.MultiValueMode; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.VersionUtils; import java.io.IOException; import java.io.StringReader; @@ -62,6 +77,8 @@ import static org.hamcrest.Matchers.equalTo; public class LuceneTests extends ESTestCase { + private static final NamedWriteableRegistry EMPTY_REGISTRY = new NamedWriteableRegistry(Collections.emptyList()); + public void testWaitForIndex() throws Exception { final MockDirectoryWrapper dir = newMockDirectory(); @@ -498,4 +515,148 @@ public void testWrapLiveDocsNotExposeAbortedDocuments() throws Exception { } IOUtils.close(writer, dir); } + + public void testSortFieldSerialization() throws IOException { + Tuple sortFieldTuple = randomSortField(); + SortField deserialized = copyInstance(sortFieldTuple.v1(), EMPTY_REGISTRY, Lucene::writeSortField, Lucene::readSortField, + VersionUtils.randomVersion(random())); + assertEquals(sortFieldTuple.v2(), deserialized); + } + + public void testSortValueSerialization() throws IOException { + Object sortValue = randomSortValue(); + Object deserialized = copyInstance(sortValue, EMPTY_REGISTRY, Lucene::writeSortValue, Lucene::readSortValue, + VersionUtils.randomVersion(random())); + assertEquals(sortValue, deserialized); + } + + public static Object randomSortValue() { + switch(randomIntBetween(0, 8)) { + case 0: + return randomAlphaOfLengthBetween(3, 10); + case 1: + return randomInt(); + case 2: + return randomLong(); + case 3: + return randomFloat(); + case 4: + return randomDouble(); + case 5: + return randomByte(); + case 6: + return randomShort(); + case 7: + return randomBoolean(); + case 8: + return new BytesRef(randomAlphaOfLengthBetween(3, 10)); + default: + throw new UnsupportedOperationException(); + } + } + + public static Tuple randomSortField() { + switch(randomIntBetween(0, 2)) { + case 0: + return randomSortFieldCustomComparatorSource(); + case 1: + return randomCustomSortField(); + case 2: + String field = randomAlphaOfLengthBetween(3, 10); + SortField.Type type = randomFrom(SortField.Type.values()); + if ((type == SortField.Type.SCORE || type == SortField.Type.DOC) && randomBoolean()) { + field = null; + } + SortField sortField = new SortField(field, type, randomBoolean()); + Object missingValue = randomMissingValue(sortField.getType()); + if (missingValue != null) { + sortField.setMissingValue(missingValue); + } + return Tuple.tuple(sortField, sortField); + default: + throw new UnsupportedOperationException(); + } + } + + private static Tuple randomSortFieldCustomComparatorSource() { + String field = randomAlphaOfLengthBetween(3, 10); + IndexFieldData.XFieldComparatorSource comparatorSource; + boolean reverse = randomBoolean(); + Object missingValue = null; + switch(randomIntBetween(0, 3)) { + case 0: + comparatorSource = new LongValuesComparatorSource(null, randomBoolean() ? randomLong() : null, + randomFrom(MultiValueMode.values()), null); + break; + case 1: + comparatorSource = new DoubleValuesComparatorSource(null, randomBoolean() ? randomDouble() : null, + randomFrom(MultiValueMode.values()), null); + break; + case 2: + comparatorSource = new FloatValuesComparatorSource(null, randomBoolean() ? randomFloat() : null, + randomFrom(MultiValueMode.values()), null); + break; + case 3: + comparatorSource = new BytesRefFieldComparatorSource(null, + randomBoolean() ? "_first" : "_last", randomFrom(MultiValueMode.values()), null); + missingValue = comparatorSource.missingValue(reverse); + break; + default: + throw new UnsupportedOperationException(); + } + SortField sortField = new SortField(field, comparatorSource, reverse); + SortField expected = new SortField(field, comparatorSource.reducedType(), reverse); + expected.setMissingValue(missingValue); + return Tuple.tuple(sortField, expected); + } + + private static Tuple randomCustomSortField() { + String field = randomAlphaOfLengthBetween(3, 10); + switch(randomIntBetween(0, 2)) { + case 0: { + SortField sortField = LatLonDocValuesField.newDistanceSort(field, 0, 0); + SortField expected = new SortField(field, SortField.Type.DOUBLE); + expected.setMissingValue(Double.POSITIVE_INFINITY); + return Tuple.tuple(sortField, expected); + } + case 1: { + SortedSetSortField sortField = new SortedSetSortField(field, randomBoolean(), randomFrom(SortedSetSelector.Type.values())); + SortField expected = new SortField(sortField.getField(), SortField.Type.STRING, sortField.getReverse()); + Object missingValue = randomMissingValue(SortField.Type.STRING); + sortField.setMissingValue(missingValue); + expected.setMissingValue(missingValue); + return Tuple.tuple(sortField, expected); + } + case 2: { + SortField.Type type = randomFrom(SortField.Type.DOUBLE, SortField.Type.INT, SortField.Type.FLOAT, SortField.Type.LONG); + SortedNumericSortField sortField = new SortedNumericSortField(field, type, randomBoolean()); + SortField expected = new SortField(sortField.getField(), sortField.getNumericType(), sortField.getReverse()); + Object missingValue = randomMissingValue(type); + if (missingValue != null) { + sortField.setMissingValue(missingValue); + expected.setMissingValue(missingValue); + } + return Tuple.tuple(sortField, expected); + } + default: + throw new UnsupportedOperationException(); + } + } + + private static Object randomMissingValue(SortField.Type type) { + switch(type) { + case INT: + return randomInt(); + case FLOAT: + return randomFloat(); + case DOUBLE: + return randomDouble(); + case LONG: + return randomLong(); + case STRING: + return randomBoolean() ? SortField.STRING_FIRST : SortField.STRING_LAST; + default: + return null; + } + } } diff --git a/server/src/test/java/org/elasticsearch/search/SearchHitTests.java b/server/src/test/java/org/elasticsearch/search/SearchHitTests.java index 8ea7af82b90da..f64b98502be6e 100644 --- a/server/src/test/java/org/elasticsearch/search/SearchHitTests.java +++ b/server/src/test/java/org/elasticsearch/search/SearchHitTests.java @@ -19,15 +19,6 @@ package org.elasticsearch.search; -import java.io.IOException; -import java.io.InputStream; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.Predicate; - import org.apache.lucene.search.Explanation; import org.apache.lucene.search.TotalHits; import org.elasticsearch.action.OriginalIndices; @@ -52,6 +43,15 @@ import org.elasticsearch.test.AbstractStreamableTestCase; import org.elasticsearch.test.RandomObjects; +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; + import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.XContentTestUtils.insertRandomFields; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; @@ -65,7 +65,7 @@ public static SearchHit createTestItem(boolean withOptionalInnerHits, boolean wi return createTestItem(randomFrom(XContentType.values()), withOptionalInnerHits, withShardTarget); } - public static SearchHit createTestItem(XContentType xContentType, boolean withOptionalInnerHits, boolean withShardTarget) { + public static SearchHit createTestItem(XContentType xContentType, boolean withOptionalInnerHits, boolean transportSerialization) { int internalId = randomInt(); String uid = randomAlphaOfLength(10); Text type = new Text(randomAlphaOfLengthBetween(5, 10)); @@ -120,12 +120,12 @@ public static SearchHit createTestItem(XContentType xContentType, boolean withOp Map innerHits = new HashMap<>(innerHitsSize); for (int i = 0; i < innerHitsSize; i++) { innerHits.put(randomAlphaOfLength(5), - SearchHitsTests.createTestItem(xContentType, false, withShardTarget)); + SearchHitsTests.createTestItem(xContentType, false, transportSerialization)); } hit.setInnerHits(innerHits); } } - if (withShardTarget && randomBoolean()) { + if (transportSerialization && randomBoolean()) { String index = randomAlphaOfLengthBetween(5, 10); String clusterAlias = randomBoolean() ? null : randomAlphaOfLengthBetween(5, 10); hit.shard(new SearchShardTarget(randomAlphaOfLengthBetween(5, 10), diff --git a/server/src/test/java/org/elasticsearch/search/SearchHitsTests.java b/server/src/test/java/org/elasticsearch/search/SearchHitsTests.java index 05ad84a4cc270..0bc9a72af7871 100644 --- a/server/src/test/java/org/elasticsearch/search/SearchHitsTests.java +++ b/server/src/test/java/org/elasticsearch/search/SearchHitsTests.java @@ -19,11 +19,15 @@ package org.elasticsearch.search; +import org.apache.lucene.search.SortField; import org.apache.lucene.search.TotalHits; import org.apache.lucene.util.TestUtil; +import org.elasticsearch.Version; import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.lucene.LuceneTests; import org.elasticsearch.common.text.Text; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.ToXContent; @@ -34,41 +38,75 @@ import org.elasticsearch.index.Index; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.test.AbstractStreamableXContentTestCase; +import org.elasticsearch.test.VersionUtils; import java.io.IOException; +import java.util.Base64; import java.util.Collections; import java.util.function.Predicate; public class SearchHitsTests extends AbstractStreamableXContentTestCase { + public static SearchHits createTestItem(boolean withOptionalInnerHits, boolean withShardTarget) { return createTestItem(randomFrom(XContentType.values()), withOptionalInnerHits, withShardTarget); } private static SearchHit[] createSearchHitArray(int size, XContentType xContentType, boolean withOptionalInnerHits, - boolean withShardTarget) { + boolean transportSerialization) { SearchHit[] hits = new SearchHit[size]; for (int i = 0; i < hits.length; i++) { - hits[i] = SearchHitTests.createTestItem(xContentType, withOptionalInnerHits, withShardTarget); + hits[i] = SearchHitTests.createTestItem(xContentType, withOptionalInnerHits, transportSerialization); } return hits; } - private static TotalHits randomTotalHits() { + private static TotalHits randomTotalHits(TotalHits.Relation relation) { long totalHits = TestUtil.nextLong(random(), 0, Long.MAX_VALUE); - TotalHits.Relation relation = randomFrom(TotalHits.Relation.values()); return new TotalHits(totalHits, relation); } - public static SearchHits createTestItem(XContentType xContentType, boolean withOptionalInnerHits, boolean withShardTarget) { + public static SearchHits createTestItem(XContentType xContentType, boolean withOptionalInnerHits, boolean transportSerialization) { + return createTestItem(xContentType, withOptionalInnerHits, transportSerialization, randomFrom(TotalHits.Relation.values())); + } + + private static SearchHits createTestItem(XContentType xContentType, boolean withOptionalInnerHits, boolean transportSerialization, + TotalHits.Relation totalHitsRelation) { int searchHits = randomIntBetween(0, 5); - SearchHit[] hits = createSearchHitArray(searchHits, xContentType, withOptionalInnerHits, withShardTarget); + SearchHit[] hits = createSearchHitArray(searchHits, xContentType, withOptionalInnerHits, transportSerialization); + TotalHits totalHits = frequently() ? randomTotalHits(totalHitsRelation) : null; float maxScore = frequently() ? randomFloat() : Float.NaN; - return new SearchHits(hits, frequently() ? randomTotalHits() : null, maxScore); + SortField[] sortFields = null; + String collapseField = null; + Object[] collapseValues = null; + if (transportSerialization) { + sortFields = randomBoolean() ? createSortFields(randomIntBetween(1, 5)) : null; + collapseField = randomAlphaOfLengthBetween(5, 10); + collapseValues = randomBoolean() ? createCollapseValues(randomIntBetween(1, 10)) : null; + } + return new SearchHits(hits, totalHits, maxScore, sortFields, collapseField, collapseValues); + } + + private static SortField[] createSortFields(int size) { + SortField[] sortFields = new SortField[size]; + for (int i = 0; i < sortFields.length; i++) { + //sort fields are simplified before serialization, we write directly the simplified version + //otherwise equality comparisons become complicated + sortFields[i] = LuceneTests.randomSortField().v2(); + } + return sortFields; + } + + private static Object[] createCollapseValues(int size) { + Object[] collapseValues = new Object[size]; + for (int i = 0; i < collapseValues.length; i++) { + collapseValues[i] = LuceneTests.randomSortValue(); + } + return collapseValues; } @Override protected SearchHits mutateInstance(SearchHits instance) { - switch (randomIntBetween(0, 2)) { + switch (randomIntBetween(0, 5)) { case 0: return new SearchHits(createSearchHitArray(instance.getHits().length + 1, randomFrom(XContentType.values()), false, randomBoolean()), @@ -76,7 +114,7 @@ protected SearchHits mutateInstance(SearchHits instance) { case 1: final TotalHits totalHits; if (instance.getTotalHits() == null) { - totalHits = randomTotalHits(); + totalHits = randomTotalHits(randomFrom(TotalHits.Relation.values())); } else { totalHits = null; } @@ -89,6 +127,33 @@ protected SearchHits mutateInstance(SearchHits instance) { maxScore = Float.NaN; } return new SearchHits(instance.getHits(), instance.getTotalHits(), maxScore); + case 3: + SortField[] sortFields; + if (instance.getSortFields() == null) { + sortFields = createSortFields(randomIntBetween(1, 5)); + } else { + sortFields = randomBoolean() ? createSortFields(instance.getSortFields().length + 1) : null; + } + return new SearchHits(instance.getHits(), instance.getTotalHits(), instance.getMaxScore(), + sortFields, instance.getCollapseField(), instance.getCollapseValues()); + case 4: + String collapseField; + if (instance.getCollapseField() == null) { + collapseField = randomAlphaOfLengthBetween(5, 10); + } else { + collapseField = randomBoolean() ? instance.getCollapseField() + randomAlphaOfLengthBetween(2, 5) : null; + } + return new SearchHits(instance.getHits(), instance.getTotalHits(), instance.getMaxScore(), + instance.getSortFields(), collapseField, instance.getCollapseValues()); + case 5: + Object[] collapseValues; + if (instance.getCollapseValues() == null) { + collapseValues = createCollapseValues(randomIntBetween(1, 5)); + } else { + collapseValues = randomBoolean() ? createCollapseValues(instance.getCollapseValues().length) : null; + } + return new SearchHits(instance.getHits(), instance.getTotalHits(), instance.getMaxScore(), + instance.getSortFields(), instance.getCollapseField(), collapseValues); default: throw new UnsupportedOperationException(); } @@ -125,7 +190,7 @@ protected SearchHits createXContextTestInstance(XContentType xContentType) { // deserialized hit cannot be equal to the original instance. // There is another test (#testFromXContentWithShards) that checks the // rest serialization with shard targets. - return createTestItem(xContentType,true, false); + return createTestItem(xContentType, true, false); } @Override @@ -205,4 +270,40 @@ public void testFromXContentWithShards() throws IOException { } } + + //TODO rename method and adapt versions after backport + public void testReadFromPre70() throws IOException { + try (StreamInput in = StreamInput.wrap(Base64.getDecoder().decode("AQC/gAAAAAA="))) { + in.setVersion(VersionUtils.randomVersionBetween(random(), Version.V_6_0_0, VersionUtils.getPreviousVersion(Version.V_7_0_0))); + SearchHits searchHits = new SearchHits(); + searchHits.readFrom(in); + assertEquals(0, searchHits.getHits().length); + assertNotNull(searchHits.getTotalHits()); + assertEquals(0L, searchHits.getTotalHits().value); + assertEquals(TotalHits.Relation.EQUAL_TO, searchHits.getTotalHits().relation); + assertEquals(-1F, searchHits.getMaxScore(), 0F); + assertNull(searchHits.getSortFields()); + assertNull(searchHits.getCollapseField()); + assertNull(searchHits.getCollapseValues()); + } + } + + //TODO rename method and adapt versions after backport + public void testSerializationPre70() throws IOException { + Version version = VersionUtils.randomVersionBetween(random(), Version.V_6_0_0, VersionUtils.getPreviousVersion(Version.V_7_0_0)); + SearchHits original = createTestItem(randomFrom(XContentType.values()), false, true, TotalHits.Relation.EQUAL_TO); + SearchHits deserialized = copyInstance(original, version); + assertArrayEquals(original.getHits(), deserialized.getHits()); + assertEquals(original.getMaxScore(), deserialized.getMaxScore(), 0F); + if (original.getTotalHits() == null) { + assertNull(deserialized.getTotalHits()); + } else { + assertNotNull(deserialized.getTotalHits()); + assertEquals(original.getTotalHits().value, deserialized.getTotalHits().value); + assertEquals(original.getTotalHits().relation, deserialized.getTotalHits().relation); + } + assertNull(deserialized.getSortFields()); + assertNull(deserialized.getCollapseField()); + assertNull(deserialized.getCollapseValues()); + } } diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java index 8e6a9e5e51b85..8cb35fa13db5d 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java @@ -1151,7 +1151,7 @@ public static T copyStreamable(T original, NamedWriteable Streamable.newWriteableReader(supplier), version); } - private static T copyInstance(T original, NamedWriteableRegistry namedWriteableRegistry, Writeable.Writer writer, + protected static T copyInstance(T original, NamedWriteableRegistry namedWriteableRegistry, Writeable.Writer writer, Writeable.Reader reader, Version version) throws IOException { try (BytesStreamOutput output = new BytesStreamOutput()) { output.setVersion(version);