Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pagination in Hybrid query #1048

Merged
merged 19 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x)
### Features
- Pagination in Hybrid query ([#963](https://github.com/opensearch-project/neural-search/pull/963))
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
### Enhancements
- Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970))
- Support new knn query parameter expand_nested ([#1013](https://github.com/opensearch-project/neural-search/pull/1013))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@

import org.opensearch.index.query.MatchQueryBuilder;

import static org.opensearch.knn.index.query.KNNQueryBuilder.EXPAND_NESTED_FIELD;
import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion;
import static org.opensearch.neuralsearch.util.TestUtils.getModelId;
import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER;
import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public final class MinClusterVersionUtil {

private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0;
private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version.V_2_14_0;
private static final Version MINIMAL_SUPPORTED_VERSION_PAGINATION_IN_HYBRID_QUERY = Version.V_2_19_0;

// Note this minimal version will act as a override
private static final Map<String, Version> MINIMAL_VERSION_NEURAL = ImmutableMap.<String, Version>builder()
Expand All @@ -38,6 +39,10 @@ public static boolean isClusterOnOrAfterMinReqVersionForRadialSearch() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH);
}

public static boolean isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_PAGINATION_IN_HYBRID_QUERY);
}

public static boolean isClusterOnOrAfterMinReqVersion(String key) {
Version version;
if (MINIMAL_VERSION_NEURAL.containsKey(key)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ private <Result extends SearchPhaseResult> void prepareAndExecuteNormalizationWo
.combinationTechnique(combinationTechnique)
.explain(explain)
.pipelineProcessingContext(requestContextOptional.orElse(null))
.searchPhaseContext(searchPhaseContext)
.build();
normalizationWorkflow.execute(request);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.FieldDoc;
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.processor.combination.CombineScoresDto;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
Expand Down Expand Up @@ -64,25 +65,30 @@ public void execute(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional,
final ScoreNormalizationTechnique normalizationTechnique,
final ScoreCombinationTechnique combinationTechnique
final ScoreCombinationTechnique combinationTechnique,
final SearchPhaseContext searchPhaseContext
) {
NormalizationProcessorWorkflowExecuteRequest request = NormalizationProcessorWorkflowExecuteRequest.builder()
.querySearchResults(querySearchResults)
.fetchSearchResultOptional(fetchSearchResultOptional)
.normalizationTechnique(normalizationTechnique)
.combinationTechnique(combinationTechnique)
.explain(false)
.searchPhaseContext(searchPhaseContext)
.build();
execute(request);
}

public void execute(final NormalizationProcessorWorkflowExecuteRequest request) {
List<QuerySearchResult> querySearchResults = request.getQuerySearchResults();
Optional<FetchSearchResult> fetchSearchResultOptional = request.getFetchSearchResultOptional();

// save original state
List<Integer> unprocessedDocIds = unprocessedDocIds(request.getQuerySearchResults());
List<Integer> unprocessedDocIds = unprocessedDocIds(querySearchResults);

// pre-process data
log.debug("Pre-process query results");
List<CompoundTopDocs> queryTopDocs = getQueryTopDocs(request.getQuerySearchResults());
List<CompoundTopDocs> queryTopDocs = getQueryTopDocs(querySearchResults);

explain(request, queryTopDocs);

Expand All @@ -93,8 +99,9 @@ public void execute(final NormalizationProcessorWorkflowExecuteRequest request)
CombineScoresDto combineScoresDTO = CombineScoresDto.builder()
.queryTopDocs(queryTopDocs)
.scoreCombinationTechnique(request.getCombinationTechnique())
.querySearchResults(request.getQuerySearchResults())
.sort(evaluateSortCriteria(request.getQuerySearchResults(), queryTopDocs))
.querySearchResults(querySearchResults)
.sort(evaluateSortCriteria(querySearchResults, queryTopDocs))
.fromValueForSingleShard(getFromValueIfSingleShard(request))
.build();

// combine
Expand All @@ -103,8 +110,26 @@ public void execute(final NormalizationProcessorWorkflowExecuteRequest request)

// post-process data
log.debug("Post-process query results after score normalization and combination");
updateOriginalQueryResults(combineScoresDTO);
updateOriginalFetchResults(request.getQuerySearchResults(), request.getFetchSearchResultOptional(), unprocessedDocIds);
updateOriginalQueryResults(combineScoresDTO, fetchSearchResultOptional.isPresent());
updateOriginalFetchResults(
querySearchResults,
fetchSearchResultOptional,
unprocessedDocIds,
combineScoresDTO.getFromValueForSingleShard()
);
}

/**
* Get value of from parameter when there is a single shard
* and fetch phase is already executed
* Ref https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/search/SearchService.java#L715
*/
private int getFromValueIfSingleShard(final NormalizationProcessorWorkflowExecuteRequest request) {
final SearchPhaseContext searchPhaseContext = request.getSearchPhaseContext();
if (searchPhaseContext.getNumShards() > 1 || request.fetchSearchResultOptional.isEmpty()) {
return -1;
}
return searchPhaseContext.getRequest().source().from();
}

/**
Expand Down Expand Up @@ -173,19 +198,33 @@ private List<CompoundTopDocs> getQueryTopDocs(final List<QuerySearchResult> quer
return queryTopDocs;
}

private void updateOriginalQueryResults(final CombineScoresDto combineScoresDTO) {
private void updateOriginalQueryResults(final CombineScoresDto combineScoresDTO, final boolean isFetchPhaseExecuted) {
final List<QuerySearchResult> querySearchResults = combineScoresDTO.getQuerySearchResults();
final List<CompoundTopDocs> queryTopDocs = getCompoundTopDocs(combineScoresDTO, querySearchResults);
final Sort sort = combineScoresDTO.getSort();
int totalScoreDocsCount = 0;
for (int index = 0; index < querySearchResults.size(); index++) {
QuerySearchResult querySearchResult = querySearchResults.get(index);
CompoundTopDocs updatedTopDocs = queryTopDocs.get(index);
totalScoreDocsCount += updatedTopDocs.getScoreDocs().size();
TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(
buildTopDocs(updatedTopDocs, sort),
maxScoreForShard(updatedTopDocs, sort != null)
);
// Fetch Phase had ran before the normalization phase, therefore update the from value in result of each shard.
// This will ensure the trimming of the search results.
if (isFetchPhaseExecuted) {
querySearchResult.from(combineScoresDTO.getFromValueForSingleShard());
}
querySearchResult.topDocs(updatedTopDocsAndMaxScore, querySearchResult.sortValueFormats());
}

final int from = querySearchResults.get(0).from();
if (from > totalScoreDocsCount) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Reached end of search result, increase pagination_depth value to see more results")
);
}
}

private List<CompoundTopDocs> getCompoundTopDocs(CombineScoresDto combineScoresDTO, List<QuerySearchResult> querySearchResults) {
Expand Down Expand Up @@ -244,7 +283,8 @@ private TopDocs buildTopDocs(CompoundTopDocs updatedTopDocs, Sort sort) {
private void updateOriginalFetchResults(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional,
final List<Integer> docIds
final List<Integer> docIds,
final int fromValueForSingleShard
) {
if (fetchSearchResultOptional.isEmpty()) {
return;
Expand Down Expand Up @@ -276,14 +316,21 @@ private void updateOriginalFetchResults(

QuerySearchResult querySearchResult = querySearchResults.get(0);
TopDocs topDocs = querySearchResult.topDocs().topDocs;
// Scenario to handle when calculating the trimmed length of updated search hits
// When normalization process runs after fetch phase, then search hits already fetched. Therefore, use the from value sent in the
// search request to calculate the effective length of updated search hits array.
int trimmedLengthOfSearchHits = topDocs.scoreDocs.length - fromValueForSingleShard;
// iterate over the normalized/combined scores, that solves (1) and (3)
SearchHit[] updatedSearchHitArray = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> {
SearchHit[] updatedSearchHitArray = new SearchHit[trimmedLengthOfSearchHits];
for (int i = 0; i < trimmedLengthOfSearchHits; i++) {
// Read topDocs after the desired from length
ScoreDoc scoreDoc = topDocs.scoreDocs[i + fromValueForSingleShard];
// get fetched hit content by doc_id
SearchHit searchHit = docIdToSearchHit.get(scoreDoc.doc);
// update score to normalized/combined value (3)
searchHit.score(scoreDoc.score);
return searchHit;
}).toArray(SearchHit[]::new);
updatedSearchHitArray[i] = searchHit;
}
SearchHits updatedSearchHits = new SearchHits(
updatedSearchHitArray,
querySearchResult.getTotalHits(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.fetch.FetchSearchResult;
Expand All @@ -29,4 +30,5 @@ public class NormalizationProcessorWorkflowExecuteRequest {
final ScoreCombinationTechnique combinationTechnique;
boolean explain;
final PipelineProcessingContext pipelineProcessingContext;
final SearchPhaseContext searchPhaseContext;
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ public class CombineScoresDto {
private List<QuerySearchResult> querySearchResults;
@Nullable
private Sort sort;
private int fromValueForSingleShard;
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,10 @@ public class ScoreCombiner {
public void combineScores(final CombineScoresDto combineScoresDTO) {
// iterate over results from each shard. Every CompoundTopDocs object has results from
// multiple sub queries, doc ids may repeat for each sub query results
ScoreCombinationTechnique scoreCombinationTechnique = combineScoresDTO.getScoreCombinationTechnique();
Sort sort = combineScoresDTO.getSort();
combineScoresDTO.getQueryTopDocs()
.forEach(
compoundQueryTopDocs -> combineShardScores(
combineScoresDTO.getScoreCombinationTechnique(),
compoundQueryTopDocs,
combineScoresDTO.getSort()
)
);
.forEach(compoundQueryTopDocs -> combineShardScores(scoreCombinationTechnique, compoundQueryTopDocs, sort));
}

private void combineShardScores(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,21 @@
public final class HybridQuery extends Query implements Iterable<Query> {

private final List<Query> subQueries;
private final HybridQueryContext queryContext;

/**
* Create new instance of hybrid query object based on collection of sub queries and filter query
* @param subQueries collection of queries that are executed individually and contribute to a final list of combined scores
* @param filterQueries list of filters that will be applied to each sub query. Each filter from the list is added as bool "filter" clause. If this is null sub queries will be executed as is
*/
public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQueries) {
public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQueries, final HybridQueryContext hybridQueryContext) {
Objects.requireNonNull(subQueries, "collection of queries must not be null");
if (subQueries.isEmpty()) {
throw new IllegalArgumentException("collection of queries must not be empty");
}
if (hybridQueryContext.getPaginationDepth() == 0) {
vibrantvarun marked this conversation as resolved.
Show resolved Hide resolved
throw new IllegalArgumentException("pagination_depth must not be zero");
}
if (Objects.isNull(filterQueries) || filterQueries.isEmpty()) {
this.subQueries = new ArrayList<>(subQueries);
} else {
Expand All @@ -57,10 +61,11 @@ public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQ
}
this.subQueries = modifiedSubQueries;
}
this.queryContext = hybridQueryContext;
}

public HybridQuery(final Collection<Query> subQueries) {
this(subQueries, List.of());
public HybridQuery(final Collection<Query> subQueries, final HybridQueryContext hybridQueryContext) {
this(subQueries, List.of(), hybridQueryContext);
}

/**
Expand Down Expand Up @@ -128,7 +133,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
return super.rewrite(indexSearcher);
}
final List<Query> rewrittenSubQueries = manager.getQueriesAfterRewrite(collectors);
return new HybridQuery(rewrittenSubQueries);
return new HybridQuery(rewrittenSubQueries, queryContext);
}

private Void rewriteQuery(Query query, HybridQueryExecutorCollector<IndexSearcher, Map.Entry<Query, Boolean>> collector) {
Expand Down Expand Up @@ -190,6 +195,10 @@ public Collection<Query> getSubQueries() {
return Collections.unmodifiableCollection(subQueries);
}

public HybridQueryContext getQueryContext() {
return queryContext;
}

/**
* Create the Weight used to score this query
*
Expand Down
Loading
Loading