From 8008a0b160d3548e687a3f7aa2e804040d7074c6 Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Mon, 30 Dec 2024 17:48:40 -0800 Subject: [PATCH 01/13] Pagination in Hybrid query Signed-off-by: Varun Jain --- CHANGELOG.md | 1 + .../neuralsearch/bwc/HybridSearchIT.java | 2 - .../common/MinClusterVersionUtil.java | 5 + .../processor/NormalizationProcessor.java | 1 + .../NormalizationProcessorWorkflow.java | 71 +++++-- ...zationProcessorWorkflowExecuteRequest.java | 2 + .../combination/CombineScoresDto.java | 1 + .../processor/combination/ScoreCombiner.java | 10 +- .../neuralsearch/query/HybridQuery.java | 15 +- .../query/HybridQueryBuilder.java | 44 ++++- .../search/query/HybridCollectorManager.java | 46 ++++- .../query/HybridQueryPhaseSearcher.java | 10 +- .../NormalizationProcessorTests.java | 5 +- .../NormalizationProcessorWorkflowTests.java | 128 +++++++++++-- .../query/HybridQueryBuilderTests.java | 115 +++++++++++ .../neuralsearch/query/HybridQueryIT.java | 180 ++++++++++++++++-- .../neuralsearch/query/HybridQueryTests.java | 44 +++-- .../query/HybridQueryWeightTests.java | 9 +- .../HybridAggregationProcessorTests.java | 4 +- .../query/HybridCollectorManagerTests.java | 96 ++++++++-- .../query/HybridQueryPhaseSearcherTests.java | 28 +++ .../util/HybridQueryUtilTests.java | 38 +++- .../neuralsearch/BaseNeuralSearchIT.java | 2 - 23 files changed, 765 insertions(+), 92 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5345d416f..7751478d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x) ### Features +- Pagination in Hybrid query ([#963](https://github.com/opensearch-project/neural-search/pull/963)) ### Enhancements - Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970)) - Support new knn query parameter expand_nested ([#1013](https://github.com/opensearch-project/neural-search/pull/1013)) diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java index d4ae88a3f..d2657555d 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java @@ -13,8 +13,6 @@ import org.opensearch.index.query.MatchQueryBuilder; -import static org.opensearch.knn.index.query.KNNQueryBuilder.EXPAND_NESTED_FIELD; -import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion; import static org.opensearch.neuralsearch.util.TestUtils.getModelId; import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS; diff --git a/src/main/java/org/opensearch/neuralsearch/common/MinClusterVersionUtil.java b/src/main/java/org/opensearch/neuralsearch/common/MinClusterVersionUtil.java index 0f5cbefcf..05e04e84a 100644 --- a/src/main/java/org/opensearch/neuralsearch/common/MinClusterVersionUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/common/MinClusterVersionUtil.java @@ -22,6 +22,7 @@ public final class MinClusterVersionUtil { private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0; private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version.V_2_14_0; + private static final Version MINIMAL_SUPPORTED_VERSION_PAGINATION_IN_HYBRID_QUERY = Version.V_2_19_0; // Note this minimal version will act as a override private static final Map MINIMAL_VERSION_NEURAL = ImmutableMap.builder() @@ -38,6 +39,10 @@ public static boolean isClusterOnOrAfterMinReqVersionForRadialSearch() { return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH); } + public static boolean isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery() { + return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_PAGINATION_IN_HYBRID_QUERY); + } + public static boolean isClusterOnOrAfterMinReqVersion(String key) { Version version; if (MINIMAL_VERSION_NEURAL.containsKey(key)) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java index d2008ae97..d2fa03fde 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -93,6 +93,7 @@ private void prepareAndExecuteNormalizationWo .combinationTechnique(combinationTechnique) .explain(explain) .pipelineProcessingContext(requestContextOptional.orElse(null)) + .searchPhaseContext(searchPhaseContext) .build(); normalizationWorkflow.execute(request); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index f2699d967..0bedf0e64 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -19,6 +19,7 @@ import org.apache.lucene.search.Sort; import org.apache.lucene.search.TopFieldDocs; import org.apache.lucene.search.FieldDoc; +import org.opensearch.action.search.SearchPhaseContext; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.neuralsearch.processor.combination.CombineScoresDto; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; @@ -64,7 +65,8 @@ public void execute( final List querySearchResults, final Optional fetchSearchResultOptional, final ScoreNormalizationTechnique normalizationTechnique, - final ScoreCombinationTechnique combinationTechnique + final ScoreCombinationTechnique combinationTechnique, + final SearchPhaseContext searchPhaseContext ) { NormalizationProcessorWorkflowExecuteRequest request = NormalizationProcessorWorkflowExecuteRequest.builder() .querySearchResults(querySearchResults) @@ -72,17 +74,21 @@ public void execute( .normalizationTechnique(normalizationTechnique) .combinationTechnique(combinationTechnique) .explain(false) + .searchPhaseContext(searchPhaseContext) .build(); execute(request); } public void execute(final NormalizationProcessorWorkflowExecuteRequest request) { + List querySearchResults = request.getQuerySearchResults(); + Optional fetchSearchResultOptional = request.getFetchSearchResultOptional(); + // save original state - List unprocessedDocIds = unprocessedDocIds(request.getQuerySearchResults()); + List unprocessedDocIds = unprocessedDocIds(querySearchResults); // pre-process data log.debug("Pre-process query results"); - List queryTopDocs = getQueryTopDocs(request.getQuerySearchResults()); + List queryTopDocs = getQueryTopDocs(querySearchResults); explain(request, queryTopDocs); @@ -93,8 +99,9 @@ public void execute(final NormalizationProcessorWorkflowExecuteRequest request) CombineScoresDto combineScoresDTO = CombineScoresDto.builder() .queryTopDocs(queryTopDocs) .scoreCombinationTechnique(request.getCombinationTechnique()) - .querySearchResults(request.getQuerySearchResults()) - .sort(evaluateSortCriteria(request.getQuerySearchResults(), queryTopDocs)) + .querySearchResults(querySearchResults) + .sort(evaluateSortCriteria(querySearchResults, queryTopDocs)) + .fromValueForSingleShard(getFromValueIfSingleShard(request)) .build(); // combine @@ -103,8 +110,26 @@ public void execute(final NormalizationProcessorWorkflowExecuteRequest request) // post-process data log.debug("Post-process query results after score normalization and combination"); - updateOriginalQueryResults(combineScoresDTO); - updateOriginalFetchResults(request.getQuerySearchResults(), request.getFetchSearchResultOptional(), unprocessedDocIds); + updateOriginalQueryResults(combineScoresDTO, fetchSearchResultOptional.isPresent()); + updateOriginalFetchResults( + querySearchResults, + fetchSearchResultOptional, + unprocessedDocIds, + combineScoresDTO.getFromValueForSingleShard() + ); + } + + /** + * Get value of from parameter when there is a single shard + * and fetch phase is already executed + * Ref https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/search/SearchService.java#L715 + */ + private int getFromValueIfSingleShard(final NormalizationProcessorWorkflowExecuteRequest request) { + final SearchPhaseContext searchPhaseContext = request.getSearchPhaseContext(); + if (searchPhaseContext.getNumShards() > 1 || request.fetchSearchResultOptional.isEmpty()) { + return -1; + } + return searchPhaseContext.getRequest().source().from(); } /** @@ -173,19 +198,33 @@ private List getQueryTopDocs(final List quer return queryTopDocs; } - private void updateOriginalQueryResults(final CombineScoresDto combineScoresDTO) { + private void updateOriginalQueryResults(final CombineScoresDto combineScoresDTO, final boolean isFetchPhaseExecuted) { final List querySearchResults = combineScoresDTO.getQuerySearchResults(); final List queryTopDocs = getCompoundTopDocs(combineScoresDTO, querySearchResults); final Sort sort = combineScoresDTO.getSort(); + int totalScoreDocsCount = 0; for (int index = 0; index < querySearchResults.size(); index++) { QuerySearchResult querySearchResult = querySearchResults.get(index); CompoundTopDocs updatedTopDocs = queryTopDocs.get(index); + totalScoreDocsCount += updatedTopDocs.getScoreDocs().size(); TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore( buildTopDocs(updatedTopDocs, sort), maxScoreForShard(updatedTopDocs, sort != null) ); + // Fetch Phase had ran before the normalization phase, therefore update the from value in result of each shard. + // This will ensure the trimming of the search results. + if (isFetchPhaseExecuted) { + querySearchResult.from(combineScoresDTO.getFromValueForSingleShard()); + } querySearchResult.topDocs(updatedTopDocsAndMaxScore, querySearchResult.sortValueFormats()); } + + final int from = querySearchResults.get(0).from(); + if (from > totalScoreDocsCount) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Reached end of search result, increase pagination_depth value to see more results") + ); + } } private List getCompoundTopDocs(CombineScoresDto combineScoresDTO, List querySearchResults) { @@ -244,7 +283,8 @@ private TopDocs buildTopDocs(CompoundTopDocs updatedTopDocs, Sort sort) { private void updateOriginalFetchResults( final List querySearchResults, final Optional fetchSearchResultOptional, - final List docIds + final List docIds, + final int fromValueForSingleShard ) { if (fetchSearchResultOptional.isEmpty()) { return; @@ -276,14 +316,21 @@ private void updateOriginalFetchResults( QuerySearchResult querySearchResult = querySearchResults.get(0); TopDocs topDocs = querySearchResult.topDocs().topDocs; + // Scenario to handle when calculating the trimmed length of updated search hits + // When normalization process runs after fetch phase, then search hits already fetched. Therefore, use the from value sent in the + // search request to calculate the effective length of updated search hits array. + int trimmedLengthOfSearchHits = topDocs.scoreDocs.length - fromValueForSingleShard; // iterate over the normalized/combined scores, that solves (1) and (3) - SearchHit[] updatedSearchHitArray = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> { + SearchHit[] updatedSearchHitArray = new SearchHit[trimmedLengthOfSearchHits]; + for (int i = fromValueForSingleShard; i < topDocs.scoreDocs.length; i++) { + // Read topDocs after the desired from length + ScoreDoc scoreDoc = topDocs.scoreDocs[i]; // get fetched hit content by doc_id SearchHit searchHit = docIdToSearchHit.get(scoreDoc.doc); // update score to normalized/combined value (3) searchHit.score(scoreDoc.score); - return searchHit; - }).toArray(SearchHit[]::new); + updatedSearchHitArray[i - fromValueForSingleShard] = searchHit; + } SearchHits updatedSearchHits = new SearchHits( updatedSearchHitArray, querySearchResult.getTotalHits(), diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java index ea0b54b9c..e818c1b31 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java @@ -7,6 +7,7 @@ import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Getter; +import org.opensearch.action.search.SearchPhaseContext; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.search.fetch.FetchSearchResult; @@ -29,4 +30,5 @@ public class NormalizationProcessorWorkflowExecuteRequest { final ScoreCombinationTechnique combinationTechnique; boolean explain; final PipelineProcessingContext pipelineProcessingContext; + final SearchPhaseContext searchPhaseContext; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/CombineScoresDto.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/CombineScoresDto.java index c4783969b..fecf5ca09 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/CombineScoresDto.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/CombineScoresDto.java @@ -29,4 +29,5 @@ public class CombineScoresDto { private List querySearchResults; @Nullable private Sort sort; + private int fromValueForSingleShard; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java index 1779f20f7..40625adfb 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java @@ -70,14 +70,10 @@ public class ScoreCombiner { public void combineScores(final CombineScoresDto combineScoresDTO) { // iterate over results from each shard. Every CompoundTopDocs object has results from // multiple sub queries, doc ids may repeat for each sub query results + ScoreCombinationTechnique scoreCombinationTechnique = combineScoresDTO.getScoreCombinationTechnique(); + Sort sort = combineScoresDTO.getSort(); combineScoresDTO.getQueryTopDocs() - .forEach( - compoundQueryTopDocs -> combineShardScores( - combineScoresDTO.getScoreCombinationTechnique(), - compoundQueryTopDocs, - combineScoresDTO.getSort() - ) - ); + .forEach(compoundQueryTopDocs -> combineShardScores(scoreCombinationTechnique, compoundQueryTopDocs, sort)); } private void combineShardScores( diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java index 60d5870da..fb3b6dd00 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java @@ -15,6 +15,7 @@ import java.util.Objects; import java.util.concurrent.Callable; +import lombok.Getter; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.IndexSearcher; @@ -31,20 +32,25 @@ * Implementation of Query interface for type "hybrid". It allows execution of multiple sub-queries and collect individual * scores for each sub-query. */ +@Getter public final class HybridQuery extends Query implements Iterable { private final List subQueries; + private int paginationDepth; /** * Create new instance of hybrid query object based on collection of sub queries and filter query * @param subQueries collection of queries that are executed individually and contribute to a final list of combined scores * @param filterQueries list of filters that will be applied to each sub query. Each filter from the list is added as bool "filter" clause. If this is null sub queries will be executed as is */ - public HybridQuery(final Collection subQueries, final List filterQueries) { + public HybridQuery(final Collection subQueries, final List filterQueries, final int paginationDepth) { Objects.requireNonNull(subQueries, "collection of queries must not be null"); if (subQueries.isEmpty()) { throw new IllegalArgumentException("collection of queries must not be empty"); } + if (paginationDepth == 0) { + throw new IllegalArgumentException("pagination_depth must not be zero"); + } if (Objects.isNull(filterQueries) || filterQueries.isEmpty()) { this.subQueries = new ArrayList<>(subQueries); } else { @@ -57,10 +63,11 @@ public HybridQuery(final Collection subQueries, final List filterQ } this.subQueries = modifiedSubQueries; } + this.paginationDepth = paginationDepth; } - public HybridQuery(final Collection subQueries) { - this(subQueries, List.of()); + public HybridQuery(final Collection subQueries, final int paginationDepth) { + this(subQueries, List.of(), paginationDepth); } /** @@ -128,7 +135,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { return super.rewrite(indexSearcher); } final List rewrittenSubQueries = manager.getQueriesAfterRewrite(collectors); - return new HybridQuery(rewrittenSubQueries); + return new HybridQuery(rewrittenSubQueries, paginationDepth); } private Void rewriteQuery(Query query, HybridQueryExecutorCollector> collector) { diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java index 60d9fd639..2e6c1ce47 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java @@ -35,6 +35,8 @@ import lombok.experimental.Accessors; import lombok.extern.log4j.Log4j2; +import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery; + /** * Class abstract creation of a Query type "hybrid". Hybrid query will allow execution of multiple sub-queries and * collects score for each of those sub-query. @@ -48,16 +50,23 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder queries = new ArrayList<>(); private String fieldName; + private int paginationDepth; static final int MAX_NUMBER_OF_SUB_QUERIES = 5; + private final static int DEFAULT_PAGINATION_DEPTH = 10; + private static final int LOWER_BOUND_OF_PAGINATION_DEPTH = 1; public HybridQueryBuilder(StreamInput in) throws IOException { super(in); queries.addAll(readQueries(in)); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + paginationDepth = in.readInt(); + } } /** @@ -68,6 +77,9 @@ public HybridQueryBuilder(StreamInput in) throws IOException { @Override protected void doWriteTo(StreamOutput out) throws IOException { writeQueries(out, queries); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + out.writeInt(paginationDepth); + } } /** @@ -97,6 +109,9 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep queryBuilder.toXContent(builder, params); } builder.endArray(); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + builder.field(PAGINATION_DEPTH_FIELD.getPreferredName(), paginationDepth == 0 ? DEFAULT_PAGINATION_DEPTH : paginationDepth); + } printBoostAndQueryName(builder); builder.endObject(); } @@ -113,7 +128,8 @@ protected Query doToQuery(QueryShardContext queryShardContext) throws IOExceptio if (queryCollection.isEmpty()) { return Queries.newMatchNoDocsQuery(String.format(Locale.ROOT, "no clauses for %s query", NAME)); } - return new HybridQuery(queryCollection); + validatePaginationDepth(paginationDepth, queryShardContext); + return new HybridQuery(queryCollection, paginationDepth); } /** @@ -149,6 +165,7 @@ protected Query doToQuery(QueryShardContext queryShardContext) throws IOExceptio public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOException { float boost = AbstractQueryBuilder.DEFAULT_BOOST; + int paginationDepth = DEFAULT_PAGINATION_DEPTH; final List queries = new ArrayList<>(); String queryName = null; @@ -196,6 +213,8 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx } } else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { queryName = parser.text(); + } else if (PAGINATION_DEPTH_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + paginationDepth = parser.intValue(); } else { log.error(String.format(Locale.ROOT, "[%s] query does not support [%s]", NAME, currentFieldName)); throw new ParsingException( @@ -216,6 +235,9 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx HybridQueryBuilder compoundQueryBuilder = new HybridQueryBuilder(); compoundQueryBuilder.queryName(queryName); compoundQueryBuilder.boost(boost); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + compoundQueryBuilder.paginationDepth(paginationDepth); + } for (QueryBuilder query : queries) { compoundQueryBuilder.add(query); } @@ -235,6 +257,9 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryShardContext) throws I if (changed) { newBuilder.queryName(queryName); newBuilder.boost(boost); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + newBuilder.paginationDepth(paginationDepth); + } return newBuilder; } else { return this; @@ -257,6 +282,9 @@ protected boolean doEquals(HybridQueryBuilder obj) { EqualsBuilder equalsBuilder = new EqualsBuilder(); equalsBuilder.append(fieldName, obj.fieldName); equalsBuilder.append(queries, obj.queries); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + equalsBuilder.append(paginationDepth, obj.paginationDepth); + } return equalsBuilder.isEquals(); } @@ -297,6 +325,20 @@ private Collection toQueries(Collection queryBuilders, Quer return queries; } + private static void validatePaginationDepth(final int paginationDepth, final QueryShardContext queryShardContext) { + if (paginationDepth < LOWER_BOUND_OF_PAGINATION_DEPTH) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "pagination_depth should be greater than 0")); + } + // compare pagination depth with OpenSearch setting index.max_result_window + // see https://opensearch.org/docs/latest/install-and-configure/configuring-opensearch/index-settings/ + int maxResultWindowIndexSetting = queryShardContext.getIndexSettings().getMaxResultWindow(); + if (paginationDepth > maxResultWindowIndexSetting) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "pagination_depth should be less than index.max_result_window setting") + ); + } + } + /** * visit method to parse the HybridQueryBuilder by a visitor */ diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index f9457f6ca..7ff737990 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -8,6 +8,7 @@ import lombok.RequiredArgsConstructor; import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Collector; import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.Weight; @@ -22,6 +23,7 @@ import org.opensearch.common.Nullable; import org.opensearch.common.lucene.search.FilteredCollector; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.neuralsearch.query.HybridQuery; import org.opensearch.neuralsearch.search.HitsThresholdChecker; import org.opensearch.neuralsearch.search.collector.HybridSearchCollector; import org.opensearch.neuralsearch.search.collector.HybridTopFieldDocSortCollector; @@ -80,14 +82,28 @@ public abstract class HybridCollectorManager implements CollectorManager 0) { + searchContext.from(0); + } + Weight filteringWeight = null; // Check for post filter to create weight for filter query and later use that weight in the search workflow if (Objects.nonNull(searchContext.parsedPostFilter()) && Objects.nonNull(searchContext.parsedPostFilter().query())) { @@ -461,6 +477,34 @@ private ReduceableSearchResult reduceSearchResults(final List BooleanClause.Occur.FILTER == clause.getOccur()) .map(BooleanClause::getQuery) .collect(Collectors.toList()); - HybridQuery hybridQueryWithFilter = new HybridQuery(hybridQuery.getSubQueries(), filterQueries); + HybridQuery hybridQueryWithFilter = new HybridQuery( + hybridQuery.getSubQueries(), + filterQueries, + hybridQuery.getPaginationDepth() + ); return hybridQueryWithFilter; } return query; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java index 5f45b14fe..87dac8674 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -274,7 +274,7 @@ public void testEmptySearchResults_whenEmptySearchResults_thenDoNotExecuteWorkfl SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); normalizationProcessor.process(null, searchPhaseContext); - verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any()); + verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any(), any()); } public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResult_thenDoNotExecuteWorkflow() { @@ -330,7 +330,7 @@ public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResul when(searchPhaseContext.getNumShards()).thenReturn(numberOfShards); normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); - verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any()); + verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any(), any()); } public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormalization() { @@ -346,6 +346,7 @@ public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormaliz ); SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.source().from(0); searchRequest.setBatchedReduceSize(4); AtomicReference onPartialMergeFailure = new AtomicReference<>(); QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer( diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java index 59fb51563..9969081a6 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java @@ -12,6 +12,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Optional; @@ -19,6 +20,8 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.opensearch.action.OriginalIndices; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchRequest; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.core.index.shard.ShardId; import org.opensearch.neuralsearch.util.TestUtils; @@ -29,6 +32,7 @@ import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.FetchSearchResult; import org.opensearch.search.internal.ShardSearchRequest; import org.opensearch.search.query.QuerySearchResult; @@ -71,12 +75,18 @@ public void testSearchResultTypes_whenResultsOfHybridSearch_thenDoNormalizationC querySearchResult.setShardIndex(shardId); querySearchResults.add(querySearchResult); } - + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + when(searchRequest.source()).thenReturn(searchSourceBuilder); normalizationProcessorWorkflow.execute( querySearchResults, Optional.empty(), ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD + ScoreCombinationFactory.DEFAULT_METHOD, + searchPhaseContext ); TestUtils.assertQueryResultScores(querySearchResults); @@ -113,12 +123,18 @@ public void testSearchResultTypes_whenNoMatches_thenReturnZeroResults() { querySearchResult.setShardIndex(shardId); querySearchResults.add(querySearchResult); } - + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + when(searchRequest.source()).thenReturn(searchSourceBuilder); normalizationProcessorWorkflow.execute( querySearchResults, Optional.empty(), ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD + ScoreCombinationFactory.DEFAULT_METHOD, + searchPhaseContext ); TestUtils.assertQueryResultScoresWithNoMatches(querySearchResults); @@ -172,12 +188,18 @@ public void testFetchResults_whenOneShardAndQueryAndFetchResultsPresent_thenDoNo new SearchHit(0, "10", Map.of(), Map.of()), }; SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); fetchSearchResult.hits(searchHits); - + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + when(searchRequest.source()).thenReturn(searchSourceBuilder); normalizationProcessorWorkflow.execute( querySearchResults, Optional.of(fetchSearchResult), ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD + ScoreCombinationFactory.DEFAULT_METHOD, + searchPhaseContext ); TestUtils.assertQueryResultScores(querySearchResults); @@ -232,12 +254,18 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom new SearchHit(-1, "10", Map.of(), Map.of()), }; SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); fetchSearchResult.hits(searchHits); - + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + when(searchRequest.source()).thenReturn(searchSourceBuilder); normalizationProcessorWorkflow.execute( querySearchResults, Optional.of(fetchSearchResult), ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD + ScoreCombinationFactory.DEFAULT_METHOD, + searchPhaseContext ); TestUtils.assertQueryResultScores(querySearchResults); @@ -284,14 +312,20 @@ public void testFetchResultsAndNoCache_whenOneShardAndMultipleNodesAndMismatchRe querySearchResults.add(querySearchResult); SearchHits searchHits = getSearchHits(); fetchSearchResult.hits(searchHits); - + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + when(searchRequest.source()).thenReturn(searchSourceBuilder); expectThrows( IllegalStateException.class, () -> normalizationProcessorWorkflow.execute( querySearchResults, Optional.of(fetchSearchResult), ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD + ScoreCombinationFactory.DEFAULT_METHOD, + searchPhaseContext ) ); } @@ -336,18 +370,88 @@ public void testFetchResultsAndCache_whenOneShardAndMultipleNodesAndMismatchResu querySearchResults.add(querySearchResult); SearchHits searchHits = getSearchHits(); fetchSearchResult.hits(searchHits); - + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + when(searchRequest.source()).thenReturn(searchSourceBuilder); normalizationProcessorWorkflow.execute( querySearchResults, Optional.of(fetchSearchResult), ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD + ScoreCombinationFactory.DEFAULT_METHOD, + searchPhaseContext ); TestUtils.assertQueryResultScores(querySearchResults); TestUtils.assertFetchResultScores(fetchSearchResult, 4); } + public void testNormalization_whenFromIsGreaterThanResultsSize_thenFail() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + + List querySearchResults = new ArrayList<>(); + for (int shardId = 0; shardId < 4; shardId++) { + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + querySearchResult.topDocs( + new TopDocsAndMaxScore( + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + new ScoreDoc(4, 0.25f), + new ScoreDoc(10, 0.2f), + createStartStopElementForHybridSearchResults(0) } + ), + 0.5f + ), + null + ); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + // requested page is out of bound for the total number of results + querySearchResult.from(17); + querySearchResults.add(querySearchResult); + } + + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + when(searchPhaseContext.getNumShards()).thenReturn(4); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.from(17); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + + NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDto = NormalizationProcessorWorkflowExecuteRequest.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.empty()) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> normalizationProcessorWorkflow.execute(normalizationExecuteDto) + ); + + assertEquals( + String.format(Locale.ROOT, "Reached end of search result, increase pagination_depth value to see more results"), + illegalArgumentException.getMessage() + ); + } + private static SearchHits getSearchHits() { SearchHit[] searchHitArray = new SearchHit[] { new SearchHit(-1, "10", Map.of(), Map.of()), diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java index f40d4bf59..71c066ec1 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -11,6 +11,7 @@ import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD; import static org.opensearch.index.query.AbstractQueryBuilder.DEFAULT_BOOST; +import static org.opensearch.index.remote.RemoteStoreEnums.PathType.HASHED_PREFIX; import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD; import static org.opensearch.neuralsearch.util.TestUtils.xContentBuilderToMap; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.K_FIELD; @@ -33,7 +34,9 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.Version; +import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.UUIDs; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; @@ -50,6 +53,7 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexSettings; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.MatchAllQueryBuilder; @@ -57,6 +61,7 @@ import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.remote.RemoteStoreEnums; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; @@ -119,6 +124,7 @@ public void testDoToQuery_whenNoSubqueries_thenBuildSuccessfully() { @SneakyThrows public void testDoToQuery_whenOneSubquery_thenBuildSuccessfully() { HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + queryBuilder.paginationDepth(10); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); @@ -130,6 +136,10 @@ public void testDoToQuery_whenOneSubquery_thenBuildSuccessfully() { when(mockKNNVectorField.getKnnMappingConfig().getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(3)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME) .queryText(QUERY_TEXT) @@ -152,6 +162,7 @@ public void testDoToQuery_whenOneSubquery_thenBuildSuccessfully() { @SneakyThrows public void testDoToQuery_whenMultipleSubqueries_thenBuildSuccessfully() { HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + queryBuilder.paginationDepth(10); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); @@ -163,6 +174,10 @@ public void testDoToQuery_whenMultipleSubqueries_thenBuildSuccessfully() { when(mockKNNVectorField.getKnnMappingConfig().getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(3)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME) .queryText(QUERY_TEXT) @@ -197,6 +212,74 @@ public void testDoToQuery_whenMultipleSubqueries_thenBuildSuccessfully() { assertEquals(TERM_QUERY_TEXT, termQuery.getTerm().text()); } + @SneakyThrows + public void testDoToQuery_whenPaginationDepthIsGreaterThan10000_thenBuildSuccessfully() { + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + queryBuilder.paginationDepth(10001); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); + KNNMappingConfig mockKNNMappingConfig = mock(KNNMappingConfig.class); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, MethodComponentContext.EMPTY); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(mockKNNMappingConfig); + when(mockKNNMappingConfig.getKnnMethodContext()).thenReturn(Optional.of(knnMethodContext)); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getKnnMappingConfig().getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(3)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER); + + queryBuilder.add(neuralQueryBuilder); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> queryBuilder.doToQuery(mockQueryShardContext) + ); + assertThat(exception.getMessage(), containsString("pagination_depth should be less than index.max_result_window setting")); + } + + @SneakyThrows + public void testDoToQuery_whenPaginationDepthIsLessThanZero_thenBuildSuccessfully() { + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + queryBuilder.paginationDepth(-1); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); + KNNMappingConfig mockKNNMappingConfig = mock(KNNMappingConfig.class); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, MethodComponentContext.EMPTY); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(mockKNNMappingConfig); + when(mockKNNMappingConfig.getKnnMethodContext()).thenReturn(Optional.of(knnMethodContext)); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getKnnMappingConfig().getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(3)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER); + + queryBuilder.add(neuralQueryBuilder); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> queryBuilder.doToQuery(mockQueryShardContext) + ); + assertThat(exception.getMessage(), containsString("pagination_depth should be greater than 0")); + } + @SneakyThrows public void testDoToQuery_whenTooManySubqueries_thenFail() { // create query with 6 sub-queries, which is more than current max allowed @@ -332,6 +415,7 @@ public void testFromXContent_whenMultipleSubQueries_thenBuildSuccessfully() { assertEquals(2, queryTwoSubQueries.queries().size()); assertTrue(queryTwoSubQueries.queries().get(0) instanceof NeuralQueryBuilder); assertTrue(queryTwoSubQueries.queries().get(1) instanceof TermQueryBuilder); + assertEquals(10, queryTwoSubQueries.paginationDepth()); // verify knn vector query NeuralQueryBuilder neuralQueryBuilder = (NeuralQueryBuilder) queryTwoSubQueries.queries().get(0); assertEquals(VECTOR_FIELD_NAME, neuralQueryBuilder.fieldName()); @@ -405,6 +489,7 @@ public void testFromXContent_whenIncorrectFormat_thenFail() { @SneakyThrows public void testToXContent_whenIncomingJsonIsCorrect_thenSuccessful() { + setUpClusterService(); HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); @@ -526,6 +611,7 @@ public void testHashAndEquals_whenSameOrIdenticalObject_thenReturnEqual() { } public void testHashAndEquals_whenSubQueriesDifferent_thenReturnNotEqual() { + setUpClusterService(); String modelId = "testModelId"; String fieldName = "fieldTwo"; String queryText = "query text"; @@ -614,6 +700,7 @@ public void testHashAndEquals_whenSubQueriesDifferent_thenReturnNotEqual() { @SneakyThrows public void testRewrite_whenMultipleSubQueries_thenReturnBuilderForEachSubQuery() { + setUpClusterService(); HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME) .queryText(QUERY_TEXT) @@ -719,6 +806,7 @@ public void testBoost_whenNonDefaultBoostSet_thenFail() { @SneakyThrows public void testBoost_whenDefaultBoostSet_thenBuildSuccessfully() { + setUpClusterService(); // create query with 6 sub-queries, which is more than current max allowed XContentBuilder xContentBuilderWithNonDefaultBoost = XContentFactory.jsonBuilder() .startObject() @@ -769,6 +857,10 @@ public void testBuild_whenValidParameters_thenCreateQuery() { MappedFieldType fieldType = mock(MappedFieldType.class); when(context.fieldMapper(fieldName)).thenReturn(fieldType); when(fieldType.typeName()).thenReturn("rank_features"); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(3)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(context.getIndexSettings()).thenReturn(indexSettings); // Create HybridQueryBuilder instance (no spy since it's final) NeuralSparseQueryBuilder neuralSparseQueryBuilder = new NeuralSparseQueryBuilder(); @@ -777,6 +869,7 @@ public void testBuild_whenValidParameters_thenCreateQuery() { .modelId(modelId) .queryTokensSupplier(() -> Map.of("token1", 1.0f, "token2", 0.5f)); HybridQueryBuilder builder = new HybridQueryBuilder().add(neuralSparseQueryBuilder); + builder.paginationDepth(10); // Build query Query query = builder.toQuery(context); @@ -788,6 +881,7 @@ public void testBuild_whenValidParameters_thenCreateQuery() { @SneakyThrows public void testDoEquals_whenSameParameters_thenEqual() { + setUpClusterService(); // Create neural queries NeuralQueryBuilder neuralQueryBuilder1 = new NeuralQueryBuilder().queryText("test").modelId("test_model"); @@ -859,4 +953,25 @@ private void initKNNSettings() { when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, defaultClusterSettings)); KNNSettings.state().setClusterService(clusterService); } + + private static IndexMetadata getIndexMetadata() { + Map remoteCustomData = Map.of( + RemoteStoreEnums.PathType.NAME, + HASHED_PREFIX.name(), + RemoteStoreEnums.PathHashAlgorithm.NAME, + RemoteStoreEnums.PathHashAlgorithm.FNV_1A_BASE64.name(), + IndexMetadata.TRANSLOG_METADATA_KEY, + "false" + ); + Settings idxSettings = Settings.builder() + .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) + .put(IndexMetadata.SETTING_INDEX_UUID, UUIDs.randomBase64UUID()) + .build(); + IndexMetadata indexMetadata = new IndexMetadata.Builder("test").settings(idxSettings) + .numberOfShards(1) + .numberOfReplicas(0) + .putCustom(IndexMetadata.REMOTE_STORE_CUSTOM_KEY, remoteCustomData) + .build(); + return indexMetadata; + } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 610e08dd0..d13612c61 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -28,6 +28,7 @@ import org.junit.Before; import org.opensearch.client.ResponseException; import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.MatchQueryBuilder; import org.opensearch.index.query.NestedQueryBuilder; import org.opensearch.index.query.QueryBuilder; @@ -793,21 +794,142 @@ public void testConcurrentSearchWithMultipleSlices_whenMultipleShardsIndex_thenS } } - // TODO remove this test after following issue https://github.com/opensearch-project/neural-search/issues/280 gets resolved. @SneakyThrows - public void testHybridQuery_whenFromIsSetInSearchRequest_thenFail() { + public void testPaginationOnSingleShard_whenConcurrentSearchEnabled_thenSuccessful() { try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); - MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); - HybridQueryBuilder hybridQueryBuilderOnlyTerm = new HybridQueryBuilder(); - hybridQueryBuilderOnlyTerm.add(matchQueryBuilder); + testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + // testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + // testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + // testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE); + } + } - ResponseException exceptionNoNestedTypes = expectThrows( + @SneakyThrows + public void testPaginationOnSingleShard_whenConcurrentSearchDisabled_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + // testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + // testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + // testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testPaginationOnMultipleShard_whenConcurrentSearchEnabled_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); + // testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); + // testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail(TEST_MULTI_DOC_INDEX_NAME); + // testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail(TEST_MULTI_DOC_INDEX_NAME); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testPaginationOnMultipleShard_whenConcurrentSearchDisabled_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); + // testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); + // testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail(TEST_MULTI_DOC_INDEX_NAME); + // testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail(TEST_MULTI_DOC_INDEX_NAME); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(String indexName) { + HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder(); + hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder()); + hybridQueryBuilderOnlyMatchAll.paginationDepth(10); + + Map searchResponseAsMap = search( + indexName, + hybridQueryBuilderOnlyMatchAll, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + null, + false, + null, + 2 + ); + + assertEquals(2, getHitCount(searchResponseAsMap)); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(4, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } + + @SneakyThrows + public void testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder(); + hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder()); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, + hybridQueryBuilderOnlyMatchAll, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + null, + false, + null, + 2 + ); + + assertEquals(2, getHitCount(searchResponseAsMap)); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(4, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder(); + hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder()); + + ResponseException responseException = assertThrows( ResponseException.class, () -> search( TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, - hybridQueryBuilderOnlyTerm, + hybridQueryBuilderOnlyMatchAll, null, 10, Map.of("search_pipeline", SEARCH_PIPELINE), @@ -816,18 +938,50 @@ public void testHybridQuery_whenFromIsSetInSearchRequest_thenFail() { null, false, null, - 10 + 5 ) - ); org.hamcrest.MatcherAssert.assertThat( - exceptionNoNestedTypes.getMessage(), - allOf( - containsString("In the current OpenSearch version pagination is not supported with hybrid query"), - containsString("illegal_argument_exception") + responseException.getMessage(), + allOf(containsString("Reached end of search result, increase pagination_depth value to see more results")) + ); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder(); + hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder()); + hybridQueryBuilderOnlyMatchAll.paginationDepth(100001); + + ResponseException responseException = assertThrows( + ResponseException.class, + () -> search( + TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, + hybridQueryBuilderOnlyMatchAll, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + null, + false, + null, + 0 ) ); + + org.hamcrest.MatcherAssert.assertThat( + responseException.getMessage(), + allOf(containsString("pagination_depth should be less than index.max_result_window setting")) + ); } finally { wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE); } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java index 15f0621e8..ea007b9db 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java @@ -72,16 +72,19 @@ public void testQueryBasics_whenMultipleDifferentQueries_thenSuccessful() { when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); HybridQuery query1 = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), + 10 ); HybridQuery query2 = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), + 10 ); HybridQuery query3 = new HybridQuery( List.of( QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_ANOTHER_QUERY_TEXT).toQuery(mockQueryShardContext) - ) + ), + 10 ); QueryUtils.check(query1); QueryUtils.checkEqual(query1, query2); @@ -96,6 +99,7 @@ public void testQueryBasics_whenMultipleDifferentQueries_thenSuccessful() { countOfQueries++; } assertEquals(2, countOfQueries); + assertEquals(10, query3.getPaginationDepth()); } @SneakyThrows @@ -103,6 +107,7 @@ public void testRewrite_whenRewriteQuery_thenSuccessful() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + String field1Value = "text1"; Directory directory = newDirectory(); final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); @@ -120,14 +125,15 @@ public void testRewrite_whenRewriteQuery_thenSuccessful() { // Test with TermQuery HybridQuery hybridQueryWithTerm = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), + 10 ); Query rewritten = hybridQueryWithTerm.rewrite(reader); // term query is the same after we rewrite it assertSame(hybridQueryWithTerm, rewritten); // Test empty query list - IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of())); + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of(), 10)); assertThat(exception.getMessage(), containsString("collection of queries must not be empty")); w.close(); @@ -160,7 +166,8 @@ public void testWithRandomDocuments_whenMultipleTermSubQueriesWithMatch_thenRetu IndexSearcher searcher = newSearcher(reader); HybridQuery query = new HybridQuery( - List.of(new TermQuery(new Term(TEXT_FIELD_NAME, field1Value)), new TermQuery(new Term(TEXT_FIELD_NAME, field2Value))) + List.of(new TermQuery(new Term(TEXT_FIELD_NAME, field1Value)), new TermQuery(new Term(TEXT_FIELD_NAME, field2Value))), + 10 ); // executing search query, getting up to 3 docs in result TopDocs hybridQueryResult = searcher.search(query, 3); @@ -206,7 +213,7 @@ public void testWithRandomDocuments_whenOneTermSubQueryWithoutMatch_thenReturnSu DirectoryReader reader = DirectoryReader.open(w); IndexSearcher searcher = newSearcher(reader); - HybridQuery query = new HybridQuery(List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT)))); + HybridQuery query = new HybridQuery(List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))), 10); // executing search query, getting up to 3 docs in result TopDocs hybridQueryResult = searcher.search(query, 3); @@ -242,7 +249,8 @@ public void testWithRandomDocuments_whenMultipleTermSubQueriesWithoutMatch_thenR IndexSearcher searcher = newSearcher(reader); HybridQuery query = new HybridQuery( - List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT)), new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))) + List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT)), new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))), + 10 ); // executing search query, getting up to 3 docs in result TopDocs hybridQueryResult = searcher.search(query, 3); @@ -256,10 +264,22 @@ public void testWithRandomDocuments_whenMultipleTermSubQueriesWithoutMatch_thenR @SneakyThrows public void testWithRandomDocuments_whenNoSubQueries_thenFail() { - IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of())); + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of(), 10)); assertThat(exception.getMessage(), containsString("collection of queries must not be empty")); } + @SneakyThrows + public void testWithRandomDocuments_whenPaginationDepthIsZero_thenFail() { + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> new HybridQuery( + List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT)), new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))), + 0 + ) + ); + assertThat(exception.getMessage(), containsString("pagination_depth must not be zero")); + } + @SneakyThrows public void testToString_whenCallQueryToString_thenSuccessful() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); @@ -273,7 +293,8 @@ public void testToString_whenCallQueryToString_thenSuccessful() { new BoolQueryBuilder().should(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT)) .should(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_ANOTHER_QUERY_TEXT)) .toQuery(mockQueryShardContext) - ) + ), + 10 ); String queryString = query.toString(TEXT_FIELD_NAME); @@ -293,7 +314,8 @@ public void testFilter_whenSubQueriesWithFilterPassed_thenSuccessful() { QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_ANOTHER_QUERY_TEXT).toQuery(mockQueryShardContext) ), - List.of(filter) + List.of(filter), + 10 ); QueryUtils.check(hybridQuery); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java index 0e32b5e78..5096a8ab7 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java @@ -61,7 +61,8 @@ public void testScorerIterator_whenExecuteQuery_thenScorerIteratorSuccessful() { IndexReader reader = DirectoryReader.open(w); HybridQuery hybridQueryWithTerm = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), + 10 ); IndexSearcher searcher = newSearcher(reader); Weight weight = hybridQueryWithTerm.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f); @@ -117,7 +118,8 @@ public void testSubQueries_whenMultipleEqualSubQueries_thenSuccessful() { .rewrite(mockQueryShardContext) .toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext) - ) + ), + 10 ); IndexSearcher searcher = newSearcher(reader); Weight weight = hybridQueryWithTerm.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f); @@ -164,7 +166,8 @@ public void testExplain_whenCallExplain_thenSuccessful() { IndexReader reader = DirectoryReader.open(w); HybridQuery hybridQueryWithTerm = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), + 10 ); IndexSearcher searcher = newSearcher(reader); Weight weight = searcher.createWeight(hybridQueryWithTerm, ScoreMode.COMPLETE, 1.0f); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java index f44e762f0..6d432aba6 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java @@ -69,7 +69,7 @@ public void testCollectorManager_whenHybridQueryAndNotConcurrentSearch_thenSucce TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 10); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -129,7 +129,7 @@ public void testCollectorManager_whenHybridQueryAndConcurrentSearch_thenSuccessf TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 10); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index 24ebebe5b..d0b91df37 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -52,12 +52,14 @@ import org.opensearch.neuralsearch.search.query.exception.HybridSearchRescoreQueryException; import org.opensearch.search.DocValueFormat; import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.ScrollContext; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.query.ReduceableSearchResult; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import static org.mockito.ArgumentMatchers.any; @@ -91,7 +93,7 @@ public void testNewCollector_whenNotConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 10); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -122,7 +124,7 @@ public void testNewCollector_whenConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 10); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -153,7 +155,7 @@ public void testPostFilter_whenNotConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 10); QueryBuilder postFilterQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, "world"); ParsedQuery parsedQuery = new ParsedQuery(postFilterQuery.toQuery(mockQueryShardContext)); @@ -197,7 +199,7 @@ public void testPostFilter_whenConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 10); QueryBuilder postFilterQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, "world"); Query pfQuery = postFilterQuery.toQuery(mockQueryShardContext); @@ -242,7 +244,8 @@ public void testReduce_whenMatchedDocs_thenSuccessful() { when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); HybridQuery hybridQueryWithTerm = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext)), + 10 ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -343,7 +346,7 @@ public void testNewCollector_whenNotConcurrentSearchAndSortingIsApplied_thenSucc TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 10); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -380,7 +383,7 @@ public void testNewCollector_whenNotConcurrentSearchAndSortingAndSearchAfterAreA TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 10); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -411,7 +414,7 @@ public void testReduce_whenMatchedDocsAndSortingIsApplied_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); - HybridQuery hybridQueryWithMatchAll = new HybridQuery(List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext))); + HybridQuery hybridQueryWithMatchAll = new HybridQuery(List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)), 10); when(searchContext.query()).thenReturn(hybridQueryWithMatchAll); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); @@ -508,7 +511,8 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD List.of( QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) - ) + ), + 10 ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -594,7 +598,7 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); - HybridQuery hybridQueryWithTerm = new HybridQuery(List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext))); + HybridQuery hybridQueryWithTerm = new HybridQuery(List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)), 10); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); @@ -723,7 +727,8 @@ public void testReduceAndRescore_whenMatchedDocsAndRescoreContextPresent_thenSuc List.of( QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) - ) + ), + 10 ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -841,7 +846,8 @@ public void testRescoreWithConcurrentSegmentSearch_whenMatchedDocsAndRescore_the QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY3).toQuery(mockQueryShardContext) - ) + ), + 10 ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -984,7 +990,8 @@ public void testReduceAndRescore_whenRescorerThrowsException_thenFail() { List.of( QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) - ) + ), + 10 ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -1042,4 +1049,67 @@ public void testReduceAndRescore_whenRescorerThrowsException_thenFail() { reader.close(); directory.close(); } + + @SneakyThrows + public void testCreateCollectorManager_whenFromAreEqualToZeroAndPaginationDepthInRange_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + // pagination_depth=10 + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 10); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + assertNotNull(hybridCollectorManager); + assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorNonConcurrentManager); + + Collector collector = hybridCollectorManager.newCollector(); + assertNotNull(collector); + assertTrue(collector instanceof HybridTopScoreDocCollector); + + Collector secondCollector = hybridCollectorManager.newCollector(); + assertSame(collector, secondCollector); + } + + @SneakyThrows + public void testScrollWithHybridQuery_thenFail() { + SearchContext searchContext = mock(SearchContext.class); + ScrollContext scrollContext = new ScrollContext(); + when(searchContext.scrollContext()).thenReturn(scrollContext); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 10); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> HybridCollectorManager.createHybridCollectorManager(searchContext) + ); + assertEquals( + String.format(Locale.ROOT, "Scroll operation is not supported in hybrid query"), + illegalArgumentException.getMessage() + ); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index a8cad5ec7..2aafa2ece 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -138,6 +138,10 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); when(searchContext.mapperService()).thenReturn(mapperService); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -150,6 +154,7 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { TermQueryBuilder termSubQuery2 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2); queryBuilder.add(termSubQuery1); queryBuilder.add(termSubQuery2); + queryBuilder.paginationDepth(10); Query query = queryBuilder.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); @@ -287,6 +292,10 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { when(searchContext.queryResult()).thenReturn(querySearchResult); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); when(searchContext.mapperService()).thenReturn(mapperService); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -296,6 +305,7 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); queryBuilder.add(termSubQuery); + queryBuilder.paginationDepth(10); Query query = queryBuilder.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); @@ -372,6 +382,10 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); when(searchContext.mapperService()).thenReturn(mapperService); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -382,6 +396,7 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); queryBuilder.add(QueryBuilders.matchAllQuery()); + queryBuilder.paginationDepth(10); Query query = queryBuilder.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); @@ -473,6 +488,7 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBool_thenFail() { queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + queryBuilder.paginationDepth(10); TermQueryBuilder termQuery3 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery().should(queryBuilder).should(termQuery3); @@ -578,6 +594,7 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolAndIncorrectStructur queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + queryBuilder.paginationDepth(10); BooleanQuery.Builder builder = new BooleanQuery.Builder(); builder.add(Queries.newNonNestedFilter(), BooleanClause.Occur.FILTER) @@ -694,6 +711,7 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolBecauseOfNested_then queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + queryBuilder.paginationDepth(10); BooleanQuery.Builder builder = new BooleanQuery.Builder(); builder.add(queryBuilder.toQuery(mockQueryShardContext), BooleanClause.Occur.MUST) @@ -868,6 +886,10 @@ public void testAggregations_whenMetricAggregation_thenSuccessful() { when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); when(searchContext.mapperService()).thenReturn(mapperService); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); LinkedList collectors = new LinkedList<>(); @@ -881,6 +903,7 @@ public void testAggregations_whenMetricAggregation_thenSuccessful() { TermQueryBuilder termSubQuery2 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2); queryBuilder.add(termSubQuery1); queryBuilder.add(termSubQuery2); + queryBuilder.paginationDepth(10); Query query = queryBuilder.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); @@ -965,6 +988,10 @@ public void testAliasWithFilter_whenHybridWrappedIntoBoolBecauseOfIndexAlias_the when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); when(searchContext.mapperService()).thenReturn(mapperService); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -974,6 +1001,7 @@ public void testAliasWithFilter_whenHybridWrappedIntoBoolBecauseOfIndexAlias_the queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + queryBuilder.paginationDepth(10); Query termFilter = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1).toQuery(mockQueryShardContext); BooleanQuery.Builder builder = new BooleanQuery.Builder(); diff --git a/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java b/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java index be9dbc2cc..6c2377c6f 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java @@ -6,20 +6,28 @@ import lombok.SneakyThrows; import org.apache.lucene.search.Query; +import org.opensearch.Version; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.common.UUIDs; +import org.opensearch.common.settings.Settings; +import org.opensearch.index.IndexSettings; import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.remote.RemoteStoreEnums; import org.opensearch.neuralsearch.query.HybridQuery; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; import org.opensearch.search.internal.SearchContext; import java.util.List; +import java.util.Map; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.opensearch.index.remote.RemoteStoreEnums.PathType.HASHED_PREFIX; public class HybridQueryUtilTests extends OpenSearchQueryTestCase { @@ -45,7 +53,8 @@ public void testIsHybridQueryCheck_whenQueryIsHybridQueryInstance_thenSuccess() .rewrite(mockQueryShardContext) .toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext) - ) + ), + 10 ); SearchContext searchContext = mock(SearchContext.class); @@ -58,13 +67,17 @@ public void testIsHybridQueryCheck_whenHybridWrappedIntoBoolAndNoNested_thenSucc MapperService mapperService = createMapperService(); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + IndexMetadata indexMetadata = getIndexMetadata(); + Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); hybridQueryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT)); hybridQueryBuilder.add( QueryBuilders.rangeQuery(RANGE_FIELD).from(FROM_TEXT).to(TO_TEXT).rewrite(mockQueryShardContext).rewrite(mockQueryShardContext) ); - + hybridQueryBuilder.paginationDepth(10); Query booleanQuery = QueryBuilders.boolQuery() .should(hybridQueryBuilder) .should(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT)) @@ -97,4 +110,25 @@ public void testIsHybridQueryCheck_whenNoHybridQuery_thenSuccess() { assertFalse(HybridQueryUtil.isHybridQuery(booleanQuery, searchContext)); } + + private static IndexMetadata getIndexMetadata() { + Map remoteCustomData = Map.of( + RemoteStoreEnums.PathType.NAME, + HASHED_PREFIX.name(), + RemoteStoreEnums.PathHashAlgorithm.NAME, + RemoteStoreEnums.PathHashAlgorithm.FNV_1A_BASE64.name(), + IndexMetadata.TRANSLOG_METADATA_KEY, + "false" + ); + Settings idxSettings = Settings.builder() + .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) + .put(IndexMetadata.SETTING_INDEX_UUID, UUIDs.randomBase64UUID()) + .build(); + IndexMetadata indexMetadata = new IndexMetadata.Builder("test").settings(idxSettings) + .numberOfShards(1) + .numberOfReplicas(0) + .putCustom(IndexMetadata.REMOTE_STORE_CUSTOM_KEY, remoteCustomData) + .build(); + return indexMetadata; + } } diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 08628c247..1279b33ef 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -600,13 +600,11 @@ protected Map search( if (requestParams != null && !requestParams.isEmpty()) { requestParams.forEach(request::addParameter); } - logger.info("Sorting request " + builder.toString()); request.setJsonEntity(builder.toString()); Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); String responseBody = EntityUtils.toString(response.getEntity()); - logger.info("Response " + responseBody); return XContentHelper.convertToMap(XContentType.JSON.xContent(), responseBody, false); } From db56f5ab3074ef8235fc5e0cb88d485227277ffd Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Thu, 2 Jan 2025 09:36:42 -0800 Subject: [PATCH 02/13] Remove unwanted code Signed-off-by: Varun Jain --- .../opensearch/neuralsearch/query/HybridQueryIT.java | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index d13612c61..89fe58808 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -801,9 +801,6 @@ public void testPaginationOnSingleShard_whenConcurrentSearchEnabled_thenSuccessf initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); - // testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); - // testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); - // testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); } finally { wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE); } @@ -816,9 +813,6 @@ public void testPaginationOnSingleShard_whenConcurrentSearchDisabled_thenSuccess initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); - // testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); - // testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); - // testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); } finally { wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE); } @@ -831,9 +825,6 @@ public void testPaginationOnMultipleShard_whenConcurrentSearchEnabled_thenSucces initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); - // testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); - // testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail(TEST_MULTI_DOC_INDEX_NAME); - // testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail(TEST_MULTI_DOC_INDEX_NAME); } finally { wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, SEARCH_PIPELINE); } @@ -846,9 +837,6 @@ public void testPaginationOnMultipleShard_whenConcurrentSearchDisabled_thenSucce initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); - // testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); - // testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail(TEST_MULTI_DOC_INDEX_NAME); - // testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail(TEST_MULTI_DOC_INDEX_NAME); } finally { wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, SEARCH_PIPELINE); } From 8fa8fc6bfd569f1e0a5ad2c6ae4cd0a0f5cec3d4 Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Fri, 3 Jan 2025 14:30:15 -0800 Subject: [PATCH 03/13] Adding hybrid query context dto Signed-off-by: Varun Jain --- .../NormalizationProcessorWorkflow.java | 6 +-- .../neuralsearch/query/HybridQuery.java | 20 +++++----- .../query/HybridQueryBuilder.java | 3 +- .../query/HybridQueryContext.java | 16 ++++++++ .../search/query/HybridCollectorManager.java | 11 ++++-- .../query/HybridQueryPhaseSearcher.java | 6 +-- .../neuralsearch/query/HybridQueryTests.java | 32 +++++++++------- .../query/HybridQueryWeightTests.java | 6 +-- .../HybridAggregationProcessorTests.java | 5 ++- .../query/HybridCollectorManagerTests.java | 37 +++++++++++-------- .../util/HybridQueryUtilTests.java | 3 +- 11 files changed, 89 insertions(+), 56 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/query/HybridQueryContext.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index 0bedf0e64..db3747a13 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -322,14 +322,14 @@ private void updateOriginalFetchResults( int trimmedLengthOfSearchHits = topDocs.scoreDocs.length - fromValueForSingleShard; // iterate over the normalized/combined scores, that solves (1) and (3) SearchHit[] updatedSearchHitArray = new SearchHit[trimmedLengthOfSearchHits]; - for (int i = fromValueForSingleShard; i < topDocs.scoreDocs.length; i++) { + for (int i = 0; i < trimmedLengthOfSearchHits; i++) { // Read topDocs after the desired from length - ScoreDoc scoreDoc = topDocs.scoreDocs[i]; + ScoreDoc scoreDoc = topDocs.scoreDocs[i + fromValueForSingleShard]; // get fetched hit content by doc_id SearchHit searchHit = docIdToSearchHit.get(scoreDoc.doc); // update score to normalized/combined value (3) searchHit.score(scoreDoc.score); - updatedSearchHitArray[i - fromValueForSingleShard] = searchHit; + updatedSearchHitArray[i] = searchHit; } SearchHits updatedSearchHits = new SearchHits( updatedSearchHitArray, diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java index fb3b6dd00..f08120ce9 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java @@ -15,7 +15,6 @@ import java.util.Objects; import java.util.concurrent.Callable; -import lombok.Getter; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.IndexSearcher; @@ -32,23 +31,22 @@ * Implementation of Query interface for type "hybrid". It allows execution of multiple sub-queries and collect individual * scores for each sub-query. */ -@Getter public final class HybridQuery extends Query implements Iterable { private final List subQueries; - private int paginationDepth; + private final HybridQueryContext queryContext; /** * Create new instance of hybrid query object based on collection of sub queries and filter query * @param subQueries collection of queries that are executed individually and contribute to a final list of combined scores * @param filterQueries list of filters that will be applied to each sub query. Each filter from the list is added as bool "filter" clause. If this is null sub queries will be executed as is */ - public HybridQuery(final Collection subQueries, final List filterQueries, final int paginationDepth) { + public HybridQuery(final Collection subQueries, final List filterQueries, final HybridQueryContext hybridQueryContext) { Objects.requireNonNull(subQueries, "collection of queries must not be null"); if (subQueries.isEmpty()) { throw new IllegalArgumentException("collection of queries must not be empty"); } - if (paginationDepth == 0) { + if (hybridQueryContext.getPaginationDepth() == 0) { throw new IllegalArgumentException("pagination_depth must not be zero"); } if (Objects.isNull(filterQueries) || filterQueries.isEmpty()) { @@ -63,11 +61,11 @@ public HybridQuery(final Collection subQueries, final List filterQ } this.subQueries = modifiedSubQueries; } - this.paginationDepth = paginationDepth; + this.queryContext = hybridQueryContext; } - public HybridQuery(final Collection subQueries, final int paginationDepth) { - this(subQueries, List.of(), paginationDepth); + public HybridQuery(final Collection subQueries, final HybridQueryContext hybridQueryContext) { + this(subQueries, List.of(), hybridQueryContext); } /** @@ -135,7 +133,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { return super.rewrite(indexSearcher); } final List rewrittenSubQueries = manager.getQueriesAfterRewrite(collectors); - return new HybridQuery(rewrittenSubQueries, paginationDepth); + return new HybridQuery(rewrittenSubQueries, queryContext); } private Void rewriteQuery(Query query, HybridQueryExecutorCollector> collector) { @@ -197,6 +195,10 @@ public Collection getSubQueries() { return Collections.unmodifiableCollection(subQueries); } + public HybridQueryContext getQueryContext() { + return queryContext; + } + /** * Create the Weight used to score this query * diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java index 2e6c1ce47..dbff0cde8 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java @@ -129,7 +129,8 @@ protected Query doToQuery(QueryShardContext queryShardContext) throws IOExceptio return Queries.newMatchNoDocsQuery(String.format(Locale.ROOT, "no clauses for %s query", NAME)); } validatePaginationDepth(paginationDepth, queryShardContext); - return new HybridQuery(queryCollection, paginationDepth); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(paginationDepth).build(); + return new HybridQuery(queryCollection, hybridQueryContext); } /** diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryContext.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryContext.java new file mode 100644 index 000000000..9a9877dac --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryContext.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; + +@AllArgsConstructor +@Builder +@Getter +public class HybridQueryContext { + private int paginationDepth; +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index 7ff737990..8b8235450 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -483,17 +483,20 @@ private ReduceableSearchResult reduceSearchResults(final List BooleanClause.Occur.FILTER == clause.getOccur()) .map(BooleanClause::getQuery) .collect(Collectors.toList()); - HybridQuery hybridQueryWithFilter = new HybridQuery( - hybridQuery.getSubQueries(), - filterQueries, - hybridQuery.getPaginationDepth() - ); + HybridQuery hybridQueryWithFilter = new HybridQuery(hybridQuery.getSubQueries(), filterQueries, hybridQuery.getQueryContext()); return hybridQueryWithFilter; } return query; diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java index ea007b9db..8e429e270 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java @@ -73,18 +73,18 @@ public void testQueryBasics_whenMultipleDifferentQueries_thenSuccessful() { HybridQuery query1 = new HybridQuery( List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), - 10 + new HybridQueryContext(10) ); HybridQuery query2 = new HybridQuery( List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), - 10 + new HybridQueryContext(10) ); HybridQuery query3 = new HybridQuery( List.of( QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_ANOTHER_QUERY_TEXT).toQuery(mockQueryShardContext) ), - 10 + new HybridQueryContext(10) ); QueryUtils.check(query1); QueryUtils.checkEqual(query1, query2); @@ -99,7 +99,7 @@ public void testQueryBasics_whenMultipleDifferentQueries_thenSuccessful() { countOfQueries++; } assertEquals(2, countOfQueries); - assertEquals(10, query3.getPaginationDepth()); + assertEquals(10, query3.getQueryContext().getPaginationDepth()); } @SneakyThrows @@ -126,14 +126,17 @@ public void testRewrite_whenRewriteQuery_thenSuccessful() { // Test with TermQuery HybridQuery hybridQueryWithTerm = new HybridQuery( List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), - 10 + new HybridQueryContext(10) ); Query rewritten = hybridQueryWithTerm.rewrite(reader); // term query is the same after we rewrite it assertSame(hybridQueryWithTerm, rewritten); // Test empty query list - IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of(), 10)); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> new HybridQuery(List.of(), new HybridQueryContext(10)) + ); assertThat(exception.getMessage(), containsString("collection of queries must not be empty")); w.close(); @@ -167,7 +170,7 @@ public void testWithRandomDocuments_whenMultipleTermSubQueriesWithMatch_thenRetu HybridQuery query = new HybridQuery( List.of(new TermQuery(new Term(TEXT_FIELD_NAME, field1Value)), new TermQuery(new Term(TEXT_FIELD_NAME, field2Value))), - 10 + new HybridQueryContext(10) ); // executing search query, getting up to 3 docs in result TopDocs hybridQueryResult = searcher.search(query, 3); @@ -213,7 +216,7 @@ public void testWithRandomDocuments_whenOneTermSubQueryWithoutMatch_thenReturnSu DirectoryReader reader = DirectoryReader.open(w); IndexSearcher searcher = newSearcher(reader); - HybridQuery query = new HybridQuery(List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))), 10); + HybridQuery query = new HybridQuery(List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))), new HybridQueryContext(10)); // executing search query, getting up to 3 docs in result TopDocs hybridQueryResult = searcher.search(query, 3); @@ -250,7 +253,7 @@ public void testWithRandomDocuments_whenMultipleTermSubQueriesWithoutMatch_thenR HybridQuery query = new HybridQuery( List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT)), new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))), - 10 + new HybridQueryContext(10) ); // executing search query, getting up to 3 docs in result TopDocs hybridQueryResult = searcher.search(query, 3); @@ -264,7 +267,10 @@ public void testWithRandomDocuments_whenMultipleTermSubQueriesWithoutMatch_thenR @SneakyThrows public void testWithRandomDocuments_whenNoSubQueries_thenFail() { - IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of(), 10)); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> new HybridQuery(List.of(), new HybridQueryContext(10)) + ); assertThat(exception.getMessage(), containsString("collection of queries must not be empty")); } @@ -274,7 +280,7 @@ public void testWithRandomDocuments_whenPaginationDepthIsZero_thenFail() { IllegalArgumentException.class, () -> new HybridQuery( List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT)), new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))), - 0 + new HybridQueryContext(0) ) ); assertThat(exception.getMessage(), containsString("pagination_depth must not be zero")); @@ -294,7 +300,7 @@ public void testToString_whenCallQueryToString_thenSuccessful() { .should(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_ANOTHER_QUERY_TEXT)) .toQuery(mockQueryShardContext) ), - 10 + new HybridQueryContext(10) ); String queryString = query.toString(TEXT_FIELD_NAME); @@ -315,7 +321,7 @@ public void testFilter_whenSubQueriesWithFilterPassed_thenSuccessful() { QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_ANOTHER_QUERY_TEXT).toQuery(mockQueryShardContext) ), List.of(filter), - 10 + new HybridQueryContext(10) ); QueryUtils.check(hybridQuery); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java index 5096a8ab7..024c5e6e3 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java @@ -62,7 +62,7 @@ public void testScorerIterator_whenExecuteQuery_thenScorerIteratorSuccessful() { IndexReader reader = DirectoryReader.open(w); HybridQuery hybridQueryWithTerm = new HybridQuery( List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), - 10 + new HybridQueryContext(10) ); IndexSearcher searcher = newSearcher(reader); Weight weight = hybridQueryWithTerm.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f); @@ -119,7 +119,7 @@ public void testSubQueries_whenMultipleEqualSubQueries_thenSuccessful() { .toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext) ), - 10 + new HybridQueryContext(10) ); IndexSearcher searcher = newSearcher(reader); Weight weight = hybridQueryWithTerm.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f); @@ -167,7 +167,7 @@ public void testExplain_whenCallExplain_thenSuccessful() { IndexReader reader = DirectoryReader.open(w); HybridQuery hybridQueryWithTerm = new HybridQuery( List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), - 10 + new HybridQueryContext(10) ); IndexSearcher searcher = newSearcher(reader); Weight weight = searcher.createWeight(hybridQueryWithTerm, ScoreMode.COMPLETE, 1.0f); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java index 6d432aba6..28e3af663 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java @@ -20,6 +20,7 @@ import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.neuralsearch.query.HybridQuery; +import org.opensearch.neuralsearch.query.HybridQueryContext; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; import org.opensearch.search.DocValueFormat; import org.opensearch.search.SearchShardTarget; @@ -69,7 +70,7 @@ public void testCollectorManager_whenHybridQueryAndNotConcurrentSearch_thenSucce TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 10); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), new HybridQueryContext(10)); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -129,7 +130,7 @@ public void testCollectorManager_whenHybridQueryAndConcurrentSearch_thenSuccessf TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 10); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), new HybridQueryContext(10)); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index d0b91df37..83d29fa73 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -44,6 +44,7 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.ParsedQuery; import org.opensearch.neuralsearch.query.HybridQuery; +import org.opensearch.neuralsearch.query.HybridQueryContext; import org.opensearch.neuralsearch.query.HybridQueryWeight; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; import org.opensearch.neuralsearch.search.collector.HybridTopScoreDocCollector; @@ -93,7 +94,7 @@ public void testNewCollector_whenNotConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 10); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), new HybridQueryContext(10)); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -124,7 +125,7 @@ public void testNewCollector_whenConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 10); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), new HybridQueryContext(10)); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -155,7 +156,7 @@ public void testPostFilter_whenNotConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 10); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), new HybridQueryContext(10)); QueryBuilder postFilterQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, "world"); ParsedQuery parsedQuery = new ParsedQuery(postFilterQuery.toQuery(mockQueryShardContext)); @@ -199,7 +200,7 @@ public void testPostFilter_whenConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 10); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), new HybridQueryContext(10)); QueryBuilder postFilterQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, "world"); Query pfQuery = postFilterQuery.toQuery(mockQueryShardContext); @@ -245,7 +246,7 @@ public void testReduce_whenMatchedDocs_thenSuccessful() { HybridQuery hybridQueryWithTerm = new HybridQuery( List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext)), - 10 + new HybridQueryContext(10) ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -346,7 +347,7 @@ public void testNewCollector_whenNotConcurrentSearchAndSortingIsApplied_thenSucc TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 10); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), new HybridQueryContext(10)); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -383,7 +384,7 @@ public void testNewCollector_whenNotConcurrentSearchAndSortingAndSearchAfterAreA TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 10); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), new HybridQueryContext(10)); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -414,7 +415,10 @@ public void testReduce_whenMatchedDocsAndSortingIsApplied_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); - HybridQuery hybridQueryWithMatchAll = new HybridQuery(List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)), 10); + HybridQuery hybridQueryWithMatchAll = new HybridQuery( + List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)), + new HybridQueryContext(10) + ); when(searchContext.query()).thenReturn(hybridQueryWithMatchAll); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); @@ -512,7 +516,7 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) ), - 10 + new HybridQueryContext(10) ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -598,7 +602,10 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); - HybridQuery hybridQueryWithTerm = new HybridQuery(List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)), 10); + HybridQuery hybridQueryWithTerm = new HybridQuery( + List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)), + new HybridQueryContext(10) + ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); @@ -728,7 +735,7 @@ public void testReduceAndRescore_whenMatchedDocsAndRescoreContextPresent_thenSuc QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) ), - 10 + new HybridQueryContext(10) ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -847,7 +854,7 @@ public void testRescoreWithConcurrentSegmentSearch_whenMatchedDocsAndRescore_the QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY3).toQuery(mockQueryShardContext) ), - 10 + new HybridQueryContext(10) ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -991,7 +998,7 @@ public void testReduceAndRescore_whenRescorerThrowsException_thenFail() { QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) ), - 10 + new HybridQueryContext(10) ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -1058,7 +1065,7 @@ public void testCreateCollectorManager_whenFromAreEqualToZeroAndPaginationDepthI when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); // pagination_depth=10 - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 10); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), new HybridQueryContext(10)); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -1091,7 +1098,7 @@ public void testScrollWithHybridQuery_thenFail() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 10); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), new HybridQueryContext(10)); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); diff --git a/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java b/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java index 6c2377c6f..0a733550a 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java @@ -18,6 +18,7 @@ import org.opensearch.index.remote.RemoteStoreEnums; import org.opensearch.neuralsearch.query.HybridQuery; import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.neuralsearch.query.HybridQueryContext; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; import org.opensearch.search.internal.SearchContext; @@ -54,7 +55,7 @@ public void testIsHybridQueryCheck_whenQueryIsHybridQueryInstance_thenSuccess() .toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext) ), - 10 + new HybridQueryContext(10) ); SearchContext searchContext = mock(SearchContext.class); From 14749059533872706387a804818694cb8376127d Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Mon, 6 Jan 2025 10:02:19 -0800 Subject: [PATCH 04/13] Adding javadoc in hybridquerycontext and addressing few comments from review Signed-off-by: Varun Jain --- .../query/HybridQueryContext.java | 7 ++- .../search/query/HybridCollectorManager.java | 2 +- .../HybridAggregationProcessorTests.java | 6 ++- .../query/HybridCollectorManagerTests.java | 49 +++++++++++++------ .../util/HybridQueryUtilTests.java | 3 +- 5 files changed, 46 insertions(+), 21 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryContext.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryContext.java index 9a9877dac..3be741bcd 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryContext.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryContext.java @@ -4,13 +4,16 @@ */ package org.opensearch.neuralsearch.query; -import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Getter; +import lombok.NonNull; -@AllArgsConstructor +/** + * Class that holds the low level information of hybrid query in the form of context + */ @Builder @Getter public class HybridQueryContext { + @NonNull private int paginationDepth; } diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index 8b8235450..1d75877ad 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -494,7 +494,7 @@ private static int getSubqueryResultsRetrievalSize(final SearchContext searchCon } /** - * Extract hybrid query from generic query object retrieved from search context + * Unwraps a HybridQuery from either a direct query or a nested BooleanQuery */ private static HybridQuery extractHybridQueryFromAbstractQuery(Query query) { HybridQuery hybridQuery; diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java index 28e3af663..0b896e33d 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java @@ -70,7 +70,8 @@ public void testCollectorManager_whenHybridQueryAndNotConcurrentSearch_thenSucce TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), new HybridQueryContext(10)); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -130,7 +131,8 @@ public void testCollectorManager_whenHybridQueryAndConcurrentSearch_thenSuccessf TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), new HybridQueryContext(10)); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index 83d29fa73..e18ab6b12 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -94,7 +94,8 @@ public void testNewCollector_whenNotConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), new HybridQueryContext(10)); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -125,7 +126,8 @@ public void testNewCollector_whenConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), new HybridQueryContext(10)); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -156,7 +158,8 @@ public void testPostFilter_whenNotConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), new HybridQueryContext(10)); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); QueryBuilder postFilterQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, "world"); ParsedQuery parsedQuery = new ParsedQuery(postFilterQuery.toQuery(mockQueryShardContext)); @@ -200,7 +203,8 @@ public void testPostFilter_whenConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), new HybridQueryContext(10)); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); QueryBuilder postFilterQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, "world"); Query pfQuery = postFilterQuery.toQuery(mockQueryShardContext); @@ -243,10 +247,11 @@ public void testReduce_whenMatchedDocs_thenSuccessful() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); HybridQuery hybridQueryWithTerm = new HybridQuery( List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext)), - new HybridQueryContext(10) + hybridQueryContext ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -347,7 +352,9 @@ public void testNewCollector_whenNotConcurrentSearchAndSortingIsApplied_thenSucc TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), new HybridQueryContext(10)); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -384,7 +391,9 @@ public void testNewCollector_whenNotConcurrentSearchAndSortingAndSearchAfterAreA TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), new HybridQueryContext(10)); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -414,10 +423,11 @@ public void testReduce_whenMatchedDocsAndSortingIsApplied_thenSuccessful() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); HybridQuery hybridQueryWithMatchAll = new HybridQuery( List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)), - new HybridQueryContext(10) + hybridQueryContext ); when(searchContext.query()).thenReturn(hybridQueryWithMatchAll); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -510,13 +520,14 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); HybridQuery hybridQueryWithTerm = new HybridQuery( List.of( QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) ), - new HybridQueryContext(10) + hybridQueryContext ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -601,10 +612,11 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); HybridQuery hybridQueryWithTerm = new HybridQuery( List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)), - new HybridQueryContext(10) + hybridQueryContext ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -729,13 +741,14 @@ public void testReduceAndRescore_whenMatchedDocsAndRescoreContextPresent_thenSuc QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); HybridQuery hybridQueryWithTerm = new HybridQuery( List.of( QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) ), - new HybridQueryContext(10) + hybridQueryContext ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -847,6 +860,7 @@ public void testRescoreWithConcurrentSegmentSearch_whenMatchedDocsAndRescore_the QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); HybridQuery hybridQueryWithTerm = new HybridQuery( List.of( @@ -854,7 +868,7 @@ public void testRescoreWithConcurrentSegmentSearch_whenMatchedDocsAndRescore_the QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY3).toQuery(mockQueryShardContext) ), - new HybridQueryContext(10) + hybridQueryContext ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -992,13 +1006,14 @@ public void testReduceAndRescore_whenRescorerThrowsException_thenFail() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); HybridQuery hybridQueryWithTerm = new HybridQuery( List.of( QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) ), - new HybridQueryContext(10) + hybridQueryContext ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -1064,8 +1079,10 @@ public void testCreateCollectorManager_whenFromAreEqualToZeroAndPaginationDepthI TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + // pagination_depth=10 - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), new HybridQueryContext(10)); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -1098,7 +1115,9 @@ public void testScrollWithHybridQuery_thenFail() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), new HybridQueryContext(10)); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); diff --git a/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java b/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java index 0a733550a..ab882b388 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java @@ -43,6 +43,7 @@ public void testIsHybridQueryCheck_whenQueryIsHybridQueryInstance_thenSuccess() QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); HybridQuery query = new HybridQuery( List.of( @@ -55,7 +56,7 @@ public void testIsHybridQueryCheck_whenQueryIsHybridQueryInstance_thenSuccess() .toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext) ), - new HybridQueryContext(10) + hybridQueryContext ); SearchContext searchContext = mock(SearchContext.class); From d39a1b887312850fc164c8a958ee015aa6043884 Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Mon, 6 Jan 2025 10:32:35 -0800 Subject: [PATCH 05/13] rename hybrid query extraction method Signed-off-by: Varun Jain --- .../neuralsearch/search/query/HybridCollectorManager.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index 1d75877ad..66cf784d4 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -483,7 +483,7 @@ private ReduceableSearchResult reduceSearchResults(final List Date: Wed, 8 Jan 2025 13:48:24 -0800 Subject: [PATCH 06/13] Refactoring to optimize extractHybridQuery method calls Signed-off-by: Varun Jain --- .../neuralsearch/query/HybridQueryBuilder.java | 8 +++++--- .../search/query/HybridCollectorManager.java | 8 +++++--- .../search/query/HybridQueryPhaseSearcher.java | 12 ++---------- .../neuralsearch/util/HybridQueryUtil.java | 8 +++++++- 4 files changed, 19 insertions(+), 17 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java index dbff0cde8..8b706c1ca 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java @@ -59,7 +59,7 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder toQueries(Collection queryBuilders, Quer } private static void validatePaginationDepth(final int paginationDepth, final QueryShardContext queryShardContext) { - if (paginationDepth < LOWER_BOUND_OF_PAGINATION_DEPTH) { - throw new IllegalArgumentException(String.format(Locale.ROOT, "pagination_depth should be greater than 0")); + if (paginationDepth <= LOWER_BOUND_OF_PAGINATION_DEPTH) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "pagination_depth should be greater than %s", LOWER_BOUND_OF_PAGINATION_DEPTH) + ); } // compare pagination depth with OpenSearch setting index.max_result_window // see https://opensearch.org/docs/latest/install-and-configure/configuring-opensearch/index-settings/ diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index 66cf784d4..03290f421 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -54,6 +54,7 @@ import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createFieldDocStartStopElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createFieldDocDelimiterElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createSortFieldsForDelimiterResults; +import static org.opensearch.neuralsearch.util.HybridQueryUtil.isHybridQueryWrappedInBooleanQuery; /** * Collector manager based on HybridTopScoreDocCollector that allows users to parallelize counting the number of hits. @@ -483,7 +484,7 @@ private ReduceableSearchResult reduceSearchResults(final List clauseQuery.getQuery() instanceof HybridQuery); - } - @VisibleForTesting protected Query extractHybridQuery(final SearchContext searchContext, final Query query) { - if ((hasAliasFilter(query, searchContext) || hasNestedFieldOrNestedDocs(query, searchContext)) - && isWrappedHybridQuery(query) - && !((BooleanQuery) query).clauses().isEmpty()) { + if (isHybridQueryWrappedInBooleanQuery(searchContext, query)) { List booleanClauses = ((BooleanQuery) query).clauses(); if (!(booleanClauses.get(0).getQuery() instanceof HybridQuery)) { throw new IllegalStateException("cannot process hybrid query due to incorrect structure of top level query"); diff --git a/src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java b/src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java index d19985f5c..77c2e79cb 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java @@ -56,7 +56,7 @@ public static boolean hasNestedFieldOrNestedDocs(final Query query, final Search return searchContext.mapperService().hasNested() && new NestedHelper(searchContext.mapperService()).mightMatchNestedDocs(query); } - private static boolean isWrappedHybridQuery(final Query query) { + public static boolean isWrappedHybridQuery(final Query query) { return query instanceof BooleanQuery && ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery); } @@ -64,4 +64,10 @@ private static boolean isWrappedHybridQuery(final Query query) { public static boolean hasAliasFilter(final Query query, final SearchContext searchContext) { return Objects.nonNull(searchContext.aliasFilter()); } + + public static boolean isHybridQueryWrappedInBooleanQuery(final SearchContext searchContext, final Query query) { + return ((hasAliasFilter(query, searchContext) || hasNestedFieldOrNestedDocs(query, searchContext)) + && isWrappedHybridQuery(query) + && !((BooleanQuery) query).clauses().isEmpty()); + } } From 962fbaa9ab4674abb64496cac12ba1949e9fd03b Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Wed, 8 Jan 2025 14:04:59 -0800 Subject: [PATCH 07/13] Changes in tests to adapt with builder pattern in querybuilder Signed-off-by: Varun Jain --- .../opensearch/neuralsearch/bwc/HybridSearchIT.java | 1 - .../neuralsearch/query/HybridQueryBuilderTests.java | 12 ++++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java index 95924112c..44671ed4a 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java @@ -20,7 +20,6 @@ import static org.opensearch.neuralsearch.util.TestUtils.getModelId; import org.opensearch.index.query.QueryBuilder; -import org.opensearch.index.query.QueryBuilders; import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java index 6383f1ed0..0dfd2071c 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -236,11 +236,13 @@ public void testDoToQuery_whenPaginationDepthIsGreaterThan10000_thenBuildSuccess IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME) + NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder() + .fieldName(VECTOR_FIELD_NAME) .queryText(QUERY_TEXT) .modelId(MODEL_ID) .k(K) - .vectorSupplier(TEST_VECTOR_SUPPLIER); + .vectorSupplier(TEST_VECTOR_SUPPLIER) + .build(); queryBuilder.add(neuralQueryBuilder); IllegalArgumentException exception = expectThrows( @@ -270,11 +272,13 @@ public void testDoToQuery_whenPaginationDepthIsLessThanZero_thenBuildSuccessfull IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME) + NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder() + .fieldName(VECTOR_FIELD_NAME) .queryText(QUERY_TEXT) .modelId(MODEL_ID) .k(K) - .vectorSupplier(TEST_VECTOR_SUPPLIER); + .vectorSupplier(TEST_VECTOR_SUPPLIER) + .build(); queryBuilder.add(neuralQueryBuilder); IllegalArgumentException exception = expectThrows( From d2a101d7f98dd0dbf1bd8840623481ba5860b72c Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Wed, 8 Jan 2025 14:39:07 -0800 Subject: [PATCH 08/13] Add mapper service mock in tests Signed-off-by: Varun Jain --- .../HybridAggregationProcessorTests.java | 5 +++ .../query/HybridCollectorManagerTests.java | 32 +++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java index 0b896e33d..acbc2148c 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java @@ -15,6 +15,7 @@ import org.opensearch.action.OriginalIndices; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.QueryShardContext; @@ -74,6 +75,8 @@ public void testCollectorManager_whenHybridQueryAndNotConcurrentSearch_thenSucce HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); when(searchContext.query()).thenReturn(hybridQuery); + MapperService mapperService = mock(MapperService.class); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexSearcher.getIndexReader()).thenReturn(indexReader); @@ -135,6 +138,8 @@ public void testCollectorManager_whenHybridQueryAndConcurrentSearch_thenSuccessf HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); when(searchContext.query()).thenReturn(hybridQuery); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexSearcher.getIndexReader()).thenReturn(indexReader); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index e18ab6b12..e0d95f24e 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -36,6 +36,7 @@ import org.apache.lucene.tests.analysis.MockAnalyzer; import org.opensearch.common.lucene.search.FilteredCollector; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.BoostingQueryBuilder; import org.opensearch.index.query.QueryBuilders; @@ -91,12 +92,14 @@ public class HybridCollectorManagerTests extends OpenSearchQueryTestCase { public void testNewCollector_whenNotConcurrentSearch_thenSuccessful() { SearchContext searchContext = mock(SearchContext.class); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + MapperService mapperService = createMapperService(); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); + when(searchContext.mapperService()).thenReturn(mapperService); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); @@ -129,6 +132,8 @@ public void testNewCollector_whenConcurrentSearch_thenSuccessful() { HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); @@ -161,6 +166,8 @@ public void testPostFilter_whenNotConcurrentSearch_thenSuccessful() { HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); QueryBuilder postFilterQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, "world"); ParsedQuery parsedQuery = new ParsedQuery(postFilterQuery.toQuery(mockQueryShardContext)); searchContext.parsedQuery(parsedQuery); @@ -206,6 +213,9 @@ public void testPostFilter_whenConcurrentSearch_thenSuccessful() { HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); + QueryBuilder postFilterQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, "world"); Query pfQuery = postFilterQuery.toQuery(mockQueryShardContext); ParsedQuery parsedQuery = new ParsedQuery(pfQuery); @@ -249,6 +259,9 @@ public void testReduce_whenMatchedDocs_thenSuccessful() { when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build(); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); + HybridQuery hybridQueryWithTerm = new HybridQuery( List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext)), hybridQueryContext @@ -357,6 +370,8 @@ public void testNewCollector_whenNotConcurrentSearchAndSortingIsApplied_thenSucc HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); when(searchContext.query()).thenReturn(hybridQuery); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexSearcher.getIndexReader()).thenReturn(indexReader); @@ -396,6 +411,8 @@ public void testNewCollector_whenNotConcurrentSearchAndSortingAndSearchAfterAreA HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); when(searchContext.query()).thenReturn(hybridQuery); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexSearcher.getIndexReader()).thenReturn(indexReader); @@ -437,6 +454,9 @@ public void testReduce_whenMatchedDocsAndSortingIsApplied_thenSuccessful() { when(searchContext.searcher()).thenReturn(indexSearcher); when(searchContext.size()).thenReturn(1); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); @@ -530,6 +550,8 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD hybridQueryContext ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexReader.numDocs()).thenReturn(3); @@ -619,6 +641,8 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD hybridQueryContext ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexReader.numDocs()).thenReturn(2); @@ -751,6 +775,8 @@ public void testReduceAndRescore_whenMatchedDocsAndRescoreContextPresent_thenSuc hybridQueryContext ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexReader.numDocs()).thenReturn(3); @@ -871,6 +897,8 @@ public void testRescoreWithConcurrentSegmentSearch_whenMatchedDocsAndRescore_the hybridQueryContext ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexReader.numDocs()).thenReturn(2); @@ -1016,6 +1044,8 @@ public void testReduceAndRescore_whenRescorerThrowsException_thenFail() { hybridQueryContext ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexReader.numDocs()).thenReturn(3); @@ -1085,6 +1115,8 @@ public void testCreateCollectorManager_whenFromAreEqualToZeroAndPaginationDepthI HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), hybridQueryContext); when(searchContext.query()).thenReturn(hybridQuery); + MapperService mapperService = createMapperService(); + when(searchContext.mapperService()).thenReturn(mapperService); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); when(indexSearcher.getIndexReader()).thenReturn(indexReader); From 1c0bdebc32ea042c88e931f139a08c51e0676137 Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Wed, 8 Jan 2025 16:04:10 -0800 Subject: [PATCH 09/13] Fix error message of index.max_result_window setting Signed-off-by: Varun Jain --- .../neuralsearch/query/HybridQueryBuilder.java | 2 +- .../neuralsearch/util/HybridQueryUtil.java | 12 +++++++++--- .../neuralsearch/query/HybridQueryBuilderTests.java | 5 ++++- .../opensearch/neuralsearch/query/HybridQueryIT.java | 2 +- 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java index 8b706c1ca..653caf07c 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java @@ -337,7 +337,7 @@ private static void validatePaginationDepth(final int paginationDepth, final Que int maxResultWindowIndexSetting = queryShardContext.getIndexSettings().getMaxResultWindow(); if (paginationDepth > maxResultWindowIndexSetting) { throw new IllegalArgumentException( - String.format(Locale.ROOT, "pagination_depth should be less than index.max_result_window setting") + String.format(Locale.ROOT, "pagination_depth should be less than or equal to index.max_result_window setting") ); } } diff --git a/src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java b/src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java index 77c2e79cb..e8794131f 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java @@ -20,6 +20,9 @@ @NoArgsConstructor(access = AccessLevel.PRIVATE) public class HybridQueryUtil { + /** + * This method validates whether the query object is an instance of hybrid query + */ public static boolean isHybridQuery(final Query query, final SearchContext searchContext) { if (query instanceof HybridQuery) { return true; @@ -52,19 +55,22 @@ public static boolean isHybridQuery(final Query query, final SearchContext searc return false; } - public static boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) { + private static boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) { return searchContext.mapperService().hasNested() && new NestedHelper(searchContext.mapperService()).mightMatchNestedDocs(query); } - public static boolean isWrappedHybridQuery(final Query query) { + private static boolean isWrappedHybridQuery(final Query query) { return query instanceof BooleanQuery && ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery); } - public static boolean hasAliasFilter(final Query query, final SearchContext searchContext) { + private static boolean hasAliasFilter(final Query query, final SearchContext searchContext) { return Objects.nonNull(searchContext.aliasFilter()); } + /** + * This method checks whether hybrid query is wrapped under boolean query object + */ public static boolean isHybridQueryWrappedInBooleanQuery(final SearchContext searchContext, final Query query) { return ((hasAliasFilter(query, searchContext) || hasNestedFieldOrNestedDocs(query, searchContext)) && isWrappedHybridQuery(query) diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java index 0dfd2071c..2385300a6 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -249,7 +249,10 @@ public void testDoToQuery_whenPaginationDepthIsGreaterThan10000_thenBuildSuccess IllegalArgumentException.class, () -> queryBuilder.doToQuery(mockQueryShardContext) ); - assertThat(exception.getMessage(), containsString("pagination_depth should be less than index.max_result_window setting")); + assertThat( + exception.getMessage(), + containsString("pagination_depth should be less than or equal to index.max_result_window setting") + ); } @SneakyThrows diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 89fe58808..c3087a1e4 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -968,7 +968,7 @@ public void testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail() { org.hamcrest.MatcherAssert.assertThat( responseException.getMessage(), - allOf(containsString("pagination_depth should be less than index.max_result_window setting")) + allOf(containsString("pagination_depth should be less than or equal to index.max_result_window setting")) ); } finally { wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE); From eae90dc5f9d74a18f5daad68cc5a80c93a34def6 Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Wed, 8 Jan 2025 16:15:09 -0800 Subject: [PATCH 10/13] Fix error message of index.max_result_window setting Signed-off-by: Varun Jain --- .../opensearch/neuralsearch/query/HybridQueryBuilder.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java index 653caf07c..a8486d227 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java @@ -22,6 +22,7 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexSettings; import org.opensearch.index.query.AbstractQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryRewriteContext; @@ -337,7 +338,11 @@ private static void validatePaginationDepth(final int paginationDepth, final Que int maxResultWindowIndexSetting = queryShardContext.getIndexSettings().getMaxResultWindow(); if (paginationDepth > maxResultWindowIndexSetting) { throw new IllegalArgumentException( - String.format(Locale.ROOT, "pagination_depth should be less than or equal to index.max_result_window setting") + String.format( + Locale.ROOT, + "pagination_depth should be less than or equal to %s setting", + IndexSettings.MAX_RESULT_WINDOW_SETTING.getKey() + ) ); } } From 5e0477290c96a1f27163f315f2ffebf059850a88 Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Fri, 10 Jan 2025 08:39:40 -0800 Subject: [PATCH 11/13] Fixing validation condition for lower bound Signed-off-by: Varun Jain --- CHANGELOG.md | 2 +- .../neuralsearch/query/HybridQueryBuilder.java | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bcfb30322..6203e6e88 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x) ### Features -- Pagination in Hybrid query ([#963](https://github.com/opensearch-project/neural-search/pull/963)) +- Pagination in Hybrid query ([#1048](https://github.com/opensearch-project/neural-search/pull/1048)) ### Enhancements - Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970)) - Support new knn query parameter expand_nested ([#1013](https://github.com/opensearch-project/neural-search/pull/1013)) diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java index 562363aff..de4680dc1 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java @@ -294,7 +294,12 @@ protected boolean doEquals(HybridQueryBuilder obj) { */ @Override protected int doHashCode() { - return Objects.hash(queries, paginationDepth); + List hashValues = new ArrayList<>(); + hashValues.add(queries); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + hashValues.add(paginationDepth); + } + return Objects.hash(hashValues.toArray()); } /** @@ -326,7 +331,7 @@ private Collection toQueries(Collection queryBuilders, Quer } private static void validatePaginationDepth(final int paginationDepth, final QueryShardContext queryShardContext) { - if (paginationDepth <= LOWER_BOUND_OF_PAGINATION_DEPTH) { + if (paginationDepth < LOWER_BOUND_OF_PAGINATION_DEPTH) { throw new IllegalArgumentException( String.format(Locale.ROOT, "pagination_depth should be greater than %s", LOWER_BOUND_OF_PAGINATION_DEPTH) ); From 33e328e03ddcd14d522a36e74688e90bd15efff5 Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Fri, 10 Jan 2025 08:59:51 -0800 Subject: [PATCH 12/13] fix tests Signed-off-by: Varun Jain --- .../opensearch/neuralsearch/query/HybridQueryBuilderTests.java | 1 + 1 file changed, 1 insertion(+) diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java index c22d174ec..89f3459e1 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -595,6 +595,7 @@ public void testStreams_whenWrittingToStream_thenSuccessful() { } public void testHashAndEquals_whenSameOrIdenticalObject_thenReturnEqual() { + setUpClusterService(); HybridQueryBuilder hybridQueryBuilderBaseline = new HybridQueryBuilder(); hybridQueryBuilderBaseline.add( NeuralQueryBuilder.builder() From 372ecbff387349c59262ab76a6a275a4e27cc33d Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Fri, 10 Jan 2025 15:36:21 -0800 Subject: [PATCH 13/13] Removing version check from doEquals and doHashCode method Signed-off-by: Varun Jain --- .../neuralsearch/query/HybridQueryBuilder.java | 11 ++--------- .../neuralsearch/query/HybridQueryBuilderTests.java | 1 - 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java index de4680dc1..55f6b4a96 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java @@ -282,9 +282,7 @@ protected boolean doEquals(HybridQueryBuilder obj) { } EqualsBuilder equalsBuilder = new EqualsBuilder(); equalsBuilder.append(queries, obj.queries); - if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { - equalsBuilder.append(paginationDepth, obj.paginationDepth); - } + equalsBuilder.append(paginationDepth, obj.paginationDepth); return equalsBuilder.isEquals(); } @@ -294,12 +292,7 @@ protected boolean doEquals(HybridQueryBuilder obj) { */ @Override protected int doHashCode() { - List hashValues = new ArrayList<>(); - hashValues.add(queries); - if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { - hashValues.add(paginationDepth); - } - return Objects.hash(hashValues.toArray()); + return Objects.hash(queries, paginationDepth); } /** diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java index 89f3459e1..c22d174ec 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -595,7 +595,6 @@ public void testStreams_whenWrittingToStream_thenSuccessful() { } public void testHashAndEquals_whenSameOrIdenticalObject_thenReturnEqual() { - setUpClusterService(); HybridQueryBuilder hybridQueryBuilderBaseline = new HybridQueryBuilder(); hybridQueryBuilderBaseline.add( NeuralQueryBuilder.builder()